Separated learning and generation by components
This commit is contained in:
parent
889a8f7fde
commit
6ee9feb61a
10
src/bot.py
10
src/bot.py
|
@ -1,10 +1,12 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from telegram.ext import Updater
|
from telegram.ext import Updater
|
||||||
from src.handlers.message_handler import MessageHandler
|
from src.handler.message_handler import MessageHandler
|
||||||
from src.handlers.command_handler import CommandHandler
|
from src.handler.command_handler import CommandHandler
|
||||||
from src.handlers.status_handler import StatusHandler
|
from src.handler.status_handler import StatusHandler
|
||||||
from src.chat_purge_queue import ChatPurgeQueue
|
from src.chat_purge_queue import ChatPurgeQueue
|
||||||
|
from src.service.reply_generator import ReplyGenerator
|
||||||
|
from src.service.data_learner import DataLearner
|
||||||
from src.config import config
|
from src.config import config
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,7 +18,7 @@ class Bot:
|
||||||
def run(self):
|
def run(self):
|
||||||
logging.info("Bot started")
|
logging.info("Bot started")
|
||||||
|
|
||||||
self.dispatcher.add_handler(MessageHandler())
|
self.dispatcher.add_handler(MessageHandler(data_learner=DataLearner(), reply_generator=ReplyGenerator()))
|
||||||
self.dispatcher.add_handler(CommandHandler())
|
self.dispatcher.add_handler(CommandHandler())
|
||||||
self.dispatcher.add_handler(StatusHandler(chat_purge_queue=ChatPurgeQueue(self.updater.job_queue)))
|
self.dispatcher.add_handler(StatusHandler(chat_purge_queue=ChatPurgeQueue(self.updater.job_queue)))
|
||||||
|
|
||||||
|
|
|
@ -6,12 +6,10 @@ from orator.orm import has_many
|
||||||
import src.entity.reply
|
import src.entity.reply
|
||||||
import src.entity.chat
|
import src.entity.chat
|
||||||
import src.entity.word
|
import src.entity.word
|
||||||
from src.utils import *
|
|
||||||
from src.config import config
|
|
||||||
|
|
||||||
|
|
||||||
class Pair(Model):
|
class Pair(Model):
|
||||||
__guarded__ = ['id']
|
__fillable__ = ['chat_id', 'first_id', 'second_id']
|
||||||
__timestamps__ = ['created_at']
|
__timestamps__ = ['created_at']
|
||||||
|
|
||||||
@has_many
|
@has_many
|
||||||
|
@ -31,113 +29,15 @@ class Pair(Model):
|
||||||
return src.entity.word.Word
|
return src.entity.word.Word
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate(message):
|
def get_random_pair(chat_id, first_id, second_id_list):
|
||||||
return Pair.generate_story(message, message.words, random.randint(0, 2) + 1)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def generate_story(message, words, sentences):
|
|
||||||
words_ids = src.entity.word.Word.where_in('word', words).get().pluck('id').all()
|
|
||||||
|
|
||||||
result = []
|
|
||||||
for _ in range(0, sentences):
|
|
||||||
result.append(Pair.__generate_sentence(message, words_ids))
|
|
||||||
|
|
||||||
return ' '.join(result)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def learn(message):
|
|
||||||
src.entity.word.Word.learn(message.words)
|
|
||||||
|
|
||||||
words = [None]
|
|
||||||
for word in message.words:
|
|
||||||
words.append(word)
|
|
||||||
if word[-1] in config['grammar']['end_sentence']:
|
|
||||||
words.append(None)
|
|
||||||
if words[-1] is not None:
|
|
||||||
words.append(None)
|
|
||||||
|
|
||||||
while any(word for word in words):
|
|
||||||
trigram = words[:3]
|
|
||||||
first_word_id, second_word_id, *third_word_id = list(map(
|
|
||||||
lambda x: None if x is None else src.entity.word.Word.where('word', x).first().id,
|
|
||||||
trigram
|
|
||||||
))
|
|
||||||
third_word_id = None if len(third_word_id) == 0 else third_word_id[0]
|
|
||||||
|
|
||||||
words.pop(0)
|
|
||||||
|
|
||||||
pair = Pair.where('chat_id', message.chat.id)\
|
|
||||||
.where('first_id', first_word_id)\
|
|
||||||
.where('second_id', second_word_id)\
|
|
||||||
.first()
|
|
||||||
if pair is None:
|
|
||||||
pair = Pair.create(chat_id=message.chat.id,
|
|
||||||
first_id=first_word_id,
|
|
||||||
second_id=second_word_id)
|
|
||||||
|
|
||||||
reply = pair.replies().where('word_id', third_word_id).first()
|
|
||||||
|
|
||||||
if reply is not None:
|
|
||||||
reply.count += 1
|
|
||||||
reply.save()
|
|
||||||
else:
|
|
||||||
pair.replies().create(pair_id=pair.id, word_id=third_word_id)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def __generate_sentence(message, word_ids):
|
|
||||||
sentences = []
|
|
||||||
safety_counter = 50
|
|
||||||
first_word = None
|
|
||||||
second_words = list(word_ids)
|
|
||||||
|
|
||||||
while safety_counter > 0:
|
|
||||||
pair = Pair.__get_pair(chat_id=message.chat.id, first_id=first_word, second_ids=second_words)
|
|
||||||
replies = getattr(pair, 'replies', [])
|
|
||||||
safety_counter -= 1
|
|
||||||
|
|
||||||
if pair is None or len(replies) == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
reply = random.choice(replies.all())
|
|
||||||
first_word = pair.second.id
|
|
||||||
|
|
||||||
# TODO. WARNING! Do not try to fix, it's magic, i have no clue why
|
|
||||||
try:
|
|
||||||
second_words = [reply.word.id]
|
|
||||||
except AttributeError:
|
|
||||||
second_words = None
|
|
||||||
|
|
||||||
if len(sentences) == 0:
|
|
||||||
sentences.append(capitalize(pair.second.word))
|
|
||||||
word_ids.remove(pair.second.id)
|
|
||||||
|
|
||||||
# TODO. WARNING! Do not try to fix, it's magic, i have no clue why
|
|
||||||
try:
|
|
||||||
reply_word = reply.word.word
|
|
||||||
except AttributeError:
|
|
||||||
reply_word = None
|
|
||||||
|
|
||||||
if reply_word is not None:
|
|
||||||
sentences.append(reply_word)
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
sentence = ' '.join(sentences).strip()
|
|
||||||
if sentence[-1:] not in config['grammar']['end_sentence']:
|
|
||||||
sentence += random_element(list(config['grammar']['end_sentence']))
|
|
||||||
|
|
||||||
return sentence
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def __get_pair(chat_id, first_id, second_ids):
|
|
||||||
ten_minutes_ago = datetime.now() - timedelta(seconds=10 * 60)
|
ten_minutes_ago = datetime.now() - timedelta(seconds=10 * 60)
|
||||||
pairs = Pair.with_('replies')\
|
return Pair\
|
||||||
|
.with_({
|
||||||
|
'replies': lambda q: q.order_by('count', 'desc').limit(3)
|
||||||
|
})\
|
||||||
.where('chat_id', chat_id)\
|
.where('chat_id', chat_id)\
|
||||||
.where('first_id', first_id)\
|
.where('first_id', first_id)\
|
||||||
.where_in('second_id', second_ids)\
|
.where_in('second_id', second_id_list)\
|
||||||
.where('created_at', '<', ten_minutes_ago)\
|
.where('created_at', '<', ten_minutes_ago)\
|
||||||
.limit(3)\
|
.order_by_raw('RANDOM()')\
|
||||||
.get()\
|
.first()
|
||||||
.all()
|
|
||||||
|
|
||||||
return random_element(pairs)
|
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
from orator.orm import Model
|
from orator.orm import Model
|
||||||
from orator.orm import has_many
|
from orator.orm import has_many
|
||||||
|
|
||||||
|
@ -13,12 +11,3 @@ class Word(Model):
|
||||||
@has_many
|
@has_many
|
||||||
def chats(self):
|
def chats(self):
|
||||||
return src.entity.chat.Chat
|
return src.entity.chat.Chat
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def learn(words):
|
|
||||||
existing_words = Word.where_in('word', words).get().pluck('word').all()
|
|
||||||
# TODO. Слова должны быть уникальные И ТАКЖЕ ОБЯЗАТЕЛЬНО в оригинальном порядке
|
|
||||||
new_words = [word for word in OrderedDict.fromkeys(words).keys() if word not in existing_words]
|
|
||||||
|
|
||||||
for word in new_words:
|
|
||||||
Word.create(word=word)
|
|
||||||
|
|
|
@ -3,16 +3,18 @@ import logging
|
||||||
from telegram.ext import MessageHandler as ParentHandler, Filters
|
from telegram.ext import MessageHandler as ParentHandler, Filters
|
||||||
|
|
||||||
from src.domain.message import Message
|
from src.domain.message import Message
|
||||||
from src.entity.pair import Pair
|
|
||||||
from src.entity.chat import Chat
|
from src.entity.chat import Chat
|
||||||
|
|
||||||
|
|
||||||
class MessageHandler(ParentHandler):
|
class MessageHandler(ParentHandler):
|
||||||
def __init__(self):
|
def __init__(self, data_learner, reply_generator):
|
||||||
super(MessageHandler, self).__init__(
|
super(MessageHandler, self).__init__(
|
||||||
Filters.text | Filters.sticker,
|
Filters.text | Filters.sticker,
|
||||||
self.handle)
|
self.handle)
|
||||||
|
|
||||||
|
self.data_learner = data_learner
|
||||||
|
self.reply_generator = reply_generator
|
||||||
|
|
||||||
def handle(self, bot, update):
|
def handle(self, bot, update):
|
||||||
chat = Chat.get_chat(update.message)
|
chat = Chat.get_chat(update.message)
|
||||||
message = Message(chat=chat, message=update.message)
|
message = Message(chat=chat, message=update.message)
|
||||||
|
@ -29,14 +31,13 @@ class MessageHandler(ParentHandler):
|
||||||
return self.__process_sticker(bot, message)
|
return self.__process_sticker(bot, message)
|
||||||
|
|
||||||
def __process_message(self, bot, message):
|
def __process_message(self, bot, message):
|
||||||
Pair.learn(message)
|
self.data_learner.learn(message)
|
||||||
|
|
||||||
if message.has_anchors() \
|
if message.has_anchors() \
|
||||||
or message.is_private() \
|
or message.is_private() \
|
||||||
or message.is_reply_to_bot() \
|
or message.is_reply_to_bot() \
|
||||||
or message.is_random_answer():
|
or message.is_random_answer():
|
||||||
|
reply = self.reply_generator.generate(message)
|
||||||
reply = Pair.generate(message)
|
|
||||||
if reply != '':
|
if reply != '':
|
||||||
self.__answer(bot, message, reply)
|
self.__answer(bot, message, reply)
|
||||||
|
|
||||||
|
@ -56,16 +57,6 @@ class MessageHandler(ParentHandler):
|
||||||
|
|
||||||
bot.sendMessage(chat_id=message.chat.telegram_id, text=reply)
|
bot.sendMessage(chat_id=message.chat.telegram_id, text=reply)
|
||||||
|
|
||||||
def __reply(self, bot, message, reply):
|
|
||||||
logging.debug("[Chat %s %s reply] %s" %
|
|
||||||
(message.chat.chat_type,
|
|
||||||
message.chat.telegram_id,
|
|
||||||
reply))
|
|
||||||
|
|
||||||
bot.sendMessage(chat_id=message.chat.telegram_id,
|
|
||||||
reply_to_message_id=message.message.message_id,
|
|
||||||
text=reply)
|
|
||||||
|
|
||||||
def __send_sticker(self, bot, message, sticker_id):
|
def __send_sticker(self, bot, message, sticker_id):
|
||||||
logging.debug("[Chat %s %s send_sticker]" %
|
logging.debug("[Chat %s %s send_sticker]" %
|
||||||
(message.chat.chat_type, message.chat.telegram_id))
|
(message.chat.chat_type, message.chat.telegram_id))
|
|
@ -0,0 +1,59 @@
|
||||||
|
from collections import OrderedDict
|
||||||
|
from src.config import config
|
||||||
|
from src.entity.word import Word
|
||||||
|
from src.entity.pair import Pair
|
||||||
|
|
||||||
|
|
||||||
|
class DataLearner:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def learn(self, message):
|
||||||
|
self.__write_new_unique_words(message.words)
|
||||||
|
|
||||||
|
words = self.__normalize_words(message.words)
|
||||||
|
while any(word for word in words):
|
||||||
|
trigram = words[:3]
|
||||||
|
first_word_id, second_word_id, *third_word_id = list(map(
|
||||||
|
lambda x: None if x is None else Word.where('word', x).first(['id']).id,
|
||||||
|
trigram
|
||||||
|
))
|
||||||
|
third_word_id = None if len(third_word_id) == 0 else third_word_id[0]
|
||||||
|
|
||||||
|
words.pop(0)
|
||||||
|
|
||||||
|
pair = Pair.where('chat_id', message.chat.id) \
|
||||||
|
.where('first_id', first_word_id) \
|
||||||
|
.where('second_id', second_word_id) \
|
||||||
|
.first()
|
||||||
|
if pair is None:
|
||||||
|
pair = Pair.create(chat_id=message.chat.id,
|
||||||
|
first_id=first_word_id,
|
||||||
|
second_id=second_word_id)
|
||||||
|
|
||||||
|
reply = pair.replies().where('word_id', third_word_id).first()
|
||||||
|
|
||||||
|
if reply is not None:
|
||||||
|
reply.count += 1
|
||||||
|
reply.save()
|
||||||
|
else:
|
||||||
|
pair.replies().create(pair_id=pair.id, word_id=third_word_id)
|
||||||
|
|
||||||
|
def __normalize_words(self, src_words):
|
||||||
|
words = [None]
|
||||||
|
for word in src_words:
|
||||||
|
words.append(word)
|
||||||
|
if word[-1] in config['grammar']['end_sentence']:
|
||||||
|
words.append(None)
|
||||||
|
if words[-1] is not None:
|
||||||
|
words.append(None)
|
||||||
|
|
||||||
|
return words
|
||||||
|
|
||||||
|
def __write_new_unique_words(self, words):
|
||||||
|
# TODO. Слова должны быть уникальные И ТАКЖЕ ОБЯЗАТЕЛЬНО в оригинальном порядке
|
||||||
|
existing_words = Word.where_in('word', words).lists('word').all()
|
||||||
|
new_words = [word for word in OrderedDict.fromkeys(words).keys() if word not in existing_words]
|
||||||
|
|
||||||
|
for word in new_words:
|
||||||
|
Word.create(word=word)
|
|
@ -0,0 +1,63 @@
|
||||||
|
from src.config import config
|
||||||
|
from src.utils import *
|
||||||
|
from src.entity.word import Word
|
||||||
|
from src.entity.pair import Pair
|
||||||
|
|
||||||
|
|
||||||
|
class ReplyGenerator:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def generate(self, message):
|
||||||
|
return self.generate_story(message, message.words, random.randint(0, 2) + 1)
|
||||||
|
|
||||||
|
def generate_story(self, message, words, sentences_count):
|
||||||
|
word_ids = Word.where_in('word', words).lists('id').all()
|
||||||
|
|
||||||
|
return ' '.join([self.__generate_sentence(message, word_ids) for _ in range(sentences_count)])
|
||||||
|
|
||||||
|
def __generate_sentence(self, message, word_ids):
|
||||||
|
sentences = []
|
||||||
|
safety_counter = 50
|
||||||
|
first_word_id = None
|
||||||
|
second_word_id_list = word_ids
|
||||||
|
|
||||||
|
while safety_counter > 0:
|
||||||
|
pair = Pair.get_random_pair(chat_id=message.chat.id,
|
||||||
|
first_id=first_word_id,
|
||||||
|
second_id_list=second_word_id_list)
|
||||||
|
replies = getattr(pair, 'replies', [])
|
||||||
|
safety_counter -= 1
|
||||||
|
|
||||||
|
if pair is None or len(replies) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
reply = random.choice(replies.all())
|
||||||
|
first_word_id = pair.second.id
|
||||||
|
|
||||||
|
# TODO. WARNING! Do not try to fix, it's magic, i have no clue why
|
||||||
|
try:
|
||||||
|
second_word_id_list = [reply.word.id]
|
||||||
|
except AttributeError:
|
||||||
|
second_word_id_list = None
|
||||||
|
|
||||||
|
if len(sentences) == 0:
|
||||||
|
sentences.append(capitalize(pair.second.word))
|
||||||
|
word_ids.remove(pair.second.id)
|
||||||
|
|
||||||
|
# TODO. WARNING! Do not try to fix, it's magic, i have no clue why
|
||||||
|
try:
|
||||||
|
reply_word = reply.word.word
|
||||||
|
except AttributeError:
|
||||||
|
reply_word = None
|
||||||
|
|
||||||
|
if reply_word is not None:
|
||||||
|
sentences.append(reply_word)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
sentence = ' '.join(sentences).strip()
|
||||||
|
if sentence[-1:] not in config['grammar']['end_sentence']:
|
||||||
|
sentence += random_element(list(config['grammar']['end_sentence']))
|
||||||
|
|
||||||
|
return sentence
|
Loading…
Reference in New Issue