79 lines
2.2 KiB
Python
79 lines
2.2 KiB
Python
import json
|
||
from rich.console import Console
|
||
from dubbing.steps.base import PipelineStep
|
||
from dubbing.models import Segment
|
||
|
||
console = Console()
|
||
|
||
|
||
class ASRStep(PipelineStep):
|
||
name = "ASR"
|
||
|
||
def is_cached(self) -> bool:
|
||
return self.paths.segments_json.exists()
|
||
|
||
def clean(self) -> None:
|
||
if self.paths.segments_json.exists():
|
||
self.paths.segments_json.unlink()
|
||
|
||
def _group_chinese_segments(
|
||
self, words: list[str], timestamps: list[list[int]]
|
||
) -> list[Segment]:
|
||
segments = []
|
||
current_text = ""
|
||
current_start = None
|
||
current_end = None
|
||
punctuation = {"。", ",", "!", "?", ";", ":", "…", "、"}
|
||
|
||
for word, ts in zip(words, timestamps):
|
||
if current_start is None:
|
||
current_start = ts[0]
|
||
current_text += word
|
||
current_end = ts[1]
|
||
|
||
if word in punctuation:
|
||
segments.append(
|
||
Segment(start=current_start, end=current_end, text=current_text)
|
||
)
|
||
current_text = ""
|
||
current_start = None
|
||
|
||
if current_text:
|
||
segments.append(
|
||
Segment(start=current_start, end=current_end, text=current_text)
|
||
)
|
||
|
||
return segments
|
||
|
||
async def run(self) -> None:
|
||
console.print("[cyan]Running speech recognition...[/]")
|
||
|
||
from funasr import AutoModel
|
||
|
||
model = AutoModel(
|
||
model="iic/SenseVoiceSmall",
|
||
device="mps",
|
||
vad_model="fsmn-vad",
|
||
vad_kwargs={"max_single_segment_time": 30000},
|
||
)
|
||
|
||
result = model.generate(
|
||
input=str(self.paths.source_audio),
|
||
language="zh",
|
||
use_itn=True,
|
||
batch_size_s=60,
|
||
merge_length_s=15,
|
||
output_timestamp=True,
|
||
)
|
||
|
||
segments = self._group_chinese_segments(
|
||
result[0]["words"], result[0]["timestamp"]
|
||
)
|
||
|
||
with open(self.paths.segments_json, "w", encoding="utf-8") as f:
|
||
json.dump(
|
||
[s.model_dump() for s in segments], f, ensure_ascii=False, indent=2
|
||
)
|
||
|
||
console.print(f"[green]✓ Found {len(segments)} segments[/]")
|