216 lines
7.7 KiB
Python
216 lines
7.7 KiB
Python
import asyncio
|
|
import json
|
|
from rich.console import Console
|
|
from rich.progress import (
|
|
Progress,
|
|
SpinnerColumn,
|
|
TextColumn,
|
|
BarColumn,
|
|
TaskProgressColumn,
|
|
)
|
|
from pydantic_ai import Agent, NativeOutput
|
|
from pydantic_ai.models.google import GoogleModel
|
|
from pydantic_ai.providers.google import GoogleProvider
|
|
from dubbing.steps.base import PipelineStep
|
|
from dubbing.models import Segment, TranslatedSegment, TranslationBatch, Language
|
|
from dubbing.config import settings, TranslationPrompts
|
|
|
|
console = Console()
|
|
|
|
|
|
class TranslateStep(PipelineStep):
|
|
name = "Translate"
|
|
|
|
def __init__(
|
|
self,
|
|
paths,
|
|
model_name: str = "gemini-2.0-flash-lite",
|
|
language: Language = Language.RU,
|
|
):
|
|
super().__init__(paths)
|
|
self.model_name = model_name
|
|
self.language = language
|
|
|
|
def is_cached(self) -> bool:
|
|
return self.paths.translated_json.exists()
|
|
|
|
def clean(self) -> None:
|
|
if self.paths.translated_json.exists():
|
|
self.paths.translated_json.unlink()
|
|
|
|
def _load_segments(self) -> list[Segment]:
|
|
with open(self.paths.segments_json, "r", encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
return [Segment(**s) for s in data]
|
|
|
|
def _get_system_prompt(self, stage: int = 1) -> str:
|
|
if self.language == Language.EN:
|
|
return TranslationPrompts.EN
|
|
elif self.language == Language.EN_RU:
|
|
return (
|
|
TranslationPrompts.EN_RU_STAGE1
|
|
if stage == 1
|
|
else TranslationPrompts.EN_RU_STAGE2
|
|
)
|
|
return TranslationPrompts.RU
|
|
|
|
def _get_translate_command(self, stage: int = 1) -> str:
|
|
if self.language == Language.EN:
|
|
return "Translate:"
|
|
elif self.language == Language.EN_RU:
|
|
return "Translate:" if stage == 1 else "Переведи:"
|
|
return "Переведи:"
|
|
|
|
def _get_context_header(self, stage: int = 1) -> str:
|
|
if self.language == Language.EN:
|
|
return "Context:"
|
|
elif self.language == Language.EN_RU:
|
|
return "Context:" if stage == 1 else "Контекст:"
|
|
return "Контекст:"
|
|
|
|
def _create_agent(self, stage: int = 1) -> Agent:
|
|
provider = GoogleProvider(api_key=settings.gemini_api_key)
|
|
model = GoogleModel(self.model_name, provider=provider)
|
|
|
|
return Agent(
|
|
model,
|
|
output_type=NativeOutput(TranslationBatch),
|
|
system_prompt=self._get_system_prompt(stage),
|
|
)
|
|
|
|
async def _translate_chunk(
|
|
self,
|
|
agent: Agent,
|
|
chunk: list[Segment],
|
|
chunk_idx: int,
|
|
context: str,
|
|
semaphore: asyncio.Semaphore,
|
|
stage: int = 1,
|
|
) -> list[TranslatedSegment]:
|
|
async with semaphore:
|
|
# For stage 2, use translated field as source
|
|
if stage == 2:
|
|
items = "\n".join([f"{i}: {s.translated}" for i, s in enumerate(chunk)])
|
|
else:
|
|
items = "\n".join([f"{i}: {s.text}" for i, s in enumerate(chunk)])
|
|
prompt = f"{context}{self._get_translate_command(stage)}\n\n{items}"
|
|
|
|
try:
|
|
result = await agent.run(prompt)
|
|
translated = []
|
|
for i, seg in enumerate(chunk):
|
|
translated_text = ""
|
|
for t in result.output.translations:
|
|
if t.id == i:
|
|
translated_text = t.translated
|
|
break
|
|
translated.append(
|
|
TranslatedSegment(
|
|
start=seg.start,
|
|
end=seg.end,
|
|
text=seg.text,
|
|
translated=translated_text
|
|
or (seg.translated if stage == 2 else seg.text),
|
|
)
|
|
)
|
|
return translated
|
|
except Exception as e:
|
|
console.print(f"[red]Chunk {chunk_idx} error: {e}[/]")
|
|
fallback = seg.translated if stage == 2 else seg.text
|
|
return [
|
|
TranslatedSegment(
|
|
start=s.start, end=s.end, text=s.text, translated=fallback
|
|
)
|
|
for s in chunk
|
|
]
|
|
|
|
async def _translate_parallel(
|
|
self, agent: Agent, segments: list, stage: int = 1, desc: str = "Translating..."
|
|
) -> list[TranslatedSegment]:
|
|
chunk_size = settings.translation_chunk_size
|
|
concurrency = settings.translation_concurrency
|
|
|
|
chunks = [
|
|
segments[i : i + chunk_size] for i in range(0, len(segments), chunk_size)
|
|
]
|
|
semaphore = asyncio.Semaphore(concurrency)
|
|
|
|
all_results: list[TranslatedSegment] = []
|
|
|
|
with Progress(
|
|
SpinnerColumn(),
|
|
TextColumn("[progress.description]{task.description}"),
|
|
BarColumn(),
|
|
TaskProgressColumn(),
|
|
console=console,
|
|
) as progress:
|
|
task = progress.add_task(desc, total=len(chunks))
|
|
|
|
for batch_start in range(0, len(chunks), concurrency):
|
|
batch = chunks[batch_start : batch_start + concurrency]
|
|
|
|
context = ""
|
|
if all_results:
|
|
prev = all_results[-5:]
|
|
if stage == 2:
|
|
ctx_lines = [f"• {s.translated}" for s in prev]
|
|
else:
|
|
ctx_lines = [f"• {s.text} → {s.translated}" for s in prev]
|
|
context = (
|
|
f"{self._get_context_header(stage)}\n"
|
|
+ "\n".join(ctx_lines)
|
|
+ "\n\n"
|
|
)
|
|
|
|
tasks = [
|
|
self._translate_chunk(
|
|
agent, chunk, batch_start + idx, context, semaphore, stage
|
|
)
|
|
for idx, chunk in enumerate(batch)
|
|
]
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
for result in results:
|
|
if isinstance(result, list):
|
|
all_results.extend(result)
|
|
progress.advance(task)
|
|
|
|
return all_results
|
|
|
|
async def run(self) -> None:
|
|
segments = self._load_segments()
|
|
|
|
if self.language == Language.EN_RU:
|
|
# Stage 1: Chinese -> English
|
|
console.print(f"[cyan]Stage 1: Chinese → English ({self.model_name})...[/]")
|
|
agent1 = self._create_agent(stage=1)
|
|
intermediate = await self._translate_parallel(
|
|
agent1, segments, stage=1, desc="Zh→En..."
|
|
)
|
|
|
|
# Stage 2: English -> Russian
|
|
console.print(f"[cyan]Stage 2: English → Russian ({self.model_name})...[/]")
|
|
agent2 = self._create_agent(stage=2)
|
|
translated = await self._translate_parallel(
|
|
agent2, intermediate, stage=2, desc="En→Ru..."
|
|
)
|
|
else:
|
|
console.print(
|
|
f"[cyan]Translating to {self.language.value.upper()} with {self.model_name}...[/]"
|
|
)
|
|
agent = self._create_agent()
|
|
translated = await self._translate_parallel(agent, segments)
|
|
|
|
with open(self.paths.translated_json, "w", encoding="utf-8") as f:
|
|
json.dump(
|
|
[s.model_dump() for s in translated], f, ensure_ascii=False, indent=2
|
|
)
|
|
|
|
missing = sum(1 for s in translated if s.translated == s.text)
|
|
if missing:
|
|
console.print(
|
|
f"[yellow]Warning: {missing} segments may not be translated[/]"
|
|
)
|
|
|
|
console.print(f"[green]✓ Translated {len(translated)} segments[/]")
|