From dcdb4ef60d66e591249dc874c99c73353e0aa46f Mon Sep 17 00:00:00 2001 From: BarsTiger Date: Tue, 14 Feb 2023 13:49:58 +0200 Subject: [PATCH] Admin and start handlers work. Added testing txt2img from first bot version (will be replaced soon) --- bot/config.py | 2 +- bot/db/pull_db.py | 1 - bot/handlers/admin/__init__.py | 6 +++ bot/handlers/admin/aliases.py | 23 +++++++++ bot/handlers/help/help_strings.py | 2 - bot/handlers/help_command/__init__.py | 6 +++ .../help.py => help_command/help_handler.py} | 2 - bot/handlers/help_command/help_strings.py | 3 ++ bot/handlers/initialize/__init__.py | 9 +++- bot/handlers/initialize/all_messages.py | 7 +++ bot/handlers/initialize/pull_db.py | 7 --- bot/handlers/initialize/start.py | 23 +++++++++ bot/handlers/register.py | 11 +++-- bot/handlers/txt2img/__init__.py | 6 +++ bot/handlers/txt2img/txt2img.py | 31 ++++++++++++ .../help => modules/api}/__init__.py | 0 bot/modules/api/models.py | 31 ++++++++++++ bot/modules/api/txt2img.py | 32 ++++++++++++ bot/utils/commands.py | 2 +- bot/utils/cooldown.py | 49 +++++++++++++++++++ 20 files changed, 234 insertions(+), 19 deletions(-) create mode 100644 bot/handlers/admin/__init__.py create mode 100644 bot/handlers/admin/aliases.py delete mode 100644 bot/handlers/help/help_strings.py create mode 100644 bot/handlers/help_command/__init__.py rename bot/handlers/{help/help.py => help_command/help_handler.py} (91%) create mode 100644 bot/handlers/help_command/help_strings.py create mode 100644 bot/handlers/initialize/all_messages.py delete mode 100644 bot/handlers/initialize/pull_db.py create mode 100644 bot/handlers/initialize/start.py create mode 100644 bot/handlers/txt2img/__init__.py create mode 100644 bot/handlers/txt2img/txt2img.py rename bot/{handlers/help => modules/api}/__init__.py (100%) create mode 100644 bot/modules/api/models.py create mode 100644 bot/modules/api/txt2img.py create mode 100644 bot/utils/cooldown.py diff --git a/bot/config.py b/bot/config.py index af3d94e..cac6699 100644 --- a/bot/config.py +++ b/bot/config.py @@ -4,7 +4,7 @@ from dotenv import load_dotenv load_dotenv() TOKEN = os.getenv('TOKEN') -ADMIN = os.getenv('ADMIN') +ADMIN = int(os.getenv('ADMIN')) DB_CHAT = os.getenv('DB_CHAT') _DB_PATH = os.getenv('DB_PATH') DB = _DB_PATH + '/db' diff --git a/bot/db/pull_db.py b/bot/db/pull_db.py index 065efce..bd01dac 100644 --- a/bot/db/pull_db.py +++ b/bot/db/pull_db.py @@ -40,7 +40,6 @@ async def pull(): from .db import db for table in DBTables.tables: - db[table].clear() new_table = SqliteDict(DB + 'b', tablename=table) for key in new_table.keys(): db[table][key] = new_table[key] diff --git a/bot/handlers/admin/__init__.py b/bot/handlers/admin/__init__.py new file mode 100644 index 0000000..80c28e4 --- /dev/null +++ b/bot/handlers/admin/__init__.py @@ -0,0 +1,6 @@ +from bot.common import dp +from .aliases import * + + +def register(): + dp.register_message_handler(set_endpoint, commands='setendpoint') diff --git a/bot/handlers/admin/aliases.py b/bot/handlers/admin/aliases.py new file mode 100644 index 0000000..10acff5 --- /dev/null +++ b/bot/handlers/admin/aliases.py @@ -0,0 +1,23 @@ +from aiogram import types +from bot.db import db, DBTables +import validators +from bot.config import ADMIN +from bot.utils.cooldown import throttle + + +@throttle(5) +async def set_endpoint(message: types.Message): + if message.from_id not in db[DBTables.config].get('admins') and message.from_id != ADMIN: + await message.reply('❌ You are not permitted to do that. ' + 'It is only for this bot instance maintainers and admins') + return + + if not message.get_args() or not validators.url(message.get_args()): + await message.reply("❌ Specify correct url for endpoint") + return + + db[DBTables.config]['endpoint'] = message.get_args() + + await db[DBTables.config].write() + + await message.reply("✅ New url set") diff --git a/bot/handlers/help/help_strings.py b/bot/handlers/help/help_strings.py deleted file mode 100644 index fd7195f..0000000 --- a/bot/handlers/help/help_strings.py +++ /dev/null @@ -1,2 +0,0 @@ -help_data = { -} diff --git a/bot/handlers/help_command/__init__.py b/bot/handlers/help_command/__init__.py new file mode 100644 index 0000000..85b8c8d --- /dev/null +++ b/bot/handlers/help_command/__init__.py @@ -0,0 +1,6 @@ +from bot.common import dp +from . import help_handler + + +def register(): + dp.register_message_handler(help_handler.help_command, commands='help') diff --git a/bot/handlers/help/help.py b/bot/handlers/help_command/help_handler.py similarity index 91% rename from bot/handlers/help/help.py rename to bot/handlers/help_command/help_handler.py index aa184f4..d0da708 100644 --- a/bot/handlers/help/help.py +++ b/bot/handlers/help_command/help_handler.py @@ -1,9 +1,7 @@ from aiogram import types -from bot.common import dp from .help_strings import help_data -@dp.message_handler(commands='help') async def help_command(message: types.Message): if message.get_args() == "": await message.reply( diff --git a/bot/handlers/help_command/help_strings.py b/bot/handlers/help_command/help_strings.py new file mode 100644 index 0000000..c34d2c9 --- /dev/null +++ b/bot/handlers/help_command/help_strings.py @@ -0,0 +1,3 @@ +help_data = { + 'setendpoint': '(admin) Set StableDiffusion API endpoint' +} diff --git a/bot/handlers/initialize/__init__.py b/bot/handlers/initialize/__init__.py index 2ae2839..acb0de8 100644 --- a/bot/handlers/initialize/__init__.py +++ b/bot/handlers/initialize/__init__.py @@ -1 +1,8 @@ -pass +from bot.common import dp, bot +from .start import * +from .all_messages import * + + +def register(): + dp.register_message_handler(all_messages.sync_db_filter, lambda *_: not hasattr(bot, 'cloudmeta_message_text')) + dp.register_message_handler(start.start_command, commands='start') diff --git a/bot/handlers/initialize/all_messages.py b/bot/handlers/initialize/all_messages.py new file mode 100644 index 0000000..321ef04 --- /dev/null +++ b/bot/handlers/initialize/all_messages.py @@ -0,0 +1,7 @@ +from aiogram.types import Message +from bot.db.pull_db import pull + + +async def sync_db_filter(message: Message): + await pull() + await message.reply(f'🔄️ Bot database synchronised. If you tried to run a command, run it again') diff --git a/bot/handlers/initialize/pull_db.py b/bot/handlers/initialize/pull_db.py deleted file mode 100644 index ab0a26d..0000000 --- a/bot/handlers/initialize/pull_db.py +++ /dev/null @@ -1,7 +0,0 @@ -from bot.common import dp -from bot.db.pull_db import pull - - -@dp.message_handler() -async def pull_db_if_new(_): - await pull() diff --git a/bot/handlers/initialize/start.py b/bot/handlers/initialize/start.py new file mode 100644 index 0000000..37c0b16 --- /dev/null +++ b/bot/handlers/initialize/start.py @@ -0,0 +1,23 @@ +from aiogram import types +from bot.db import db, DBTables +from bot.config import ADMIN +from bot.utils.cooldown import throttle + + +@throttle(10) +async def start_command(message: types.Message): + if message.from_id == ADMIN: + await message.reply(f'👋 Hello, {message.from_user.username}. You are admin of this instance, ' + f'so we will check config for you now') + if not isinstance(db[DBTables.config].get('admins'), list): + db[DBTables.config]['admins'] = list() + if ADMIN not in db[DBTables.config].get('admins'): + admins_ = db[DBTables.config].get('admins') + admins_.append(ADMIN) + db[DBTables.config]['admins'] = admins_ + await db[DBTables.config].write() + await message.reply(f'✅ Added {message.from_user.username} to admins. You can add other admins, ' + f'check bot settings menu') + return + + await message.reply(f'👋 Hello, {message.from_user.username}. Use /help to see available commands.') diff --git a/bot/handlers/register.py b/bot/handlers/register.py index 7e29201..4ca332c 100644 --- a/bot/handlers/register.py +++ b/bot/handlers/register.py @@ -2,10 +2,13 @@ from rich import print def import_handlers(): - import bot.handlers.help.help - assert bot.handlers.help.help + from bot.handlers import ( + initialize, admin, help_command, txt2img + ) - import bot.handlers.initialize.pull_db - assert bot.handlers.initialize.pull_db + initialize.register() + admin.register() + help_command.register() + txt2img.register() print('[gray]All handlers imported[/]') diff --git a/bot/handlers/txt2img/__init__.py b/bot/handlers/txt2img/__init__.py new file mode 100644 index 0000000..39e2ecf --- /dev/null +++ b/bot/handlers/txt2img/__init__.py @@ -0,0 +1,6 @@ +from bot.common import dp +from .txt2img import * + + +def register(): + dp.register_message_handler(txt2img, commands='txt2img') diff --git a/bot/handlers/txt2img/txt2img.py b/bot/handlers/txt2img/txt2img.py new file mode 100644 index 0000000..5dc26b6 --- /dev/null +++ b/bot/handlers/txt2img/txt2img.py @@ -0,0 +1,31 @@ +from aiogram import types +from bot.db import db, DBTables +from bot.utils.cooldown import throttle +import os.path +from bot.modules.api.txt2img import txt2img + + +@throttle(cooldown=30, admin_ids=db[DBTables.config].get('admins')) +async def txt2img_comand(message: types.Message): + temp_message = await message.reply("⏳ Generating image...") + if not message.get_args(): + await temp_message.edit_text("Specify prompt for this command. Check /help txt2img") + return + + if not os.path.isfile('adstring'): + with open('adstring', 'w') as f: + f.write('@aiistop_bot') + + try: + image = await txt2img(message.get_args()) + await message.reply_photo(photo=image[0], caption=str( + image[1]["infotexts"][0]) + "\n\n" + open('adstring').read()) + except Exception as e: + assert e + await message.reply("We ran into error while processing your request. StableDiffusion models may not be " + "configured on specified endpoint or server with StableDiffusion may be turned " + "off. Ask admins of this bot instance if you have contacts for further info") + await temp_message.delete() + return + + await temp_message.delete() diff --git a/bot/handlers/help/__init__.py b/bot/modules/api/__init__.py similarity index 100% rename from bot/handlers/help/__init__.py rename to bot/modules/api/__init__.py diff --git a/bot/modules/api/models.py b/bot/modules/api/models.py new file mode 100644 index 0000000..2de3006 --- /dev/null +++ b/bot/modules/api/models.py @@ -0,0 +1,31 @@ +import aiohttp +from bot.db import db, DBTables +from rich import print + + +async def get_models(): + endpoint = db[DBTables.config].get('endpoint') + try: + async with aiohttp.ClientSession() as session: + r = await session.get(endpoint + "/sdapi/v1/sd-models") + if r.status != 200: + return None + return [x["title"] for x in await r.json()] + except Exception as e: + print(e) + return None + + +async def set_model(model_name: str): + endpoint = db[DBTables.config].get('endpoint') + try: + async with aiohttp.ClientSession() as session: + r = await session.post(endpoint + "/sdapi/v1/options", json={ + "sd_model_checkpoint": model_name + }) + if r.status != 200: + return False + return True + except Exception as e: + print(e) + return False diff --git a/bot/modules/api/txt2img.py b/bot/modules/api/txt2img.py new file mode 100644 index 0000000..963f361 --- /dev/null +++ b/bot/modules/api/txt2img.py @@ -0,0 +1,32 @@ +import aiohttp +from bot.db import db, DBTables +import json +import base64 + + +async def txt2img(prompt: str, negative_prompt: str = None, steps: int = 20, + cfg_scale: int = 7, width: int = 768, height: int = 768, + restore_faces: bool = True, sampler: str = "Euler a") -> list[bytes, dict] | None: + endpoint = db[DBTables.config].get('endpoint') + try: + async with aiohttp.ClientSession() as session: + r = await session.post( + endpoint + "/sdapi/v1/txt2img", + json={ + "prompt": prompt, + "steps": steps, + "cfg_scale": cfg_scale, + "width": width, + "height": height, + "restore_faces": restore_faces, + "negative_prompt": negative_prompt, + "sampler_index": sampler + } + ) + if r.status != 200: + return None + return [base64.b64decode((await r.json())["images"][0]), + json.loads((await r.json())["info"])] + except Exception as e: + assert e + return None diff --git a/bot/utils/commands.py b/bot/utils/commands.py index 4be4c89..70c3135 100644 --- a/bot/utils/commands.py +++ b/bot/utils/commands.py @@ -4,7 +4,7 @@ from bot.common import bot async def set_commands(): - from bot.handlers.help.help_strings import help_data + from bot.handlers.help_command.help_strings import help_data await bot.set_my_commands( commands=list( diff --git a/bot/utils/cooldown.py b/bot/utils/cooldown.py new file mode 100644 index 0000000..c0dad9f --- /dev/null +++ b/bot/utils/cooldown.py @@ -0,0 +1,49 @@ +from functools import wraps +import datetime +from bot.common import bot +from bot.db import db, DBTables +import asyncio +from aiogram import types + + +def not_allowed(message: types.Message, cd: int, by_id: bool): + return asyncio.create_task(message.reply( + text= + f"❌ Wait for cooldown ({cd}s for this command)" + f"{'. Please note that this cooldown is for all users' if not by_id else ''}" + )) + + +def throttle(cooldown: int = 5, by_id: bool = True, admin_ids: list = None): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + user_id = int(args[0]["from"]["id"]) + if admin_ids and user_id in admin_ids: + return asyncio.create_task(func(*args, **kwargs)) + user_id = str(user_id) if by_id else "0" + now = datetime.datetime.now() + delta = now - datetime.timedelta(seconds=cooldown) + try: + last_time = db[DBTables.cooldown].get(func.__name__).get(user_id) + except AttributeError: + last_time = None + if not last_time: + last_time = delta + + if last_time <= delta: + try: + db[DBTables.cooldown][func.__name__][user_id] = now + except KeyError: + db[DBTables.cooldown][func.__name__] = dict() + db[DBTables.cooldown][func.__name__][user_id] = now + try: + return asyncio.create_task(func(*args, **kwargs)) + except Exception as e: + assert e + else: + return not_allowed(*args, cooldown, by_id) + + return wrapper + + return decorator