Setting generation properties

This commit is contained in:
BarsTiger
2023-02-26 22:16:05 +02:00
parent 8f0cd0ac05
commit 3c3d644137
10 changed files with 286 additions and 56 deletions

View File

@@ -1,31 +1,22 @@
import aiohttp
from bot.db import db, DBTables, decrypt
from rich import print
async def get_models():
endpoint = decrypt(db[DBTables.config].get('endpoint'))
try:
async with aiohttp.ClientSession() as session:
r = await session.get(endpoint + "/sdapi/v1/sd-models")
if r.status != 200:
return None
return [x["title"] for x in await r.json()]
except Exception as e:
print(e)
return None
async with aiohttp.ClientSession() as session:
r = await session.get(endpoint + "/sdapi/v1/sd-models")
if r.status != 200:
return None
return [x["title"] for x in await r.json()]
async def set_model(model_name: str):
endpoint = decrypt(db[DBTables.config].get('endpoint'))
try:
async with aiohttp.ClientSession() as session:
r = await session.post(endpoint + "/sdapi/v1/options", json={
"sd_model_checkpoint": model_name
})
if r.status != 200:
return False
return True
except Exception as e:
print(e)
return False
async with aiohttp.ClientSession() as session:
r = await session.post(endpoint + "/sdapi/v1/options", json={
"sd_model_checkpoint": model_name
})
if r.status != 200:
return False
return True

View File

@@ -0,0 +1,38 @@
from bot.modules.api.objects.prompt_request import Prompt
from bot.db import db, DBTables
def get_prompt(user_id: int, prompt_string: str = None, negative_prompt: str = None, steps: int = None,
cfg_scale: int = None, width: int = None, height: int = None, restore_faces: bool = None,
sampler: str = None) -> Prompt:
new_prompt: Prompt = db[DBTables.prompts].get(user_id)
creator = user_id
if not new_prompt:
if user_id and prompt_string:
db[DBTables.prompts][user_id] = Prompt(
prompt=prompt_string,
creator=user_id,
negative_prompt=negative_prompt,
steps=steps,
cfg_scale=cfg_scale,
width=width,
height=height,
restore_faces=restore_faces,
sampler=sampler,
)
new_prompt: Prompt = db[DBTables.prompts].get(user_id)
else:
raise AttributeError('No prompt string specified and prompt doesn\'t exist for this user')
if prompt_string:
new_prompt.prompt = prompt_string
for key in new_prompt.__dict__.keys():
if key in locals().keys() and locals()[key]:
new_prompt.__setattr__(key, locals()[key])
elif not new_prompt.__getattribute__(key):
new_prompt.__setattr__(key, Prompt('').__getattribute__(key))
db[DBTables.prompts][user_id] = new_prompt
return new_prompt

View File

@@ -12,3 +12,10 @@ class Prompt:
restore_faces: bool = True
sampler: str = "Euler a"
creator: int = None
@dataclasses.dataclass
class Generated:
prompt: Prompt
seed: int
model: str

View File

@@ -0,0 +1,11 @@
import aiohttp
from bot.db import db, DBTables, decrypt
async def get_samplers():
endpoint = decrypt(db[DBTables.config].get('endpoint'))
async with aiohttp.ClientSession() as session:
r = await session.get(endpoint + "/sdapi/v1/samplers")
if r.status != 200:
return None
return [x["name"] for x in await r.json()]

View File

@@ -22,8 +22,10 @@ async def txt2img(prompt: Prompt, ignore_exceptions: bool = False) -> list[bytes
"sampler_index": prompt.sampler
}
)
if r.status != 200:
if r.status != 200 and ignore_exceptions:
return None
elif r.status != 200:
raise ValueError((await r.json())['detail'])
return [base64.b64decode((await r.json())["images"][0]),
json.loads((await r.json())["info"])]
except Exception as e: