commit 653711aba39a9bdca7e3d2aacdc5c079cee460c2 Author: BarsTiger Date: Thu Nov 16 22:49:59 2023 +0200 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6d470ec --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +/build +/dist +/models/* + +*.osz +poetry.lock diff --git a/README.md b/README.md new file mode 100644 index 0000000..94c38be --- /dev/null +++ b/README.md @@ -0,0 +1,5 @@ +# osu!dreamer gui + +Interface for [osu!dreamer](https://github.com/jaswon/osu-dreamer) (inference only) + +Just PoC, don't take code quality seriously diff --git a/dreamer.spec b/dreamer.spec new file mode 100644 index 0000000..480eee5 --- /dev/null +++ b/dreamer.spec @@ -0,0 +1,45 @@ +# -*- mode: python ; coding: utf-8 -*- +import nicegui +from pathlib import Path + + +a = Analysis( + ['dreamer_gui.py'], + pathex=[], + binaries=[], + datas=[(Path(nicegui.__file__).parent, 'nicegui')], + hiddenimports=[], + hookspath=[], + hooksconfig={}, + runtime_hooks=[], + excludes=[], + noarchive=False, +) +pyz = PYZ(a.pure) + +exe = EXE( + pyz, + a.scripts, + [], + exclude_binaries=True, + name='osu!dreamer-gui', + debug=False, + bootloader_ignore_signals=False, + strip=False, + upx=True, + console=False, + disable_windowed_traceback=False, + argv_emulation=False, + target_arch=None, + codesign_identity=None, + entitlements_file=None, +) +coll = COLLECT( + exe, + a.binaries, + a.datas, + strip=False, + upx=True, + upx_exclude=[], + name='osu!dreamer-gui', +) diff --git a/dreamer_gui.py b/dreamer_gui.py new file mode 100644 index 0000000..6293a76 --- /dev/null +++ b/dreamer_gui.py @@ -0,0 +1,5 @@ +from osu_dreamer_gui import main + + +if __name__ in {'__main__', '__mp_main__'}: + main() diff --git a/osu_dreamer_gui/__init__.py b/osu_dreamer_gui/__init__.py new file mode 100644 index 0000000..6b33fdc --- /dev/null +++ b/osu_dreamer_gui/__init__.py @@ -0,0 +1,17 @@ +from nicegui import ui + +from . import gui + +from rich.traceback import install + + +def main(): + install(show_locals=True) + ui.run( + title='osu!dreamer', + native=True, + dark=True, + reload=False, + storage_secret='...', + window_size=(800, 670), + ) diff --git a/osu_dreamer_gui/gui/__init__.py b/osu_dreamer_gui/gui/__init__.py new file mode 100644 index 0000000..c04565f --- /dev/null +++ b/osu_dreamer_gui/gui/__init__.py @@ -0,0 +1 @@ +from .views import home diff --git a/osu_dreamer_gui/gui/elements/__init__.py b/osu_dreamer_gui/gui/elements/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/osu_dreamer_gui/gui/elements/create_button.py b/osu_dreamer_gui/gui/elements/create_button.py new file mode 100644 index 0000000..ab711b1 --- /dev/null +++ b/osu_dreamer_gui/gui/elements/create_button.py @@ -0,0 +1,29 @@ +from nicegui import ui, app +from ..handlers.on_create import on_create + + +class CreateButton: + create_button: ui.button + download_button: ui.button + loading: ui.spinner + + def place(self): + with ui.row(): + self.create_button = ui.button( + 'Create map', on_click=on_create + ).bind_enabled_from(app.storage.user, 'can_be_created') + + self.download_button = ui.button( + 'Save map', + on_click=lambda: ui.download( + src=app.storage.user.get('mapset_path') + ) + ).bind_enabled_from(app.storage.user, 'can_be_saved') + + self.loading = ui.spinner( + type='audio', + size='2rem' + ).bind_visibility_from(app.storage.user, 'is_loading') + + +createbutton = CreateButton() diff --git a/osu_dreamer_gui/gui/elements/mp3_choose.py b/osu_dreamer_gui/gui/elements/mp3_choose.py new file mode 100644 index 0000000..5027d60 --- /dev/null +++ b/osu_dreamer_gui/gui/elements/mp3_choose.py @@ -0,0 +1,18 @@ +from nicegui import ui +from ..handlers.mp3_choose_upload import on_upload + + +class MP3Choose: + uploader: ui.upload + + def place(self): + self.uploader = ui.upload( + label='Upload audio', + max_files=1, + on_upload=on_upload, + max_total_size=100 * 1_048_576, + auto_upload=True, + ).classes('w-full') + + +mp3choose = MP3Choose() diff --git a/osu_dreamer_gui/gui/elements/params_boxes.py b/osu_dreamer_gui/gui/elements/params_boxes.py new file mode 100644 index 0000000..dfcd889 --- /dev/null +++ b/osu_dreamer_gui/gui/elements/params_boxes.py @@ -0,0 +1,54 @@ +from nicegui import ui +from nicegui import app +import os + + +class ParamsBoxes: + title: ui.input + artist: ui.input + bpm: ui.number + num_samples = ui.number + sample_steps = ui.number + model_name = ui.input + + def place(self): + storage = app.storage.user + + self.title = ui.input( + label='Title', + placeholder='Song title', + ).bind_value(storage, 'detected_title').classes('w-full') + + self.artist = ui.input( + label='Artist', + placeholder='Song artist', + ).bind_value(storage, 'detected_artist').classes('w-full') + + self.bpm = ui.number( + label='BPM', + placeholder='Song BPM', + value=180, + ).bind_value(storage, 'bpm').classes('w-full') + + self.num_samples = ui.number( + label='Number of samples', + placeholder='Number of samples', + value=2, + ).classes('w-full').bind_value(storage, 'num_samples') + + self.sample_steps = ui.number( + label='Sample steps', + placeholder='Sample steps', + value=32, + ).classes('w-full').bind_value(storage, 'sample_steps') + + self.model_name = ui.input( + label='Model name', + placeholder='Model name', + value=[item for item in filter( + lambda x: x.endswith('.ckpt'), os.listdir('models') + )][0], + ).classes('w-full').bind_value(storage, 'model_name') + + +paramsboxes = ParamsBoxes() diff --git a/osu_dreamer_gui/gui/handlers/__init__.py b/osu_dreamer_gui/gui/handlers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/osu_dreamer_gui/gui/handlers/mp3_choose_upload.py b/osu_dreamer_gui/gui/handlers/mp3_choose_upload.py new file mode 100644 index 0000000..cf387b4 --- /dev/null +++ b/osu_dreamer_gui/gui/handlers/mp3_choose_upload.py @@ -0,0 +1,29 @@ +from nicegui import ui, events, app +import mutagen + +from ...modules.encoder.encoder import dump + + +async def on_upload(e: events.UploadEventArguments): + storage = app.storage.user + + storage['bpm'] = None + storage['detected_title'] = None + storage['detected_artist'] = None + + tags = mutagen.File(e.content, easy=True) + + if 'bpm' in tags: + storage['bpm'] = int(float(tags['bpm'][0])) + + if 'title' in tags: + storage['detected_title'] = tags['title'][0] + + if 'artist' in tags: + storage['detected_artist'] = tags['artist'][0] + + storage['filename'] = e.name + storage['audio_content'] = dump(e.content.read()) + + storage['can_be_created'] = True + storage['can_be_saved'] = False diff --git a/osu_dreamer_gui/gui/handlers/on_create.py b/osu_dreamer_gui/gui/handlers/on_create.py new file mode 100644 index 0000000..99b40b2 --- /dev/null +++ b/osu_dreamer_gui/gui/handlers/on_create.py @@ -0,0 +1,57 @@ +from nicegui import ui, app, run +from io import BytesIO +from ...modules.encoder.encoder import load + + +def creator(storage): + import torch + from osu_dreamer.model import Model + from ...modules.generate import generate_mapset + from rich.traceback import install + + install(show_locals=True) + + model = Model.load_from_checkpoint( + f'models/{storage["model_name"]}', + sample_steps=storage['sample_steps'], + ).eval() + + if torch.cuda.is_available(): + model = model.cuda() + + audio = BytesIO(load(storage['audio_content'])) + audio.name = storage['filename'] + audio.seek(0) + return generate_mapset( + model, + audio, + storage['bpm'], + storage['num_samples'], + storage['detected_title'], storage['detected_artist'], + ).name + + +async def on_create(): + storage = app.storage.user + + storage['is_loading'] = True + storage['can_be_created'] = False + + try: + mapset_path = await run.cpu_bound(creator, dict( + model_name=storage['model_name'], + audio_content=storage['audio_content'], + bpm=storage['bpm'], + num_samples=storage['num_samples'], + sample_steps=storage['sample_steps'], + detected_title=storage['detected_title'], + detected_artist=storage['detected_artist'], + filename=storage['filename'], + )) + storage['mapset_path'] = mapset_path + storage['can_be_saved'] = True + except Exception as e: + ui.notify(f'Error {e}') + + storage['is_loading'] = False + storage['can_be_created'] = True diff --git a/osu_dreamer_gui/gui/views/__init__.py b/osu_dreamer_gui/gui/views/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/osu_dreamer_gui/gui/views/home.py b/osu_dreamer_gui/gui/views/home.py new file mode 100644 index 0000000..95c37a3 --- /dev/null +++ b/osu_dreamer_gui/gui/views/home.py @@ -0,0 +1,15 @@ +from nicegui import ui, app +from ..elements.mp3_choose import mp3choose +from ..elements.params_boxes import paramsboxes +from ..elements.create_button import createbutton + + +@ui.page('/') +async def homepage(): + app.storage.user['can_be_created'] = False + app.storage.user['is_loading'] = False + app.storage.user['can_be_saved'] = False + + mp3choose.place() + paramsboxes.place() + createbutton.place() diff --git a/osu_dreamer_gui/modules/__init__.py b/osu_dreamer_gui/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/osu_dreamer_gui/modules/encoder/__init__.py b/osu_dreamer_gui/modules/encoder/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/osu_dreamer_gui/modules/encoder/encoder.py b/osu_dreamer_gui/modules/encoder/encoder.py new file mode 100644 index 0000000..429dd66 --- /dev/null +++ b/osu_dreamer_gui/modules/encoder/encoder.py @@ -0,0 +1,9 @@ +import base64 + + +def dump(obj): + return base64.b64encode(obj).decode() + + +def load(obj): + return base64.b64decode(obj.encode()) diff --git a/osu_dreamer_gui/modules/generate.py b/osu_dreamer_gui/modules/generate.py new file mode 100644 index 0000000..7781d53 --- /dev/null +++ b/osu_dreamer_gui/modules/generate.py @@ -0,0 +1,67 @@ +import random +from pathlib import Path +from zipfile import ZipFile + +import numpy as np +import torch +import librosa + +from osu_dreamer.data import load_audio, HOP_LEN, SR, N_FFT +from osu_dreamer.signal import to_beatmap as signal_to_map + +from io import BytesIO + + +def generate_mapset( + model, + audio_file: BytesIO, + timing: int, + num_samples: int, + title: str, + artist: str, +): + metadata = dict( + audio_filename=audio_file.name, + title=title, + artist=artist, + ) + + # load audio + # ====== + dev = next(model.parameters()).device + a = torch.tensor(load_audio(audio_file), device=dev) + audio_file.seek(0) + + frame_times = librosa.frames_to_time( + np.arange(a.shape[-1]), + hop_length=HOP_LEN, + n_fft=N_FFT, + sr=SR, + ) * 1000 + + # generate maps + # ====== + pred_signals = model(a.repeat(num_samples, 1, 1)).cpu().numpy() + + # package mapset + # ====== + random_hex_string = lambda num: hex(random.randrange(16 ** num))[2:] + + while True: + mapset = Path(f"_{random_hex_string(7)} {artist} - {title}.osz") + if not mapset.exists(): + break + + with ZipFile(mapset, 'x') as mapset_archive: + mapset_archive.writestr(audio_file.name, audio_file.read()) + + for i, pred_signal in enumerate(pred_signals): + mapset_archive.writestr( + f"{artist} - {title} (osu!dreamer) [version {i}].osu", + signal_to_map( + dict(**metadata, version=f"version {i}"), + pred_signal, frame_times, timing, + ), + ) + + return mapset diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..1def1cb --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,29 @@ +[tool.poetry] +name = "osu-dreamer-gui" +version = "0.1.0" +description = "" +authors = ["BarsTiger "] +readme = "README.md" + +[tool.poetry.dependencies] +python = "~3.8" +osu-dreamer = { git = "https://github.com/jaswon/osu-dreamer", tag = "v4.0" } +torch = [ + { url = "https://download.pytorch.org/whl/cu118/torch-2.1.1%2Bcu118-cp38-cp38-win_amd64.whl", platform = "win32" }, + { url = "https://download.pytorch.org/whl/cu118/torch-2.1.1%2Bcu118-cp38-cp38-linux_x86_64.whl", platform = "linux" } +] +torchaudio = { version = "^2.1.1+cu118", source = "torchcuda" } +nicegui = "^1.4.2" +pywebview = "^4.4.1" +pyinstaller = "^6.2.0" +rich = "^13.7.0" + + +[[tool.poetry.source]] +name = "torchcuda" +url = "https://download.pytorch.org/whl/cu118" +priority = "supplemental" + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api"