Image info, endpoint encryption

This commit is contained in:
BarsTiger
2023-02-24 23:41:29 +02:00
parent 846e65e2ce
commit 8f0cd0ac05
19 changed files with 211 additions and 23 deletions

View File

@@ -0,0 +1,7 @@
from aiogram.utils.callback_data import CallbackData
prompt_only = CallbackData("prompt_only", "p_id")
full_prompt = CallbackData("full_prompt", "p_id")
import_prompt = CallbackData("import_prompt", "p_id")
back = CallbackData("img_info_back", "p_id")

View File

@@ -0,0 +1,79 @@
from bot.common import dp
from bot.db import db, DBTables
from aiogram import types
from .factories.image_info import full_prompt, prompt_only, import_prompt, back
from bot.keyboards.image_info import get_img_info_keyboard, get_img_back_keyboard
from bot.utils.cooldown import throttle
from bot.utils.private_keyboard import other_user
from bot.modules.api.objects.prompt_request import Prompt
async def on_back(call: types.CallbackQuery, callback_data: dict):
p_id = callback_data['p_id']
if await other_user(call):
return
await call.message.edit_text(
"Image was generated using this bot",
parse_mode='html',
reply_markup=get_img_info_keyboard(p_id)
)
@throttle(5)
async def on_prompt_only(call: types.CallbackQuery, callback_data: dict):
p_id = callback_data['p_id']
if await other_user(call):
return
prompt: Prompt = db[DBTables.generated].get(p_id)
await call.message.edit_text(
f"🖤 Prompt: {prompt.prompt} \n"
f"{f'🐊 Negative: {prompt.negative_prompt}' if prompt.negative_prompt else ''}",
parse_mode='html',
reply_markup=get_img_back_keyboard(p_id)
)
@throttle(5)
async def on_full_info(call: types.CallbackQuery, callback_data: dict):
p_id = callback_data['p_id']
if await other_user(call):
return
prompt: Prompt = db[DBTables.generated].get(p_id)
await call.message.edit_text(
f"🖤 Prompt: {prompt.prompt} \n"
f"🐊 Negative: {prompt.negative_prompt} \n"
f"🪜 Steps: {prompt.steps} \n"
f"🧑‍🎨 CFG Scale: {prompt.cfg_scale} \n"
f"🖥️ Size: {prompt.width}x{prompt.height} \n"
f"😀 Restore faces: {'on' if prompt.restore_faces else 'off'} \n"
f"⚒️ Sampler: {prompt.sampler}",
parse_mode='html',
reply_markup=get_img_back_keyboard(p_id)
)
@throttle(5)
async def on_import(call: types.CallbackQuery, callback_data: dict):
p_id = callback_data['p_id']
if await other_user(call):
return
prompt: Prompt = db[DBTables.generated].get(p_id)
await call.message.edit_text(
f"😶‍🌫️ Not implemented yet",
parse_mode='html',
reply_markup=get_img_back_keyboard(p_id)
)
def register():
dp.register_callback_query_handler(on_prompt_only, prompt_only.filter())
dp.register_callback_query_handler(on_back, back.filter())
dp.register_callback_query_handler(on_full_info, full_prompt.filter())
dp.register_callback_query_handler(on_import, import_prompt.filter())

View File

@@ -3,9 +3,11 @@ from rich import print
def register_callbacks(): def register_callbacks():
from bot.callbacks import ( from bot.callbacks import (
exception exception,
image_info
) )
exception.register() exception.register()
image_info.register()
print('[gray]All callbacks registered[/]') print('[gray]All callbacks registered[/]')

View File

@@ -6,6 +6,7 @@ load_dotenv()
TOKEN = os.getenv('TOKEN') TOKEN = os.getenv('TOKEN')
ADMIN = int(os.getenv('ADMIN')) ADMIN = int(os.getenv('ADMIN'))
DB_CHAT = os.getenv('DB_CHAT') DB_CHAT = os.getenv('DB_CHAT')
ENCRYPTION_KEY = os.getenv('ENCRYPTION_KEY').encode()
_DB_PATH = os.getenv('DB_PATH') _DB_PATH = os.getenv('DB_PATH')
DB = _DB_PATH + '/db' DB = _DB_PATH + '/db'
DBMETA = _DB_PATH + '/dbmeta' DBMETA = _DB_PATH + '/dbmeta'

View File

@@ -1,2 +1,3 @@
from .db import db from .db import db
from .db_model import DBTables from .db_model import DBTables
from .encryption import encrypt, decrypt

View File

@@ -11,5 +11,6 @@ db = {
'cooldown': DBDict(DB, autocommit=True, tablename='cooldown'), 'cooldown': DBDict(DB, autocommit=True, tablename='cooldown'),
'exceptions': DBDict(DB, autocommit=True, tablename='exceptions'), 'exceptions': DBDict(DB, autocommit=True, tablename='exceptions'),
'queue': DBDict(DB, autocommit=True, tablename='queue'), 'queue': DBDict(DB, autocommit=True, tablename='queue'),
'generated': DBDict(DB, autocommit=True, tablename='generated') 'generated': DBDict(DB, autocommit=True, tablename='generated'),
'prompts': DBDict(DB, autocommit=True, tablename='prompts')
} }

View File

@@ -8,12 +8,13 @@ from .meta import DBMeta
class DBTables: class DBTables:
tables = ['config', 'cooldown', 'exceptions', 'queue', 'generated'] tables = ['config', 'cooldown', 'exceptions', 'queue', 'generated', 'prompts']
config = "config" config = "config"
cooldown = "cooldown" cooldown = "cooldown"
exceptions = "exceptions" exceptions = "exceptions"
queue = "queue" queue = "queue"
generated = "generated" generated = "generated"
prompts = "prompts"
class DBDict(SqliteDict): class DBDict(SqliteDict):

25
bot/db/encryption.py Normal file
View File

@@ -0,0 +1,25 @@
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
import base64
from bot.config import ENCRYPTION_KEY, BARS_APP_ID
fernet = Fernet(
base64.urlsafe_b64encode(
PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
iterations=390000,
salt=BARS_APP_ID.encode()
).derive(ENCRYPTION_KEY)
)
)
def encrypt(s: str) -> bytes:
return fernet.encrypt(s.encode())
def decrypt(s: bytes) -> str:
return fernet.decrypt(s).decode()

View File

@@ -1,5 +1,5 @@
from aiogram import types from aiogram import types
from bot.db import db, DBTables from bot.db import db, DBTables, encrypt
import validators import validators
from bot.config import ADMIN from bot.config import ADMIN
from bot.utils.cooldown import throttle from bot.utils.cooldown import throttle
@@ -16,7 +16,7 @@ async def set_endpoint(message: types.Message):
await message.reply("❌ Specify correct url for endpoint") await message.reply("❌ Specify correct url for endpoint")
return return
db[DBTables.config]['endpoint'] = message.get_args() db[DBTables.config]['endpoint'] = encrypt(message.get_args())
await db[DBTables.config].write() await db[DBTables.config].write()

View File

@@ -2,6 +2,7 @@ from aiogram import types
from bot.db import db, DBTables from bot.db import db, DBTables
from bot.utils.cooldown import throttle from bot.utils.cooldown import throttle
from bot.keyboards.exception import get_exception_keyboard from bot.keyboards.exception import get_exception_keyboard
from bot.keyboards.image_info import get_img_info_keyboard
from bot.utils.trace_exception import PrettyException from bot.utils.trace_exception import PrettyException
@@ -12,14 +13,15 @@ async def imginfo(message: types.Message):
await message.reply('❌ Reply with this command on picture', parse_mode='html') await message.reply('❌ Reply with this command on picture', parse_mode='html')
return return
if not (original_r := db[DBTables.generated].get(message.reply_to_message.photo[0].file_unique_id)): if not db[DBTables.generated].get(message.reply_to_message.photo[0].file_unique_id):
await message.reply('❌ This picture wasn\'t generated using this bot ' await message.reply('❌ This picture wasn\'t generated using this bot '
'or doesn\'t exist in database. Note this only works on ' 'or doesn\'t exist in database. Note this only works on '
'files forwarded from bot.', parse_mode='html') 'files forwarded from bot.', parse_mode='html')
return return
await message.reply(str(original_r)) await message.reply("Image was generated using this bot", reply_markup=get_img_info_keyboard(
# TODO: Pretty print this message.reply_to_message.photo[0].file_unique_id
))
except IndexError: except IndexError:
await message.reply('❌ Reply with this command on PICTURE', parse_mode='html') await message.reply('❌ Reply with this command on PICTURE', parse_mode='html')

View File

@@ -0,0 +1,34 @@
from aiogram import types
from bot.db import db, DBTables
from bot.utils.cooldown import throttle
from bot.modules.api.objects.prompt_request import Prompt
from bot.keyboards.exception import get_exception_keyboard
from bot.utils.trace_exception import PrettyException
@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
async def set_prompt_command(message: types.Message):
temp_message = await message.reply("⏳ Setting prompt...")
if not message.get_args():
await temp_message.edit_text("😶‍🌫️ Specify prompt for this command. Check /help setprompt")
return
try:
prompt: Prompt = db[DBTables.prompts].get(message.from_id, Prompt(message.get_args()))
prompt.prompt = message.get_args()
prompt.creator = message.from_id
db[DBTables.prompts][message.from_id] = prompt
await db[DBTables.config].write()
await message.reply('✅ Default prompt set')
await temp_message.delete()
except Exception as e:
exception_id = f'{message.message_thread_id}-{message.message_id}'
db[DBTables.exceptions][exception_id] = PrettyException(e)
await message.reply('❌ Error happened while processing your request', parse_mode='html',
reply_markup=get_exception_keyboard(exception_id))
await temp_message.delete()
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
return

View File

@@ -12,9 +12,13 @@ from aiohttp import ClientConnectorError
@throttle(cooldown=30, admin_ids=db[DBTables.config].get('admins')) @throttle(cooldown=30, admin_ids=db[DBTables.config].get('admins'))
async def txt2img_comand(message: types.Message): async def txt2img_comand(message: types.Message):
temp_message = await message.reply("⏳ Enqueued...") temp_message = await message.reply("⏳ Enqueued...")
if not message.get_args():
await temp_message.edit_text("😶‍🌫️ Specify prompt for this command. Check /help txt2img") prompt: Prompt = db[DBTables.prompts].get(message.from_id)
return if not prompt:
if message.get_args():
db[DBTables.prompts][message.from_id] = Prompt(message.get_args(), creator=message.from_id)
# TODO: Move it to other module
try: try:
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 0) + 1 db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 0) + 1

View File

@@ -0,0 +1,18 @@
from aiogram import types
from bot.callbacks.factories.image_info import (prompt_only, full_prompt, import_prompt, back)
def get_img_info_keyboard(p_id: str) -> types.InlineKeyboardMarkup:
buttons = [types.InlineKeyboardButton(text="📋 Show prompts", callback_data=prompt_only.new(p_id=p_id)),
types.InlineKeyboardButton(text="🧿 Show full info", callback_data=full_prompt.new(p_id=p_id)),
types.InlineKeyboardButton(text="🪄 Import prompt", callback_data=import_prompt.new(p_id=p_id))]
keyboard = types.InlineKeyboardMarkup(row_width=2)
keyboard.add(*buttons)
return keyboard
def get_img_back_keyboard(p_id: str) -> types.InlineKeyboardMarkup:
buttons = [types.InlineKeyboardButton(text="👈 Back", callback_data=back.new(p_id=p_id))]
keyboard = types.InlineKeyboardMarkup()
keyboard.add(*buttons)
return keyboard

View File

@@ -1,10 +1,10 @@
import aiohttp import aiohttp
from bot.db import db, DBTables from bot.db import db, DBTables, decrypt
from rich import print from rich import print
async def get_models(): async def get_models():
endpoint = db[DBTables.config].get('endpoint') endpoint = decrypt(db[DBTables.config].get('endpoint'))
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
r = await session.get(endpoint + "/sdapi/v1/sd-models") r = await session.get(endpoint + "/sdapi/v1/sd-models")
@@ -17,7 +17,7 @@ async def get_models():
async def set_model(model_name: str): async def set_model(model_name: str):
endpoint = db[DBTables.config].get('endpoint') endpoint = decrypt(db[DBTables.config].get('endpoint'))
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
r = await session.post(endpoint + "/sdapi/v1/options", json={ r = await session.post(endpoint + "/sdapi/v1/options", json={

View File

@@ -1,4 +1,4 @@
from bot.db import db, DBTables from bot.db import db, DBTables, decrypt
import aiohttp import aiohttp
import asyncio import asyncio
import time import time
@@ -18,7 +18,7 @@ async def job_exists(endpoint):
async def wait_for_status(ignore_exceptions: bool = False): async def wait_for_status(ignore_exceptions: bool = False):
endpoint = db[DBTables.config].get('endpoint') endpoint = decrypt(db[DBTables.config].get('endpoint'))
try: try:
while await job_exists(endpoint): while await job_exists(endpoint):
while db[DBTables.cooldown].get('_last_time_status_checked', 0) + 5 > time.time(): while db[DBTables.cooldown].get('_last_time_status_checked', 0) + 5 > time.time():

View File

@@ -1,12 +1,12 @@
import aiohttp import aiohttp
from bot.db import db, DBTables from bot.db import db, DBTables, decrypt
from .objects.prompt_request import Prompt from .objects.prompt_request import Prompt
import json import json
import base64 import base64
async def txt2img(prompt: Prompt, ignore_exceptions: bool = False) -> list[bytes, dict] | None: async def txt2img(prompt: Prompt, ignore_exceptions: bool = False) -> list[bytes, dict] | None:
endpoint = db[DBTables.config].get('endpoint') endpoint = decrypt(db[DBTables.config].get('endpoint'))
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
r = await session.post( r = await session.post(

View File

@@ -6,11 +6,11 @@ from aiogram import types
def not_allowed(message: types.Message, cd: int, by_id: bool): def not_allowed(message: types.Message, cd: int, by_id: bool):
return asyncio.create_task(message.reply( text = f"❌ Wait for cooldown ({cd}s for this command) " \
text=
f"❌ Wait for cooldown ({cd}s for this command)"
f"{'. Please note that this cooldown is for all users' if not by_id else ''}" f"{'. Please note that this cooldown is for all users' if not by_id else ''}"
)) return asyncio.create_task(message.reply(
text=text
) if hasattr(message, 'reply') else message.answer(text=text, show_alert=True))
def throttle(cooldown: int = 5, by_id: bool = True, admin_ids: list = None): def throttle(cooldown: int = 5, by_id: bool = True, admin_ids: list = None):

View File

@@ -0,0 +1,12 @@
from aiogram import types
async def other_user(call: types.CallbackQuery) -> bool:
if not hasattr(call.message.reply_to_message, 'from_id'):
await call.answer('Error, original call was removed', show_alert=True)
return True
elif call.message.reply_to_message.from_id != call.from_user.id:
await call.answer('It is not your menu!', show_alert=True)
return True
return False

View File

@@ -5,3 +5,4 @@ rich
aiohttp aiohttp
validators validators
sqlitedict sqlitedict
cryptography