Setting generation properties
This commit is contained in:
@@ -5,7 +5,7 @@ from .factories.image_info import full_prompt, prompt_only, import_prompt, back
|
|||||||
from bot.keyboards.image_info import get_img_info_keyboard, get_img_back_keyboard
|
from bot.keyboards.image_info import get_img_info_keyboard, get_img_back_keyboard
|
||||||
from bot.utils.cooldown import throttle
|
from bot.utils.cooldown import throttle
|
||||||
from bot.utils.private_keyboard import other_user
|
from bot.utils.private_keyboard import other_user
|
||||||
from bot.modules.api.objects.prompt_request import Prompt
|
from bot.modules.api.objects.prompt_request import Generated
|
||||||
|
|
||||||
|
|
||||||
async def on_back(call: types.CallbackQuery, callback_data: dict):
|
async def on_back(call: types.CallbackQuery, callback_data: dict):
|
||||||
@@ -26,11 +26,11 @@ async def on_prompt_only(call: types.CallbackQuery, callback_data: dict):
|
|||||||
if await other_user(call):
|
if await other_user(call):
|
||||||
return
|
return
|
||||||
|
|
||||||
prompt: Prompt = db[DBTables.generated].get(p_id)
|
prompt: Generated = db[DBTables.generated].get(p_id)
|
||||||
|
|
||||||
await call.message.edit_text(
|
await call.message.edit_text(
|
||||||
f"🖤 Prompt: {prompt.prompt} \n"
|
f"🖤 Prompt: {prompt.prompt.prompt} \n"
|
||||||
f"{f'🐊 Negative: {prompt.negative_prompt}' if prompt.negative_prompt else ''}",
|
f"{f'🐊 Negative: {prompt.prompt.negative_prompt}' if prompt.prompt.negative_prompt else ''}",
|
||||||
parse_mode='html',
|
parse_mode='html',
|
||||||
reply_markup=get_img_back_keyboard(p_id)
|
reply_markup=get_img_back_keyboard(p_id)
|
||||||
)
|
)
|
||||||
@@ -42,16 +42,18 @@ async def on_full_info(call: types.CallbackQuery, callback_data: dict):
|
|||||||
if await other_user(call):
|
if await other_user(call):
|
||||||
return
|
return
|
||||||
|
|
||||||
prompt: Prompt = db[DBTables.generated].get(p_id)
|
prompt: Generated = db[DBTables.generated].get(p_id)
|
||||||
|
|
||||||
await call.message.edit_text(
|
await call.message.edit_text(
|
||||||
f"🖤 Prompt: {prompt.prompt} \n"
|
f"🖤 Prompt: {prompt.prompt.prompt} \n"
|
||||||
f"🐊 Negative: {prompt.negative_prompt} \n"
|
f"🐊 Negative: {prompt.prompt.negative_prompt} \n"
|
||||||
f"🪜 Steps: {prompt.steps} \n"
|
f"💫 Model: {prompt.model} \n"
|
||||||
f"🧑🎨 CFG Scale: {prompt.cfg_scale} \n"
|
f"🪜 Steps: {prompt.prompt.steps} \n"
|
||||||
f"🖥️ Size: {prompt.width}x{prompt.height} \n"
|
f"🧑🎨 CFG Scale: {prompt.prompt.cfg_scale} \n"
|
||||||
f"😀 Restore faces: {'on' if prompt.restore_faces else 'off'} \n"
|
f"🖥️ Size: {prompt.prompt.width}x{prompt.prompt.height} \n"
|
||||||
f"⚒️ Sampler: {prompt.sampler}",
|
f"😀 Restore faces: {'on' if prompt.prompt.restore_faces else 'off'} \n"
|
||||||
|
f"⚒️ Sampler: {prompt.prompt.sampler} \n"
|
||||||
|
f"🌱 Seed: {prompt.seed}",
|
||||||
parse_mode='html',
|
parse_mode='html',
|
||||||
reply_markup=get_img_back_keyboard(p_id)
|
reply_markup=get_img_back_keyboard(p_id)
|
||||||
)
|
)
|
||||||
@@ -63,7 +65,7 @@ async def on_import(call: types.CallbackQuery, callback_data: dict):
|
|||||||
if await other_user(call):
|
if await other_user(call):
|
||||||
return
|
return
|
||||||
|
|
||||||
prompt: Prompt = db[DBTables.generated].get(p_id)
|
prompt: Generated = db[DBTables.generated].get(p_id)
|
||||||
|
|
||||||
await call.message.edit_text(
|
await call.message.edit_text(
|
||||||
f"😶🌫️ Not implemented yet",
|
f"😶🌫️ Not implemented yet",
|
||||||
|
|||||||
@@ -1,4 +1,15 @@
|
|||||||
help_data = {
|
help_data = {
|
||||||
'setendpoint': '(admin) Set StableDiffusion API endpoint',
|
'generate': 'Generate picture using configuration set by user. You can pass prompt also in command arguments or '
|
||||||
'imginfo': 'Get information about image, that was generated using this bot'
|
'use it without arguments to generate picture with prompt, that was used for last generation',
|
||||||
|
'imginfo': 'Get information about image, that was generated using this bot',
|
||||||
|
'setprompt': 'Set default prompt for images, will be overwritten if you specify prompt in generate command',
|
||||||
|
'setnegative': 'Set negative prompt',
|
||||||
|
'setsize': 'Set size for image (in hxw format)',
|
||||||
|
'setwidth': 'Set width for image',
|
||||||
|
'setheight': 'Set height for image',
|
||||||
|
'setsteps': 'Set sampling steps number',
|
||||||
|
'setsampler': 'Set StableDiffusion sampler',
|
||||||
|
'setscale': 'Set CFG Scale (prompt stringency)',
|
||||||
|
'setfaces': 'Set restore faces mode',
|
||||||
|
'setendpoint': '(admin) Set StableDiffusion API endpoint'
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,19 @@
|
|||||||
from bot.common import dp
|
from bot.common import dp
|
||||||
from .txt2img import txt2img_comand
|
from .txt2img import generate_command
|
||||||
|
from .set_settings import (
|
||||||
|
set_height_command, set_negative_prompt_command, set_size_command, set_steps_command, set_width_command,
|
||||||
|
set_prompt_command, set_sampler_command, set_cfg_scale_command, set_restore_faces_command
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def register():
|
def register():
|
||||||
dp.register_message_handler(txt2img.txt2img_comand, commands='txt2img')
|
dp.register_message_handler(txt2img.generate_command, commands='generate')
|
||||||
|
dp.register_message_handler(set_settings.set_prompt_command, commands='setprompt')
|
||||||
|
dp.register_message_handler(set_settings.set_height_command, commands='setheight')
|
||||||
|
dp.register_message_handler(set_settings.set_width_command, commands='setwidth')
|
||||||
|
dp.register_message_handler(set_settings.set_negative_prompt_command, commands='setnegative')
|
||||||
|
dp.register_message_handler(set_settings.set_size_command, commands='setsize')
|
||||||
|
dp.register_message_handler(set_settings.set_steps_command, commands='setsteps')
|
||||||
|
dp.register_message_handler(set_settings.set_sampler_command, commands='setsampler')
|
||||||
|
dp.register_message_handler(set_settings.set_cfg_scale_command, commands='setscale')
|
||||||
|
dp.register_message_handler(set_settings.set_restore_faces_command, commands='setfaces')
|
||||||
|
|||||||
@@ -6,22 +6,28 @@ from bot.keyboards.exception import get_exception_keyboard
|
|||||||
from bot.utils.trace_exception import PrettyException
|
from bot.utils.trace_exception import PrettyException
|
||||||
|
|
||||||
|
|
||||||
@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
|
async def _set_property(message: types.Message, prop: str, value=None):
|
||||||
async def set_prompt_command(message: types.Message):
|
temp_message = await message.reply(f"⏳ Setting {prop}...")
|
||||||
temp_message = await message.reply("⏳ Setting prompt...")
|
|
||||||
if not message.get_args():
|
if not message.get_args():
|
||||||
await temp_message.edit_text("😶🌫️ Specify prompt for this command. Check /help setprompt")
|
await temp_message.edit_text("😶🌫️ Specify arguments for this command. Check /help")
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
prompt: Prompt = db[DBTables.prompts].get(message.from_id, Prompt(message.get_args()))
|
prompt: Prompt = db[DBTables.prompts].get(message.from_id)
|
||||||
prompt.prompt = message.get_args()
|
if prompt is None and prop != 'prompt':
|
||||||
|
await temp_message.edit_text(f"You didn't created any prompt. Specify prompt text at least first time. "
|
||||||
|
f"For example, it can be: <code>masterpiece, best quality, 1girl, white hair, "
|
||||||
|
f"medium hair, cat ears, closed eyes, looking at viewer, :3, cute, scarf, "
|
||||||
|
f"jacket, outdoors, streets</code>", parse_mode='HTML')
|
||||||
|
return
|
||||||
|
|
||||||
|
prompt.__setattr__(prop, message.get_args() if value is None else value)
|
||||||
prompt.creator = message.from_id
|
prompt.creator = message.from_id
|
||||||
db[DBTables.prompts][message.from_id] = prompt
|
db[DBTables.prompts][message.from_id] = prompt
|
||||||
|
|
||||||
await db[DBTables.config].write()
|
await db[DBTables.config].write()
|
||||||
|
|
||||||
await message.reply('✅ Default prompt set')
|
await message.reply(f'✅ {prop} set')
|
||||||
await temp_message.delete()
|
await temp_message.delete()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -32,3 +38,134 @@ async def set_prompt_command(message: types.Message):
|
|||||||
await temp_message.delete()
|
await temp_message.delete()
|
||||||
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
|
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
|
||||||
|
async def set_prompt_command(message: types.Message):
|
||||||
|
await _set_property(message, 'prompt')
|
||||||
|
|
||||||
|
|
||||||
|
@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
|
||||||
|
async def set_negative_prompt_command(message: types.Message):
|
||||||
|
await _set_property(message, 'negative_prompt')
|
||||||
|
|
||||||
|
|
||||||
|
@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
|
||||||
|
async def set_steps_command(message: types.Message):
|
||||||
|
try:
|
||||||
|
_ = int(message.get_args())
|
||||||
|
except Exception as e:
|
||||||
|
assert e
|
||||||
|
await message.reply('❌ Specify number as argument')
|
||||||
|
return
|
||||||
|
|
||||||
|
if _ > 30:
|
||||||
|
await message.reply('❌ Specify number <= 30')
|
||||||
|
return
|
||||||
|
|
||||||
|
await _set_property(message, 'steps')
|
||||||
|
|
||||||
|
|
||||||
|
@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
|
||||||
|
async def set_cfg_scale_command(message: types.Message):
|
||||||
|
try:
|
||||||
|
_ = int(message.get_args())
|
||||||
|
except Exception as e:
|
||||||
|
assert e
|
||||||
|
await message.reply('❌ Specify number as argument')
|
||||||
|
return
|
||||||
|
|
||||||
|
await _set_property(message, 'cfg_scale')
|
||||||
|
|
||||||
|
|
||||||
|
@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
|
||||||
|
async def set_width_command(message: types.Message):
|
||||||
|
try:
|
||||||
|
_ = int(message.get_args())
|
||||||
|
except Exception as e:
|
||||||
|
assert e
|
||||||
|
await message.reply('❌ Specify number as argument')
|
||||||
|
return
|
||||||
|
|
||||||
|
if _ > 768:
|
||||||
|
await message.reply('❌ Specify number <= 768')
|
||||||
|
return
|
||||||
|
|
||||||
|
await _set_property(message, 'width')
|
||||||
|
|
||||||
|
|
||||||
|
@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
|
||||||
|
async def set_height_command(message: types.Message):
|
||||||
|
try:
|
||||||
|
_ = int(message.get_args())
|
||||||
|
except Exception as e:
|
||||||
|
assert e
|
||||||
|
await message.reply('❌ Specify number as argument')
|
||||||
|
return
|
||||||
|
|
||||||
|
if _ > 768:
|
||||||
|
await message.reply('❌ Specify number <= 768')
|
||||||
|
return
|
||||||
|
|
||||||
|
await _set_property(message, 'height')
|
||||||
|
|
||||||
|
|
||||||
|
@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
|
||||||
|
async def set_restore_faces_command(message: types.Message):
|
||||||
|
try:
|
||||||
|
_ = bool(message.get_args())
|
||||||
|
except Exception as e:
|
||||||
|
assert e
|
||||||
|
await message.reply('❌ Specify boolean <code>True</code>/<code>False</code> as argument',
|
||||||
|
parse_mode='HTML')
|
||||||
|
return
|
||||||
|
|
||||||
|
await _set_property(message, 'restore_faces')
|
||||||
|
|
||||||
|
|
||||||
|
@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
|
||||||
|
async def set_sampler_command(message: types.Message):
|
||||||
|
temp_message = await message.reply('⏳ Getting samplers...')
|
||||||
|
from bot.modules.api.samplers import get_samplers
|
||||||
|
from aiohttp import ClientConnectorError
|
||||||
|
try:
|
||||||
|
if message.get_args() not in (samplers := await get_samplers()):
|
||||||
|
await message.reply(
|
||||||
|
f'❌ You can use only {", ".join(f"<code>{x}</code>" for x in samplers)}',
|
||||||
|
parse_mode='HTML'
|
||||||
|
)
|
||||||
|
return
|
||||||
|
await temp_message.delete()
|
||||||
|
except ClientConnectorError:
|
||||||
|
await message.reply('❌ Error! Maybe, StableDiffusion API endpoint is incorrect '
|
||||||
|
'or turned off')
|
||||||
|
except Exception as e:
|
||||||
|
exception_id = f'{message.message_thread_id}-{message.message_id}'
|
||||||
|
db[DBTables.exceptions][exception_id] = PrettyException(e)
|
||||||
|
await message.reply('❌ Error happened while processing your request', parse_mode='html',
|
||||||
|
reply_markup=get_exception_keyboard(exception_id))
|
||||||
|
await temp_message.delete()
|
||||||
|
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
|
||||||
|
return
|
||||||
|
|
||||||
|
await _set_property(message, 'sampler')
|
||||||
|
|
||||||
|
|
||||||
|
@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
|
||||||
|
async def set_size_command(message: types.Message):
|
||||||
|
try:
|
||||||
|
hxw = message.get_args().split('x')
|
||||||
|
height = int(hxw[0])
|
||||||
|
width = int(hxw[1])
|
||||||
|
except Exception as e:
|
||||||
|
assert e
|
||||||
|
await message.reply('❌ Specify size in <code>hxw</code> format, for example <code>512x512</code>',
|
||||||
|
parse_mode='HTML')
|
||||||
|
return
|
||||||
|
|
||||||
|
if height > 768 or width > 768:
|
||||||
|
await message.reply('❌ Specify numbers <= 768')
|
||||||
|
return
|
||||||
|
|
||||||
|
await _set_property(message, 'height', height)
|
||||||
|
await _set_property(message, 'width', width)
|
||||||
|
|||||||
@@ -1,24 +1,30 @@
|
|||||||
|
import re
|
||||||
from aiogram import types
|
from aiogram import types
|
||||||
from bot.db import db, DBTables
|
from bot.db import db, DBTables
|
||||||
from bot.utils.cooldown import throttle
|
from bot.utils.cooldown import throttle
|
||||||
from bot.modules.api.txt2img import txt2img
|
from bot.modules.api.txt2img import txt2img
|
||||||
from bot.modules.api.objects.prompt_request import Prompt
|
from bot.modules.api.objects.get_prompt import get_prompt
|
||||||
|
from bot.modules.api.objects.prompt_request import Generated
|
||||||
from bot.modules.api.status import wait_for_status
|
from bot.modules.api.status import wait_for_status
|
||||||
from bot.keyboards.exception import get_exception_keyboard
|
from bot.keyboards.exception import get_exception_keyboard
|
||||||
|
from bot.keyboards.image_info import get_img_info_keyboard
|
||||||
from bot.utils.trace_exception import PrettyException
|
from bot.utils.trace_exception import PrettyException
|
||||||
from aiohttp import ClientConnectorError
|
from aiohttp import ClientConnectorError
|
||||||
|
|
||||||
|
|
||||||
@throttle(cooldown=30, admin_ids=db[DBTables.config].get('admins'))
|
@throttle(cooldown=30, admin_ids=db[DBTables.config].get('admins'))
|
||||||
async def txt2img_comand(message: types.Message):
|
async def generate_command(message: types.Message):
|
||||||
temp_message = await message.reply("⏳ Enqueued...")
|
temp_message = await message.reply("⏳ Enqueued...")
|
||||||
|
|
||||||
prompt: Prompt = db[DBTables.prompts].get(message.from_id)
|
try:
|
||||||
if not prompt:
|
prompt = get_prompt(user_id=message.from_id,
|
||||||
if message.get_args():
|
prompt_string=message.get_args())
|
||||||
db[DBTables.prompts][message.from_id] = Prompt(message.get_args(), creator=message.from_id)
|
except AttributeError:
|
||||||
|
await temp_message.edit_text(f"You didn't created any prompt. Specify prompt text at least first time. "
|
||||||
# TODO: Move it to other module
|
f"For example, it can be: <code>masterpiece, best quality, 1girl, white hair, "
|
||||||
|
f"medium hair, cat ears, closed eyes, looking at viewer, :3, cute, scarf, jacket, "
|
||||||
|
f"outdoors, streets</code>", parse_mode='HTML')
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 0) + 1
|
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 0) + 1
|
||||||
@@ -27,12 +33,18 @@ async def txt2img_comand(message: types.Message):
|
|||||||
await wait_for_status()
|
await wait_for_status()
|
||||||
|
|
||||||
await temp_message.edit_text(f"⌛ Generating...")
|
await temp_message.edit_text(f"⌛ Generating...")
|
||||||
prompt = Prompt(prompt=message.get_args(), creator=message.from_id)
|
|
||||||
image = await txt2img(prompt)
|
image = await txt2img(prompt)
|
||||||
image_message = await message.reply_photo(photo=image[0])
|
image_message = await message.reply_photo(photo=image[0])
|
||||||
|
|
||||||
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
|
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
|
||||||
db[DBTables.generated][image_message.photo[0].file_unique_id] = prompt
|
db[DBTables.generated][image_message.photo[0].file_unique_id] = Generated(
|
||||||
|
prompt=prompt,
|
||||||
|
seed=image[1]['seed'],
|
||||||
|
model=re.search(r", Model: ([^,]+),", image[1]['infotexts'][0]).groups()[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
await message.reply('Here is your image',
|
||||||
|
reply_markup=get_img_info_keyboard(image_message.photo[0].file_unique_id))
|
||||||
|
|
||||||
await temp_message.delete()
|
await temp_message.delete()
|
||||||
|
|
||||||
@@ -45,6 +57,12 @@ async def txt2img_comand(message: types.Message):
|
|||||||
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
|
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
|
||||||
return
|
return
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
await message.reply(f'❌ Error! {e.args[0]}')
|
||||||
|
await temp_message.delete()
|
||||||
|
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
|
||||||
|
return
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
exception_id = f'{message.message_thread_id}-{message.message_id}'
|
exception_id = f'{message.message_thread_id}-{message.message_id}'
|
||||||
db[DBTables.exceptions][exception_id] = PrettyException(e)
|
db[DBTables.exceptions][exception_id] = PrettyException(e)
|
||||||
|
|||||||
@@ -1,31 +1,22 @@
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
from bot.db import db, DBTables, decrypt
|
from bot.db import db, DBTables, decrypt
|
||||||
from rich import print
|
|
||||||
|
|
||||||
|
|
||||||
async def get_models():
|
async def get_models():
|
||||||
endpoint = decrypt(db[DBTables.config].get('endpoint'))
|
endpoint = decrypt(db[DBTables.config].get('endpoint'))
|
||||||
try:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with aiohttp.ClientSession() as session:
|
r = await session.get(endpoint + "/sdapi/v1/sd-models")
|
||||||
r = await session.get(endpoint + "/sdapi/v1/sd-models")
|
if r.status != 200:
|
||||||
if r.status != 200:
|
return None
|
||||||
return None
|
return [x["title"] for x in await r.json()]
|
||||||
return [x["title"] for x in await r.json()]
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def set_model(model_name: str):
|
async def set_model(model_name: str):
|
||||||
endpoint = decrypt(db[DBTables.config].get('endpoint'))
|
endpoint = decrypt(db[DBTables.config].get('endpoint'))
|
||||||
try:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with aiohttp.ClientSession() as session:
|
r = await session.post(endpoint + "/sdapi/v1/options", json={
|
||||||
r = await session.post(endpoint + "/sdapi/v1/options", json={
|
"sd_model_checkpoint": model_name
|
||||||
"sd_model_checkpoint": model_name
|
})
|
||||||
})
|
if r.status != 200:
|
||||||
if r.status != 200:
|
return False
|
||||||
return False
|
return True
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
return False
|
|
||||||
|
|||||||
38
bot/modules/api/objects/get_prompt.py
Normal file
38
bot/modules/api/objects/get_prompt.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
from bot.modules.api.objects.prompt_request import Prompt
|
||||||
|
from bot.db import db, DBTables
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt(user_id: int, prompt_string: str = None, negative_prompt: str = None, steps: int = None,
|
||||||
|
cfg_scale: int = None, width: int = None, height: int = None, restore_faces: bool = None,
|
||||||
|
sampler: str = None) -> Prompt:
|
||||||
|
new_prompt: Prompt = db[DBTables.prompts].get(user_id)
|
||||||
|
creator = user_id
|
||||||
|
if not new_prompt:
|
||||||
|
if user_id and prompt_string:
|
||||||
|
db[DBTables.prompts][user_id] = Prompt(
|
||||||
|
prompt=prompt_string,
|
||||||
|
creator=user_id,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
steps=steps,
|
||||||
|
cfg_scale=cfg_scale,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
restore_faces=restore_faces,
|
||||||
|
sampler=sampler,
|
||||||
|
)
|
||||||
|
new_prompt: Prompt = db[DBTables.prompts].get(user_id)
|
||||||
|
else:
|
||||||
|
raise AttributeError('No prompt string specified and prompt doesn\'t exist for this user')
|
||||||
|
|
||||||
|
if prompt_string:
|
||||||
|
new_prompt.prompt = prompt_string
|
||||||
|
|
||||||
|
for key in new_prompt.__dict__.keys():
|
||||||
|
if key in locals().keys() and locals()[key]:
|
||||||
|
new_prompt.__setattr__(key, locals()[key])
|
||||||
|
elif not new_prompt.__getattribute__(key):
|
||||||
|
new_prompt.__setattr__(key, Prompt('').__getattribute__(key))
|
||||||
|
|
||||||
|
db[DBTables.prompts][user_id] = new_prompt
|
||||||
|
|
||||||
|
return new_prompt
|
||||||
@@ -12,3 +12,10 @@ class Prompt:
|
|||||||
restore_faces: bool = True
|
restore_faces: bool = True
|
||||||
sampler: str = "Euler a"
|
sampler: str = "Euler a"
|
||||||
creator: int = None
|
creator: int = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class Generated:
|
||||||
|
prompt: Prompt
|
||||||
|
seed: int
|
||||||
|
model: str
|
||||||
|
|||||||
11
bot/modules/api/samplers.py
Normal file
11
bot/modules/api/samplers.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
import aiohttp
|
||||||
|
from bot.db import db, DBTables, decrypt
|
||||||
|
|
||||||
|
|
||||||
|
async def get_samplers():
|
||||||
|
endpoint = decrypt(db[DBTables.config].get('endpoint'))
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
r = await session.get(endpoint + "/sdapi/v1/samplers")
|
||||||
|
if r.status != 200:
|
||||||
|
return None
|
||||||
|
return [x["name"] for x in await r.json()]
|
||||||
@@ -22,8 +22,10 @@ async def txt2img(prompt: Prompt, ignore_exceptions: bool = False) -> list[bytes
|
|||||||
"sampler_index": prompt.sampler
|
"sampler_index": prompt.sampler
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if r.status != 200:
|
if r.status != 200 and ignore_exceptions:
|
||||||
return None
|
return None
|
||||||
|
elif r.status != 200:
|
||||||
|
raise ValueError((await r.json())['detail'])
|
||||||
return [base64.b64decode((await r.json())["images"][0]),
|
return [base64.b64decode((await r.json())["images"][0]),
|
||||||
json.loads((await r.json())["info"])]
|
json.loads((await r.json())["info"])]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
Reference in New Issue
Block a user