This commit is contained in:
REDNBLACK 2016-12-09 22:45:50 +03:00
parent 9c56ad1728
commit 2cb8d687e6
10 changed files with 29 additions and 141 deletions

View File

@ -1,56 +0,0 @@
import logging.config
from orator.orm import Model
from orator import DatabaseManager
from src.config import config, redis
from src.entity.pair import Pair
def main():
logging.basicConfig(level='DEBUG')
Model.set_connection_resolver(DatabaseManager({'db': config['db']}))
redis_c = redis.instance()
counter = 0
key = 'trigrams:{}:{}'
for pairs in Pair.with_('replies', 'first', 'second').chunk(500):
for pair in pairs:
try:
first = clear_word(pair.first.word)
except AttributeError:
first = config['grammar']['stop_word']
try:
second = clear_word(pair.second.word)
except AttributeError:
second = config['grammar']['stop_word']
point = to_key(key=key, chat_id=pair.chat_id, pair=[first, second])
replies = list(filter(None, map(get_word, pair.replies.all())))
if len(replies) == 0:
continue
pipe = redis_c.pipeline()
for reply in replies:
pipe.sadd(point, reply)
pipe.execute()
counter += 1
if counter % 1000 == 0:
print("Imported: " + str(counter))
def clear_word(word):
return word.strip(";()\-—\"[]{}«»/*&^#$")
def get_word(reply):
try:
return clear_word(reply.word.word)
except AttributeError:
return None
def to_key(key, chat_id, pair):
return key.format(chat_id, config['grammar']['separator'].join(pair))
if __name__ == '__main__':
main()

View File

@ -32,10 +32,12 @@ for section, options in sections.items():
from src.redis_c import Redis
redis = Redis(config)
from src.service.tokenizer import Tokenizer
tokenz = Tokenizer()
from src.service import *
data_learner = DataLearner(tokenz)
reply_generator = ReplyGenerator(tokenz)
chance_manager = ChanceManager()
media_checker = MediaUniquenessChecker()
tokenizer = Tokenizer()
data_learner = DataLearner()
reply_generator = ReplyGenerator()
chat_purge_queue = ChatPurgeQueue()

View File

@ -10,6 +10,7 @@ class Message(AbstractEntity):
super(Message, self).__init__(message)
self.chance = chance
self.entities = message.entities
if self.has_text():
self.text = message.text
@ -29,7 +30,7 @@ class Message(AbstractEntity):
def has_entities(self):
"""Returns True if the message has entities (attachments).
"""
return self.message.entities is not None
return self.entities is not None
def has_anchors(self):
"""Returns True if the message contains at least one anchor from anchors config.

View File

View File

@ -1,39 +0,0 @@
from orator.orm import Model
from orator.orm import belongs_to
from orator.orm import has_many
from src.utils import random_element
import src.entity.reply
import src.entity.word
class Pair(Model):
__fillable__ = ['chat_id', 'first_id', 'second_id']
__timestamps__ = ['created_at']
@has_many
def replies(self):
return src.entity.reply.Reply
@belongs_to
def first(self):
return src.entity.word.Word
@belongs_to
def second(self):
return src.entity.word.Word
@staticmethod
def get_random_pair(chat_id, first_id, second_id_list):
pairs = Pair\
.with_({
'replies': lambda q: q.order_by('count', 'desc').limit(3)
})\
.where('chat_id', chat_id)\
.where('first_id', first_id)\
.where_in('second_id', second_id_list)\
.limit(3)\
.get()\
.all()
return random_element(pairs)

View File

@ -1,19 +0,0 @@
from orator.orm import Model
from orator.orm import belongs_to
from orator.orm import belongs_to_many
import src.entity.pair
import src.entity.word
class Reply(Model):
__fillable__ = ['pair_id', 'word_id', 'count']
__timestamps__ = False
@belongs_to_many
def pairs(self):
return src.entity.pair.Pair
@belongs_to
def word(self):
return src.entity.word.Word

View File

@ -1,5 +0,0 @@
from orator.orm import Model
class Word(Model):
__fillable__ = ['word']
__timestamps__ = False

View File

@ -1,6 +1,5 @@
from .data_learner import DataLearner
from .reply_generator import ReplyGenerator
from .chance_manager import ChanceManager
from .chat_purge_queue import ChatPurgeQueue
from .data_learner import DataLearner
from .media_uniqueness_checker import MediaUniquenessChecker
from .reply_generator import ReplyGenerator
from .tokenizer import Tokenizer

View File

@ -1,8 +1,8 @@
from src.config import redis, tokenizer
from src.config import redis
class DataLearner:
def __init__(self):
def __init__(self, tokenizer):
self.redis = redis
self.tokenizer = tokenizer

View File

@ -1,9 +1,9 @@
from src.config import config, redis, tokenizer
from src.config import config, redis
from src.utils import strings_has_equal_letters, capitalize, random_element
class ReplyGenerator:
def __init__(self):
def __init__(self, tokenizer):
self.redis = redis
self.tokenizer = tokenizer
self.max_words = config.getint('grammar', 'max_words')
@ -18,7 +18,7 @@ class ReplyGenerator:
best_message = ''
for _ in range(self.max_messages):
generated = self.__generate_sentence(pair)
generated = self.__generate_sentence(message.chat_id, pair)
if len(generated) > len(best_message):
best_message = generated
@ -27,30 +27,35 @@ class ReplyGenerator:
result = random_element(messages) if len(messages) else ''
if strings_has_equal_letters(result, ''.join(message.words)):
if strings_has_equal_letters(result, ''.join(words)):
return ''
return result
def __generate_sentence(self, seed):
def __generate_sentence(self, chat_id, seed):
key = seed
gen_words = []
redis = self.redis.instance()
for _ in range(self.max_words):
if len(gen_words):
gen_words.append(key[0])
else:
gen_words.append(capitalize(key[0]))
words = key
next_word = redis.srandmember(self.tokenizer.to_key(key))
if not next_word:
if len(gen_words):
gen_words.append(words[0])
else:
gen_words.append(capitalize(words[0]))
next_word = redis.srandmember(self.tokenizer.to_key(chat_id=chat_id, pair=key))
if next_word is None:
break
next_word = next_word.decode("utf-8")
if next_word == self.tokenizer.stop_word:
break
key = self.tokenizer.separator.join(key[1:] + [next_word])
key = words[1:] + [next_word]
sentence = ' '.join(gen_words).strip()
if sentence[-1:] not in self.tokenizer.end_sentence:
sentence += self.tokenizer.random_end_sentence_token()
return ' '.join(gen_words)
return sentence