feat: mvp
This commit is contained in:
15
dubbing/steps/__init__.py
Normal file
15
dubbing/steps/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from dubbing.steps.extract_audio import ExtractAudioStep
|
||||
from dubbing.steps.asr import ASRStep
|
||||
from dubbing.steps.translate import TranslateStep
|
||||
from dubbing.steps.tts import TTSStep
|
||||
from dubbing.steps.mix import MixStep
|
||||
from dubbing.steps.finalize import FinalizeStep
|
||||
|
||||
__all__ = [
|
||||
"ExtractAudioStep",
|
||||
"ASRStep",
|
||||
"TranslateStep",
|
||||
"TTSStep",
|
||||
"MixStep",
|
||||
"FinalizeStep",
|
||||
]
|
||||
78
dubbing/steps/asr.py
Normal file
78
dubbing/steps/asr.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import json
|
||||
from rich.console import Console
|
||||
from dubbing.steps.base import PipelineStep
|
||||
from dubbing.models import Segment
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class ASRStep(PipelineStep):
|
||||
name = "ASR"
|
||||
|
||||
def is_cached(self) -> bool:
|
||||
return self.paths.segments_json.exists()
|
||||
|
||||
def clean(self) -> None:
|
||||
if self.paths.segments_json.exists():
|
||||
self.paths.segments_json.unlink()
|
||||
|
||||
def _group_chinese_segments(
|
||||
self, words: list[str], timestamps: list[list[int]]
|
||||
) -> list[Segment]:
|
||||
segments = []
|
||||
current_text = ""
|
||||
current_start = None
|
||||
current_end = None
|
||||
punctuation = {"。", ",", "!", "?", ";", ":", "…", "、"}
|
||||
|
||||
for word, ts in zip(words, timestamps):
|
||||
if current_start is None:
|
||||
current_start = ts[0]
|
||||
current_text += word
|
||||
current_end = ts[1]
|
||||
|
||||
if word in punctuation:
|
||||
segments.append(
|
||||
Segment(start=current_start, end=current_end, text=current_text)
|
||||
)
|
||||
current_text = ""
|
||||
current_start = None
|
||||
|
||||
if current_text:
|
||||
segments.append(
|
||||
Segment(start=current_start, end=current_end, text=current_text)
|
||||
)
|
||||
|
||||
return segments
|
||||
|
||||
async def run(self) -> None:
|
||||
console.print("[cyan]Running speech recognition...[/]")
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(
|
||||
model="iic/SenseVoiceSmall",
|
||||
device="mps",
|
||||
vad_model="fsmn-vad",
|
||||
vad_kwargs={"max_single_segment_time": 30000},
|
||||
)
|
||||
|
||||
result = model.generate(
|
||||
input=str(self.paths.source_audio),
|
||||
language="zh",
|
||||
use_itn=True,
|
||||
batch_size_s=60,
|
||||
merge_length_s=15,
|
||||
output_timestamp=True,
|
||||
)
|
||||
|
||||
segments = self._group_chinese_segments(
|
||||
result[0]["words"], result[0]["timestamp"]
|
||||
)
|
||||
|
||||
with open(self.paths.segments_json, "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
[s.model_dump() for s in segments], f, ensure_ascii=False, indent=2
|
||||
)
|
||||
|
||||
console.print(f"[green]✓ Found {len(segments)} segments[/]")
|
||||
21
dubbing/steps/base.py
Normal file
21
dubbing/steps/base.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dubbing.models import ProjectPaths
|
||||
|
||||
|
||||
class PipelineStep(ABC):
|
||||
name: str = "Step"
|
||||
|
||||
def __init__(self, paths: ProjectPaths):
|
||||
self.paths = paths
|
||||
|
||||
@abstractmethod
|
||||
def is_cached(self) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clean(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def run(self) -> None:
|
||||
pass
|
||||
38
dubbing/steps/extract_audio.py
Normal file
38
dubbing/steps/extract_audio.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import subprocess
|
||||
from rich.console import Console
|
||||
from dubbing.steps.base import PipelineStep
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class ExtractAudioStep(PipelineStep):
|
||||
name = "Extract audio"
|
||||
|
||||
def is_cached(self) -> bool:
|
||||
return self.paths.source_audio.exists()
|
||||
|
||||
def clean(self) -> None:
|
||||
if self.paths.source_audio.exists():
|
||||
self.paths.source_audio.unlink()
|
||||
|
||||
async def run(self) -> None:
|
||||
console.print("[cyan]Extracting audio from video...[/]")
|
||||
|
||||
subprocess.run(
|
||||
[
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
str(self.paths.source_video),
|
||||
"-vn",
|
||||
"-acodec",
|
||||
"libmp3lame",
|
||||
"-q:a",
|
||||
"2",
|
||||
str(self.paths.source_audio),
|
||||
],
|
||||
capture_output=True,
|
||||
check=True,
|
||||
)
|
||||
|
||||
console.print("[green]✓ Audio extracted[/]")
|
||||
44
dubbing/steps/finalize.py
Normal file
44
dubbing/steps/finalize.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import subprocess
|
||||
from rich.console import Console
|
||||
from dubbing.steps.base import PipelineStep
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class FinalizeStep(PipelineStep):
|
||||
name = "Finalize"
|
||||
|
||||
def is_cached(self) -> bool:
|
||||
return self.paths.result_video.exists()
|
||||
|
||||
def clean(self) -> None:
|
||||
if self.paths.result_video.exists():
|
||||
self.paths.result_video.unlink()
|
||||
|
||||
async def run(self) -> None:
|
||||
console.print("[cyan]Creating final video...[/]")
|
||||
|
||||
subprocess.run(
|
||||
[
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
str(self.paths.source_video),
|
||||
"-i",
|
||||
str(self.paths.dubbed_audio),
|
||||
"-map",
|
||||
"0:v",
|
||||
"-map",
|
||||
"1:a",
|
||||
"-c:v",
|
||||
"copy",
|
||||
"-c:a",
|
||||
"copy",
|
||||
"-shortest",
|
||||
str(self.paths.result_video),
|
||||
],
|
||||
capture_output=True,
|
||||
check=True,
|
||||
)
|
||||
|
||||
console.print(f"[green]✓ Created {self.paths.result_video}[/]")
|
||||
156
dubbing/steps/mix.py
Normal file
156
dubbing/steps/mix.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import json
|
||||
import subprocess
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
from pydub import AudioSegment
|
||||
from rich.console import Console
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
BarColumn,
|
||||
TaskProgressColumn,
|
||||
)
|
||||
from dubbing.steps.base import PipelineStep
|
||||
from dubbing.models import TranslatedSegment
|
||||
from dubbing.config import settings
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class MixStep(PipelineStep):
|
||||
name = "Mix"
|
||||
|
||||
def is_cached(self) -> bool:
|
||||
return self.paths.dubbed_audio.exists()
|
||||
|
||||
def clean(self) -> None:
|
||||
if self.paths.dubbed_audio.exists():
|
||||
self.paths.dubbed_audio.unlink()
|
||||
for f in self.paths.tts_dir.glob("*_fast.wav"):
|
||||
f.unlink()
|
||||
|
||||
def _load_translated(self) -> list[TranslatedSegment]:
|
||||
with open(self.paths.translated_json, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return [TranslatedSegment(**s) for s in data]
|
||||
|
||||
def _speedup_file(self, input_path: Path, output_path: Path, speed: float) -> None:
|
||||
filters = []
|
||||
remaining = speed
|
||||
while remaining > 1.0:
|
||||
if remaining > 2.0:
|
||||
filters.append("atempo=2.0")
|
||||
remaining /= 2.0
|
||||
else:
|
||||
filters.append(f"atempo={remaining:.3f}")
|
||||
remaining = 1.0
|
||||
|
||||
subprocess.run(
|
||||
[
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
str(input_path),
|
||||
"-filter:a",
|
||||
",".join(filters),
|
||||
"-vn",
|
||||
str(output_path),
|
||||
],
|
||||
capture_output=True,
|
||||
)
|
||||
|
||||
def _process_segment(
|
||||
self, i: int, seg: TranslatedSegment, target_sr: int, channels: int
|
||||
) -> tuple[int, np.ndarray | None]:
|
||||
audio_path = self.paths.tts_dir / f"seg_{i:04d}.wav"
|
||||
fast_path = self.paths.tts_dir / f"seg_{i:04d}_fast.wav"
|
||||
|
||||
if not audio_path.exists():
|
||||
return (int(seg.start), None)
|
||||
|
||||
if fast_path.exists():
|
||||
data, sr = sf.read(str(fast_path), dtype="int16")
|
||||
else:
|
||||
data, sr = sf.read(str(audio_path), dtype="int16")
|
||||
duration_ms = len(data) / sr * 1000
|
||||
available_ms = seg.end - seg.start
|
||||
|
||||
if duration_ms > available_ms > 100:
|
||||
speedup_ratio = min(duration_ms / available_ms, settings.max_speedup)
|
||||
self._speedup_file(audio_path, fast_path, speedup_ratio)
|
||||
data, sr = sf.read(str(fast_path), dtype="int16")
|
||||
|
||||
if len(data.shape) == 1 and channels == 2:
|
||||
data = np.column_stack([data, data])
|
||||
elif len(data.shape) == 2 and channels == 1:
|
||||
data = data[:, 0]
|
||||
|
||||
if sr != target_sr:
|
||||
ratio = target_sr / sr
|
||||
new_len = int(len(data) * ratio)
|
||||
indices = np.linspace(0, len(data) - 1, new_len).astype(int)
|
||||
data = data[indices]
|
||||
|
||||
return (int(seg.start), data)
|
||||
|
||||
async def run(self) -> None:
|
||||
console.print("[cyan]Mixing audio tracks...[/]")
|
||||
|
||||
segments = self._load_translated()
|
||||
|
||||
original = AudioSegment.from_mp3(str(self.paths.source_audio))
|
||||
original_quiet = original + settings.original_volume_db
|
||||
original_samples = np.array(
|
||||
original_quiet.get_array_of_samples(), dtype=np.float32
|
||||
)
|
||||
sample_rate = original_quiet.frame_rate
|
||||
channels = original_quiet.channels
|
||||
|
||||
if channels == 2:
|
||||
original_samples = original_samples.reshape(-1, 2)
|
||||
|
||||
dubbed_samples = np.zeros_like(original_samples, dtype=np.float32)
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
console=console,
|
||||
) as progress:
|
||||
task = progress.add_task("Processing & mixing...", total=len(segments))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=8) as executor:
|
||||
futures = {
|
||||
executor.submit(
|
||||
self._process_segment, i, seg, sample_rate, channels
|
||||
): i
|
||||
for i, seg in enumerate(segments)
|
||||
}
|
||||
for future in as_completed(futures):
|
||||
position_ms, data = future.result()
|
||||
if data is not None:
|
||||
start_sample = int(position_ms * sample_rate / 1000)
|
||||
end_sample = min(start_sample + len(data), len(dubbed_samples))
|
||||
length = end_sample - start_sample
|
||||
dubbed_samples[start_sample:end_sample] += data[:length].astype(
|
||||
np.float32
|
||||
)
|
||||
progress.advance(task)
|
||||
|
||||
console.print("[cyan]Exporting MP3...[/]")
|
||||
final_samples = original_samples + dubbed_samples
|
||||
final_samples = np.clip(final_samples, -32768, 32767).astype(np.int16)
|
||||
|
||||
final = AudioSegment(
|
||||
final_samples.tobytes(),
|
||||
frame_rate=sample_rate,
|
||||
sample_width=2,
|
||||
channels=channels,
|
||||
)
|
||||
final.export(str(self.paths.dubbed_audio), format="mp3")
|
||||
|
||||
console.print("[green]✓ Audio mixed[/]")
|
||||
215
dubbing/steps/translate.py
Normal file
215
dubbing/steps/translate.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import asyncio
|
||||
import json
|
||||
from rich.console import Console
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
BarColumn,
|
||||
TaskProgressColumn,
|
||||
)
|
||||
from pydantic_ai import Agent, NativeOutput
|
||||
from pydantic_ai.models.google import GoogleModel
|
||||
from pydantic_ai.providers.google import GoogleProvider
|
||||
from dubbing.steps.base import PipelineStep
|
||||
from dubbing.models import Segment, TranslatedSegment, TranslationBatch, Language
|
||||
from dubbing.config import settings, TranslationPrompts
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class TranslateStep(PipelineStep):
|
||||
name = "Translate"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
paths,
|
||||
model_name: str = "gemini-2.0-flash-lite",
|
||||
language: Language = Language.RU,
|
||||
):
|
||||
super().__init__(paths)
|
||||
self.model_name = model_name
|
||||
self.language = language
|
||||
|
||||
def is_cached(self) -> bool:
|
||||
return self.paths.translated_json.exists()
|
||||
|
||||
def clean(self) -> None:
|
||||
if self.paths.translated_json.exists():
|
||||
self.paths.translated_json.unlink()
|
||||
|
||||
def _load_segments(self) -> list[Segment]:
|
||||
with open(self.paths.segments_json, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return [Segment(**s) for s in data]
|
||||
|
||||
def _get_system_prompt(self, stage: int = 1) -> str:
|
||||
if self.language == Language.EN:
|
||||
return TranslationPrompts.EN
|
||||
elif self.language == Language.EN_RU:
|
||||
return (
|
||||
TranslationPrompts.EN_RU_STAGE1
|
||||
if stage == 1
|
||||
else TranslationPrompts.EN_RU_STAGE2
|
||||
)
|
||||
return TranslationPrompts.RU
|
||||
|
||||
def _get_translate_command(self, stage: int = 1) -> str:
|
||||
if self.language == Language.EN:
|
||||
return "Translate:"
|
||||
elif self.language == Language.EN_RU:
|
||||
return "Translate:" if stage == 1 else "Переведи:"
|
||||
return "Переведи:"
|
||||
|
||||
def _get_context_header(self, stage: int = 1) -> str:
|
||||
if self.language == Language.EN:
|
||||
return "Context:"
|
||||
elif self.language == Language.EN_RU:
|
||||
return "Context:" if stage == 1 else "Контекст:"
|
||||
return "Контекст:"
|
||||
|
||||
def _create_agent(self, stage: int = 1) -> Agent:
|
||||
provider = GoogleProvider(api_key=settings.gemini_api_key)
|
||||
model = GoogleModel(self.model_name, provider=provider)
|
||||
|
||||
return Agent(
|
||||
model,
|
||||
output_type=NativeOutput(TranslationBatch),
|
||||
system_prompt=self._get_system_prompt(stage),
|
||||
)
|
||||
|
||||
async def _translate_chunk(
|
||||
self,
|
||||
agent: Agent,
|
||||
chunk: list[Segment],
|
||||
chunk_idx: int,
|
||||
context: str,
|
||||
semaphore: asyncio.Semaphore,
|
||||
stage: int = 1,
|
||||
) -> list[TranslatedSegment]:
|
||||
async with semaphore:
|
||||
# For stage 2, use translated field as source
|
||||
if stage == 2:
|
||||
items = "\n".join([f"{i}: {s.translated}" for i, s in enumerate(chunk)])
|
||||
else:
|
||||
items = "\n".join([f"{i}: {s.text}" for i, s in enumerate(chunk)])
|
||||
prompt = f"{context}{self._get_translate_command(stage)}\n\n{items}"
|
||||
|
||||
try:
|
||||
result = await agent.run(prompt)
|
||||
translated = []
|
||||
for i, seg in enumerate(chunk):
|
||||
translated_text = ""
|
||||
for t in result.output.translations:
|
||||
if t.id == i:
|
||||
translated_text = t.translated
|
||||
break
|
||||
translated.append(
|
||||
TranslatedSegment(
|
||||
start=seg.start,
|
||||
end=seg.end,
|
||||
text=seg.text,
|
||||
translated=translated_text
|
||||
or (seg.translated if stage == 2 else seg.text),
|
||||
)
|
||||
)
|
||||
return translated
|
||||
except Exception as e:
|
||||
console.print(f"[red]Chunk {chunk_idx} error: {e}[/]")
|
||||
fallback = seg.translated if stage == 2 else seg.text
|
||||
return [
|
||||
TranslatedSegment(
|
||||
start=s.start, end=s.end, text=s.text, translated=fallback
|
||||
)
|
||||
for s in chunk
|
||||
]
|
||||
|
||||
async def _translate_parallel(
|
||||
self, agent: Agent, segments: list, stage: int = 1, desc: str = "Translating..."
|
||||
) -> list[TranslatedSegment]:
|
||||
chunk_size = settings.translation_chunk_size
|
||||
concurrency = settings.translation_concurrency
|
||||
|
||||
chunks = [
|
||||
segments[i : i + chunk_size] for i in range(0, len(segments), chunk_size)
|
||||
]
|
||||
semaphore = asyncio.Semaphore(concurrency)
|
||||
|
||||
all_results: list[TranslatedSegment] = []
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
console=console,
|
||||
) as progress:
|
||||
task = progress.add_task(desc, total=len(chunks))
|
||||
|
||||
for batch_start in range(0, len(chunks), concurrency):
|
||||
batch = chunks[batch_start : batch_start + concurrency]
|
||||
|
||||
context = ""
|
||||
if all_results:
|
||||
prev = all_results[-5:]
|
||||
if stage == 2:
|
||||
ctx_lines = [f"• {s.translated}" for s in prev]
|
||||
else:
|
||||
ctx_lines = [f"• {s.text} → {s.translated}" for s in prev]
|
||||
context = (
|
||||
f"{self._get_context_header(stage)}\n"
|
||||
+ "\n".join(ctx_lines)
|
||||
+ "\n\n"
|
||||
)
|
||||
|
||||
tasks = [
|
||||
self._translate_chunk(
|
||||
agent, chunk, batch_start + idx, context, semaphore, stage
|
||||
)
|
||||
for idx, chunk in enumerate(batch)
|
||||
]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for result in results:
|
||||
if isinstance(result, list):
|
||||
all_results.extend(result)
|
||||
progress.advance(task)
|
||||
|
||||
return all_results
|
||||
|
||||
async def run(self) -> None:
|
||||
segments = self._load_segments()
|
||||
|
||||
if self.language == Language.EN_RU:
|
||||
# Stage 1: Chinese -> English
|
||||
console.print(f"[cyan]Stage 1: Chinese → English ({self.model_name})...[/]")
|
||||
agent1 = self._create_agent(stage=1)
|
||||
intermediate = await self._translate_parallel(
|
||||
agent1, segments, stage=1, desc="Zh→En..."
|
||||
)
|
||||
|
||||
# Stage 2: English -> Russian
|
||||
console.print(f"[cyan]Stage 2: English → Russian ({self.model_name})...[/]")
|
||||
agent2 = self._create_agent(stage=2)
|
||||
translated = await self._translate_parallel(
|
||||
agent2, intermediate, stage=2, desc="En→Ru..."
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
f"[cyan]Translating to {self.language.value.upper()} with {self.model_name}...[/]"
|
||||
)
|
||||
agent = self._create_agent()
|
||||
translated = await self._translate_parallel(agent, segments)
|
||||
|
||||
with open(self.paths.translated_json, "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
[s.model_dump() for s in translated], f, ensure_ascii=False, indent=2
|
||||
)
|
||||
|
||||
missing = sum(1 for s in translated if s.translated == s.text)
|
||||
if missing:
|
||||
console.print(
|
||||
f"[yellow]Warning: {missing} segments may not be translated[/]"
|
||||
)
|
||||
|
||||
console.print(f"[green]✓ Translated {len(translated)} segments[/]")
|
||||
162
dubbing/steps/tts.py
Normal file
162
dubbing/steps/tts.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import json
|
||||
import re
|
||||
import torch
|
||||
import torchaudio
|
||||
import edge_tts
|
||||
from pathlib import Path
|
||||
from rich.console import Console
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
BarColumn,
|
||||
TaskProgressColumn,
|
||||
)
|
||||
from dubbing.steps.base import PipelineStep
|
||||
from dubbing.models import TranslatedSegment, Language
|
||||
from dubbing.config import TTSConfig
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class TTSStep(PipelineStep):
|
||||
name = "TTS"
|
||||
|
||||
def __init__(self, paths, language: Language = Language.RU):
|
||||
super().__init__(paths)
|
||||
self.language = language
|
||||
self._silero_model = None
|
||||
|
||||
def is_cached(self) -> bool:
|
||||
if not self.paths.tts_dir.exists():
|
||||
return False
|
||||
return any(self.paths.tts_dir.glob("*.wav"))
|
||||
|
||||
def clean(self) -> None:
|
||||
if self.paths.tts_dir.exists():
|
||||
for f in self.paths.tts_dir.glob("*.wav"):
|
||||
f.unlink()
|
||||
|
||||
def _load_translated(self) -> list[TranslatedSegment]:
|
||||
with open(self.paths.translated_json, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return [TranslatedSegment(**s) for s in data]
|
||||
|
||||
def _clean_text_russian(self, text: str) -> str | None:
|
||||
"""Clean text for Russian TTS (Silero)."""
|
||||
text = re.sub(r"[^\w\s.,!?;:\-—–\'\"«»а-яА-ЯёЁ]", "", text)
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
if not text or not re.search(r"[а-яА-ЯёЁ]", text):
|
||||
return None
|
||||
return text
|
||||
|
||||
def _clean_text_english(self, text: str) -> str | None:
|
||||
"""Clean text for English TTS (Piper)."""
|
||||
text = re.sub(r"[^\w\s.,!?;:\-—–\'\"a-zA-Z0-9]", "", text)
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
if not text or not re.search(r"[a-zA-Z]", text):
|
||||
return None
|
||||
return text
|
||||
|
||||
def _clean_text(self, text: str) -> str | None:
|
||||
"""Clean text based on language."""
|
||||
if self.language == Language.EN:
|
||||
return self._clean_text_english(text)
|
||||
# RU and EN_RU both output Russian
|
||||
return self._clean_text_russian(text)
|
||||
|
||||
def _load_silero_model(self):
|
||||
"""Load Silero TTS model for Russian."""
|
||||
if self._silero_model is None:
|
||||
console.print("[dim]Loading Silero TTS model...[/]")
|
||||
device = torch.device("cpu")
|
||||
self._silero_model, _ = torch.hub.load(
|
||||
repo_or_dir="snakers4/silero-models",
|
||||
model="silero_tts",
|
||||
language="ru",
|
||||
speaker="v4_ru",
|
||||
)
|
||||
self._silero_model.to(device)
|
||||
return self._silero_model
|
||||
|
||||
def _synthesize_russian(self, text: str, output_path: Path) -> bool:
|
||||
"""Synthesize Russian speech using Silero."""
|
||||
model = self._load_silero_model()
|
||||
try:
|
||||
audio = model.apply_tts(
|
||||
text=text,
|
||||
speaker=TTSConfig.RU_VOICE,
|
||||
sample_rate=TTSConfig.RU_SAMPLE_RATE,
|
||||
)
|
||||
torchaudio.save(
|
||||
str(output_path), audio.unsqueeze(0), TTSConfig.RU_SAMPLE_RATE
|
||||
)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
async def _synthesize_english(self, text: str, output_path: Path) -> bool:
|
||||
"""Synthesize English speech using Edge TTS."""
|
||||
try:
|
||||
mp3_path = output_path.with_suffix(".mp3")
|
||||
communicate = edge_tts.Communicate(text, TTSConfig.EN_VOICE)
|
||||
await communicate.save(str(mp3_path))
|
||||
|
||||
# Convert MP3 to WAV using pydub
|
||||
from pydub import AudioSegment
|
||||
|
||||
audio = AudioSegment.from_mp3(str(mp3_path))
|
||||
audio.export(str(output_path), format="wav")
|
||||
mp3_path.unlink()
|
||||
return True
|
||||
except Exception as e:
|
||||
console.print(f"[red]Edge TTS error: {e}[/]")
|
||||
return False
|
||||
|
||||
async def _synthesize(self, text: str, output_path: Path) -> bool:
|
||||
"""Synthesize speech based on language."""
|
||||
if self.language == Language.EN:
|
||||
return await self._synthesize_english(text, output_path)
|
||||
# RU and EN_RU both use Russian TTS
|
||||
return self._synthesize_russian(text, output_path)
|
||||
|
||||
async def run(self) -> None:
|
||||
engine = (
|
||||
TTSConfig.EN_ENGINE if self.language == Language.EN else TTSConfig.RU_ENGINE
|
||||
) # EN_RU uses RU engine
|
||||
console.print(f"[cyan]Generating TTS audio ({engine})...[/]")
|
||||
|
||||
self.paths.tts_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
segments = self._load_translated()
|
||||
|
||||
skipped = []
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
console=console,
|
||||
) as progress:
|
||||
task = progress.add_task("Generating speech...", total=len(segments))
|
||||
|
||||
for i, seg in enumerate(segments):
|
||||
clean_text = self._clean_text(seg.translated)
|
||||
if not clean_text:
|
||||
skipped.append((i, seg.translated, "no_text"))
|
||||
progress.advance(task)
|
||||
continue
|
||||
|
||||
path = self.paths.tts_dir / f"seg_{i:04d}.wav"
|
||||
success = await self._synthesize(clean_text, path)
|
||||
if not success:
|
||||
skipped.append((i, seg.translated, "tts_error"))
|
||||
|
||||
progress.advance(task)
|
||||
|
||||
generated = len(segments) - len(skipped)
|
||||
console.print(f"[green]✓ Generated {generated} audio files[/]")
|
||||
if skipped:
|
||||
console.print(f"[yellow]Skipped {len(skipped)} segments:[/]")
|
||||
for idx, text, reason in skipped:
|
||||
console.print(f" [dim]{idx}:[/] {text[:60]}... [red]({reason})[/]")
|
||||
Reference in New Issue
Block a user