feat: mvp

This commit is contained in:
h
2026-02-01 16:07:59 +01:00
commit 4ef769597a
20 changed files with 3566 additions and 0 deletions

15
dubbing/steps/__init__.py Normal file
View File

@@ -0,0 +1,15 @@
from dubbing.steps.extract_audio import ExtractAudioStep
from dubbing.steps.asr import ASRStep
from dubbing.steps.translate import TranslateStep
from dubbing.steps.tts import TTSStep
from dubbing.steps.mix import MixStep
from dubbing.steps.finalize import FinalizeStep
__all__ = [
"ExtractAudioStep",
"ASRStep",
"TranslateStep",
"TTSStep",
"MixStep",
"FinalizeStep",
]

78
dubbing/steps/asr.py Normal file
View File

@@ -0,0 +1,78 @@
import json
from rich.console import Console
from dubbing.steps.base import PipelineStep
from dubbing.models import Segment
console = Console()
class ASRStep(PipelineStep):
name = "ASR"
def is_cached(self) -> bool:
return self.paths.segments_json.exists()
def clean(self) -> None:
if self.paths.segments_json.exists():
self.paths.segments_json.unlink()
def _group_chinese_segments(
self, words: list[str], timestamps: list[list[int]]
) -> list[Segment]:
segments = []
current_text = ""
current_start = None
current_end = None
punctuation = {"", "", "", "", "", "", "", ""}
for word, ts in zip(words, timestamps):
if current_start is None:
current_start = ts[0]
current_text += word
current_end = ts[1]
if word in punctuation:
segments.append(
Segment(start=current_start, end=current_end, text=current_text)
)
current_text = ""
current_start = None
if current_text:
segments.append(
Segment(start=current_start, end=current_end, text=current_text)
)
return segments
async def run(self) -> None:
console.print("[cyan]Running speech recognition...[/]")
from funasr import AutoModel
model = AutoModel(
model="iic/SenseVoiceSmall",
device="mps",
vad_model="fsmn-vad",
vad_kwargs={"max_single_segment_time": 30000},
)
result = model.generate(
input=str(self.paths.source_audio),
language="zh",
use_itn=True,
batch_size_s=60,
merge_length_s=15,
output_timestamp=True,
)
segments = self._group_chinese_segments(
result[0]["words"], result[0]["timestamp"]
)
with open(self.paths.segments_json, "w", encoding="utf-8") as f:
json.dump(
[s.model_dump() for s in segments], f, ensure_ascii=False, indent=2
)
console.print(f"[green]✓ Found {len(segments)} segments[/]")

21
dubbing/steps/base.py Normal file
View File

@@ -0,0 +1,21 @@
from abc import ABC, abstractmethod
from dubbing.models import ProjectPaths
class PipelineStep(ABC):
name: str = "Step"
def __init__(self, paths: ProjectPaths):
self.paths = paths
@abstractmethod
def is_cached(self) -> bool:
pass
@abstractmethod
def clean(self) -> None:
pass
@abstractmethod
async def run(self) -> None:
pass

View File

@@ -0,0 +1,38 @@
import subprocess
from rich.console import Console
from dubbing.steps.base import PipelineStep
console = Console()
class ExtractAudioStep(PipelineStep):
name = "Extract audio"
def is_cached(self) -> bool:
return self.paths.source_audio.exists()
def clean(self) -> None:
if self.paths.source_audio.exists():
self.paths.source_audio.unlink()
async def run(self) -> None:
console.print("[cyan]Extracting audio from video...[/]")
subprocess.run(
[
"ffmpeg",
"-y",
"-i",
str(self.paths.source_video),
"-vn",
"-acodec",
"libmp3lame",
"-q:a",
"2",
str(self.paths.source_audio),
],
capture_output=True,
check=True,
)
console.print("[green]✓ Audio extracted[/]")

44
dubbing/steps/finalize.py Normal file
View File

@@ -0,0 +1,44 @@
import subprocess
from rich.console import Console
from dubbing.steps.base import PipelineStep
console = Console()
class FinalizeStep(PipelineStep):
name = "Finalize"
def is_cached(self) -> bool:
return self.paths.result_video.exists()
def clean(self) -> None:
if self.paths.result_video.exists():
self.paths.result_video.unlink()
async def run(self) -> None:
console.print("[cyan]Creating final video...[/]")
subprocess.run(
[
"ffmpeg",
"-y",
"-i",
str(self.paths.source_video),
"-i",
str(self.paths.dubbed_audio),
"-map",
"0:v",
"-map",
"1:a",
"-c:v",
"copy",
"-c:a",
"copy",
"-shortest",
str(self.paths.result_video),
],
capture_output=True,
check=True,
)
console.print(f"[green]✓ Created {self.paths.result_video}[/]")

156
dubbing/steps/mix.py Normal file
View File

@@ -0,0 +1,156 @@
import json
import subprocess
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import numpy as np
import soundfile as sf
from pydub import AudioSegment
from rich.console import Console
from rich.progress import (
Progress,
SpinnerColumn,
TextColumn,
BarColumn,
TaskProgressColumn,
)
from dubbing.steps.base import PipelineStep
from dubbing.models import TranslatedSegment
from dubbing.config import settings
console = Console()
class MixStep(PipelineStep):
name = "Mix"
def is_cached(self) -> bool:
return self.paths.dubbed_audio.exists()
def clean(self) -> None:
if self.paths.dubbed_audio.exists():
self.paths.dubbed_audio.unlink()
for f in self.paths.tts_dir.glob("*_fast.wav"):
f.unlink()
def _load_translated(self) -> list[TranslatedSegment]:
with open(self.paths.translated_json, "r", encoding="utf-8") as f:
data = json.load(f)
return [TranslatedSegment(**s) for s in data]
def _speedup_file(self, input_path: Path, output_path: Path, speed: float) -> None:
filters = []
remaining = speed
while remaining > 1.0:
if remaining > 2.0:
filters.append("atempo=2.0")
remaining /= 2.0
else:
filters.append(f"atempo={remaining:.3f}")
remaining = 1.0
subprocess.run(
[
"ffmpeg",
"-y",
"-i",
str(input_path),
"-filter:a",
",".join(filters),
"-vn",
str(output_path),
],
capture_output=True,
)
def _process_segment(
self, i: int, seg: TranslatedSegment, target_sr: int, channels: int
) -> tuple[int, np.ndarray | None]:
audio_path = self.paths.tts_dir / f"seg_{i:04d}.wav"
fast_path = self.paths.tts_dir / f"seg_{i:04d}_fast.wav"
if not audio_path.exists():
return (int(seg.start), None)
if fast_path.exists():
data, sr = sf.read(str(fast_path), dtype="int16")
else:
data, sr = sf.read(str(audio_path), dtype="int16")
duration_ms = len(data) / sr * 1000
available_ms = seg.end - seg.start
if duration_ms > available_ms > 100:
speedup_ratio = min(duration_ms / available_ms, settings.max_speedup)
self._speedup_file(audio_path, fast_path, speedup_ratio)
data, sr = sf.read(str(fast_path), dtype="int16")
if len(data.shape) == 1 and channels == 2:
data = np.column_stack([data, data])
elif len(data.shape) == 2 and channels == 1:
data = data[:, 0]
if sr != target_sr:
ratio = target_sr / sr
new_len = int(len(data) * ratio)
indices = np.linspace(0, len(data) - 1, new_len).astype(int)
data = data[indices]
return (int(seg.start), data)
async def run(self) -> None:
console.print("[cyan]Mixing audio tracks...[/]")
segments = self._load_translated()
original = AudioSegment.from_mp3(str(self.paths.source_audio))
original_quiet = original + settings.original_volume_db
original_samples = np.array(
original_quiet.get_array_of_samples(), dtype=np.float32
)
sample_rate = original_quiet.frame_rate
channels = original_quiet.channels
if channels == 2:
original_samples = original_samples.reshape(-1, 2)
dubbed_samples = np.zeros_like(original_samples, dtype=np.float32)
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
console=console,
) as progress:
task = progress.add_task("Processing & mixing...", total=len(segments))
with ThreadPoolExecutor(max_workers=8) as executor:
futures = {
executor.submit(
self._process_segment, i, seg, sample_rate, channels
): i
for i, seg in enumerate(segments)
}
for future in as_completed(futures):
position_ms, data = future.result()
if data is not None:
start_sample = int(position_ms * sample_rate / 1000)
end_sample = min(start_sample + len(data), len(dubbed_samples))
length = end_sample - start_sample
dubbed_samples[start_sample:end_sample] += data[:length].astype(
np.float32
)
progress.advance(task)
console.print("[cyan]Exporting MP3...[/]")
final_samples = original_samples + dubbed_samples
final_samples = np.clip(final_samples, -32768, 32767).astype(np.int16)
final = AudioSegment(
final_samples.tobytes(),
frame_rate=sample_rate,
sample_width=2,
channels=channels,
)
final.export(str(self.paths.dubbed_audio), format="mp3")
console.print("[green]✓ Audio mixed[/]")

215
dubbing/steps/translate.py Normal file
View File

@@ -0,0 +1,215 @@
import asyncio
import json
from rich.console import Console
from rich.progress import (
Progress,
SpinnerColumn,
TextColumn,
BarColumn,
TaskProgressColumn,
)
from pydantic_ai import Agent, NativeOutput
from pydantic_ai.models.google import GoogleModel
from pydantic_ai.providers.google import GoogleProvider
from dubbing.steps.base import PipelineStep
from dubbing.models import Segment, TranslatedSegment, TranslationBatch, Language
from dubbing.config import settings, TranslationPrompts
console = Console()
class TranslateStep(PipelineStep):
name = "Translate"
def __init__(
self,
paths,
model_name: str = "gemini-2.0-flash-lite",
language: Language = Language.RU,
):
super().__init__(paths)
self.model_name = model_name
self.language = language
def is_cached(self) -> bool:
return self.paths.translated_json.exists()
def clean(self) -> None:
if self.paths.translated_json.exists():
self.paths.translated_json.unlink()
def _load_segments(self) -> list[Segment]:
with open(self.paths.segments_json, "r", encoding="utf-8") as f:
data = json.load(f)
return [Segment(**s) for s in data]
def _get_system_prompt(self, stage: int = 1) -> str:
if self.language == Language.EN:
return TranslationPrompts.EN
elif self.language == Language.EN_RU:
return (
TranslationPrompts.EN_RU_STAGE1
if stage == 1
else TranslationPrompts.EN_RU_STAGE2
)
return TranslationPrompts.RU
def _get_translate_command(self, stage: int = 1) -> str:
if self.language == Language.EN:
return "Translate:"
elif self.language == Language.EN_RU:
return "Translate:" if stage == 1 else "Переведи:"
return "Переведи:"
def _get_context_header(self, stage: int = 1) -> str:
if self.language == Language.EN:
return "Context:"
elif self.language == Language.EN_RU:
return "Context:" if stage == 1 else "Контекст:"
return "Контекст:"
def _create_agent(self, stage: int = 1) -> Agent:
provider = GoogleProvider(api_key=settings.gemini_api_key)
model = GoogleModel(self.model_name, provider=provider)
return Agent(
model,
output_type=NativeOutput(TranslationBatch),
system_prompt=self._get_system_prompt(stage),
)
async def _translate_chunk(
self,
agent: Agent,
chunk: list[Segment],
chunk_idx: int,
context: str,
semaphore: asyncio.Semaphore,
stage: int = 1,
) -> list[TranslatedSegment]:
async with semaphore:
# For stage 2, use translated field as source
if stage == 2:
items = "\n".join([f"{i}: {s.translated}" for i, s in enumerate(chunk)])
else:
items = "\n".join([f"{i}: {s.text}" for i, s in enumerate(chunk)])
prompt = f"{context}{self._get_translate_command(stage)}\n\n{items}"
try:
result = await agent.run(prompt)
translated = []
for i, seg in enumerate(chunk):
translated_text = ""
for t in result.output.translations:
if t.id == i:
translated_text = t.translated
break
translated.append(
TranslatedSegment(
start=seg.start,
end=seg.end,
text=seg.text,
translated=translated_text
or (seg.translated if stage == 2 else seg.text),
)
)
return translated
except Exception as e:
console.print(f"[red]Chunk {chunk_idx} error: {e}[/]")
fallback = seg.translated if stage == 2 else seg.text
return [
TranslatedSegment(
start=s.start, end=s.end, text=s.text, translated=fallback
)
for s in chunk
]
async def _translate_parallel(
self, agent: Agent, segments: list, stage: int = 1, desc: str = "Translating..."
) -> list[TranslatedSegment]:
chunk_size = settings.translation_chunk_size
concurrency = settings.translation_concurrency
chunks = [
segments[i : i + chunk_size] for i in range(0, len(segments), chunk_size)
]
semaphore = asyncio.Semaphore(concurrency)
all_results: list[TranslatedSegment] = []
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
console=console,
) as progress:
task = progress.add_task(desc, total=len(chunks))
for batch_start in range(0, len(chunks), concurrency):
batch = chunks[batch_start : batch_start + concurrency]
context = ""
if all_results:
prev = all_results[-5:]
if stage == 2:
ctx_lines = [f"{s.translated}" for s in prev]
else:
ctx_lines = [f"{s.text}{s.translated}" for s in prev]
context = (
f"{self._get_context_header(stage)}\n"
+ "\n".join(ctx_lines)
+ "\n\n"
)
tasks = [
self._translate_chunk(
agent, chunk, batch_start + idx, context, semaphore, stage
)
for idx, chunk in enumerate(batch)
]
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, list):
all_results.extend(result)
progress.advance(task)
return all_results
async def run(self) -> None:
segments = self._load_segments()
if self.language == Language.EN_RU:
# Stage 1: Chinese -> English
console.print(f"[cyan]Stage 1: Chinese → English ({self.model_name})...[/]")
agent1 = self._create_agent(stage=1)
intermediate = await self._translate_parallel(
agent1, segments, stage=1, desc="Zh→En..."
)
# Stage 2: English -> Russian
console.print(f"[cyan]Stage 2: English → Russian ({self.model_name})...[/]")
agent2 = self._create_agent(stage=2)
translated = await self._translate_parallel(
agent2, intermediate, stage=2, desc="En→Ru..."
)
else:
console.print(
f"[cyan]Translating to {self.language.value.upper()} with {self.model_name}...[/]"
)
agent = self._create_agent()
translated = await self._translate_parallel(agent, segments)
with open(self.paths.translated_json, "w", encoding="utf-8") as f:
json.dump(
[s.model_dump() for s in translated], f, ensure_ascii=False, indent=2
)
missing = sum(1 for s in translated if s.translated == s.text)
if missing:
console.print(
f"[yellow]Warning: {missing} segments may not be translated[/]"
)
console.print(f"[green]✓ Translated {len(translated)} segments[/]")

162
dubbing/steps/tts.py Normal file
View File

@@ -0,0 +1,162 @@
import json
import re
import torch
import torchaudio
import edge_tts
from pathlib import Path
from rich.console import Console
from rich.progress import (
Progress,
SpinnerColumn,
TextColumn,
BarColumn,
TaskProgressColumn,
)
from dubbing.steps.base import PipelineStep
from dubbing.models import TranslatedSegment, Language
from dubbing.config import TTSConfig
console = Console()
class TTSStep(PipelineStep):
name = "TTS"
def __init__(self, paths, language: Language = Language.RU):
super().__init__(paths)
self.language = language
self._silero_model = None
def is_cached(self) -> bool:
if not self.paths.tts_dir.exists():
return False
return any(self.paths.tts_dir.glob("*.wav"))
def clean(self) -> None:
if self.paths.tts_dir.exists():
for f in self.paths.tts_dir.glob("*.wav"):
f.unlink()
def _load_translated(self) -> list[TranslatedSegment]:
with open(self.paths.translated_json, "r", encoding="utf-8") as f:
data = json.load(f)
return [TranslatedSegment(**s) for s in data]
def _clean_text_russian(self, text: str) -> str | None:
"""Clean text for Russian TTS (Silero)."""
text = re.sub(r"[^\w\s.,!?;:\-—–\'\"«»а-яА-ЯёЁ]", "", text)
text = re.sub(r"\s+", " ", text).strip()
if not text or not re.search(r"[а-яА-ЯёЁ]", text):
return None
return text
def _clean_text_english(self, text: str) -> str | None:
"""Clean text for English TTS (Piper)."""
text = re.sub(r"[^\w\s.,!?;:\-—–\'\"a-zA-Z0-9]", "", text)
text = re.sub(r"\s+", " ", text).strip()
if not text or not re.search(r"[a-zA-Z]", text):
return None
return text
def _clean_text(self, text: str) -> str | None:
"""Clean text based on language."""
if self.language == Language.EN:
return self._clean_text_english(text)
# RU and EN_RU both output Russian
return self._clean_text_russian(text)
def _load_silero_model(self):
"""Load Silero TTS model for Russian."""
if self._silero_model is None:
console.print("[dim]Loading Silero TTS model...[/]")
device = torch.device("cpu")
self._silero_model, _ = torch.hub.load(
repo_or_dir="snakers4/silero-models",
model="silero_tts",
language="ru",
speaker="v4_ru",
)
self._silero_model.to(device)
return self._silero_model
def _synthesize_russian(self, text: str, output_path: Path) -> bool:
"""Synthesize Russian speech using Silero."""
model = self._load_silero_model()
try:
audio = model.apply_tts(
text=text,
speaker=TTSConfig.RU_VOICE,
sample_rate=TTSConfig.RU_SAMPLE_RATE,
)
torchaudio.save(
str(output_path), audio.unsqueeze(0), TTSConfig.RU_SAMPLE_RATE
)
return True
except ValueError:
return False
async def _synthesize_english(self, text: str, output_path: Path) -> bool:
"""Synthesize English speech using Edge TTS."""
try:
mp3_path = output_path.with_suffix(".mp3")
communicate = edge_tts.Communicate(text, TTSConfig.EN_VOICE)
await communicate.save(str(mp3_path))
# Convert MP3 to WAV using pydub
from pydub import AudioSegment
audio = AudioSegment.from_mp3(str(mp3_path))
audio.export(str(output_path), format="wav")
mp3_path.unlink()
return True
except Exception as e:
console.print(f"[red]Edge TTS error: {e}[/]")
return False
async def _synthesize(self, text: str, output_path: Path) -> bool:
"""Synthesize speech based on language."""
if self.language == Language.EN:
return await self._synthesize_english(text, output_path)
# RU and EN_RU both use Russian TTS
return self._synthesize_russian(text, output_path)
async def run(self) -> None:
engine = (
TTSConfig.EN_ENGINE if self.language == Language.EN else TTSConfig.RU_ENGINE
) # EN_RU uses RU engine
console.print(f"[cyan]Generating TTS audio ({engine})...[/]")
self.paths.tts_dir.mkdir(parents=True, exist_ok=True)
segments = self._load_translated()
skipped = []
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
console=console,
) as progress:
task = progress.add_task("Generating speech...", total=len(segments))
for i, seg in enumerate(segments):
clean_text = self._clean_text(seg.translated)
if not clean_text:
skipped.append((i, seg.translated, "no_text"))
progress.advance(task)
continue
path = self.paths.tts_dir / f"seg_{i:04d}.wav"
success = await self._synthesize(clean_text, path)
if not success:
skipped.append((i, seg.translated, "tts_error"))
progress.advance(task)
generated = len(segments) - len(skipped)
console.print(f"[green]✓ Generated {generated} audio files[/]")
if skipped:
console.print(f"[yellow]Skipped {len(skipped)} segments:[/]")
for idx, text, reason in skipped:
console.print(f" [dim]{idx}:[/] {text[:60]}... [red]({reason})[/]")