Merge pull request #13 from vasmarfas/master

Multiple commits
This commit is contained in:
BarsTiger
2023-05-05 14:09:43 +03:00
committed by GitHub
10 changed files with 127 additions and 10 deletions

View File

@@ -10,4 +10,6 @@ def register():
dp.register_message_handler(reset.resetqueue, commands='resetqueue') dp.register_message_handler(reset.resetqueue, commands='resetqueue')
dp.register_message_handler(aliases.add_admin, commands='addadmin') dp.register_message_handler(aliases.add_admin, commands='addadmin')
dp.register_message_handler(aliases.remove_admin, commands='rmadmin') dp.register_message_handler(aliases.remove_admin, commands='rmadmin')
dp.register_message_handler(aliases.add_whitelist, commands='addwhitelist')
dp.register_message_handler(aliases.remove_whitelist, commands='rmwhitelist')
dp.register_message_handler(tools.hash_command, commands='hash') dp.register_message_handler(tools.hash_command, commands='hash')

View File

@@ -23,6 +23,75 @@ async def set_endpoint(message: types.Message, is_command: bool = True):
await message.reply("✅ New url set") await message.reply("✅ New url set")
@throttle(5)
async def add_whitelist(message: types.Message, is_command: bool = True):
if message.from_id != ADMIN:
await message.reply('❌ You are not permitted to do that. It is only for main admin')
return
if not (message.get_args() if is_command else message.text).isdecimal() and not \
hasattr(message.reply_to_message, 'text') and (message.chat.id >= 0):
await message.reply('❌ Put new whitelist chat ID to command arguments')
return
elif not (message.get_args() if is_command else message.text).isdecimal() and not \
hasattr(message.reply_to_message, 'text') and (message.chat.id < 0):
ID = message.chat.id
await message.reply(f'Chat ID: {message.chat.id} Chat title: {message.chat.title}')
elif not (message.get_args() if is_command else message.text).isdecimal():
ID = message.reply_to_message.from_id
elif not hasattr(message.reply_to_message, 'text'):
ID = int((message.get_args() if is_command else message.text))
if not isinstance(db[DBTables.config].get('whitelist'), list):
db[DBTables.config]['whitelist'] = list()
if ID not in db[DBTables.config].get('whitelist'):
whitelist_ = db[DBTables.config].get('whitelist')
whitelist_.append(ID)
db[DBTables.config]['whitelist'] = whitelist_
else:
await message.reply('❌ This whitelist is added already')
return
await db[DBTables.config].write()
await message.reply("✅ Added whitelist")
@throttle(5)
async def remove_whitelist(message: types.Message, is_command: bool = True):
if message.from_id != ADMIN:
await message.reply('❌ You are not permitted to do that. It is only for main admin')
return
if not (message.get_args() if is_command else message.text).isdecimal() and not \
hasattr(message.reply_to_message, 'text') and (message.chat.id >= 0:
await message.reply('❌ Put whitelist ID to command arguments or answer to users message')
return
elif not (message.get_args() if is_command else message.text).isdecimal() and not \
hasattr(message.reply_to_message, 'text') and (message.chat.id < 0):
ID = message.chat.id
await message.reply(f'Chat ID: {message.chat.id} Chat title: {message.chat.title}')
elif not (message.get_args() if is_command else message.text).isdecimal():
ID = message.reply_to_message.from_id
elif not hasattr(message.reply_to_message, 'text'):
ID = int((message.get_args() if is_command else message.text))
if not isinstance(db[DBTables.config].get('whitelist'), list):
db[DBTables.config]['whitelist'] = list()
if ID not in db[DBTables.config].get('whitelist'):
await message.reply('❌ This whitelist is not added')
return
else:
whitelist_ = db[DBTables.config].get('whitelist')
whitelist_.remove(ID)
db[DBTables.config]['whitelist'] = whitelist_
await db[DBTables.config].write()
await message.reply("✅ Removed whitelist")
@throttle(5) @throttle(5)
async def add_admin(message: types.Message, is_command: bool = True): async def add_admin(message: types.Message, is_command: bool = True):

View File

@@ -17,5 +17,7 @@ help_data = {
'setmodel': '(global) Sets StableDiffusion model for all users. Can be used only once an hour', 'setmodel': '(global) Sets StableDiffusion model for all users. Can be used only once an hour',
'setendpoint': '(admin) Set StableDiffusion API endpoint', 'setendpoint': '(admin) Set StableDiffusion API endpoint',
'addadmin': '(admin) Add new admin - reply to message or type user ID', 'addadmin': '(admin) Add new admin - reply to message or type user ID',
'rmadmin': '(admin) Remove admin - reply to message or type user ID' 'rmadmin': '(admin) Remove admin - reply to message or type user ID',
'addwhitelist': '(admin) Add new whitelist - reply to message or type user ID',
'rmwhitelist': '(admin) Remove whitelist - reply to message or type user ID'
} }

View File

@@ -18,6 +18,13 @@ async def start_command(message: types.Message):
await db[DBTables.config].write() await db[DBTables.config].write()
await message.reply(f'✅ Added {message.from_user.username} to admins. You can add other admins, ' await message.reply(f'✅ Added {message.from_user.username} to admins. You can add other admins, '
f'check bot settings menu') f'check bot settings menu')
if ADMIN not in db[DBTables.config].get('whitelist'):
whitelist_ = db[DBTables.config].get('whitelist')
whitelist_.append(ADMIN)
db[DBTables.config]['whitelist'] = whitelist_
await db[DBTables.config].write()
await message.reply(f'✅ Added {message.from_user.username} to whitelist. You can add other users to whitelist, '
f'check bot settings menu')
if db[DBTables.config].get('enabled') is None: if db[DBTables.config].get('enabled') is None:
db[DBTables.config]['enabled'] = True db[DBTables.config]['enabled'] = True
await message.reply(f'✅ Generation is enabled now') await message.reply(f'✅ Generation is enabled now')

View File

@@ -3,7 +3,7 @@ from rich import print
def register_handlers(): def register_handlers():
from bot.handlers import ( from bot.handlers import (
initialize, admin, help_command, txt2img, image_info, config initialize, admin, help_command, txt2img, image_info, config, whitelist
) )
initialize.register() initialize.register()

View File

@@ -1,5 +1,5 @@
from aiogram import types from aiogram import types
from bot.db import db, DBTables from bot.db import db, DBTables, decrypt
from bot.utils.cooldown import throttle from bot.utils.cooldown import throttle
from bot.keyboards.set_model import get_set_model_keyboard from bot.keyboards.set_model import get_set_model_keyboard
from bot.modules.api.models import get_models from bot.modules.api.models import get_models
@@ -7,13 +7,16 @@ from bot.utils.errorable_command import wrap_exception
@wrap_exception() @wrap_exception()
@throttle(cooldown=60*60, admin_ids=db[DBTables.config].get('admins'), by_id=False) @throttle(cooldown=5*60, admin_ids=db[DBTables.config].get('admins'), by_id=False)
async def set_model_command(message: types.Message): async def set_model_command(message: types.Message):
if (message.chat.id not in db[DBTables.config]['whitelist'] and message.from_id not in db[DBTables.config]['whitelist']):
await message.reply('❌You are not on the white list, access denied. Contact admin @kilisauros for details')
return
models = await get_models() models = await get_models()
if models is not None and len(models) > 0: if models is not None and len(models) > 0:
db[DBTables.config]['models'] = models db[DBTables.config]['models'] = models
else: else:
await message.reply('❌ No models available') await message.reply('❌ No models available')
return return
await message.reply("Examples of models (with additional info): https://telegra.ph/Opisanie-raboty-modelej-05-03")
await message.reply("🪄 You can choose model from available:", reply_markup=get_set_model_keyboard(0)) await message.reply("🪄 You can choose model from available:", reply_markup=get_set_model_keyboard(0))

View File

@@ -1,4 +1,5 @@
import re import re
from bot.common import bot
from aiogram import types from aiogram import types
from bot.db import db, DBTables from bot.db import db, DBTables
from bot.utils.cooldown import throttle from bot.utils.cooldown import throttle
@@ -8,6 +9,8 @@ from bot.modules.api.objects.prompt_request import Generated
from bot.modules.api.status import wait_for_status from bot.modules.api.status import wait_for_status
from bot.keyboards.image_info import get_img_info_keyboard from bot.keyboards.image_info import get_img_info_keyboard
from bot.utils.errorable_command import wrap_exception from bot.utils.errorable_command import wrap_exception
from bot.callbacks.factories.image_info import (prompt_only, full_prompt, import_prompt, back)
@wrap_exception([ValueError], custom_loading=True) @wrap_exception([ValueError], custom_loading=True)
@@ -18,6 +21,10 @@ async def generate_command(message: types.Message):
await message.reply('💔 Generation is disabled by admins now. Try again later') await message.reply('💔 Generation is disabled by admins now. Try again later')
await temp_message.delete() await temp_message.delete()
return return
elif (message.chat.id not in db[DBTables.config]['whitelist'] and message.from_id not in db[DBTables.config]['whitelist']):
await message.reply('❌You are not on the white list, access denied. Contact admin @kilisauros for details')
await temp_message.delete()
return
try: try:
prompt = get_prompt(user_id=message.from_id, prompt = get_prompt(user_id=message.from_id,
@@ -39,6 +46,24 @@ async def generate_command(message: types.Message):
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1 db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
image = await txt2img(prompt) image = await txt2img(prompt)
image_message = await message.reply_photo(photo=image[0]) image_message = await message.reply_photo(photo=image[0])
#Send photo to SD Image Archive
archive_message = f'User ID: {message.from_id} \n \
User nickname: {message.from_user.full_name} \n \
User username: @{message.from_user.username} \n \
Chat ID: {message.chat.id} \n \
Chat title: {message.chat.title} \n \
Info: \n \
🖤 Prompt: {prompt.prompt} \n \
🐊 Negative: {prompt.negative_prompt} \n \
💫 Model: In development \n \
🪜 Steps: {prompt.steps} \n \
🧑‍🎨 CFG Scale: {prompt.cfg_scale} \n \
🖥️ Size: {prompt.width}x{prompt.height} \n \
😀 Restore faces: {prompt.restore_faces} \n \
⚒️ Sampler: {prompt.sampler} \n '
await bot.send_photo(-929754401, photo=image[0], caption=archive_message)
db[DBTables.generated][image_message.photo[0].file_unique_id] = Generated( db[DBTables.generated][image_message.photo[0].file_unique_id] = Generated(
prompt=prompt, prompt=prompt,
@@ -57,4 +82,4 @@ async def generate_command(message: types.Message):
await message.reply(f'❌ Error! {e.args[0]}') await message.reply(f'❌ Error! {e.args[0]}')
await temp_message.delete() await temp_message.delete()
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1 db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
return return

View File

@@ -1,14 +1,16 @@
import aiohttp import aiohttp
import json
from bot.db import db, DBTables, decrypt from bot.db import db, DBTables, decrypt
async def get_models() -> list | None: async def get_models() -> list | None:
endpoint = decrypt(db[DBTables.config].get('endpoint')) endpoint = decrypt(db[DBTables.config].get('endpoint'))
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
r = await session.get(endpoint + "/sdapi/v1/sd-models") r = await session.get(endpoint + '/sdapi/v1/sd-models')
if r.status != 200: if r.status != 200:
return None return None
return [x["title"] for x in await r.json()] json_data = await r.json()
return [x["title"] for x in json_data]
async def set_model(model_name: str): async def set_model(model_name: str):

View File

@@ -30,7 +30,7 @@ def throttle(cooldown: int = 5, by_id: bool = True, admin_ids: list = None):
if not last_time: if not last_time:
last_time = delta last_time = delta
if last_time <= delta: if last_time <= delta or user_id in admin_ids:
try: try:
f_name_dict = db[DBTables.cooldown][func.__name__] f_name_dict = db[DBTables.cooldown][func.__name__]
f_name_dict[user_id] = now f_name_dict[user_id] = now
@@ -42,7 +42,7 @@ def throttle(cooldown: int = 5, by_id: bool = True, admin_ids: list = None):
try: try:
return asyncio.create_task(func(*args, **kwargs)) return asyncio.create_task(func(*args, **kwargs))
except Exception as e: except Exception as e:
assert e assert e
else: else:
return not_allowed(*args, cooldown, by_id) return not_allowed(*args, cooldown, by_id)

View File

@@ -1,10 +1,17 @@
from aiogram import types from aiogram import types
from bot.config import ADMIN
async def other_user(call: types.CallbackQuery) -> bool: async def other_user(call: types.CallbackQuery) -> bool:
if not hasattr(call.message.reply_to_message, 'from_id'): if not hasattr(call.message.reply_to_message, 'from_id'):
await call.answer('Error, original call was removed', show_alert=True) await call.answer('Error, original call was removed', show_alert=True)
return True return True
elif call.from_user.id == ADMIN:
print(f"call.from_user.id: {call.from_user.id}")
print(f"ADMIN: {ADMIN}")
print("User is Admin")
return False
elif call.message.reply_to_message.from_id != call.from_user.id: elif call.message.reply_to_message.from_id != call.from_user.id:
await call.answer('It is not your menu!', show_alert=True) await call.answer('It is not your menu!', show_alert=True)
return True return True