Initial commit
This commit is contained in:
0
osu_dreamer_gui/modules/__init__.py
Normal file
0
osu_dreamer_gui/modules/__init__.py
Normal file
0
osu_dreamer_gui/modules/encoder/__init__.py
Normal file
0
osu_dreamer_gui/modules/encoder/__init__.py
Normal file
9
osu_dreamer_gui/modules/encoder/encoder.py
Normal file
9
osu_dreamer_gui/modules/encoder/encoder.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import base64
|
||||
|
||||
|
||||
def dump(obj):
|
||||
return base64.b64encode(obj).decode()
|
||||
|
||||
|
||||
def load(obj):
|
||||
return base64.b64decode(obj.encode())
|
||||
67
osu_dreamer_gui/modules/generate.py
Normal file
67
osu_dreamer_gui/modules/generate.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import random
|
||||
from pathlib import Path
|
||||
from zipfile import ZipFile
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import librosa
|
||||
|
||||
from osu_dreamer.data import load_audio, HOP_LEN, SR, N_FFT
|
||||
from osu_dreamer.signal import to_beatmap as signal_to_map
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
def generate_mapset(
|
||||
model,
|
||||
audio_file: BytesIO,
|
||||
timing: int,
|
||||
num_samples: int,
|
||||
title: str,
|
||||
artist: str,
|
||||
):
|
||||
metadata = dict(
|
||||
audio_filename=audio_file.name,
|
||||
title=title,
|
||||
artist=artist,
|
||||
)
|
||||
|
||||
# load audio
|
||||
# ======
|
||||
dev = next(model.parameters()).device
|
||||
a = torch.tensor(load_audio(audio_file), device=dev)
|
||||
audio_file.seek(0)
|
||||
|
||||
frame_times = librosa.frames_to_time(
|
||||
np.arange(a.shape[-1]),
|
||||
hop_length=HOP_LEN,
|
||||
n_fft=N_FFT,
|
||||
sr=SR,
|
||||
) * 1000
|
||||
|
||||
# generate maps
|
||||
# ======
|
||||
pred_signals = model(a.repeat(num_samples, 1, 1)).cpu().numpy()
|
||||
|
||||
# package mapset
|
||||
# ======
|
||||
random_hex_string = lambda num: hex(random.randrange(16 ** num))[2:]
|
||||
|
||||
while True:
|
||||
mapset = Path(f"_{random_hex_string(7)} {artist} - {title}.osz")
|
||||
if not mapset.exists():
|
||||
break
|
||||
|
||||
with ZipFile(mapset, 'x') as mapset_archive:
|
||||
mapset_archive.writestr(audio_file.name, audio_file.read())
|
||||
|
||||
for i, pred_signal in enumerate(pred_signals):
|
||||
mapset_archive.writestr(
|
||||
f"{artist} - {title} (osu!dreamer) [version {i}].osu",
|
||||
signal_to_map(
|
||||
dict(**metadata, version=f"version {i}"),
|
||||
pred_signal, frame_times, timing,
|
||||
),
|
||||
)
|
||||
|
||||
return mapset
|
||||
Reference in New Issue
Block a user