Testing text2img command (will be replaced, only for tests); added exception handling, queue and database table with generated images

This commit is contained in:
BarsTiger
2023-02-20 16:41:55 +02:00
parent 3bd34009ae
commit d56f1d386f
21 changed files with 219 additions and 34 deletions

View File

View File

@@ -0,0 +1,15 @@
from bot.common import dp
from bot.db import db, DBTables
from aiogram import types
from .factories.exception import exception_callback
async def on_exception(call: types.CallbackQuery, callback_data: dict):
e_id = callback_data['e_id']
e = db[DBTables.exceptions][e_id]
del db[DBTables.exceptions][e_id]
await call.message.edit_text(e, parse_mode='html')
def register():
dp.register_callback_query_handler(on_exception, exception_callback.filter())

View File

View File

@@ -0,0 +1,4 @@
from aiogram.utils.callback_data import CallbackData
exception_callback = CallbackData("full_exception", "e_id")

11
bot/callbacks/register.py Normal file
View File

@@ -0,0 +1,11 @@
from rich import print
def register_callbacks():
from bot.callbacks import (
exception
)
exception.register()
print('[gray]All callbacks registered[/]')

View File

@@ -8,5 +8,8 @@ if not os.path.isfile(DB):
db = {
'config': DBDict(DB, autocommit=True, tablename='config'),
'cooldown': DBDict(DB, autocommit=True, tablename='cooldown')
'cooldown': DBDict(DB, autocommit=True, tablename='cooldown'),
'exceptions': DBDict(DB, autocommit=True, tablename='exceptions'),
'queue': DBDict(DB, autocommit=True, tablename='queue'),
'generated': DBDict(DB, autocommit=True, tablename='generated')
}

View File

@@ -8,9 +8,12 @@ from .meta import DBMeta
class DBTables:
tables = ['config', 'cooldown']
tables = ['config', 'cooldown', 'exceptions', 'queue', 'generated']
config = "config"
cooldown = "cooldown"
exceptions = "exceptions"
queue = "queue"
generated = "generated"
class DBDict(SqliteDict):

View File

@@ -1,6 +1,8 @@
from bot.common import dp
from .aliases import *
from .reset import *
def register():
dp.register_message_handler(set_endpoint, commands='setendpoint')
dp.register_message_handler(reset.resetqueue, commands='resetqueue')

View File

@@ -0,0 +1,18 @@
from aiogram import types
from bot.db import db, DBTables
from bot.config import ADMIN
from bot.utils.cooldown import throttle
@throttle(5)
async def resetqueue(message: types.Message):
if message.from_id not in db[DBTables.config].get('admins') and message.from_id != ADMIN:
await message.reply('❌ You are not permitted to do that. '
'It is only for this bot instance maintainers and admins')
return
db[DBTables.queue]['n'] = 0
await db[DBTables.config].write()
await message.reply("✅ Reset queue")

View File

@@ -4,5 +4,6 @@ from bot.db.pull_db import pull
async def sync_db_filter(message: Message):
await pull()
await message.reply(f'🔄️ Bot database synchronised because of restart. '
f'If you tried to run a command, run it again')
if message.is_command():
await message.reply(f'🔄️ Bot database synchronised because of restart. '
f'If you tried to run a command, run it again')

View File

@@ -1,7 +1,7 @@
from rich import print
def import_handlers():
def register_handlers():
from bot.handlers import (
initialize, admin, help_command, txt2img
)
@@ -11,4 +11,4 @@ def import_handlers():
help_command.register()
txt2img.register()
print('[gray]All handlers imported[/]')
print('[gray]All handlers registered[/]')

View File

@@ -1,6 +1,6 @@
from bot.common import dp
from .txt2img import *
from .txt2img import txt2img_comand
def register():
dp.register_message_handler(txt2img, commands='txt2img')
dp.register_message_handler(txt2img.txt2img_comand, commands='txt2img')

View File

@@ -2,25 +2,50 @@ from aiogram import types
from bot.db import db, DBTables
from bot.utils.cooldown import throttle
from bot.modules.api.txt2img import txt2img
from bot.modules.api.objects.prompt_request import Prompt
from bot.modules.api.status import wait_for_status
from bot.keyboards.exception import get_exception_keyboard
from bot.utils.trace_exception import PrettyException
from aiohttp import ClientConnectorError
@throttle(cooldown=30, admin_ids=db[DBTables.config].get('admins'))
async def txt2img_comand(message: types.Message):
temp_message = await message.reply("Generating image...")
temp_message = await message.reply("Enqueued...")
if not message.get_args():
await temp_message.edit_text("Specify prompt for this command. Check /help txt2img")
await temp_message.edit_text("😶‍🌫️ Specify prompt for this command. Check /help txt2img")
return
try:
image = await txt2img(message.get_args())
await message.reply_photo(photo=image[0], caption=str(
image[1]["infotexts"][0]))
except Exception as e:
assert e
await message.reply("We ran into error while processing your request. StableDiffusion models may not be "
"configured on specified endpoint or server with StableDiffusion may be turned "
"off. Ask admins of this bot instance if you have contacts for further info")
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 0) + 1
await temp_message.edit_text(f"⏳ Enqueued in position {db[DBTables.queue].get('n', 0)}...")
await wait_for_status()
await temp_message.edit_text(f"⌛ Generating...")
prompt = Prompt(prompt=message.get_args(), creator=message.from_id)
image = await txt2img(prompt)
image_message = await message.reply_photo(photo=image[0])
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
db[DBTables.generated][image_message.photo[0].file_unique_id] = prompt
await temp_message.delete()
await db[DBTables.config].write()
except ClientConnectorError:
await message.reply('❌ Error! Maybe, StableDiffusion API endpoint is incorrect '
'or turned off')
await temp_message.delete()
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
return
await temp_message.delete()
except Exception as e:
exception_id = f'{message.message_thread_id}-{message.message_id}'
db[DBTables.exceptions][exception_id] = PrettyException(e)
await message.reply('❌ Error happened while processing your request', parse_mode='html',
reply_markup=get_exception_keyboard(exception_id))
await temp_message.delete()
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
return

View File

View File

@@ -0,0 +1,9 @@
from aiogram import types
from bot.callbacks.factories.exception import exception_callback
def get_exception_keyboard(e_id: str) -> types.InlineKeyboardMarkup:
buttons = [types.InlineKeyboardButton(text="Show full stack", callback_data=exception_callback.new(e_id=e_id))]
keyboard = types.InlineKeyboardMarkup()
keyboard.add(*buttons)
return keyboard

View File

View File

@@ -0,0 +1,14 @@
import dataclasses
@dataclasses.dataclass
class Prompt:
prompt: str
negative_prompt: str = None
steps: int = 20
cfg_scale: int = 7
width: int = 768
height: int = 768
restore_faces: bool = True
sampler: str = "Euler a"
creator: int = None

31
bot/modules/api/status.py Normal file
View File

@@ -0,0 +1,31 @@
from bot.db import db, DBTables
import aiohttp
import asyncio
import time
async def job_exists(endpoint):
async with aiohttp.ClientSession() as session:
r = await session.get(
endpoint + "/sdapi/v1/progress",
json={
"skip_current_image": True,
}
)
if r.status != 200:
return None
return (await r.json()).get('state').get('job_count') > 0
async def wait_for_status(ignore_exceptions: bool = False):
endpoint = db[DBTables.config].get('endpoint')
try:
while await job_exists(endpoint):
while db[DBTables.cooldown].get('_last_time_status_checked', 0) + 5 > time.time():
await asyncio.sleep(5)
db[DBTables.cooldown]['_last_time_status_checked'] = time.time()
return
except Exception as e:
if not ignore_exceptions:
raise e
return

View File

@@ -1,26 +1,25 @@
import aiohttp
from bot.db import db, DBTables
from .objects.prompt_request import Prompt
import json
import base64
async def txt2img(prompt: str, negative_prompt: str = None, steps: int = 20,
cfg_scale: int = 7, width: int = 768, height: int = 768,
restore_faces: bool = True, sampler: str = "Euler a") -> list[bytes, dict] | None:
async def txt2img(prompt: Prompt, ignore_exceptions: bool = False) -> list[bytes, dict] | None:
endpoint = db[DBTables.config].get('endpoint')
try:
async with aiohttp.ClientSession() as session:
r = await session.post(
endpoint + "/sdapi/v1/txt2img",
json={
"prompt": prompt,
"steps": steps,
"cfg_scale": cfg_scale,
"width": width,
"height": height,
"restore_faces": restore_faces,
"negative_prompt": negative_prompt,
"sampler_index": sampler
"prompt": prompt.prompt,
"steps": prompt.steps,
"cfg_scale": prompt.cfg_scale,
"width": prompt.width,
"height": prompt.height,
"restore_faces": prompt.restore_faces,
"negative_prompt": prompt.negative_prompt,
"sampler_index": prompt.sampler
}
)
if r.status != 200:
@@ -28,5 +27,6 @@ async def txt2img(prompt: str, negative_prompt: str = None, steps: int = 20,
return [base64.b64decode((await r.json())["images"][0]),
json.loads((await r.json())["info"])]
except Exception as e:
assert e
return None
if not ignore_exceptions:
raise e
return

View File

@@ -0,0 +1,47 @@
import os
import traceback
import contextlib
import re
class PrettyException:
def __init__(self, e: Exception):
self.pretty_exception = f'❌ Error! Report it to admins: \n' \
f'🐊 <code>{e.__traceback__.tb_frame.f_code.co_filename.replace(os.getcwd(), "")}' \
f'</code>:{e.__traceback__.tb_frame.f_lineno} \n' \
f'😍 {e.__class__.__name__} \n' \
f'👉 {"".join(traceback.format_exception_only(e)).strip()} \n\n' \
f'⬇️ Trace: \n' \
f'{self.get_full_stack()}'
@staticmethod
def get_full_stack():
full_stack = traceback.format_exc().replace(
"Traceback (most recent call last):\n", ""
)
line_regex = r' File "(.*?)", line ([0-9]+), in (.+)'
def format_line(line: str) -> str:
filename_, lineno_, name_ = re.search(line_regex, line).groups()
with contextlib.suppress(Exception):
filename_ = os.path.basename(filename_)
return (
f"🤯 <code>{filename_}:{lineno_}</code> (<b>in</b>"
f" <code>{name_}</code> call)"
)
full_stack = "\n".join(
[
format_line(line)
if re.search(line_regex, line)
else f"<code>{line}</code>"
for line in full_stack.splitlines()
]
)
return full_stack
def __str__(self):
return self.pretty_exception