Initial commit
This commit is contained in:
1
osu_dreamer_gui/gui/__init__.py
Normal file
1
osu_dreamer_gui/gui/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .views import home
|
||||
0
osu_dreamer_gui/gui/elements/__init__.py
Normal file
0
osu_dreamer_gui/gui/elements/__init__.py
Normal file
29
osu_dreamer_gui/gui/elements/create_button.py
Normal file
29
osu_dreamer_gui/gui/elements/create_button.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from nicegui import ui, app
|
||||
from ..handlers.on_create import on_create
|
||||
|
||||
|
||||
class CreateButton:
|
||||
create_button: ui.button
|
||||
download_button: ui.button
|
||||
loading: ui.spinner
|
||||
|
||||
def place(self):
|
||||
with ui.row():
|
||||
self.create_button = ui.button(
|
||||
'Create map', on_click=on_create
|
||||
).bind_enabled_from(app.storage.user, 'can_be_created')
|
||||
|
||||
self.download_button = ui.button(
|
||||
'Save map',
|
||||
on_click=lambda: ui.download(
|
||||
src=app.storage.user.get('mapset_path')
|
||||
)
|
||||
).bind_enabled_from(app.storage.user, 'can_be_saved')
|
||||
|
||||
self.loading = ui.spinner(
|
||||
type='audio',
|
||||
size='2rem'
|
||||
).bind_visibility_from(app.storage.user, 'is_loading')
|
||||
|
||||
|
||||
createbutton = CreateButton()
|
||||
18
osu_dreamer_gui/gui/elements/mp3_choose.py
Normal file
18
osu_dreamer_gui/gui/elements/mp3_choose.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from nicegui import ui
|
||||
from ..handlers.mp3_choose_upload import on_upload
|
||||
|
||||
|
||||
class MP3Choose:
|
||||
uploader: ui.upload
|
||||
|
||||
def place(self):
|
||||
self.uploader = ui.upload(
|
||||
label='Upload audio',
|
||||
max_files=1,
|
||||
on_upload=on_upload,
|
||||
max_total_size=100 * 1_048_576,
|
||||
auto_upload=True,
|
||||
).classes('w-full')
|
||||
|
||||
|
||||
mp3choose = MP3Choose()
|
||||
54
osu_dreamer_gui/gui/elements/params_boxes.py
Normal file
54
osu_dreamer_gui/gui/elements/params_boxes.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from nicegui import ui
|
||||
from nicegui import app
|
||||
import os
|
||||
|
||||
|
||||
class ParamsBoxes:
|
||||
title: ui.input
|
||||
artist: ui.input
|
||||
bpm: ui.number
|
||||
num_samples = ui.number
|
||||
sample_steps = ui.number
|
||||
model_name = ui.input
|
||||
|
||||
def place(self):
|
||||
storage = app.storage.user
|
||||
|
||||
self.title = ui.input(
|
||||
label='Title',
|
||||
placeholder='Song title',
|
||||
).bind_value(storage, 'detected_title').classes('w-full')
|
||||
|
||||
self.artist = ui.input(
|
||||
label='Artist',
|
||||
placeholder='Song artist',
|
||||
).bind_value(storage, 'detected_artist').classes('w-full')
|
||||
|
||||
self.bpm = ui.number(
|
||||
label='BPM',
|
||||
placeholder='Song BPM',
|
||||
value=180,
|
||||
).bind_value(storage, 'bpm').classes('w-full')
|
||||
|
||||
self.num_samples = ui.number(
|
||||
label='Number of samples',
|
||||
placeholder='Number of samples',
|
||||
value=2,
|
||||
).classes('w-full').bind_value(storage, 'num_samples')
|
||||
|
||||
self.sample_steps = ui.number(
|
||||
label='Sample steps',
|
||||
placeholder='Sample steps',
|
||||
value=32,
|
||||
).classes('w-full').bind_value(storage, 'sample_steps')
|
||||
|
||||
self.model_name = ui.input(
|
||||
label='Model name',
|
||||
placeholder='Model name',
|
||||
value=[item for item in filter(
|
||||
lambda x: x.endswith('.ckpt'), os.listdir('models')
|
||||
)][0],
|
||||
).classes('w-full').bind_value(storage, 'model_name')
|
||||
|
||||
|
||||
paramsboxes = ParamsBoxes()
|
||||
0
osu_dreamer_gui/gui/handlers/__init__.py
Normal file
0
osu_dreamer_gui/gui/handlers/__init__.py
Normal file
29
osu_dreamer_gui/gui/handlers/mp3_choose_upload.py
Normal file
29
osu_dreamer_gui/gui/handlers/mp3_choose_upload.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from nicegui import ui, events, app
|
||||
import mutagen
|
||||
|
||||
from ...modules.encoder.encoder import dump
|
||||
|
||||
|
||||
async def on_upload(e: events.UploadEventArguments):
|
||||
storage = app.storage.user
|
||||
|
||||
storage['bpm'] = None
|
||||
storage['detected_title'] = None
|
||||
storage['detected_artist'] = None
|
||||
|
||||
tags = mutagen.File(e.content, easy=True)
|
||||
|
||||
if 'bpm' in tags:
|
||||
storage['bpm'] = int(float(tags['bpm'][0]))
|
||||
|
||||
if 'title' in tags:
|
||||
storage['detected_title'] = tags['title'][0]
|
||||
|
||||
if 'artist' in tags:
|
||||
storage['detected_artist'] = tags['artist'][0]
|
||||
|
||||
storage['filename'] = e.name
|
||||
storage['audio_content'] = dump(e.content.read())
|
||||
|
||||
storage['can_be_created'] = True
|
||||
storage['can_be_saved'] = False
|
||||
57
osu_dreamer_gui/gui/handlers/on_create.py
Normal file
57
osu_dreamer_gui/gui/handlers/on_create.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from nicegui import ui, app, run
|
||||
from io import BytesIO
|
||||
from ...modules.encoder.encoder import load
|
||||
|
||||
|
||||
def creator(storage):
|
||||
import torch
|
||||
from osu_dreamer.model import Model
|
||||
from ...modules.generate import generate_mapset
|
||||
from rich.traceback import install
|
||||
|
||||
install(show_locals=True)
|
||||
|
||||
model = Model.load_from_checkpoint(
|
||||
f'models/{storage["model_name"]}',
|
||||
sample_steps=storage['sample_steps'],
|
||||
).eval()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
model = model.cuda()
|
||||
|
||||
audio = BytesIO(load(storage['audio_content']))
|
||||
audio.name = storage['filename']
|
||||
audio.seek(0)
|
||||
return generate_mapset(
|
||||
model,
|
||||
audio,
|
||||
storage['bpm'],
|
||||
storage['num_samples'],
|
||||
storage['detected_title'], storage['detected_artist'],
|
||||
).name
|
||||
|
||||
|
||||
async def on_create():
|
||||
storage = app.storage.user
|
||||
|
||||
storage['is_loading'] = True
|
||||
storage['can_be_created'] = False
|
||||
|
||||
try:
|
||||
mapset_path = await run.cpu_bound(creator, dict(
|
||||
model_name=storage['model_name'],
|
||||
audio_content=storage['audio_content'],
|
||||
bpm=storage['bpm'],
|
||||
num_samples=storage['num_samples'],
|
||||
sample_steps=storage['sample_steps'],
|
||||
detected_title=storage['detected_title'],
|
||||
detected_artist=storage['detected_artist'],
|
||||
filename=storage['filename'],
|
||||
))
|
||||
storage['mapset_path'] = mapset_path
|
||||
storage['can_be_saved'] = True
|
||||
except Exception as e:
|
||||
ui.notify(f'Error {e}')
|
||||
|
||||
storage['is_loading'] = False
|
||||
storage['can_be_created'] = True
|
||||
0
osu_dreamer_gui/gui/views/__init__.py
Normal file
0
osu_dreamer_gui/gui/views/__init__.py
Normal file
15
osu_dreamer_gui/gui/views/home.py
Normal file
15
osu_dreamer_gui/gui/views/home.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from nicegui import ui, app
|
||||
from ..elements.mp3_choose import mp3choose
|
||||
from ..elements.params_boxes import paramsboxes
|
||||
from ..elements.create_button import createbutton
|
||||
|
||||
|
||||
@ui.page('/')
|
||||
async def homepage():
|
||||
app.storage.user['can_be_created'] = False
|
||||
app.storage.user['is_loading'] = False
|
||||
app.storage.user['can_be_saved'] = False
|
||||
|
||||
mp3choose.place()
|
||||
paramsboxes.place()
|
||||
createbutton.place()
|
||||
Reference in New Issue
Block a user