diff --git a/bot/callbacks/image_info.py b/bot/callbacks/image_info.py
index 54e07f0..ae0c317 100644
--- a/bot/callbacks/image_info.py
+++ b/bot/callbacks/image_info.py
@@ -5,7 +5,7 @@ from .factories.image_info import full_prompt, prompt_only, import_prompt, back
from bot.keyboards.image_info import get_img_info_keyboard, get_img_back_keyboard
from bot.utils.cooldown import throttle
from bot.utils.private_keyboard import other_user
-from bot.modules.api.objects.prompt_request import Prompt
+from bot.modules.api.objects.prompt_request import Generated
async def on_back(call: types.CallbackQuery, callback_data: dict):
@@ -26,11 +26,11 @@ async def on_prompt_only(call: types.CallbackQuery, callback_data: dict):
if await other_user(call):
return
- prompt: Prompt = db[DBTables.generated].get(p_id)
+ prompt: Generated = db[DBTables.generated].get(p_id)
await call.message.edit_text(
- f"π€ Prompt: {prompt.prompt} \n"
- f"{f'π Negative: {prompt.negative_prompt}' if prompt.negative_prompt else ''}",
+ f"π€ Prompt: {prompt.prompt.prompt} \n"
+ f"{f'π Negative: {prompt.prompt.negative_prompt}' if prompt.prompt.negative_prompt else ''}",
parse_mode='html',
reply_markup=get_img_back_keyboard(p_id)
)
@@ -42,16 +42,18 @@ async def on_full_info(call: types.CallbackQuery, callback_data: dict):
if await other_user(call):
return
- prompt: Prompt = db[DBTables.generated].get(p_id)
+ prompt: Generated = db[DBTables.generated].get(p_id)
await call.message.edit_text(
- f"π€ Prompt: {prompt.prompt} \n"
- f"π Negative: {prompt.negative_prompt} \n"
- f"πͺ Steps: {prompt.steps} \n"
- f"π§βπ¨ CFG Scale: {prompt.cfg_scale} \n"
- f"π₯οΈ Size: {prompt.width}x{prompt.height} \n"
- f"π Restore faces: {'on' if prompt.restore_faces else 'off'} \n"
- f"βοΈ Sampler: {prompt.sampler}",
+ f"π€ Prompt: {prompt.prompt.prompt} \n"
+ f"π Negative: {prompt.prompt.negative_prompt} \n"
+ f"π« Model: {prompt.model} \n"
+ f"πͺ Steps: {prompt.prompt.steps} \n"
+ f"π§βπ¨ CFG Scale: {prompt.prompt.cfg_scale} \n"
+ f"π₯οΈ Size: {prompt.prompt.width}x{prompt.prompt.height} \n"
+ f"π Restore faces: {'on' if prompt.prompt.restore_faces else 'off'} \n"
+ f"βοΈ Sampler: {prompt.prompt.sampler} \n"
+ f"π± Seed: {prompt.seed}",
parse_mode='html',
reply_markup=get_img_back_keyboard(p_id)
)
@@ -63,7 +65,7 @@ async def on_import(call: types.CallbackQuery, callback_data: dict):
if await other_user(call):
return
- prompt: Prompt = db[DBTables.generated].get(p_id)
+ prompt: Generated = db[DBTables.generated].get(p_id)
await call.message.edit_text(
f"πΆβπ«οΈ Not implemented yet",
diff --git a/bot/handlers/help_command/help_strings.py b/bot/handlers/help_command/help_strings.py
index a224509..d8b8962 100644
--- a/bot/handlers/help_command/help_strings.py
+++ b/bot/handlers/help_command/help_strings.py
@@ -1,4 +1,15 @@
help_data = {
- 'setendpoint': '(admin) Set StableDiffusion API endpoint',
- 'imginfo': 'Get information about image, that was generated using this bot'
+ 'generate': 'Generate picture using configuration set by user. You can pass prompt also in command arguments or '
+ 'use it without arguments to generate picture with prompt, that was used for last generation',
+ 'imginfo': 'Get information about image, that was generated using this bot',
+ 'setprompt': 'Set default prompt for images, will be overwritten if you specify prompt in generate command',
+ 'setnegative': 'Set negative prompt',
+ 'setsize': 'Set size for image (in hxw format)',
+ 'setwidth': 'Set width for image',
+ 'setheight': 'Set height for image',
+ 'setsteps': 'Set sampling steps number',
+ 'setsampler': 'Set StableDiffusion sampler',
+ 'setscale': 'Set CFG Scale (prompt stringency)',
+ 'setfaces': 'Set restore faces mode',
+ 'setendpoint': '(admin) Set StableDiffusion API endpoint'
}
diff --git a/bot/handlers/txt2img/__init__.py b/bot/handlers/txt2img/__init__.py
index f75dcaf..520b4e8 100644
--- a/bot/handlers/txt2img/__init__.py
+++ b/bot/handlers/txt2img/__init__.py
@@ -1,6 +1,19 @@
from bot.common import dp
-from .txt2img import txt2img_comand
+from .txt2img import generate_command
+from .set_settings import (
+ set_height_command, set_negative_prompt_command, set_size_command, set_steps_command, set_width_command,
+ set_prompt_command, set_sampler_command, set_cfg_scale_command, set_restore_faces_command
+)
def register():
- dp.register_message_handler(txt2img.txt2img_comand, commands='txt2img')
+ dp.register_message_handler(txt2img.generate_command, commands='generate')
+ dp.register_message_handler(set_settings.set_prompt_command, commands='setprompt')
+ dp.register_message_handler(set_settings.set_height_command, commands='setheight')
+ dp.register_message_handler(set_settings.set_width_command, commands='setwidth')
+ dp.register_message_handler(set_settings.set_negative_prompt_command, commands='setnegative')
+ dp.register_message_handler(set_settings.set_size_command, commands='setsize')
+ dp.register_message_handler(set_settings.set_steps_command, commands='setsteps')
+ dp.register_message_handler(set_settings.set_sampler_command, commands='setsampler')
+ dp.register_message_handler(set_settings.set_cfg_scale_command, commands='setscale')
+ dp.register_message_handler(set_settings.set_restore_faces_command, commands='setfaces')
diff --git a/bot/handlers/txt2img/set_settings.py b/bot/handlers/txt2img/set_settings.py
index 91883e9..8c9c3ef 100644
--- a/bot/handlers/txt2img/set_settings.py
+++ b/bot/handlers/txt2img/set_settings.py
@@ -6,22 +6,28 @@ from bot.keyboards.exception import get_exception_keyboard
from bot.utils.trace_exception import PrettyException
-@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
-async def set_prompt_command(message: types.Message):
- temp_message = await message.reply("β³ Setting prompt...")
+async def _set_property(message: types.Message, prop: str, value=None):
+ temp_message = await message.reply(f"β³ Setting {prop}...")
if not message.get_args():
- await temp_message.edit_text("πΆβπ«οΈ Specify prompt for this command. Check /help setprompt")
+ await temp_message.edit_text("πΆβπ«οΈ Specify arguments for this command. Check /help")
return
try:
- prompt: Prompt = db[DBTables.prompts].get(message.from_id, Prompt(message.get_args()))
- prompt.prompt = message.get_args()
+ prompt: Prompt = db[DBTables.prompts].get(message.from_id)
+ if prompt is None and prop != 'prompt':
+ await temp_message.edit_text(f"You didn't created any prompt. Specify prompt text at least first time. "
+ f"For example, it can be: masterpiece, best quality, 1girl, white hair, "
+ f"medium hair, cat ears, closed eyes, looking at viewer, :3, cute, scarf, "
+ f"jacket, outdoors, streets", parse_mode='HTML')
+ return
+
+ prompt.__setattr__(prop, message.get_args() if value is None else value)
prompt.creator = message.from_id
db[DBTables.prompts][message.from_id] = prompt
await db[DBTables.config].write()
- await message.reply('β
Default prompt set')
+ await message.reply(f'β
{prop} set')
await temp_message.delete()
except Exception as e:
@@ -32,3 +38,134 @@ async def set_prompt_command(message: types.Message):
await temp_message.delete()
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
return
+
+
+@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
+async def set_prompt_command(message: types.Message):
+ await _set_property(message, 'prompt')
+
+
+@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
+async def set_negative_prompt_command(message: types.Message):
+ await _set_property(message, 'negative_prompt')
+
+
+@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
+async def set_steps_command(message: types.Message):
+ try:
+ _ = int(message.get_args())
+ except Exception as e:
+ assert e
+ await message.reply('β Specify number as argument')
+ return
+
+ if _ > 30:
+ await message.reply('β Specify number <= 30')
+ return
+
+ await _set_property(message, 'steps')
+
+
+@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
+async def set_cfg_scale_command(message: types.Message):
+ try:
+ _ = int(message.get_args())
+ except Exception as e:
+ assert e
+ await message.reply('β Specify number as argument')
+ return
+
+ await _set_property(message, 'cfg_scale')
+
+
+@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
+async def set_width_command(message: types.Message):
+ try:
+ _ = int(message.get_args())
+ except Exception as e:
+ assert e
+ await message.reply('β Specify number as argument')
+ return
+
+ if _ > 768:
+ await message.reply('β Specify number <= 768')
+ return
+
+ await _set_property(message, 'width')
+
+
+@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
+async def set_height_command(message: types.Message):
+ try:
+ _ = int(message.get_args())
+ except Exception as e:
+ assert e
+ await message.reply('β Specify number as argument')
+ return
+
+ if _ > 768:
+ await message.reply('β Specify number <= 768')
+ return
+
+ await _set_property(message, 'height')
+
+
+@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
+async def set_restore_faces_command(message: types.Message):
+ try:
+ _ = bool(message.get_args())
+ except Exception as e:
+ assert e
+ await message.reply('β Specify boolean True/False as argument',
+ parse_mode='HTML')
+ return
+
+ await _set_property(message, 'restore_faces')
+
+
+@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
+async def set_sampler_command(message: types.Message):
+ temp_message = await message.reply('β³ Getting samplers...')
+ from bot.modules.api.samplers import get_samplers
+ from aiohttp import ClientConnectorError
+ try:
+ if message.get_args() not in (samplers := await get_samplers()):
+ await message.reply(
+ f'β You can use only {", ".join(f"{x}" for x in samplers)}',
+ parse_mode='HTML'
+ )
+ return
+ await temp_message.delete()
+ except ClientConnectorError:
+ await message.reply('β Error! Maybe, StableDiffusion API endpoint is incorrect '
+ 'or turned off')
+ except Exception as e:
+ exception_id = f'{message.message_thread_id}-{message.message_id}'
+ db[DBTables.exceptions][exception_id] = PrettyException(e)
+ await message.reply('β Error happened while processing your request', parse_mode='html',
+ reply_markup=get_exception_keyboard(exception_id))
+ await temp_message.delete()
+ db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
+ return
+
+ await _set_property(message, 'sampler')
+
+
+@throttle(cooldown=5, admin_ids=db[DBTables.config].get('admins'))
+async def set_size_command(message: types.Message):
+ try:
+ hxw = message.get_args().split('x')
+ height = int(hxw[0])
+ width = int(hxw[1])
+ except Exception as e:
+ assert e
+ await message.reply('β Specify size in hxw format, for example 512x512',
+ parse_mode='HTML')
+ return
+
+ if height > 768 or width > 768:
+ await message.reply('β Specify numbers <= 768')
+ return
+
+ await _set_property(message, 'height', height)
+ await _set_property(message, 'width', width)
diff --git a/bot/handlers/txt2img/txt2img.py b/bot/handlers/txt2img/txt2img.py
index 47c7a34..680a6ac 100644
--- a/bot/handlers/txt2img/txt2img.py
+++ b/bot/handlers/txt2img/txt2img.py
@@ -1,24 +1,30 @@
+import re
from aiogram import types
from bot.db import db, DBTables
from bot.utils.cooldown import throttle
from bot.modules.api.txt2img import txt2img
-from bot.modules.api.objects.prompt_request import Prompt
+from bot.modules.api.objects.get_prompt import get_prompt
+from bot.modules.api.objects.prompt_request import Generated
from bot.modules.api.status import wait_for_status
from bot.keyboards.exception import get_exception_keyboard
+from bot.keyboards.image_info import get_img_info_keyboard
from bot.utils.trace_exception import PrettyException
from aiohttp import ClientConnectorError
@throttle(cooldown=30, admin_ids=db[DBTables.config].get('admins'))
-async def txt2img_comand(message: types.Message):
+async def generate_command(message: types.Message):
temp_message = await message.reply("β³ Enqueued...")
- prompt: Prompt = db[DBTables.prompts].get(message.from_id)
- if not prompt:
- if message.get_args():
- db[DBTables.prompts][message.from_id] = Prompt(message.get_args(), creator=message.from_id)
-
- # TODO: Move it to other module
+ try:
+ prompt = get_prompt(user_id=message.from_id,
+ prompt_string=message.get_args())
+ except AttributeError:
+ await temp_message.edit_text(f"You didn't created any prompt. Specify prompt text at least first time. "
+ f"For example, it can be: masterpiece, best quality, 1girl, white hair, "
+ f"medium hair, cat ears, closed eyes, looking at viewer, :3, cute, scarf, jacket, "
+ f"outdoors, streets", parse_mode='HTML')
+ return
try:
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 0) + 1
@@ -27,12 +33,18 @@ async def txt2img_comand(message: types.Message):
await wait_for_status()
await temp_message.edit_text(f"β Generating...")
- prompt = Prompt(prompt=message.get_args(), creator=message.from_id)
image = await txt2img(prompt)
image_message = await message.reply_photo(photo=image[0])
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
- db[DBTables.generated][image_message.photo[0].file_unique_id] = prompt
+ db[DBTables.generated][image_message.photo[0].file_unique_id] = Generated(
+ prompt=prompt,
+ seed=image[1]['seed'],
+ model=re.search(r", Model: ([^,]+),", image[1]['infotexts'][0]).groups()[0]
+ )
+
+ await message.reply('Here is your image',
+ reply_markup=get_img_info_keyboard(image_message.photo[0].file_unique_id))
await temp_message.delete()
@@ -45,6 +57,12 @@ async def txt2img_comand(message: types.Message):
db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
return
+ except ValueError as e:
+ await message.reply(f'β Error! {e.args[0]}')
+ await temp_message.delete()
+ db[DBTables.queue]['n'] = db[DBTables.queue].get('n', 1) - 1
+ return
+
except Exception as e:
exception_id = f'{message.message_thread_id}-{message.message_id}'
db[DBTables.exceptions][exception_id] = PrettyException(e)
diff --git a/bot/modules/api/models.py b/bot/modules/api/models.py
index bdf73cc..e36927b 100644
--- a/bot/modules/api/models.py
+++ b/bot/modules/api/models.py
@@ -1,31 +1,22 @@
import aiohttp
from bot.db import db, DBTables, decrypt
-from rich import print
async def get_models():
endpoint = decrypt(db[DBTables.config].get('endpoint'))
- try:
- async with aiohttp.ClientSession() as session:
- r = await session.get(endpoint + "/sdapi/v1/sd-models")
- if r.status != 200:
- return None
- return [x["title"] for x in await r.json()]
- except Exception as e:
- print(e)
- return None
+ async with aiohttp.ClientSession() as session:
+ r = await session.get(endpoint + "/sdapi/v1/sd-models")
+ if r.status != 200:
+ return None
+ return [x["title"] for x in await r.json()]
async def set_model(model_name: str):
endpoint = decrypt(db[DBTables.config].get('endpoint'))
- try:
- async with aiohttp.ClientSession() as session:
- r = await session.post(endpoint + "/sdapi/v1/options", json={
- "sd_model_checkpoint": model_name
- })
- if r.status != 200:
- return False
- return True
- except Exception as e:
- print(e)
- return False
+ async with aiohttp.ClientSession() as session:
+ r = await session.post(endpoint + "/sdapi/v1/options", json={
+ "sd_model_checkpoint": model_name
+ })
+ if r.status != 200:
+ return False
+ return True
diff --git a/bot/modules/api/objects/get_prompt.py b/bot/modules/api/objects/get_prompt.py
new file mode 100644
index 0000000..2cecf52
--- /dev/null
+++ b/bot/modules/api/objects/get_prompt.py
@@ -0,0 +1,38 @@
+from bot.modules.api.objects.prompt_request import Prompt
+from bot.db import db, DBTables
+
+
+def get_prompt(user_id: int, prompt_string: str = None, negative_prompt: str = None, steps: int = None,
+ cfg_scale: int = None, width: int = None, height: int = None, restore_faces: bool = None,
+ sampler: str = None) -> Prompt:
+ new_prompt: Prompt = db[DBTables.prompts].get(user_id)
+ creator = user_id
+ if not new_prompt:
+ if user_id and prompt_string:
+ db[DBTables.prompts][user_id] = Prompt(
+ prompt=prompt_string,
+ creator=user_id,
+ negative_prompt=negative_prompt,
+ steps=steps,
+ cfg_scale=cfg_scale,
+ width=width,
+ height=height,
+ restore_faces=restore_faces,
+ sampler=sampler,
+ )
+ new_prompt: Prompt = db[DBTables.prompts].get(user_id)
+ else:
+ raise AttributeError('No prompt string specified and prompt doesn\'t exist for this user')
+
+ if prompt_string:
+ new_prompt.prompt = prompt_string
+
+ for key in new_prompt.__dict__.keys():
+ if key in locals().keys() and locals()[key]:
+ new_prompt.__setattr__(key, locals()[key])
+ elif not new_prompt.__getattribute__(key):
+ new_prompt.__setattr__(key, Prompt('').__getattribute__(key))
+
+ db[DBTables.prompts][user_id] = new_prompt
+
+ return new_prompt
diff --git a/bot/modules/api/objects/prompt_request.py b/bot/modules/api/objects/prompt_request.py
index 3621d05..841c6c6 100644
--- a/bot/modules/api/objects/prompt_request.py
+++ b/bot/modules/api/objects/prompt_request.py
@@ -12,3 +12,10 @@ class Prompt:
restore_faces: bool = True
sampler: str = "Euler a"
creator: int = None
+
+
+@dataclasses.dataclass
+class Generated:
+ prompt: Prompt
+ seed: int
+ model: str
diff --git a/bot/modules/api/samplers.py b/bot/modules/api/samplers.py
new file mode 100644
index 0000000..4ebbc34
--- /dev/null
+++ b/bot/modules/api/samplers.py
@@ -0,0 +1,11 @@
+import aiohttp
+from bot.db import db, DBTables, decrypt
+
+
+async def get_samplers():
+ endpoint = decrypt(db[DBTables.config].get('endpoint'))
+ async with aiohttp.ClientSession() as session:
+ r = await session.get(endpoint + "/sdapi/v1/samplers")
+ if r.status != 200:
+ return None
+ return [x["name"] for x in await r.json()]
diff --git a/bot/modules/api/txt2img.py b/bot/modules/api/txt2img.py
index 29982a9..0b77ac4 100644
--- a/bot/modules/api/txt2img.py
+++ b/bot/modules/api/txt2img.py
@@ -22,8 +22,10 @@ async def txt2img(prompt: Prompt, ignore_exceptions: bool = False) -> list[bytes
"sampler_index": prompt.sampler
}
)
- if r.status != 200:
+ if r.status != 200 and ignore_exceptions:
return None
+ elif r.status != 200:
+ raise ValueError((await r.json())['detail'])
return [base64.b64decode((await r.json())["images"][0]),
json.loads((await r.json())["info"])]
except Exception as e: