163 lines
5.7 KiB
Python
163 lines
5.7 KiB
Python
import json
|
||
import re
|
||
import torch
|
||
import torchaudio
|
||
import edge_tts
|
||
from pathlib import Path
|
||
from rich.console import Console
|
||
from rich.progress import (
|
||
Progress,
|
||
SpinnerColumn,
|
||
TextColumn,
|
||
BarColumn,
|
||
TaskProgressColumn,
|
||
)
|
||
from dubbing.steps.base import PipelineStep
|
||
from dubbing.models import TranslatedSegment, Language
|
||
from dubbing.config import TTSConfig
|
||
|
||
console = Console()
|
||
|
||
|
||
class TTSStep(PipelineStep):
|
||
name = "TTS"
|
||
|
||
def __init__(self, paths, language: Language = Language.RU):
|
||
super().__init__(paths)
|
||
self.language = language
|
||
self._silero_model = None
|
||
|
||
def is_cached(self) -> bool:
|
||
if not self.paths.tts_dir.exists():
|
||
return False
|
||
return any(self.paths.tts_dir.glob("*.wav"))
|
||
|
||
def clean(self) -> None:
|
||
if self.paths.tts_dir.exists():
|
||
for f in self.paths.tts_dir.glob("*.wav"):
|
||
f.unlink()
|
||
|
||
def _load_translated(self) -> list[TranslatedSegment]:
|
||
with open(self.paths.translated_json, "r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
return [TranslatedSegment(**s) for s in data]
|
||
|
||
def _clean_text_russian(self, text: str) -> str | None:
|
||
"""Clean text for Russian TTS (Silero)."""
|
||
text = re.sub(r"[^\w\s.,!?;:\-—–\'\"«»а-яА-ЯёЁ]", "", text)
|
||
text = re.sub(r"\s+", " ", text).strip()
|
||
if not text or not re.search(r"[а-яА-ЯёЁ]", text):
|
||
return None
|
||
return text
|
||
|
||
def _clean_text_english(self, text: str) -> str | None:
|
||
"""Clean text for English TTS (Piper)."""
|
||
text = re.sub(r"[^\w\s.,!?;:\-—–\'\"a-zA-Z0-9]", "", text)
|
||
text = re.sub(r"\s+", " ", text).strip()
|
||
if not text or not re.search(r"[a-zA-Z]", text):
|
||
return None
|
||
return text
|
||
|
||
def _clean_text(self, text: str) -> str | None:
|
||
"""Clean text based on language."""
|
||
if self.language == Language.EN:
|
||
return self._clean_text_english(text)
|
||
# RU and EN_RU both output Russian
|
||
return self._clean_text_russian(text)
|
||
|
||
def _load_silero_model(self):
|
||
"""Load Silero TTS model for Russian."""
|
||
if self._silero_model is None:
|
||
console.print("[dim]Loading Silero TTS model...[/]")
|
||
device = torch.device("cpu")
|
||
self._silero_model, _ = torch.hub.load(
|
||
repo_or_dir="snakers4/silero-models",
|
||
model="silero_tts",
|
||
language="ru",
|
||
speaker="v4_ru",
|
||
)
|
||
self._silero_model.to(device)
|
||
return self._silero_model
|
||
|
||
def _synthesize_russian(self, text: str, output_path: Path) -> bool:
|
||
"""Synthesize Russian speech using Silero."""
|
||
model = self._load_silero_model()
|
||
try:
|
||
audio = model.apply_tts(
|
||
text=text,
|
||
speaker=TTSConfig.RU_VOICE,
|
||
sample_rate=TTSConfig.RU_SAMPLE_RATE,
|
||
)
|
||
torchaudio.save(
|
||
str(output_path), audio.unsqueeze(0), TTSConfig.RU_SAMPLE_RATE
|
||
)
|
||
return True
|
||
except ValueError:
|
||
return False
|
||
|
||
async def _synthesize_english(self, text: str, output_path: Path) -> bool:
|
||
"""Synthesize English speech using Edge TTS."""
|
||
try:
|
||
mp3_path = output_path.with_suffix(".mp3")
|
||
communicate = edge_tts.Communicate(text, TTSConfig.EN_VOICE)
|
||
await communicate.save(str(mp3_path))
|
||
|
||
# Convert MP3 to WAV using pydub
|
||
from pydub import AudioSegment
|
||
|
||
audio = AudioSegment.from_mp3(str(mp3_path))
|
||
audio.export(str(output_path), format="wav")
|
||
mp3_path.unlink()
|
||
return True
|
||
except Exception as e:
|
||
console.print(f"[red]Edge TTS error: {e}[/]")
|
||
return False
|
||
|
||
async def _synthesize(self, text: str, output_path: Path) -> bool:
|
||
"""Synthesize speech based on language."""
|
||
if self.language == Language.EN:
|
||
return await self._synthesize_english(text, output_path)
|
||
# RU and EN_RU both use Russian TTS
|
||
return self._synthesize_russian(text, output_path)
|
||
|
||
async def run(self) -> None:
|
||
engine = (
|
||
TTSConfig.EN_ENGINE if self.language == Language.EN else TTSConfig.RU_ENGINE
|
||
) # EN_RU uses RU engine
|
||
console.print(f"[cyan]Generating TTS audio ({engine})...[/]")
|
||
|
||
self.paths.tts_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
segments = self._load_translated()
|
||
|
||
skipped = []
|
||
with Progress(
|
||
SpinnerColumn(),
|
||
TextColumn("[progress.description]{task.description}"),
|
||
BarColumn(),
|
||
TaskProgressColumn(),
|
||
console=console,
|
||
) as progress:
|
||
task = progress.add_task("Generating speech...", total=len(segments))
|
||
|
||
for i, seg in enumerate(segments):
|
||
clean_text = self._clean_text(seg.translated)
|
||
if not clean_text:
|
||
skipped.append((i, seg.translated, "no_text"))
|
||
progress.advance(task)
|
||
continue
|
||
|
||
path = self.paths.tts_dir / f"seg_{i:04d}.wav"
|
||
success = await self._synthesize(clean_text, path)
|
||
if not success:
|
||
skipped.append((i, seg.translated, "tts_error"))
|
||
|
||
progress.advance(task)
|
||
|
||
generated = len(segments) - len(skipped)
|
||
console.print(f"[green]✓ Generated {generated} audio files[/]")
|
||
if skipped:
|
||
console.print(f"[yellow]Skipped {len(skipped)} segments:[/]")
|
||
for idx, text, reason in skipped:
|
||
console.print(f" [dim]{idx}:[/] {text[:60]}... [red]({reason})[/]")
|