[WIP] #19
This commit is contained in:
parent
9c56ad1728
commit
2cb8d687e6
|
@ -1,56 +0,0 @@
|
|||
import logging.config
|
||||
from orator.orm import Model
|
||||
from orator import DatabaseManager
|
||||
from src.config import config, redis
|
||||
from src.entity.pair import Pair
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level='DEBUG')
|
||||
Model.set_connection_resolver(DatabaseManager({'db': config['db']}))
|
||||
|
||||
redis_c = redis.instance()
|
||||
counter = 0
|
||||
key = 'trigrams:{}:{}'
|
||||
for pairs in Pair.with_('replies', 'first', 'second').chunk(500):
|
||||
for pair in pairs:
|
||||
try:
|
||||
first = clear_word(pair.first.word)
|
||||
except AttributeError:
|
||||
first = config['grammar']['stop_word']
|
||||
|
||||
try:
|
||||
second = clear_word(pair.second.word)
|
||||
except AttributeError:
|
||||
second = config['grammar']['stop_word']
|
||||
|
||||
point = to_key(key=key, chat_id=pair.chat_id, pair=[first, second])
|
||||
replies = list(filter(None, map(get_word, pair.replies.all())))
|
||||
|
||||
if len(replies) == 0:
|
||||
continue
|
||||
|
||||
pipe = redis_c.pipeline()
|
||||
for reply in replies:
|
||||
pipe.sadd(point, reply)
|
||||
pipe.execute()
|
||||
|
||||
counter += 1
|
||||
|
||||
if counter % 1000 == 0:
|
||||
print("Imported: " + str(counter))
|
||||
|
||||
def clear_word(word):
|
||||
return word.strip(";()\-—\"[]{}«»/*&^#$")
|
||||
|
||||
def get_word(reply):
|
||||
try:
|
||||
return clear_word(reply.word.word)
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
def to_key(key, chat_id, pair):
|
||||
return key.format(chat_id, config['grammar']['separator'].join(pair))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -32,10 +32,12 @@ for section, options in sections.items():
|
|||
from src.redis_c import Redis
|
||||
redis = Redis(config)
|
||||
|
||||
from src.service.tokenizer import Tokenizer
|
||||
tokenz = Tokenizer()
|
||||
|
||||
from src.service import *
|
||||
data_learner = DataLearner(tokenz)
|
||||
reply_generator = ReplyGenerator(tokenz)
|
||||
chance_manager = ChanceManager()
|
||||
media_checker = MediaUniquenessChecker()
|
||||
tokenizer = Tokenizer()
|
||||
data_learner = DataLearner()
|
||||
reply_generator = ReplyGenerator()
|
||||
chat_purge_queue = ChatPurgeQueue()
|
||||
|
|
|
@ -10,6 +10,7 @@ class Message(AbstractEntity):
|
|||
super(Message, self).__init__(message)
|
||||
|
||||
self.chance = chance
|
||||
self.entities = message.entities
|
||||
|
||||
if self.has_text():
|
||||
self.text = message.text
|
||||
|
@ -29,7 +30,7 @@ class Message(AbstractEntity):
|
|||
def has_entities(self):
|
||||
"""Returns True if the message has entities (attachments).
|
||||
"""
|
||||
return self.message.entities is not None
|
||||
return self.entities is not None
|
||||
|
||||
def has_anchors(self):
|
||||
"""Returns True if the message contains at least one anchor from anchors config.
|
||||
|
|
|
@ -1,39 +0,0 @@
|
|||
from orator.orm import Model
|
||||
from orator.orm import belongs_to
|
||||
from orator.orm import has_many
|
||||
|
||||
from src.utils import random_element
|
||||
import src.entity.reply
|
||||
import src.entity.word
|
||||
|
||||
|
||||
class Pair(Model):
|
||||
__fillable__ = ['chat_id', 'first_id', 'second_id']
|
||||
__timestamps__ = ['created_at']
|
||||
|
||||
@has_many
|
||||
def replies(self):
|
||||
return src.entity.reply.Reply
|
||||
|
||||
@belongs_to
|
||||
def first(self):
|
||||
return src.entity.word.Word
|
||||
|
||||
@belongs_to
|
||||
def second(self):
|
||||
return src.entity.word.Word
|
||||
|
||||
@staticmethod
|
||||
def get_random_pair(chat_id, first_id, second_id_list):
|
||||
pairs = Pair\
|
||||
.with_({
|
||||
'replies': lambda q: q.order_by('count', 'desc').limit(3)
|
||||
})\
|
||||
.where('chat_id', chat_id)\
|
||||
.where('first_id', first_id)\
|
||||
.where_in('second_id', second_id_list)\
|
||||
.limit(3)\
|
||||
.get()\
|
||||
.all()
|
||||
|
||||
return random_element(pairs)
|
|
@ -1,19 +0,0 @@
|
|||
from orator.orm import Model
|
||||
from orator.orm import belongs_to
|
||||
from orator.orm import belongs_to_many
|
||||
|
||||
import src.entity.pair
|
||||
import src.entity.word
|
||||
|
||||
|
||||
class Reply(Model):
|
||||
__fillable__ = ['pair_id', 'word_id', 'count']
|
||||
__timestamps__ = False
|
||||
|
||||
@belongs_to_many
|
||||
def pairs(self):
|
||||
return src.entity.pair.Pair
|
||||
|
||||
@belongs_to
|
||||
def word(self):
|
||||
return src.entity.word.Word
|
|
@ -1,5 +0,0 @@
|
|||
from orator.orm import Model
|
||||
|
||||
class Word(Model):
|
||||
__fillable__ = ['word']
|
||||
__timestamps__ = False
|
|
@ -1,6 +1,5 @@
|
|||
from .data_learner import DataLearner
|
||||
from .reply_generator import ReplyGenerator
|
||||
from .chance_manager import ChanceManager
|
||||
from .chat_purge_queue import ChatPurgeQueue
|
||||
from .data_learner import DataLearner
|
||||
from .media_uniqueness_checker import MediaUniquenessChecker
|
||||
from .reply_generator import ReplyGenerator
|
||||
from .tokenizer import Tokenizer
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from src.config import redis, tokenizer
|
||||
from src.config import redis
|
||||
|
||||
|
||||
class DataLearner:
|
||||
def __init__(self):
|
||||
def __init__(self, tokenizer):
|
||||
self.redis = redis
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
from src.config import config, redis, tokenizer
|
||||
from src.config import config, redis
|
||||
from src.utils import strings_has_equal_letters, capitalize, random_element
|
||||
|
||||
|
||||
class ReplyGenerator:
|
||||
def __init__(self):
|
||||
def __init__(self, tokenizer):
|
||||
self.redis = redis
|
||||
self.tokenizer = tokenizer
|
||||
self.max_words = config.getint('grammar', 'max_words')
|
||||
|
@ -18,7 +18,7 @@ class ReplyGenerator:
|
|||
|
||||
best_message = ''
|
||||
for _ in range(self.max_messages):
|
||||
generated = self.__generate_sentence(pair)
|
||||
generated = self.__generate_sentence(message.chat_id, pair)
|
||||
if len(generated) > len(best_message):
|
||||
best_message = generated
|
||||
|
||||
|
@ -27,30 +27,35 @@ class ReplyGenerator:
|
|||
|
||||
result = random_element(messages) if len(messages) else ''
|
||||
|
||||
if strings_has_equal_letters(result, ''.join(message.words)):
|
||||
if strings_has_equal_letters(result, ''.join(words)):
|
||||
return ''
|
||||
|
||||
return result
|
||||
|
||||
def __generate_sentence(self, seed):
|
||||
def __generate_sentence(self, chat_id, seed):
|
||||
key = seed
|
||||
gen_words = []
|
||||
redis = self.redis.instance()
|
||||
|
||||
for _ in range(self.max_words):
|
||||
if len(gen_words):
|
||||
gen_words.append(key[0])
|
||||
else:
|
||||
gen_words.append(capitalize(key[0]))
|
||||
words = key
|
||||
|
||||
next_word = redis.srandmember(self.tokenizer.to_key(key))
|
||||
if not next_word:
|
||||
if len(gen_words):
|
||||
gen_words.append(words[0])
|
||||
else:
|
||||
gen_words.append(capitalize(words[0]))
|
||||
|
||||
next_word = redis.srandmember(self.tokenizer.to_key(chat_id=chat_id, pair=key))
|
||||
if next_word is None:
|
||||
break
|
||||
next_word = next_word.decode("utf-8")
|
||||
if next_word == self.tokenizer.stop_word:
|
||||
break
|
||||
|
||||
key = self.tokenizer.separator.join(key[1:] + [next_word])
|
||||
key = words[1:] + [next_word]
|
||||
|
||||
sentence = ' '.join(gen_words).strip()
|
||||
if sentence[-1:] not in self.tokenizer.end_sentence:
|
||||
sentence += self.tokenizer.random_end_sentence_token()
|
||||
|
||||
return ' '.join(gen_words)
|
||||
return sentence
|
||||
|
|
Loading…
Reference in New Issue