diff --git a/bot/callbacks/set_model.py b/bot/callbacks/set_model.py index 5f194cb..d8ea4a0 100644 --- a/bot/callbacks/set_model.py +++ b/bot/callbacks/set_model.py @@ -4,35 +4,18 @@ from aiogram import types from .factories.set_model import set_model, set_model_page from bot.keyboards.set_model import get_set_model_keyboard from bot.utils.private_keyboard import other_user -from bot.keyboards.exception import get_exception_keyboard -from bot.utils.trace_exception import PrettyException -from aiohttp import ClientConnectorError from bot.modules.api import models +from bot.utils.errorable_command import wrap_exception +@wrap_exception() async def on_set_model(call: types.CallbackQuery, callback_data: dict): n = int(callback_data['n']) - temp_message = await call.message.answer('⏳ Setting model...') - try: - await models.set_model(db[DBTables.config]['models'][n]) + await call.message.edit_reply_markup(reply_markup=None) + await models.set_model(db[DBTables.config]['models'][n]) - except ClientConnectorError: - await call.answer('❌ Error! Maybe, StableDiffusion API endpoint is incorrect ' - 'or turned off', show_alert=True) - await call.message.delete() - await temp_message.delete() - return - - except Exception as e: - exception_id = f'{call.message.message_thread_id}-{call.message.message_id}' - db[DBTables.exceptions][exception_id] = PrettyException(e) - await call.message.reply('❌ Error happened while processing your request', parse_mode='html', - reply_markup=get_exception_keyboard(exception_id)) - db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1 - return - - await temp_message.answer('✅ Model set for all users!') - await temp_message.delete() + await call.message.answer('✅ Model set for all users!') + await call.message.delete() async def on_page_change(call: types.CallbackQuery, callback_data: dict): diff --git a/bot/handlers/admin/tools.py b/bot/handlers/admin/tools.py index f2602c6..ec85d2b 100644 --- a/bot/handlers/admin/tools.py +++ b/bot/handlers/admin/tools.py @@ -1,11 +1,11 @@ from aiogram import types from bot.db import db, DBTables from bot.utils.cooldown import throttle -from bot.keyboards.exception import get_exception_keyboard -from bot.utils.trace_exception import PrettyException +from bot.utils.errorable_command import wrap_exception from bot.modules.get_hash.get_hash import get_hash +@wrap_exception([IndexError]) @throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins')) async def hash_command(message: types.Message): try: @@ -23,10 +23,3 @@ async def hash_command(message: types.Message): except IndexError: await message.reply('❌ Reply with this command on PICTURE OR FILE', parse_mode='html') - - except Exception as e: - exception_id = f'{message.message_thread_id}-{message.message_id}' - db[DBTables.exceptions][exception_id] = PrettyException(e) - await message.reply('❌ Error happened while processing your request', parse_mode='html', - reply_markup=get_exception_keyboard(exception_id)) - return diff --git a/bot/handlers/image_info/image_info.py b/bot/handlers/image_info/image_info.py index 93f6be2..3ab7ca2 100644 --- a/bot/handlers/image_info/image_info.py +++ b/bot/handlers/image_info/image_info.py @@ -1,11 +1,11 @@ from aiogram import types from bot.db import db, DBTables from bot.utils.cooldown import throttle -from bot.keyboards.exception import get_exception_keyboard from bot.keyboards.image_info import get_img_info_keyboard -from bot.utils.trace_exception import PrettyException +from bot.utils.errorable_command import wrap_exception +@wrap_exception([IndexError]) @throttle(cooldown=10, admin_ids=db[DBTables.config].get('admins')) async def imginfo(message: types.Message): try: @@ -25,10 +25,3 @@ async def imginfo(message: types.Message): except IndexError: await message.reply('❌ Reply with this command on PICTURE', parse_mode='html') - - except Exception as e: - exception_id = f'{message.message_thread_id}-{message.message_id}' - db[DBTables.exceptions][exception_id] = PrettyException(e) - await message.reply('❌ Error happened while processing your request', parse_mode='html', - reply_markup=get_exception_keyboard(exception_id)) - return diff --git a/bot/handlers/txt2img/set_model.py b/bot/handlers/txt2img/set_model.py index 6d8e851..ddb31e3 100644 --- a/bot/handlers/txt2img/set_model.py +++ b/bot/handlers/txt2img/set_model.py @@ -2,36 +2,18 @@ from aiogram import types from bot.db import db, DBTables from bot.utils.cooldown import throttle from bot.keyboards.set_model import get_set_model_keyboard -from bot.keyboards.exception import get_exception_keyboard -from bot.utils.trace_exception import PrettyException from bot.modules.api.models import get_models -from aiohttp import ClientConnectorError +from bot.utils.errorable_command import wrap_exception +@wrap_exception() @throttle(cooldown=60*60, admin_ids=db[DBTables.config].get('admins'), by_id=False) async def set_model_command(message: types.Message): - temp_message = await message.reply('⏳ Getting models...') - try: - models = await get_models() - if models is not None and len(models) > 0: - db[DBTables.config]['models'] = models - else: - await temp_message.delete() - await message.reply('❌ No models available') - return - - await temp_message.delete() - await message.reply("🪄 You can choose model from available:", reply_markup=get_set_model_keyboard(0)) - - except ClientConnectorError: - await message.reply('❌ Error! Maybe, StableDiffusion API endpoint is incorrect ' - 'or turned off') - await temp_message.delete() + models = await get_models() + if models is not None and len(models) > 0: + db[DBTables.config]['models'] = models + else: + await message.reply('❌ No models available') return - except Exception as e: - exception_id = f'{message.message_thread_id}-{message.message_id}' - db[DBTables.exceptions][exception_id] = PrettyException(e) - await message.reply('❌ Error happened while processing your request', parse_mode='html', - reply_markup=get_exception_keyboard(exception_id)) - return + await message.reply("🪄 You can choose model from available:", reply_markup=get_set_model_keyboard(0)) diff --git a/bot/handlers/txt2img/set_settings.py b/bot/handlers/txt2img/set_settings.py index 40f7267..5c8dfa2 100644 --- a/bot/handlers/txt2img/set_settings.py +++ b/bot/handlers/txt2img/set_settings.py @@ -2,44 +2,34 @@ from aiogram import types from bot.db import db, DBTables from bot.utils.cooldown import throttle from bot.modules.api.objects.prompt_request import Prompt -from bot.keyboards.exception import get_exception_keyboard -from bot.utils.trace_exception import PrettyException +from bot.utils.errorable_command import wrap_exception +@wrap_exception(custom_loading=True) async def _set_property(message: types.Message, prop: str, value=None): temp_message = await message.reply(f"⏳ Setting {prop}...") if not message.get_args(): await temp_message.edit_text("😶‍🌫️ Specify arguments for this command. Check /help") return - try: - prompt: Prompt = db[DBTables.prompts].get(message.from_id) - if prompt is None and prop != 'prompt': - await temp_message.edit_text(f"You didn't created any prompt. Specify prompt text at least first time. " - f"For example, it can be: masterpiece, best quality, 1girl, white hair, " - f"medium hair, cat ears, closed eyes, looking at viewer, :3, cute, scarf, " - f"jacket, outdoors, streets", parse_mode='HTML') - return - elif prompt is None: - prompt = Prompt(message.get_args(), creator=message.from_id) - - prompt.__setattr__(prop, message.get_args() if value is None else value) - prompt.creator = message.from_id - db[DBTables.prompts][message.from_id] = prompt - - await db[DBTables.config].write() - - await message.reply(f'✅ {prop} set') - await temp_message.delete() - - except Exception as e: - exception_id = f'{message.message_thread_id}-{message.message_id}' - db[DBTables.exceptions][exception_id] = PrettyException(e) - await message.reply('❌ Error happened while processing your request', parse_mode='html', - reply_markup=get_exception_keyboard(exception_id)) - await temp_message.delete() - db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1 + prompt: Prompt = db[DBTables.prompts].get(message.from_id) + if prompt is None and prop != 'prompt': + await temp_message.edit_text(f"You didn't created any prompt. Specify prompt text at least first time. " + f"For example, it can be: masterpiece, best quality, 1girl, white hair, " + f"medium hair, cat ears, closed eyes, looking at viewer, :3, cute, scarf, " + f"jacket, outdoors, streets", parse_mode='HTML') return + elif prompt is None: + prompt = Prompt(message.get_args(), creator=message.from_id) + + prompt.__setattr__(prop, message.get_args() if value is None else value) + prompt.creator = message.from_id + db[DBTables.prompts][message.from_id] = prompt + + await db[DBTables.config].write() + + await message.reply(f'✅ {prop} set') + await temp_message.delete() @throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins')) @@ -125,30 +115,15 @@ async def set_restore_faces_command(message: types.Message): await _set_property(message, 'restore_faces') +@wrap_exception() @throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins')) async def set_sampler_command(message: types.Message): - temp_message = await message.reply('⏳ Getting samplers...') from bot.modules.api.samplers import get_samplers - from aiohttp import ClientConnectorError - try: - if message.get_args() not in (samplers := await get_samplers()): - await message.reply( - f'❌ You can use only {", ".join(f"{x}" for x in samplers)}', - parse_mode='HTML' - ) - return - await temp_message.delete() - except ClientConnectorError: - await message.reply('❌ Error! Maybe, StableDiffusion API endpoint is incorrect ' - 'or turned off') - await temp_message.delete() - return - except Exception as e: - exception_id = f'{message.message_thread_id}-{message.message_id}' - db[DBTables.exceptions][exception_id] = PrettyException(e) - await message.reply('❌ Error happened while processing your request', parse_mode='html', - reply_markup=get_exception_keyboard(exception_id)) - await temp_message.delete() + if message.get_args() not in (samplers := await get_samplers()): + await message.reply( + f'❌ You can use only {", ".join(f"{x}" for x in samplers)}', + parse_mode='HTML' + ) return await _set_property(message, 'sampler') diff --git a/bot/handlers/txt2img/txt2img.py b/bot/handlers/txt2img/txt2img.py index 680a6ac..e44a6c0 100644 --- a/bot/handlers/txt2img/txt2img.py +++ b/bot/handlers/txt2img/txt2img.py @@ -6,12 +6,11 @@ from bot.modules.api.txt2img import txt2img from bot.modules.api.objects.get_prompt import get_prompt from bot.modules.api.objects.prompt_request import Generated from bot.modules.api.status import wait_for_status -from bot.keyboards.exception import get_exception_keyboard from bot.keyboards.image_info import get_img_info_keyboard -from bot.utils.trace_exception import PrettyException -from aiohttp import ClientConnectorError +from bot.utils.errorable_command import wrap_exception +@wrap_exception([ValueError], custom_loading=True) @throttle(cooldown=30, admin_ids=db[DBTables.config].get('admins')) async def generate_command(message: types.Message): temp_message = await message.reply("⏳ Enqueued...") @@ -33,10 +32,10 @@ async def generate_command(message: types.Message): await wait_for_status() await temp_message.edit_text(f"⌛ Generating...") + db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1 image = await txt2img(prompt) image_message = await message.reply_photo(photo=image[0]) - db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1 db[DBTables.generated][image_message.photo[0].file_unique_id] = Generated( prompt=prompt, seed=image[1]['seed'], @@ -50,24 +49,8 @@ async def generate_command(message: types.Message): await db[DBTables.config].write() - except ClientConnectorError: - await message.reply('❌ Error! Maybe, StableDiffusion API endpoint is incorrect ' - 'or turned off') - await temp_message.delete() - db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1 - return - except ValueError as e: await message.reply(f'❌ Error! {e.args[0]}') await temp_message.delete() db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1 return - - except Exception as e: - exception_id = f'{message.message_thread_id}-{message.message_id}' - db[DBTables.exceptions][exception_id] = PrettyException(e) - await message.reply('❌ Error happened while processing your request', parse_mode='html', - reply_markup=get_exception_keyboard(exception_id)) - await temp_message.delete() - db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1 - return diff --git a/bot/utils/errorable_command.py b/bot/utils/errorable_command.py new file mode 100644 index 0000000..7ca9cfe --- /dev/null +++ b/bot/utils/errorable_command.py @@ -0,0 +1,58 @@ +from functools import wraps +from bot.db import db, DBTables +from bot.keyboards.exception import get_exception_keyboard +from bot.utils.trace_exception import PrettyException +from aiohttp import ClientConnectorError +from aiogram import types + + +def wrap_exception(unhandled_types: list = None, custom_loading: bool = False): + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + message = args[0] + if not custom_loading: + if isinstance(message, types.Message): + temp_message = await message.reply('⏳ Loading...') + elif isinstance(message, types.CallbackQuery): + temp_message = await message.message.answer('⏳ Loading...') + else: + raise AttributeError("This wrapper is only for commands!") + try: + _ = await func(*args, **kwargs) + if not custom_loading: + await temp_message.delete() + return _ + + except ClientConnectorError: + r_string = '❌ Error! Maybe, StableDiffusion API endpoint is incorrect ' \ + 'or turned off' + if isinstance(message, types.Message): + await message.reply(r_string) + elif isinstance(message, types.CallbackQuery): + await message.message.answer(r_string) + + if not custom_loading: + await temp_message.delete() + return + + except Exception as e: + if not unhandled_types or e.__class__ not in unhandled_types: + exception_id = f'{message.message_thread_id}-{message.message_id}' + db[DBTables.exceptions][exception_id] = PrettyException(e) + if not custom_loading: + await temp_message.delete() + if isinstance(message, types.Message): + await message.reply('❌ Error happened while processing your request', parse_mode='html', + reply_markup=get_exception_keyboard(exception_id)) + elif isinstance(message, types.CallbackQuery): + await message.message.reply('❌ Error happened while processing your request', + parse_mode='html', + reply_markup=get_exception_keyboard(exception_id)) + return + else: + raise e + + return wrapper + + return decorator