[WIP] #19
This commit is contained in:
parent
9c56ad1728
commit
2cb8d687e6
|
@ -1,56 +0,0 @@
|
||||||
import logging.config
|
|
||||||
from orator.orm import Model
|
|
||||||
from orator import DatabaseManager
|
|
||||||
from src.config import config, redis
|
|
||||||
from src.entity.pair import Pair
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
logging.basicConfig(level='DEBUG')
|
|
||||||
Model.set_connection_resolver(DatabaseManager({'db': config['db']}))
|
|
||||||
|
|
||||||
redis_c = redis.instance()
|
|
||||||
counter = 0
|
|
||||||
key = 'trigrams:{}:{}'
|
|
||||||
for pairs in Pair.with_('replies', 'first', 'second').chunk(500):
|
|
||||||
for pair in pairs:
|
|
||||||
try:
|
|
||||||
first = clear_word(pair.first.word)
|
|
||||||
except AttributeError:
|
|
||||||
first = config['grammar']['stop_word']
|
|
||||||
|
|
||||||
try:
|
|
||||||
second = clear_word(pair.second.word)
|
|
||||||
except AttributeError:
|
|
||||||
second = config['grammar']['stop_word']
|
|
||||||
|
|
||||||
point = to_key(key=key, chat_id=pair.chat_id, pair=[first, second])
|
|
||||||
replies = list(filter(None, map(get_word, pair.replies.all())))
|
|
||||||
|
|
||||||
if len(replies) == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
pipe = redis_c.pipeline()
|
|
||||||
for reply in replies:
|
|
||||||
pipe.sadd(point, reply)
|
|
||||||
pipe.execute()
|
|
||||||
|
|
||||||
counter += 1
|
|
||||||
|
|
||||||
if counter % 1000 == 0:
|
|
||||||
print("Imported: " + str(counter))
|
|
||||||
|
|
||||||
def clear_word(word):
|
|
||||||
return word.strip(";()\-—\"[]{}«»/*&^#$")
|
|
||||||
|
|
||||||
def get_word(reply):
|
|
||||||
try:
|
|
||||||
return clear_word(reply.word.word)
|
|
||||||
except AttributeError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def to_key(key, chat_id, pair):
|
|
||||||
return key.format(chat_id, config['grammar']['separator'].join(pair))
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
|
@ -32,10 +32,12 @@ for section, options in sections.items():
|
||||||
from src.redis_c import Redis
|
from src.redis_c import Redis
|
||||||
redis = Redis(config)
|
redis = Redis(config)
|
||||||
|
|
||||||
|
from src.service.tokenizer import Tokenizer
|
||||||
|
tokenz = Tokenizer()
|
||||||
|
|
||||||
from src.service import *
|
from src.service import *
|
||||||
|
data_learner = DataLearner(tokenz)
|
||||||
|
reply_generator = ReplyGenerator(tokenz)
|
||||||
chance_manager = ChanceManager()
|
chance_manager = ChanceManager()
|
||||||
media_checker = MediaUniquenessChecker()
|
media_checker = MediaUniquenessChecker()
|
||||||
tokenizer = Tokenizer()
|
|
||||||
data_learner = DataLearner()
|
|
||||||
reply_generator = ReplyGenerator()
|
|
||||||
chat_purge_queue = ChatPurgeQueue()
|
chat_purge_queue = ChatPurgeQueue()
|
||||||
|
|
|
@ -10,6 +10,7 @@ class Message(AbstractEntity):
|
||||||
super(Message, self).__init__(message)
|
super(Message, self).__init__(message)
|
||||||
|
|
||||||
self.chance = chance
|
self.chance = chance
|
||||||
|
self.entities = message.entities
|
||||||
|
|
||||||
if self.has_text():
|
if self.has_text():
|
||||||
self.text = message.text
|
self.text = message.text
|
||||||
|
@ -29,7 +30,7 @@ class Message(AbstractEntity):
|
||||||
def has_entities(self):
|
def has_entities(self):
|
||||||
"""Returns True if the message has entities (attachments).
|
"""Returns True if the message has entities (attachments).
|
||||||
"""
|
"""
|
||||||
return self.message.entities is not None
|
return self.entities is not None
|
||||||
|
|
||||||
def has_anchors(self):
|
def has_anchors(self):
|
||||||
"""Returns True if the message contains at least one anchor from anchors config.
|
"""Returns True if the message contains at least one anchor from anchors config.
|
||||||
|
|
|
@ -1,39 +0,0 @@
|
||||||
from orator.orm import Model
|
|
||||||
from orator.orm import belongs_to
|
|
||||||
from orator.orm import has_many
|
|
||||||
|
|
||||||
from src.utils import random_element
|
|
||||||
import src.entity.reply
|
|
||||||
import src.entity.word
|
|
||||||
|
|
||||||
|
|
||||||
class Pair(Model):
|
|
||||||
__fillable__ = ['chat_id', 'first_id', 'second_id']
|
|
||||||
__timestamps__ = ['created_at']
|
|
||||||
|
|
||||||
@has_many
|
|
||||||
def replies(self):
|
|
||||||
return src.entity.reply.Reply
|
|
||||||
|
|
||||||
@belongs_to
|
|
||||||
def first(self):
|
|
||||||
return src.entity.word.Word
|
|
||||||
|
|
||||||
@belongs_to
|
|
||||||
def second(self):
|
|
||||||
return src.entity.word.Word
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_random_pair(chat_id, first_id, second_id_list):
|
|
||||||
pairs = Pair\
|
|
||||||
.with_({
|
|
||||||
'replies': lambda q: q.order_by('count', 'desc').limit(3)
|
|
||||||
})\
|
|
||||||
.where('chat_id', chat_id)\
|
|
||||||
.where('first_id', first_id)\
|
|
||||||
.where_in('second_id', second_id_list)\
|
|
||||||
.limit(3)\
|
|
||||||
.get()\
|
|
||||||
.all()
|
|
||||||
|
|
||||||
return random_element(pairs)
|
|
|
@ -1,19 +0,0 @@
|
||||||
from orator.orm import Model
|
|
||||||
from orator.orm import belongs_to
|
|
||||||
from orator.orm import belongs_to_many
|
|
||||||
|
|
||||||
import src.entity.pair
|
|
||||||
import src.entity.word
|
|
||||||
|
|
||||||
|
|
||||||
class Reply(Model):
|
|
||||||
__fillable__ = ['pair_id', 'word_id', 'count']
|
|
||||||
__timestamps__ = False
|
|
||||||
|
|
||||||
@belongs_to_many
|
|
||||||
def pairs(self):
|
|
||||||
return src.entity.pair.Pair
|
|
||||||
|
|
||||||
@belongs_to
|
|
||||||
def word(self):
|
|
||||||
return src.entity.word.Word
|
|
|
@ -1,5 +0,0 @@
|
||||||
from orator.orm import Model
|
|
||||||
|
|
||||||
class Word(Model):
|
|
||||||
__fillable__ = ['word']
|
|
||||||
__timestamps__ = False
|
|
|
@ -1,6 +1,5 @@
|
||||||
|
from .data_learner import DataLearner
|
||||||
|
from .reply_generator import ReplyGenerator
|
||||||
from .chance_manager import ChanceManager
|
from .chance_manager import ChanceManager
|
||||||
from .chat_purge_queue import ChatPurgeQueue
|
from .chat_purge_queue import ChatPurgeQueue
|
||||||
from .data_learner import DataLearner
|
|
||||||
from .media_uniqueness_checker import MediaUniquenessChecker
|
from .media_uniqueness_checker import MediaUniquenessChecker
|
||||||
from .reply_generator import ReplyGenerator
|
|
||||||
from .tokenizer import Tokenizer
|
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
from src.config import redis, tokenizer
|
from src.config import redis
|
||||||
|
|
||||||
|
|
||||||
class DataLearner:
|
class DataLearner:
|
||||||
def __init__(self):
|
def __init__(self, tokenizer):
|
||||||
self.redis = redis
|
self.redis = redis
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
from src.config import config, redis, tokenizer
|
from src.config import config, redis
|
||||||
from src.utils import strings_has_equal_letters, capitalize, random_element
|
from src.utils import strings_has_equal_letters, capitalize, random_element
|
||||||
|
|
||||||
|
|
||||||
class ReplyGenerator:
|
class ReplyGenerator:
|
||||||
def __init__(self):
|
def __init__(self, tokenizer):
|
||||||
self.redis = redis
|
self.redis = redis
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.max_words = config.getint('grammar', 'max_words')
|
self.max_words = config.getint('grammar', 'max_words')
|
||||||
|
@ -18,7 +18,7 @@ class ReplyGenerator:
|
||||||
|
|
||||||
best_message = ''
|
best_message = ''
|
||||||
for _ in range(self.max_messages):
|
for _ in range(self.max_messages):
|
||||||
generated = self.__generate_sentence(pair)
|
generated = self.__generate_sentence(message.chat_id, pair)
|
||||||
if len(generated) > len(best_message):
|
if len(generated) > len(best_message):
|
||||||
best_message = generated
|
best_message = generated
|
||||||
|
|
||||||
|
@ -27,30 +27,35 @@ class ReplyGenerator:
|
||||||
|
|
||||||
result = random_element(messages) if len(messages) else ''
|
result = random_element(messages) if len(messages) else ''
|
||||||
|
|
||||||
if strings_has_equal_letters(result, ''.join(message.words)):
|
if strings_has_equal_letters(result, ''.join(words)):
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def __generate_sentence(self, seed):
|
def __generate_sentence(self, chat_id, seed):
|
||||||
key = seed
|
key = seed
|
||||||
gen_words = []
|
gen_words = []
|
||||||
redis = self.redis.instance()
|
redis = self.redis.instance()
|
||||||
|
|
||||||
for _ in range(self.max_words):
|
for _ in range(self.max_words):
|
||||||
if len(gen_words):
|
words = key
|
||||||
gen_words.append(key[0])
|
|
||||||
else:
|
|
||||||
gen_words.append(capitalize(key[0]))
|
|
||||||
|
|
||||||
next_word = redis.srandmember(self.tokenizer.to_key(key))
|
if len(gen_words):
|
||||||
if not next_word:
|
gen_words.append(words[0])
|
||||||
|
else:
|
||||||
|
gen_words.append(capitalize(words[0]))
|
||||||
|
|
||||||
|
next_word = redis.srandmember(self.tokenizer.to_key(chat_id=chat_id, pair=key))
|
||||||
|
if next_word is None:
|
||||||
|
break
|
||||||
|
next_word = next_word.decode("utf-8")
|
||||||
|
if next_word == self.tokenizer.stop_word:
|
||||||
break
|
break
|
||||||
|
|
||||||
key = self.tokenizer.separator.join(key[1:] + [next_word])
|
key = words[1:] + [next_word]
|
||||||
|
|
||||||
sentence = ' '.join(gen_words).strip()
|
sentence = ' '.join(gen_words).strip()
|
||||||
if sentence[-1:] not in self.tokenizer.end_sentence:
|
if sentence[-1:] not in self.tokenizer.end_sentence:
|
||||||
sentence += self.tokenizer.random_end_sentence_token()
|
sentence += self.tokenizer.random_end_sentence_token()
|
||||||
|
|
||||||
return ' '.join(gen_words)
|
return sentence
|
||||||
|
|
Loading…
Reference in New Issue