Separated learning and generation by components
This commit is contained in:
parent
889a8f7fde
commit
6ee9feb61a
10
src/bot.py
10
src/bot.py
|
@ -1,10 +1,12 @@
|
|||
import logging
|
||||
|
||||
from telegram.ext import Updater
|
||||
from src.handlers.message_handler import MessageHandler
|
||||
from src.handlers.command_handler import CommandHandler
|
||||
from src.handlers.status_handler import StatusHandler
|
||||
from src.handler.message_handler import MessageHandler
|
||||
from src.handler.command_handler import CommandHandler
|
||||
from src.handler.status_handler import StatusHandler
|
||||
from src.chat_purge_queue import ChatPurgeQueue
|
||||
from src.service.reply_generator import ReplyGenerator
|
||||
from src.service.data_learner import DataLearner
|
||||
from src.config import config
|
||||
|
||||
|
||||
|
@ -16,7 +18,7 @@ class Bot:
|
|||
def run(self):
|
||||
logging.info("Bot started")
|
||||
|
||||
self.dispatcher.add_handler(MessageHandler())
|
||||
self.dispatcher.add_handler(MessageHandler(data_learner=DataLearner(), reply_generator=ReplyGenerator()))
|
||||
self.dispatcher.add_handler(CommandHandler())
|
||||
self.dispatcher.add_handler(StatusHandler(chat_purge_queue=ChatPurgeQueue(self.updater.job_queue)))
|
||||
|
||||
|
|
|
@ -6,12 +6,10 @@ from orator.orm import has_many
|
|||
import src.entity.reply
|
||||
import src.entity.chat
|
||||
import src.entity.word
|
||||
from src.utils import *
|
||||
from src.config import config
|
||||
|
||||
|
||||
class Pair(Model):
|
||||
__guarded__ = ['id']
|
||||
__fillable__ = ['chat_id', 'first_id', 'second_id']
|
||||
__timestamps__ = ['created_at']
|
||||
|
||||
@has_many
|
||||
|
@ -31,113 +29,15 @@ class Pair(Model):
|
|||
return src.entity.word.Word
|
||||
|
||||
@staticmethod
|
||||
def generate(message):
|
||||
return Pair.generate_story(message, message.words, random.randint(0, 2) + 1)
|
||||
|
||||
@staticmethod
|
||||
def generate_story(message, words, sentences):
|
||||
words_ids = src.entity.word.Word.where_in('word', words).get().pluck('id').all()
|
||||
|
||||
result = []
|
||||
for _ in range(0, sentences):
|
||||
result.append(Pair.__generate_sentence(message, words_ids))
|
||||
|
||||
return ' '.join(result)
|
||||
|
||||
@staticmethod
|
||||
def learn(message):
|
||||
src.entity.word.Word.learn(message.words)
|
||||
|
||||
words = [None]
|
||||
for word in message.words:
|
||||
words.append(word)
|
||||
if word[-1] in config['grammar']['end_sentence']:
|
||||
words.append(None)
|
||||
if words[-1] is not None:
|
||||
words.append(None)
|
||||
|
||||
while any(word for word in words):
|
||||
trigram = words[:3]
|
||||
first_word_id, second_word_id, *third_word_id = list(map(
|
||||
lambda x: None if x is None else src.entity.word.Word.where('word', x).first().id,
|
||||
trigram
|
||||
))
|
||||
third_word_id = None if len(third_word_id) == 0 else third_word_id[0]
|
||||
|
||||
words.pop(0)
|
||||
|
||||
pair = Pair.where('chat_id', message.chat.id)\
|
||||
.where('first_id', first_word_id)\
|
||||
.where('second_id', second_word_id)\
|
||||
.first()
|
||||
if pair is None:
|
||||
pair = Pair.create(chat_id=message.chat.id,
|
||||
first_id=first_word_id,
|
||||
second_id=second_word_id)
|
||||
|
||||
reply = pair.replies().where('word_id', third_word_id).first()
|
||||
|
||||
if reply is not None:
|
||||
reply.count += 1
|
||||
reply.save()
|
||||
else:
|
||||
pair.replies().create(pair_id=pair.id, word_id=third_word_id)
|
||||
|
||||
@staticmethod
|
||||
def __generate_sentence(message, word_ids):
|
||||
sentences = []
|
||||
safety_counter = 50
|
||||
first_word = None
|
||||
second_words = list(word_ids)
|
||||
|
||||
while safety_counter > 0:
|
||||
pair = Pair.__get_pair(chat_id=message.chat.id, first_id=first_word, second_ids=second_words)
|
||||
replies = getattr(pair, 'replies', [])
|
||||
safety_counter -= 1
|
||||
|
||||
if pair is None or len(replies) == 0:
|
||||
continue
|
||||
|
||||
reply = random.choice(replies.all())
|
||||
first_word = pair.second.id
|
||||
|
||||
# TODO. WARNING! Do not try to fix, it's magic, i have no clue why
|
||||
try:
|
||||
second_words = [reply.word.id]
|
||||
except AttributeError:
|
||||
second_words = None
|
||||
|
||||
if len(sentences) == 0:
|
||||
sentences.append(capitalize(pair.second.word))
|
||||
word_ids.remove(pair.second.id)
|
||||
|
||||
# TODO. WARNING! Do not try to fix, it's magic, i have no clue why
|
||||
try:
|
||||
reply_word = reply.word.word
|
||||
except AttributeError:
|
||||
reply_word = None
|
||||
|
||||
if reply_word is not None:
|
||||
sentences.append(reply_word)
|
||||
else:
|
||||
break
|
||||
|
||||
sentence = ' '.join(sentences).strip()
|
||||
if sentence[-1:] not in config['grammar']['end_sentence']:
|
||||
sentence += random_element(list(config['grammar']['end_sentence']))
|
||||
|
||||
return sentence
|
||||
|
||||
@staticmethod
|
||||
def __get_pair(chat_id, first_id, second_ids):
|
||||
def get_random_pair(chat_id, first_id, second_id_list):
|
||||
ten_minutes_ago = datetime.now() - timedelta(seconds=10 * 60)
|
||||
pairs = Pair.with_('replies')\
|
||||
return Pair\
|
||||
.with_({
|
||||
'replies': lambda q: q.order_by('count', 'desc').limit(3)
|
||||
})\
|
||||
.where('chat_id', chat_id)\
|
||||
.where('first_id', first_id)\
|
||||
.where_in('second_id', second_ids)\
|
||||
.where_in('second_id', second_id_list)\
|
||||
.where('created_at', '<', ten_minutes_ago)\
|
||||
.limit(3)\
|
||||
.get()\
|
||||
.all()
|
||||
|
||||
return random_element(pairs)
|
||||
.order_by_raw('RANDOM()')\
|
||||
.first()
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
from collections import OrderedDict
|
||||
|
||||
from orator.orm import Model
|
||||
from orator.orm import has_many
|
||||
|
||||
|
@ -13,12 +11,3 @@ class Word(Model):
|
|||
@has_many
|
||||
def chats(self):
|
||||
return src.entity.chat.Chat
|
||||
|
||||
@staticmethod
|
||||
def learn(words):
|
||||
existing_words = Word.where_in('word', words).get().pluck('word').all()
|
||||
# TODO. Слова должны быть уникальные И ТАКЖЕ ОБЯЗАТЕЛЬНО в оригинальном порядке
|
||||
new_words = [word for word in OrderedDict.fromkeys(words).keys() if word not in existing_words]
|
||||
|
||||
for word in new_words:
|
||||
Word.create(word=word)
|
||||
|
|
|
@ -3,16 +3,18 @@ import logging
|
|||
from telegram.ext import MessageHandler as ParentHandler, Filters
|
||||
|
||||
from src.domain.message import Message
|
||||
from src.entity.pair import Pair
|
||||
from src.entity.chat import Chat
|
||||
|
||||
|
||||
class MessageHandler(ParentHandler):
|
||||
def __init__(self):
|
||||
def __init__(self, data_learner, reply_generator):
|
||||
super(MessageHandler, self).__init__(
|
||||
Filters.text | Filters.sticker,
|
||||
self.handle)
|
||||
|
||||
self.data_learner = data_learner
|
||||
self.reply_generator = reply_generator
|
||||
|
||||
def handle(self, bot, update):
|
||||
chat = Chat.get_chat(update.message)
|
||||
message = Message(chat=chat, message=update.message)
|
||||
|
@ -29,14 +31,13 @@ class MessageHandler(ParentHandler):
|
|||
return self.__process_sticker(bot, message)
|
||||
|
||||
def __process_message(self, bot, message):
|
||||
Pair.learn(message)
|
||||
self.data_learner.learn(message)
|
||||
|
||||
if message.has_anchors() \
|
||||
or message.is_private() \
|
||||
or message.is_reply_to_bot() \
|
||||
or message.is_random_answer():
|
||||
|
||||
reply = Pair.generate(message)
|
||||
reply = self.reply_generator.generate(message)
|
||||
if reply != '':
|
||||
self.__answer(bot, message, reply)
|
||||
|
||||
|
@ -56,16 +57,6 @@ class MessageHandler(ParentHandler):
|
|||
|
||||
bot.sendMessage(chat_id=message.chat.telegram_id, text=reply)
|
||||
|
||||
def __reply(self, bot, message, reply):
|
||||
logging.debug("[Chat %s %s reply] %s" %
|
||||
(message.chat.chat_type,
|
||||
message.chat.telegram_id,
|
||||
reply))
|
||||
|
||||
bot.sendMessage(chat_id=message.chat.telegram_id,
|
||||
reply_to_message_id=message.message.message_id,
|
||||
text=reply)
|
||||
|
||||
def __send_sticker(self, bot, message, sticker_id):
|
||||
logging.debug("[Chat %s %s send_sticker]" %
|
||||
(message.chat.chat_type, message.chat.telegram_id))
|
|
@ -0,0 +1,59 @@
|
|||
from collections import OrderedDict
|
||||
from src.config import config
|
||||
from src.entity.word import Word
|
||||
from src.entity.pair import Pair
|
||||
|
||||
|
||||
class DataLearner:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def learn(self, message):
|
||||
self.__write_new_unique_words(message.words)
|
||||
|
||||
words = self.__normalize_words(message.words)
|
||||
while any(word for word in words):
|
||||
trigram = words[:3]
|
||||
first_word_id, second_word_id, *third_word_id = list(map(
|
||||
lambda x: None if x is None else Word.where('word', x).first(['id']).id,
|
||||
trigram
|
||||
))
|
||||
third_word_id = None if len(third_word_id) == 0 else third_word_id[0]
|
||||
|
||||
words.pop(0)
|
||||
|
||||
pair = Pair.where('chat_id', message.chat.id) \
|
||||
.where('first_id', first_word_id) \
|
||||
.where('second_id', second_word_id) \
|
||||
.first()
|
||||
if pair is None:
|
||||
pair = Pair.create(chat_id=message.chat.id,
|
||||
first_id=first_word_id,
|
||||
second_id=second_word_id)
|
||||
|
||||
reply = pair.replies().where('word_id', third_word_id).first()
|
||||
|
||||
if reply is not None:
|
||||
reply.count += 1
|
||||
reply.save()
|
||||
else:
|
||||
pair.replies().create(pair_id=pair.id, word_id=third_word_id)
|
||||
|
||||
def __normalize_words(self, src_words):
|
||||
words = [None]
|
||||
for word in src_words:
|
||||
words.append(word)
|
||||
if word[-1] in config['grammar']['end_sentence']:
|
||||
words.append(None)
|
||||
if words[-1] is not None:
|
||||
words.append(None)
|
||||
|
||||
return words
|
||||
|
||||
def __write_new_unique_words(self, words):
|
||||
# TODO. Слова должны быть уникальные И ТАКЖЕ ОБЯЗАТЕЛЬНО в оригинальном порядке
|
||||
existing_words = Word.where_in('word', words).lists('word').all()
|
||||
new_words = [word for word in OrderedDict.fromkeys(words).keys() if word not in existing_words]
|
||||
|
||||
for word in new_words:
|
||||
Word.create(word=word)
|
|
@ -0,0 +1,63 @@
|
|||
from src.config import config
|
||||
from src.utils import *
|
||||
from src.entity.word import Word
|
||||
from src.entity.pair import Pair
|
||||
|
||||
|
||||
class ReplyGenerator:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def generate(self, message):
|
||||
return self.generate_story(message, message.words, random.randint(0, 2) + 1)
|
||||
|
||||
def generate_story(self, message, words, sentences_count):
|
||||
word_ids = Word.where_in('word', words).lists('id').all()
|
||||
|
||||
return ' '.join([self.__generate_sentence(message, word_ids) for _ in range(sentences_count)])
|
||||
|
||||
def __generate_sentence(self, message, word_ids):
|
||||
sentences = []
|
||||
safety_counter = 50
|
||||
first_word_id = None
|
||||
second_word_id_list = word_ids
|
||||
|
||||
while safety_counter > 0:
|
||||
pair = Pair.get_random_pair(chat_id=message.chat.id,
|
||||
first_id=first_word_id,
|
||||
second_id_list=second_word_id_list)
|
||||
replies = getattr(pair, 'replies', [])
|
||||
safety_counter -= 1
|
||||
|
||||
if pair is None or len(replies) == 0:
|
||||
continue
|
||||
|
||||
reply = random.choice(replies.all())
|
||||
first_word_id = pair.second.id
|
||||
|
||||
# TODO. WARNING! Do not try to fix, it's magic, i have no clue why
|
||||
try:
|
||||
second_word_id_list = [reply.word.id]
|
||||
except AttributeError:
|
||||
second_word_id_list = None
|
||||
|
||||
if len(sentences) == 0:
|
||||
sentences.append(capitalize(pair.second.word))
|
||||
word_ids.remove(pair.second.id)
|
||||
|
||||
# TODO. WARNING! Do not try to fix, it's magic, i have no clue why
|
||||
try:
|
||||
reply_word = reply.word.word
|
||||
except AttributeError:
|
||||
reply_word = None
|
||||
|
||||
if reply_word is not None:
|
||||
sentences.append(reply_word)
|
||||
else:
|
||||
break
|
||||
|
||||
sentence = ' '.join(sentences).strip()
|
||||
if sentence[-1:] not in config['grammar']['end_sentence']:
|
||||
sentence += random_element(list(config['grammar']['end_sentence']))
|
||||
|
||||
return sentence
|
Loading…
Reference in New Issue