Setting model, /status command for pinging host
This commit is contained in:
5
bot/callbacks/factories/set_model.py
Normal file
5
bot/callbacks/factories/set_model.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from aiogram.utils.callback_data import CallbackData
|
||||||
|
|
||||||
|
|
||||||
|
set_model = CallbackData("set_model", "n")
|
||||||
|
set_model_page = CallbackData("set_model_page", "page")
|
||||||
@@ -4,10 +4,12 @@ from rich import print
|
|||||||
def register_callbacks():
|
def register_callbacks():
|
||||||
from bot.callbacks import (
|
from bot.callbacks import (
|
||||||
exception,
|
exception,
|
||||||
image_info
|
image_info,
|
||||||
|
set_model
|
||||||
)
|
)
|
||||||
|
|
||||||
exception.register()
|
exception.register()
|
||||||
image_info.register()
|
image_info.register()
|
||||||
|
set_model.register()
|
||||||
|
|
||||||
print('[gray]All callbacks registered[/]')
|
print('[gray]All callbacks registered[/]')
|
||||||
|
|||||||
50
bot/callbacks/set_model.py
Normal file
50
bot/callbacks/set_model.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
from bot.common import dp
|
||||||
|
from bot.db import db, DBTables
|
||||||
|
from aiogram import types
|
||||||
|
from .factories.set_model import set_model, set_model_page
|
||||||
|
from bot.keyboards.set_model import get_set_model_keyboard
|
||||||
|
from bot.utils.private_keyboard import other_user
|
||||||
|
from bot.keyboards.exception import get_exception_keyboard
|
||||||
|
from bot.utils.trace_exception import PrettyException
|
||||||
|
from aiohttp import ClientConnectorError
|
||||||
|
from bot.modules.api import models
|
||||||
|
|
||||||
|
|
||||||
|
async def on_set_model(call: types.CallbackQuery, callback_data: dict):
|
||||||
|
n = int(callback_data['n'])
|
||||||
|
temp_message = await call.message.answer('⏳ Setting model...')
|
||||||
|
try:
|
||||||
|
await models.set_model(db[DBTables.config]['models'][n])
|
||||||
|
|
||||||
|
except ClientConnectorError:
|
||||||
|
await call.answer('❌ Error! Maybe, StableDiffusion API endpoint is incorrect '
|
||||||
|
'or turned off', show_alert=True)
|
||||||
|
await call.message.delete()
|
||||||
|
await temp_message.delete()
|
||||||
|
return
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
exception_id = f'{call.message.message_thread_id}-{call.message.message_id}'
|
||||||
|
db[DBTables.exceptions][exception_id] = PrettyException(e)
|
||||||
|
await call.message.reply('❌ Error happened while processing your request', parse_mode='html',
|
||||||
|
reply_markup=get_exception_keyboard(exception_id))
|
||||||
|
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
|
||||||
|
return
|
||||||
|
|
||||||
|
await temp_message.answer('✅ Model set for all users!')
|
||||||
|
await temp_message.delete()
|
||||||
|
|
||||||
|
|
||||||
|
async def on_page_change(call: types.CallbackQuery, callback_data: dict):
|
||||||
|
page = callback_data['page']
|
||||||
|
if await other_user(call):
|
||||||
|
return
|
||||||
|
|
||||||
|
await call.message.edit_reply_markup(
|
||||||
|
get_set_model_keyboard(page)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def register():
|
||||||
|
dp.register_callback_query_handler(on_set_model, set_model.filter())
|
||||||
|
dp.register_callback_query_handler(on_page_change, set_model_page.filter())
|
||||||
@@ -11,6 +11,8 @@ help_data = {
|
|||||||
'setsampler': 'Set StableDiffusion sampler',
|
'setsampler': 'Set StableDiffusion sampler',
|
||||||
'setscale': 'Set CFG Scale (prompt stringency)',
|
'setscale': 'Set CFG Scale (prompt stringency)',
|
||||||
'setfaces': 'Set restore faces mode',
|
'setfaces': 'Set restore faces mode',
|
||||||
|
'status': 'Ping API endpoint host',
|
||||||
|
'setmodel': '(global) Sets StableDiffusion model for all users. Can be used only once an hour',
|
||||||
'setendpoint': '(admin) Set StableDiffusion API endpoint',
|
'setendpoint': '(admin) Set StableDiffusion API endpoint',
|
||||||
'addadmin': '(admin) Add new admin - reply to message or type user ID',
|
'addadmin': '(admin) Add new admin - reply to message or type user ID',
|
||||||
'rmadmin': '(admin) Remove admin - reply to message or type user ID'
|
'rmadmin': '(admin) Remove admin - reply to message or type user ID'
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
from aiogram.types import Message
|
from aiogram.types import Message
|
||||||
from bot.db.pull_db import pull
|
|
||||||
|
|
||||||
|
|
||||||
async def sync_db_filter(message: Message):
|
async def sync_db_filter(message: Message):
|
||||||
|
from bot.db.pull_db import pull
|
||||||
|
from bot.modules.api.ping import ping
|
||||||
await pull()
|
await pull()
|
||||||
if message.is_command():
|
if message.is_command():
|
||||||
await message.reply(f'🔄️ Bot database synchronised because of restart. '
|
await message.reply(f'🔄️ Bot database synchronised because of restart. '
|
||||||
f'If you tried to run a command, run it again')
|
f'If you tried to run a command, run it again')
|
||||||
|
if not await ping():
|
||||||
|
await message.reply('⚠️ Warning: StableDiffusion server is turned off or api endpoint is incorrect')
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
from bot.common import dp
|
from bot.common import dp
|
||||||
from .txt2img import generate_command
|
from .txt2img import generate_command
|
||||||
|
from .set_model import set_model_command
|
||||||
|
from .status import get_status
|
||||||
from .set_settings import (
|
from .set_settings import (
|
||||||
set_height_command, set_negative_prompt_command, set_size_command, set_steps_command, set_width_command,
|
set_height_command, set_negative_prompt_command, set_size_command, set_steps_command, set_width_command,
|
||||||
set_prompt_command, set_sampler_command, set_cfg_scale_command, set_restore_faces_command
|
set_prompt_command, set_sampler_command, set_cfg_scale_command, set_restore_faces_command
|
||||||
@@ -17,3 +19,5 @@ def register():
|
|||||||
dp.register_message_handler(set_settings.set_sampler_command, commands='setsampler')
|
dp.register_message_handler(set_settings.set_sampler_command, commands='setsampler')
|
||||||
dp.register_message_handler(set_settings.set_cfg_scale_command, commands='setscale')
|
dp.register_message_handler(set_settings.set_cfg_scale_command, commands='setscale')
|
||||||
dp.register_message_handler(set_settings.set_restore_faces_command, commands='setfaces')
|
dp.register_message_handler(set_settings.set_restore_faces_command, commands='setfaces')
|
||||||
|
dp.register_message_handler(set_model.set_model_command, commands='setmodel')
|
||||||
|
dp.register_message_handler(status.get_status, commands='status')
|
||||||
|
|||||||
37
bot/handlers/txt2img/set_model.py
Normal file
37
bot/handlers/txt2img/set_model.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
from aiogram import types
|
||||||
|
from bot.db import db, DBTables
|
||||||
|
from bot.utils.cooldown import throttle
|
||||||
|
from bot.keyboards.set_model import get_set_model_keyboard
|
||||||
|
from bot.keyboards.exception import get_exception_keyboard
|
||||||
|
from bot.utils.trace_exception import PrettyException
|
||||||
|
from bot.modules.api.models import get_models
|
||||||
|
from aiohttp import ClientConnectorError
|
||||||
|
|
||||||
|
|
||||||
|
@throttle(cooldown=60*60, admin_ids=db[DBTables.config].get('admins'), by_id=False)
|
||||||
|
async def set_model_command(message: types.Message):
|
||||||
|
temp_message = await message.reply('⏳ Getting models...')
|
||||||
|
try:
|
||||||
|
models = await get_models()
|
||||||
|
if models is not None and len(models) > 0:
|
||||||
|
db[DBTables.config]['models'] = models
|
||||||
|
else:
|
||||||
|
await temp_message.delete()
|
||||||
|
await message.reply('❌ No models available')
|
||||||
|
return
|
||||||
|
|
||||||
|
await temp_message.delete()
|
||||||
|
await message.reply("🪄 You can choose model from available:", reply_markup=get_set_model_keyboard(0))
|
||||||
|
|
||||||
|
except ClientConnectorError:
|
||||||
|
await message.reply('❌ Error! Maybe, StableDiffusion API endpoint is incorrect '
|
||||||
|
'or turned off')
|
||||||
|
await temp_message.delete()
|
||||||
|
return
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
exception_id = f'{message.message_thread_id}-{message.message_id}'
|
||||||
|
db[DBTables.exceptions][exception_id] = PrettyException(e)
|
||||||
|
await message.reply('❌ Error happened while processing your request', parse_mode='html',
|
||||||
|
reply_markup=get_exception_keyboard(exception_id))
|
||||||
|
return
|
||||||
20
bot/handlers/txt2img/status.py
Normal file
20
bot/handlers/txt2img/status.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from aiogram import types
|
||||||
|
from bot.db import db, DBTables
|
||||||
|
from bot.utils.cooldown import throttle
|
||||||
|
from bot.modules.api.ping import ping
|
||||||
|
|
||||||
|
|
||||||
|
@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
|
||||||
|
async def get_status(message: types.Message):
|
||||||
|
temp_message = await message.reply('⏳ Sending request...')
|
||||||
|
try:
|
||||||
|
if await ping():
|
||||||
|
await message.reply('💚 Endpoint is UP')
|
||||||
|
else:
|
||||||
|
raise Exception
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
assert e
|
||||||
|
await message.reply('💔 Endpoint is probably DOWN or incorrect')
|
||||||
|
|
||||||
|
await temp_message.delete()
|
||||||
31
bot/keyboards/set_model.py
Normal file
31
bot/keyboards/set_model.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
from aiogram import types
|
||||||
|
from bot.db import db, DBTables
|
||||||
|
from bot.callbacks.factories.set_model import set_model_page, set_model
|
||||||
|
|
||||||
|
|
||||||
|
def get_set_model_keyboard(page: int) -> types.InlineKeyboardMarkup:
|
||||||
|
models = db[DBTables.config]['models']
|
||||||
|
navigation_buttons = list()
|
||||||
|
page = int(page)
|
||||||
|
|
||||||
|
if page > 0:
|
||||||
|
navigation_buttons.append(types.InlineKeyboardButton(
|
||||||
|
'<',
|
||||||
|
callback_data=set_model_page.new(page=page - 1)
|
||||||
|
))
|
||||||
|
if len([models[i:i + 5] for i in range(0, len(models), 5)]) > page + 1:
|
||||||
|
navigation_buttons.append(types.InlineKeyboardButton(
|
||||||
|
'>',
|
||||||
|
callback_data=set_model_page.new(page=page + 1)
|
||||||
|
))
|
||||||
|
|
||||||
|
models_buttons = [types.InlineKeyboardButton(models[i], callback_data=set_model.new(i)) for i in range(len(models))]
|
||||||
|
|
||||||
|
keyboard = types.InlineKeyboardMarkup(row_width=1)
|
||||||
|
if len(models) > 5:
|
||||||
|
keyboard.add(*[models_buttons[i:i + 5] for i in range(0, len(models_buttons), 5)][page])
|
||||||
|
keyboard.row(*navigation_buttons)
|
||||||
|
else:
|
||||||
|
keyboard.add(*models_buttons)
|
||||||
|
|
||||||
|
return keyboard
|
||||||
@@ -2,7 +2,7 @@ import aiohttp
|
|||||||
from bot.db import db, DBTables, decrypt
|
from bot.db import db, DBTables, decrypt
|
||||||
|
|
||||||
|
|
||||||
async def get_models():
|
async def get_models() -> list | None:
|
||||||
endpoint = decrypt(db[DBTables.config].get('endpoint'))
|
endpoint = decrypt(db[DBTables.config].get('endpoint'))
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
r = await session.get(endpoint + "/sdapi/v1/sd-models")
|
r = await session.get(endpoint + "/sdapi/v1/sd-models")
|
||||||
|
|||||||
14
bot/modules/api/ping.py
Normal file
14
bot/modules/api/ping.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
import aiohttp
|
||||||
|
from bot.db import db, DBTables, decrypt
|
||||||
|
|
||||||
|
|
||||||
|
async def ping():
|
||||||
|
endpoint = decrypt(db[DBTables.config].get('endpoint'))
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
r = await session.head(endpoint)
|
||||||
|
if r.status != 200:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
except aiohttp.ClientConnectorError:
|
||||||
|
return False
|
||||||
Reference in New Issue
Block a user