Initial commit

This commit is contained in:
BarsTiger
2023-11-16 22:49:59 +02:00
commit 653711aba3
20 changed files with 386 additions and 0 deletions

6
.gitignore vendored Normal file
View File

@@ -0,0 +1,6 @@
/build
/dist
/models/*
*.osz
poetry.lock

5
README.md Normal file
View File

@@ -0,0 +1,5 @@
# osu!dreamer gui
Interface for [osu!dreamer](https://github.com/jaswon/osu-dreamer) (inference only)
Just PoC, don't take code quality seriously

45
dreamer.spec Normal file
View File

@@ -0,0 +1,45 @@
# -*- mode: python ; coding: utf-8 -*-
import nicegui
from pathlib import Path
a = Analysis(
['dreamer_gui.py'],
pathex=[],
binaries=[],
datas=[(Path(nicegui.__file__).parent, 'nicegui')],
hiddenimports=[],
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
noarchive=False,
)
pyz = PYZ(a.pure)
exe = EXE(
pyz,
a.scripts,
[],
exclude_binaries=True,
name='osu!dreamer-gui',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
console=False,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)
coll = COLLECT(
exe,
a.binaries,
a.datas,
strip=False,
upx=True,
upx_exclude=[],
name='osu!dreamer-gui',
)

5
dreamer_gui.py Normal file
View File

@@ -0,0 +1,5 @@
from osu_dreamer_gui import main
if __name__ in {'__main__', '__mp_main__'}:
main()

View File

@@ -0,0 +1,17 @@
from nicegui import ui
from . import gui
from rich.traceback import install
def main():
install(show_locals=True)
ui.run(
title='osu!dreamer',
native=True,
dark=True,
reload=False,
storage_secret='...',
window_size=(800, 670),
)

View File

@@ -0,0 +1 @@
from .views import home

View File

View File

@@ -0,0 +1,29 @@
from nicegui import ui, app
from ..handlers.on_create import on_create
class CreateButton:
create_button: ui.button
download_button: ui.button
loading: ui.spinner
def place(self):
with ui.row():
self.create_button = ui.button(
'Create map', on_click=on_create
).bind_enabled_from(app.storage.user, 'can_be_created')
self.download_button = ui.button(
'Save map',
on_click=lambda: ui.download(
src=app.storage.user.get('mapset_path')
)
).bind_enabled_from(app.storage.user, 'can_be_saved')
self.loading = ui.spinner(
type='audio',
size='2rem'
).bind_visibility_from(app.storage.user, 'is_loading')
createbutton = CreateButton()

View File

@@ -0,0 +1,18 @@
from nicegui import ui
from ..handlers.mp3_choose_upload import on_upload
class MP3Choose:
uploader: ui.upload
def place(self):
self.uploader = ui.upload(
label='Upload audio',
max_files=1,
on_upload=on_upload,
max_total_size=100 * 1_048_576,
auto_upload=True,
).classes('w-full')
mp3choose = MP3Choose()

View File

@@ -0,0 +1,54 @@
from nicegui import ui
from nicegui import app
import os
class ParamsBoxes:
title: ui.input
artist: ui.input
bpm: ui.number
num_samples = ui.number
sample_steps = ui.number
model_name = ui.input
def place(self):
storage = app.storage.user
self.title = ui.input(
label='Title',
placeholder='Song title',
).bind_value(storage, 'detected_title').classes('w-full')
self.artist = ui.input(
label='Artist',
placeholder='Song artist',
).bind_value(storage, 'detected_artist').classes('w-full')
self.bpm = ui.number(
label='BPM',
placeholder='Song BPM',
value=180,
).bind_value(storage, 'bpm').classes('w-full')
self.num_samples = ui.number(
label='Number of samples',
placeholder='Number of samples',
value=2,
).classes('w-full').bind_value(storage, 'num_samples')
self.sample_steps = ui.number(
label='Sample steps',
placeholder='Sample steps',
value=32,
).classes('w-full').bind_value(storage, 'sample_steps')
self.model_name = ui.input(
label='Model name',
placeholder='Model name',
value=[item for item in filter(
lambda x: x.endswith('.ckpt'), os.listdir('models')
)][0],
).classes('w-full').bind_value(storage, 'model_name')
paramsboxes = ParamsBoxes()

View File

View File

@@ -0,0 +1,29 @@
from nicegui import ui, events, app
import mutagen
from ...modules.encoder.encoder import dump
async def on_upload(e: events.UploadEventArguments):
storage = app.storage.user
storage['bpm'] = None
storage['detected_title'] = None
storage['detected_artist'] = None
tags = mutagen.File(e.content, easy=True)
if 'bpm' in tags:
storage['bpm'] = int(float(tags['bpm'][0]))
if 'title' in tags:
storage['detected_title'] = tags['title'][0]
if 'artist' in tags:
storage['detected_artist'] = tags['artist'][0]
storage['filename'] = e.name
storage['audio_content'] = dump(e.content.read())
storage['can_be_created'] = True
storage['can_be_saved'] = False

View File

@@ -0,0 +1,57 @@
from nicegui import ui, app, run
from io import BytesIO
from ...modules.encoder.encoder import load
def creator(storage):
import torch
from osu_dreamer.model import Model
from ...modules.generate import generate_mapset
from rich.traceback import install
install(show_locals=True)
model = Model.load_from_checkpoint(
f'models/{storage["model_name"]}',
sample_steps=storage['sample_steps'],
).eval()
if torch.cuda.is_available():
model = model.cuda()
audio = BytesIO(load(storage['audio_content']))
audio.name = storage['filename']
audio.seek(0)
return generate_mapset(
model,
audio,
storage['bpm'],
storage['num_samples'],
storage['detected_title'], storage['detected_artist'],
).name
async def on_create():
storage = app.storage.user
storage['is_loading'] = True
storage['can_be_created'] = False
try:
mapset_path = await run.cpu_bound(creator, dict(
model_name=storage['model_name'],
audio_content=storage['audio_content'],
bpm=storage['bpm'],
num_samples=storage['num_samples'],
sample_steps=storage['sample_steps'],
detected_title=storage['detected_title'],
detected_artist=storage['detected_artist'],
filename=storage['filename'],
))
storage['mapset_path'] = mapset_path
storage['can_be_saved'] = True
except Exception as e:
ui.notify(f'Error {e}')
storage['is_loading'] = False
storage['can_be_created'] = True

View File

View File

@@ -0,0 +1,15 @@
from nicegui import ui, app
from ..elements.mp3_choose import mp3choose
from ..elements.params_boxes import paramsboxes
from ..elements.create_button import createbutton
@ui.page('/')
async def homepage():
app.storage.user['can_be_created'] = False
app.storage.user['is_loading'] = False
app.storage.user['can_be_saved'] = False
mp3choose.place()
paramsboxes.place()
createbutton.place()

View File

View File

@@ -0,0 +1,9 @@
import base64
def dump(obj):
return base64.b64encode(obj).decode()
def load(obj):
return base64.b64decode(obj.encode())

View File

@@ -0,0 +1,67 @@
import random
from pathlib import Path
from zipfile import ZipFile
import numpy as np
import torch
import librosa
from osu_dreamer.data import load_audio, HOP_LEN, SR, N_FFT
from osu_dreamer.signal import to_beatmap as signal_to_map
from io import BytesIO
def generate_mapset(
model,
audio_file: BytesIO,
timing: int,
num_samples: int,
title: str,
artist: str,
):
metadata = dict(
audio_filename=audio_file.name,
title=title,
artist=artist,
)
# load audio
# ======
dev = next(model.parameters()).device
a = torch.tensor(load_audio(audio_file), device=dev)
audio_file.seek(0)
frame_times = librosa.frames_to_time(
np.arange(a.shape[-1]),
hop_length=HOP_LEN,
n_fft=N_FFT,
sr=SR,
) * 1000
# generate maps
# ======
pred_signals = model(a.repeat(num_samples, 1, 1)).cpu().numpy()
# package mapset
# ======
random_hex_string = lambda num: hex(random.randrange(16 ** num))[2:]
while True:
mapset = Path(f"_{random_hex_string(7)} {artist} - {title}.osz")
if not mapset.exists():
break
with ZipFile(mapset, 'x') as mapset_archive:
mapset_archive.writestr(audio_file.name, audio_file.read())
for i, pred_signal in enumerate(pred_signals):
mapset_archive.writestr(
f"{artist} - {title} (osu!dreamer) [version {i}].osu",
signal_to_map(
dict(**metadata, version=f"version {i}"),
pred_signal, frame_times, timing,
),
)
return mapset

29
pyproject.toml Normal file
View File

@@ -0,0 +1,29 @@
[tool.poetry]
name = "osu-dreamer-gui"
version = "0.1.0"
description = ""
authors = ["BarsTiger <zxcbarstiger@gmail.com>"]
readme = "README.md"
[tool.poetry.dependencies]
python = "~3.8"
osu-dreamer = { git = "https://github.com/jaswon/osu-dreamer", tag = "v4.0" }
torch = [
{ url = "https://download.pytorch.org/whl/cu118/torch-2.1.1%2Bcu118-cp38-cp38-win_amd64.whl", platform = "win32" },
{ url = "https://download.pytorch.org/whl/cu118/torch-2.1.1%2Bcu118-cp38-cp38-linux_x86_64.whl", platform = "linux" }
]
torchaudio = { version = "^2.1.1+cu118", source = "torchcuda" }
nicegui = "^1.4.2"
pywebview = "^4.4.1"
pyinstaller = "^6.2.0"
rich = "^13.7.0"
[[tool.poetry.source]]
name = "torchcuda"
url = "https://download.pytorch.org/whl/cu118"
priority = "supplemental"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"