Files
osu-dreamer-gui/osu_dreamer_gui/gui/handlers/on_create.py
2023-11-17 11:43:12 +02:00

58 lines
1.6 KiB
Python

from nicegui import ui, app, run
from io import BytesIO
from ...modules.encoder.encoder import load
def creator(storage):
import torch
from osu_dreamer.model import Model
from ...modules.generate import generate_mapset
from rich.traceback import install
install(show_locals=True)
model = Model.load_from_checkpoint(
f'models/{storage["model_name"]}',
sample_steps=storage['sample_steps'],
).eval()
if torch.cuda.is_available():
model = model.cuda()
audio = BytesIO(load(storage['audio_content']))
audio.name = storage['filename']
audio.seek(0)
return generate_mapset(
model,
audio,
storage['bpm'],
storage['num_samples'],
storage['detected_title'], storage['detected_artist'],
).name
async def on_create():
storage = app.storage.user
storage['is_loading'] = True
storage['can_be_created'] = False
try:
mapset_path = await run.cpu_bound(creator, dict(
model_name=storage['model_name'],
audio_content=storage['audio_content'],
bpm=int(storage['bpm']),
num_samples=int(storage['num_samples']),
sample_steps=int(storage['sample_steps']),
detected_title=storage['detected_title'],
detected_artist=storage['detected_artist'],
filename=storage['filename'],
))
storage['mapset_path'] = mapset_path
storage['can_be_saved'] = True
except Exception as e:
ui.notify(f'Error {e}')
storage['is_loading'] = False
storage['can_be_created'] = True