diff --git a/bot/callbacks/__init__.py b/bot/callbacks/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/bot/callbacks/exception.py b/bot/callbacks/exception.py
new file mode 100644
index 0000000..3e4e264
--- /dev/null
+++ b/bot/callbacks/exception.py
@@ -0,0 +1,15 @@
+from bot.common import dp
+from bot.db import db, DBTables
+from aiogram import types
+from .factories.exception import exception_callback
+
+
+async def on_exception(call: types.CallbackQuery, callback_data: dict):
+ e_id = callback_data['e_id']
+ e = db[DBTables.exceptions][e_id]
+ del db[DBTables.exceptions][e_id]
+ await call.message.edit_text(e, parse_mode='html')
+
+
+def register():
+ dp.register_callback_query_handler(on_exception, exception_callback.filter())
diff --git a/bot/callbacks/factories/__init__.py b/bot/callbacks/factories/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/bot/callbacks/factories/exception.py b/bot/callbacks/factories/exception.py
new file mode 100644
index 0000000..9adcc96
--- /dev/null
+++ b/bot/callbacks/factories/exception.py
@@ -0,0 +1,4 @@
+from aiogram.utils.callback_data import CallbackData
+
+
+exception_callback = CallbackData("full_exception", "e_id")
diff --git a/bot/callbacks/register.py b/bot/callbacks/register.py
new file mode 100644
index 0000000..306c40b
--- /dev/null
+++ b/bot/callbacks/register.py
@@ -0,0 +1,11 @@
+from rich import print
+
+
+def register_callbacks():
+ from bot.callbacks import (
+ exception
+ )
+
+ exception.register()
+
+ print('[gray]All callbacks registered[/]')
diff --git a/bot/db/db.py b/bot/db/db.py
index 837aba8..5025ce5 100644
--- a/bot/db/db.py
+++ b/bot/db/db.py
@@ -8,5 +8,8 @@ if not os.path.isfile(DB):
db = {
'config': DBDict(DB, autocommit=True, tablename='config'),
- 'cooldown': DBDict(DB, autocommit=True, tablename='cooldown')
+ 'cooldown': DBDict(DB, autocommit=True, tablename='cooldown'),
+ 'exceptions': DBDict(DB, autocommit=True, tablename='exceptions'),
+ 'queue': DBDict(DB, autocommit=True, tablename='queue'),
+ 'generated': DBDict(DB, autocommit=True, tablename='generated')
}
diff --git a/bot/db/db_model.py b/bot/db/db_model.py
index ae32ef4..43411fb 100644
--- a/bot/db/db_model.py
+++ b/bot/db/db_model.py
@@ -8,9 +8,12 @@ from .meta import DBMeta
class DBTables:
- tables = ['config', 'cooldown']
+ tables = ['config', 'cooldown', 'exceptions', 'queue', 'generated']
config = "config"
cooldown = "cooldown"
+ exceptions = "exceptions"
+ queue = "queue"
+ generated = "generated"
class DBDict(SqliteDict):
diff --git a/bot/handlers/admin/__init__.py b/bot/handlers/admin/__init__.py
index 80c28e4..1e759b7 100644
--- a/bot/handlers/admin/__init__.py
+++ b/bot/handlers/admin/__init__.py
@@ -1,6 +1,8 @@
from bot.common import dp
from .aliases import *
+from .reset import *
def register():
dp.register_message_handler(set_endpoint, commands='setendpoint')
+ dp.register_message_handler(reset.resetqueue, commands='resetqueue')
diff --git a/bot/handlers/admin/reset.py b/bot/handlers/admin/reset.py
new file mode 100644
index 0000000..076a209
--- /dev/null
+++ b/bot/handlers/admin/reset.py
@@ -0,0 +1,18 @@
+from aiogram import types
+from bot.db import db, DBTables
+from bot.config import ADMIN
+from bot.utils.cooldown import throttle
+
+
+@throttle(5)
+async def resetqueue(message: types.Message):
+ if message.from_id not in db[DBTables.config].get('admins') and message.from_id != ADMIN:
+ await message.reply('β You are not permitted to do that. '
+ 'It is only for this bot instance maintainers and admins')
+ return
+
+ db[DBTables.queue]['n'] = 0
+
+ await db[DBTables.config].write()
+
+ await message.reply("β
Reset queue")
diff --git a/bot/handlers/initialize/all_messages.py b/bot/handlers/initialize/all_messages.py
index a0488bc..ffc15de 100644
--- a/bot/handlers/initialize/all_messages.py
+++ b/bot/handlers/initialize/all_messages.py
@@ -4,5 +4,6 @@ from bot.db.pull_db import pull
async def sync_db_filter(message: Message):
await pull()
- await message.reply(f'ποΈ Bot database synchronised because of restart. '
- f'If you tried to run a command, run it again')
+ if message.is_command():
+ await message.reply(f'ποΈ Bot database synchronised because of restart. '
+ f'If you tried to run a command, run it again')
diff --git a/bot/handlers/register.py b/bot/handlers/register.py
index 4ca332c..b050c6b 100644
--- a/bot/handlers/register.py
+++ b/bot/handlers/register.py
@@ -1,7 +1,7 @@
from rich import print
-def import_handlers():
+def register_handlers():
from bot.handlers import (
initialize, admin, help_command, txt2img
)
@@ -11,4 +11,4 @@ def import_handlers():
help_command.register()
txt2img.register()
- print('[gray]All handlers imported[/]')
+ print('[gray]All handlers registered[/]')
diff --git a/bot/handlers/txt2img/__init__.py b/bot/handlers/txt2img/__init__.py
index 39e2ecf..f75dcaf 100644
--- a/bot/handlers/txt2img/__init__.py
+++ b/bot/handlers/txt2img/__init__.py
@@ -1,6 +1,6 @@
from bot.common import dp
-from .txt2img import *
+from .txt2img import txt2img_comand
def register():
- dp.register_message_handler(txt2img, commands='txt2img')
+ dp.register_message_handler(txt2img.txt2img_comand, commands='txt2img')
diff --git a/bot/handlers/txt2img/txt2img.py b/bot/handlers/txt2img/txt2img.py
index 2212684..66aa0bf 100644
--- a/bot/handlers/txt2img/txt2img.py
+++ b/bot/handlers/txt2img/txt2img.py
@@ -2,25 +2,50 @@ from aiogram import types
from bot.db import db, DBTables
from bot.utils.cooldown import throttle
from bot.modules.api.txt2img import txt2img
+from bot.modules.api.objects.prompt_request import Prompt
+from bot.modules.api.status import wait_for_status
+from bot.keyboards.exception import get_exception_keyboard
+from bot.utils.trace_exception import PrettyException
+from aiohttp import ClientConnectorError
@throttle(cooldown=30, admin_ids=db[DBTables.config].get('admins'))
async def txt2img_comand(message: types.Message):
- temp_message = await message.reply("β³ Generating image...")
+ temp_message = await message.reply("β³ Enqueued...")
if not message.get_args():
- await temp_message.edit_text("Specify prompt for this command. Check /help txt2img")
+ await temp_message.edit_text("πΆβπ«οΈ Specify prompt for this command. Check /help txt2img")
return
try:
- image = await txt2img(message.get_args())
- await message.reply_photo(photo=image[0], caption=str(
- image[1]["infotexts"][0]))
- except Exception as e:
- assert e
- await message.reply("We ran into error while processing your request. StableDiffusion models may not be "
- "configured on specified endpoint or server with StableDiffusion may be turned "
- "off. Ask admins of this bot instance if you have contacts for further info")
+ db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 0) + 1
+ await temp_message.edit_text(f"β³ Enqueued in position {db[DBTables.queue].get('n', 0)}...")
+
+ await wait_for_status()
+
+ await temp_message.edit_text(f"β Generating...")
+ prompt = Prompt(prompt=message.get_args(), creator=message.from_id)
+ image = await txt2img(prompt)
+ image_message = await message.reply_photo(photo=image[0])
+
+ db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
+ db[DBTables.generated][image_message.photo[0].file_unique_id] = prompt
+
await temp_message.delete()
+
+ await db[DBTables.config].write()
+
+ except ClientConnectorError:
+ await message.reply('β Error! Maybe, StableDiffusion API endpoint is incorrect '
+ 'or turned off')
+ await temp_message.delete()
+ db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
return
- await temp_message.delete()
+ except Exception as e:
+ exception_id = f'{message.message_thread_id}-{message.message_id}'
+ db[DBTables.exceptions][exception_id] = PrettyException(e)
+ await message.reply('β Error happened while processing your request', parse_mode='html',
+ reply_markup=get_exception_keyboard(exception_id))
+ await temp_message.delete()
+ db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
+ return
diff --git a/bot/keyboards/__init__.py b/bot/keyboards/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/bot/keyboards/exception.py b/bot/keyboards/exception.py
new file mode 100644
index 0000000..ca70ca1
--- /dev/null
+++ b/bot/keyboards/exception.py
@@ -0,0 +1,9 @@
+from aiogram import types
+from bot.callbacks.factories.exception import exception_callback
+
+
+def get_exception_keyboard(e_id: str) -> types.InlineKeyboardMarkup:
+ buttons = [types.InlineKeyboardButton(text="Show full stack", callback_data=exception_callback.new(e_id=e_id))]
+ keyboard = types.InlineKeyboardMarkup()
+ keyboard.add(*buttons)
+ return keyboard
diff --git a/bot/modules/api/objects/__init__.py b/bot/modules/api/objects/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/bot/modules/api/objects/prompt_request.py b/bot/modules/api/objects/prompt_request.py
new file mode 100644
index 0000000..3621d05
--- /dev/null
+++ b/bot/modules/api/objects/prompt_request.py
@@ -0,0 +1,14 @@
+import dataclasses
+
+
+@dataclasses.dataclass
+class Prompt:
+ prompt: str
+ negative_prompt: str = None
+ steps: int = 20
+ cfg_scale: int = 7
+ width: int = 768
+ height: int = 768
+ restore_faces: bool = True
+ sampler: str = "Euler a"
+ creator: int = None
diff --git a/bot/modules/api/status.py b/bot/modules/api/status.py
new file mode 100644
index 0000000..e033497
--- /dev/null
+++ b/bot/modules/api/status.py
@@ -0,0 +1,31 @@
+from bot.db import db, DBTables
+import aiohttp
+import asyncio
+import time
+
+
+async def job_exists(endpoint):
+ async with aiohttp.ClientSession() as session:
+ r = await session.get(
+ endpoint + "/sdapi/v1/progress",
+ json={
+ "skip_current_image": True,
+ }
+ )
+ if r.status != 200:
+ return None
+ return (await r.json()).get('state').get('job_count') > 0
+
+
+async def wait_for_status(ignore_exceptions: bool = False):
+ endpoint = db[DBTables.config].get('endpoint')
+ try:
+ while await job_exists(endpoint):
+ while db[DBTables.cooldown].get('_last_time_status_checked', 0) + 5 > time.time():
+ await asyncio.sleep(5)
+ db[DBTables.cooldown]['_last_time_status_checked'] = time.time()
+ return
+ except Exception as e:
+ if not ignore_exceptions:
+ raise e
+ return
diff --git a/bot/modules/api/txt2img.py b/bot/modules/api/txt2img.py
index 963f361..429e127 100644
--- a/bot/modules/api/txt2img.py
+++ b/bot/modules/api/txt2img.py
@@ -1,26 +1,25 @@
import aiohttp
from bot.db import db, DBTables
+from .objects.prompt_request import Prompt
import json
import base64
-async def txt2img(prompt: str, negative_prompt: str = None, steps: int = 20,
- cfg_scale: int = 7, width: int = 768, height: int = 768,
- restore_faces: bool = True, sampler: str = "Euler a") -> list[bytes, dict] | None:
+async def txt2img(prompt: Prompt, ignore_exceptions: bool = False) -> list[bytes, dict] | None:
endpoint = db[DBTables.config].get('endpoint')
try:
async with aiohttp.ClientSession() as session:
r = await session.post(
endpoint + "/sdapi/v1/txt2img",
json={
- "prompt": prompt,
- "steps": steps,
- "cfg_scale": cfg_scale,
- "width": width,
- "height": height,
- "restore_faces": restore_faces,
- "negative_prompt": negative_prompt,
- "sampler_index": sampler
+ "prompt": prompt.prompt,
+ "steps": prompt.steps,
+ "cfg_scale": prompt.cfg_scale,
+ "width": prompt.width,
+ "height": prompt.height,
+ "restore_faces": prompt.restore_faces,
+ "negative_prompt": prompt.negative_prompt,
+ "sampler_index": prompt.sampler
}
)
if r.status != 200:
@@ -28,5 +27,6 @@ async def txt2img(prompt: str, negative_prompt: str = None, steps: int = 20,
return [base64.b64decode((await r.json())["images"][0]),
json.loads((await r.json())["info"])]
except Exception as e:
- assert e
- return None
+ if not ignore_exceptions:
+ raise e
+ return
diff --git a/bot/utils/trace_exception.py b/bot/utils/trace_exception.py
new file mode 100644
index 0000000..a0e3a37
--- /dev/null
+++ b/bot/utils/trace_exception.py
@@ -0,0 +1,47 @@
+import os
+import traceback
+import contextlib
+import re
+
+
+class PrettyException:
+ def __init__(self, e: Exception):
+ self.pretty_exception = f'β Error! Report it to admins: \n' \
+ f'π {e.__traceback__.tb_frame.f_code.co_filename.replace(os.getcwd(), "")}' \
+ f':{e.__traceback__.tb_frame.f_lineno} \n' \
+ f'π {e.__class__.__name__} \n' \
+ f'π {"".join(traceback.format_exception_only(e)).strip()} \n\n' \
+ f'β¬οΈ Trace: \n' \
+ f'{self.get_full_stack()}'
+
+ @staticmethod
+ def get_full_stack():
+ full_stack = traceback.format_exc().replace(
+ "Traceback (most recent call last):\n", ""
+ )
+
+ line_regex = r' File "(.*?)", line ([0-9]+), in (.+)'
+
+ def format_line(line: str) -> str:
+ filename_, lineno_, name_ = re.search(line_regex, line).groups()
+ with contextlib.suppress(Exception):
+ filename_ = os.path.basename(filename_)
+
+ return (
+ f"π€― {filename_}:{lineno_} (in"
+ f" {name_} call)"
+ )
+
+ full_stack = "\n".join(
+ [
+ format_line(line)
+ if re.search(line_regex, line)
+ else f"{line}"
+ for line in full_stack.splitlines()
+ ]
+ )
+
+ return full_stack
+
+ def __str__(self):
+ return self.pretty_exception
diff --git a/main.py b/main.py
index e9812cc..6265fcb 100644
--- a/main.py
+++ b/main.py
@@ -4,11 +4,13 @@ from rich import print
async def main():
print(BARS_APP_ID)
- import bot.handlers.register
from bot.common import dp
+ import bot.handlers.register
+ import bot.callbacks.register
from bot.utils.commands import set_commands
- bot.handlers.register.import_handlers()
+ bot.handlers.register.register_handlers()
+ bot.callbacks.register.register_callbacks()
await set_commands()
print('[green]Bot will start now[/]')
await dp.skip_updates()