learner: add possibility to learn from chat log

This commit is contained in:
Gitea 2020-03-24 17:43:20 +01:00
parent d583b26db7
commit 15566baf1a
3 changed files with 35 additions and 2 deletions

View File

@ -19,10 +19,15 @@ class Bot:
purge_queue_instance = chat_purge_queue.instance(self.updater.job_queue)
self.dispatcher.add_handler(MessageHandler())
msghandler = MessageHandler()
self.dispatcher.add_handler(msghandler)
self.dispatcher.add_handler(CommandHandler())
self.dispatcher.add_handler(StatusHandler(chat_purge_queue=purge_queue_instance))
if config['bot']['learn_from_file']:
msghandler.data_learner.learn_from_file(config['bot']['learn_from_file'])
while True: pass
if config['updates']['mode'] == 'polling':
self.updater.start_polling()
elif config['updates']['mode'] == 'webhook':

View File

@ -5,7 +5,7 @@ import os
encoding = 'utf-8'
sections = {
'bot': ['token', 'name', 'anchors', 'god_mode', 'purge_interval', 'default_chance', 'spam_stickers'],
'bot': ['token', 'name', 'anchors', 'god_mode', 'purge_interval', 'default_chance', 'spam_stickers', 'mention_name', 'mention_name_full', 'learn_from_file'],
'grammar': ['chain_len', 'sep', 'stop_word', 'max_wrds', 'max_msgs', 'endsen', 'garbage', 'garbage_entities'],
'logging': ['level'],
'updates': ['mode'],

View File

@ -1,5 +1,10 @@
from src.config import trigram_repository, tokenizer
class FakeMessage:
def __init__(self, chat_id, text):
self.chat_id = chat_id
self.text = text
self.entities = []
class DataLearner:
def __init__(self):
@ -11,7 +16,30 @@ class DataLearner:
Split message to trigrams and write to DB
:param message: Message
"""
# print("{}".format(message))
words = self.tokenizer.extract_words(message)
trigrams = self.tokenizer.split_to_trigrams(words)
self.trigram_repository.store(message.chat_id, trigrams)
def learn_from_file(self, file):
with open(file, 'r') as fd:
# f = open('filtered.txt', 'w+')
content = fd.read()
# message = object()
# setattr(message, 'chat_id', -1001120788844)
strs = content.split(';;;')
j = 0
for text in strs:
import re
text = re.sub(r'https?:\/\/\S+', '', text, flags=re.MULTILINE)
text = re.sub(r'[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)', '', text, flags=re.MULTILINE)
text = re.sub(r"(^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$)", '', text, flags=re.MULTILINE)
text = re.sub(r'@\S+', '', text, flags=re.MULTILINE)
text = text.strip()
message = FakeMessage(-1001180265993, text)
self.learn(message)
#f.write(text + ';;;')
j += 1
if j % 1000 == 0:
print('learned %d messages' % j)