diff --git a/bot/callbacks/image_info.py b/bot/callbacks/image_info.py index 54e07f0..ae0c317 100644 --- a/bot/callbacks/image_info.py +++ b/bot/callbacks/image_info.py @@ -5,7 +5,7 @@ from .factories.image_info import full_prompt, prompt_only, import_prompt, back from bot.keyboards.image_info import get_img_info_keyboard, get_img_back_keyboard from bot.utils.cooldown import throttle from bot.utils.private_keyboard import other_user -from bot.modules.api.objects.prompt_request import Prompt +from bot.modules.api.objects.prompt_request import Generated async def on_back(call: types.CallbackQuery, callback_data: dict): @@ -26,11 +26,11 @@ async def on_prompt_only(call: types.CallbackQuery, callback_data: dict): if await other_user(call): return - prompt: Prompt = db[DBTables.generated].get(p_id) + prompt: Generated = db[DBTables.generated].get(p_id) await call.message.edit_text( - f"πŸ–€ Prompt: {prompt.prompt} \n" - f"{f'🐊 Negative: {prompt.negative_prompt}' if prompt.negative_prompt else ''}", + f"πŸ–€ Prompt: {prompt.prompt.prompt} \n" + f"{f'🐊 Negative: {prompt.prompt.negative_prompt}' if prompt.prompt.negative_prompt else ''}", parse_mode='html', reply_markup=get_img_back_keyboard(p_id) ) @@ -42,16 +42,18 @@ async def on_full_info(call: types.CallbackQuery, callback_data: dict): if await other_user(call): return - prompt: Prompt = db[DBTables.generated].get(p_id) + prompt: Generated = db[DBTables.generated].get(p_id) await call.message.edit_text( - f"πŸ–€ Prompt: {prompt.prompt} \n" - f"🐊 Negative: {prompt.negative_prompt} \n" - f"πŸͺœ Steps: {prompt.steps} \n" - f"πŸ§‘β€πŸŽ¨ CFG Scale: {prompt.cfg_scale} \n" - f"πŸ–₯️ Size: {prompt.width}x{prompt.height} \n" - f"πŸ˜€ Restore faces: {'on' if prompt.restore_faces else 'off'} \n" - f"βš’οΈ Sampler: {prompt.sampler}", + f"πŸ–€ Prompt: {prompt.prompt.prompt} \n" + f"🐊 Negative: {prompt.prompt.negative_prompt} \n" + f"πŸ’« Model: {prompt.model} \n" + f"πŸͺœ Steps: {prompt.prompt.steps} \n" + f"πŸ§‘β€πŸŽ¨ CFG Scale: {prompt.prompt.cfg_scale} \n" + f"πŸ–₯️ Size: {prompt.prompt.width}x{prompt.prompt.height} \n" + f"πŸ˜€ Restore faces: {'on' if prompt.prompt.restore_faces else 'off'} \n" + f"βš’οΈ Sampler: {prompt.prompt.sampler} \n" + f"🌱 Seed: {prompt.seed}", parse_mode='html', reply_markup=get_img_back_keyboard(p_id) ) @@ -63,7 +65,7 @@ async def on_import(call: types.CallbackQuery, callback_data: dict): if await other_user(call): return - prompt: Prompt = db[DBTables.generated].get(p_id) + prompt: Generated = db[DBTables.generated].get(p_id) await call.message.edit_text( f"πŸ˜Άβ€πŸŒ«οΈ Not implemented yet", diff --git a/bot/handlers/help_command/help_strings.py b/bot/handlers/help_command/help_strings.py index a224509..d8b8962 100644 --- a/bot/handlers/help_command/help_strings.py +++ b/bot/handlers/help_command/help_strings.py @@ -1,4 +1,15 @@ help_data = { - 'setendpoint': '(admin) Set StableDiffusion API endpoint', - 'imginfo': 'Get information about image, that was generated using this bot' + 'generate': 'Generate picture using configuration set by user. You can pass prompt also in command arguments or ' + 'use it without arguments to generate picture with prompt, that was used for last generation', + 'imginfo': 'Get information about image, that was generated using this bot', + 'setprompt': 'Set default prompt for images, will be overwritten if you specify prompt in generate command', + 'setnegative': 'Set negative prompt', + 'setsize': 'Set size for image (in hxw format)', + 'setwidth': 'Set width for image', + 'setheight': 'Set height for image', + 'setsteps': 'Set sampling steps number', + 'setsampler': 'Set StableDiffusion sampler', + 'setscale': 'Set CFG Scale (prompt stringency)', + 'setfaces': 'Set restore faces mode', + 'setendpoint': '(admin) Set StableDiffusion API endpoint' } diff --git a/bot/handlers/txt2img/__init__.py b/bot/handlers/txt2img/__init__.py index f75dcaf..520b4e8 100644 --- a/bot/handlers/txt2img/__init__.py +++ b/bot/handlers/txt2img/__init__.py @@ -1,6 +1,19 @@ from bot.common import dp -from .txt2img import txt2img_comand +from .txt2img import generate_command +from .set_settings import ( + set_height_command, set_negative_prompt_command, set_size_command, set_steps_command, set_width_command, + set_prompt_command, set_sampler_command, set_cfg_scale_command, set_restore_faces_command +) def register(): - dp.register_message_handler(txt2img.txt2img_comand, commands='txt2img') + dp.register_message_handler(txt2img.generate_command, commands='generate') + dp.register_message_handler(set_settings.set_prompt_command, commands='setprompt') + dp.register_message_handler(set_settings.set_height_command, commands='setheight') + dp.register_message_handler(set_settings.set_width_command, commands='setwidth') + dp.register_message_handler(set_settings.set_negative_prompt_command, commands='setnegative') + dp.register_message_handler(set_settings.set_size_command, commands='setsize') + dp.register_message_handler(set_settings.set_steps_command, commands='setsteps') + dp.register_message_handler(set_settings.set_sampler_command, commands='setsampler') + dp.register_message_handler(set_settings.set_cfg_scale_command, commands='setscale') + dp.register_message_handler(set_settings.set_restore_faces_command, commands='setfaces') diff --git a/bot/handlers/txt2img/set_settings.py b/bot/handlers/txt2img/set_settings.py index 91883e9..8c9c3ef 100644 --- a/bot/handlers/txt2img/set_settings.py +++ b/bot/handlers/txt2img/set_settings.py @@ -6,22 +6,28 @@ from bot.keyboards.exception import get_exception_keyboard from bot.utils.trace_exception import PrettyException -@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins')) -async def set_prompt_command(message: types.Message): - temp_message = await message.reply("⏳ Setting prompt...") +async def _set_property(message: types.Message, prop: str, value=None): + temp_message = await message.reply(f"⏳ Setting {prop}...") if not message.get_args(): - await temp_message.edit_text("πŸ˜Άβ€πŸŒ«οΈ Specify prompt for this command. Check /help setprompt") + await temp_message.edit_text("πŸ˜Άβ€πŸŒ«οΈ Specify arguments for this command. Check /help") return try: - prompt: Prompt = db[DBTables.prompts].get(message.from_id, Prompt(message.get_args())) - prompt.prompt = message.get_args() + prompt: Prompt = db[DBTables.prompts].get(message.from_id) + if prompt is None and prop != 'prompt': + await temp_message.edit_text(f"You didn't created any prompt. Specify prompt text at least first time. " + f"For example, it can be: masterpiece, best quality, 1girl, white hair, " + f"medium hair, cat ears, closed eyes, looking at viewer, :3, cute, scarf, " + f"jacket, outdoors, streets", parse_mode='HTML') + return + + prompt.__setattr__(prop, message.get_args() if value is None else value) prompt.creator = message.from_id db[DBTables.prompts][message.from_id] = prompt await db[DBTables.config].write() - await message.reply('βœ… Default prompt set') + await message.reply(f'βœ… {prop} set') await temp_message.delete() except Exception as e: @@ -32,3 +38,134 @@ async def set_prompt_command(message: types.Message): await temp_message.delete() db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1 return + + +@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins')) +async def set_prompt_command(message: types.Message): + await _set_property(message, 'prompt') + + +@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins')) +async def set_negative_prompt_command(message: types.Message): + await _set_property(message, 'negative_prompt') + + +@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins')) +async def set_steps_command(message: types.Message): + try: + _ = int(message.get_args()) + except Exception as e: + assert e + await message.reply('❌ Specify number as argument') + return + + if _ > 30: + await message.reply('❌ Specify number <= 30') + return + + await _set_property(message, 'steps') + + +@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins')) +async def set_cfg_scale_command(message: types.Message): + try: + _ = int(message.get_args()) + except Exception as e: + assert e + await message.reply('❌ Specify number as argument') + return + + await _set_property(message, 'cfg_scale') + + +@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins')) +async def set_width_command(message: types.Message): + try: + _ = int(message.get_args()) + except Exception as e: + assert e + await message.reply('❌ Specify number as argument') + return + + if _ > 768: + await message.reply('❌ Specify number <= 768') + return + + await _set_property(message, 'width') + + +@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins')) +async def set_height_command(message: types.Message): + try: + _ = int(message.get_args()) + except Exception as e: + assert e + await message.reply('❌ Specify number as argument') + return + + if _ > 768: + await message.reply('❌ Specify number <= 768') + return + + await _set_property(message, 'height') + + +@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins')) +async def set_restore_faces_command(message: types.Message): + try: + _ = bool(message.get_args()) + except Exception as e: + assert e + await message.reply('❌ Specify boolean True/False as argument', + parse_mode='HTML') + return + + await _set_property(message, 'restore_faces') + + +@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins')) +async def set_sampler_command(message: types.Message): + temp_message = await message.reply('⏳ Getting samplers...') + from bot.modules.api.samplers import get_samplers + from aiohttp import ClientConnectorError + try: + if message.get_args() not in (samplers := await get_samplers()): + await message.reply( + f'❌ You can use only {", ".join(f"{x}" for x in samplers)}', + parse_mode='HTML' + ) + return + await temp_message.delete() + except ClientConnectorError: + await message.reply('❌ Error! Maybe, StableDiffusion API endpoint is incorrect ' + 'or turned off') + except Exception as e: + exception_id = f'{message.message_thread_id}-{message.message_id}' + db[DBTables.exceptions][exception_id] = PrettyException(e) + await message.reply('❌ Error happened while processing your request', parse_mode='html', + reply_markup=get_exception_keyboard(exception_id)) + await temp_message.delete() + db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1 + return + + await _set_property(message, 'sampler') + + +@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins')) +async def set_size_command(message: types.Message): + try: + hxw = message.get_args().split('x') + height = int(hxw[0]) + width = int(hxw[1]) + except Exception as e: + assert e + await message.reply('❌ Specify size in hxw format, for example 512x512', + parse_mode='HTML') + return + + if height > 768 or width > 768: + await message.reply('❌ Specify numbers <= 768') + return + + await _set_property(message, 'height', height) + await _set_property(message, 'width', width) diff --git a/bot/handlers/txt2img/txt2img.py b/bot/handlers/txt2img/txt2img.py index 47c7a34..680a6ac 100644 --- a/bot/handlers/txt2img/txt2img.py +++ b/bot/handlers/txt2img/txt2img.py @@ -1,24 +1,30 @@ +import re from aiogram import types from bot.db import db, DBTables from bot.utils.cooldown import throttle from bot.modules.api.txt2img import txt2img -from bot.modules.api.objects.prompt_request import Prompt +from bot.modules.api.objects.get_prompt import get_prompt +from bot.modules.api.objects.prompt_request import Generated from bot.modules.api.status import wait_for_status from bot.keyboards.exception import get_exception_keyboard +from bot.keyboards.image_info import get_img_info_keyboard from bot.utils.trace_exception import PrettyException from aiohttp import ClientConnectorError @throttle(cooldown=30, admin_ids=db[DBTables.config].get('admins')) -async def txt2img_comand(message: types.Message): +async def generate_command(message: types.Message): temp_message = await message.reply("⏳ Enqueued...") - prompt: Prompt = db[DBTables.prompts].get(message.from_id) - if not prompt: - if message.get_args(): - db[DBTables.prompts][message.from_id] = Prompt(message.get_args(), creator=message.from_id) - - # TODO: Move it to other module + try: + prompt = get_prompt(user_id=message.from_id, + prompt_string=message.get_args()) + except AttributeError: + await temp_message.edit_text(f"You didn't created any prompt. Specify prompt text at least first time. " + f"For example, it can be: masterpiece, best quality, 1girl, white hair, " + f"medium hair, cat ears, closed eyes, looking at viewer, :3, cute, scarf, jacket, " + f"outdoors, streets", parse_mode='HTML') + return try: db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 0) + 1 @@ -27,12 +33,18 @@ async def txt2img_comand(message: types.Message): await wait_for_status() await temp_message.edit_text(f"βŒ› Generating...") - prompt = Prompt(prompt=message.get_args(), creator=message.from_id) image = await txt2img(prompt) image_message = await message.reply_photo(photo=image[0]) db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1 - db[DBTables.generated][image_message.photo[0].file_unique_id] = prompt + db[DBTables.generated][image_message.photo[0].file_unique_id] = Generated( + prompt=prompt, + seed=image[1]['seed'], + model=re.search(r", Model: ([^,]+),", image[1]['infotexts'][0]).groups()[0] + ) + + await message.reply('Here is your image', + reply_markup=get_img_info_keyboard(image_message.photo[0].file_unique_id)) await temp_message.delete() @@ -45,6 +57,12 @@ async def txt2img_comand(message: types.Message): db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1 return + except ValueError as e: + await message.reply(f'❌ Error! {e.args[0]}') + await temp_message.delete() + db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1 + return + except Exception as e: exception_id = f'{message.message_thread_id}-{message.message_id}' db[DBTables.exceptions][exception_id] = PrettyException(e) diff --git a/bot/modules/api/models.py b/bot/modules/api/models.py index bdf73cc..e36927b 100644 --- a/bot/modules/api/models.py +++ b/bot/modules/api/models.py @@ -1,31 +1,22 @@ import aiohttp from bot.db import db, DBTables, decrypt -from rich import print async def get_models(): endpoint = decrypt(db[DBTables.config].get('endpoint')) - try: - async with aiohttp.ClientSession() as session: - r = await session.get(endpoint + "/sdapi/v1/sd-models") - if r.status != 200: - return None - return [x["title"] for x in await r.json()] - except Exception as e: - print(e) - return None + async with aiohttp.ClientSession() as session: + r = await session.get(endpoint + "/sdapi/v1/sd-models") + if r.status != 200: + return None + return [x["title"] for x in await r.json()] async def set_model(model_name: str): endpoint = decrypt(db[DBTables.config].get('endpoint')) - try: - async with aiohttp.ClientSession() as session: - r = await session.post(endpoint + "/sdapi/v1/options", json={ - "sd_model_checkpoint": model_name - }) - if r.status != 200: - return False - return True - except Exception as e: - print(e) - return False + async with aiohttp.ClientSession() as session: + r = await session.post(endpoint + "/sdapi/v1/options", json={ + "sd_model_checkpoint": model_name + }) + if r.status != 200: + return False + return True diff --git a/bot/modules/api/objects/get_prompt.py b/bot/modules/api/objects/get_prompt.py new file mode 100644 index 0000000..2cecf52 --- /dev/null +++ b/bot/modules/api/objects/get_prompt.py @@ -0,0 +1,38 @@ +from bot.modules.api.objects.prompt_request import Prompt +from bot.db import db, DBTables + + +def get_prompt(user_id: int, prompt_string: str = None, negative_prompt: str = None, steps: int = None, + cfg_scale: int = None, width: int = None, height: int = None, restore_faces: bool = None, + sampler: str = None) -> Prompt: + new_prompt: Prompt = db[DBTables.prompts].get(user_id) + creator = user_id + if not new_prompt: + if user_id and prompt_string: + db[DBTables.prompts][user_id] = Prompt( + prompt=prompt_string, + creator=user_id, + negative_prompt=negative_prompt, + steps=steps, + cfg_scale=cfg_scale, + width=width, + height=height, + restore_faces=restore_faces, + sampler=sampler, + ) + new_prompt: Prompt = db[DBTables.prompts].get(user_id) + else: + raise AttributeError('No prompt string specified and prompt doesn\'t exist for this user') + + if prompt_string: + new_prompt.prompt = prompt_string + + for key in new_prompt.__dict__.keys(): + if key in locals().keys() and locals()[key]: + new_prompt.__setattr__(key, locals()[key]) + elif not new_prompt.__getattribute__(key): + new_prompt.__setattr__(key, Prompt('').__getattribute__(key)) + + db[DBTables.prompts][user_id] = new_prompt + + return new_prompt diff --git a/bot/modules/api/objects/prompt_request.py b/bot/modules/api/objects/prompt_request.py index 3621d05..841c6c6 100644 --- a/bot/modules/api/objects/prompt_request.py +++ b/bot/modules/api/objects/prompt_request.py @@ -12,3 +12,10 @@ class Prompt: restore_faces: bool = True sampler: str = "Euler a" creator: int = None + + +@dataclasses.dataclass +class Generated: + prompt: Prompt + seed: int + model: str diff --git a/bot/modules/api/samplers.py b/bot/modules/api/samplers.py new file mode 100644 index 0000000..4ebbc34 --- /dev/null +++ b/bot/modules/api/samplers.py @@ -0,0 +1,11 @@ +import aiohttp +from bot.db import db, DBTables, decrypt + + +async def get_samplers(): + endpoint = decrypt(db[DBTables.config].get('endpoint')) + async with aiohttp.ClientSession() as session: + r = await session.get(endpoint + "/sdapi/v1/samplers") + if r.status != 200: + return None + return [x["name"] for x in await r.json()] diff --git a/bot/modules/api/txt2img.py b/bot/modules/api/txt2img.py index 29982a9..0b77ac4 100644 --- a/bot/modules/api/txt2img.py +++ b/bot/modules/api/txt2img.py @@ -22,8 +22,10 @@ async def txt2img(prompt: Prompt, ignore_exceptions: bool = False) -> list[bytes "sampler_index": prompt.sampler } ) - if r.status != 200: + if r.status != 200 and ignore_exceptions: return None + elif r.status != 200: + raise ValueError((await r.json())['detail']) return [base64.b64decode((await r.json())["images"][0]), json.loads((await r.json())["info"])] except Exception as e: