Setting generation properties

This commit is contained in:
BarsTiger
2023-02-26 22:16:05 +02:00
parent 8f0cd0ac05
commit 3c3d644137
10 changed files with 286 additions and 56 deletions

View File

@@ -5,7 +5,7 @@ from .factories.image_info import full_prompt, prompt_only, import_prompt, back
from bot.keyboards.image_info import get_img_info_keyboard, get_img_back_keyboard from bot.keyboards.image_info import get_img_info_keyboard, get_img_back_keyboard
from bot.utils.cooldown import throttle from bot.utils.cooldown import throttle
from bot.utils.private_keyboard import other_user from bot.utils.private_keyboard import other_user
from bot.modules.api.objects.prompt_request import Prompt from bot.modules.api.objects.prompt_request import Generated
async def on_back(call: types.CallbackQuery, callback_data: dict): async def on_back(call: types.CallbackQuery, callback_data: dict):
@@ -26,11 +26,11 @@ async def on_prompt_only(call: types.CallbackQuery, callback_data: dict):
if await other_user(call): if await other_user(call):
return return
prompt: Prompt = db[DBTables.generated].get(p_id) prompt: Generated = db[DBTables.generated].get(p_id)
await call.message.edit_text( await call.message.edit_text(
f"🖤 Prompt: {prompt.prompt} \n" f"🖤 Prompt: {prompt.prompt.prompt} \n"
f"{f'🐊 Negative: {prompt.negative_prompt}' if prompt.negative_prompt else ''}", f"{f'🐊 Negative: {prompt.prompt.negative_prompt}' if prompt.prompt.negative_prompt else ''}",
parse_mode='html', parse_mode='html',
reply_markup=get_img_back_keyboard(p_id) reply_markup=get_img_back_keyboard(p_id)
) )
@@ -42,16 +42,18 @@ async def on_full_info(call: types.CallbackQuery, callback_data: dict):
if await other_user(call): if await other_user(call):
return return
prompt: Prompt = db[DBTables.generated].get(p_id) prompt: Generated = db[DBTables.generated].get(p_id)
await call.message.edit_text( await call.message.edit_text(
f"🖤 Prompt: {prompt.prompt} \n" f"🖤 Prompt: {prompt.prompt.prompt} \n"
f"🐊 Negative: {prompt.negative_prompt} \n" f"🐊 Negative: {prompt.prompt.negative_prompt} \n"
f"🪜 Steps: {prompt.steps} \n" f"💫 Model: {prompt.model} \n"
f"🧑‍🎨 CFG Scale: {prompt.cfg_scale} \n" f"🪜 Steps: {prompt.prompt.steps} \n"
f"🖥️ Size: {prompt.width}x{prompt.height} \n" f"🧑‍🎨 CFG Scale: {prompt.prompt.cfg_scale} \n"
f"😀 Restore faces: {'on' if prompt.restore_faces else 'off'} \n" f"🖥️ Size: {prompt.prompt.width}x{prompt.prompt.height} \n"
f"⚒️ Sampler: {prompt.sampler}", f"😀 Restore faces: {'on' if prompt.prompt.restore_faces else 'off'} \n"
f"⚒️ Sampler: {prompt.prompt.sampler} \n"
f"🌱 Seed: {prompt.seed}",
parse_mode='html', parse_mode='html',
reply_markup=get_img_back_keyboard(p_id) reply_markup=get_img_back_keyboard(p_id)
) )
@@ -63,7 +65,7 @@ async def on_import(call: types.CallbackQuery, callback_data: dict):
if await other_user(call): if await other_user(call):
return return
prompt: Prompt = db[DBTables.generated].get(p_id) prompt: Generated = db[DBTables.generated].get(p_id)
await call.message.edit_text( await call.message.edit_text(
f"😶‍🌫️ Not implemented yet", f"😶‍🌫️ Not implemented yet",

View File

@@ -1,4 +1,15 @@
help_data = { help_data = {
'setendpoint': '(admin) Set StableDiffusion API endpoint', 'generate': 'Generate picture using configuration set by user. You can pass prompt also in command arguments or '
'imginfo': 'Get information about image, that was generated using this bot' 'use it without arguments to generate picture with prompt, that was used for last generation',
'imginfo': 'Get information about image, that was generated using this bot',
'setprompt': 'Set default prompt for images, will be overwritten if you specify prompt in generate command',
'setnegative': 'Set negative prompt',
'setsize': 'Set size for image (in hxw format)',
'setwidth': 'Set width for image',
'setheight': 'Set height for image',
'setsteps': 'Set sampling steps number',
'setsampler': 'Set StableDiffusion sampler',
'setscale': 'Set CFG Scale (prompt stringency)',
'setfaces': 'Set restore faces mode',
'setendpoint': '(admin) Set StableDiffusion API endpoint'
} }

View File

@@ -1,6 +1,19 @@
from bot.common import dp from bot.common import dp
from .txt2img import txt2img_comand from .txt2img import generate_command
from .set_settings import (
set_height_command, set_negative_prompt_command, set_size_command, set_steps_command, set_width_command,
set_prompt_command, set_sampler_command, set_cfg_scale_command, set_restore_faces_command
)
def register(): def register():
dp.register_message_handler(txt2img.txt2img_comand, commands='txt2img') dp.register_message_handler(txt2img.generate_command, commands='generate')
dp.register_message_handler(set_settings.set_prompt_command, commands='setprompt')
dp.register_message_handler(set_settings.set_height_command, commands='setheight')
dp.register_message_handler(set_settings.set_width_command, commands='setwidth')
dp.register_message_handler(set_settings.set_negative_prompt_command, commands='setnegative')
dp.register_message_handler(set_settings.set_size_command, commands='setsize')
dp.register_message_handler(set_settings.set_steps_command, commands='setsteps')
dp.register_message_handler(set_settings.set_sampler_command, commands='setsampler')
dp.register_message_handler(set_settings.set_cfg_scale_command, commands='setscale')
dp.register_message_handler(set_settings.set_restore_faces_command, commands='setfaces')

View File

@@ -6,22 +6,28 @@ from bot.keyboards.exception import get_exception_keyboard
from bot.utils.trace_exception import PrettyException from bot.utils.trace_exception import PrettyException
@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins')) async def _set_property(message: types.Message, prop: str, value=None):
async def set_prompt_command(message: types.Message): temp_message = await message.reply(f"⏳ Setting {prop}...")
temp_message = await message.reply("⏳ Setting prompt...")
if not message.get_args(): if not message.get_args():
await temp_message.edit_text("😶‍🌫️ Specify prompt for this command. Check /help setprompt") await temp_message.edit_text("😶‍🌫️ Specify arguments for this command. Check /help")
return return
try: try:
prompt: Prompt = db[DBTables.prompts].get(message.from_id, Prompt(message.get_args())) prompt: Prompt = db[DBTables.prompts].get(message.from_id)
prompt.prompt = message.get_args() if prompt is None and prop != 'prompt':
await temp_message.edit_text(f"You didn't created any prompt. Specify prompt text at least first time. "
f"For example, it can be: <code>masterpiece, best quality, 1girl, white hair, "
f"medium hair, cat ears, closed eyes, looking at viewer, :3, cute, scarf, "
f"jacket, outdoors, streets</code>", parse_mode='HTML')
return
prompt.__setattr__(prop, message.get_args() if value is None else value)
prompt.creator = message.from_id prompt.creator = message.from_id
db[DBTables.prompts][message.from_id] = prompt db[DBTables.prompts][message.from_id] = prompt
await db[DBTables.config].write() await db[DBTables.config].write()
await message.reply('Default prompt set') await message.reply(f'{prop} set')
await temp_message.delete() await temp_message.delete()
except Exception as e: except Exception as e:
@@ -32,3 +38,134 @@ async def set_prompt_command(message: types.Message):
await temp_message.delete() await temp_message.delete()
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1 db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
return return
@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
async def set_prompt_command(message: types.Message):
await _set_property(message, 'prompt')
@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
async def set_negative_prompt_command(message: types.Message):
await _set_property(message, 'negative_prompt')
@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
async def set_steps_command(message: types.Message):
try:
_ = int(message.get_args())
except Exception as e:
assert e
await message.reply('❌ Specify number as argument')
return
if _ > 30:
await message.reply('❌ Specify number <= 30')
return
await _set_property(message, 'steps')
@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
async def set_cfg_scale_command(message: types.Message):
try:
_ = int(message.get_args())
except Exception as e:
assert e
await message.reply('❌ Specify number as argument')
return
await _set_property(message, 'cfg_scale')
@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
async def set_width_command(message: types.Message):
try:
_ = int(message.get_args())
except Exception as e:
assert e
await message.reply('❌ Specify number as argument')
return
if _ > 768:
await message.reply('❌ Specify number <= 768')
return
await _set_property(message, 'width')
@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
async def set_height_command(message: types.Message):
try:
_ = int(message.get_args())
except Exception as e:
assert e
await message.reply('❌ Specify number as argument')
return
if _ > 768:
await message.reply('❌ Specify number <= 768')
return
await _set_property(message, 'height')
@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
async def set_restore_faces_command(message: types.Message):
try:
_ = bool(message.get_args())
except Exception as e:
assert e
await message.reply('❌ Specify boolean <code>True</code>/<code>False</code> as argument',
parse_mode='HTML')
return
await _set_property(message, 'restore_faces')
@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
async def set_sampler_command(message: types.Message):
temp_message = await message.reply('⏳ Getting samplers...')
from bot.modules.api.samplers import get_samplers
from aiohttp import ClientConnectorError
try:
if message.get_args() not in (samplers := await get_samplers()):
await message.reply(
f'❌ You can use only {", ".join(f"<code>{x}</code>" for x in samplers)}',
parse_mode='HTML'
)
return
await temp_message.delete()
except ClientConnectorError:
await message.reply('❌ Error! Maybe, StableDiffusion API endpoint is incorrect '
'or turned off')
except Exception as e:
exception_id = f'{message.message_thread_id}-{message.message_id}'
db[DBTables.exceptions][exception_id] = PrettyException(e)
await message.reply('❌ Error happened while processing your request', parse_mode='html',
reply_markup=get_exception_keyboard(exception_id))
await temp_message.delete()
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
return
await _set_property(message, 'sampler')
@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
async def set_size_command(message: types.Message):
try:
hxw = message.get_args().split('x')
height = int(hxw[0])
width = int(hxw[1])
except Exception as e:
assert e
await message.reply('❌ Specify size in <code>hxw</code> format, for example <code>512x512</code>',
parse_mode='HTML')
return
if height > 768 or width > 768:
await message.reply('❌ Specify numbers <= 768')
return
await _set_property(message, 'height', height)
await _set_property(message, 'width', width)

View File

@@ -1,24 +1,30 @@
import re
from aiogram import types from aiogram import types
from bot.db import db, DBTables from bot.db import db, DBTables
from bot.utils.cooldown import throttle from bot.utils.cooldown import throttle
from bot.modules.api.txt2img import txt2img from bot.modules.api.txt2img import txt2img
from bot.modules.api.objects.prompt_request import Prompt from bot.modules.api.objects.get_prompt import get_prompt
from bot.modules.api.objects.prompt_request import Generated
from bot.modules.api.status import wait_for_status from bot.modules.api.status import wait_for_status
from bot.keyboards.exception import get_exception_keyboard from bot.keyboards.exception import get_exception_keyboard
from bot.keyboards.image_info import get_img_info_keyboard
from bot.utils.trace_exception import PrettyException from bot.utils.trace_exception import PrettyException
from aiohttp import ClientConnectorError from aiohttp import ClientConnectorError
@throttle(cooldown=30, admin_ids=db[DBTables.config].get('admins')) @throttle(cooldown=30, admin_ids=db[DBTables.config].get('admins'))
async def txt2img_comand(message: types.Message): async def generate_command(message: types.Message):
temp_message = await message.reply("⏳ Enqueued...") temp_message = await message.reply("⏳ Enqueued...")
prompt: Prompt = db[DBTables.prompts].get(message.from_id) try:
if not prompt: prompt = get_prompt(user_id=message.from_id,
if message.get_args(): prompt_string=message.get_args())
db[DBTables.prompts][message.from_id] = Prompt(message.get_args(), creator=message.from_id) except AttributeError:
await temp_message.edit_text(f"You didn't created any prompt. Specify prompt text at least first time. "
# TODO: Move it to other module f"For example, it can be: <code>masterpiece, best quality, 1girl, white hair, "
f"medium hair, cat ears, closed eyes, looking at viewer, :3, cute, scarf, jacket, "
f"outdoors, streets</code>", parse_mode='HTML')
return
try: try:
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 0) + 1 db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 0) + 1
@@ -27,12 +33,18 @@ async def txt2img_comand(message: types.Message):
await wait_for_status() await wait_for_status()
await temp_message.edit_text(f"⌛ Generating...") await temp_message.edit_text(f"⌛ Generating...")
prompt = Prompt(prompt=message.get_args(), creator=message.from_id)
image = await txt2img(prompt) image = await txt2img(prompt)
image_message = await message.reply_photo(photo=image[0]) image_message = await message.reply_photo(photo=image[0])
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1 db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
db[DBTables.generated][image_message.photo[0].file_unique_id] = prompt db[DBTables.generated][image_message.photo[0].file_unique_id] = Generated(
prompt=prompt,
seed=image[1]['seed'],
model=re.search(r", Model: ([^,]+),", image[1]['infotexts'][0]).groups()[0]
)
await message.reply('Here is your image',
reply_markup=get_img_info_keyboard(image_message.photo[0].file_unique_id))
await temp_message.delete() await temp_message.delete()
@@ -45,6 +57,12 @@ async def txt2img_comand(message: types.Message):
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1 db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
return return
except ValueError as e:
await message.reply(f'❌ Error! {e.args[0]}')
await temp_message.delete()
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
return
except Exception as e: except Exception as e:
exception_id = f'{message.message_thread_id}-{message.message_id}' exception_id = f'{message.message_thread_id}-{message.message_id}'
db[DBTables.exceptions][exception_id] = PrettyException(e) db[DBTables.exceptions][exception_id] = PrettyException(e)

View File

@@ -1,24 +1,18 @@
import aiohttp import aiohttp
from bot.db import db, DBTables, decrypt from bot.db import db, DBTables, decrypt
from rich import print
async def get_models(): async def get_models():
endpoint = decrypt(db[DBTables.config].get('endpoint')) endpoint = decrypt(db[DBTables.config].get('endpoint'))
try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
r = await session.get(endpoint + "/sdapi/v1/sd-models") r = await session.get(endpoint + "/sdapi/v1/sd-models")
if r.status != 200: if r.status != 200:
return None return None
return [x["title"] for x in await r.json()] return [x["title"] for x in await r.json()]
except Exception as e:
print(e)
return None
async def set_model(model_name: str): async def set_model(model_name: str):
endpoint = decrypt(db[DBTables.config].get('endpoint')) endpoint = decrypt(db[DBTables.config].get('endpoint'))
try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
r = await session.post(endpoint + "/sdapi/v1/options", json={ r = await session.post(endpoint + "/sdapi/v1/options", json={
"sd_model_checkpoint": model_name "sd_model_checkpoint": model_name
@@ -26,6 +20,3 @@ async def set_model(model_name: str):
if r.status != 200: if r.status != 200:
return False return False
return True return True
except Exception as e:
print(e)
return False

View File

@@ -0,0 +1,38 @@
from bot.modules.api.objects.prompt_request import Prompt
from bot.db import db, DBTables
def get_prompt(user_id: int, prompt_string: str = None, negative_prompt: str = None, steps: int = None,
cfg_scale: int = None, width: int = None, height: int = None, restore_faces: bool = None,
sampler: str = None) -> Prompt:
new_prompt: Prompt = db[DBTables.prompts].get(user_id)
creator = user_id
if not new_prompt:
if user_id and prompt_string:
db[DBTables.prompts][user_id] = Prompt(
prompt=prompt_string,
creator=user_id,
negative_prompt=negative_prompt,
steps=steps,
cfg_scale=cfg_scale,
width=width,
height=height,
restore_faces=restore_faces,
sampler=sampler,
)
new_prompt: Prompt = db[DBTables.prompts].get(user_id)
else:
raise AttributeError('No prompt string specified and prompt doesn\'t exist for this user')
if prompt_string:
new_prompt.prompt = prompt_string
for key in new_prompt.__dict__.keys():
if key in locals().keys() and locals()[key]:
new_prompt.__setattr__(key, locals()[key])
elif not new_prompt.__getattribute__(key):
new_prompt.__setattr__(key, Prompt('').__getattribute__(key))
db[DBTables.prompts][user_id] = new_prompt
return new_prompt

View File

@@ -12,3 +12,10 @@ class Prompt:
restore_faces: bool = True restore_faces: bool = True
sampler: str = "Euler a" sampler: str = "Euler a"
creator: int = None creator: int = None
@dataclasses.dataclass
class Generated:
prompt: Prompt
seed: int
model: str

View File

@@ -0,0 +1,11 @@
import aiohttp
from bot.db import db, DBTables, decrypt
async def get_samplers():
endpoint = decrypt(db[DBTables.config].get('endpoint'))
async with aiohttp.ClientSession() as session:
r = await session.get(endpoint + "/sdapi/v1/samplers")
if r.status != 200:
return None
return [x["name"] for x in await r.json()]

View File

@@ -22,8 +22,10 @@ async def txt2img(prompt: Prompt, ignore_exceptions: bool = False) -> list[bytes
"sampler_index": prompt.sampler "sampler_index": prompt.sampler
} }
) )
if r.status != 200: if r.status != 200 and ignore_exceptions:
return None return None
elif r.status != 200:
raise ValueError((await r.json())['detail'])
return [base64.b64decode((await r.json())["images"][0]), return [base64.b64decode((await r.json())["images"][0]),
json.loads((await r.json())["info"])] json.loads((await r.json())["info"])]
except Exception as e: except Exception as e: