From d56f1d386f4c076024c6bfac39fcf3373dfc8a7d Mon Sep 17 00:00:00 2001 From: BarsTiger Date: Mon, 20 Feb 2023 16:41:55 +0200 Subject: [PATCH] Testing text2img command (will be replaced, only for tests); added exception handling, queue and database table with generated images --- bot/callbacks/__init__.py | 0 bot/callbacks/exception.py | 15 ++++++++ bot/callbacks/factories/__init__.py | 0 bot/callbacks/factories/exception.py | 4 ++ bot/callbacks/register.py | 11 ++++++ bot/db/db.py | 5 ++- bot/db/db_model.py | 5 ++- bot/handlers/admin/__init__.py | 2 + bot/handlers/admin/reset.py | 18 +++++++++ bot/handlers/initialize/all_messages.py | 5 ++- bot/handlers/register.py | 4 +- bot/handlers/txt2img/__init__.py | 4 +- bot/handlers/txt2img/txt2img.py | 47 +++++++++++++++++------ bot/keyboards/__init__.py | 0 bot/keyboards/exception.py | 9 +++++ bot/modules/api/objects/__init__.py | 0 bot/modules/api/objects/prompt_request.py | 14 +++++++ bot/modules/api/status.py | 31 +++++++++++++++ bot/modules/api/txt2img.py | 26 ++++++------- bot/utils/trace_exception.py | 47 +++++++++++++++++++++++ main.py | 6 ++- 21 files changed, 219 insertions(+), 34 deletions(-) create mode 100644 bot/callbacks/__init__.py create mode 100644 bot/callbacks/exception.py create mode 100644 bot/callbacks/factories/__init__.py create mode 100644 bot/callbacks/factories/exception.py create mode 100644 bot/callbacks/register.py create mode 100644 bot/handlers/admin/reset.py create mode 100644 bot/keyboards/__init__.py create mode 100644 bot/keyboards/exception.py create mode 100644 bot/modules/api/objects/__init__.py create mode 100644 bot/modules/api/objects/prompt_request.py create mode 100644 bot/modules/api/status.py create mode 100644 bot/utils/trace_exception.py diff --git a/bot/callbacks/__init__.py b/bot/callbacks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bot/callbacks/exception.py b/bot/callbacks/exception.py new file mode 100644 index 0000000..3e4e264 --- /dev/null +++ b/bot/callbacks/exception.py @@ -0,0 +1,15 @@ +from bot.common import dp +from bot.db import db, DBTables +from aiogram import types +from .factories.exception import exception_callback + + +async def on_exception(call: types.CallbackQuery, callback_data: dict): + e_id = callback_data['e_id'] + e = db[DBTables.exceptions][e_id] + del db[DBTables.exceptions][e_id] + await call.message.edit_text(e, parse_mode='html') + + +def register(): + dp.register_callback_query_handler(on_exception, exception_callback.filter()) diff --git a/bot/callbacks/factories/__init__.py b/bot/callbacks/factories/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bot/callbacks/factories/exception.py b/bot/callbacks/factories/exception.py new file mode 100644 index 0000000..9adcc96 --- /dev/null +++ b/bot/callbacks/factories/exception.py @@ -0,0 +1,4 @@ +from aiogram.utils.callback_data import CallbackData + + +exception_callback = CallbackData("full_exception", "e_id") diff --git a/bot/callbacks/register.py b/bot/callbacks/register.py new file mode 100644 index 0000000..306c40b --- /dev/null +++ b/bot/callbacks/register.py @@ -0,0 +1,11 @@ +from rich import print + + +def register_callbacks(): + from bot.callbacks import ( + exception + ) + + exception.register() + + print('[gray]All callbacks registered[/]') diff --git a/bot/db/db.py b/bot/db/db.py index 837aba8..5025ce5 100644 --- a/bot/db/db.py +++ b/bot/db/db.py @@ -8,5 +8,8 @@ if not os.path.isfile(DB): db = { 'config': DBDict(DB, autocommit=True, tablename='config'), - 'cooldown': DBDict(DB, autocommit=True, tablename='cooldown') + 'cooldown': DBDict(DB, autocommit=True, tablename='cooldown'), + 'exceptions': DBDict(DB, autocommit=True, tablename='exceptions'), + 'queue': DBDict(DB, autocommit=True, tablename='queue'), + 'generated': DBDict(DB, autocommit=True, tablename='generated') } diff --git a/bot/db/db_model.py b/bot/db/db_model.py index ae32ef4..43411fb 100644 --- a/bot/db/db_model.py +++ b/bot/db/db_model.py @@ -8,9 +8,12 @@ from .meta import DBMeta class DBTables: - tables = ['config', 'cooldown'] + tables = ['config', 'cooldown', 'exceptions', 'queue', 'generated'] config = "config" cooldown = "cooldown" + exceptions = "exceptions" + queue = "queue" + generated = "generated" class DBDict(SqliteDict): diff --git a/bot/handlers/admin/__init__.py b/bot/handlers/admin/__init__.py index 80c28e4..1e759b7 100644 --- a/bot/handlers/admin/__init__.py +++ b/bot/handlers/admin/__init__.py @@ -1,6 +1,8 @@ from bot.common import dp from .aliases import * +from .reset import * def register(): dp.register_message_handler(set_endpoint, commands='setendpoint') + dp.register_message_handler(reset.resetqueue, commands='resetqueue') diff --git a/bot/handlers/admin/reset.py b/bot/handlers/admin/reset.py new file mode 100644 index 0000000..076a209 --- /dev/null +++ b/bot/handlers/admin/reset.py @@ -0,0 +1,18 @@ +from aiogram import types +from bot.db import db, DBTables +from bot.config import ADMIN +from bot.utils.cooldown import throttle + + +@throttle(5) +async def resetqueue(message: types.Message): + if message.from_id not in db[DBTables.config].get('admins') and message.from_id != ADMIN: + await message.reply('❌ You are not permitted to do that. ' + 'It is only for this bot instance maintainers and admins') + return + + db[DBTables.queue]['n'] = 0 + + await db[DBTables.config].write() + + await message.reply("βœ… Reset queue") diff --git a/bot/handlers/initialize/all_messages.py b/bot/handlers/initialize/all_messages.py index a0488bc..ffc15de 100644 --- a/bot/handlers/initialize/all_messages.py +++ b/bot/handlers/initialize/all_messages.py @@ -4,5 +4,6 @@ from bot.db.pull_db import pull async def sync_db_filter(message: Message): await pull() - await message.reply(f'πŸ”„οΈ Bot database synchronised because of restart. ' - f'If you tried to run a command, run it again') + if message.is_command(): + await message.reply(f'πŸ”„οΈ Bot database synchronised because of restart. ' + f'If you tried to run a command, run it again') diff --git a/bot/handlers/register.py b/bot/handlers/register.py index 4ca332c..b050c6b 100644 --- a/bot/handlers/register.py +++ b/bot/handlers/register.py @@ -1,7 +1,7 @@ from rich import print -def import_handlers(): +def register_handlers(): from bot.handlers import ( initialize, admin, help_command, txt2img ) @@ -11,4 +11,4 @@ def import_handlers(): help_command.register() txt2img.register() - print('[gray]All handlers imported[/]') + print('[gray]All handlers registered[/]') diff --git a/bot/handlers/txt2img/__init__.py b/bot/handlers/txt2img/__init__.py index 39e2ecf..f75dcaf 100644 --- a/bot/handlers/txt2img/__init__.py +++ b/bot/handlers/txt2img/__init__.py @@ -1,6 +1,6 @@ from bot.common import dp -from .txt2img import * +from .txt2img import txt2img_comand def register(): - dp.register_message_handler(txt2img, commands='txt2img') + dp.register_message_handler(txt2img.txt2img_comand, commands='txt2img') diff --git a/bot/handlers/txt2img/txt2img.py b/bot/handlers/txt2img/txt2img.py index 2212684..66aa0bf 100644 --- a/bot/handlers/txt2img/txt2img.py +++ b/bot/handlers/txt2img/txt2img.py @@ -2,25 +2,50 @@ from aiogram import types from bot.db import db, DBTables from bot.utils.cooldown import throttle from bot.modules.api.txt2img import txt2img +from bot.modules.api.objects.prompt_request import Prompt +from bot.modules.api.status import wait_for_status +from bot.keyboards.exception import get_exception_keyboard +from bot.utils.trace_exception import PrettyException +from aiohttp import ClientConnectorError @throttle(cooldown=30, admin_ids=db[DBTables.config].get('admins')) async def txt2img_comand(message: types.Message): - temp_message = await message.reply("⏳ Generating image...") + temp_message = await message.reply("⏳ Enqueued...") if not message.get_args(): - await temp_message.edit_text("Specify prompt for this command. Check /help txt2img") + await temp_message.edit_text("πŸ˜Άβ€πŸŒ«οΈ Specify prompt for this command. Check /help txt2img") return try: - image = await txt2img(message.get_args()) - await message.reply_photo(photo=image[0], caption=str( - image[1]["infotexts"][0])) - except Exception as e: - assert e - await message.reply("We ran into error while processing your request. StableDiffusion models may not be " - "configured on specified endpoint or server with StableDiffusion may be turned " - "off. Ask admins of this bot instance if you have contacts for further info") + db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 0) + 1 + await temp_message.edit_text(f"⏳ Enqueued in position {db[DBTables.queue].get('n', 0)}...") + + await wait_for_status() + + await temp_message.edit_text(f"βŒ› Generating...") + prompt = Prompt(prompt=message.get_args(), creator=message.from_id) + image = await txt2img(prompt) + image_message = await message.reply_photo(photo=image[0]) + + db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1 + db[DBTables.generated][image_message.photo[0].file_unique_id] = prompt + await temp_message.delete() + + await db[DBTables.config].write() + + except ClientConnectorError: + await message.reply('❌ Error! Maybe, StableDiffusion API endpoint is incorrect ' + 'or turned off') + await temp_message.delete() + db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1 return - await temp_message.delete() + except Exception as e: + exception_id = f'{message.message_thread_id}-{message.message_id}' + db[DBTables.exceptions][exception_id] = PrettyException(e) + await message.reply('❌ Error happened while processing your request', parse_mode='html', + reply_markup=get_exception_keyboard(exception_id)) + await temp_message.delete() + db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1 + return diff --git a/bot/keyboards/__init__.py b/bot/keyboards/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bot/keyboards/exception.py b/bot/keyboards/exception.py new file mode 100644 index 0000000..ca70ca1 --- /dev/null +++ b/bot/keyboards/exception.py @@ -0,0 +1,9 @@ +from aiogram import types +from bot.callbacks.factories.exception import exception_callback + + +def get_exception_keyboard(e_id: str) -> types.InlineKeyboardMarkup: + buttons = [types.InlineKeyboardButton(text="Show full stack", callback_data=exception_callback.new(e_id=e_id))] + keyboard = types.InlineKeyboardMarkup() + keyboard.add(*buttons) + return keyboard diff --git a/bot/modules/api/objects/__init__.py b/bot/modules/api/objects/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bot/modules/api/objects/prompt_request.py b/bot/modules/api/objects/prompt_request.py new file mode 100644 index 0000000..3621d05 --- /dev/null +++ b/bot/modules/api/objects/prompt_request.py @@ -0,0 +1,14 @@ +import dataclasses + + +@dataclasses.dataclass +class Prompt: + prompt: str + negative_prompt: str = None + steps: int = 20 + cfg_scale: int = 7 + width: int = 768 + height: int = 768 + restore_faces: bool = True + sampler: str = "Euler a" + creator: int = None diff --git a/bot/modules/api/status.py b/bot/modules/api/status.py new file mode 100644 index 0000000..e033497 --- /dev/null +++ b/bot/modules/api/status.py @@ -0,0 +1,31 @@ +from bot.db import db, DBTables +import aiohttp +import asyncio +import time + + +async def job_exists(endpoint): + async with aiohttp.ClientSession() as session: + r = await session.get( + endpoint + "/sdapi/v1/progress", + json={ + "skip_current_image": True, + } + ) + if r.status != 200: + return None + return (await r.json()).get('state').get('job_count') > 0 + + +async def wait_for_status(ignore_exceptions: bool = False): + endpoint = db[DBTables.config].get('endpoint') + try: + while await job_exists(endpoint): + while db[DBTables.cooldown].get('_last_time_status_checked', 0) + 5 > time.time(): + await asyncio.sleep(5) + db[DBTables.cooldown]['_last_time_status_checked'] = time.time() + return + except Exception as e: + if not ignore_exceptions: + raise e + return diff --git a/bot/modules/api/txt2img.py b/bot/modules/api/txt2img.py index 963f361..429e127 100644 --- a/bot/modules/api/txt2img.py +++ b/bot/modules/api/txt2img.py @@ -1,26 +1,25 @@ import aiohttp from bot.db import db, DBTables +from .objects.prompt_request import Prompt import json import base64 -async def txt2img(prompt: str, negative_prompt: str = None, steps: int = 20, - cfg_scale: int = 7, width: int = 768, height: int = 768, - restore_faces: bool = True, sampler: str = "Euler a") -> list[bytes, dict] | None: +async def txt2img(prompt: Prompt, ignore_exceptions: bool = False) -> list[bytes, dict] | None: endpoint = db[DBTables.config].get('endpoint') try: async with aiohttp.ClientSession() as session: r = await session.post( endpoint + "/sdapi/v1/txt2img", json={ - "prompt": prompt, - "steps": steps, - "cfg_scale": cfg_scale, - "width": width, - "height": height, - "restore_faces": restore_faces, - "negative_prompt": negative_prompt, - "sampler_index": sampler + "prompt": prompt.prompt, + "steps": prompt.steps, + "cfg_scale": prompt.cfg_scale, + "width": prompt.width, + "height": prompt.height, + "restore_faces": prompt.restore_faces, + "negative_prompt": prompt.negative_prompt, + "sampler_index": prompt.sampler } ) if r.status != 200: @@ -28,5 +27,6 @@ async def txt2img(prompt: str, negative_prompt: str = None, steps: int = 20, return [base64.b64decode((await r.json())["images"][0]), json.loads((await r.json())["info"])] except Exception as e: - assert e - return None + if not ignore_exceptions: + raise e + return diff --git a/bot/utils/trace_exception.py b/bot/utils/trace_exception.py new file mode 100644 index 0000000..a0e3a37 --- /dev/null +++ b/bot/utils/trace_exception.py @@ -0,0 +1,47 @@ +import os +import traceback +import contextlib +import re + + +class PrettyException: + def __init__(self, e: Exception): + self.pretty_exception = f'❌ Error! Report it to admins: \n' \ + f'🐊 {e.__traceback__.tb_frame.f_code.co_filename.replace(os.getcwd(), "")}' \ + f':{e.__traceback__.tb_frame.f_lineno} \n' \ + f'😍 {e.__class__.__name__} \n' \ + f'πŸ‘‰ {"".join(traceback.format_exception_only(e)).strip()} \n\n' \ + f'⬇️ Trace: \n' \ + f'{self.get_full_stack()}' + + @staticmethod + def get_full_stack(): + full_stack = traceback.format_exc().replace( + "Traceback (most recent call last):\n", "" + ) + + line_regex = r' File "(.*?)", line ([0-9]+), in (.+)' + + def format_line(line: str) -> str: + filename_, lineno_, name_ = re.search(line_regex, line).groups() + with contextlib.suppress(Exception): + filename_ = os.path.basename(filename_) + + return ( + f"🀯 {filename_}:{lineno_} (in" + f" {name_} call)" + ) + + full_stack = "\n".join( + [ + format_line(line) + if re.search(line_regex, line) + else f"{line}" + for line in full_stack.splitlines() + ] + ) + + return full_stack + + def __str__(self): + return self.pretty_exception diff --git a/main.py b/main.py index e9812cc..6265fcb 100644 --- a/main.py +++ b/main.py @@ -4,11 +4,13 @@ from rich import print async def main(): print(BARS_APP_ID) - import bot.handlers.register from bot.common import dp + import bot.handlers.register + import bot.callbacks.register from bot.utils.commands import set_commands - bot.handlers.register.import_handlers() + bot.handlers.register.register_handlers() + bot.callbacks.register.register_callbacks() await set_commands() print('[green]Bot will start now[/]') await dp.skip_updates()