feat: mvp
This commit is contained in:
78
dubbing/steps/asr.py
Normal file
78
dubbing/steps/asr.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import json
|
||||
from rich.console import Console
|
||||
from dubbing.steps.base import PipelineStep
|
||||
from dubbing.models import Segment
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class ASRStep(PipelineStep):
|
||||
name = "ASR"
|
||||
|
||||
def is_cached(self) -> bool:
|
||||
return self.paths.segments_json.exists()
|
||||
|
||||
def clean(self) -> None:
|
||||
if self.paths.segments_json.exists():
|
||||
self.paths.segments_json.unlink()
|
||||
|
||||
def _group_chinese_segments(
|
||||
self, words: list[str], timestamps: list[list[int]]
|
||||
) -> list[Segment]:
|
||||
segments = []
|
||||
current_text = ""
|
||||
current_start = None
|
||||
current_end = None
|
||||
punctuation = {"。", ",", "!", "?", ";", ":", "…", "、"}
|
||||
|
||||
for word, ts in zip(words, timestamps):
|
||||
if current_start is None:
|
||||
current_start = ts[0]
|
||||
current_text += word
|
||||
current_end = ts[1]
|
||||
|
||||
if word in punctuation:
|
||||
segments.append(
|
||||
Segment(start=current_start, end=current_end, text=current_text)
|
||||
)
|
||||
current_text = ""
|
||||
current_start = None
|
||||
|
||||
if current_text:
|
||||
segments.append(
|
||||
Segment(start=current_start, end=current_end, text=current_text)
|
||||
)
|
||||
|
||||
return segments
|
||||
|
||||
async def run(self) -> None:
|
||||
console.print("[cyan]Running speech recognition...[/]")
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(
|
||||
model="iic/SenseVoiceSmall",
|
||||
device="mps",
|
||||
vad_model="fsmn-vad",
|
||||
vad_kwargs={"max_single_segment_time": 30000},
|
||||
)
|
||||
|
||||
result = model.generate(
|
||||
input=str(self.paths.source_audio),
|
||||
language="zh",
|
||||
use_itn=True,
|
||||
batch_size_s=60,
|
||||
merge_length_s=15,
|
||||
output_timestamp=True,
|
||||
)
|
||||
|
||||
segments = self._group_chinese_segments(
|
||||
result[0]["words"], result[0]["timestamp"]
|
||||
)
|
||||
|
||||
with open(self.paths.segments_json, "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
[s.model_dump() for s in segments], f, ensure_ascii=False, indent=2
|
||||
)
|
||||
|
||||
console.print(f"[green]✓ Found {len(segments)} segments[/]")
|
||||
Reference in New Issue
Block a user