Testing text2img command (will be replaced, only for tests); added exception handling, queue and database table with generated images

This commit is contained in:
BarsTiger
2023-02-20 16:41:55 +02:00
parent 3bd34009ae
commit d56f1d386f
21 changed files with 219 additions and 34 deletions

View File

View File

@@ -0,0 +1,15 @@
from bot.common import dp
from bot.db import db, DBTables
from aiogram import types
from .factories.exception import exception_callback
async def on_exception(call: types.CallbackQuery, callback_data: dict):
e_id = callback_data['e_id']
e = db[DBTables.exceptions][e_id]
del db[DBTables.exceptions][e_id]
await call.message.edit_text(e, parse_mode='html')
def register():
dp.register_callback_query_handler(on_exception, exception_callback.filter())

View File

View File

@@ -0,0 +1,4 @@
from aiogram.utils.callback_data import CallbackData
exception_callback = CallbackData("full_exception", "e_id")

11
bot/callbacks/register.py Normal file
View File

@@ -0,0 +1,11 @@
from rich import print
def register_callbacks():
from bot.callbacks import (
exception
)
exception.register()
print('[gray]All callbacks registered[/]')

View File

@@ -8,5 +8,8 @@ if not os.path.isfile(DB):
db = { db = {
'config': DBDict(DB, autocommit=True, tablename='config'), 'config': DBDict(DB, autocommit=True, tablename='config'),
'cooldown': DBDict(DB, autocommit=True, tablename='cooldown') 'cooldown': DBDict(DB, autocommit=True, tablename='cooldown'),
'exceptions': DBDict(DB, autocommit=True, tablename='exceptions'),
'queue': DBDict(DB, autocommit=True, tablename='queue'),
'generated': DBDict(DB, autocommit=True, tablename='generated')
} }

View File

@@ -8,9 +8,12 @@ from .meta import DBMeta
class DBTables: class DBTables:
tables = ['config', 'cooldown'] tables = ['config', 'cooldown', 'exceptions', 'queue', 'generated']
config = "config" config = "config"
cooldown = "cooldown" cooldown = "cooldown"
exceptions = "exceptions"
queue = "queue"
generated = "generated"
class DBDict(SqliteDict): class DBDict(SqliteDict):

View File

@@ -1,6 +1,8 @@
from bot.common import dp from bot.common import dp
from .aliases import * from .aliases import *
from .reset import *
def register(): def register():
dp.register_message_handler(set_endpoint, commands='setendpoint') dp.register_message_handler(set_endpoint, commands='setendpoint')
dp.register_message_handler(reset.resetqueue, commands='resetqueue')

View File

@@ -0,0 +1,18 @@
from aiogram import types
from bot.db import db, DBTables
from bot.config import ADMIN
from bot.utils.cooldown import throttle
@throttle(5)
async def resetqueue(message: types.Message):
if message.from_id not in db[DBTables.config].get('admins') and message.from_id != ADMIN:
await message.reply('❌ You are not permitted to do that. '
'It is only for this bot instance maintainers and admins')
return
db[DBTables.queue]['n'] = 0
await db[DBTables.config].write()
await message.reply("✅ Reset queue")

View File

@@ -4,5 +4,6 @@ from bot.db.pull_db import pull
async def sync_db_filter(message: Message): async def sync_db_filter(message: Message):
await pull() await pull()
if message.is_command():
await message.reply(f'🔄️ Bot database synchronised because of restart. ' await message.reply(f'🔄️ Bot database synchronised because of restart. '
f'If you tried to run a command, run it again') f'If you tried to run a command, run it again')

View File

@@ -1,7 +1,7 @@
from rich import print from rich import print
def import_handlers(): def register_handlers():
from bot.handlers import ( from bot.handlers import (
initialize, admin, help_command, txt2img initialize, admin, help_command, txt2img
) )
@@ -11,4 +11,4 @@ def import_handlers():
help_command.register() help_command.register()
txt2img.register() txt2img.register()
print('[gray]All handlers imported[/]') print('[gray]All handlers registered[/]')

View File

@@ -1,6 +1,6 @@
from bot.common import dp from bot.common import dp
from .txt2img import * from .txt2img import txt2img_comand
def register(): def register():
dp.register_message_handler(txt2img, commands='txt2img') dp.register_message_handler(txt2img.txt2img_comand, commands='txt2img')

View File

@@ -2,25 +2,50 @@ from aiogram import types
from bot.db import db, DBTables from bot.db import db, DBTables
from bot.utils.cooldown import throttle from bot.utils.cooldown import throttle
from bot.modules.api.txt2img import txt2img from bot.modules.api.txt2img import txt2img
from bot.modules.api.objects.prompt_request import Prompt
from bot.modules.api.status import wait_for_status
from bot.keyboards.exception import get_exception_keyboard
from bot.utils.trace_exception import PrettyException
from aiohttp import ClientConnectorError
@throttle(cooldown=30, admin_ids=db[DBTables.config].get('admins')) @throttle(cooldown=30, admin_ids=db[DBTables.config].get('admins'))
async def txt2img_comand(message: types.Message): async def txt2img_comand(message: types.Message):
temp_message = await message.reply("Generating image...") temp_message = await message.reply("Enqueued...")
if not message.get_args(): if not message.get_args():
await temp_message.edit_text("Specify prompt for this command. Check /help txt2img") await temp_message.edit_text("😶‍🌫️ Specify prompt for this command. Check /help txt2img")
return return
try: try:
image = await txt2img(message.get_args()) db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 0) + 1
await message.reply_photo(photo=image[0], caption=str( await temp_message.edit_text(f"⏳ Enqueued in position {db[DBTables.queue].get('n', 0)}...")
image[1]["infotexts"][0]))
except Exception as e: await wait_for_status()
assert e
await message.reply("We ran into error while processing your request. StableDiffusion models may not be " await temp_message.edit_text(f"⌛ Generating...")
"configured on specified endpoint or server with StableDiffusion may be turned " prompt = Prompt(prompt=message.get_args(), creator=message.from_id)
"off. Ask admins of this bot instance if you have contacts for further info") image = await txt2img(prompt)
await temp_message.delete() image_message = await message.reply_photo(photo=image[0])
return
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
db[DBTables.generated][image_message.photo[0].file_unique_id] = prompt
await temp_message.delete() await temp_message.delete()
await db[DBTables.config].write()
except ClientConnectorError:
await message.reply('❌ Error! Maybe, StableDiffusion API endpoint is incorrect '
'or turned off')
await temp_message.delete()
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
return
except Exception as e:
exception_id = f'{message.message_thread_id}-{message.message_id}'
db[DBTables.exceptions][exception_id] = PrettyException(e)
await message.reply('❌ Error happened while processing your request', parse_mode='html',
reply_markup=get_exception_keyboard(exception_id))
await temp_message.delete()
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
return

View File

View File

@@ -0,0 +1,9 @@
from aiogram import types
from bot.callbacks.factories.exception import exception_callback
def get_exception_keyboard(e_id: str) -> types.InlineKeyboardMarkup:
buttons = [types.InlineKeyboardButton(text="Show full stack", callback_data=exception_callback.new(e_id=e_id))]
keyboard = types.InlineKeyboardMarkup()
keyboard.add(*buttons)
return keyboard

View File

View File

@@ -0,0 +1,14 @@
import dataclasses
@dataclasses.dataclass
class Prompt:
prompt: str
negative_prompt: str = None
steps: int = 20
cfg_scale: int = 7
width: int = 768
height: int = 768
restore_faces: bool = True
sampler: str = "Euler a"
creator: int = None

31
bot/modules/api/status.py Normal file
View File

@@ -0,0 +1,31 @@
from bot.db import db, DBTables
import aiohttp
import asyncio
import time
async def job_exists(endpoint):
async with aiohttp.ClientSession() as session:
r = await session.get(
endpoint + "/sdapi/v1/progress",
json={
"skip_current_image": True,
}
)
if r.status != 200:
return None
return (await r.json()).get('state').get('job_count') > 0
async def wait_for_status(ignore_exceptions: bool = False):
endpoint = db[DBTables.config].get('endpoint')
try:
while await job_exists(endpoint):
while db[DBTables.cooldown].get('_last_time_status_checked', 0) + 5 > time.time():
await asyncio.sleep(5)
db[DBTables.cooldown]['_last_time_status_checked'] = time.time()
return
except Exception as e:
if not ignore_exceptions:
raise e
return

View File

@@ -1,26 +1,25 @@
import aiohttp import aiohttp
from bot.db import db, DBTables from bot.db import db, DBTables
from .objects.prompt_request import Prompt
import json import json
import base64 import base64
async def txt2img(prompt: str, negative_prompt: str = None, steps: int = 20, async def txt2img(prompt: Prompt, ignore_exceptions: bool = False) -> list[bytes, dict] | None:
cfg_scale: int = 7, width: int = 768, height: int = 768,
restore_faces: bool = True, sampler: str = "Euler a") -> list[bytes, dict] | None:
endpoint = db[DBTables.config].get('endpoint') endpoint = db[DBTables.config].get('endpoint')
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
r = await session.post( r = await session.post(
endpoint + "/sdapi/v1/txt2img", endpoint + "/sdapi/v1/txt2img",
json={ json={
"prompt": prompt, "prompt": prompt.prompt,
"steps": steps, "steps": prompt.steps,
"cfg_scale": cfg_scale, "cfg_scale": prompt.cfg_scale,
"width": width, "width": prompt.width,
"height": height, "height": prompt.height,
"restore_faces": restore_faces, "restore_faces": prompt.restore_faces,
"negative_prompt": negative_prompt, "negative_prompt": prompt.negative_prompt,
"sampler_index": sampler "sampler_index": prompt.sampler
} }
) )
if r.status != 200: if r.status != 200:
@@ -28,5 +27,6 @@ async def txt2img(prompt: str, negative_prompt: str = None, steps: int = 20,
return [base64.b64decode((await r.json())["images"][0]), return [base64.b64decode((await r.json())["images"][0]),
json.loads((await r.json())["info"])] json.loads((await r.json())["info"])]
except Exception as e: except Exception as e:
assert e if not ignore_exceptions:
return None raise e
return

View File

@@ -0,0 +1,47 @@
import os
import traceback
import contextlib
import re
class PrettyException:
def __init__(self, e: Exception):
self.pretty_exception = f'❌ Error! Report it to admins: \n' \
f'🐊 <code>{e.__traceback__.tb_frame.f_code.co_filename.replace(os.getcwd(), "")}' \
f'</code>:{e.__traceback__.tb_frame.f_lineno} \n' \
f'😍 {e.__class__.__name__} \n' \
f'👉 {"".join(traceback.format_exception_only(e)).strip()} \n\n' \
f'⬇️ Trace: \n' \
f'{self.get_full_stack()}'
@staticmethod
def get_full_stack():
full_stack = traceback.format_exc().replace(
"Traceback (most recent call last):\n", ""
)
line_regex = r' File "(.*?)", line ([0-9]+), in (.+)'
def format_line(line: str) -> str:
filename_, lineno_, name_ = re.search(line_regex, line).groups()
with contextlib.suppress(Exception):
filename_ = os.path.basename(filename_)
return (
f"🤯 <code>{filename_}:{lineno_}</code> (<b>in</b>"
f" <code>{name_}</code> call)"
)
full_stack = "\n".join(
[
format_line(line)
if re.search(line_regex, line)
else f"<code>{line}</code>"
for line in full_stack.splitlines()
]
)
return full_stack
def __str__(self):
return self.pretty_exception

View File

@@ -4,11 +4,13 @@ from rich import print
async def main(): async def main():
print(BARS_APP_ID) print(BARS_APP_ID)
import bot.handlers.register
from bot.common import dp from bot.common import dp
import bot.handlers.register
import bot.callbacks.register
from bot.utils.commands import set_commands from bot.utils.commands import set_commands
bot.handlers.register.import_handlers() bot.handlers.register.register_handlers()
bot.callbacks.register.register_callbacks()
await set_commands() await set_commands()
print('[green]Bot will start now[/]') print('[green]Bot will start now[/]')
await dp.skip_updates() await dp.skip_updates()