diff --git a/bot/handlers/admin/__init__.py b/bot/handlers/admin/__init__.py index c840d8a..fb1be1f 100644 --- a/bot/handlers/admin/__init__.py +++ b/bot/handlers/admin/__init__.py @@ -10,4 +10,6 @@ def register(): dp.register_message_handler(reset.resetqueue, commands='resetqueue') dp.register_message_handler(aliases.add_admin, commands='addadmin') dp.register_message_handler(aliases.remove_admin, commands='rmadmin') + dp.register_message_handler(aliases.add_whitelist, commands='addwhitelist') + dp.register_message_handler(aliases.remove_whitelist, commands='rmwhitelist') dp.register_message_handler(tools.hash_command, commands='hash') diff --git a/bot/handlers/admin/aliases.py b/bot/handlers/admin/aliases.py index b6d47aa..5521731 100644 --- a/bot/handlers/admin/aliases.py +++ b/bot/handlers/admin/aliases.py @@ -23,6 +23,75 @@ async def set_endpoint(message: types.Message, is_command: bool = True): await message.reply("✅ New url set") +@throttle(5) +async def add_whitelist(message: types.Message, is_command: bool = True): + if message.from_id != ADMIN: + await message.reply('❌ You are not permitted to do that. It is only for main admin') + return + + if not (message.get_args() if is_command else message.text).isdecimal() and not \ + hasattr(message.reply_to_message, 'text') and (message.chat.id >= 0): + await message.reply('❌ Put new whitelist chat ID to command arguments') + return + elif not (message.get_args() if is_command else message.text).isdecimal() and not \ + hasattr(message.reply_to_message, 'text') and (message.chat.id < 0): + ID = message.chat.id + await message.reply(f'Chat ID: {message.chat.id} Chat title: {message.chat.title}') + elif not (message.get_args() if is_command else message.text).isdecimal(): + ID = message.reply_to_message.from_id + elif not hasattr(message.reply_to_message, 'text'): + ID = int((message.get_args() if is_command else message.text)) + + if not isinstance(db[DBTables.config].get('whitelist'), list): + db[DBTables.config]['whitelist'] = list() + + if ID not in db[DBTables.config].get('whitelist'): + whitelist_ = db[DBTables.config].get('whitelist') + whitelist_.append(ID) + db[DBTables.config]['whitelist'] = whitelist_ + else: + await message.reply('❌ This whitelist is added already') + return + + await db[DBTables.config].write() + + await message.reply("✅ Added whitelist") + +@throttle(5) +async def remove_whitelist(message: types.Message, is_command: bool = True): + if message.from_id != ADMIN: + await message.reply('❌ You are not permitted to do that. It is only for main admin') + return + + if not (message.get_args() if is_command else message.text).isdecimal() and not \ + hasattr(message.reply_to_message, 'text') and (message.chat.id >= 0: + await message.reply('❌ Put whitelist ID to command arguments or answer to users message') + return + elif not (message.get_args() if is_command else message.text).isdecimal() and not \ + hasattr(message.reply_to_message, 'text') and (message.chat.id < 0): + ID = message.chat.id + await message.reply(f'Chat ID: {message.chat.id} Chat title: {message.chat.title}') + elif not (message.get_args() if is_command else message.text).isdecimal(): + ID = message.reply_to_message.from_id + elif not hasattr(message.reply_to_message, 'text'): + ID = int((message.get_args() if is_command else message.text)) + + if not isinstance(db[DBTables.config].get('whitelist'), list): + db[DBTables.config]['whitelist'] = list() + + if ID not in db[DBTables.config].get('whitelist'): + await message.reply('❌ This whitelist is not added') + return + else: + whitelist_ = db[DBTables.config].get('whitelist') + whitelist_.remove(ID) + db[DBTables.config]['whitelist'] = whitelist_ + + await db[DBTables.config].write() + + await message.reply("✅ Removed whitelist") + + @throttle(5) async def add_admin(message: types.Message, is_command: bool = True): diff --git a/bot/handlers/help_command/help_strings.py b/bot/handlers/help_command/help_strings.py index 686cf72..bb795da 100644 --- a/bot/handlers/help_command/help_strings.py +++ b/bot/handlers/help_command/help_strings.py @@ -17,5 +17,7 @@ help_data = { 'setmodel': '(global) Sets StableDiffusion model for all users. Can be used only once an hour', 'setendpoint': '(admin) Set StableDiffusion API endpoint', 'addadmin': '(admin) Add new admin - reply to message or type user ID', - 'rmadmin': '(admin) Remove admin - reply to message or type user ID' + 'rmadmin': '(admin) Remove admin - reply to message or type user ID', + 'addwhitelist': '(admin) Add new whitelist - reply to message or type user ID', + 'rmwhitelist': '(admin) Remove whitelist - reply to message or type user ID' } diff --git a/bot/handlers/initialize/start.py b/bot/handlers/initialize/start.py index 6780159..e832a48 100644 --- a/bot/handlers/initialize/start.py +++ b/bot/handlers/initialize/start.py @@ -18,6 +18,13 @@ async def start_command(message: types.Message): await db[DBTables.config].write() await message.reply(f'✅ Added {message.from_user.username} to admins. You can add other admins, ' f'check bot settings menu') + if ADMIN not in db[DBTables.config].get('whitelist'): + whitelist_ = db[DBTables.config].get('whitelist') + whitelist_.append(ADMIN) + db[DBTables.config]['whitelist'] = whitelist_ + await db[DBTables.config].write() + await message.reply(f'✅ Added {message.from_user.username} to whitelist. You can add other users to whitelist, ' + f'check bot settings menu') if db[DBTables.config].get('enabled') is None: db[DBTables.config]['enabled'] = True await message.reply(f'✅ Generation is enabled now') diff --git a/bot/handlers/register.py b/bot/handlers/register.py index de3a159..fb82a95 100644 --- a/bot/handlers/register.py +++ b/bot/handlers/register.py @@ -3,7 +3,7 @@ from rich import print def register_handlers(): from bot.handlers import ( - initialize, admin, help_command, txt2img, image_info, config + initialize, admin, help_command, txt2img, image_info, config, whitelist ) initialize.register() diff --git a/bot/handlers/txt2img/set_model.py b/bot/handlers/txt2img/set_model.py index ddb31e3..dcc5634 100644 --- a/bot/handlers/txt2img/set_model.py +++ b/bot/handlers/txt2img/set_model.py @@ -1,5 +1,5 @@ from aiogram import types -from bot.db import db, DBTables +from bot.db import db, DBTables, decrypt from bot.utils.cooldown import throttle from bot.keyboards.set_model import get_set_model_keyboard from bot.modules.api.models import get_models @@ -7,13 +7,16 @@ from bot.utils.errorable_command import wrap_exception @wrap_exception() -@throttle(cooldown=60*60, admin_ids=db[DBTables.config].get('admins'), by_id=False) +@throttle(cooldown=5*60, admin_ids=db[DBTables.config].get('admins'), by_id=False) async def set_model_command(message: types.Message): + if (message.chat.id not in db[DBTables.config]['whitelist'] and message.from_id not in db[DBTables.config]['whitelist']): + await message.reply('❌You are not on the white list, access denied. Contact admin @kilisauros for details') + return models = await get_models() if models is not None and len(models) > 0: db[DBTables.config]['models'] = models else: await message.reply('❌ No models available') return - + await message.reply("Examples of models (with additional info): https://telegra.ph/Opisanie-raboty-modelej-05-03") await message.reply("🪄 You can choose model from available:", reply_markup=get_set_model_keyboard(0)) diff --git a/bot/handlers/txt2img/txt2img.py b/bot/handlers/txt2img/txt2img.py index c164474..1da8ee2 100644 --- a/bot/handlers/txt2img/txt2img.py +++ b/bot/handlers/txt2img/txt2img.py @@ -1,4 +1,5 @@ import re +from bot.common import bot from aiogram import types from bot.db import db, DBTables from bot.utils.cooldown import throttle @@ -8,6 +9,8 @@ from bot.modules.api.objects.prompt_request import Generated from bot.modules.api.status import wait_for_status from bot.keyboards.image_info import get_img_info_keyboard from bot.utils.errorable_command import wrap_exception +from bot.callbacks.factories.image_info import (prompt_only, full_prompt, import_prompt, back) + @wrap_exception([ValueError], custom_loading=True) @@ -18,6 +21,10 @@ async def generate_command(message: types.Message): await message.reply('💔 Generation is disabled by admins now. Try again later') await temp_message.delete() return + elif (message.chat.id not in db[DBTables.config]['whitelist'] and message.from_id not in db[DBTables.config]['whitelist']): + await message.reply('❌You are not on the white list, access denied. Contact admin @kilisauros for details') + await temp_message.delete() + return try: prompt = get_prompt(user_id=message.from_id, @@ -39,6 +46,24 @@ async def generate_command(message: types.Message): db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1 image = await txt2img(prompt) image_message = await message.reply_photo(photo=image[0]) + + #Send photo to SD Image Archive + + archive_message = f'User ID: {message.from_id} \n \ + User nickname: {message.from_user.full_name} \n \ + User username: @{message.from_user.username} \n \ + Chat ID: {message.chat.id} \n \ + Chat title: {message.chat.title} \n \ + Info: \n \ + 🖤 Prompt: {prompt.prompt} \n \ + 🐊 Negative: {prompt.negative_prompt} \n \ + 💫 Model: In development \n \ + 🪜 Steps: {prompt.steps} \n \ + 🧑‍🎨 CFG Scale: {prompt.cfg_scale} \n \ + 🖥️ Size: {prompt.width}x{prompt.height} \n \ + 😀 Restore faces: {prompt.restore_faces} \n \ + ⚒️ Sampler: {prompt.sampler} \n ' + await bot.send_photo(-929754401, photo=image[0], caption=archive_message) db[DBTables.generated][image_message.photo[0].file_unique_id] = Generated( prompt=prompt, @@ -57,4 +82,4 @@ async def generate_command(message: types.Message): await message.reply(f'❌ Error! {e.args[0]}') await temp_message.delete() db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1 - return + return \ No newline at end of file diff --git a/bot/modules/api/models.py b/bot/modules/api/models.py index 77fba92..9000950 100644 --- a/bot/modules/api/models.py +++ b/bot/modules/api/models.py @@ -1,14 +1,16 @@ import aiohttp +import json from bot.db import db, DBTables, decrypt async def get_models() -> list | None: endpoint = decrypt(db[DBTables.config].get('endpoint')) async with aiohttp.ClientSession() as session: - r = await session.get(endpoint + "/sdapi/v1/sd-models") + r = await session.get(endpoint + '/sdapi/v1/sd-models') if r.status != 200: return None - return [x["title"] for x in await r.json()] + json_data = await r.json() + return [x["title"] for x in json_data] async def set_model(model_name: str): diff --git a/bot/utils/cooldown.py b/bot/utils/cooldown.py index 47bb29b..d997cf5 100644 --- a/bot/utils/cooldown.py +++ b/bot/utils/cooldown.py @@ -30,7 +30,7 @@ def throttle(cooldown: int = 5, by_id: bool = True, admin_ids: list = None): if not last_time: last_time = delta - if last_time <= delta: + if last_time <= delta or user_id in admin_ids: try: f_name_dict = db[DBTables.cooldown][func.__name__] f_name_dict[user_id] = now @@ -42,7 +42,7 @@ def throttle(cooldown: int = 5, by_id: bool = True, admin_ids: list = None): try: return asyncio.create_task(func(*args, **kwargs)) except Exception as e: - assert e + assert e else: return not_allowed(*args, cooldown, by_id) diff --git a/bot/utils/private_keyboard.py b/bot/utils/private_keyboard.py index 4a69b35..ce0cf62 100644 --- a/bot/utils/private_keyboard.py +++ b/bot/utils/private_keyboard.py @@ -1,10 +1,17 @@ from aiogram import types +from bot.config import ADMIN async def other_user(call: types.CallbackQuery) -> bool: if not hasattr(call.message.reply_to_message, 'from_id'): await call.answer('Error, original call was removed', show_alert=True) return True + elif call.from_user.id == ADMIN: + print(f"call.from_user.id: {call.from_user.id}") + print(f"ADMIN: {ADMIN}") + print("User is Admin") + return False + elif call.message.reply_to_message.from_id != call.from_user.id: await call.answer('It is not your menu!', show_alert=True) return True