imaginaryfriend/src/service/reply_generator.py

63 lines
2.3 KiB
Python
Raw Normal View History

2016-12-28 19:01:37 +01:00
from src.config import config, redis, tokenizer, trigram_repository
2016-12-27 19:23:49 +01:00
from src.utils import strings_has_equal_letters, capitalize
class ReplyGenerator:
2016-12-27 19:23:49 +01:00
def __init__(self):
2016-12-09 19:37:55 +01:00
self.redis = redis
2016-12-28 19:01:37 +01:00
self.tokenizer = tokenizer
2016-12-27 19:23:49 +01:00
self.trigram_repository = trigram_repository
2016-12-09 19:37:55 +01:00
self.max_words = config.getint('grammar', 'max_words')
self.max_messages = config.getint('grammar', 'max_messages')
2016-12-27 19:23:49 +01:00
self.stop_word = config['grammar']['stop_word']
self.separator = config['grammar']['separator']
self.end_sentence = config['grammar']['end_sentence']
2016-12-09 19:37:55 +01:00
def generate(self, message):
words = self.tokenizer.extract_words(message)
2016-12-28 19:01:37 +01:00
pairs = [trigram[:-1] for trigram in self.tokenizer.split_to_trigrams(words)]
messages = [self.__generate_best_message(chat_id=message.chat_id, pair=pair) for pair in pairs]
longest_message = max(messages, key=len) if len(messages) else ''
2016-12-28 19:01:37 +01:00
if longest_message and strings_has_equal_letters(longest_message, ''.join(words)):
2016-12-09 19:37:55 +01:00
return ''
2016-12-28 19:01:37 +01:00
return longest_message
2016-12-27 19:23:49 +01:00
def __generate_best_message(self, chat_id, pair):
best_message = ''
for _ in range(self.max_messages):
generated = self.__generate_sentence(chat_id=chat_id, pair=pair)
if len(generated) > len(best_message):
best_message = generated
return best_message
def __generate_sentence(self, chat_id, pair):
2016-12-09 19:37:55 +01:00
gen_words = []
2016-12-27 19:23:49 +01:00
key = self.separator.join(pair)
2016-12-09 19:37:55 +01:00
for _ in range(self.max_words):
2016-12-27 19:23:49 +01:00
words = key.split(self.separator)
2016-12-09 20:45:50 +01:00
2016-12-28 18:12:24 +01:00
gen_words.append(words[1] if len(gen_words) == 0 else words[1])
2016-12-09 19:37:55 +01:00
2016-12-27 19:23:49 +01:00
next_word = self.trigram_repository.get_random_reply(chat_id, key)
2016-12-09 20:45:50 +01:00
if next_word is None:
break
2016-12-27 19:23:49 +01:00
key = self.separator.join(words[1:] + [next_word])
2016-12-09 19:37:55 +01:00
2016-12-28 18:12:24 +01:00
last_word = key.split(self.separator)[-1]
if last_word not in gen_words:
gen_words.append(last_word)
2016-12-27 19:23:49 +01:00
gen_words = list(filter(lambda w: w != self.stop_word, gen_words))
2016-12-09 19:37:55 +01:00
sentence = ' '.join(gen_words).strip()
2016-12-27 19:23:49 +01:00
if sentence[-1:] not in self.end_sentence:
2016-12-09 19:37:55 +01:00
sentence += self.tokenizer.random_end_sentence_token()
2016-12-27 19:23:49 +01:00
return capitalize(sentence)