import asyncio import json from rich.console import Console from rich.progress import ( Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, ) from pydantic_ai import Agent, NativeOutput from pydantic_ai.models.google import GoogleModel from pydantic_ai.providers.google import GoogleProvider from dubbing.steps.base import PipelineStep from dubbing.models import Segment, TranslatedSegment, TranslationBatch, Language from dubbing.config import settings, TranslationPrompts console = Console() class TranslateStep(PipelineStep): name = "Translate" def __init__( self, paths, model_name: str = "gemini-2.0-flash-lite", language: Language = Language.RU, ): super().__init__(paths) self.model_name = model_name self.language = language def is_cached(self) -> bool: return self.paths.translated_json.exists() def clean(self) -> None: if self.paths.translated_json.exists(): self.paths.translated_json.unlink() def _load_segments(self) -> list[Segment]: with open(self.paths.segments_json, "r", encoding="utf-8") as f: data = json.load(f) return [Segment(**s) for s in data] def _get_system_prompt(self, stage: int = 1) -> str: if self.language == Language.EN: return TranslationPrompts.EN elif self.language == Language.EN_RU: return ( TranslationPrompts.EN_RU_STAGE1 if stage == 1 else TranslationPrompts.EN_RU_STAGE2 ) return TranslationPrompts.RU def _get_translate_command(self, stage: int = 1) -> str: if self.language == Language.EN: return "Translate:" elif self.language == Language.EN_RU: return "Translate:" if stage == 1 else "Переведи:" return "Переведи:" def _get_context_header(self, stage: int = 1) -> str: if self.language == Language.EN: return "Context:" elif self.language == Language.EN_RU: return "Context:" if stage == 1 else "Контекст:" return "Контекст:" def _create_agent(self, stage: int = 1) -> Agent: provider = GoogleProvider(api_key=settings.gemini_api_key) model = GoogleModel(self.model_name, provider=provider) return Agent( model, output_type=NativeOutput(TranslationBatch), system_prompt=self._get_system_prompt(stage), ) async def _translate_chunk( self, agent: Agent, chunk: list[Segment], chunk_idx: int, context: str, semaphore: asyncio.Semaphore, stage: int = 1, ) -> list[TranslatedSegment]: async with semaphore: # For stage 2, use translated field as source if stage == 2: items = "\n".join([f"{i}: {s.translated}" for i, s in enumerate(chunk)]) else: items = "\n".join([f"{i}: {s.text}" for i, s in enumerate(chunk)]) prompt = f"{context}{self._get_translate_command(stage)}\n\n{items}" try: result = await agent.run(prompt) translated = [] for i, seg in enumerate(chunk): translated_text = "" for t in result.output.translations: if t.id == i: translated_text = t.translated break translated.append( TranslatedSegment( start=seg.start, end=seg.end, text=seg.text, translated=translated_text or (seg.translated if stage == 2 else seg.text), ) ) return translated except Exception as e: console.print(f"[red]Chunk {chunk_idx} error: {e}[/]") fallback = seg.translated if stage == 2 else seg.text return [ TranslatedSegment( start=s.start, end=s.end, text=s.text, translated=fallback ) for s in chunk ] async def _translate_parallel( self, agent: Agent, segments: list, stage: int = 1, desc: str = "Translating..." ) -> list[TranslatedSegment]: chunk_size = settings.translation_chunk_size concurrency = settings.translation_concurrency chunks = [ segments[i : i + chunk_size] for i in range(0, len(segments), chunk_size) ] semaphore = asyncio.Semaphore(concurrency) all_results: list[TranslatedSegment] = [] with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), console=console, ) as progress: task = progress.add_task(desc, total=len(chunks)) for batch_start in range(0, len(chunks), concurrency): batch = chunks[batch_start : batch_start + concurrency] context = "" if all_results: prev = all_results[-5:] if stage == 2: ctx_lines = [f"• {s.translated}" for s in prev] else: ctx_lines = [f"• {s.text} → {s.translated}" for s in prev] context = ( f"{self._get_context_header(stage)}\n" + "\n".join(ctx_lines) + "\n\n" ) tasks = [ self._translate_chunk( agent, chunk, batch_start + idx, context, semaphore, stage ) for idx, chunk in enumerate(batch) ] results = await asyncio.gather(*tasks, return_exceptions=True) for result in results: if isinstance(result, list): all_results.extend(result) progress.advance(task) return all_results async def run(self) -> None: segments = self._load_segments() if self.language == Language.EN_RU: # Stage 1: Chinese -> English console.print(f"[cyan]Stage 1: Chinese → English ({self.model_name})...[/]") agent1 = self._create_agent(stage=1) intermediate = await self._translate_parallel( agent1, segments, stage=1, desc="Zh→En..." ) # Stage 2: English -> Russian console.print(f"[cyan]Stage 2: English → Russian ({self.model_name})...[/]") agent2 = self._create_agent(stage=2) translated = await self._translate_parallel( agent2, intermediate, stage=2, desc="En→Ru..." ) else: console.print( f"[cyan]Translating to {self.language.value.upper()} with {self.model_name}...[/]" ) agent = self._create_agent() translated = await self._translate_parallel(agent, segments) with open(self.paths.translated_json, "w", encoding="utf-8") as f: json.dump( [s.model_dump() for s in translated], f, ensure_ascii=False, indent=2 ) missing = sum(1 for s in translated if s.translated == s.text) if missing: console.print( f"[yellow]Warning: {missing} segments may not be translated[/]" ) console.print(f"[green]✓ Translated {len(translated)} segments[/]")