Initial commit
This commit is contained in:
6
.gitignore
vendored
Normal file
6
.gitignore
vendored
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
/build
|
||||||
|
/dist
|
||||||
|
/models/*
|
||||||
|
|
||||||
|
*.osz
|
||||||
|
poetry.lock
|
||||||
5
README.md
Normal file
5
README.md
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
# osu!dreamer gui
|
||||||
|
|
||||||
|
Interface for [osu!dreamer](https://github.com/jaswon/osu-dreamer) (inference only)
|
||||||
|
|
||||||
|
Just PoC, don't take code quality seriously
|
||||||
45
dreamer.spec
Normal file
45
dreamer.spec
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
# -*- mode: python ; coding: utf-8 -*-
|
||||||
|
import nicegui
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
a = Analysis(
|
||||||
|
['dreamer_gui.py'],
|
||||||
|
pathex=[],
|
||||||
|
binaries=[],
|
||||||
|
datas=[(Path(nicegui.__file__).parent, 'nicegui')],
|
||||||
|
hiddenimports=[],
|
||||||
|
hookspath=[],
|
||||||
|
hooksconfig={},
|
||||||
|
runtime_hooks=[],
|
||||||
|
excludes=[],
|
||||||
|
noarchive=False,
|
||||||
|
)
|
||||||
|
pyz = PYZ(a.pure)
|
||||||
|
|
||||||
|
exe = EXE(
|
||||||
|
pyz,
|
||||||
|
a.scripts,
|
||||||
|
[],
|
||||||
|
exclude_binaries=True,
|
||||||
|
name='osu!dreamer-gui',
|
||||||
|
debug=False,
|
||||||
|
bootloader_ignore_signals=False,
|
||||||
|
strip=False,
|
||||||
|
upx=True,
|
||||||
|
console=False,
|
||||||
|
disable_windowed_traceback=False,
|
||||||
|
argv_emulation=False,
|
||||||
|
target_arch=None,
|
||||||
|
codesign_identity=None,
|
||||||
|
entitlements_file=None,
|
||||||
|
)
|
||||||
|
coll = COLLECT(
|
||||||
|
exe,
|
||||||
|
a.binaries,
|
||||||
|
a.datas,
|
||||||
|
strip=False,
|
||||||
|
upx=True,
|
||||||
|
upx_exclude=[],
|
||||||
|
name='osu!dreamer-gui',
|
||||||
|
)
|
||||||
5
dreamer_gui.py
Normal file
5
dreamer_gui.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from osu_dreamer_gui import main
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ in {'__main__', '__mp_main__'}:
|
||||||
|
main()
|
||||||
17
osu_dreamer_gui/__init__.py
Normal file
17
osu_dreamer_gui/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from nicegui import ui
|
||||||
|
|
||||||
|
from . import gui
|
||||||
|
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
install(show_locals=True)
|
||||||
|
ui.run(
|
||||||
|
title='osu!dreamer',
|
||||||
|
native=True,
|
||||||
|
dark=True,
|
||||||
|
reload=False,
|
||||||
|
storage_secret='...',
|
||||||
|
window_size=(800, 670),
|
||||||
|
)
|
||||||
1
osu_dreamer_gui/gui/__init__.py
Normal file
1
osu_dreamer_gui/gui/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .views import home
|
||||||
0
osu_dreamer_gui/gui/elements/__init__.py
Normal file
0
osu_dreamer_gui/gui/elements/__init__.py
Normal file
29
osu_dreamer_gui/gui/elements/create_button.py
Normal file
29
osu_dreamer_gui/gui/elements/create_button.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from nicegui import ui, app
|
||||||
|
from ..handlers.on_create import on_create
|
||||||
|
|
||||||
|
|
||||||
|
class CreateButton:
|
||||||
|
create_button: ui.button
|
||||||
|
download_button: ui.button
|
||||||
|
loading: ui.spinner
|
||||||
|
|
||||||
|
def place(self):
|
||||||
|
with ui.row():
|
||||||
|
self.create_button = ui.button(
|
||||||
|
'Create map', on_click=on_create
|
||||||
|
).bind_enabled_from(app.storage.user, 'can_be_created')
|
||||||
|
|
||||||
|
self.download_button = ui.button(
|
||||||
|
'Save map',
|
||||||
|
on_click=lambda: ui.download(
|
||||||
|
src=app.storage.user.get('mapset_path')
|
||||||
|
)
|
||||||
|
).bind_enabled_from(app.storage.user, 'can_be_saved')
|
||||||
|
|
||||||
|
self.loading = ui.spinner(
|
||||||
|
type='audio',
|
||||||
|
size='2rem'
|
||||||
|
).bind_visibility_from(app.storage.user, 'is_loading')
|
||||||
|
|
||||||
|
|
||||||
|
createbutton = CreateButton()
|
||||||
18
osu_dreamer_gui/gui/elements/mp3_choose.py
Normal file
18
osu_dreamer_gui/gui/elements/mp3_choose.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
from nicegui import ui
|
||||||
|
from ..handlers.mp3_choose_upload import on_upload
|
||||||
|
|
||||||
|
|
||||||
|
class MP3Choose:
|
||||||
|
uploader: ui.upload
|
||||||
|
|
||||||
|
def place(self):
|
||||||
|
self.uploader = ui.upload(
|
||||||
|
label='Upload audio',
|
||||||
|
max_files=1,
|
||||||
|
on_upload=on_upload,
|
||||||
|
max_total_size=100 * 1_048_576,
|
||||||
|
auto_upload=True,
|
||||||
|
).classes('w-full')
|
||||||
|
|
||||||
|
|
||||||
|
mp3choose = MP3Choose()
|
||||||
54
osu_dreamer_gui/gui/elements/params_boxes.py
Normal file
54
osu_dreamer_gui/gui/elements/params_boxes.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
from nicegui import ui
|
||||||
|
from nicegui import app
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class ParamsBoxes:
|
||||||
|
title: ui.input
|
||||||
|
artist: ui.input
|
||||||
|
bpm: ui.number
|
||||||
|
num_samples = ui.number
|
||||||
|
sample_steps = ui.number
|
||||||
|
model_name = ui.input
|
||||||
|
|
||||||
|
def place(self):
|
||||||
|
storage = app.storage.user
|
||||||
|
|
||||||
|
self.title = ui.input(
|
||||||
|
label='Title',
|
||||||
|
placeholder='Song title',
|
||||||
|
).bind_value(storage, 'detected_title').classes('w-full')
|
||||||
|
|
||||||
|
self.artist = ui.input(
|
||||||
|
label='Artist',
|
||||||
|
placeholder='Song artist',
|
||||||
|
).bind_value(storage, 'detected_artist').classes('w-full')
|
||||||
|
|
||||||
|
self.bpm = ui.number(
|
||||||
|
label='BPM',
|
||||||
|
placeholder='Song BPM',
|
||||||
|
value=180,
|
||||||
|
).bind_value(storage, 'bpm').classes('w-full')
|
||||||
|
|
||||||
|
self.num_samples = ui.number(
|
||||||
|
label='Number of samples',
|
||||||
|
placeholder='Number of samples',
|
||||||
|
value=2,
|
||||||
|
).classes('w-full').bind_value(storage, 'num_samples')
|
||||||
|
|
||||||
|
self.sample_steps = ui.number(
|
||||||
|
label='Sample steps',
|
||||||
|
placeholder='Sample steps',
|
||||||
|
value=32,
|
||||||
|
).classes('w-full').bind_value(storage, 'sample_steps')
|
||||||
|
|
||||||
|
self.model_name = ui.input(
|
||||||
|
label='Model name',
|
||||||
|
placeholder='Model name',
|
||||||
|
value=[item for item in filter(
|
||||||
|
lambda x: x.endswith('.ckpt'), os.listdir('models')
|
||||||
|
)][0],
|
||||||
|
).classes('w-full').bind_value(storage, 'model_name')
|
||||||
|
|
||||||
|
|
||||||
|
paramsboxes = ParamsBoxes()
|
||||||
0
osu_dreamer_gui/gui/handlers/__init__.py
Normal file
0
osu_dreamer_gui/gui/handlers/__init__.py
Normal file
29
osu_dreamer_gui/gui/handlers/mp3_choose_upload.py
Normal file
29
osu_dreamer_gui/gui/handlers/mp3_choose_upload.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from nicegui import ui, events, app
|
||||||
|
import mutagen
|
||||||
|
|
||||||
|
from ...modules.encoder.encoder import dump
|
||||||
|
|
||||||
|
|
||||||
|
async def on_upload(e: events.UploadEventArguments):
|
||||||
|
storage = app.storage.user
|
||||||
|
|
||||||
|
storage['bpm'] = None
|
||||||
|
storage['detected_title'] = None
|
||||||
|
storage['detected_artist'] = None
|
||||||
|
|
||||||
|
tags = mutagen.File(e.content, easy=True)
|
||||||
|
|
||||||
|
if 'bpm' in tags:
|
||||||
|
storage['bpm'] = int(float(tags['bpm'][0]))
|
||||||
|
|
||||||
|
if 'title' in tags:
|
||||||
|
storage['detected_title'] = tags['title'][0]
|
||||||
|
|
||||||
|
if 'artist' in tags:
|
||||||
|
storage['detected_artist'] = tags['artist'][0]
|
||||||
|
|
||||||
|
storage['filename'] = e.name
|
||||||
|
storage['audio_content'] = dump(e.content.read())
|
||||||
|
|
||||||
|
storage['can_be_created'] = True
|
||||||
|
storage['can_be_saved'] = False
|
||||||
57
osu_dreamer_gui/gui/handlers/on_create.py
Normal file
57
osu_dreamer_gui/gui/handlers/on_create.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
from nicegui import ui, app, run
|
||||||
|
from io import BytesIO
|
||||||
|
from ...modules.encoder.encoder import load
|
||||||
|
|
||||||
|
|
||||||
|
def creator(storage):
|
||||||
|
import torch
|
||||||
|
from osu_dreamer.model import Model
|
||||||
|
from ...modules.generate import generate_mapset
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(show_locals=True)
|
||||||
|
|
||||||
|
model = Model.load_from_checkpoint(
|
||||||
|
f'models/{storage["model_name"]}',
|
||||||
|
sample_steps=storage['sample_steps'],
|
||||||
|
).eval()
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
model = model.cuda()
|
||||||
|
|
||||||
|
audio = BytesIO(load(storage['audio_content']))
|
||||||
|
audio.name = storage['filename']
|
||||||
|
audio.seek(0)
|
||||||
|
return generate_mapset(
|
||||||
|
model,
|
||||||
|
audio,
|
||||||
|
storage['bpm'],
|
||||||
|
storage['num_samples'],
|
||||||
|
storage['detected_title'], storage['detected_artist'],
|
||||||
|
).name
|
||||||
|
|
||||||
|
|
||||||
|
async def on_create():
|
||||||
|
storage = app.storage.user
|
||||||
|
|
||||||
|
storage['is_loading'] = True
|
||||||
|
storage['can_be_created'] = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
mapset_path = await run.cpu_bound(creator, dict(
|
||||||
|
model_name=storage['model_name'],
|
||||||
|
audio_content=storage['audio_content'],
|
||||||
|
bpm=storage['bpm'],
|
||||||
|
num_samples=storage['num_samples'],
|
||||||
|
sample_steps=storage['sample_steps'],
|
||||||
|
detected_title=storage['detected_title'],
|
||||||
|
detected_artist=storage['detected_artist'],
|
||||||
|
filename=storage['filename'],
|
||||||
|
))
|
||||||
|
storage['mapset_path'] = mapset_path
|
||||||
|
storage['can_be_saved'] = True
|
||||||
|
except Exception as e:
|
||||||
|
ui.notify(f'Error {e}')
|
||||||
|
|
||||||
|
storage['is_loading'] = False
|
||||||
|
storage['can_be_created'] = True
|
||||||
0
osu_dreamer_gui/gui/views/__init__.py
Normal file
0
osu_dreamer_gui/gui/views/__init__.py
Normal file
15
osu_dreamer_gui/gui/views/home.py
Normal file
15
osu_dreamer_gui/gui/views/home.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
from nicegui import ui, app
|
||||||
|
from ..elements.mp3_choose import mp3choose
|
||||||
|
from ..elements.params_boxes import paramsboxes
|
||||||
|
from ..elements.create_button import createbutton
|
||||||
|
|
||||||
|
|
||||||
|
@ui.page('/')
|
||||||
|
async def homepage():
|
||||||
|
app.storage.user['can_be_created'] = False
|
||||||
|
app.storage.user['is_loading'] = False
|
||||||
|
app.storage.user['can_be_saved'] = False
|
||||||
|
|
||||||
|
mp3choose.place()
|
||||||
|
paramsboxes.place()
|
||||||
|
createbutton.place()
|
||||||
0
osu_dreamer_gui/modules/__init__.py
Normal file
0
osu_dreamer_gui/modules/__init__.py
Normal file
0
osu_dreamer_gui/modules/encoder/__init__.py
Normal file
0
osu_dreamer_gui/modules/encoder/__init__.py
Normal file
9
osu_dreamer_gui/modules/encoder/encoder.py
Normal file
9
osu_dreamer_gui/modules/encoder/encoder.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
import base64
|
||||||
|
|
||||||
|
|
||||||
|
def dump(obj):
|
||||||
|
return base64.b64encode(obj).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def load(obj):
|
||||||
|
return base64.b64decode(obj.encode())
|
||||||
67
osu_dreamer_gui/modules/generate.py
Normal file
67
osu_dreamer_gui/modules/generate.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
from zipfile import ZipFile
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import librosa
|
||||||
|
|
||||||
|
from osu_dreamer.data import load_audio, HOP_LEN, SR, N_FFT
|
||||||
|
from osu_dreamer.signal import to_beatmap as signal_to_map
|
||||||
|
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
|
||||||
|
def generate_mapset(
|
||||||
|
model,
|
||||||
|
audio_file: BytesIO,
|
||||||
|
timing: int,
|
||||||
|
num_samples: int,
|
||||||
|
title: str,
|
||||||
|
artist: str,
|
||||||
|
):
|
||||||
|
metadata = dict(
|
||||||
|
audio_filename=audio_file.name,
|
||||||
|
title=title,
|
||||||
|
artist=artist,
|
||||||
|
)
|
||||||
|
|
||||||
|
# load audio
|
||||||
|
# ======
|
||||||
|
dev = next(model.parameters()).device
|
||||||
|
a = torch.tensor(load_audio(audio_file), device=dev)
|
||||||
|
audio_file.seek(0)
|
||||||
|
|
||||||
|
frame_times = librosa.frames_to_time(
|
||||||
|
np.arange(a.shape[-1]),
|
||||||
|
hop_length=HOP_LEN,
|
||||||
|
n_fft=N_FFT,
|
||||||
|
sr=SR,
|
||||||
|
) * 1000
|
||||||
|
|
||||||
|
# generate maps
|
||||||
|
# ======
|
||||||
|
pred_signals = model(a.repeat(num_samples, 1, 1)).cpu().numpy()
|
||||||
|
|
||||||
|
# package mapset
|
||||||
|
# ======
|
||||||
|
random_hex_string = lambda num: hex(random.randrange(16 ** num))[2:]
|
||||||
|
|
||||||
|
while True:
|
||||||
|
mapset = Path(f"_{random_hex_string(7)} {artist} - {title}.osz")
|
||||||
|
if not mapset.exists():
|
||||||
|
break
|
||||||
|
|
||||||
|
with ZipFile(mapset, 'x') as mapset_archive:
|
||||||
|
mapset_archive.writestr(audio_file.name, audio_file.read())
|
||||||
|
|
||||||
|
for i, pred_signal in enumerate(pred_signals):
|
||||||
|
mapset_archive.writestr(
|
||||||
|
f"{artist} - {title} (osu!dreamer) [version {i}].osu",
|
||||||
|
signal_to_map(
|
||||||
|
dict(**metadata, version=f"version {i}"),
|
||||||
|
pred_signal, frame_times, timing,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return mapset
|
||||||
29
pyproject.toml
Normal file
29
pyproject.toml
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
[tool.poetry]
|
||||||
|
name = "osu-dreamer-gui"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = ""
|
||||||
|
authors = ["BarsTiger <zxcbarstiger@gmail.com>"]
|
||||||
|
readme = "README.md"
|
||||||
|
|
||||||
|
[tool.poetry.dependencies]
|
||||||
|
python = "~3.8"
|
||||||
|
osu-dreamer = { git = "https://github.com/jaswon/osu-dreamer", tag = "v4.0" }
|
||||||
|
torch = [
|
||||||
|
{ url = "https://download.pytorch.org/whl/cu118/torch-2.1.1%2Bcu118-cp38-cp38-win_amd64.whl", platform = "win32" },
|
||||||
|
{ url = "https://download.pytorch.org/whl/cu118/torch-2.1.1%2Bcu118-cp38-cp38-linux_x86_64.whl", platform = "linux" }
|
||||||
|
]
|
||||||
|
torchaudio = { version = "^2.1.1+cu118", source = "torchcuda" }
|
||||||
|
nicegui = "^1.4.2"
|
||||||
|
pywebview = "^4.4.1"
|
||||||
|
pyinstaller = "^6.2.0"
|
||||||
|
rich = "^13.7.0"
|
||||||
|
|
||||||
|
|
||||||
|
[[tool.poetry.source]]
|
||||||
|
name = "torchcuda"
|
||||||
|
url = "https://download.pytorch.org/whl/cu118"
|
||||||
|
priority = "supplemental"
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["poetry-core"]
|
||||||
|
build-backend = "poetry.core.masonry.api"
|
||||||
Reference in New Issue
Block a user