feat: mvp
This commit is contained in:
215
dubbing/steps/translate.py
Normal file
215
dubbing/steps/translate.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import asyncio
|
||||
import json
|
||||
from rich.console import Console
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
BarColumn,
|
||||
TaskProgressColumn,
|
||||
)
|
||||
from pydantic_ai import Agent, NativeOutput
|
||||
from pydantic_ai.models.google import GoogleModel
|
||||
from pydantic_ai.providers.google import GoogleProvider
|
||||
from dubbing.steps.base import PipelineStep
|
||||
from dubbing.models import Segment, TranslatedSegment, TranslationBatch, Language
|
||||
from dubbing.config import settings, TranslationPrompts
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class TranslateStep(PipelineStep):
|
||||
name = "Translate"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
paths,
|
||||
model_name: str = "gemini-2.0-flash-lite",
|
||||
language: Language = Language.RU,
|
||||
):
|
||||
super().__init__(paths)
|
||||
self.model_name = model_name
|
||||
self.language = language
|
||||
|
||||
def is_cached(self) -> bool:
|
||||
return self.paths.translated_json.exists()
|
||||
|
||||
def clean(self) -> None:
|
||||
if self.paths.translated_json.exists():
|
||||
self.paths.translated_json.unlink()
|
||||
|
||||
def _load_segments(self) -> list[Segment]:
|
||||
with open(self.paths.segments_json, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return [Segment(**s) for s in data]
|
||||
|
||||
def _get_system_prompt(self, stage: int = 1) -> str:
|
||||
if self.language == Language.EN:
|
||||
return TranslationPrompts.EN
|
||||
elif self.language == Language.EN_RU:
|
||||
return (
|
||||
TranslationPrompts.EN_RU_STAGE1
|
||||
if stage == 1
|
||||
else TranslationPrompts.EN_RU_STAGE2
|
||||
)
|
||||
return TranslationPrompts.RU
|
||||
|
||||
def _get_translate_command(self, stage: int = 1) -> str:
|
||||
if self.language == Language.EN:
|
||||
return "Translate:"
|
||||
elif self.language == Language.EN_RU:
|
||||
return "Translate:" if stage == 1 else "Переведи:"
|
||||
return "Переведи:"
|
||||
|
||||
def _get_context_header(self, stage: int = 1) -> str:
|
||||
if self.language == Language.EN:
|
||||
return "Context:"
|
||||
elif self.language == Language.EN_RU:
|
||||
return "Context:" if stage == 1 else "Контекст:"
|
||||
return "Контекст:"
|
||||
|
||||
def _create_agent(self, stage: int = 1) -> Agent:
|
||||
provider = GoogleProvider(api_key=settings.gemini_api_key)
|
||||
model = GoogleModel(self.model_name, provider=provider)
|
||||
|
||||
return Agent(
|
||||
model,
|
||||
output_type=NativeOutput(TranslationBatch),
|
||||
system_prompt=self._get_system_prompt(stage),
|
||||
)
|
||||
|
||||
async def _translate_chunk(
|
||||
self,
|
||||
agent: Agent,
|
||||
chunk: list[Segment],
|
||||
chunk_idx: int,
|
||||
context: str,
|
||||
semaphore: asyncio.Semaphore,
|
||||
stage: int = 1,
|
||||
) -> list[TranslatedSegment]:
|
||||
async with semaphore:
|
||||
# For stage 2, use translated field as source
|
||||
if stage == 2:
|
||||
items = "\n".join([f"{i}: {s.translated}" for i, s in enumerate(chunk)])
|
||||
else:
|
||||
items = "\n".join([f"{i}: {s.text}" for i, s in enumerate(chunk)])
|
||||
prompt = f"{context}{self._get_translate_command(stage)}\n\n{items}"
|
||||
|
||||
try:
|
||||
result = await agent.run(prompt)
|
||||
translated = []
|
||||
for i, seg in enumerate(chunk):
|
||||
translated_text = ""
|
||||
for t in result.output.translations:
|
||||
if t.id == i:
|
||||
translated_text = t.translated
|
||||
break
|
||||
translated.append(
|
||||
TranslatedSegment(
|
||||
start=seg.start,
|
||||
end=seg.end,
|
||||
text=seg.text,
|
||||
translated=translated_text
|
||||
or (seg.translated if stage == 2 else seg.text),
|
||||
)
|
||||
)
|
||||
return translated
|
||||
except Exception as e:
|
||||
console.print(f"[red]Chunk {chunk_idx} error: {e}[/]")
|
||||
fallback = seg.translated if stage == 2 else seg.text
|
||||
return [
|
||||
TranslatedSegment(
|
||||
start=s.start, end=s.end, text=s.text, translated=fallback
|
||||
)
|
||||
for s in chunk
|
||||
]
|
||||
|
||||
async def _translate_parallel(
|
||||
self, agent: Agent, segments: list, stage: int = 1, desc: str = "Translating..."
|
||||
) -> list[TranslatedSegment]:
|
||||
chunk_size = settings.translation_chunk_size
|
||||
concurrency = settings.translation_concurrency
|
||||
|
||||
chunks = [
|
||||
segments[i : i + chunk_size] for i in range(0, len(segments), chunk_size)
|
||||
]
|
||||
semaphore = asyncio.Semaphore(concurrency)
|
||||
|
||||
all_results: list[TranslatedSegment] = []
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
console=console,
|
||||
) as progress:
|
||||
task = progress.add_task(desc, total=len(chunks))
|
||||
|
||||
for batch_start in range(0, len(chunks), concurrency):
|
||||
batch = chunks[batch_start : batch_start + concurrency]
|
||||
|
||||
context = ""
|
||||
if all_results:
|
||||
prev = all_results[-5:]
|
||||
if stage == 2:
|
||||
ctx_lines = [f"• {s.translated}" for s in prev]
|
||||
else:
|
||||
ctx_lines = [f"• {s.text} → {s.translated}" for s in prev]
|
||||
context = (
|
||||
f"{self._get_context_header(stage)}\n"
|
||||
+ "\n".join(ctx_lines)
|
||||
+ "\n\n"
|
||||
)
|
||||
|
||||
tasks = [
|
||||
self._translate_chunk(
|
||||
agent, chunk, batch_start + idx, context, semaphore, stage
|
||||
)
|
||||
for idx, chunk in enumerate(batch)
|
||||
]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for result in results:
|
||||
if isinstance(result, list):
|
||||
all_results.extend(result)
|
||||
progress.advance(task)
|
||||
|
||||
return all_results
|
||||
|
||||
async def run(self) -> None:
|
||||
segments = self._load_segments()
|
||||
|
||||
if self.language == Language.EN_RU:
|
||||
# Stage 1: Chinese -> English
|
||||
console.print(f"[cyan]Stage 1: Chinese → English ({self.model_name})...[/]")
|
||||
agent1 = self._create_agent(stage=1)
|
||||
intermediate = await self._translate_parallel(
|
||||
agent1, segments, stage=1, desc="Zh→En..."
|
||||
)
|
||||
|
||||
# Stage 2: English -> Russian
|
||||
console.print(f"[cyan]Stage 2: English → Russian ({self.model_name})...[/]")
|
||||
agent2 = self._create_agent(stage=2)
|
||||
translated = await self._translate_parallel(
|
||||
agent2, intermediate, stage=2, desc="En→Ru..."
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
f"[cyan]Translating to {self.language.value.upper()} with {self.model_name}...[/]"
|
||||
)
|
||||
agent = self._create_agent()
|
||||
translated = await self._translate_parallel(agent, segments)
|
||||
|
||||
with open(self.paths.translated_json, "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
[s.model_dump() for s in translated], f, ensure_ascii=False, indent=2
|
||||
)
|
||||
|
||||
missing = sum(1 for s in translated if s.translated == s.text)
|
||||
if missing:
|
||||
console.print(
|
||||
f"[yellow]Warning: {missing} segments may not be translated[/]"
|
||||
)
|
||||
|
||||
console.print(f"[green]✓ Translated {len(translated)} segments[/]")
|
||||
Reference in New Issue
Block a user