feat: add realtime capturing

This commit is contained in:
h
2026-05-29 18:19:06 +02:00
parent 920a0235e2
commit 3c1a12750c
29 changed files with 967 additions and 47 deletions
@@ -0,0 +1,139 @@
"""realtime core
Revision ID: c3a8e5f1b6d2
Revises: b2f7c1a9d3e4
Create Date: 2026-05-29 18:30:00.000000
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
revision: str = "c3a8e5f1b6d2"
down_revision: str | None = "b2f7c1a9d3e4"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
op.add_column("messages", sa.Column("sender_id", sa.BigInteger(), nullable=True))
op.add_column("messages", sa.Column("text", sa.String(), nullable=True))
op.add_column(
"messages",
sa.Column("has_media", sa.Boolean(), nullable=False, server_default=sa.false()),
)
op.add_column(
"messages",
sa.Column(
"is_self_destruct", sa.Boolean(), nullable=False, server_default=sa.false()
),
)
op.add_column(
"messages", sa.Column("edited_at", sa.DateTime(timezone=True), nullable=True)
)
op.create_index("ix_messages_box", "messages", ["account_id", "message_id"])
op.create_table(
"message_versions",
sa.Column("account_id", sa.Integer(), nullable=False),
sa.Column("chat_id", sa.BigInteger(), nullable=False),
sa.Column("message_id", sa.BigInteger(), nullable=False),
sa.Column("observed_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("edit_date", sa.DateTime(timezone=True), nullable=True),
sa.Column("text", sa.String(), nullable=True),
sa.Column("raw", postgresql.JSONB(astext_type=sa.Text()), nullable=False),
sa.PrimaryKeyConstraint("account_id", "chat_id", "message_id", "observed_at"),
)
op.create_table(
"media",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("account_id", sa.Integer(), nullable=False),
sa.Column("chat_id", sa.BigInteger(), nullable=False),
sa.Column("message_id", sa.BigInteger(), nullable=False),
sa.Column("kind", sa.String(), nullable=False),
sa.Column("storage_key", sa.String(), nullable=True),
sa.Column("file_size", sa.BigInteger(), nullable=True),
sa.Column("mime", sa.String(), nullable=True),
sa.Column("ttl_seconds", sa.Integer(), nullable=True),
sa.Column(
"downloaded", sa.Boolean(), nullable=False, server_default=sa.false()
),
sa.Column("extracted_text", sa.String(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("account_id", "chat_id", "message_id"),
)
op.create_table(
"callbacks",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("account_id", sa.Integer(), nullable=False),
sa.Column("chat_id", sa.BigInteger(), nullable=False),
sa.Column("message_id", sa.BigInteger(), nullable=False),
sa.Column("position", sa.Integer(), nullable=False),
sa.Column("label", sa.String(), nullable=True),
sa.Column("data", sa.LargeBinary(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_callbacks_message", "callbacks", ["account_id", "chat_id", "message_id"]
)
op.create_table(
"reactions",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("account_id", sa.Integer(), nullable=False),
sa.Column("chat_id", sa.BigInteger(), nullable=False),
sa.Column("message_id", sa.BigInteger(), nullable=False),
sa.Column("peer_id", sa.BigInteger(), nullable=False),
sa.Column("reaction", sa.String(), nullable=False),
sa.Column(
"added_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("removed_at", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_reactions_message", "reactions", ["account_id", "chat_id", "message_id"]
)
op.execute(
"CREATE INDEX ix_reactions_open ON reactions "
"(account_id, chat_id, message_id) WHERE removed_at IS NULL"
)
def downgrade() -> None:
op.drop_table("reactions")
op.drop_table("callbacks")
op.drop_table("media")
op.drop_table("message_versions")
op.drop_index("ix_messages_box", table_name="messages")
op.drop_column("messages", "edited_at")
op.drop_column("messages", "is_self_destruct")
op.drop_column("messages", "has_media")
op.drop_column("messages", "text")
op.drop_column("messages", "sender_id")
+7 -1
View File
@@ -17,6 +17,8 @@ from utils.policy.resolver import resolve
router = APIRouter(prefix="/api/policy", tags=["policy"])
POLICY_CHANGED_CHANNEL = "policy_changed"
class EffectiveQuery(BaseModel):
account_id: int
@@ -55,9 +57,11 @@ async def effective_policy(
async def create_policy(
pool: FromDishka[asyncpg.Pool], body: PolicyCreate
) -> PolicyRecord:
return await repository.create_policy(
record = await repository.create_policy(
pool, body.account_id, body.scope_type, body.scope_id, body
)
await pool.execute(f"NOTIFY {POLICY_CHANGED_CHANNEL}")
return record
@router.get("/{policy_id}")
@@ -77,6 +81,7 @@ async def update_policy(
record = await repository.update_policy(pool, policy_id, body)
if record is None:
raise HTTPException(status_code=404, detail="policy not found")
await pool.execute(f"NOTIFY {POLICY_CHANGED_CHANNEL}")
return record
@@ -85,3 +90,4 @@ async def update_policy(
async def delete_policy(pool: FromDishka[asyncpg.Pool], policy_id: int) -> None:
if not await repository.delete_policy(pool, policy_id):
raise HTTPException(status_code=404, detail="policy not found")
await pool.execute(f"NOTIFY {POLICY_CHANGED_CHANNEL}")
+2 -1
View File
@@ -1,5 +1,6 @@
from dishka import make_async_container
from dependencies.providers.postgres import DbProvider
from dependencies.providers.storage import StorageProvider
container = make_async_container(DbProvider())
container = make_async_container(DbProvider(), StorageProvider())
@@ -0,0 +1,10 @@
from dishka import Provider, Scope, provide
from utils.env import env
from utils.storage import ContentAddressedStorage
class StorageProvider(Provider):
@provide(scope=Scope.APP)
def get_storage(self) -> ContentAddressedStorage:
return ContentAddressedStorage(env.storage.root, env.storage.shard_depth)
+3
View File
@@ -0,0 +1,3 @@
from userbot.modules.client import PyroClient
__all__ = ["PyroClient"]
+2 -2
View File
@@ -1,3 +1,3 @@
from .dialog_filters import dialog_filter_handler
from userbot.handlers import deletes, edits, messages, raw
__all__ = ["dialog_filter_handler"]
handlers = messages.handlers + edits.handlers + deletes.handlers + raw.handlers
+28
View File
@@ -0,0 +1,28 @@
from pyrogram.types import Message
from userbot import PyroClient
from userbot.modules.capture import repository
from userbot.modules.capture.repository import CHANNEL_ID_THRESHOLD
@PyroClient.on_deleted_messages()
async def on_deleted_messages(client: PyroClient, messages: list[Message]) -> None:
ctx = client.capture
if ctx is None:
return
box: list[int] = []
channels: dict[int, list[int]] = {}
for message in messages:
chat = message.chat
chat_id = chat.id if chat is not None else None
if chat_id is None or chat_id > CHANNEL_ID_THRESHOLD:
box.append(message.id)
else:
channels.setdefault(chat_id, []).append(message.id)
if box:
await repository.mark_deleted_box(ctx.pool, ctx.account_id, box)
for chat_id, ids in channels.items():
await repository.mark_deleted_channel(ctx.pool, ctx.account_id, chat_id, ids)
handlers = on_deleted_messages.handlers
@@ -1,20 +0,0 @@
from pyrogram import Client, raw
from pyrogram.handlers import RawUpdateHandler
from userbot.folders import FolderCache
_FILTER_UPDATES = (
raw.types.UpdateDialogFilter,
raw.types.UpdateDialogFilters,
raw.types.UpdateDialogFilterOrder,
)
def dialog_filter_handler(cache: FolderCache) -> RawUpdateHandler:
async def on_update(
_client: Client, update: raw.base.Update, _users: dict, _chats: dict
) -> None:
if isinstance(update, _FILTER_UPDATES):
await cache.refresh()
return RawUpdateHandler(on_update)
+36
View File
@@ -0,0 +1,36 @@
from pyrogram.types import Message
from userbot import PyroClient
from userbot.handlers.messages import sender_id
from userbot.modules.capture import repository
from userbot.modules.capture.chat_meta import meta_from_chat
from userbot.modules.media import self_destruct_ttl
@PyroClient.on_edited_message()
async def on_edited_message(client: PyroClient, message: Message) -> None:
ctx = client.capture
if ctx is None or message.empty or message.chat is None or message.date is None:
return
chat = message.chat
chat_id = chat.id or 0
meta = meta_from_chat(chat, ctx.contacts.ids)
toggles = ctx.resolve(meta)
if not toggles.track_edits_deletes:
return
await repository.add_version(
ctx.pool,
ctx.account_id,
chat_id,
message.id,
message.date,
sender_id(message),
message.text or message.caption,
str(message),
message.edit_date,
has_media=message.media is not None,
is_self_destruct=self_destruct_ttl(message) is not None,
)
handlers = on_edited_message.handlers
+64
View File
@@ -0,0 +1,64 @@
from pyrogram.types import Message
from userbot import PyroClient
from userbot.modules.capture import repository
from userbot.modules.capture.chat_meta import meta_from_chat
from userbot.modules.media import capture_media, self_destruct_ttl
def sender_id(message: Message) -> int | None:
if message.from_user is not None:
return message.from_user.id
if message.sender_chat is not None:
return message.sender_chat.id
return None
def _callbacks(message: Message) -> list[tuple[int, str | None, bytes | None]]:
rows = getattr(message.reply_markup, "inline_keyboard", None)
if not rows:
return []
buttons: list[tuple[int, str | None, bytes | None]] = []
position = 0
for row in rows:
for button in row:
data = button.callback_data
if data is not None:
encoded = data.encode() if isinstance(data, str) else data
buttons.append((position, button.text, encoded))
position += 1
return buttons
@PyroClient.on_message()
async def on_message(client: PyroClient, message: Message) -> None:
ctx = client.capture
if ctx is None or message.empty or message.chat is None or message.date is None:
return
chat = message.chat
chat_id = chat.id or 0
meta = meta_from_chat(chat, ctx.contacts.ids)
toggles = ctx.resolve(meta)
if not toggles.messages:
return
await repository.upsert_message(
ctx.pool,
ctx.account_id,
chat_id,
message.id,
message.date,
sender_id(message),
message.text or message.caption,
str(message),
has_media=message.media is not None,
is_self_destruct=self_destruct_ttl(message) is not None,
)
await capture_media(client, message, ctx, chat_id, message.id, toggles)
buttons = _callbacks(message)
if buttons:
await repository.insert_callbacks(
ctx.pool, ctx.account_id, chat_id, message.id, buttons
)
handlers = on_message.handlers
@@ -0,0 +1,3 @@
from userbot.handlers.raw.dispatcher import handlers
__all__ = ["handlers"]
@@ -0,0 +1,22 @@
from pyrogram import raw
from userbot import PyroClient
HANDLES = (raw.types.UpdateContactsReset, raw.types.UpdateUser)
async def handle(
client: PyroClient, update: raw.base.Update, users: dict, _chats: dict
) -> None:
ctx = client.capture
if ctx is None:
return
if isinstance(update, raw.types.UpdateContactsReset):
await ctx.contacts.refresh()
return
if not isinstance(update, raw.types.UpdateUser):
return
user = users.get(update.user_id)
if not isinstance(user, raw.types.User) or user.min:
return
ctx.contacts.mark(user.id, is_contact=bool(user.contact))
@@ -0,0 +1,19 @@
from pyrogram import raw
from userbot import PyroClient
HANDLES = (
raw.types.UpdateDialogFilter,
raw.types.UpdateDialogFilters,
raw.types.UpdateDialogFilterOrder,
)
async def handle(
client: PyroClient, _update: raw.base.Update, _users: dict, _chats: dict
) -> None:
ctx = client.capture
if ctx is None:
return
await ctx.folders.refresh()
await ctx.reload_policies()
@@ -0,0 +1,25 @@
from collections.abc import Awaitable, Callable
from pyrogram import raw
from userbot import PyroClient
from userbot.handlers.raw import contacts, dialog_filters, reactions
RawHandler = Callable[..., Awaitable[None]]
_MODULES = (contacts, dialog_filters, reactions)
_REGISTRY: dict[type, RawHandler] = {
update_type: module.handle for module in _MODULES for update_type in module.HANDLES
}
@PyroClient.on_raw_update()
async def dispatch(
client: PyroClient, update: raw.base.Update, users: dict, chats: dict
) -> None:
handler = _REGISTRY.get(type(update))
if handler is not None:
await handler(client, update, users, chats)
handlers = dispatch.handlers
@@ -0,0 +1,49 @@
from pyrogram import raw, utils
from userbot import PyroClient
from userbot.modules.capture import repository
from userbot.modules.capture.chat_meta import meta_from_peer
HANDLES = (raw.types.UpdateMessageReactions,)
def _reaction_key(reaction: raw.base.Reaction) -> str | None:
if isinstance(reaction, raw.types.ReactionEmoji):
return reaction.emoticon
if isinstance(reaction, raw.types.ReactionCustomEmoji):
return f"custom:{reaction.document_id}"
return None
def _parse_recent(reactions: raw.base.MessageReactions) -> list[tuple[int, str]]:
recent = getattr(reactions, "recent_reactions", None) or []
parsed: list[tuple[int, str]] = []
for peer_reaction in recent:
key = _reaction_key(peer_reaction.reaction)
if key is None:
continue
try:
peer_id = utils.get_peer_id(peer_reaction.peer_id)
except (ValueError, AttributeError):
continue
parsed.append((peer_id, key))
return parsed
async def handle(
client: PyroClient,
update: raw.types.UpdateMessageReactions,
users: dict,
chats: dict,
) -> None:
ctx = client.capture
if ctx is None:
return
meta = meta_from_peer(update.peer, chats, users, ctx.contacts.ids)
toggles = ctx.resolve(meta)
if not toggles.reactions:
return
current = _parse_recent(update.reactions)
await repository.sync_reactions(
ctx.pool, ctx.account_id, meta.chat_id, update.msg_id, current
)
-3
View File
@@ -1,3 +0,0 @@
from .client import PyroClient
__all__ = ["PyroClient"]
@@ -0,0 +1,3 @@
from userbot.modules.capture.context import CaptureContext, build_capture_context
__all__ = ["CaptureContext", "build_capture_context"]
@@ -0,0 +1,46 @@
from pyrogram import enums, raw, utils
from pyrogram.types import Chat
from utils.policy.models import ChatKind, ChatMeta
_KIND_BY_TYPE: dict[enums.ChatType, ChatKind] = {
enums.ChatType.PRIVATE: ChatKind.DM,
enums.ChatType.BOT: ChatKind.DM,
enums.ChatType.GROUP: ChatKind.GROUP,
enums.ChatType.SUPERGROUP: ChatKind.GROUP,
enums.ChatType.FORUM: ChatKind.GROUP,
enums.ChatType.CHANNEL: ChatKind.CHANNEL,
}
def chat_kind(chat_type: enums.ChatType | None) -> ChatKind:
return _KIND_BY_TYPE.get(chat_type, ChatKind.DM) if chat_type else ChatKind.DM
def meta_from_chat(chat: Chat, contacts: set[int]) -> ChatMeta:
kind = chat_kind(chat.type)
is_bot = chat.type is enums.ChatType.BOT
chat_id = chat.id or 0
is_contact = chat_id in contacts if kind is ChatKind.DM else None
return ChatMeta(chat_id=chat_id, kind=kind, is_bot=is_bot, is_contact=is_contact)
def meta_from_peer(
peer: raw.base.Peer, chats: dict, users: dict, contacts: set[int]
) -> ChatMeta:
chat_id = utils.get_peer_id(peer)
if isinstance(peer, raw.types.PeerUser):
user = users.get(peer.user_id)
return ChatMeta(
chat_id=chat_id,
kind=ChatKind.DM,
is_bot=bool(getattr(user, "bot", False)),
is_contact=peer.user_id in contacts,
)
if isinstance(peer, raw.types.PeerChannel):
channel = chats.get(peer.channel_id)
kind = (
ChatKind.CHANNEL if getattr(channel, "broadcast", False) else ChatKind.GROUP
)
return ChatMeta(chat_id=chat_id, kind=kind)
return ChatMeta(chat_id=chat_id, kind=ChatKind.GROUP)
@@ -0,0 +1,47 @@
import asyncpg
from pyrogram import Client
from userbot.modules.contacts import ContactCache
from userbot.modules.folders import FolderCache
from utils.policy.models import CaptureToggles, ChatMeta, PolicySet
from utils.policy.repository import load_policy_set
from utils.policy.resolver import resolve
from utils.storage import ContentAddressedStorage
class CaptureContext:
def __init__(
self,
account_id: int,
pool: asyncpg.Pool,
storage: ContentAddressedStorage,
folders: FolderCache,
contacts: ContactCache,
) -> None:
self.account_id = account_id
self.pool = pool
self.storage = storage
self.folders = folders
self.contacts = contacts
self.policies = PolicySet()
async def reload_policies(self) -> None:
self.policies = await load_policy_set(self.pool, self.account_id)
def resolve(self, chat: ChatMeta) -> CaptureToggles:
return resolve(chat, self.folders.folders, self.policies)
async def build_capture_context(
client: Client,
pool: asyncpg.Pool,
storage: ContentAddressedStorage,
account_id: int,
) -> CaptureContext:
folders = FolderCache(client, pool, account_id)
await folders.refresh()
contacts = ContactCache(client)
await contacts.refresh()
ctx = CaptureContext(account_id, pool, storage, folders, contacts)
await ctx.reload_policies()
return ctx
@@ -0,0 +1,223 @@
from datetime import datetime
import asyncpg
CHANNEL_ID_THRESHOLD = -1000000000000
_UPSERT_MESSAGE = """
INSERT INTO messages
(account_id, chat_id, message_id, date, sender_id, text, raw,
has_media, is_self_destruct)
VALUES ($1, $2, $3, $4, $5, $6, $7::jsonb, $8, $9)
ON CONFLICT (account_id, chat_id, message_id, date) DO UPDATE SET
sender_id = EXCLUDED.sender_id,
text = EXCLUDED.text,
raw = EXCLUDED.raw,
has_media = EXCLUDED.has_media,
is_self_destruct = EXCLUDED.is_self_destruct
"""
_TOUCH_EDITED = """
INSERT INTO messages
(account_id, chat_id, message_id, date, sender_id, text, raw,
has_media, is_self_destruct, edited_at)
VALUES ($1, $2, $3, $4, $5, $6, $7::jsonb, $8, $9, now())
ON CONFLICT (account_id, chat_id, message_id, date) DO UPDATE SET
edited_at = now()
"""
_INSERT_VERSION = """
INSERT INTO message_versions
(account_id, chat_id, message_id, observed_at, edit_date, text, raw)
VALUES ($1, $2, $3, now(), $4, $5, $6::jsonb)
ON CONFLICT DO NOTHING
"""
_INSERT_MEDIA = """
INSERT INTO media
(account_id, chat_id, message_id, kind, storage_key, file_size, mime,
ttl_seconds, downloaded)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
ON CONFLICT (account_id, chat_id, message_id) DO UPDATE SET
kind = EXCLUDED.kind,
storage_key = EXCLUDED.storage_key,
file_size = EXCLUDED.file_size,
mime = EXCLUDED.mime,
ttl_seconds = EXCLUDED.ttl_seconds,
downloaded = EXCLUDED.downloaded
"""
async def upsert_message( # noqa: PLR0913
pool: asyncpg.Pool,
account_id: int,
chat_id: int,
message_id: int,
date: datetime,
sender_id: int | None,
text: str | None,
raw: str,
*,
has_media: bool,
is_self_destruct: bool,
) -> None:
await pool.execute(
_UPSERT_MESSAGE,
account_id,
chat_id,
message_id,
date,
sender_id,
text,
raw,
has_media,
is_self_destruct,
)
async def mark_deleted_box(
pool: asyncpg.Pool, account_id: int, message_ids: list[int]
) -> None:
await pool.execute(
"UPDATE messages SET deleted_at = now() "
"WHERE account_id = $1 AND message_id = ANY($2::bigint[]) "
"AND chat_id > $3 AND deleted_at IS NULL",
account_id,
message_ids,
CHANNEL_ID_THRESHOLD,
)
async def mark_deleted_channel(
pool: asyncpg.Pool, account_id: int, chat_id: int, message_ids: list[int]
) -> None:
await pool.execute(
"UPDATE messages SET deleted_at = now() "
"WHERE account_id = $1 AND chat_id = $2 AND message_id = ANY($3::bigint[]) "
"AND deleted_at IS NULL",
account_id,
chat_id,
message_ids,
)
async def add_version( # noqa: PLR0913
pool: asyncpg.Pool,
account_id: int,
chat_id: int,
message_id: int,
date: datetime,
sender_id: int | None,
text: str | None,
raw: str,
edit_date: datetime | None,
*,
has_media: bool,
is_self_destruct: bool,
) -> None:
async with pool.acquire() as conn, conn.transaction():
await conn.execute(
_TOUCH_EDITED,
account_id,
chat_id,
message_id,
date,
sender_id,
text,
raw,
has_media,
is_self_destruct,
)
await conn.execute(
_INSERT_VERSION, account_id, chat_id, message_id, edit_date, text, raw
)
async def insert_media( # noqa: PLR0913
pool: asyncpg.Pool,
account_id: int,
chat_id: int,
message_id: int,
kind: str,
storage_key: str | None,
file_size: int | None,
mime: str | None,
ttl_seconds: int | None,
*,
downloaded: bool,
) -> None:
await pool.execute(
_INSERT_MEDIA,
account_id,
chat_id,
message_id,
kind,
storage_key,
file_size,
mime,
ttl_seconds,
downloaded,
)
async def insert_callbacks(
pool: asyncpg.Pool,
account_id: int,
chat_id: int,
message_id: int,
buttons: list[tuple[int, str | None, bytes | None]],
) -> None:
async with pool.acquire() as conn, conn.transaction():
await conn.execute(
"DELETE FROM callbacks "
"WHERE account_id = $1 AND chat_id = $2 AND message_id = $3",
account_id,
chat_id,
message_id,
)
await conn.executemany(
"INSERT INTO callbacks "
"(account_id, chat_id, message_id, position, label, data) "
"VALUES ($1, $2, $3, $4, $5, $6)",
[
(account_id, chat_id, message_id, position, label, data)
for position, label, data in buttons
],
)
async def sync_reactions(
pool: asyncpg.Pool,
account_id: int,
chat_id: int,
message_id: int,
current: list[tuple[int, str]],
) -> None:
async with pool.acquire() as conn, conn.transaction():
rows = await conn.fetch(
"SELECT id, peer_id, reaction FROM reactions "
"WHERE account_id = $1 AND chat_id = $2 AND message_id = $3 "
"AND removed_at IS NULL",
account_id,
chat_id,
message_id,
)
existing = {(row["peer_id"], row["reaction"]): row["id"] for row in rows}
current_set = set(current)
stale = [pid for key, pid in existing.items() if key not in current_set]
if stale:
await conn.execute(
"UPDATE reactions SET removed_at = now() WHERE id = ANY($1::bigint[])",
stale,
)
fresh = [key for key in current_set if key not in existing]
if fresh:
await conn.executemany(
"INSERT INTO reactions "
"(account_id, chat_id, message_id, peer_id, reaction) "
"VALUES ($1, $2, $3, $4, $5)",
[
(account_id, chat_id, message_id, peer_id, reaction)
for peer_id, reaction in fresh
],
)
+15 -1
View File
@@ -1,8 +1,15 @@
from typing import TYPE_CHECKING
from pyrogram import Client, enums
if TYPE_CHECKING:
from userbot.modules.capture import CaptureContext
class PyroClient(Client):
def __init__(self, name: str, *, workdir: str = "sessions") -> None:
def __init__(
self, name: str, *, workdir: str = "sessions", load_handlers: bool = True
) -> None:
super().__init__(
name,
workdir=workdir,
@@ -14,6 +21,13 @@ class PyroClient(Client):
lang_pack="tdesktop",
client_platform=enums.ClientPlatform.DESKTOP,
)
self.capture: CaptureContext | None = None
if load_handlers:
from userbot import handlers # noqa: PLC0415
for handler in handlers.handlers:
self.add_handler(*handler)
__all__ = ["PyroClient"]
@@ -0,0 +1,3 @@
from userbot.modules.contacts.cache import ContactCache
__all__ = ["ContactCache"]
@@ -0,0 +1,20 @@
from pyrogram import Client
from utils.logging import logger
class ContactCache:
def __init__(self, client: Client) -> None:
self._client = client
self.ids: set[int] = set()
async def refresh(self) -> None:
contacts = await self._client.get_contacts()
self.ids = {user.id for user in contacts}
logger.info(f"[green]Contacts cached:[/] {len(self.ids)}")
def mark(self, user_id: int, *, is_contact: bool) -> None:
if is_contact:
self.ids.add(user_id)
else:
self.ids.discard(user_id)
@@ -0,0 +1,3 @@
from userbot.modules.folders.cache import FolderCache
__all__ = ["FolderCache"]
@@ -0,0 +1,3 @@
from userbot.modules.media.downloader import capture_media, self_destruct_ttl
__all__ = ["capture_media", "self_destruct_ttl"]
@@ -0,0 +1,71 @@
from io import BytesIO
from typing import Any
from pyrogram import Client
from pyrogram.types import Message
from userbot.modules.capture import repository
from userbot.modules.capture.context import CaptureContext
from utils.policy.models import CaptureToggles
_MEDIA_ATTRS = (
"photo",
"video",
"voice",
"video_note",
"document",
"audio",
"animation",
"sticker",
)
def media_object(message: Message) -> tuple[str | None, Any]:
for attr in _MEDIA_ATTRS:
obj = getattr(message, attr, None)
if obj is not None:
return attr, obj
return None, None
def self_destruct_ttl(message: Message) -> int | None:
_, obj = media_object(message)
return getattr(obj, "ttl_seconds", None) if obj is not None else None
async def capture_media( # noqa: PLR0913
client: Client,
message: Message,
ctx: CaptureContext,
chat_id: int,
message_id: int,
toggles: CaptureToggles,
) -> None:
kind, obj = media_object(message)
if obj is None:
return
ttl = getattr(obj, "ttl_seconds", None)
want = toggles.self_destruct_media if ttl else toggles.media
file_size = getattr(obj, "file_size", None)
mime = getattr(obj, "mime_type", None)
storage_key: str | None = None
downloaded = False
if want:
buffer = await client.download_media(message, in_memory=True)
if isinstance(buffer, BytesIO):
data = buffer.getvalue()
storage_key = ctx.storage.put(data)
file_size = len(data)
downloaded = True
await repository.insert_media(
ctx.pool,
ctx.account_id,
chat_id,
message_id,
kind or "unknown",
storage_key,
file_size,
mime,
ttl,
downloaded=downloaded,
)
+36 -18
View File
@@ -7,14 +7,11 @@ import asyncpg
import uvloop
from dependencies.container import container
from userbot.folders import FolderCache
from userbot.handlers import dialog_filter_handler
from userbot.modules import PyroClient
from userbot import PyroClient
from userbot.modules.capture import build_capture_context
from utils.env import env
from utils.logging import logger, setup_logging
from utils.policy.models import ChatKind, ChatMeta
from utils.policy.repository import load_policy_set
from utils.policy.resolver import resolve
from utils.storage import ContentAddressedStorage
setup_logging()
@@ -61,22 +58,37 @@ async def _sync_account(
return account_id
async def _setup_policy(
pool: asyncpg.Pool, client: PyroClient, account_id: int
async def _setup_capture(
pool: asyncpg.Pool,
client: PyroClient,
account_id: int,
storage: ContentAddressedStorage,
) -> None:
cache = FolderCache(client, pool, account_id)
await cache.refresh()
client.add_handler(dialog_filter_handler(cache))
if client.me:
policies = await load_policy_set(pool, account_id)
sample = resolve(
ChatMeta(chat_id=client.me.id, kind=ChatKind.DM), cache.folders, policies
)
logger.info(f"[green]Sample resolve (self DM):[/] {sample.model_dump()}")
client.capture = await build_capture_context(client, pool, storage, account_id)
logger.info("[green]Capture context ready.[/]")
async def _listen_policy_changes(
clients: list[PyroClient], tasks: set[asyncio.Task]
) -> asyncpg.Connection:
def on_change(
_conn: asyncpg.Connection, _pid: int, _channel: str, _payload: str
) -> None:
for client in clients:
if client.capture is None:
continue
task = asyncio.create_task(client.capture.reload_policies())
tasks.add(task)
task.add_done_callback(tasks.discard)
conn = await asyncpg.connect(dsn=env.db.connection_url)
await conn.add_listener("policy_changed", on_change)
return conn
async def runner() -> None:
pool = await container.get(asyncpg.Pool)
storage = await container.get(ContentAddressedStorage)
sessions_dir = Path(env.tg.sessions_dir)
session_files = _discover_sessions(sessions_dir)
@@ -88,6 +100,8 @@ async def runner() -> None:
)
clients: list[PyroClient] = []
reload_tasks: set[asyncio.Task] = set()
listen_conn: asyncpg.Connection | None = None
try:
for session_path in session_files:
session_name = session_path.stem
@@ -101,12 +115,16 @@ async def runner() -> None:
)
account_id = await _sync_account(pool, client, session_name)
if account_id is not None:
await _setup_policy(pool, client, account_id)
await _setup_capture(pool, client, account_id, storage)
if clients:
listen_conn = await _listen_policy_changes(clients, reload_tasks)
logger.info("[green]Userbot running.[/]")
await asyncio.Event().wait()
finally:
if listen_conn is not None:
with contextlib.suppress(Exception):
await listen_conn.close()
for client in clients:
with contextlib.suppress(Exception):
await client.stop()
+88 -1
View File
@@ -1,7 +1,7 @@
from datetime import datetime
from typing import Any
from sqlalchemy import BigInteger, Column, DateTime, func
from sqlalchemy import BigInteger, Column, DateTime, LargeBinary, func
from sqlalchemy.dialects.postgresql import JSONB
from sqlmodel import Field, SQLModel
@@ -42,9 +42,16 @@ class Message(SQLModel, table=True):
chat_id: int = Field(sa_column=Column(BigInteger, primary_key=True))
message_id: int = Field(sa_column=Column(BigInteger, primary_key=True))
date: datetime = Field(sa_column=Column(DateTime(timezone=True), primary_key=True))
sender_id: int | None = Field(default=None, sa_column=Column(BigInteger))
text: str | None = None
raw: dict[str, Any] = Field(
default_factory=dict, sa_column=Column(JSONB, nullable=False)
)
has_media: bool = False
is_self_destruct: bool = False
edited_at: datetime | None = Field(
default=None, sa_column=Column(DateTime(timezone=True))
)
deleted_at: datetime | None = Field(
default=None, sa_column=Column(DateTime(timezone=True))
)
@@ -71,6 +78,86 @@ class Folder(SQLModel, table=True):
)
class MessageVersion(SQLModel, table=True):
__tablename__ = "message_versions"
account_id: int = Field(primary_key=True)
chat_id: int = Field(sa_column=Column(BigInteger, primary_key=True))
message_id: int = Field(sa_column=Column(BigInteger, primary_key=True))
observed_at: datetime = Field(
sa_column=Column(DateTime(timezone=True), primary_key=True)
)
edit_date: datetime | None = Field(
default=None, sa_column=Column(DateTime(timezone=True))
)
text: str | None = None
raw: dict[str, Any] = Field(
default_factory=dict, sa_column=Column(JSONB, nullable=False)
)
class Media(SQLModel, table=True):
__tablename__ = "media"
id: int | None = Field(default=None, primary_key=True)
account_id: int
chat_id: int = Field(sa_column=Column(BigInteger, nullable=False))
message_id: int = Field(sa_column=Column(BigInteger, nullable=False))
kind: str
storage_key: str | None = None
file_size: int | None = Field(default=None, sa_column=Column(BigInteger))
mime: str | None = None
ttl_seconds: int | None = None
downloaded: bool = False
extracted_text: str | None = None
created_at: datetime = Field(
sa_column=Column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
)
class Callback(SQLModel, table=True):
__tablename__ = "callbacks"
id: int | None = Field(default=None, primary_key=True)
account_id: int
chat_id: int = Field(sa_column=Column(BigInteger, nullable=False))
message_id: int = Field(sa_column=Column(BigInteger, nullable=False))
position: int
label: str | None = None
data: bytes | None = Field(default=None, sa_column=Column(LargeBinary))
created_at: datetime = Field(
sa_column=Column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
)
class Reaction(SQLModel, table=True):
__tablename__ = "reactions"
id: int | None = Field(default=None, primary_key=True)
account_id: int
chat_id: int = Field(sa_column=Column(BigInteger, nullable=False))
message_id: int = Field(sa_column=Column(BigInteger, nullable=False))
peer_id: int = Field(sa_column=Column(BigInteger, nullable=False))
reaction: str
added_at: datetime = Field(
sa_column=Column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
)
removed_at: datetime | None = Field(
default=None, sa_column=Column(DateTime(timezone=True))
)
created_at: datetime = Field(
sa_column=Column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
)
class CapturePolicy(SQLModel, table=True):
__tablename__ = "capture_policy"