from nicegui import ui, app, run from io import BytesIO from ...modules.encoder.encoder import load def creator(storage): import torch from osu_dreamer.model import Model from ...modules.generate import generate_mapset from rich.traceback import install install(show_locals=True) model = Model.load_from_checkpoint( f'models/{storage["model_name"]}', sample_steps=storage['sample_steps'], ).eval() if torch.cuda.is_available(): model = model.cuda() audio = BytesIO(load(storage['audio_content'])) audio.name = storage['filename'] audio.seek(0) return generate_mapset( model, audio, storage['bpm'], storage['num_samples'], storage['detected_title'], storage['detected_artist'], ).name async def on_create(): storage = app.storage.user storage['is_loading'] = True storage['can_be_created'] = False try: mapset_path = await run.cpu_bound(creator, dict( model_name=storage['model_name'], audio_content=storage['audio_content'], bpm=storage['bpm'], num_samples=storage['num_samples'], sample_steps=storage['sample_steps'], detected_title=storage['detected_title'], detected_artist=storage['detected_artist'], filename=storage['filename'], )) storage['mapset_path'] = mapset_path storage['can_be_saved'] = True except Exception as e: ui.notify(f'Error {e}') storage['is_loading'] = False storage['can_be_created'] = True