import json import re import torch import torchaudio import edge_tts from pathlib import Path from rich.console import Console from rich.progress import ( Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, ) from dubbing.steps.base import PipelineStep from dubbing.models import TranslatedSegment, Language from dubbing.config import TTSConfig console = Console() class TTSStep(PipelineStep): name = "TTS" def __init__(self, paths, language: Language = Language.RU): super().__init__(paths) self.language = language self._silero_model = None def is_cached(self) -> bool: if not self.paths.tts_dir.exists(): return False return any(self.paths.tts_dir.glob("*.wav")) def clean(self) -> None: if self.paths.tts_dir.exists(): for f in self.paths.tts_dir.glob("*.wav"): f.unlink() def _load_translated(self) -> list[TranslatedSegment]: with open(self.paths.translated_json, "r", encoding="utf-8") as f: data = json.load(f) return [TranslatedSegment(**s) for s in data] def _clean_text_russian(self, text: str) -> str | None: """Clean text for Russian TTS (Silero).""" text = re.sub(r"[^\w\s.,!?;:\-—–\'\"«»а-яА-ЯёЁ]", "", text) text = re.sub(r"\s+", " ", text).strip() if not text or not re.search(r"[а-яА-ЯёЁ]", text): return None return text def _clean_text_english(self, text: str) -> str | None: """Clean text for English TTS (Piper).""" text = re.sub(r"[^\w\s.,!?;:\-—–\'\"a-zA-Z0-9]", "", text) text = re.sub(r"\s+", " ", text).strip() if not text or not re.search(r"[a-zA-Z]", text): return None return text def _clean_text(self, text: str) -> str | None: """Clean text based on language.""" if self.language == Language.EN: return self._clean_text_english(text) # RU and EN_RU both output Russian return self._clean_text_russian(text) def _load_silero_model(self): """Load Silero TTS model for Russian.""" if self._silero_model is None: console.print("[dim]Loading Silero TTS model...[/]") device = torch.device("cpu") self._silero_model, _ = torch.hub.load( repo_or_dir="snakers4/silero-models", model="silero_tts", language="ru", speaker="v4_ru", ) self._silero_model.to(device) return self._silero_model def _synthesize_russian(self, text: str, output_path: Path) -> bool: """Synthesize Russian speech using Silero.""" model = self._load_silero_model() try: audio = model.apply_tts( text=text, speaker=TTSConfig.RU_VOICE, sample_rate=TTSConfig.RU_SAMPLE_RATE, ) torchaudio.save( str(output_path), audio.unsqueeze(0), TTSConfig.RU_SAMPLE_RATE ) return True except ValueError: return False async def _synthesize_english(self, text: str, output_path: Path) -> bool: """Synthesize English speech using Edge TTS.""" try: mp3_path = output_path.with_suffix(".mp3") communicate = edge_tts.Communicate(text, TTSConfig.EN_VOICE) await communicate.save(str(mp3_path)) # Convert MP3 to WAV using pydub from pydub import AudioSegment audio = AudioSegment.from_mp3(str(mp3_path)) audio.export(str(output_path), format="wav") mp3_path.unlink() return True except Exception as e: console.print(f"[red]Edge TTS error: {e}[/]") return False async def _synthesize(self, text: str, output_path: Path) -> bool: """Synthesize speech based on language.""" if self.language == Language.EN: return await self._synthesize_english(text, output_path) # RU and EN_RU both use Russian TTS return self._synthesize_russian(text, output_path) async def run(self) -> None: engine = ( TTSConfig.EN_ENGINE if self.language == Language.EN else TTSConfig.RU_ENGINE ) # EN_RU uses RU engine console.print(f"[cyan]Generating TTS audio ({engine})...[/]") self.paths.tts_dir.mkdir(parents=True, exist_ok=True) segments = self._load_translated() skipped = [] with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), console=console, ) as progress: task = progress.add_task("Generating speech...", total=len(segments)) for i, seg in enumerate(segments): clean_text = self._clean_text(seg.translated) if not clean_text: skipped.append((i, seg.translated, "no_text")) progress.advance(task) continue path = self.paths.tts_dir / f"seg_{i:04d}.wav" success = await self._synthesize(clean_text, path) if not success: skipped.append((i, seg.translated, "tts_error")) progress.advance(task) generated = len(segments) - len(skipped) console.print(f"[green]✓ Generated {generated} audio files[/]") if skipped: console.print(f"[yellow]Skipped {len(skipped)} segments:[/]") for idx, text, reason in skipped: console.print(f" [dim]{idx}:[/] {text[:60]}... [red]({reason})[/]")