Setting generation properties
This commit is contained in:
@@ -1,31 +1,22 @@
|
||||
import aiohttp
|
||||
from bot.db import db, DBTables, decrypt
|
||||
from rich import print
|
||||
|
||||
|
||||
async def get_models():
|
||||
endpoint = decrypt(db[DBTables.config].get('endpoint'))
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
r = await session.get(endpoint + "/sdapi/v1/sd-models")
|
||||
if r.status != 200:
|
||||
return None
|
||||
return [x["title"] for x in await r.json()]
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
async with aiohttp.ClientSession() as session:
|
||||
r = await session.get(endpoint + "/sdapi/v1/sd-models")
|
||||
if r.status != 200:
|
||||
return None
|
||||
return [x["title"] for x in await r.json()]
|
||||
|
||||
|
||||
async def set_model(model_name: str):
|
||||
endpoint = decrypt(db[DBTables.config].get('endpoint'))
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
r = await session.post(endpoint + "/sdapi/v1/options", json={
|
||||
"sd_model_checkpoint": model_name
|
||||
})
|
||||
if r.status != 200:
|
||||
return False
|
||||
return True
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return False
|
||||
async with aiohttp.ClientSession() as session:
|
||||
r = await session.post(endpoint + "/sdapi/v1/options", json={
|
||||
"sd_model_checkpoint": model_name
|
||||
})
|
||||
if r.status != 200:
|
||||
return False
|
||||
return True
|
||||
|
||||
38
bot/modules/api/objects/get_prompt.py
Normal file
38
bot/modules/api/objects/get_prompt.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from bot.modules.api.objects.prompt_request import Prompt
|
||||
from bot.db import db, DBTables
|
||||
|
||||
|
||||
def get_prompt(user_id: int, prompt_string: str = None, negative_prompt: str = None, steps: int = None,
|
||||
cfg_scale: int = None, width: int = None, height: int = None, restore_faces: bool = None,
|
||||
sampler: str = None) -> Prompt:
|
||||
new_prompt: Prompt = db[DBTables.prompts].get(user_id)
|
||||
creator = user_id
|
||||
if not new_prompt:
|
||||
if user_id and prompt_string:
|
||||
db[DBTables.prompts][user_id] = Prompt(
|
||||
prompt=prompt_string,
|
||||
creator=user_id,
|
||||
negative_prompt=negative_prompt,
|
||||
steps=steps,
|
||||
cfg_scale=cfg_scale,
|
||||
width=width,
|
||||
height=height,
|
||||
restore_faces=restore_faces,
|
||||
sampler=sampler,
|
||||
)
|
||||
new_prompt: Prompt = db[DBTables.prompts].get(user_id)
|
||||
else:
|
||||
raise AttributeError('No prompt string specified and prompt doesn\'t exist for this user')
|
||||
|
||||
if prompt_string:
|
||||
new_prompt.prompt = prompt_string
|
||||
|
||||
for key in new_prompt.__dict__.keys():
|
||||
if key in locals().keys() and locals()[key]:
|
||||
new_prompt.__setattr__(key, locals()[key])
|
||||
elif not new_prompt.__getattribute__(key):
|
||||
new_prompt.__setattr__(key, Prompt('').__getattribute__(key))
|
||||
|
||||
db[DBTables.prompts][user_id] = new_prompt
|
||||
|
||||
return new_prompt
|
||||
@@ -12,3 +12,10 @@ class Prompt:
|
||||
restore_faces: bool = True
|
||||
sampler: str = "Euler a"
|
||||
creator: int = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Generated:
|
||||
prompt: Prompt
|
||||
seed: int
|
||||
model: str
|
||||
|
||||
11
bot/modules/api/samplers.py
Normal file
11
bot/modules/api/samplers.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import aiohttp
|
||||
from bot.db import db, DBTables, decrypt
|
||||
|
||||
|
||||
async def get_samplers():
|
||||
endpoint = decrypt(db[DBTables.config].get('endpoint'))
|
||||
async with aiohttp.ClientSession() as session:
|
||||
r = await session.get(endpoint + "/sdapi/v1/samplers")
|
||||
if r.status != 200:
|
||||
return None
|
||||
return [x["name"] for x in await r.json()]
|
||||
@@ -22,8 +22,10 @@ async def txt2img(prompt: Prompt, ignore_exceptions: bool = False) -> list[bytes
|
||||
"sampler_index": prompt.sampler
|
||||
}
|
||||
)
|
||||
if r.status != 200:
|
||||
if r.status != 200 and ignore_exceptions:
|
||||
return None
|
||||
elif r.status != 200:
|
||||
raise ValueError((await r.json())['detail'])
|
||||
return [base64.b64decode((await r.json())["images"][0]),
|
||||
json.loads((await r.json())["info"])]
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user