Testing text2img command (will be replaced, only for tests); added exception handling, queue and database table with generated images
This commit is contained in:
0
bot/callbacks/__init__.py
Normal file
0
bot/callbacks/__init__.py
Normal file
15
bot/callbacks/exception.py
Normal file
15
bot/callbacks/exception.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from bot.common import dp
|
||||
from bot.db import db, DBTables
|
||||
from aiogram import types
|
||||
from .factories.exception import exception_callback
|
||||
|
||||
|
||||
async def on_exception(call: types.CallbackQuery, callback_data: dict):
|
||||
e_id = callback_data['e_id']
|
||||
e = db[DBTables.exceptions][e_id]
|
||||
del db[DBTables.exceptions][e_id]
|
||||
await call.message.edit_text(e, parse_mode='html')
|
||||
|
||||
|
||||
def register():
|
||||
dp.register_callback_query_handler(on_exception, exception_callback.filter())
|
||||
0
bot/callbacks/factories/__init__.py
Normal file
0
bot/callbacks/factories/__init__.py
Normal file
4
bot/callbacks/factories/exception.py
Normal file
4
bot/callbacks/factories/exception.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from aiogram.utils.callback_data import CallbackData
|
||||
|
||||
|
||||
exception_callback = CallbackData("full_exception", "e_id")
|
||||
11
bot/callbacks/register.py
Normal file
11
bot/callbacks/register.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from rich import print
|
||||
|
||||
|
||||
def register_callbacks():
|
||||
from bot.callbacks import (
|
||||
exception
|
||||
)
|
||||
|
||||
exception.register()
|
||||
|
||||
print('[gray]All callbacks registered[/]')
|
||||
@@ -8,5 +8,8 @@ if not os.path.isfile(DB):
|
||||
|
||||
db = {
|
||||
'config': DBDict(DB, autocommit=True, tablename='config'),
|
||||
'cooldown': DBDict(DB, autocommit=True, tablename='cooldown')
|
||||
'cooldown': DBDict(DB, autocommit=True, tablename='cooldown'),
|
||||
'exceptions': DBDict(DB, autocommit=True, tablename='exceptions'),
|
||||
'queue': DBDict(DB, autocommit=True, tablename='queue'),
|
||||
'generated': DBDict(DB, autocommit=True, tablename='generated')
|
||||
}
|
||||
|
||||
@@ -8,9 +8,12 @@ from .meta import DBMeta
|
||||
|
||||
|
||||
class DBTables:
|
||||
tables = ['config', 'cooldown']
|
||||
tables = ['config', 'cooldown', 'exceptions', 'queue', 'generated']
|
||||
config = "config"
|
||||
cooldown = "cooldown"
|
||||
exceptions = "exceptions"
|
||||
queue = "queue"
|
||||
generated = "generated"
|
||||
|
||||
|
||||
class DBDict(SqliteDict):
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from bot.common import dp
|
||||
from .aliases import *
|
||||
from .reset import *
|
||||
|
||||
|
||||
def register():
|
||||
dp.register_message_handler(set_endpoint, commands='setendpoint')
|
||||
dp.register_message_handler(reset.resetqueue, commands='resetqueue')
|
||||
|
||||
18
bot/handlers/admin/reset.py
Normal file
18
bot/handlers/admin/reset.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from aiogram import types
|
||||
from bot.db import db, DBTables
|
||||
from bot.config import ADMIN
|
||||
from bot.utils.cooldown import throttle
|
||||
|
||||
|
||||
@throttle(5)
|
||||
async def resetqueue(message: types.Message):
|
||||
if message.from_id not in db[DBTables.config].get('admins') and message.from_id != ADMIN:
|
||||
await message.reply('❌ You are not permitted to do that. '
|
||||
'It is only for this bot instance maintainers and admins')
|
||||
return
|
||||
|
||||
db[DBTables.queue]['n'] = 0
|
||||
|
||||
await db[DBTables.config].write()
|
||||
|
||||
await message.reply("✅ Reset queue")
|
||||
@@ -4,5 +4,6 @@ from bot.db.pull_db import pull
|
||||
|
||||
async def sync_db_filter(message: Message):
|
||||
await pull()
|
||||
await message.reply(f'🔄️ Bot database synchronised because of restart. '
|
||||
f'If you tried to run a command, run it again')
|
||||
if message.is_command():
|
||||
await message.reply(f'🔄️ Bot database synchronised because of restart. '
|
||||
f'If you tried to run a command, run it again')
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from rich import print
|
||||
|
||||
|
||||
def import_handlers():
|
||||
def register_handlers():
|
||||
from bot.handlers import (
|
||||
initialize, admin, help_command, txt2img
|
||||
)
|
||||
@@ -11,4 +11,4 @@ def import_handlers():
|
||||
help_command.register()
|
||||
txt2img.register()
|
||||
|
||||
print('[gray]All handlers imported[/]')
|
||||
print('[gray]All handlers registered[/]')
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from bot.common import dp
|
||||
from .txt2img import *
|
||||
from .txt2img import txt2img_comand
|
||||
|
||||
|
||||
def register():
|
||||
dp.register_message_handler(txt2img, commands='txt2img')
|
||||
dp.register_message_handler(txt2img.txt2img_comand, commands='txt2img')
|
||||
|
||||
@@ -2,25 +2,50 @@ from aiogram import types
|
||||
from bot.db import db, DBTables
|
||||
from bot.utils.cooldown import throttle
|
||||
from bot.modules.api.txt2img import txt2img
|
||||
from bot.modules.api.objects.prompt_request import Prompt
|
||||
from bot.modules.api.status import wait_for_status
|
||||
from bot.keyboards.exception import get_exception_keyboard
|
||||
from bot.utils.trace_exception import PrettyException
|
||||
from aiohttp import ClientConnectorError
|
||||
|
||||
|
||||
@throttle(cooldown=30, admin_ids=db[DBTables.config].get('admins'))
|
||||
async def txt2img_comand(message: types.Message):
|
||||
temp_message = await message.reply("⏳ Generating image...")
|
||||
temp_message = await message.reply("⏳ Enqueued...")
|
||||
if not message.get_args():
|
||||
await temp_message.edit_text("Specify prompt for this command. Check /help txt2img")
|
||||
await temp_message.edit_text("😶🌫️ Specify prompt for this command. Check /help txt2img")
|
||||
return
|
||||
|
||||
try:
|
||||
image = await txt2img(message.get_args())
|
||||
await message.reply_photo(photo=image[0], caption=str(
|
||||
image[1]["infotexts"][0]))
|
||||
except Exception as e:
|
||||
assert e
|
||||
await message.reply("We ran into error while processing your request. StableDiffusion models may not be "
|
||||
"configured on specified endpoint or server with StableDiffusion may be turned "
|
||||
"off. Ask admins of this bot instance if you have contacts for further info")
|
||||
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 0) + 1
|
||||
await temp_message.edit_text(f"⏳ Enqueued in position {db[DBTables.queue].get('n', 0)}...")
|
||||
|
||||
await wait_for_status()
|
||||
|
||||
await temp_message.edit_text(f"⌛ Generating...")
|
||||
prompt = Prompt(prompt=message.get_args(), creator=message.from_id)
|
||||
image = await txt2img(prompt)
|
||||
image_message = await message.reply_photo(photo=image[0])
|
||||
|
||||
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
|
||||
db[DBTables.generated][image_message.photo[0].file_unique_id] = prompt
|
||||
|
||||
await temp_message.delete()
|
||||
|
||||
await db[DBTables.config].write()
|
||||
|
||||
except ClientConnectorError:
|
||||
await message.reply('❌ Error! Maybe, StableDiffusion API endpoint is incorrect '
|
||||
'or turned off')
|
||||
await temp_message.delete()
|
||||
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
|
||||
return
|
||||
|
||||
await temp_message.delete()
|
||||
except Exception as e:
|
||||
exception_id = f'{message.message_thread_id}-{message.message_id}'
|
||||
db[DBTables.exceptions][exception_id] = PrettyException(e)
|
||||
await message.reply('❌ Error happened while processing your request', parse_mode='html',
|
||||
reply_markup=get_exception_keyboard(exception_id))
|
||||
await temp_message.delete()
|
||||
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
|
||||
return
|
||||
|
||||
0
bot/keyboards/__init__.py
Normal file
0
bot/keyboards/__init__.py
Normal file
9
bot/keyboards/exception.py
Normal file
9
bot/keyboards/exception.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from aiogram import types
|
||||
from bot.callbacks.factories.exception import exception_callback
|
||||
|
||||
|
||||
def get_exception_keyboard(e_id: str) -> types.InlineKeyboardMarkup:
|
||||
buttons = [types.InlineKeyboardButton(text="Show full stack", callback_data=exception_callback.new(e_id=e_id))]
|
||||
keyboard = types.InlineKeyboardMarkup()
|
||||
keyboard.add(*buttons)
|
||||
return keyboard
|
||||
0
bot/modules/api/objects/__init__.py
Normal file
0
bot/modules/api/objects/__init__.py
Normal file
14
bot/modules/api/objects/prompt_request.py
Normal file
14
bot/modules/api/objects/prompt_request.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import dataclasses
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Prompt:
|
||||
prompt: str
|
||||
negative_prompt: str = None
|
||||
steps: int = 20
|
||||
cfg_scale: int = 7
|
||||
width: int = 768
|
||||
height: int = 768
|
||||
restore_faces: bool = True
|
||||
sampler: str = "Euler a"
|
||||
creator: int = None
|
||||
31
bot/modules/api/status.py
Normal file
31
bot/modules/api/status.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from bot.db import db, DBTables
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
|
||||
async def job_exists(endpoint):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
r = await session.get(
|
||||
endpoint + "/sdapi/v1/progress",
|
||||
json={
|
||||
"skip_current_image": True,
|
||||
}
|
||||
)
|
||||
if r.status != 200:
|
||||
return None
|
||||
return (await r.json()).get('state').get('job_count') > 0
|
||||
|
||||
|
||||
async def wait_for_status(ignore_exceptions: bool = False):
|
||||
endpoint = db[DBTables.config].get('endpoint')
|
||||
try:
|
||||
while await job_exists(endpoint):
|
||||
while db[DBTables.cooldown].get('_last_time_status_checked', 0) + 5 > time.time():
|
||||
await asyncio.sleep(5)
|
||||
db[DBTables.cooldown]['_last_time_status_checked'] = time.time()
|
||||
return
|
||||
except Exception as e:
|
||||
if not ignore_exceptions:
|
||||
raise e
|
||||
return
|
||||
@@ -1,26 +1,25 @@
|
||||
import aiohttp
|
||||
from bot.db import db, DBTables
|
||||
from .objects.prompt_request import Prompt
|
||||
import json
|
||||
import base64
|
||||
|
||||
|
||||
async def txt2img(prompt: str, negative_prompt: str = None, steps: int = 20,
|
||||
cfg_scale: int = 7, width: int = 768, height: int = 768,
|
||||
restore_faces: bool = True, sampler: str = "Euler a") -> list[bytes, dict] | None:
|
||||
async def txt2img(prompt: Prompt, ignore_exceptions: bool = False) -> list[bytes, dict] | None:
|
||||
endpoint = db[DBTables.config].get('endpoint')
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
r = await session.post(
|
||||
endpoint + "/sdapi/v1/txt2img",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"steps": steps,
|
||||
"cfg_scale": cfg_scale,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"restore_faces": restore_faces,
|
||||
"negative_prompt": negative_prompt,
|
||||
"sampler_index": sampler
|
||||
"prompt": prompt.prompt,
|
||||
"steps": prompt.steps,
|
||||
"cfg_scale": prompt.cfg_scale,
|
||||
"width": prompt.width,
|
||||
"height": prompt.height,
|
||||
"restore_faces": prompt.restore_faces,
|
||||
"negative_prompt": prompt.negative_prompt,
|
||||
"sampler_index": prompt.sampler
|
||||
}
|
||||
)
|
||||
if r.status != 200:
|
||||
@@ -28,5 +27,6 @@ async def txt2img(prompt: str, negative_prompt: str = None, steps: int = 20,
|
||||
return [base64.b64decode((await r.json())["images"][0]),
|
||||
json.loads((await r.json())["info"])]
|
||||
except Exception as e:
|
||||
assert e
|
||||
return None
|
||||
if not ignore_exceptions:
|
||||
raise e
|
||||
return
|
||||
|
||||
47
bot/utils/trace_exception.py
Normal file
47
bot/utils/trace_exception.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import os
|
||||
import traceback
|
||||
import contextlib
|
||||
import re
|
||||
|
||||
|
||||
class PrettyException:
|
||||
def __init__(self, e: Exception):
|
||||
self.pretty_exception = f'❌ Error! Report it to admins: \n' \
|
||||
f'🐊 <code>{e.__traceback__.tb_frame.f_code.co_filename.replace(os.getcwd(), "")}' \
|
||||
f'</code>:{e.__traceback__.tb_frame.f_lineno} \n' \
|
||||
f'😍 {e.__class__.__name__} \n' \
|
||||
f'👉 {"".join(traceback.format_exception_only(e)).strip()} \n\n' \
|
||||
f'⬇️ Trace: \n' \
|
||||
f'{self.get_full_stack()}'
|
||||
|
||||
@staticmethod
|
||||
def get_full_stack():
|
||||
full_stack = traceback.format_exc().replace(
|
||||
"Traceback (most recent call last):\n", ""
|
||||
)
|
||||
|
||||
line_regex = r' File "(.*?)", line ([0-9]+), in (.+)'
|
||||
|
||||
def format_line(line: str) -> str:
|
||||
filename_, lineno_, name_ = re.search(line_regex, line).groups()
|
||||
with contextlib.suppress(Exception):
|
||||
filename_ = os.path.basename(filename_)
|
||||
|
||||
return (
|
||||
f"🤯 <code>{filename_}:{lineno_}</code> (<b>in</b>"
|
||||
f" <code>{name_}</code> call)"
|
||||
)
|
||||
|
||||
full_stack = "\n".join(
|
||||
[
|
||||
format_line(line)
|
||||
if re.search(line_regex, line)
|
||||
else f"<code>{line}</code>"
|
||||
for line in full_stack.splitlines()
|
||||
]
|
||||
)
|
||||
|
||||
return full_stack
|
||||
|
||||
def __str__(self):
|
||||
return self.pretty_exception
|
||||
6
main.py
6
main.py
@@ -4,11 +4,13 @@ from rich import print
|
||||
|
||||
async def main():
|
||||
print(BARS_APP_ID)
|
||||
import bot.handlers.register
|
||||
from bot.common import dp
|
||||
import bot.handlers.register
|
||||
import bot.callbacks.register
|
||||
from bot.utils.commands import set_commands
|
||||
|
||||
bot.handlers.register.import_handlers()
|
||||
bot.handlers.register.register_handlers()
|
||||
bot.callbacks.register.register_callbacks()
|
||||
await set_commands()
|
||||
print('[green]Bot will start now[/]')
|
||||
await dp.skip_updates()
|
||||
|
||||
Reference in New Issue
Block a user