Files
2026-02-01 16:07:59 +01:00

163 lines
5.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import json
import re
import torch
import torchaudio
import edge_tts
from pathlib import Path
from rich.console import Console
from rich.progress import (
Progress,
SpinnerColumn,
TextColumn,
BarColumn,
TaskProgressColumn,
)
from dubbing.steps.base import PipelineStep
from dubbing.models import TranslatedSegment, Language
from dubbing.config import TTSConfig
console = Console()
class TTSStep(PipelineStep):
name = "TTS"
def __init__(self, paths, language: Language = Language.RU):
super().__init__(paths)
self.language = language
self._silero_model = None
def is_cached(self) -> bool:
if not self.paths.tts_dir.exists():
return False
return any(self.paths.tts_dir.glob("*.wav"))
def clean(self) -> None:
if self.paths.tts_dir.exists():
for f in self.paths.tts_dir.glob("*.wav"):
f.unlink()
def _load_translated(self) -> list[TranslatedSegment]:
with open(self.paths.translated_json, "r", encoding="utf-8") as f:
data = json.load(f)
return [TranslatedSegment(**s) for s in data]
def _clean_text_russian(self, text: str) -> str | None:
"""Clean text for Russian TTS (Silero)."""
text = re.sub(r"[^\w\s.,!?;:\-—–\'\"«»а-яА-ЯёЁ]", "", text)
text = re.sub(r"\s+", " ", text).strip()
if not text or not re.search(r"[а-яА-ЯёЁ]", text):
return None
return text
def _clean_text_english(self, text: str) -> str | None:
"""Clean text for English TTS (Piper)."""
text = re.sub(r"[^\w\s.,!?;:\-—–\'\"a-zA-Z0-9]", "", text)
text = re.sub(r"\s+", " ", text).strip()
if not text or not re.search(r"[a-zA-Z]", text):
return None
return text
def _clean_text(self, text: str) -> str | None:
"""Clean text based on language."""
if self.language == Language.EN:
return self._clean_text_english(text)
# RU and EN_RU both output Russian
return self._clean_text_russian(text)
def _load_silero_model(self):
"""Load Silero TTS model for Russian."""
if self._silero_model is None:
console.print("[dim]Loading Silero TTS model...[/]")
device = torch.device("cpu")
self._silero_model, _ = torch.hub.load(
repo_or_dir="snakers4/silero-models",
model="silero_tts",
language="ru",
speaker="v4_ru",
)
self._silero_model.to(device)
return self._silero_model
def _synthesize_russian(self, text: str, output_path: Path) -> bool:
"""Synthesize Russian speech using Silero."""
model = self._load_silero_model()
try:
audio = model.apply_tts(
text=text,
speaker=TTSConfig.RU_VOICE,
sample_rate=TTSConfig.RU_SAMPLE_RATE,
)
torchaudio.save(
str(output_path), audio.unsqueeze(0), TTSConfig.RU_SAMPLE_RATE
)
return True
except ValueError:
return False
async def _synthesize_english(self, text: str, output_path: Path) -> bool:
"""Synthesize English speech using Edge TTS."""
try:
mp3_path = output_path.with_suffix(".mp3")
communicate = edge_tts.Communicate(text, TTSConfig.EN_VOICE)
await communicate.save(str(mp3_path))
# Convert MP3 to WAV using pydub
from pydub import AudioSegment
audio = AudioSegment.from_mp3(str(mp3_path))
audio.export(str(output_path), format="wav")
mp3_path.unlink()
return True
except Exception as e:
console.print(f"[red]Edge TTS error: {e}[/]")
return False
async def _synthesize(self, text: str, output_path: Path) -> bool:
"""Synthesize speech based on language."""
if self.language == Language.EN:
return await self._synthesize_english(text, output_path)
# RU and EN_RU both use Russian TTS
return self._synthesize_russian(text, output_path)
async def run(self) -> None:
engine = (
TTSConfig.EN_ENGINE if self.language == Language.EN else TTSConfig.RU_ENGINE
) # EN_RU uses RU engine
console.print(f"[cyan]Generating TTS audio ({engine})...[/]")
self.paths.tts_dir.mkdir(parents=True, exist_ok=True)
segments = self._load_translated()
skipped = []
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
console=console,
) as progress:
task = progress.add_task("Generating speech...", total=len(segments))
for i, seg in enumerate(segments):
clean_text = self._clean_text(seg.translated)
if not clean_text:
skipped.append((i, seg.translated, "no_text"))
progress.advance(task)
continue
path = self.paths.tts_dir / f"seg_{i:04d}.wav"
success = await self._synthesize(clean_text, path)
if not success:
skipped.append((i, seg.translated, "tts_error"))
progress.advance(task)
generated = len(segments) - len(skipped)
console.print(f"[green]✓ Generated {generated} audio files[/]")
if skipped:
console.print(f"[yellow]Skipped {len(skipped)} segments:[/]")
for idx, text, reason in skipped:
console.print(f" [dim]{idx}:[/] {text[:60]}... [red]({reason})[/]")