diff --git a/bot/callbacks/factories/image_info.py b/bot/callbacks/factories/image_info.py new file mode 100644 index 0000000..417468c --- /dev/null +++ b/bot/callbacks/factories/image_info.py @@ -0,0 +1,7 @@ +from aiogram.utils.callback_data import CallbackData + + +prompt_only = CallbackData("prompt_only", "p_id") +full_prompt = CallbackData("full_prompt", "p_id") +import_prompt = CallbackData("import_prompt", "p_id") +back = CallbackData("img_info_back", "p_id") diff --git a/bot/callbacks/image_info.py b/bot/callbacks/image_info.py new file mode 100644 index 0000000..54e07f0 --- /dev/null +++ b/bot/callbacks/image_info.py @@ -0,0 +1,79 @@ +from bot.common import dp +from bot.db import db, DBTables +from aiogram import types +from .factories.image_info import full_prompt, prompt_only, import_prompt, back +from bot.keyboards.image_info import get_img_info_keyboard, get_img_back_keyboard +from bot.utils.cooldown import throttle +from bot.utils.private_keyboard import other_user +from bot.modules.api.objects.prompt_request import Prompt + + +async def on_back(call: types.CallbackQuery, callback_data: dict): + p_id = callback_data['p_id'] + if await other_user(call): + return + + await call.message.edit_text( + "Image was generated using this bot", + parse_mode='html', + reply_markup=get_img_info_keyboard(p_id) + ) + + +@throttle(5) +async def on_prompt_only(call: types.CallbackQuery, callback_data: dict): + p_id = callback_data['p_id'] + if await other_user(call): + return + + prompt: Prompt = db[DBTables.generated].get(p_id) + + await call.message.edit_text( + f"πŸ–€ Prompt: {prompt.prompt} \n" + f"{f'🐊 Negative: {prompt.negative_prompt}' if prompt.negative_prompt else ''}", + parse_mode='html', + reply_markup=get_img_back_keyboard(p_id) + ) + + +@throttle(5) +async def on_full_info(call: types.CallbackQuery, callback_data: dict): + p_id = callback_data['p_id'] + if await other_user(call): + return + + prompt: Prompt = db[DBTables.generated].get(p_id) + + await call.message.edit_text( + f"πŸ–€ Prompt: {prompt.prompt} \n" + f"🐊 Negative: {prompt.negative_prompt} \n" + f"πŸͺœ Steps: {prompt.steps} \n" + f"πŸ§‘β€πŸŽ¨ CFG Scale: {prompt.cfg_scale} \n" + f"πŸ–₯️ Size: {prompt.width}x{prompt.height} \n" + f"πŸ˜€ Restore faces: {'on' if prompt.restore_faces else 'off'} \n" + f"βš’οΈ Sampler: {prompt.sampler}", + parse_mode='html', + reply_markup=get_img_back_keyboard(p_id) + ) + + +@throttle(5) +async def on_import(call: types.CallbackQuery, callback_data: dict): + p_id = callback_data['p_id'] + if await other_user(call): + return + + prompt: Prompt = db[DBTables.generated].get(p_id) + + await call.message.edit_text( + f"πŸ˜Άβ€πŸŒ«οΈ Not implemented yet", + parse_mode='html', + reply_markup=get_img_back_keyboard(p_id) + ) + + +def register(): + dp.register_callback_query_handler(on_prompt_only, prompt_only.filter()) + dp.register_callback_query_handler(on_back, back.filter()) + dp.register_callback_query_handler(on_full_info, full_prompt.filter()) + dp.register_callback_query_handler(on_import, import_prompt.filter()) diff --git a/bot/callbacks/register.py b/bot/callbacks/register.py index 306c40b..325a454 100644 --- a/bot/callbacks/register.py +++ b/bot/callbacks/register.py @@ -3,9 +3,11 @@ from rich import print def register_callbacks(): from bot.callbacks import ( - exception + exception, + image_info ) exception.register() + image_info.register() print('[gray]All callbacks registered[/]') diff --git a/bot/config.py b/bot/config.py index cac6699..1f99d4d 100644 --- a/bot/config.py +++ b/bot/config.py @@ -6,6 +6,7 @@ load_dotenv() TOKEN = os.getenv('TOKEN') ADMIN = int(os.getenv('ADMIN')) DB_CHAT = os.getenv('DB_CHAT') +ENCRYPTION_KEY = os.getenv('ENCRYPTION_KEY').encode() _DB_PATH = os.getenv('DB_PATH') DB = _DB_PATH + '/db' DBMETA = _DB_PATH + '/dbmeta' diff --git a/bot/db/__init__.py b/bot/db/__init__.py index f7a640d..ccfda76 100644 --- a/bot/db/__init__.py +++ b/bot/db/__init__.py @@ -1,2 +1,3 @@ from .db import db from .db_model import DBTables +from .encryption import encrypt, decrypt diff --git a/bot/db/db.py b/bot/db/db.py index 5025ce5..7f6381b 100644 --- a/bot/db/db.py +++ b/bot/db/db.py @@ -11,5 +11,6 @@ db = { 'cooldown': DBDict(DB, autocommit=True, tablename='cooldown'), 'exceptions': DBDict(DB, autocommit=True, tablename='exceptions'), 'queue': DBDict(DB, autocommit=True, tablename='queue'), - 'generated': DBDict(DB, autocommit=True, tablename='generated') + 'generated': DBDict(DB, autocommit=True, tablename='generated'), + 'prompts': DBDict(DB, autocommit=True, tablename='prompts') } diff --git a/bot/db/db_model.py b/bot/db/db_model.py index 43411fb..970a0e1 100644 --- a/bot/db/db_model.py +++ b/bot/db/db_model.py @@ -8,12 +8,13 @@ from .meta import DBMeta class DBTables: - tables = ['config', 'cooldown', 'exceptions', 'queue', 'generated'] + tables = ['config', 'cooldown', 'exceptions', 'queue', 'generated', 'prompts'] config = "config" cooldown = "cooldown" exceptions = "exceptions" queue = "queue" generated = "generated" + prompts = "prompts" class DBDict(SqliteDict): diff --git a/bot/db/encryption.py b/bot/db/encryption.py new file mode 100644 index 0000000..eee52c4 --- /dev/null +++ b/bot/db/encryption.py @@ -0,0 +1,25 @@ +from cryptography.fernet import Fernet +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC +import base64 +from bot.config import ENCRYPTION_KEY, BARS_APP_ID + + +fernet = Fernet( + base64.urlsafe_b64encode( + PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + iterations=390000, + salt=BARS_APP_ID.encode() + ).derive(ENCRYPTION_KEY) + ) +) + + +def encrypt(s: str) -> bytes: + return fernet.encrypt(s.encode()) + + +def decrypt(s: bytes) -> str: + return fernet.decrypt(s).decode() diff --git a/bot/handlers/admin/aliases.py b/bot/handlers/admin/aliases.py index 10acff5..aee96ed 100644 --- a/bot/handlers/admin/aliases.py +++ b/bot/handlers/admin/aliases.py @@ -1,5 +1,5 @@ from aiogram import types -from bot.db import db, DBTables +from bot.db import db, DBTables, encrypt import validators from bot.config import ADMIN from bot.utils.cooldown import throttle @@ -16,7 +16,7 @@ async def set_endpoint(message: types.Message): await message.reply("❌ Specify correct url for endpoint") return - db[DBTables.config]['endpoint'] = message.get_args() + db[DBTables.config]['endpoint'] = encrypt(message.get_args()) await db[DBTables.config].write() diff --git a/bot/handlers/image_info/image_info.py b/bot/handlers/image_info/image_info.py index e487c28..93f6be2 100644 --- a/bot/handlers/image_info/image_info.py +++ b/bot/handlers/image_info/image_info.py @@ -2,6 +2,7 @@ from aiogram import types from bot.db import db, DBTables from bot.utils.cooldown import throttle from bot.keyboards.exception import get_exception_keyboard +from bot.keyboards.image_info import get_img_info_keyboard from bot.utils.trace_exception import PrettyException @@ -12,14 +13,15 @@ async def imginfo(message: types.Message): await message.reply('❌ Reply with this command on picture', parse_mode='html') return - if not (original_r := db[DBTables.generated].get(message.reply_to_message.photo[0].file_unique_id)): + if not db[DBTables.generated].get(message.reply_to_message.photo[0].file_unique_id): await message.reply('❌ This picture wasn\'t generated using this bot ' 'or doesn\'t exist in database. Note this only works on ' 'files forwarded from bot.', parse_mode='html') return - await message.reply(str(original_r)) - # TODO: Pretty print this + await message.reply("Image was generated using this bot", reply_markup=get_img_info_keyboard( + message.reply_to_message.photo[0].file_unique_id + )) except IndexError: await message.reply('❌ Reply with this command on PICTURE', parse_mode='html') diff --git a/bot/handlers/txt2img/set_settings.py b/bot/handlers/txt2img/set_settings.py new file mode 100644 index 0000000..91883e9 --- /dev/null +++ b/bot/handlers/txt2img/set_settings.py @@ -0,0 +1,34 @@ +from aiogram import types +from bot.db import db, DBTables +from bot.utils.cooldown import throttle +from bot.modules.api.objects.prompt_request import Prompt +from bot.keyboards.exception import get_exception_keyboard +from bot.utils.trace_exception import PrettyException + + +@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins')) +async def set_prompt_command(message: types.Message): + temp_message = await message.reply("⏳ Setting prompt...") + if not message.get_args(): + await temp_message.edit_text("πŸ˜Άβ€πŸŒ«οΈ Specify prompt for this command. Check /help setprompt") + return + + try: + prompt: Prompt = db[DBTables.prompts].get(message.from_id, Prompt(message.get_args())) + prompt.prompt = message.get_args() + prompt.creator = message.from_id + db[DBTables.prompts][message.from_id] = prompt + + await db[DBTables.config].write() + + await message.reply('βœ… Default prompt set') + await temp_message.delete() + + except Exception as e: + exception_id = f'{message.message_thread_id}-{message.message_id}' + db[DBTables.exceptions][exception_id] = PrettyException(e) + await message.reply('❌ Error happened while processing your request', parse_mode='html', + reply_markup=get_exception_keyboard(exception_id)) + await temp_message.delete() + db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1 + return diff --git a/bot/handlers/txt2img/txt2img.py b/bot/handlers/txt2img/txt2img.py index 66aa0bf..47c7a34 100644 --- a/bot/handlers/txt2img/txt2img.py +++ b/bot/handlers/txt2img/txt2img.py @@ -12,9 +12,13 @@ from aiohttp import ClientConnectorError @throttle(cooldown=30, admin_ids=db[DBTables.config].get('admins')) async def txt2img_comand(message: types.Message): temp_message = await message.reply("⏳ Enqueued...") - if not message.get_args(): - await temp_message.edit_text("πŸ˜Άβ€πŸŒ«οΈ Specify prompt for this command. Check /help txt2img") - return + + prompt: Prompt = db[DBTables.prompts].get(message.from_id) + if not prompt: + if message.get_args(): + db[DBTables.prompts][message.from_id] = Prompt(message.get_args(), creator=message.from_id) + + # TODO: Move it to other module try: db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 0) + 1 diff --git a/bot/keyboards/image_info.py b/bot/keyboards/image_info.py new file mode 100644 index 0000000..ff4ea45 --- /dev/null +++ b/bot/keyboards/image_info.py @@ -0,0 +1,18 @@ +from aiogram import types +from bot.callbacks.factories.image_info import (prompt_only, full_prompt, import_prompt, back) + + +def get_img_info_keyboard(p_id: str) -> types.InlineKeyboardMarkup: + buttons = [types.InlineKeyboardButton(text="πŸ“‹ Show prompts", callback_data=prompt_only.new(p_id=p_id)), + types.InlineKeyboardButton(text="🧿 Show full info", callback_data=full_prompt.new(p_id=p_id)), + types.InlineKeyboardButton(text="πŸͺ„ Import prompt", callback_data=import_prompt.new(p_id=p_id))] + keyboard = types.InlineKeyboardMarkup(row_width=2) + keyboard.add(*buttons) + return keyboard + + +def get_img_back_keyboard(p_id: str) -> types.InlineKeyboardMarkup: + buttons = [types.InlineKeyboardButton(text="πŸ‘ˆ Back", callback_data=back.new(p_id=p_id))] + keyboard = types.InlineKeyboardMarkup() + keyboard.add(*buttons) + return keyboard diff --git a/bot/modules/api/models.py b/bot/modules/api/models.py index 2de3006..bdf73cc 100644 --- a/bot/modules/api/models.py +++ b/bot/modules/api/models.py @@ -1,10 +1,10 @@ import aiohttp -from bot.db import db, DBTables +from bot.db import db, DBTables, decrypt from rich import print async def get_models(): - endpoint = db[DBTables.config].get('endpoint') + endpoint = decrypt(db[DBTables.config].get('endpoint')) try: async with aiohttp.ClientSession() as session: r = await session.get(endpoint + "/sdapi/v1/sd-models") @@ -17,7 +17,7 @@ async def get_models(): async def set_model(model_name: str): - endpoint = db[DBTables.config].get('endpoint') + endpoint = decrypt(db[DBTables.config].get('endpoint')) try: async with aiohttp.ClientSession() as session: r = await session.post(endpoint + "/sdapi/v1/options", json={ diff --git a/bot/modules/api/status.py b/bot/modules/api/status.py index e033497..9322681 100644 --- a/bot/modules/api/status.py +++ b/bot/modules/api/status.py @@ -1,4 +1,4 @@ -from bot.db import db, DBTables +from bot.db import db, DBTables, decrypt import aiohttp import asyncio import time @@ -18,7 +18,7 @@ async def job_exists(endpoint): async def wait_for_status(ignore_exceptions: bool = False): - endpoint = db[DBTables.config].get('endpoint') + endpoint = decrypt(db[DBTables.config].get('endpoint')) try: while await job_exists(endpoint): while db[DBTables.cooldown].get('_last_time_status_checked', 0) + 5 > time.time(): diff --git a/bot/modules/api/txt2img.py b/bot/modules/api/txt2img.py index 429e127..29982a9 100644 --- a/bot/modules/api/txt2img.py +++ b/bot/modules/api/txt2img.py @@ -1,12 +1,12 @@ import aiohttp -from bot.db import db, DBTables +from bot.db import db, DBTables, decrypt from .objects.prompt_request import Prompt import json import base64 async def txt2img(prompt: Prompt, ignore_exceptions: bool = False) -> list[bytes, dict] | None: - endpoint = db[DBTables.config].get('endpoint') + endpoint = decrypt(db[DBTables.config].get('endpoint')) try: async with aiohttp.ClientSession() as session: r = await session.post( diff --git a/bot/utils/cooldown.py b/bot/utils/cooldown.py index dfec6b6..47bb29b 100644 --- a/bot/utils/cooldown.py +++ b/bot/utils/cooldown.py @@ -6,11 +6,11 @@ from aiogram import types def not_allowed(message: types.Message, cd: int, by_id: bool): + text = f"❌ Wait for cooldown ({cd}s for this command) " \ + f"{'. Please note that this cooldown is for all users' if not by_id else ''}" return asyncio.create_task(message.reply( - text= - f"❌ Wait for cooldown ({cd}s for this command)" - f"{'. Please note that this cooldown is for all users' if not by_id else ''}" - )) + text=text + ) if hasattr(message, 'reply') else message.answer(text=text, show_alert=True)) def throttle(cooldown: int = 5, by_id: bool = True, admin_ids: list = None): diff --git a/bot/utils/private_keyboard.py b/bot/utils/private_keyboard.py new file mode 100644 index 0000000..4a69b35 --- /dev/null +++ b/bot/utils/private_keyboard.py @@ -0,0 +1,12 @@ +from aiogram import types + + +async def other_user(call: types.CallbackQuery) -> bool: + if not hasattr(call.message.reply_to_message, 'from_id'): + await call.answer('Error, original call was removed', show_alert=True) + return True + elif call.message.reply_to_message.from_id != call.from_user.id: + await call.answer('It is not your menu!', show_alert=True) + return True + + return False diff --git a/requirements.txt b/requirements.txt index 644d592..0195544 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ python-dotenv rich aiohttp validators -sqlitedict \ No newline at end of file +sqlitedict +cryptography \ No newline at end of file