feat: add realtime capturing

This commit is contained in:
h
2026-05-29 18:19:06 +02:00
parent 920a0235e2
commit 3c1a12750c
29 changed files with 967 additions and 47 deletions
-3
View File
@@ -1,3 +0,0 @@
from .client import PyroClient
__all__ = ["PyroClient"]
@@ -0,0 +1,3 @@
from userbot.modules.capture.context import CaptureContext, build_capture_context
__all__ = ["CaptureContext", "build_capture_context"]
@@ -0,0 +1,46 @@
from pyrogram import enums, raw, utils
from pyrogram.types import Chat
from utils.policy.models import ChatKind, ChatMeta
_KIND_BY_TYPE: dict[enums.ChatType, ChatKind] = {
enums.ChatType.PRIVATE: ChatKind.DM,
enums.ChatType.BOT: ChatKind.DM,
enums.ChatType.GROUP: ChatKind.GROUP,
enums.ChatType.SUPERGROUP: ChatKind.GROUP,
enums.ChatType.FORUM: ChatKind.GROUP,
enums.ChatType.CHANNEL: ChatKind.CHANNEL,
}
def chat_kind(chat_type: enums.ChatType | None) -> ChatKind:
return _KIND_BY_TYPE.get(chat_type, ChatKind.DM) if chat_type else ChatKind.DM
def meta_from_chat(chat: Chat, contacts: set[int]) -> ChatMeta:
kind = chat_kind(chat.type)
is_bot = chat.type is enums.ChatType.BOT
chat_id = chat.id or 0
is_contact = chat_id in contacts if kind is ChatKind.DM else None
return ChatMeta(chat_id=chat_id, kind=kind, is_bot=is_bot, is_contact=is_contact)
def meta_from_peer(
peer: raw.base.Peer, chats: dict, users: dict, contacts: set[int]
) -> ChatMeta:
chat_id = utils.get_peer_id(peer)
if isinstance(peer, raw.types.PeerUser):
user = users.get(peer.user_id)
return ChatMeta(
chat_id=chat_id,
kind=ChatKind.DM,
is_bot=bool(getattr(user, "bot", False)),
is_contact=peer.user_id in contacts,
)
if isinstance(peer, raw.types.PeerChannel):
channel = chats.get(peer.channel_id)
kind = (
ChatKind.CHANNEL if getattr(channel, "broadcast", False) else ChatKind.GROUP
)
return ChatMeta(chat_id=chat_id, kind=kind)
return ChatMeta(chat_id=chat_id, kind=ChatKind.GROUP)
@@ -0,0 +1,47 @@
import asyncpg
from pyrogram import Client
from userbot.modules.contacts import ContactCache
from userbot.modules.folders import FolderCache
from utils.policy.models import CaptureToggles, ChatMeta, PolicySet
from utils.policy.repository import load_policy_set
from utils.policy.resolver import resolve
from utils.storage import ContentAddressedStorage
class CaptureContext:
def __init__(
self,
account_id: int,
pool: asyncpg.Pool,
storage: ContentAddressedStorage,
folders: FolderCache,
contacts: ContactCache,
) -> None:
self.account_id = account_id
self.pool = pool
self.storage = storage
self.folders = folders
self.contacts = contacts
self.policies = PolicySet()
async def reload_policies(self) -> None:
self.policies = await load_policy_set(self.pool, self.account_id)
def resolve(self, chat: ChatMeta) -> CaptureToggles:
return resolve(chat, self.folders.folders, self.policies)
async def build_capture_context(
client: Client,
pool: asyncpg.Pool,
storage: ContentAddressedStorage,
account_id: int,
) -> CaptureContext:
folders = FolderCache(client, pool, account_id)
await folders.refresh()
contacts = ContactCache(client)
await contacts.refresh()
ctx = CaptureContext(account_id, pool, storage, folders, contacts)
await ctx.reload_policies()
return ctx
@@ -0,0 +1,223 @@
from datetime import datetime
import asyncpg
CHANNEL_ID_THRESHOLD = -1000000000000
_UPSERT_MESSAGE = """
INSERT INTO messages
(account_id, chat_id, message_id, date, sender_id, text, raw,
has_media, is_self_destruct)
VALUES ($1, $2, $3, $4, $5, $6, $7::jsonb, $8, $9)
ON CONFLICT (account_id, chat_id, message_id, date) DO UPDATE SET
sender_id = EXCLUDED.sender_id,
text = EXCLUDED.text,
raw = EXCLUDED.raw,
has_media = EXCLUDED.has_media,
is_self_destruct = EXCLUDED.is_self_destruct
"""
_TOUCH_EDITED = """
INSERT INTO messages
(account_id, chat_id, message_id, date, sender_id, text, raw,
has_media, is_self_destruct, edited_at)
VALUES ($1, $2, $3, $4, $5, $6, $7::jsonb, $8, $9, now())
ON CONFLICT (account_id, chat_id, message_id, date) DO UPDATE SET
edited_at = now()
"""
_INSERT_VERSION = """
INSERT INTO message_versions
(account_id, chat_id, message_id, observed_at, edit_date, text, raw)
VALUES ($1, $2, $3, now(), $4, $5, $6::jsonb)
ON CONFLICT DO NOTHING
"""
_INSERT_MEDIA = """
INSERT INTO media
(account_id, chat_id, message_id, kind, storage_key, file_size, mime,
ttl_seconds, downloaded)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
ON CONFLICT (account_id, chat_id, message_id) DO UPDATE SET
kind = EXCLUDED.kind,
storage_key = EXCLUDED.storage_key,
file_size = EXCLUDED.file_size,
mime = EXCLUDED.mime,
ttl_seconds = EXCLUDED.ttl_seconds,
downloaded = EXCLUDED.downloaded
"""
async def upsert_message( # noqa: PLR0913
pool: asyncpg.Pool,
account_id: int,
chat_id: int,
message_id: int,
date: datetime,
sender_id: int | None,
text: str | None,
raw: str,
*,
has_media: bool,
is_self_destruct: bool,
) -> None:
await pool.execute(
_UPSERT_MESSAGE,
account_id,
chat_id,
message_id,
date,
sender_id,
text,
raw,
has_media,
is_self_destruct,
)
async def mark_deleted_box(
pool: asyncpg.Pool, account_id: int, message_ids: list[int]
) -> None:
await pool.execute(
"UPDATE messages SET deleted_at = now() "
"WHERE account_id = $1 AND message_id = ANY($2::bigint[]) "
"AND chat_id > $3 AND deleted_at IS NULL",
account_id,
message_ids,
CHANNEL_ID_THRESHOLD,
)
async def mark_deleted_channel(
pool: asyncpg.Pool, account_id: int, chat_id: int, message_ids: list[int]
) -> None:
await pool.execute(
"UPDATE messages SET deleted_at = now() "
"WHERE account_id = $1 AND chat_id = $2 AND message_id = ANY($3::bigint[]) "
"AND deleted_at IS NULL",
account_id,
chat_id,
message_ids,
)
async def add_version( # noqa: PLR0913
pool: asyncpg.Pool,
account_id: int,
chat_id: int,
message_id: int,
date: datetime,
sender_id: int | None,
text: str | None,
raw: str,
edit_date: datetime | None,
*,
has_media: bool,
is_self_destruct: bool,
) -> None:
async with pool.acquire() as conn, conn.transaction():
await conn.execute(
_TOUCH_EDITED,
account_id,
chat_id,
message_id,
date,
sender_id,
text,
raw,
has_media,
is_self_destruct,
)
await conn.execute(
_INSERT_VERSION, account_id, chat_id, message_id, edit_date, text, raw
)
async def insert_media( # noqa: PLR0913
pool: asyncpg.Pool,
account_id: int,
chat_id: int,
message_id: int,
kind: str,
storage_key: str | None,
file_size: int | None,
mime: str | None,
ttl_seconds: int | None,
*,
downloaded: bool,
) -> None:
await pool.execute(
_INSERT_MEDIA,
account_id,
chat_id,
message_id,
kind,
storage_key,
file_size,
mime,
ttl_seconds,
downloaded,
)
async def insert_callbacks(
pool: asyncpg.Pool,
account_id: int,
chat_id: int,
message_id: int,
buttons: list[tuple[int, str | None, bytes | None]],
) -> None:
async with pool.acquire() as conn, conn.transaction():
await conn.execute(
"DELETE FROM callbacks "
"WHERE account_id = $1 AND chat_id = $2 AND message_id = $3",
account_id,
chat_id,
message_id,
)
await conn.executemany(
"INSERT INTO callbacks "
"(account_id, chat_id, message_id, position, label, data) "
"VALUES ($1, $2, $3, $4, $5, $6)",
[
(account_id, chat_id, message_id, position, label, data)
for position, label, data in buttons
],
)
async def sync_reactions(
pool: asyncpg.Pool,
account_id: int,
chat_id: int,
message_id: int,
current: list[tuple[int, str]],
) -> None:
async with pool.acquire() as conn, conn.transaction():
rows = await conn.fetch(
"SELECT id, peer_id, reaction FROM reactions "
"WHERE account_id = $1 AND chat_id = $2 AND message_id = $3 "
"AND removed_at IS NULL",
account_id,
chat_id,
message_id,
)
existing = {(row["peer_id"], row["reaction"]): row["id"] for row in rows}
current_set = set(current)
stale = [pid for key, pid in existing.items() if key not in current_set]
if stale:
await conn.execute(
"UPDATE reactions SET removed_at = now() WHERE id = ANY($1::bigint[])",
stale,
)
fresh = [key for key in current_set if key not in existing]
if fresh:
await conn.executemany(
"INSERT INTO reactions "
"(account_id, chat_id, message_id, peer_id, reaction) "
"VALUES ($1, $2, $3, $4, $5)",
[
(account_id, chat_id, message_id, peer_id, reaction)
for peer_id, reaction in fresh
],
)
+15 -1
View File
@@ -1,8 +1,15 @@
from typing import TYPE_CHECKING
from pyrogram import Client, enums
if TYPE_CHECKING:
from userbot.modules.capture import CaptureContext
class PyroClient(Client):
def __init__(self, name: str, *, workdir: str = "sessions") -> None:
def __init__(
self, name: str, *, workdir: str = "sessions", load_handlers: bool = True
) -> None:
super().__init__(
name,
workdir=workdir,
@@ -14,6 +21,13 @@ class PyroClient(Client):
lang_pack="tdesktop",
client_platform=enums.ClientPlatform.DESKTOP,
)
self.capture: CaptureContext | None = None
if load_handlers:
from userbot import handlers # noqa: PLC0415
for handler in handlers.handlers:
self.add_handler(*handler)
__all__ = ["PyroClient"]
@@ -0,0 +1,3 @@
from userbot.modules.contacts.cache import ContactCache
__all__ = ["ContactCache"]
@@ -0,0 +1,20 @@
from pyrogram import Client
from utils.logging import logger
class ContactCache:
def __init__(self, client: Client) -> None:
self._client = client
self.ids: set[int] = set()
async def refresh(self) -> None:
contacts = await self._client.get_contacts()
self.ids = {user.id for user in contacts}
logger.info(f"[green]Contacts cached:[/] {len(self.ids)}")
def mark(self, user_id: int, *, is_contact: bool) -> None:
if is_contact:
self.ids.add(user_id)
else:
self.ids.discard(user_id)
@@ -0,0 +1,3 @@
from userbot.modules.folders.cache import FolderCache
__all__ = ["FolderCache"]
@@ -0,0 +1,68 @@
from collections.abc import Iterable
import asyncpg
from pyrogram import Client, raw, utils
from utils.logging import logger
from utils.policy.models import FolderSpec
from utils.policy.repository import replace_folders
def _peer_ids(peers: Iterable[raw.base.InputPeer]) -> set[int]:
ids: set[int] = set()
for peer in peers:
try:
ids.add(utils.get_peer_id(peer))
except (ValueError, AttributeError):
continue
return ids
def _title(raw_title: object) -> str:
return getattr(raw_title, "text", None) or str(raw_title)
def _parse(raw_filter: raw.base.DialogFilter, order_index: int) -> FolderSpec | None:
if isinstance(raw_filter, raw.types.DialogFilterDefault):
return None
if isinstance(raw_filter, raw.types.DialogFilterChatlist):
return FolderSpec(
folder_id=raw_filter.id,
order_index=order_index,
title=_title(raw_filter.title),
include_ids=_peer_ids(raw_filter.include_peers),
pinned_ids=_peer_ids(raw_filter.pinned_peers),
is_chatlist=True,
)
return FolderSpec(
folder_id=raw_filter.id,
order_index=order_index,
title=_title(raw_filter.title),
include_ids=_peer_ids(raw_filter.include_peers),
exclude_ids=_peer_ids(raw_filter.exclude_peers),
pinned_ids=_peer_ids(raw_filter.pinned_peers),
contacts=bool(raw_filter.contacts),
non_contacts=bool(raw_filter.non_contacts),
groups=bool(raw_filter.groups),
broadcasts=bool(raw_filter.broadcasts),
bots=bool(raw_filter.bots),
)
class FolderCache:
def __init__(self, client: Client, pool: asyncpg.Pool, account_id: int) -> None:
self._client = client
self._pool = pool
self._account_id = account_id
self.folders: list[FolderSpec] = []
async def refresh(self) -> None:
result = await self._client.invoke(raw.functions.messages.GetDialogFilters())
specs = [
spec
for order_index, raw_filter in enumerate(result.filters)
if (spec := _parse(raw_filter, order_index)) is not None
]
self.folders = specs
await replace_folders(self._pool, self._account_id, specs)
logger.info(f"[green]Folders cached:[/] {len(specs)}")
@@ -0,0 +1,3 @@
from userbot.modules.media.downloader import capture_media, self_destruct_ttl
__all__ = ["capture_media", "self_destruct_ttl"]
@@ -0,0 +1,71 @@
from io import BytesIO
from typing import Any
from pyrogram import Client
from pyrogram.types import Message
from userbot.modules.capture import repository
from userbot.modules.capture.context import CaptureContext
from utils.policy.models import CaptureToggles
_MEDIA_ATTRS = (
"photo",
"video",
"voice",
"video_note",
"document",
"audio",
"animation",
"sticker",
)
def media_object(message: Message) -> tuple[str | None, Any]:
for attr in _MEDIA_ATTRS:
obj = getattr(message, attr, None)
if obj is not None:
return attr, obj
return None, None
def self_destruct_ttl(message: Message) -> int | None:
_, obj = media_object(message)
return getattr(obj, "ttl_seconds", None) if obj is not None else None
async def capture_media( # noqa: PLR0913
client: Client,
message: Message,
ctx: CaptureContext,
chat_id: int,
message_id: int,
toggles: CaptureToggles,
) -> None:
kind, obj = media_object(message)
if obj is None:
return
ttl = getattr(obj, "ttl_seconds", None)
want = toggles.self_destruct_media if ttl else toggles.media
file_size = getattr(obj, "file_size", None)
mime = getattr(obj, "mime_type", None)
storage_key: str | None = None
downloaded = False
if want:
buffer = await client.download_media(message, in_memory=True)
if isinstance(buffer, BytesIO):
data = buffer.getvalue()
storage_key = ctx.storage.put(data)
file_size = len(data)
downloaded = True
await repository.insert_media(
ctx.pool,
ctx.account_id,
chat_id,
message_id,
kind or "unknown",
storage_key,
file_size,
mime,
ttl,
downloaded=downloaded,
)