58 lines
1.6 KiB
Python
58 lines
1.6 KiB
Python
from nicegui import ui, app, run
|
|
from io import BytesIO
|
|
from ...modules.encoder.encoder import load
|
|
|
|
|
|
def creator(storage):
|
|
import torch
|
|
from osu_dreamer.model import Model
|
|
from ...modules.generate import generate_mapset
|
|
from rich.traceback import install
|
|
|
|
install(show_locals=True)
|
|
|
|
model = Model.load_from_checkpoint(
|
|
f'models/{storage["model_name"]}',
|
|
sample_steps=storage['sample_steps'],
|
|
).eval()
|
|
|
|
if torch.cuda.is_available():
|
|
model = model.cuda()
|
|
|
|
audio = BytesIO(load(storage['audio_content']))
|
|
audio.name = storage['filename']
|
|
audio.seek(0)
|
|
return generate_mapset(
|
|
model,
|
|
audio,
|
|
storage['bpm'],
|
|
storage['num_samples'],
|
|
storage['detected_title'], storage['detected_artist'],
|
|
).name
|
|
|
|
|
|
async def on_create():
|
|
storage = app.storage.user
|
|
|
|
storage['is_loading'] = True
|
|
storage['can_be_created'] = False
|
|
|
|
try:
|
|
mapset_path = await run.cpu_bound(creator, dict(
|
|
model_name=storage['model_name'],
|
|
audio_content=storage['audio_content'],
|
|
bpm=int(storage['bpm']),
|
|
num_samples=int(storage['num_samples']),
|
|
sample_steps=int(storage['sample_steps']),
|
|
detected_title=storage['detected_title'],
|
|
detected_artist=storage['detected_artist'],
|
|
filename=storage['filename'],
|
|
))
|
|
storage['mapset_path'] = mapset_path
|
|
storage['can_be_saved'] = True
|
|
except Exception as e:
|
|
ui.notify(f'Error {e}')
|
|
|
|
storage['is_loading'] = False
|
|
storage['can_be_created'] = True
|