feat: add SST

This commit is contained in:
h
2026-05-29 20:45:02 +02:00
parent 51093da660
commit bfd16ab02c
11 changed files with 196 additions and 4 deletions
@@ -25,6 +25,14 @@ def meta_from_chat(chat: Chat, contacts: set[int]) -> ChatMeta:
return ChatMeta(chat_id=chat_id, kind=kind, is_bot=is_bot, is_contact=is_contact)
def meta_from_chat_id(chat_id: int, contacts: set[int]) -> ChatMeta:
if chat_id > 0:
return ChatMeta(
chat_id=chat_id, kind=ChatKind.DM, is_contact=chat_id in contacts
)
return ChatMeta(chat_id=chat_id, kind=ChatKind.GROUP)
def meta_from_peer(
peer: raw.base.Peer, chats: dict, users: dict, contacts: set[int]
) -> ChatMeta:
@@ -1,3 +1,3 @@
from userbot.modules.jobs.handlers import backfill, fetch_media
from userbot.modules.jobs.handlers import backfill, fetch_media, transcribe
__all__ = ["backfill", "fetch_media"]
__all__ = ["backfill", "fetch_media", "transcribe"]
@@ -0,0 +1,16 @@
from userbot.modules.jobs.context import JobContext
from userbot.modules.jobs.registry import register
from userbot.modules.stt import transcribe_message
@register("transcribe")
async def transcribe(ctx: JobContext) -> None:
client = ctx.client
if client is None:
return
capture = getattr(client, "capture", None)
if capture is None:
return
chat_id = ctx.job.params["chat_id"]
message_id = ctx.job.params["message_id"]
await transcribe_message(client, capture, chat_id, message_id)
@@ -0,0 +1,3 @@
from userbot.modules.stt.service import is_transcribable, transcribe_message
__all__ = ["is_transcribable", "transcribe_message"]
+23
View File
@@ -0,0 +1,23 @@
from pyrogram import Client
from pyrogram.errors import FloodPremiumWait, FloodWait, RPCError
from userbot.modules.capture.context import CaptureContext
from userbot.modules.jobs.repository import enqueue
from userbot.modules.stt.service import transcribe_message
from utils.logging import logger
async def safe_transcribe(
client: Client, ctx: CaptureContext, chat_id: int, message_id: int
) -> None:
try:
await transcribe_message(client, ctx, chat_id, message_id)
except (FloodWait, FloodPremiumWait):
await enqueue(
ctx.pool,
ctx.account_id,
"transcribe",
{"chat_id": chat_id, "message_id": message_id},
)
except RPCError as exc:
logger.warning(f"[yellow]STT failed for {chat_id}/{message_id}: {exc}[/]")
@@ -0,0 +1,45 @@
import asyncpg
from userbot.modules.capture.repository import CHANNEL_ID_THRESHOLD
_VOICE_KINDS = ["voice", "video_note"]
_SET_EXTRACTED_TEXT = """
UPDATE media SET extracted_text = $4
WHERE account_id = $1 AND chat_id = $2 AND message_id = $3
"""
_PENDING_BOX = """
SELECT chat_id, message_id FROM media
WHERE account_id = $1 AND message_id = ANY($2::bigint[])
AND chat_id > $3 AND kind = ANY($4::text[]) AND extracted_text IS NULL
"""
_PENDING_CHANNEL = """
SELECT chat_id, message_id FROM media
WHERE account_id = $1 AND message_id = ANY($2::bigint[])
AND chat_id = $3 AND kind = ANY($4::text[]) AND extracted_text IS NULL
"""
async def set_extracted_text(
pool: asyncpg.Pool, account_id: int, chat_id: int, message_id: int, text: str
) -> None:
await pool.execute(_SET_EXTRACTED_TEXT, account_id, chat_id, message_id, text)
async def pending_voice_reads(
pool: asyncpg.Pool,
account_id: int,
message_ids: list[int],
chat_id: int | None = None,
) -> list[tuple[int, int]]:
if chat_id is None:
rows = await pool.fetch(
_PENDING_BOX, account_id, message_ids, CHANNEL_ID_THRESHOLD, _VOICE_KINDS
)
else:
rows = await pool.fetch(
_PENDING_CHANNEL, account_id, message_ids, chat_id, _VOICE_KINDS
)
return [(row["chat_id"], row["message_id"]) for row in rows]
@@ -0,0 +1,27 @@
from pyrogram import Client, raw
from pyrogram.types import Message
from userbot.modules.capture.context import CaptureContext
from userbot.modules.media import self_destruct_ttl
from userbot.modules.stt import repository
def is_transcribable(message: Message) -> bool:
if self_destruct_ttl(message) is not None:
return False
return message.voice is not None or message.video_note is not None
async def transcribe_message(
client: Client, ctx: CaptureContext, chat_id: int, message_id: int
) -> None:
peer = await client.resolve_peer(chat_id)
if peer is None:
return
result = await client.invoke(
raw.functions.messages.TranscribeAudio(peer=peer, msg_id=message_id)
)
if not result.pending and result.text:
await repository.set_extracted_text(
ctx.pool, ctx.account_id, chat_id, message_id, result.text
)