Files
FUCKINGCHINESEDRAMAS/dubbing/steps/translate.py
2026-02-01 16:07:59 +01:00

216 lines
7.7 KiB
Python

import asyncio
import json
from rich.console import Console
from rich.progress import (
Progress,
SpinnerColumn,
TextColumn,
BarColumn,
TaskProgressColumn,
)
from pydantic_ai import Agent, NativeOutput
from pydantic_ai.models.google import GoogleModel
from pydantic_ai.providers.google import GoogleProvider
from dubbing.steps.base import PipelineStep
from dubbing.models import Segment, TranslatedSegment, TranslationBatch, Language
from dubbing.config import settings, TranslationPrompts
console = Console()
class TranslateStep(PipelineStep):
name = "Translate"
def __init__(
self,
paths,
model_name: str = "gemini-2.0-flash-lite",
language: Language = Language.RU,
):
super().__init__(paths)
self.model_name = model_name
self.language = language
def is_cached(self) -> bool:
return self.paths.translated_json.exists()
def clean(self) -> None:
if self.paths.translated_json.exists():
self.paths.translated_json.unlink()
def _load_segments(self) -> list[Segment]:
with open(self.paths.segments_json, "r", encoding="utf-8") as f:
data = json.load(f)
return [Segment(**s) for s in data]
def _get_system_prompt(self, stage: int = 1) -> str:
if self.language == Language.EN:
return TranslationPrompts.EN
elif self.language == Language.EN_RU:
return (
TranslationPrompts.EN_RU_STAGE1
if stage == 1
else TranslationPrompts.EN_RU_STAGE2
)
return TranslationPrompts.RU
def _get_translate_command(self, stage: int = 1) -> str:
if self.language == Language.EN:
return "Translate:"
elif self.language == Language.EN_RU:
return "Translate:" if stage == 1 else "Переведи:"
return "Переведи:"
def _get_context_header(self, stage: int = 1) -> str:
if self.language == Language.EN:
return "Context:"
elif self.language == Language.EN_RU:
return "Context:" if stage == 1 else "Контекст:"
return "Контекст:"
def _create_agent(self, stage: int = 1) -> Agent:
provider = GoogleProvider(api_key=settings.gemini_api_key)
model = GoogleModel(self.model_name, provider=provider)
return Agent(
model,
output_type=NativeOutput(TranslationBatch),
system_prompt=self._get_system_prompt(stage),
)
async def _translate_chunk(
self,
agent: Agent,
chunk: list[Segment],
chunk_idx: int,
context: str,
semaphore: asyncio.Semaphore,
stage: int = 1,
) -> list[TranslatedSegment]:
async with semaphore:
# For stage 2, use translated field as source
if stage == 2:
items = "\n".join([f"{i}: {s.translated}" for i, s in enumerate(chunk)])
else:
items = "\n".join([f"{i}: {s.text}" for i, s in enumerate(chunk)])
prompt = f"{context}{self._get_translate_command(stage)}\n\n{items}"
try:
result = await agent.run(prompt)
translated = []
for i, seg in enumerate(chunk):
translated_text = ""
for t in result.output.translations:
if t.id == i:
translated_text = t.translated
break
translated.append(
TranslatedSegment(
start=seg.start,
end=seg.end,
text=seg.text,
translated=translated_text
or (seg.translated if stage == 2 else seg.text),
)
)
return translated
except Exception as e:
console.print(f"[red]Chunk {chunk_idx} error: {e}[/]")
fallback = seg.translated if stage == 2 else seg.text
return [
TranslatedSegment(
start=s.start, end=s.end, text=s.text, translated=fallback
)
for s in chunk
]
async def _translate_parallel(
self, agent: Agent, segments: list, stage: int = 1, desc: str = "Translating..."
) -> list[TranslatedSegment]:
chunk_size = settings.translation_chunk_size
concurrency = settings.translation_concurrency
chunks = [
segments[i : i + chunk_size] for i in range(0, len(segments), chunk_size)
]
semaphore = asyncio.Semaphore(concurrency)
all_results: list[TranslatedSegment] = []
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
console=console,
) as progress:
task = progress.add_task(desc, total=len(chunks))
for batch_start in range(0, len(chunks), concurrency):
batch = chunks[batch_start : batch_start + concurrency]
context = ""
if all_results:
prev = all_results[-5:]
if stage == 2:
ctx_lines = [f"{s.translated}" for s in prev]
else:
ctx_lines = [f"{s.text}{s.translated}" for s in prev]
context = (
f"{self._get_context_header(stage)}\n"
+ "\n".join(ctx_lines)
+ "\n\n"
)
tasks = [
self._translate_chunk(
agent, chunk, batch_start + idx, context, semaphore, stage
)
for idx, chunk in enumerate(batch)
]
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, list):
all_results.extend(result)
progress.advance(task)
return all_results
async def run(self) -> None:
segments = self._load_segments()
if self.language == Language.EN_RU:
# Stage 1: Chinese -> English
console.print(f"[cyan]Stage 1: Chinese → English ({self.model_name})...[/]")
agent1 = self._create_agent(stage=1)
intermediate = await self._translate_parallel(
agent1, segments, stage=1, desc="Zh→En..."
)
# Stage 2: English -> Russian
console.print(f"[cyan]Stage 2: English → Russian ({self.model_name})...[/]")
agent2 = self._create_agent(stage=2)
translated = await self._translate_parallel(
agent2, intermediate, stage=2, desc="En→Ru..."
)
else:
console.print(
f"[cyan]Translating to {self.language.value.upper()} with {self.model_name}...[/]"
)
agent = self._create_agent()
translated = await self._translate_parallel(agent, segments)
with open(self.paths.translated_json, "w", encoding="utf-8") as f:
json.dump(
[s.model_dump() for s in translated], f, ensure_ascii=False, indent=2
)
missing = sum(1 for s in translated if s.translated == s.text)
if missing:
console.print(
f"[yellow]Warning: {missing} segments may not be translated[/]"
)
console.print(f"[green]✓ Translated {len(translated)} segments[/]")