feat: mvp
This commit is contained in:
162
dubbing/steps/tts.py
Normal file
162
dubbing/steps/tts.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import json
|
||||
import re
|
||||
import torch
|
||||
import torchaudio
|
||||
import edge_tts
|
||||
from pathlib import Path
|
||||
from rich.console import Console
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
BarColumn,
|
||||
TaskProgressColumn,
|
||||
)
|
||||
from dubbing.steps.base import PipelineStep
|
||||
from dubbing.models import TranslatedSegment, Language
|
||||
from dubbing.config import TTSConfig
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class TTSStep(PipelineStep):
|
||||
name = "TTS"
|
||||
|
||||
def __init__(self, paths, language: Language = Language.RU):
|
||||
super().__init__(paths)
|
||||
self.language = language
|
||||
self._silero_model = None
|
||||
|
||||
def is_cached(self) -> bool:
|
||||
if not self.paths.tts_dir.exists():
|
||||
return False
|
||||
return any(self.paths.tts_dir.glob("*.wav"))
|
||||
|
||||
def clean(self) -> None:
|
||||
if self.paths.tts_dir.exists():
|
||||
for f in self.paths.tts_dir.glob("*.wav"):
|
||||
f.unlink()
|
||||
|
||||
def _load_translated(self) -> list[TranslatedSegment]:
|
||||
with open(self.paths.translated_json, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return [TranslatedSegment(**s) for s in data]
|
||||
|
||||
def _clean_text_russian(self, text: str) -> str | None:
|
||||
"""Clean text for Russian TTS (Silero)."""
|
||||
text = re.sub(r"[^\w\s.,!?;:\-—–\'\"«»а-яА-ЯёЁ]", "", text)
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
if not text or not re.search(r"[а-яА-ЯёЁ]", text):
|
||||
return None
|
||||
return text
|
||||
|
||||
def _clean_text_english(self, text: str) -> str | None:
|
||||
"""Clean text for English TTS (Piper)."""
|
||||
text = re.sub(r"[^\w\s.,!?;:\-—–\'\"a-zA-Z0-9]", "", text)
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
if not text or not re.search(r"[a-zA-Z]", text):
|
||||
return None
|
||||
return text
|
||||
|
||||
def _clean_text(self, text: str) -> str | None:
|
||||
"""Clean text based on language."""
|
||||
if self.language == Language.EN:
|
||||
return self._clean_text_english(text)
|
||||
# RU and EN_RU both output Russian
|
||||
return self._clean_text_russian(text)
|
||||
|
||||
def _load_silero_model(self):
|
||||
"""Load Silero TTS model for Russian."""
|
||||
if self._silero_model is None:
|
||||
console.print("[dim]Loading Silero TTS model...[/]")
|
||||
device = torch.device("cpu")
|
||||
self._silero_model, _ = torch.hub.load(
|
||||
repo_or_dir="snakers4/silero-models",
|
||||
model="silero_tts",
|
||||
language="ru",
|
||||
speaker="v4_ru",
|
||||
)
|
||||
self._silero_model.to(device)
|
||||
return self._silero_model
|
||||
|
||||
def _synthesize_russian(self, text: str, output_path: Path) -> bool:
|
||||
"""Synthesize Russian speech using Silero."""
|
||||
model = self._load_silero_model()
|
||||
try:
|
||||
audio = model.apply_tts(
|
||||
text=text,
|
||||
speaker=TTSConfig.RU_VOICE,
|
||||
sample_rate=TTSConfig.RU_SAMPLE_RATE,
|
||||
)
|
||||
torchaudio.save(
|
||||
str(output_path), audio.unsqueeze(0), TTSConfig.RU_SAMPLE_RATE
|
||||
)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
async def _synthesize_english(self, text: str, output_path: Path) -> bool:
|
||||
"""Synthesize English speech using Edge TTS."""
|
||||
try:
|
||||
mp3_path = output_path.with_suffix(".mp3")
|
||||
communicate = edge_tts.Communicate(text, TTSConfig.EN_VOICE)
|
||||
await communicate.save(str(mp3_path))
|
||||
|
||||
# Convert MP3 to WAV using pydub
|
||||
from pydub import AudioSegment
|
||||
|
||||
audio = AudioSegment.from_mp3(str(mp3_path))
|
||||
audio.export(str(output_path), format="wav")
|
||||
mp3_path.unlink()
|
||||
return True
|
||||
except Exception as e:
|
||||
console.print(f"[red]Edge TTS error: {e}[/]")
|
||||
return False
|
||||
|
||||
async def _synthesize(self, text: str, output_path: Path) -> bool:
|
||||
"""Synthesize speech based on language."""
|
||||
if self.language == Language.EN:
|
||||
return await self._synthesize_english(text, output_path)
|
||||
# RU and EN_RU both use Russian TTS
|
||||
return self._synthesize_russian(text, output_path)
|
||||
|
||||
async def run(self) -> None:
|
||||
engine = (
|
||||
TTSConfig.EN_ENGINE if self.language == Language.EN else TTSConfig.RU_ENGINE
|
||||
) # EN_RU uses RU engine
|
||||
console.print(f"[cyan]Generating TTS audio ({engine})...[/]")
|
||||
|
||||
self.paths.tts_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
segments = self._load_translated()
|
||||
|
||||
skipped = []
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
console=console,
|
||||
) as progress:
|
||||
task = progress.add_task("Generating speech...", total=len(segments))
|
||||
|
||||
for i, seg in enumerate(segments):
|
||||
clean_text = self._clean_text(seg.translated)
|
||||
if not clean_text:
|
||||
skipped.append((i, seg.translated, "no_text"))
|
||||
progress.advance(task)
|
||||
continue
|
||||
|
||||
path = self.paths.tts_dir / f"seg_{i:04d}.wav"
|
||||
success = await self._synthesize(clean_text, path)
|
||||
if not success:
|
||||
skipped.append((i, seg.translated, "tts_error"))
|
||||
|
||||
progress.advance(task)
|
||||
|
||||
generated = len(segments) - len(skipped)
|
||||
console.print(f"[green]✓ Generated {generated} audio files[/]")
|
||||
if skipped:
|
||||
console.print(f"[yellow]Skipped {len(skipped)} segments:[/]")
|
||||
for idx, text, reason in skipped:
|
||||
console.print(f" [dim]{idx}:[/] {text[:60]}... [red]({reason})[/]")
|
||||
Reference in New Issue
Block a user