From a5ce82f1482663d3d4ebdb738a934b50a742375c Mon Sep 17 00:00:00 2001 From: BarsTiger Date: Mon, 6 Mar 2023 22:53:30 +0200 Subject: [PATCH] Setting model, /status command for pinging host --- bot/callbacks/factories/set_model.py | 5 +++ bot/callbacks/register.py | 4 +- bot/callbacks/set_model.py | 50 +++++++++++++++++++++++ bot/handlers/help_command/help_strings.py | 2 + bot/handlers/initialize/all_messages.py | 5 ++- bot/handlers/txt2img/__init__.py | 4 ++ bot/handlers/txt2img/set_model.py | 37 +++++++++++++++++ bot/handlers/txt2img/status.py | 20 +++++++++ bot/keyboards/set_model.py | 31 ++++++++++++++ bot/modules/api/models.py | 2 +- bot/modules/api/ping.py | 14 +++++++ 11 files changed, 171 insertions(+), 3 deletions(-) create mode 100644 bot/callbacks/factories/set_model.py create mode 100644 bot/callbacks/set_model.py create mode 100644 bot/handlers/txt2img/set_model.py create mode 100644 bot/handlers/txt2img/status.py create mode 100644 bot/keyboards/set_model.py create mode 100644 bot/modules/api/ping.py diff --git a/bot/callbacks/factories/set_model.py b/bot/callbacks/factories/set_model.py new file mode 100644 index 0000000..c08a7e9 --- /dev/null +++ b/bot/callbacks/factories/set_model.py @@ -0,0 +1,5 @@ +from aiogram.utils.callback_data import CallbackData + + +set_model = CallbackData("set_model", "n") +set_model_page = CallbackData("set_model_page", "page") diff --git a/bot/callbacks/register.py b/bot/callbacks/register.py index 325a454..11850dc 100644 --- a/bot/callbacks/register.py +++ b/bot/callbacks/register.py @@ -4,10 +4,12 @@ from rich import print def register_callbacks(): from bot.callbacks import ( exception, - image_info + image_info, + set_model ) exception.register() image_info.register() + set_model.register() print('[gray]All callbacks registered[/]') diff --git a/bot/callbacks/set_model.py b/bot/callbacks/set_model.py new file mode 100644 index 0000000..5f194cb --- /dev/null +++ b/bot/callbacks/set_model.py @@ -0,0 +1,50 @@ +from bot.common import dp +from bot.db import db, DBTables +from aiogram import types +from .factories.set_model import set_model, set_model_page +from bot.keyboards.set_model import get_set_model_keyboard +from bot.utils.private_keyboard import other_user +from bot.keyboards.exception import get_exception_keyboard +from bot.utils.trace_exception import PrettyException +from aiohttp import ClientConnectorError +from bot.modules.api import models + + +async def on_set_model(call: types.CallbackQuery, callback_data: dict): + n = int(callback_data['n']) + temp_message = await call.message.answer('⏳ Setting model...') + try: + await models.set_model(db[DBTables.config]['models'][n]) + + except ClientConnectorError: + await call.answer('❌ Error! Maybe, StableDiffusion API endpoint is incorrect ' + 'or turned off', show_alert=True) + await call.message.delete() + await temp_message.delete() + return + + except Exception as e: + exception_id = f'{call.message.message_thread_id}-{call.message.message_id}' + db[DBTables.exceptions][exception_id] = PrettyException(e) + await call.message.reply('❌ Error happened while processing your request', parse_mode='html', + reply_markup=get_exception_keyboard(exception_id)) + db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1 + return + + await temp_message.answer('✅ Model set for all users!') + await temp_message.delete() + + +async def on_page_change(call: types.CallbackQuery, callback_data: dict): + page = callback_data['page'] + if await other_user(call): + return + + await call.message.edit_reply_markup( + get_set_model_keyboard(page) + ) + + +def register(): + dp.register_callback_query_handler(on_set_model, set_model.filter()) + dp.register_callback_query_handler(on_page_change, set_model_page.filter()) diff --git a/bot/handlers/help_command/help_strings.py b/bot/handlers/help_command/help_strings.py index fb7f743..2218f39 100644 --- a/bot/handlers/help_command/help_strings.py +++ b/bot/handlers/help_command/help_strings.py @@ -11,6 +11,8 @@ help_data = { 'setsampler': 'Set StableDiffusion sampler', 'setscale': 'Set CFG Scale (prompt stringency)', 'setfaces': 'Set restore faces mode', + 'status': 'Ping API endpoint host', + 'setmodel': '(global) Sets StableDiffusion model for all users. Can be used only once an hour', 'setendpoint': '(admin) Set StableDiffusion API endpoint', 'addadmin': '(admin) Add new admin - reply to message or type user ID', 'rmadmin': '(admin) Remove admin - reply to message or type user ID' diff --git a/bot/handlers/initialize/all_messages.py b/bot/handlers/initialize/all_messages.py index ffc15de..16a8aa3 100644 --- a/bot/handlers/initialize/all_messages.py +++ b/bot/handlers/initialize/all_messages.py @@ -1,9 +1,12 @@ from aiogram.types import Message -from bot.db.pull_db import pull async def sync_db_filter(message: Message): + from bot.db.pull_db import pull + from bot.modules.api.ping import ping await pull() if message.is_command(): await message.reply(f'🔄️ Bot database synchronised because of restart. ' f'If you tried to run a command, run it again') + if not await ping(): + await message.reply('⚠️ Warning: StableDiffusion server is turned off or api endpoint is incorrect') diff --git a/bot/handlers/txt2img/__init__.py b/bot/handlers/txt2img/__init__.py index 520b4e8..081769e 100644 --- a/bot/handlers/txt2img/__init__.py +++ b/bot/handlers/txt2img/__init__.py @@ -1,5 +1,7 @@ from bot.common import dp from .txt2img import generate_command +from .set_model import set_model_command +from .status import get_status from .set_settings import ( set_height_command, set_negative_prompt_command, set_size_command, set_steps_command, set_width_command, set_prompt_command, set_sampler_command, set_cfg_scale_command, set_restore_faces_command @@ -17,3 +19,5 @@ def register(): dp.register_message_handler(set_settings.set_sampler_command, commands='setsampler') dp.register_message_handler(set_settings.set_cfg_scale_command, commands='setscale') dp.register_message_handler(set_settings.set_restore_faces_command, commands='setfaces') + dp.register_message_handler(set_model.set_model_command, commands='setmodel') + dp.register_message_handler(status.get_status, commands='status') diff --git a/bot/handlers/txt2img/set_model.py b/bot/handlers/txt2img/set_model.py new file mode 100644 index 0000000..6d8e851 --- /dev/null +++ b/bot/handlers/txt2img/set_model.py @@ -0,0 +1,37 @@ +from aiogram import types +from bot.db import db, DBTables +from bot.utils.cooldown import throttle +from bot.keyboards.set_model import get_set_model_keyboard +from bot.keyboards.exception import get_exception_keyboard +from bot.utils.trace_exception import PrettyException +from bot.modules.api.models import get_models +from aiohttp import ClientConnectorError + + +@throttle(cooldown=60*60, admin_ids=db[DBTables.config].get('admins'), by_id=False) +async def set_model_command(message: types.Message): + temp_message = await message.reply('⏳ Getting models...') + try: + models = await get_models() + if models is not None and len(models) > 0: + db[DBTables.config]['models'] = models + else: + await temp_message.delete() + await message.reply('❌ No models available') + return + + await temp_message.delete() + await message.reply("🪄 You can choose model from available:", reply_markup=get_set_model_keyboard(0)) + + except ClientConnectorError: + await message.reply('❌ Error! Maybe, StableDiffusion API endpoint is incorrect ' + 'or turned off') + await temp_message.delete() + return + + except Exception as e: + exception_id = f'{message.message_thread_id}-{message.message_id}' + db[DBTables.exceptions][exception_id] = PrettyException(e) + await message.reply('❌ Error happened while processing your request', parse_mode='html', + reply_markup=get_exception_keyboard(exception_id)) + return diff --git a/bot/handlers/txt2img/status.py b/bot/handlers/txt2img/status.py new file mode 100644 index 0000000..82b8634 --- /dev/null +++ b/bot/handlers/txt2img/status.py @@ -0,0 +1,20 @@ +from aiogram import types +from bot.db import db, DBTables +from bot.utils.cooldown import throttle +from bot.modules.api.ping import ping + + +@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins')) +async def get_status(message: types.Message): + temp_message = await message.reply('⏳ Sending request...') + try: + if await ping(): + await message.reply('💚 Endpoint is UP') + else: + raise Exception + + except Exception as e: + assert e + await message.reply('💔 Endpoint is probably DOWN or incorrect') + + await temp_message.delete() diff --git a/bot/keyboards/set_model.py b/bot/keyboards/set_model.py new file mode 100644 index 0000000..8db0b89 --- /dev/null +++ b/bot/keyboards/set_model.py @@ -0,0 +1,31 @@ +from aiogram import types +from bot.db import db, DBTables +from bot.callbacks.factories.set_model import set_model_page, set_model + + +def get_set_model_keyboard(page: int) -> types.InlineKeyboardMarkup: + models = db[DBTables.config]['models'] + navigation_buttons = list() + page = int(page) + + if page > 0: + navigation_buttons.append(types.InlineKeyboardButton( + '<', + callback_data=set_model_page.new(page=page - 1) + )) + if len([models[i:i + 5] for i in range(0, len(models), 5)]) > page + 1: + navigation_buttons.append(types.InlineKeyboardButton( + '>', + callback_data=set_model_page.new(page=page + 1) + )) + + models_buttons = [types.InlineKeyboardButton(models[i], callback_data=set_model.new(i)) for i in range(len(models))] + + keyboard = types.InlineKeyboardMarkup(row_width=1) + if len(models) > 5: + keyboard.add(*[models_buttons[i:i + 5] for i in range(0, len(models_buttons), 5)][page]) + keyboard.row(*navigation_buttons) + else: + keyboard.add(*models_buttons) + + return keyboard diff --git a/bot/modules/api/models.py b/bot/modules/api/models.py index e36927b..77fba92 100644 --- a/bot/modules/api/models.py +++ b/bot/modules/api/models.py @@ -2,7 +2,7 @@ import aiohttp from bot.db import db, DBTables, decrypt -async def get_models(): +async def get_models() -> list | None: endpoint = decrypt(db[DBTables.config].get('endpoint')) async with aiohttp.ClientSession() as session: r = await session.get(endpoint + "/sdapi/v1/sd-models") diff --git a/bot/modules/api/ping.py b/bot/modules/api/ping.py new file mode 100644 index 0000000..3b9b860 --- /dev/null +++ b/bot/modules/api/ping.py @@ -0,0 +1,14 @@ +import aiohttp +from bot.db import db, DBTables, decrypt + + +async def ping(): + endpoint = decrypt(db[DBTables.config].get('endpoint')) + try: + async with aiohttp.ClientSession() as session: + r = await session.head(endpoint) + if r.status != 200: + return False + return True + except aiohttp.ClientConnectorError: + return False