[WIP] #19
This commit is contained in:
parent
784214ef90
commit
9c56ad1728
|
@ -7,13 +7,18 @@ purge_interval=43200.0
|
|||
default_chance=5
|
||||
spam_stickers=BQADAgADSAIAAkcGQwU-G-9SZUDTWAI
|
||||
|
||||
[grammar]
|
||||
end_sentence=.....!!?
|
||||
all=.!?;()\-—"[]{}«»/*&^#$
|
||||
|
||||
[logging]
|
||||
level=INFO
|
||||
|
||||
[grammar]
|
||||
chain_length=2
|
||||
separator=\x02
|
||||
stop_word='\x00'
|
||||
max_words=30
|
||||
max_messages=5
|
||||
end_sentence=.....!!?
|
||||
all=.!?;()\-—"[]{}«»/*&^#$
|
||||
|
||||
[media_checker]
|
||||
lifetime=28800.0
|
||||
stickers=BQADAgADGwEAAjbsGwVVGLVNyOWfuwI
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
import logging.config
|
||||
from orator.orm import Model
|
||||
from orator import DatabaseManager
|
||||
from src.config import config, redis
|
||||
from src.entity.pair import Pair
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level='DEBUG')
|
||||
Model.set_connection_resolver(DatabaseManager({'db': config['db']}))
|
||||
|
||||
redis_c = redis.instance()
|
||||
counter = 0
|
||||
key = 'trigrams:{}:{}'
|
||||
for pairs in Pair.with_('replies', 'first', 'second').chunk(500):
|
||||
for pair in pairs:
|
||||
try:
|
||||
first = clear_word(pair.first.word)
|
||||
except AttributeError:
|
||||
first = config['grammar']['stop_word']
|
||||
|
||||
try:
|
||||
second = clear_word(pair.second.word)
|
||||
except AttributeError:
|
||||
second = config['grammar']['stop_word']
|
||||
|
||||
point = to_key(key=key, chat_id=pair.chat_id, pair=[first, second])
|
||||
replies = list(filter(None, map(get_word, pair.replies.all())))
|
||||
|
||||
if len(replies) == 0:
|
||||
continue
|
||||
|
||||
pipe = redis_c.pipeline()
|
||||
for reply in replies:
|
||||
pipe.sadd(point, reply)
|
||||
pipe.execute()
|
||||
|
||||
counter += 1
|
||||
|
||||
if counter % 1000 == 0:
|
||||
print("Imported: " + str(counter))
|
||||
|
||||
def clear_word(word):
|
||||
return word.strip(";()\-—\"[]{}«»/*&^#$")
|
||||
|
||||
def get_word(reply):
|
||||
try:
|
||||
return clear_word(reply.word.word)
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
def to_key(key, chat_id, pair):
|
||||
return key.format(chat_id, config['grammar']['separator'].join(pair))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,17 +0,0 @@
|
|||
from orator.migrations import Migration
|
||||
|
||||
|
||||
class CreateWordsTable(Migration):
|
||||
def up(self):
|
||||
"""
|
||||
Run the migrations.
|
||||
"""
|
||||
with self.schema.create('words') as table:
|
||||
table.increments('id')
|
||||
table.string('word').unique()
|
||||
|
||||
def down(self):
|
||||
"""
|
||||
Revert the migrations.
|
||||
"""
|
||||
self.schema.drop('words')
|
|
@ -1,24 +0,0 @@
|
|||
from orator.migrations import Migration
|
||||
|
||||
|
||||
class CreatePairsTable(Migration):
|
||||
def up(self):
|
||||
"""
|
||||
Run the migrations.
|
||||
"""
|
||||
with self.schema.create('pairs') as table:
|
||||
table.increments('id')
|
||||
table.integer('chat_id')
|
||||
table.integer('first_id').nullable()
|
||||
table.integer('second_id').nullable()
|
||||
# table.unique(
|
||||
# ['chat_id', 'first_id', 'second_id'],
|
||||
# name='unique_pairs_idx'
|
||||
# )
|
||||
table.timestamp('created_at').default('CURRENT_TIMESTAMP')
|
||||
|
||||
def down(self):
|
||||
"""
|
||||
Revert the migrations.
|
||||
"""
|
||||
self.schema.drop('pairs')
|
|
@ -1,19 +0,0 @@
|
|||
from orator.migrations import Migration
|
||||
|
||||
|
||||
class CreateRepliesTable(Migration):
|
||||
def up(self):
|
||||
"""
|
||||
Run the migrations.
|
||||
"""
|
||||
with self.schema.create('replies') as table:
|
||||
table.increments('id')
|
||||
table.integer('pair_id')
|
||||
table.integer('word_id').nullable()
|
||||
table.integer('count').default(1)
|
||||
|
||||
def down(self):
|
||||
"""
|
||||
Revert the migrations.
|
||||
"""
|
||||
self.schema.drop('replies')
|
|
@ -3,7 +3,7 @@ import configparser
|
|||
|
||||
sections = {
|
||||
'bot': ['token', 'name', 'anchors', 'messages', 'purge_interval', 'default_chance', 'spam_stickers'],
|
||||
'grammar': ['end_sentence', 'all'],
|
||||
'grammar': ['chain_length', 'separator', 'stop_word', 'end_sentence', 'all'],
|
||||
'logging': ['level'],
|
||||
'updates': ['mode'],
|
||||
'media_checker': ['lifetime', 'stickers'],
|
||||
|
@ -35,6 +35,7 @@ redis = Redis(config)
|
|||
from src.service import *
|
||||
chance_manager = ChanceManager()
|
||||
media_checker = MediaUniquenessChecker()
|
||||
tokenizer = Tokenizer()
|
||||
data_learner = DataLearner()
|
||||
reply_generator = ReplyGenerator()
|
||||
chat_purge_queue = ChatPurgeQueue()
|
||||
|
|
|
@ -13,10 +13,8 @@ class Message(AbstractEntity):
|
|||
|
||||
if self.has_text():
|
||||
self.text = message.text
|
||||
self.words = self.__get_words()
|
||||
else:
|
||||
self.text = ''
|
||||
self.words = []
|
||||
|
||||
def has_text(self):
|
||||
"""Returns True if the message has text.
|
||||
|
@ -56,25 +54,3 @@ class Message(AbstractEntity):
|
|||
or self.is_private() \
|
||||
or self.is_reply_to_bot() \
|
||||
or self.is_random_answer()
|
||||
|
||||
def __get_words(self):
|
||||
symbols = list(re.sub('\s', ' ', self.text))
|
||||
|
||||
def prettify(word):
|
||||
lowercase_word = word.lower().strip()
|
||||
last_symbol = lowercase_word[-1:]
|
||||
if last_symbol not in config['grammar']['end_sentence']:
|
||||
last_symbol = ''
|
||||
pretty_word = lowercase_word.strip(config['grammar']['all'])
|
||||
|
||||
if pretty_word != '' and len(pretty_word) > 2:
|
||||
return pretty_word + last_symbol
|
||||
elif lowercase_word in config['grammar']['all']:
|
||||
return None
|
||||
|
||||
return lowercase_word
|
||||
|
||||
for entity in self.message.entities:
|
||||
symbols[entity.offset:entity.length + entity.offset] = ' ' * entity.length
|
||||
|
||||
return list(filter(None, map(prettify, ''.join(symbols).split(' '))))
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
from .base import Base
|
||||
from src.entity.pair import Pair
|
||||
|
||||
|
||||
class GetStats(Base):
|
||||
|
@ -7,4 +6,5 @@ class GetStats(Base):
|
|||
|
||||
@staticmethod
|
||||
def execute(bot, command):
|
||||
GetStats.reply(bot, command, 'Pairs: {}'.format(Pair.where('chat_id', command.chat_id).count()))
|
||||
GetStats.reply(bot, command, 'Command currently disabled')
|
||||
# GetStats.reply(bot, command, 'Pairs: {}'.format(Pair.where('chat_id', command.chat_id).count()))
|
||||
|
|
|
@ -1,7 +1,3 @@
|
|||
from src.entity.pair import Pair
|
||||
from src.entity.reply import Reply
|
||||
from src.entity.word import Word
|
||||
from src.utils import safe_cast
|
||||
from .base import Base
|
||||
|
||||
|
||||
|
@ -17,10 +13,12 @@ class Moderate(Base):
|
|||
if len(command.args) == 0:
|
||||
raise IndexError
|
||||
|
||||
if safe_cast(command.args[0], int) is None:
|
||||
Moderate.reply(bot, command, Moderate.find_similar_words(command.chat_id, command.args[0]))
|
||||
else:
|
||||
Moderate.remove_word(command.chat_id, int(command.args[0]))
|
||||
Moderate.reply(bot, command, 'Command currently disabled.')
|
||||
|
||||
# if safe_cast(command.args[0], int) is None:
|
||||
# Moderate.reply(bot, command, Moderate.find_similar_words(command.chat_id, command.args[0]))
|
||||
# else:
|
||||
# Moderate.remove_word(command.chat_id, int(command.args[0]))
|
||||
except (IndexError, ValueError):
|
||||
Moderate.reply(bot, command, """Usage:
|
||||
/moderate <word> for search
|
||||
|
@ -35,61 +33,3 @@ class Moderate(Base):
|
|||
))
|
||||
|
||||
return user_id in admin_ids
|
||||
|
||||
@staticmethod
|
||||
def remove_word(chat_id, word_id):
|
||||
pairs_ids = Moderate.__find_pairs(chat_id, [word_id]).lists('id')
|
||||
|
||||
Pair.where_in('id', pairs_ids).delete()
|
||||
Reply.where_in('pair_id', pairs_ids).delete()
|
||||
|
||||
@staticmethod
|
||||
def find_similar_words(chat_id, word):
|
||||
found_words = Moderate.__find_chat_words(chat_id, word)
|
||||
|
||||
if len(found_words) == 0:
|
||||
return 'No words found!'
|
||||
|
||||
return Moderate.__formatted_view(found_words)
|
||||
|
||||
@staticmethod
|
||||
def __formatted_view(words):
|
||||
result = []
|
||||
for k, v in words.items():
|
||||
result.append("%s : %d" % (v, k))
|
||||
|
||||
return '\n'.join(result)
|
||||
|
||||
@staticmethod
|
||||
def __find_pairs(chat_id, word_ids):
|
||||
return Pair.where('chat_id', chat_id) \
|
||||
.where(
|
||||
Pair.query().where_in('first_id', word_ids)
|
||||
.or_where_in('second_id', word_ids)
|
||||
) \
|
||||
.get()
|
||||
|
||||
@staticmethod
|
||||
def __prepare_word(word):
|
||||
return word.strip("'\"")
|
||||
|
||||
@staticmethod
|
||||
def __find_chat_words(chat_id, search_word):
|
||||
found_words = Word.where('word', 'like', Moderate.__prepare_word(search_word) + '%') \
|
||||
.order_by('word', 'asc') \
|
||||
.limit(10) \
|
||||
.lists('word', 'id')
|
||||
|
||||
if len(found_words) == 0:
|
||||
return []
|
||||
|
||||
to_keep = []
|
||||
for pair in Moderate.__find_pairs(chat_id, list(found_words.keys())):
|
||||
if pair.first_id in found_words:
|
||||
to_keep.append(pair.first_id)
|
||||
if pair.second_id in found_words:
|
||||
to_keep.append(pair.second_id)
|
||||
|
||||
to_keep = set(to_keep)
|
||||
|
||||
return dict((k, found_words[k]) for k in found_words if k in to_keep)
|
||||
|
|
|
@ -3,3 +3,4 @@ from .chat_purge_queue import ChatPurgeQueue
|
|||
from .data_learner import DataLearner
|
||||
from .media_uniqueness_checker import MediaUniquenessChecker
|
||||
from .reply_generator import ReplyGenerator
|
||||
from .tokenizer import Tokenizer
|
||||
|
|
|
@ -3,8 +3,6 @@ import json
|
|||
|
||||
from datetime import datetime, timedelta
|
||||
from telegram.ext import Job
|
||||
from src.entity.reply import Reply
|
||||
from src.entity.pair import Pair
|
||||
from src.config import config, redis
|
||||
|
||||
|
||||
|
@ -74,8 +72,8 @@ class ChatPurgeQueue:
|
|||
|
||||
logging.info("Removing chat #%d data..." % chat_id)
|
||||
|
||||
for pairs in Pair.where('chat_id', chat_id).select('id').chunk(500):
|
||||
Reply.where_in('pair_id', pairs.pluck('id').all()).delete()
|
||||
Pair.where('chat_id', chat_id).delete()
|
||||
# for pairs in Pair.where('chat_id', chat_id).select('id').chunk(500):
|
||||
# Reply.where_in('pair_id', pairs.pluck('id').all()).delete()
|
||||
# Pair.where('chat_id', chat_id).delete()
|
||||
|
||||
self.redis.instance().hdel(self.key, chat_id)
|
||||
|
|
|
@ -1,56 +1,19 @@
|
|||
from collections import OrderedDict
|
||||
from src.config import config
|
||||
from src.entity.word import Word
|
||||
from src.entity.pair import Pair
|
||||
from src.config import redis, tokenizer
|
||||
|
||||
|
||||
class DataLearner:
|
||||
def __init__(self):
|
||||
self.redis = redis
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def learn(self, message):
|
||||
self.__write_new_unique_words(message.words)
|
||||
pipe = self.redis.instance().pipeline()
|
||||
|
||||
words = self.__normalize_words(message.words)
|
||||
while any(word for word in words):
|
||||
trigram = words[:3]
|
||||
first_word_id, second_word_id, *third_word_id = list(map(
|
||||
lambda x: None if x is None else Word.where('word', x).first(['id']).id,
|
||||
trigram
|
||||
))
|
||||
third_word_id = None if len(third_word_id) == 0 else third_word_id[0]
|
||||
words = self.tokenizer.extract_words(message)
|
||||
for trigram in self.tokenizer.split_to_trigrams(words):
|
||||
key = self.tokenizer.to_key(message.chat_id, trigram[:-1])
|
||||
last_word = trigram[-1]
|
||||
|
||||
words.pop(0)
|
||||
pipe.sadd(key, last_word)
|
||||
|
||||
pair = Pair.where('chat_id', message.chat_id) \
|
||||
.where('first_id', first_word_id) \
|
||||
.where('second_id', second_word_id) \
|
||||
.first()
|
||||
if pair is None:
|
||||
pair = Pair.create(chat_id=message.chat_id,
|
||||
first_id=first_word_id,
|
||||
second_id=second_word_id)
|
||||
|
||||
reply = pair.replies().where('word_id', third_word_id).first()
|
||||
|
||||
if reply is not None:
|
||||
reply.count += 1
|
||||
reply.save()
|
||||
else:
|
||||
pair.replies().create(pair_id=pair.id, word_id=third_word_id)
|
||||
|
||||
def __normalize_words(self, src_words):
|
||||
words = [None]
|
||||
for word in src_words:
|
||||
words.append(word)
|
||||
if word[-1] in config['grammar']['end_sentence']:
|
||||
words.append(None)
|
||||
if words[-1] is not None:
|
||||
words.append(None)
|
||||
|
||||
return words
|
||||
|
||||
def __write_new_unique_words(self, words):
|
||||
# TODO. Слова должны быть уникальные И ТАКЖЕ ОБЯЗАТЕЛЬНО в оригинальном порядке
|
||||
existing_words = Word.where_in('word', words).lists('word').all()
|
||||
new_words = [word for word in OrderedDict.fromkeys(words).keys() if word not in existing_words]
|
||||
|
||||
for word in new_words:
|
||||
Word.create(word=word)
|
||||
pipe.execute()
|
||||
|
|
|
@ -1,66 +1,56 @@
|
|||
import random
|
||||
from src.config import config
|
||||
from src.config import config, redis, tokenizer
|
||||
from src.utils import strings_has_equal_letters, capitalize, random_element
|
||||
from src.entity.word import Word
|
||||
from src.entity.pair import Pair
|
||||
|
||||
|
||||
class ReplyGenerator:
|
||||
def __init__(self):
|
||||
self.redis = redis
|
||||
self.tokenizer = tokenizer
|
||||
self.max_words = config.getint('grammar', 'max_words')
|
||||
self.max_messages = config.getint('grammar', 'max_messages')
|
||||
|
||||
def generate(self, message):
|
||||
result = self.generate_story(message, message.words, random.randint(0, 2) + 1)
|
||||
messages = []
|
||||
|
||||
words = self.tokenizer.extract_words(message)
|
||||
for trigram in self.tokenizer.split_to_trigrams(words):
|
||||
pair = trigram[:-1]
|
||||
|
||||
best_message = ''
|
||||
for _ in range(self.max_messages):
|
||||
generated = self.__generate_sentence(pair)
|
||||
if len(generated) > len(best_message):
|
||||
best_message = generated
|
||||
|
||||
if best_message:
|
||||
messages.append(best_message)
|
||||
|
||||
result = random_element(messages) if len(messages) else ''
|
||||
|
||||
if strings_has_equal_letters(result, ''.join(message.words)):
|
||||
return ''
|
||||
|
||||
return result
|
||||
|
||||
def generate_story(self, message, words, sentences_count):
|
||||
word_ids = Word.where_in('word', words).lists('id').all()
|
||||
def __generate_sentence(self, seed):
|
||||
key = seed
|
||||
gen_words = []
|
||||
redis = self.redis.instance()
|
||||
|
||||
return ' '.join([self.__generate_sentence(message, word_ids) for _ in range(sentences_count)])
|
||||
|
||||
def __generate_sentence(self, message, word_ids):
|
||||
sentences = []
|
||||
safety_counter = 50
|
||||
first_word_id = None
|
||||
second_word_id_list = word_ids
|
||||
|
||||
while safety_counter > 0:
|
||||
pair = Pair.get_random_pair(chat_id=message.chat_id,
|
||||
first_id=first_word_id,
|
||||
second_id_list=second_word_id_list)
|
||||
replies = getattr(pair, 'replies', [])
|
||||
safety_counter -= 1
|
||||
|
||||
if pair is None or len(replies) == 0:
|
||||
continue
|
||||
|
||||
reply = random.choice(replies.all())
|
||||
first_word_id = pair.second.id
|
||||
|
||||
# FIXME. WARNING! Do not try to fix, it's magic, i have no clue why
|
||||
try:
|
||||
second_word_id_list = [reply.word.id]
|
||||
except AttributeError:
|
||||
second_word_id_list = None
|
||||
|
||||
if len(sentences) == 0:
|
||||
sentences.append(capitalize(pair.second.word))
|
||||
word_ids.remove(pair.second.id)
|
||||
|
||||
# FIXME. WARNING! Do not try to fix, it's magic, i have no clue why
|
||||
try:
|
||||
reply_word = reply.word.word
|
||||
except AttributeError:
|
||||
reply_word = None
|
||||
|
||||
if reply_word is not None:
|
||||
sentences.append(reply_word)
|
||||
for _ in range(self.max_words):
|
||||
if len(gen_words):
|
||||
gen_words.append(key[0])
|
||||
else:
|
||||
gen_words.append(capitalize(key[0]))
|
||||
|
||||
next_word = redis.srandmember(self.tokenizer.to_key(key))
|
||||
if not next_word:
|
||||
break
|
||||
|
||||
sentence = ' '.join(sentences).strip()
|
||||
if sentence[-1:] not in config['grammar']['end_sentence']:
|
||||
sentence += random_element(list(config['grammar']['end_sentence']))
|
||||
key = self.tokenizer.separator.join(key[1:] + [next_word])
|
||||
|
||||
return sentence
|
||||
sentence = ' '.join(gen_words).strip()
|
||||
if sentence[-1:] not in self.tokenizer.end_sentence:
|
||||
sentence += self.tokenizer.random_end_sentence_token()
|
||||
|
||||
return ' '.join(gen_words)
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
import re
|
||||
from src.utils import random_element
|
||||
from src.config import config
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
def __init__(self):
|
||||
self.key = 'trigrams:{}:{}'
|
||||
self.chain_length = config.getint('grammar', 'chain_length')
|
||||
self.separator = config['grammar']['separator']
|
||||
self.stop_word = config['grammar']['stop_word']
|
||||
self.end_sentence = config['grammar']['end_sentence']
|
||||
self.garbage_tokens = config['grammar']['all']
|
||||
|
||||
def split_to_trigrams(self, words):
|
||||
if len(words) <= self.chain_length:
|
||||
yield from ()
|
||||
|
||||
words.append(self.stop_word)
|
||||
for i in range(len(words) - self.chain_length):
|
||||
yield words[i:i + self.chain_length + 1]
|
||||
|
||||
def extract_words(self, message):
|
||||
symbols = list(re.sub('\s', ' ', message.text))
|
||||
|
||||
for entity in message.entities:
|
||||
symbols[entity.offset:entity.length + entity.offset] = ' ' * entity.length
|
||||
|
||||
return list(filter(None, map(self.__prettify, ''.join(symbols).split(' '))))
|
||||
|
||||
def to_key(self, chat_id, pair):
|
||||
return self.key.format(chat_id, self.separator.join(pair))
|
||||
|
||||
def random_end_sentence_token(self):
|
||||
return random_element(list(self.end_sentence))
|
||||
|
||||
def __prettify(self, word):
|
||||
lowercase_word = word.lower().strip()
|
||||
last_symbol = lowercase_word[-1:]
|
||||
if last_symbol not in self.end_sentence:
|
||||
last_symbol = ''
|
||||
pretty_word = lowercase_word.strip(self.garbage_tokens)
|
||||
|
||||
if pretty_word != '' and len(pretty_word) > 2:
|
||||
return pretty_word + last_symbol
|
||||
elif lowercase_word in self.garbage_tokens:
|
||||
return None
|
||||
|
||||
return lowercase_word
|
Loading…
Reference in New Issue