Image info, endpoint encryption
This commit is contained in:
7
bot/callbacks/factories/image_info.py
Normal file
7
bot/callbacks/factories/image_info.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from aiogram.utils.callback_data import CallbackData
|
||||||
|
|
||||||
|
|
||||||
|
prompt_only = CallbackData("prompt_only", "p_id")
|
||||||
|
full_prompt = CallbackData("full_prompt", "p_id")
|
||||||
|
import_prompt = CallbackData("import_prompt", "p_id")
|
||||||
|
back = CallbackData("img_info_back", "p_id")
|
||||||
79
bot/callbacks/image_info.py
Normal file
79
bot/callbacks/image_info.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
from bot.common import dp
|
||||||
|
from bot.db import db, DBTables
|
||||||
|
from aiogram import types
|
||||||
|
from .factories.image_info import full_prompt, prompt_only, import_prompt, back
|
||||||
|
from bot.keyboards.image_info import get_img_info_keyboard, get_img_back_keyboard
|
||||||
|
from bot.utils.cooldown import throttle
|
||||||
|
from bot.utils.private_keyboard import other_user
|
||||||
|
from bot.modules.api.objects.prompt_request import Prompt
|
||||||
|
|
||||||
|
|
||||||
|
async def on_back(call: types.CallbackQuery, callback_data: dict):
|
||||||
|
p_id = callback_data['p_id']
|
||||||
|
if await other_user(call):
|
||||||
|
return
|
||||||
|
|
||||||
|
await call.message.edit_text(
|
||||||
|
"Image was generated using this bot",
|
||||||
|
parse_mode='html',
|
||||||
|
reply_markup=get_img_info_keyboard(p_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@throttle(5)
|
||||||
|
async def on_prompt_only(call: types.CallbackQuery, callback_data: dict):
|
||||||
|
p_id = callback_data['p_id']
|
||||||
|
if await other_user(call):
|
||||||
|
return
|
||||||
|
|
||||||
|
prompt: Prompt = db[DBTables.generated].get(p_id)
|
||||||
|
|
||||||
|
await call.message.edit_text(
|
||||||
|
f"🖤 Prompt: {prompt.prompt} \n"
|
||||||
|
f"{f'🐊 Negative: {prompt.negative_prompt}' if prompt.negative_prompt else ''}",
|
||||||
|
parse_mode='html',
|
||||||
|
reply_markup=get_img_back_keyboard(p_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@throttle(5)
|
||||||
|
async def on_full_info(call: types.CallbackQuery, callback_data: dict):
|
||||||
|
p_id = callback_data['p_id']
|
||||||
|
if await other_user(call):
|
||||||
|
return
|
||||||
|
|
||||||
|
prompt: Prompt = db[DBTables.generated].get(p_id)
|
||||||
|
|
||||||
|
await call.message.edit_text(
|
||||||
|
f"🖤 Prompt: {prompt.prompt} \n"
|
||||||
|
f"🐊 Negative: {prompt.negative_prompt} \n"
|
||||||
|
f"🪜 Steps: {prompt.steps} \n"
|
||||||
|
f"🧑🎨 CFG Scale: {prompt.cfg_scale} \n"
|
||||||
|
f"🖥️ Size: {prompt.width}x{prompt.height} \n"
|
||||||
|
f"😀 Restore faces: {'on' if prompt.restore_faces else 'off'} \n"
|
||||||
|
f"⚒️ Sampler: {prompt.sampler}",
|
||||||
|
parse_mode='html',
|
||||||
|
reply_markup=get_img_back_keyboard(p_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@throttle(5)
|
||||||
|
async def on_import(call: types.CallbackQuery, callback_data: dict):
|
||||||
|
p_id = callback_data['p_id']
|
||||||
|
if await other_user(call):
|
||||||
|
return
|
||||||
|
|
||||||
|
prompt: Prompt = db[DBTables.generated].get(p_id)
|
||||||
|
|
||||||
|
await call.message.edit_text(
|
||||||
|
f"😶🌫️ Not implemented yet",
|
||||||
|
parse_mode='html',
|
||||||
|
reply_markup=get_img_back_keyboard(p_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def register():
|
||||||
|
dp.register_callback_query_handler(on_prompt_only, prompt_only.filter())
|
||||||
|
dp.register_callback_query_handler(on_back, back.filter())
|
||||||
|
dp.register_callback_query_handler(on_full_info, full_prompt.filter())
|
||||||
|
dp.register_callback_query_handler(on_import, import_prompt.filter())
|
||||||
@@ -3,9 +3,11 @@ from rich import print
|
|||||||
|
|
||||||
def register_callbacks():
|
def register_callbacks():
|
||||||
from bot.callbacks import (
|
from bot.callbacks import (
|
||||||
exception
|
exception,
|
||||||
|
image_info
|
||||||
)
|
)
|
||||||
|
|
||||||
exception.register()
|
exception.register()
|
||||||
|
image_info.register()
|
||||||
|
|
||||||
print('[gray]All callbacks registered[/]')
|
print('[gray]All callbacks registered[/]')
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ load_dotenv()
|
|||||||
TOKEN = os.getenv('TOKEN')
|
TOKEN = os.getenv('TOKEN')
|
||||||
ADMIN = int(os.getenv('ADMIN'))
|
ADMIN = int(os.getenv('ADMIN'))
|
||||||
DB_CHAT = os.getenv('DB_CHAT')
|
DB_CHAT = os.getenv('DB_CHAT')
|
||||||
|
ENCRYPTION_KEY = os.getenv('ENCRYPTION_KEY').encode()
|
||||||
_DB_PATH = os.getenv('DB_PATH')
|
_DB_PATH = os.getenv('DB_PATH')
|
||||||
DB = _DB_PATH + '/db'
|
DB = _DB_PATH + '/db'
|
||||||
DBMETA = _DB_PATH + '/dbmeta'
|
DBMETA = _DB_PATH + '/dbmeta'
|
||||||
|
|||||||
@@ -1,2 +1,3 @@
|
|||||||
from .db import db
|
from .db import db
|
||||||
from .db_model import DBTables
|
from .db_model import DBTables
|
||||||
|
from .encryption import encrypt, decrypt
|
||||||
|
|||||||
@@ -11,5 +11,6 @@ db = {
|
|||||||
'cooldown': DBDict(DB, autocommit=True, tablename='cooldown'),
|
'cooldown': DBDict(DB, autocommit=True, tablename='cooldown'),
|
||||||
'exceptions': DBDict(DB, autocommit=True, tablename='exceptions'),
|
'exceptions': DBDict(DB, autocommit=True, tablename='exceptions'),
|
||||||
'queue': DBDict(DB, autocommit=True, tablename='queue'),
|
'queue': DBDict(DB, autocommit=True, tablename='queue'),
|
||||||
'generated': DBDict(DB, autocommit=True, tablename='generated')
|
'generated': DBDict(DB, autocommit=True, tablename='generated'),
|
||||||
|
'prompts': DBDict(DB, autocommit=True, tablename='prompts')
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,12 +8,13 @@ from .meta import DBMeta
|
|||||||
|
|
||||||
|
|
||||||
class DBTables:
|
class DBTables:
|
||||||
tables = ['config', 'cooldown', 'exceptions', 'queue', 'generated']
|
tables = ['config', 'cooldown', 'exceptions', 'queue', 'generated', 'prompts']
|
||||||
config = "config"
|
config = "config"
|
||||||
cooldown = "cooldown"
|
cooldown = "cooldown"
|
||||||
exceptions = "exceptions"
|
exceptions = "exceptions"
|
||||||
queue = "queue"
|
queue = "queue"
|
||||||
generated = "generated"
|
generated = "generated"
|
||||||
|
prompts = "prompts"
|
||||||
|
|
||||||
|
|
||||||
class DBDict(SqliteDict):
|
class DBDict(SqliteDict):
|
||||||
|
|||||||
25
bot/db/encryption.py
Normal file
25
bot/db/encryption.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
from cryptography.fernet import Fernet
|
||||||
|
from cryptography.hazmat.primitives import hashes
|
||||||
|
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||||
|
import base64
|
||||||
|
from bot.config import ENCRYPTION_KEY, BARS_APP_ID
|
||||||
|
|
||||||
|
|
||||||
|
fernet = Fernet(
|
||||||
|
base64.urlsafe_b64encode(
|
||||||
|
PBKDF2HMAC(
|
||||||
|
algorithm=hashes.SHA256(),
|
||||||
|
length=32,
|
||||||
|
iterations=390000,
|
||||||
|
salt=BARS_APP_ID.encode()
|
||||||
|
).derive(ENCRYPTION_KEY)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def encrypt(s: str) -> bytes:
|
||||||
|
return fernet.encrypt(s.encode())
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt(s: bytes) -> str:
|
||||||
|
return fernet.decrypt(s).decode()
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
from aiogram import types
|
from aiogram import types
|
||||||
from bot.db import db, DBTables
|
from bot.db import db, DBTables, encrypt
|
||||||
import validators
|
import validators
|
||||||
from bot.config import ADMIN
|
from bot.config import ADMIN
|
||||||
from bot.utils.cooldown import throttle
|
from bot.utils.cooldown import throttle
|
||||||
@@ -16,7 +16,7 @@ async def set_endpoint(message: types.Message):
|
|||||||
await message.reply("❌ Specify correct url for endpoint")
|
await message.reply("❌ Specify correct url for endpoint")
|
||||||
return
|
return
|
||||||
|
|
||||||
db[DBTables.config]['endpoint'] = message.get_args()
|
db[DBTables.config]['endpoint'] = encrypt(message.get_args())
|
||||||
|
|
||||||
await db[DBTables.config].write()
|
await db[DBTables.config].write()
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from aiogram import types
|
|||||||
from bot.db import db, DBTables
|
from bot.db import db, DBTables
|
||||||
from bot.utils.cooldown import throttle
|
from bot.utils.cooldown import throttle
|
||||||
from bot.keyboards.exception import get_exception_keyboard
|
from bot.keyboards.exception import get_exception_keyboard
|
||||||
|
from bot.keyboards.image_info import get_img_info_keyboard
|
||||||
from bot.utils.trace_exception import PrettyException
|
from bot.utils.trace_exception import PrettyException
|
||||||
|
|
||||||
|
|
||||||
@@ -12,14 +13,15 @@ async def imginfo(message: types.Message):
|
|||||||
await message.reply('❌ Reply with this command on picture', parse_mode='html')
|
await message.reply('❌ Reply with this command on picture', parse_mode='html')
|
||||||
return
|
return
|
||||||
|
|
||||||
if not (original_r := db[DBTables.generated].get(message.reply_to_message.photo[0].file_unique_id)):
|
if not db[DBTables.generated].get(message.reply_to_message.photo[0].file_unique_id):
|
||||||
await message.reply('❌ This picture wasn\'t generated using this bot '
|
await message.reply('❌ This picture wasn\'t generated using this bot '
|
||||||
'or doesn\'t exist in database. Note this only works on '
|
'or doesn\'t exist in database. Note this only works on '
|
||||||
'files forwarded from bot.', parse_mode='html')
|
'files forwarded from bot.', parse_mode='html')
|
||||||
return
|
return
|
||||||
|
|
||||||
await message.reply(str(original_r))
|
await message.reply("Image was generated using this bot", reply_markup=get_img_info_keyboard(
|
||||||
# TODO: Pretty print this
|
message.reply_to_message.photo[0].file_unique_id
|
||||||
|
))
|
||||||
|
|
||||||
except IndexError:
|
except IndexError:
|
||||||
await message.reply('❌ Reply with this command on PICTURE', parse_mode='html')
|
await message.reply('❌ Reply with this command on PICTURE', parse_mode='html')
|
||||||
|
|||||||
34
bot/handlers/txt2img/set_settings.py
Normal file
34
bot/handlers/txt2img/set_settings.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
from aiogram import types
|
||||||
|
from bot.db import db, DBTables
|
||||||
|
from bot.utils.cooldown import throttle
|
||||||
|
from bot.modules.api.objects.prompt_request import Prompt
|
||||||
|
from bot.keyboards.exception import get_exception_keyboard
|
||||||
|
from bot.utils.trace_exception import PrettyException
|
||||||
|
|
||||||
|
|
||||||
|
@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
|
||||||
|
async def set_prompt_command(message: types.Message):
|
||||||
|
temp_message = await message.reply("⏳ Setting prompt...")
|
||||||
|
if not message.get_args():
|
||||||
|
await temp_message.edit_text("😶🌫️ Specify prompt for this command. Check /help setprompt")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
prompt: Prompt = db[DBTables.prompts].get(message.from_id, Prompt(message.get_args()))
|
||||||
|
prompt.prompt = message.get_args()
|
||||||
|
prompt.creator = message.from_id
|
||||||
|
db[DBTables.prompts][message.from_id] = prompt
|
||||||
|
|
||||||
|
await db[DBTables.config].write()
|
||||||
|
|
||||||
|
await message.reply('✅ Default prompt set')
|
||||||
|
await temp_message.delete()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
exception_id = f'{message.message_thread_id}-{message.message_id}'
|
||||||
|
db[DBTables.exceptions][exception_id] = PrettyException(e)
|
||||||
|
await message.reply('❌ Error happened while processing your request', parse_mode='html',
|
||||||
|
reply_markup=get_exception_keyboard(exception_id))
|
||||||
|
await temp_message.delete()
|
||||||
|
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
|
||||||
|
return
|
||||||
@@ -12,9 +12,13 @@ from aiohttp import ClientConnectorError
|
|||||||
@throttle(cooldown=30, admin_ids=db[DBTables.config].get('admins'))
|
@throttle(cooldown=30, admin_ids=db[DBTables.config].get('admins'))
|
||||||
async def txt2img_comand(message: types.Message):
|
async def txt2img_comand(message: types.Message):
|
||||||
temp_message = await message.reply("⏳ Enqueued...")
|
temp_message = await message.reply("⏳ Enqueued...")
|
||||||
if not message.get_args():
|
|
||||||
await temp_message.edit_text("😶🌫️ Specify prompt for this command. Check /help txt2img")
|
prompt: Prompt = db[DBTables.prompts].get(message.from_id)
|
||||||
return
|
if not prompt:
|
||||||
|
if message.get_args():
|
||||||
|
db[DBTables.prompts][message.from_id] = Prompt(message.get_args(), creator=message.from_id)
|
||||||
|
|
||||||
|
# TODO: Move it to other module
|
||||||
|
|
||||||
try:
|
try:
|
||||||
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 0) + 1
|
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 0) + 1
|
||||||
|
|||||||
18
bot/keyboards/image_info.py
Normal file
18
bot/keyboards/image_info.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
from aiogram import types
|
||||||
|
from bot.callbacks.factories.image_info import (prompt_only, full_prompt, import_prompt, back)
|
||||||
|
|
||||||
|
|
||||||
|
def get_img_info_keyboard(p_id: str) -> types.InlineKeyboardMarkup:
|
||||||
|
buttons = [types.InlineKeyboardButton(text="📋 Show prompts", callback_data=prompt_only.new(p_id=p_id)),
|
||||||
|
types.InlineKeyboardButton(text="🧿 Show full info", callback_data=full_prompt.new(p_id=p_id)),
|
||||||
|
types.InlineKeyboardButton(text="🪄 Import prompt", callback_data=import_prompt.new(p_id=p_id))]
|
||||||
|
keyboard = types.InlineKeyboardMarkup(row_width=2)
|
||||||
|
keyboard.add(*buttons)
|
||||||
|
return keyboard
|
||||||
|
|
||||||
|
|
||||||
|
def get_img_back_keyboard(p_id: str) -> types.InlineKeyboardMarkup:
|
||||||
|
buttons = [types.InlineKeyboardButton(text="👈 Back", callback_data=back.new(p_id=p_id))]
|
||||||
|
keyboard = types.InlineKeyboardMarkup()
|
||||||
|
keyboard.add(*buttons)
|
||||||
|
return keyboard
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
from bot.db import db, DBTables
|
from bot.db import db, DBTables, decrypt
|
||||||
from rich import print
|
from rich import print
|
||||||
|
|
||||||
|
|
||||||
async def get_models():
|
async def get_models():
|
||||||
endpoint = db[DBTables.config].get('endpoint')
|
endpoint = decrypt(db[DBTables.config].get('endpoint'))
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
r = await session.get(endpoint + "/sdapi/v1/sd-models")
|
r = await session.get(endpoint + "/sdapi/v1/sd-models")
|
||||||
@@ -17,7 +17,7 @@ async def get_models():
|
|||||||
|
|
||||||
|
|
||||||
async def set_model(model_name: str):
|
async def set_model(model_name: str):
|
||||||
endpoint = db[DBTables.config].get('endpoint')
|
endpoint = decrypt(db[DBTables.config].get('endpoint'))
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
r = await session.post(endpoint + "/sdapi/v1/options", json={
|
r = await session.post(endpoint + "/sdapi/v1/options", json={
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from bot.db import db, DBTables
|
from bot.db import db, DBTables, decrypt
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
@@ -18,7 +18,7 @@ async def job_exists(endpoint):
|
|||||||
|
|
||||||
|
|
||||||
async def wait_for_status(ignore_exceptions: bool = False):
|
async def wait_for_status(ignore_exceptions: bool = False):
|
||||||
endpoint = db[DBTables.config].get('endpoint')
|
endpoint = decrypt(db[DBTables.config].get('endpoint'))
|
||||||
try:
|
try:
|
||||||
while await job_exists(endpoint):
|
while await job_exists(endpoint):
|
||||||
while db[DBTables.cooldown].get('_last_time_status_checked', 0) + 5 > time.time():
|
while db[DBTables.cooldown].get('_last_time_status_checked', 0) + 5 > time.time():
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
from bot.db import db, DBTables
|
from bot.db import db, DBTables, decrypt
|
||||||
from .objects.prompt_request import Prompt
|
from .objects.prompt_request import Prompt
|
||||||
import json
|
import json
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
|
|
||||||
async def txt2img(prompt: Prompt, ignore_exceptions: bool = False) -> list[bytes, dict] | None:
|
async def txt2img(prompt: Prompt, ignore_exceptions: bool = False) -> list[bytes, dict] | None:
|
||||||
endpoint = db[DBTables.config].get('endpoint')
|
endpoint = decrypt(db[DBTables.config].get('endpoint'))
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
r = await session.post(
|
r = await session.post(
|
||||||
|
|||||||
@@ -6,11 +6,11 @@ from aiogram import types
|
|||||||
|
|
||||||
|
|
||||||
def not_allowed(message: types.Message, cd: int, by_id: bool):
|
def not_allowed(message: types.Message, cd: int, by_id: bool):
|
||||||
|
text = f"❌ Wait for cooldown ({cd}s for this command) " \
|
||||||
|
f"{'. Please note that this cooldown is for all users' if not by_id else ''}"
|
||||||
return asyncio.create_task(message.reply(
|
return asyncio.create_task(message.reply(
|
||||||
text=
|
text=text
|
||||||
f"❌ Wait for cooldown ({cd}s for this command)"
|
) if hasattr(message, 'reply') else message.answer(text=text, show_alert=True))
|
||||||
f"{'. Please note that this cooldown is for all users' if not by_id else ''}"
|
|
||||||
))
|
|
||||||
|
|
||||||
|
|
||||||
def throttle(cooldown: int = 5, by_id: bool = True, admin_ids: list = None):
|
def throttle(cooldown: int = 5, by_id: bool = True, admin_ids: list = None):
|
||||||
|
|||||||
12
bot/utils/private_keyboard.py
Normal file
12
bot/utils/private_keyboard.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from aiogram import types
|
||||||
|
|
||||||
|
|
||||||
|
async def other_user(call: types.CallbackQuery) -> bool:
|
||||||
|
if not hasattr(call.message.reply_to_message, 'from_id'):
|
||||||
|
await call.answer('Error, original call was removed', show_alert=True)
|
||||||
|
return True
|
||||||
|
elif call.message.reply_to_message.from_id != call.from_user.id:
|
||||||
|
await call.answer('It is not your menu!', show_alert=True)
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
@@ -4,4 +4,5 @@ python-dotenv
|
|||||||
rich
|
rich
|
||||||
aiohttp
|
aiohttp
|
||||||
validators
|
validators
|
||||||
sqlitedict
|
sqlitedict
|
||||||
|
cryptography
|
||||||
Reference in New Issue
Block a user