feat: mvp
This commit is contained in:
156
dubbing/steps/mix.py
Normal file
156
dubbing/steps/mix.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import json
|
||||
import subprocess
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
from pydub import AudioSegment
|
||||
from rich.console import Console
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
BarColumn,
|
||||
TaskProgressColumn,
|
||||
)
|
||||
from dubbing.steps.base import PipelineStep
|
||||
from dubbing.models import TranslatedSegment
|
||||
from dubbing.config import settings
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class MixStep(PipelineStep):
|
||||
name = "Mix"
|
||||
|
||||
def is_cached(self) -> bool:
|
||||
return self.paths.dubbed_audio.exists()
|
||||
|
||||
def clean(self) -> None:
|
||||
if self.paths.dubbed_audio.exists():
|
||||
self.paths.dubbed_audio.unlink()
|
||||
for f in self.paths.tts_dir.glob("*_fast.wav"):
|
||||
f.unlink()
|
||||
|
||||
def _load_translated(self) -> list[TranslatedSegment]:
|
||||
with open(self.paths.translated_json, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return [TranslatedSegment(**s) for s in data]
|
||||
|
||||
def _speedup_file(self, input_path: Path, output_path: Path, speed: float) -> None:
|
||||
filters = []
|
||||
remaining = speed
|
||||
while remaining > 1.0:
|
||||
if remaining > 2.0:
|
||||
filters.append("atempo=2.0")
|
||||
remaining /= 2.0
|
||||
else:
|
||||
filters.append(f"atempo={remaining:.3f}")
|
||||
remaining = 1.0
|
||||
|
||||
subprocess.run(
|
||||
[
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
str(input_path),
|
||||
"-filter:a",
|
||||
",".join(filters),
|
||||
"-vn",
|
||||
str(output_path),
|
||||
],
|
||||
capture_output=True,
|
||||
)
|
||||
|
||||
def _process_segment(
|
||||
self, i: int, seg: TranslatedSegment, target_sr: int, channels: int
|
||||
) -> tuple[int, np.ndarray | None]:
|
||||
audio_path = self.paths.tts_dir / f"seg_{i:04d}.wav"
|
||||
fast_path = self.paths.tts_dir / f"seg_{i:04d}_fast.wav"
|
||||
|
||||
if not audio_path.exists():
|
||||
return (int(seg.start), None)
|
||||
|
||||
if fast_path.exists():
|
||||
data, sr = sf.read(str(fast_path), dtype="int16")
|
||||
else:
|
||||
data, sr = sf.read(str(audio_path), dtype="int16")
|
||||
duration_ms = len(data) / sr * 1000
|
||||
available_ms = seg.end - seg.start
|
||||
|
||||
if duration_ms > available_ms > 100:
|
||||
speedup_ratio = min(duration_ms / available_ms, settings.max_speedup)
|
||||
self._speedup_file(audio_path, fast_path, speedup_ratio)
|
||||
data, sr = sf.read(str(fast_path), dtype="int16")
|
||||
|
||||
if len(data.shape) == 1 and channels == 2:
|
||||
data = np.column_stack([data, data])
|
||||
elif len(data.shape) == 2 and channels == 1:
|
||||
data = data[:, 0]
|
||||
|
||||
if sr != target_sr:
|
||||
ratio = target_sr / sr
|
||||
new_len = int(len(data) * ratio)
|
||||
indices = np.linspace(0, len(data) - 1, new_len).astype(int)
|
||||
data = data[indices]
|
||||
|
||||
return (int(seg.start), data)
|
||||
|
||||
async def run(self) -> None:
|
||||
console.print("[cyan]Mixing audio tracks...[/]")
|
||||
|
||||
segments = self._load_translated()
|
||||
|
||||
original = AudioSegment.from_mp3(str(self.paths.source_audio))
|
||||
original_quiet = original + settings.original_volume_db
|
||||
original_samples = np.array(
|
||||
original_quiet.get_array_of_samples(), dtype=np.float32
|
||||
)
|
||||
sample_rate = original_quiet.frame_rate
|
||||
channels = original_quiet.channels
|
||||
|
||||
if channels == 2:
|
||||
original_samples = original_samples.reshape(-1, 2)
|
||||
|
||||
dubbed_samples = np.zeros_like(original_samples, dtype=np.float32)
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
console=console,
|
||||
) as progress:
|
||||
task = progress.add_task("Processing & mixing...", total=len(segments))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=8) as executor:
|
||||
futures = {
|
||||
executor.submit(
|
||||
self._process_segment, i, seg, sample_rate, channels
|
||||
): i
|
||||
for i, seg in enumerate(segments)
|
||||
}
|
||||
for future in as_completed(futures):
|
||||
position_ms, data = future.result()
|
||||
if data is not None:
|
||||
start_sample = int(position_ms * sample_rate / 1000)
|
||||
end_sample = min(start_sample + len(data), len(dubbed_samples))
|
||||
length = end_sample - start_sample
|
||||
dubbed_samples[start_sample:end_sample] += data[:length].astype(
|
||||
np.float32
|
||||
)
|
||||
progress.advance(task)
|
||||
|
||||
console.print("[cyan]Exporting MP3...[/]")
|
||||
final_samples = original_samples + dubbed_samples
|
||||
final_samples = np.clip(final_samples, -32768, 32767).astype(np.int16)
|
||||
|
||||
final = AudioSegment(
|
||||
final_samples.tobytes(),
|
||||
frame_rate=sample_rate,
|
||||
sample_width=2,
|
||||
channels=channels,
|
||||
)
|
||||
final.export(str(self.paths.dubbed_audio), format="mp3")
|
||||
|
||||
console.print("[green]✓ Audio mixed[/]")
|
||||
Reference in New Issue
Block a user