feat: add SST
This commit is contained in:
@@ -3,6 +3,8 @@ from pyrogram.types import Message
|
|||||||
from userbot import PyroClient
|
from userbot import PyroClient
|
||||||
from userbot.modules.capture import capture_message
|
from userbot.modules.capture import capture_message
|
||||||
from userbot.modules.capture.chat_meta import meta_from_chat
|
from userbot.modules.capture.chat_meta import meta_from_chat
|
||||||
|
from userbot.modules.stt import is_transcribable
|
||||||
|
from userbot.modules.stt.gate import safe_transcribe
|
||||||
|
|
||||||
|
|
||||||
@PyroClient.on_message()
|
@PyroClient.on_message()
|
||||||
@@ -15,6 +17,12 @@ async def on_message(client: PyroClient, message: Message) -> None:
|
|||||||
if not toggles.messages:
|
if not toggles.messages:
|
||||||
return
|
return
|
||||||
await capture_message(client, message, ctx, toggles)
|
await capture_message(client, message, ctx, toggles)
|
||||||
|
if (
|
||||||
|
toggles.stt
|
||||||
|
and is_transcribable(message)
|
||||||
|
and (message.outgoing or message.unread_media is False)
|
||||||
|
):
|
||||||
|
await safe_transcribe(client, ctx, meta.chat_id, message.id)
|
||||||
|
|
||||||
|
|
||||||
handlers = on_message.handlers
|
handlers = on_message.handlers
|
||||||
|
|||||||
@@ -3,11 +3,17 @@ from collections.abc import Awaitable, Callable
|
|||||||
from pyrogram import raw
|
from pyrogram import raw
|
||||||
|
|
||||||
from userbot import PyroClient
|
from userbot import PyroClient
|
||||||
from userbot.handlers.raw import contacts, dialog_filters, reactions
|
from userbot.handlers.raw import (
|
||||||
|
contacts,
|
||||||
|
dialog_filters,
|
||||||
|
reactions,
|
||||||
|
read_contents,
|
||||||
|
transcribed,
|
||||||
|
)
|
||||||
|
|
||||||
RawHandler = Callable[..., Awaitable[None]]
|
RawHandler = Callable[..., Awaitable[None]]
|
||||||
|
|
||||||
_MODULES = (contacts, dialog_filters, reactions)
|
_MODULES = (contacts, dialog_filters, reactions, read_contents, transcribed)
|
||||||
_REGISTRY: dict[type, RawHandler] = {
|
_REGISTRY: dict[type, RawHandler] = {
|
||||||
update_type: module.handle for module in _MODULES for update_type in module.HANDLES
|
update_type: module.handle for module in _MODULES for update_type in module.HANDLES
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,34 @@
|
|||||||
|
from pyrogram import raw, utils
|
||||||
|
|
||||||
|
from userbot import PyroClient
|
||||||
|
from userbot.modules.capture.chat_meta import meta_from_chat_id
|
||||||
|
from userbot.modules.stt import repository
|
||||||
|
from userbot.modules.stt.gate import safe_transcribe
|
||||||
|
|
||||||
|
HANDLES = (
|
||||||
|
raw.types.UpdateReadMessagesContents,
|
||||||
|
raw.types.UpdateChannelReadMessagesContents,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def handle(
|
||||||
|
client: PyroClient, update: raw.base.Update, _users: dict, _chats: dict
|
||||||
|
) -> None:
|
||||||
|
ctx = client.capture
|
||||||
|
if ctx is None:
|
||||||
|
return
|
||||||
|
if isinstance(update, raw.types.UpdateChannelReadMessagesContents):
|
||||||
|
chat_id = utils.get_peer_id(raw.types.PeerChannel(channel_id=update.channel_id))
|
||||||
|
candidates = await repository.pending_voice_reads(
|
||||||
|
ctx.pool, ctx.account_id, update.messages, chat_id=chat_id
|
||||||
|
)
|
||||||
|
elif isinstance(update, raw.types.UpdateReadMessagesContents):
|
||||||
|
candidates = await repository.pending_voice_reads(
|
||||||
|
ctx.pool, ctx.account_id, update.messages
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
for cand_chat_id, message_id in candidates:
|
||||||
|
meta = meta_from_chat_id(cand_chat_id, ctx.contacts.ids)
|
||||||
|
if ctx.resolve(meta).stt:
|
||||||
|
await safe_transcribe(client, ctx, cand_chat_id, message_id)
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
from pyrogram import raw
|
||||||
|
|
||||||
|
from userbot import PyroClient
|
||||||
|
from userbot.modules.capture.chat_meta import meta_from_peer
|
||||||
|
from userbot.modules.stt import repository
|
||||||
|
|
||||||
|
HANDLES = (raw.types.UpdateTranscribedAudio,)
|
||||||
|
|
||||||
|
|
||||||
|
async def handle(
|
||||||
|
client: PyroClient,
|
||||||
|
update: raw.types.UpdateTranscribedAudio,
|
||||||
|
users: dict,
|
||||||
|
chats: dict,
|
||||||
|
) -> None:
|
||||||
|
ctx = client.capture
|
||||||
|
if ctx is None or not update.text:
|
||||||
|
return
|
||||||
|
meta = meta_from_peer(update.peer, chats, users, ctx.contacts.ids)
|
||||||
|
await repository.set_extracted_text(
|
||||||
|
ctx.pool, ctx.account_id, meta.chat_id, update.msg_id, update.text
|
||||||
|
)
|
||||||
@@ -25,6 +25,14 @@ def meta_from_chat(chat: Chat, contacts: set[int]) -> ChatMeta:
|
|||||||
return ChatMeta(chat_id=chat_id, kind=kind, is_bot=is_bot, is_contact=is_contact)
|
return ChatMeta(chat_id=chat_id, kind=kind, is_bot=is_bot, is_contact=is_contact)
|
||||||
|
|
||||||
|
|
||||||
|
def meta_from_chat_id(chat_id: int, contacts: set[int]) -> ChatMeta:
|
||||||
|
if chat_id > 0:
|
||||||
|
return ChatMeta(
|
||||||
|
chat_id=chat_id, kind=ChatKind.DM, is_contact=chat_id in contacts
|
||||||
|
)
|
||||||
|
return ChatMeta(chat_id=chat_id, kind=ChatKind.GROUP)
|
||||||
|
|
||||||
|
|
||||||
def meta_from_peer(
|
def meta_from_peer(
|
||||||
peer: raw.base.Peer, chats: dict, users: dict, contacts: set[int]
|
peer: raw.base.Peer, chats: dict, users: dict, contacts: set[int]
|
||||||
) -> ChatMeta:
|
) -> ChatMeta:
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
from userbot.modules.jobs.handlers import backfill, fetch_media
|
from userbot.modules.jobs.handlers import backfill, fetch_media, transcribe
|
||||||
|
|
||||||
__all__ = ["backfill", "fetch_media"]
|
__all__ = ["backfill", "fetch_media", "transcribe"]
|
||||||
|
|||||||
@@ -0,0 +1,16 @@
|
|||||||
|
from userbot.modules.jobs.context import JobContext
|
||||||
|
from userbot.modules.jobs.registry import register
|
||||||
|
from userbot.modules.stt import transcribe_message
|
||||||
|
|
||||||
|
|
||||||
|
@register("transcribe")
|
||||||
|
async def transcribe(ctx: JobContext) -> None:
|
||||||
|
client = ctx.client
|
||||||
|
if client is None:
|
||||||
|
return
|
||||||
|
capture = getattr(client, "capture", None)
|
||||||
|
if capture is None:
|
||||||
|
return
|
||||||
|
chat_id = ctx.job.params["chat_id"]
|
||||||
|
message_id = ctx.job.params["message_id"]
|
||||||
|
await transcribe_message(client, capture, chat_id, message_id)
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
from userbot.modules.stt.service import is_transcribable, transcribe_message
|
||||||
|
|
||||||
|
__all__ = ["is_transcribable", "transcribe_message"]
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
from pyrogram import Client
|
||||||
|
from pyrogram.errors import FloodPremiumWait, FloodWait, RPCError
|
||||||
|
|
||||||
|
from userbot.modules.capture.context import CaptureContext
|
||||||
|
from userbot.modules.jobs.repository import enqueue
|
||||||
|
from userbot.modules.stt.service import transcribe_message
|
||||||
|
from utils.logging import logger
|
||||||
|
|
||||||
|
|
||||||
|
async def safe_transcribe(
|
||||||
|
client: Client, ctx: CaptureContext, chat_id: int, message_id: int
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
await transcribe_message(client, ctx, chat_id, message_id)
|
||||||
|
except (FloodWait, FloodPremiumWait):
|
||||||
|
await enqueue(
|
||||||
|
ctx.pool,
|
||||||
|
ctx.account_id,
|
||||||
|
"transcribe",
|
||||||
|
{"chat_id": chat_id, "message_id": message_id},
|
||||||
|
)
|
||||||
|
except RPCError as exc:
|
||||||
|
logger.warning(f"[yellow]STT failed for {chat_id}/{message_id}: {exc}[/]")
|
||||||
@@ -0,0 +1,45 @@
|
|||||||
|
import asyncpg
|
||||||
|
|
||||||
|
from userbot.modules.capture.repository import CHANNEL_ID_THRESHOLD
|
||||||
|
|
||||||
|
_VOICE_KINDS = ["voice", "video_note"]
|
||||||
|
|
||||||
|
_SET_EXTRACTED_TEXT = """
|
||||||
|
UPDATE media SET extracted_text = $4
|
||||||
|
WHERE account_id = $1 AND chat_id = $2 AND message_id = $3
|
||||||
|
"""
|
||||||
|
|
||||||
|
_PENDING_BOX = """
|
||||||
|
SELECT chat_id, message_id FROM media
|
||||||
|
WHERE account_id = $1 AND message_id = ANY($2::bigint[])
|
||||||
|
AND chat_id > $3 AND kind = ANY($4::text[]) AND extracted_text IS NULL
|
||||||
|
"""
|
||||||
|
|
||||||
|
_PENDING_CHANNEL = """
|
||||||
|
SELECT chat_id, message_id FROM media
|
||||||
|
WHERE account_id = $1 AND message_id = ANY($2::bigint[])
|
||||||
|
AND chat_id = $3 AND kind = ANY($4::text[]) AND extracted_text IS NULL
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
async def set_extracted_text(
|
||||||
|
pool: asyncpg.Pool, account_id: int, chat_id: int, message_id: int, text: str
|
||||||
|
) -> None:
|
||||||
|
await pool.execute(_SET_EXTRACTED_TEXT, account_id, chat_id, message_id, text)
|
||||||
|
|
||||||
|
|
||||||
|
async def pending_voice_reads(
|
||||||
|
pool: asyncpg.Pool,
|
||||||
|
account_id: int,
|
||||||
|
message_ids: list[int],
|
||||||
|
chat_id: int | None = None,
|
||||||
|
) -> list[tuple[int, int]]:
|
||||||
|
if chat_id is None:
|
||||||
|
rows = await pool.fetch(
|
||||||
|
_PENDING_BOX, account_id, message_ids, CHANNEL_ID_THRESHOLD, _VOICE_KINDS
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
rows = await pool.fetch(
|
||||||
|
_PENDING_CHANNEL, account_id, message_ids, chat_id, _VOICE_KINDS
|
||||||
|
)
|
||||||
|
return [(row["chat_id"], row["message_id"]) for row in rows]
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
from pyrogram import Client, raw
|
||||||
|
from pyrogram.types import Message
|
||||||
|
|
||||||
|
from userbot.modules.capture.context import CaptureContext
|
||||||
|
from userbot.modules.media import self_destruct_ttl
|
||||||
|
from userbot.modules.stt import repository
|
||||||
|
|
||||||
|
|
||||||
|
def is_transcribable(message: Message) -> bool:
|
||||||
|
if self_destruct_ttl(message) is not None:
|
||||||
|
return False
|
||||||
|
return message.voice is not None or message.video_note is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def transcribe_message(
|
||||||
|
client: Client, ctx: CaptureContext, chat_id: int, message_id: int
|
||||||
|
) -> None:
|
||||||
|
peer = await client.resolve_peer(chat_id)
|
||||||
|
if peer is None:
|
||||||
|
return
|
||||||
|
result = await client.invoke(
|
||||||
|
raw.functions.messages.TranscribeAudio(peer=peer, msg_id=message_id)
|
||||||
|
)
|
||||||
|
if not result.pending and result.text:
|
||||||
|
await repository.set_extracted_text(
|
||||||
|
ctx.pool, ctx.account_id, chat_id, message_id, result.text
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user