diff --git a/backend/migrations/versions/c3a8e5f1b6d2_realtime_core.py b/backend/migrations/versions/c3a8e5f1b6d2_realtime_core.py new file mode 100644 index 0000000..c159efa --- /dev/null +++ b/backend/migrations/versions/c3a8e5f1b6d2_realtime_core.py @@ -0,0 +1,139 @@ +"""realtime core + +Revision ID: c3a8e5f1b6d2 +Revises: b2f7c1a9d3e4 +Create Date: 2026-05-29 18:30:00.000000 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +revision: str = "c3a8e5f1b6d2" +down_revision: str | None = "b2f7c1a9d3e4" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.add_column("messages", sa.Column("sender_id", sa.BigInteger(), nullable=True)) + op.add_column("messages", sa.Column("text", sa.String(), nullable=True)) + op.add_column( + "messages", + sa.Column("has_media", sa.Boolean(), nullable=False, server_default=sa.false()), + ) + op.add_column( + "messages", + sa.Column( + "is_self_destruct", sa.Boolean(), nullable=False, server_default=sa.false() + ), + ) + op.add_column( + "messages", sa.Column("edited_at", sa.DateTime(timezone=True), nullable=True) + ) + op.create_index("ix_messages_box", "messages", ["account_id", "message_id"]) + + op.create_table( + "message_versions", + sa.Column("account_id", sa.Integer(), nullable=False), + sa.Column("chat_id", sa.BigInteger(), nullable=False), + sa.Column("message_id", sa.BigInteger(), nullable=False), + sa.Column("observed_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("edit_date", sa.DateTime(timezone=True), nullable=True), + sa.Column("text", sa.String(), nullable=True), + sa.Column("raw", postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.PrimaryKeyConstraint("account_id", "chat_id", "message_id", "observed_at"), + ) + + op.create_table( + "media", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("account_id", sa.Integer(), nullable=False), + sa.Column("chat_id", sa.BigInteger(), nullable=False), + sa.Column("message_id", sa.BigInteger(), nullable=False), + sa.Column("kind", sa.String(), nullable=False), + sa.Column("storage_key", sa.String(), nullable=True), + sa.Column("file_size", sa.BigInteger(), nullable=True), + sa.Column("mime", sa.String(), nullable=True), + sa.Column("ttl_seconds", sa.Integer(), nullable=True), + sa.Column( + "downloaded", sa.Boolean(), nullable=False, server_default=sa.false() + ), + sa.Column("extracted_text", sa.String(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("account_id", "chat_id", "message_id"), + ) + + op.create_table( + "callbacks", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("account_id", sa.Integer(), nullable=False), + sa.Column("chat_id", sa.BigInteger(), nullable=False), + sa.Column("message_id", sa.BigInteger(), nullable=False), + sa.Column("position", sa.Integer(), nullable=False), + sa.Column("label", sa.String(), nullable=True), + sa.Column("data", sa.LargeBinary(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "ix_callbacks_message", "callbacks", ["account_id", "chat_id", "message_id"] + ) + + op.create_table( + "reactions", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("account_id", sa.Integer(), nullable=False), + sa.Column("chat_id", sa.BigInteger(), nullable=False), + sa.Column("message_id", sa.BigInteger(), nullable=False), + sa.Column("peer_id", sa.BigInteger(), nullable=False), + sa.Column("reaction", sa.String(), nullable=False), + sa.Column( + "added_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column("removed_at", sa.DateTime(timezone=True), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "ix_reactions_message", "reactions", ["account_id", "chat_id", "message_id"] + ) + op.execute( + "CREATE INDEX ix_reactions_open ON reactions " + "(account_id, chat_id, message_id) WHERE removed_at IS NULL" + ) + + +def downgrade() -> None: + op.drop_table("reactions") + op.drop_table("callbacks") + op.drop_table("media") + op.drop_table("message_versions") + op.drop_index("ix_messages_box", table_name="messages") + op.drop_column("messages", "edited_at") + op.drop_column("messages", "is_self_destruct") + op.drop_column("messages", "has_media") + op.drop_column("messages", "text") + op.drop_column("messages", "sender_id") diff --git a/backend/src/api/routers/policy.py b/backend/src/api/routers/policy.py index 0ac0264..b4785a5 100644 --- a/backend/src/api/routers/policy.py +++ b/backend/src/api/routers/policy.py @@ -17,6 +17,8 @@ from utils.policy.resolver import resolve router = APIRouter(prefix="/api/policy", tags=["policy"]) +POLICY_CHANGED_CHANNEL = "policy_changed" + class EffectiveQuery(BaseModel): account_id: int @@ -55,9 +57,11 @@ async def effective_policy( async def create_policy( pool: FromDishka[asyncpg.Pool], body: PolicyCreate ) -> PolicyRecord: - return await repository.create_policy( + record = await repository.create_policy( pool, body.account_id, body.scope_type, body.scope_id, body ) + await pool.execute(f"NOTIFY {POLICY_CHANGED_CHANNEL}") + return record @router.get("/{policy_id}") @@ -77,6 +81,7 @@ async def update_policy( record = await repository.update_policy(pool, policy_id, body) if record is None: raise HTTPException(status_code=404, detail="policy not found") + await pool.execute(f"NOTIFY {POLICY_CHANGED_CHANNEL}") return record @@ -85,3 +90,4 @@ async def update_policy( async def delete_policy(pool: FromDishka[asyncpg.Pool], policy_id: int) -> None: if not await repository.delete_policy(pool, policy_id): raise HTTPException(status_code=404, detail="policy not found") + await pool.execute(f"NOTIFY {POLICY_CHANGED_CHANNEL}") diff --git a/backend/src/dependencies/container.py b/backend/src/dependencies/container.py index b9f04ce..f2cf878 100644 --- a/backend/src/dependencies/container.py +++ b/backend/src/dependencies/container.py @@ -1,5 +1,6 @@ from dishka import make_async_container from dependencies.providers.postgres import DbProvider +from dependencies.providers.storage import StorageProvider -container = make_async_container(DbProvider()) +container = make_async_container(DbProvider(), StorageProvider()) diff --git a/backend/src/dependencies/providers/storage.py b/backend/src/dependencies/providers/storage.py new file mode 100644 index 0000000..cdd7576 --- /dev/null +++ b/backend/src/dependencies/providers/storage.py @@ -0,0 +1,10 @@ +from dishka import Provider, Scope, provide + +from utils.env import env +from utils.storage import ContentAddressedStorage + + +class StorageProvider(Provider): + @provide(scope=Scope.APP) + def get_storage(self) -> ContentAddressedStorage: + return ContentAddressedStorage(env.storage.root, env.storage.shard_depth) diff --git a/backend/src/userbot/__init__.py b/backend/src/userbot/__init__.py index e69de29..5ea2b14 100644 --- a/backend/src/userbot/__init__.py +++ b/backend/src/userbot/__init__.py @@ -0,0 +1,3 @@ +from userbot.modules.client import PyroClient + +__all__ = ["PyroClient"] diff --git a/backend/src/userbot/handlers/__init__.py b/backend/src/userbot/handlers/__init__.py index f88cb35..bdb4621 100644 --- a/backend/src/userbot/handlers/__init__.py +++ b/backend/src/userbot/handlers/__init__.py @@ -1,3 +1,3 @@ -from .dialog_filters import dialog_filter_handler +from userbot.handlers import deletes, edits, messages, raw -__all__ = ["dialog_filter_handler"] +handlers = messages.handlers + edits.handlers + deletes.handlers + raw.handlers diff --git a/backend/src/userbot/handlers/deletes.py b/backend/src/userbot/handlers/deletes.py new file mode 100644 index 0000000..4d9dd60 --- /dev/null +++ b/backend/src/userbot/handlers/deletes.py @@ -0,0 +1,28 @@ +from pyrogram.types import Message + +from userbot import PyroClient +from userbot.modules.capture import repository +from userbot.modules.capture.repository import CHANNEL_ID_THRESHOLD + + +@PyroClient.on_deleted_messages() +async def on_deleted_messages(client: PyroClient, messages: list[Message]) -> None: + ctx = client.capture + if ctx is None: + return + box: list[int] = [] + channels: dict[int, list[int]] = {} + for message in messages: + chat = message.chat + chat_id = chat.id if chat is not None else None + if chat_id is None or chat_id > CHANNEL_ID_THRESHOLD: + box.append(message.id) + else: + channels.setdefault(chat_id, []).append(message.id) + if box: + await repository.mark_deleted_box(ctx.pool, ctx.account_id, box) + for chat_id, ids in channels.items(): + await repository.mark_deleted_channel(ctx.pool, ctx.account_id, chat_id, ids) + + +handlers = on_deleted_messages.handlers diff --git a/backend/src/userbot/handlers/dialog_filters.py b/backend/src/userbot/handlers/dialog_filters.py deleted file mode 100644 index 269bcfe..0000000 --- a/backend/src/userbot/handlers/dialog_filters.py +++ /dev/null @@ -1,20 +0,0 @@ -from pyrogram import Client, raw -from pyrogram.handlers import RawUpdateHandler - -from userbot.folders import FolderCache - -_FILTER_UPDATES = ( - raw.types.UpdateDialogFilter, - raw.types.UpdateDialogFilters, - raw.types.UpdateDialogFilterOrder, -) - - -def dialog_filter_handler(cache: FolderCache) -> RawUpdateHandler: - async def on_update( - _client: Client, update: raw.base.Update, _users: dict, _chats: dict - ) -> None: - if isinstance(update, _FILTER_UPDATES): - await cache.refresh() - - return RawUpdateHandler(on_update) diff --git a/backend/src/userbot/handlers/edits.py b/backend/src/userbot/handlers/edits.py new file mode 100644 index 0000000..fc130f3 --- /dev/null +++ b/backend/src/userbot/handlers/edits.py @@ -0,0 +1,36 @@ +from pyrogram.types import Message + +from userbot import PyroClient +from userbot.handlers.messages import sender_id +from userbot.modules.capture import repository +from userbot.modules.capture.chat_meta import meta_from_chat +from userbot.modules.media import self_destruct_ttl + + +@PyroClient.on_edited_message() +async def on_edited_message(client: PyroClient, message: Message) -> None: + ctx = client.capture + if ctx is None or message.empty or message.chat is None or message.date is None: + return + chat = message.chat + chat_id = chat.id or 0 + meta = meta_from_chat(chat, ctx.contacts.ids) + toggles = ctx.resolve(meta) + if not toggles.track_edits_deletes: + return + await repository.add_version( + ctx.pool, + ctx.account_id, + chat_id, + message.id, + message.date, + sender_id(message), + message.text or message.caption, + str(message), + message.edit_date, + has_media=message.media is not None, + is_self_destruct=self_destruct_ttl(message) is not None, + ) + + +handlers = on_edited_message.handlers diff --git a/backend/src/userbot/handlers/messages.py b/backend/src/userbot/handlers/messages.py new file mode 100644 index 0000000..ef4c568 --- /dev/null +++ b/backend/src/userbot/handlers/messages.py @@ -0,0 +1,64 @@ +from pyrogram.types import Message + +from userbot import PyroClient +from userbot.modules.capture import repository +from userbot.modules.capture.chat_meta import meta_from_chat +from userbot.modules.media import capture_media, self_destruct_ttl + + +def sender_id(message: Message) -> int | None: + if message.from_user is not None: + return message.from_user.id + if message.sender_chat is not None: + return message.sender_chat.id + return None + + +def _callbacks(message: Message) -> list[tuple[int, str | None, bytes | None]]: + rows = getattr(message.reply_markup, "inline_keyboard", None) + if not rows: + return [] + buttons: list[tuple[int, str | None, bytes | None]] = [] + position = 0 + for row in rows: + for button in row: + data = button.callback_data + if data is not None: + encoded = data.encode() if isinstance(data, str) else data + buttons.append((position, button.text, encoded)) + position += 1 + return buttons + + +@PyroClient.on_message() +async def on_message(client: PyroClient, message: Message) -> None: + ctx = client.capture + if ctx is None or message.empty or message.chat is None or message.date is None: + return + chat = message.chat + chat_id = chat.id or 0 + meta = meta_from_chat(chat, ctx.contacts.ids) + toggles = ctx.resolve(meta) + if not toggles.messages: + return + await repository.upsert_message( + ctx.pool, + ctx.account_id, + chat_id, + message.id, + message.date, + sender_id(message), + message.text or message.caption, + str(message), + has_media=message.media is not None, + is_self_destruct=self_destruct_ttl(message) is not None, + ) + await capture_media(client, message, ctx, chat_id, message.id, toggles) + buttons = _callbacks(message) + if buttons: + await repository.insert_callbacks( + ctx.pool, ctx.account_id, chat_id, message.id, buttons + ) + + +handlers = on_message.handlers diff --git a/backend/src/userbot/handlers/raw/__init__.py b/backend/src/userbot/handlers/raw/__init__.py new file mode 100644 index 0000000..237ac73 --- /dev/null +++ b/backend/src/userbot/handlers/raw/__init__.py @@ -0,0 +1,3 @@ +from userbot.handlers.raw.dispatcher import handlers + +__all__ = ["handlers"] diff --git a/backend/src/userbot/handlers/raw/contacts.py b/backend/src/userbot/handlers/raw/contacts.py new file mode 100644 index 0000000..6655705 --- /dev/null +++ b/backend/src/userbot/handlers/raw/contacts.py @@ -0,0 +1,22 @@ +from pyrogram import raw + +from userbot import PyroClient + +HANDLES = (raw.types.UpdateContactsReset, raw.types.UpdateUser) + + +async def handle( + client: PyroClient, update: raw.base.Update, users: dict, _chats: dict +) -> None: + ctx = client.capture + if ctx is None: + return + if isinstance(update, raw.types.UpdateContactsReset): + await ctx.contacts.refresh() + return + if not isinstance(update, raw.types.UpdateUser): + return + user = users.get(update.user_id) + if not isinstance(user, raw.types.User) or user.min: + return + ctx.contacts.mark(user.id, is_contact=bool(user.contact)) diff --git a/backend/src/userbot/handlers/raw/dialog_filters.py b/backend/src/userbot/handlers/raw/dialog_filters.py new file mode 100644 index 0000000..ff802b2 --- /dev/null +++ b/backend/src/userbot/handlers/raw/dialog_filters.py @@ -0,0 +1,19 @@ +from pyrogram import raw + +from userbot import PyroClient + +HANDLES = ( + raw.types.UpdateDialogFilter, + raw.types.UpdateDialogFilters, + raw.types.UpdateDialogFilterOrder, +) + + +async def handle( + client: PyroClient, _update: raw.base.Update, _users: dict, _chats: dict +) -> None: + ctx = client.capture + if ctx is None: + return + await ctx.folders.refresh() + await ctx.reload_policies() diff --git a/backend/src/userbot/handlers/raw/dispatcher.py b/backend/src/userbot/handlers/raw/dispatcher.py new file mode 100644 index 0000000..44bf7cd --- /dev/null +++ b/backend/src/userbot/handlers/raw/dispatcher.py @@ -0,0 +1,25 @@ +from collections.abc import Awaitable, Callable + +from pyrogram import raw + +from userbot import PyroClient +from userbot.handlers.raw import contacts, dialog_filters, reactions + +RawHandler = Callable[..., Awaitable[None]] + +_MODULES = (contacts, dialog_filters, reactions) +_REGISTRY: dict[type, RawHandler] = { + update_type: module.handle for module in _MODULES for update_type in module.HANDLES +} + + +@PyroClient.on_raw_update() +async def dispatch( + client: PyroClient, update: raw.base.Update, users: dict, chats: dict +) -> None: + handler = _REGISTRY.get(type(update)) + if handler is not None: + await handler(client, update, users, chats) + + +handlers = dispatch.handlers diff --git a/backend/src/userbot/handlers/raw/reactions.py b/backend/src/userbot/handlers/raw/reactions.py new file mode 100644 index 0000000..7bda931 --- /dev/null +++ b/backend/src/userbot/handlers/raw/reactions.py @@ -0,0 +1,49 @@ +from pyrogram import raw, utils + +from userbot import PyroClient +from userbot.modules.capture import repository +from userbot.modules.capture.chat_meta import meta_from_peer + +HANDLES = (raw.types.UpdateMessageReactions,) + + +def _reaction_key(reaction: raw.base.Reaction) -> str | None: + if isinstance(reaction, raw.types.ReactionEmoji): + return reaction.emoticon + if isinstance(reaction, raw.types.ReactionCustomEmoji): + return f"custom:{reaction.document_id}" + return None + + +def _parse_recent(reactions: raw.base.MessageReactions) -> list[tuple[int, str]]: + recent = getattr(reactions, "recent_reactions", None) or [] + parsed: list[tuple[int, str]] = [] + for peer_reaction in recent: + key = _reaction_key(peer_reaction.reaction) + if key is None: + continue + try: + peer_id = utils.get_peer_id(peer_reaction.peer_id) + except (ValueError, AttributeError): + continue + parsed.append((peer_id, key)) + return parsed + + +async def handle( + client: PyroClient, + update: raw.types.UpdateMessageReactions, + users: dict, + chats: dict, +) -> None: + ctx = client.capture + if ctx is None: + return + meta = meta_from_peer(update.peer, chats, users, ctx.contacts.ids) + toggles = ctx.resolve(meta) + if not toggles.reactions: + return + current = _parse_recent(update.reactions) + await repository.sync_reactions( + ctx.pool, ctx.account_id, meta.chat_id, update.msg_id, current + ) diff --git a/backend/src/userbot/modules/__init__.py b/backend/src/userbot/modules/__init__.py index 06b2724..e69de29 100644 --- a/backend/src/userbot/modules/__init__.py +++ b/backend/src/userbot/modules/__init__.py @@ -1,3 +0,0 @@ -from .client import PyroClient - -__all__ = ["PyroClient"] diff --git a/backend/src/userbot/modules/capture/__init__.py b/backend/src/userbot/modules/capture/__init__.py new file mode 100644 index 0000000..c8437ad --- /dev/null +++ b/backend/src/userbot/modules/capture/__init__.py @@ -0,0 +1,3 @@ +from userbot.modules.capture.context import CaptureContext, build_capture_context + +__all__ = ["CaptureContext", "build_capture_context"] diff --git a/backend/src/userbot/modules/capture/chat_meta.py b/backend/src/userbot/modules/capture/chat_meta.py new file mode 100644 index 0000000..6da537d --- /dev/null +++ b/backend/src/userbot/modules/capture/chat_meta.py @@ -0,0 +1,46 @@ +from pyrogram import enums, raw, utils +from pyrogram.types import Chat + +from utils.policy.models import ChatKind, ChatMeta + +_KIND_BY_TYPE: dict[enums.ChatType, ChatKind] = { + enums.ChatType.PRIVATE: ChatKind.DM, + enums.ChatType.BOT: ChatKind.DM, + enums.ChatType.GROUP: ChatKind.GROUP, + enums.ChatType.SUPERGROUP: ChatKind.GROUP, + enums.ChatType.FORUM: ChatKind.GROUP, + enums.ChatType.CHANNEL: ChatKind.CHANNEL, +} + + +def chat_kind(chat_type: enums.ChatType | None) -> ChatKind: + return _KIND_BY_TYPE.get(chat_type, ChatKind.DM) if chat_type else ChatKind.DM + + +def meta_from_chat(chat: Chat, contacts: set[int]) -> ChatMeta: + kind = chat_kind(chat.type) + is_bot = chat.type is enums.ChatType.BOT + chat_id = chat.id or 0 + is_contact = chat_id in contacts if kind is ChatKind.DM else None + return ChatMeta(chat_id=chat_id, kind=kind, is_bot=is_bot, is_contact=is_contact) + + +def meta_from_peer( + peer: raw.base.Peer, chats: dict, users: dict, contacts: set[int] +) -> ChatMeta: + chat_id = utils.get_peer_id(peer) + if isinstance(peer, raw.types.PeerUser): + user = users.get(peer.user_id) + return ChatMeta( + chat_id=chat_id, + kind=ChatKind.DM, + is_bot=bool(getattr(user, "bot", False)), + is_contact=peer.user_id in contacts, + ) + if isinstance(peer, raw.types.PeerChannel): + channel = chats.get(peer.channel_id) + kind = ( + ChatKind.CHANNEL if getattr(channel, "broadcast", False) else ChatKind.GROUP + ) + return ChatMeta(chat_id=chat_id, kind=kind) + return ChatMeta(chat_id=chat_id, kind=ChatKind.GROUP) diff --git a/backend/src/userbot/modules/capture/context.py b/backend/src/userbot/modules/capture/context.py new file mode 100644 index 0000000..d8bb9e2 --- /dev/null +++ b/backend/src/userbot/modules/capture/context.py @@ -0,0 +1,47 @@ +import asyncpg +from pyrogram import Client + +from userbot.modules.contacts import ContactCache +from userbot.modules.folders import FolderCache +from utils.policy.models import CaptureToggles, ChatMeta, PolicySet +from utils.policy.repository import load_policy_set +from utils.policy.resolver import resolve +from utils.storage import ContentAddressedStorage + + +class CaptureContext: + def __init__( + self, + account_id: int, + pool: asyncpg.Pool, + storage: ContentAddressedStorage, + folders: FolderCache, + contacts: ContactCache, + ) -> None: + self.account_id = account_id + self.pool = pool + self.storage = storage + self.folders = folders + self.contacts = contacts + self.policies = PolicySet() + + async def reload_policies(self) -> None: + self.policies = await load_policy_set(self.pool, self.account_id) + + def resolve(self, chat: ChatMeta) -> CaptureToggles: + return resolve(chat, self.folders.folders, self.policies) + + +async def build_capture_context( + client: Client, + pool: asyncpg.Pool, + storage: ContentAddressedStorage, + account_id: int, +) -> CaptureContext: + folders = FolderCache(client, pool, account_id) + await folders.refresh() + contacts = ContactCache(client) + await contacts.refresh() + ctx = CaptureContext(account_id, pool, storage, folders, contacts) + await ctx.reload_policies() + return ctx diff --git a/backend/src/userbot/modules/capture/repository.py b/backend/src/userbot/modules/capture/repository.py new file mode 100644 index 0000000..dd02270 --- /dev/null +++ b/backend/src/userbot/modules/capture/repository.py @@ -0,0 +1,223 @@ +from datetime import datetime + +import asyncpg + +CHANNEL_ID_THRESHOLD = -1000000000000 + +_UPSERT_MESSAGE = """ +INSERT INTO messages + (account_id, chat_id, message_id, date, sender_id, text, raw, + has_media, is_self_destruct) +VALUES ($1, $2, $3, $4, $5, $6, $7::jsonb, $8, $9) +ON CONFLICT (account_id, chat_id, message_id, date) DO UPDATE SET + sender_id = EXCLUDED.sender_id, + text = EXCLUDED.text, + raw = EXCLUDED.raw, + has_media = EXCLUDED.has_media, + is_self_destruct = EXCLUDED.is_self_destruct +""" + +_TOUCH_EDITED = """ +INSERT INTO messages + (account_id, chat_id, message_id, date, sender_id, text, raw, + has_media, is_self_destruct, edited_at) +VALUES ($1, $2, $3, $4, $5, $6, $7::jsonb, $8, $9, now()) +ON CONFLICT (account_id, chat_id, message_id, date) DO UPDATE SET + edited_at = now() +""" + +_INSERT_VERSION = """ +INSERT INTO message_versions + (account_id, chat_id, message_id, observed_at, edit_date, text, raw) +VALUES ($1, $2, $3, now(), $4, $5, $6::jsonb) +ON CONFLICT DO NOTHING +""" + +_INSERT_MEDIA = """ +INSERT INTO media + (account_id, chat_id, message_id, kind, storage_key, file_size, mime, + ttl_seconds, downloaded) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) +ON CONFLICT (account_id, chat_id, message_id) DO UPDATE SET + kind = EXCLUDED.kind, + storage_key = EXCLUDED.storage_key, + file_size = EXCLUDED.file_size, + mime = EXCLUDED.mime, + ttl_seconds = EXCLUDED.ttl_seconds, + downloaded = EXCLUDED.downloaded +""" + + +async def upsert_message( # noqa: PLR0913 + pool: asyncpg.Pool, + account_id: int, + chat_id: int, + message_id: int, + date: datetime, + sender_id: int | None, + text: str | None, + raw: str, + *, + has_media: bool, + is_self_destruct: bool, +) -> None: + await pool.execute( + _UPSERT_MESSAGE, + account_id, + chat_id, + message_id, + date, + sender_id, + text, + raw, + has_media, + is_self_destruct, + ) + + +async def mark_deleted_box( + pool: asyncpg.Pool, account_id: int, message_ids: list[int] +) -> None: + await pool.execute( + "UPDATE messages SET deleted_at = now() " + "WHERE account_id = $1 AND message_id = ANY($2::bigint[]) " + "AND chat_id > $3 AND deleted_at IS NULL", + account_id, + message_ids, + CHANNEL_ID_THRESHOLD, + ) + + +async def mark_deleted_channel( + pool: asyncpg.Pool, account_id: int, chat_id: int, message_ids: list[int] +) -> None: + await pool.execute( + "UPDATE messages SET deleted_at = now() " + "WHERE account_id = $1 AND chat_id = $2 AND message_id = ANY($3::bigint[]) " + "AND deleted_at IS NULL", + account_id, + chat_id, + message_ids, + ) + + +async def add_version( # noqa: PLR0913 + pool: asyncpg.Pool, + account_id: int, + chat_id: int, + message_id: int, + date: datetime, + sender_id: int | None, + text: str | None, + raw: str, + edit_date: datetime | None, + *, + has_media: bool, + is_self_destruct: bool, +) -> None: + async with pool.acquire() as conn, conn.transaction(): + await conn.execute( + _TOUCH_EDITED, + account_id, + chat_id, + message_id, + date, + sender_id, + text, + raw, + has_media, + is_self_destruct, + ) + await conn.execute( + _INSERT_VERSION, account_id, chat_id, message_id, edit_date, text, raw + ) + + +async def insert_media( # noqa: PLR0913 + pool: asyncpg.Pool, + account_id: int, + chat_id: int, + message_id: int, + kind: str, + storage_key: str | None, + file_size: int | None, + mime: str | None, + ttl_seconds: int | None, + *, + downloaded: bool, +) -> None: + await pool.execute( + _INSERT_MEDIA, + account_id, + chat_id, + message_id, + kind, + storage_key, + file_size, + mime, + ttl_seconds, + downloaded, + ) + + +async def insert_callbacks( + pool: asyncpg.Pool, + account_id: int, + chat_id: int, + message_id: int, + buttons: list[tuple[int, str | None, bytes | None]], +) -> None: + async with pool.acquire() as conn, conn.transaction(): + await conn.execute( + "DELETE FROM callbacks " + "WHERE account_id = $1 AND chat_id = $2 AND message_id = $3", + account_id, + chat_id, + message_id, + ) + await conn.executemany( + "INSERT INTO callbacks " + "(account_id, chat_id, message_id, position, label, data) " + "VALUES ($1, $2, $3, $4, $5, $6)", + [ + (account_id, chat_id, message_id, position, label, data) + for position, label, data in buttons + ], + ) + + +async def sync_reactions( + pool: asyncpg.Pool, + account_id: int, + chat_id: int, + message_id: int, + current: list[tuple[int, str]], +) -> None: + async with pool.acquire() as conn, conn.transaction(): + rows = await conn.fetch( + "SELECT id, peer_id, reaction FROM reactions " + "WHERE account_id = $1 AND chat_id = $2 AND message_id = $3 " + "AND removed_at IS NULL", + account_id, + chat_id, + message_id, + ) + existing = {(row["peer_id"], row["reaction"]): row["id"] for row in rows} + current_set = set(current) + stale = [pid for key, pid in existing.items() if key not in current_set] + if stale: + await conn.execute( + "UPDATE reactions SET removed_at = now() WHERE id = ANY($1::bigint[])", + stale, + ) + fresh = [key for key in current_set if key not in existing] + if fresh: + await conn.executemany( + "INSERT INTO reactions " + "(account_id, chat_id, message_id, peer_id, reaction) " + "VALUES ($1, $2, $3, $4, $5)", + [ + (account_id, chat_id, message_id, peer_id, reaction) + for peer_id, reaction in fresh + ], + ) diff --git a/backend/src/userbot/modules/client.py b/backend/src/userbot/modules/client.py index 0fb39f7..1b9010a 100644 --- a/backend/src/userbot/modules/client.py +++ b/backend/src/userbot/modules/client.py @@ -1,8 +1,15 @@ +from typing import TYPE_CHECKING + from pyrogram import Client, enums +if TYPE_CHECKING: + from userbot.modules.capture import CaptureContext + class PyroClient(Client): - def __init__(self, name: str, *, workdir: str = "sessions") -> None: + def __init__( + self, name: str, *, workdir: str = "sessions", load_handlers: bool = True + ) -> None: super().__init__( name, workdir=workdir, @@ -14,6 +21,13 @@ class PyroClient(Client): lang_pack="tdesktop", client_platform=enums.ClientPlatform.DESKTOP, ) + self.capture: CaptureContext | None = None + + if load_handlers: + from userbot import handlers # noqa: PLC0415 + + for handler in handlers.handlers: + self.add_handler(*handler) __all__ = ["PyroClient"] diff --git a/backend/src/userbot/modules/contacts/__init__.py b/backend/src/userbot/modules/contacts/__init__.py new file mode 100644 index 0000000..44f265a --- /dev/null +++ b/backend/src/userbot/modules/contacts/__init__.py @@ -0,0 +1,3 @@ +from userbot.modules.contacts.cache import ContactCache + +__all__ = ["ContactCache"] diff --git a/backend/src/userbot/modules/contacts/cache.py b/backend/src/userbot/modules/contacts/cache.py new file mode 100644 index 0000000..a2474b8 --- /dev/null +++ b/backend/src/userbot/modules/contacts/cache.py @@ -0,0 +1,20 @@ +from pyrogram import Client + +from utils.logging import logger + + +class ContactCache: + def __init__(self, client: Client) -> None: + self._client = client + self.ids: set[int] = set() + + async def refresh(self) -> None: + contacts = await self._client.get_contacts() + self.ids = {user.id for user in contacts} + logger.info(f"[green]Contacts cached:[/] {len(self.ids)}") + + def mark(self, user_id: int, *, is_contact: bool) -> None: + if is_contact: + self.ids.add(user_id) + else: + self.ids.discard(user_id) diff --git a/backend/src/userbot/modules/folders/__init__.py b/backend/src/userbot/modules/folders/__init__.py new file mode 100644 index 0000000..f2bba38 --- /dev/null +++ b/backend/src/userbot/modules/folders/__init__.py @@ -0,0 +1,3 @@ +from userbot.modules.folders.cache import FolderCache + +__all__ = ["FolderCache"] diff --git a/backend/src/userbot/folders.py b/backend/src/userbot/modules/folders/cache.py similarity index 100% rename from backend/src/userbot/folders.py rename to backend/src/userbot/modules/folders/cache.py diff --git a/backend/src/userbot/modules/media/__init__.py b/backend/src/userbot/modules/media/__init__.py new file mode 100644 index 0000000..b1ed28d --- /dev/null +++ b/backend/src/userbot/modules/media/__init__.py @@ -0,0 +1,3 @@ +from userbot.modules.media.downloader import capture_media, self_destruct_ttl + +__all__ = ["capture_media", "self_destruct_ttl"] diff --git a/backend/src/userbot/modules/media/downloader.py b/backend/src/userbot/modules/media/downloader.py new file mode 100644 index 0000000..277d08e --- /dev/null +++ b/backend/src/userbot/modules/media/downloader.py @@ -0,0 +1,71 @@ +from io import BytesIO +from typing import Any + +from pyrogram import Client +from pyrogram.types import Message + +from userbot.modules.capture import repository +from userbot.modules.capture.context import CaptureContext +from utils.policy.models import CaptureToggles + +_MEDIA_ATTRS = ( + "photo", + "video", + "voice", + "video_note", + "document", + "audio", + "animation", + "sticker", +) + + +def media_object(message: Message) -> tuple[str | None, Any]: + for attr in _MEDIA_ATTRS: + obj = getattr(message, attr, None) + if obj is not None: + return attr, obj + return None, None + + +def self_destruct_ttl(message: Message) -> int | None: + _, obj = media_object(message) + return getattr(obj, "ttl_seconds", None) if obj is not None else None + + +async def capture_media( # noqa: PLR0913 + client: Client, + message: Message, + ctx: CaptureContext, + chat_id: int, + message_id: int, + toggles: CaptureToggles, +) -> None: + kind, obj = media_object(message) + if obj is None: + return + ttl = getattr(obj, "ttl_seconds", None) + want = toggles.self_destruct_media if ttl else toggles.media + file_size = getattr(obj, "file_size", None) + mime = getattr(obj, "mime_type", None) + storage_key: str | None = None + downloaded = False + if want: + buffer = await client.download_media(message, in_memory=True) + if isinstance(buffer, BytesIO): + data = buffer.getvalue() + storage_key = ctx.storage.put(data) + file_size = len(data) + downloaded = True + await repository.insert_media( + ctx.pool, + ctx.account_id, + chat_id, + message_id, + kind or "unknown", + storage_key, + file_size, + mime, + ttl, + downloaded=downloaded, + ) diff --git a/backend/src/userbot/runner.py b/backend/src/userbot/runner.py index 602707e..c640c5a 100644 --- a/backend/src/userbot/runner.py +++ b/backend/src/userbot/runner.py @@ -7,14 +7,11 @@ import asyncpg import uvloop from dependencies.container import container -from userbot.folders import FolderCache -from userbot.handlers import dialog_filter_handler -from userbot.modules import PyroClient +from userbot import PyroClient +from userbot.modules.capture import build_capture_context from utils.env import env from utils.logging import logger, setup_logging -from utils.policy.models import ChatKind, ChatMeta -from utils.policy.repository import load_policy_set -from utils.policy.resolver import resolve +from utils.storage import ContentAddressedStorage setup_logging() @@ -61,22 +58,37 @@ async def _sync_account( return account_id -async def _setup_policy( - pool: asyncpg.Pool, client: PyroClient, account_id: int +async def _setup_capture( + pool: asyncpg.Pool, + client: PyroClient, + account_id: int, + storage: ContentAddressedStorage, ) -> None: - cache = FolderCache(client, pool, account_id) - await cache.refresh() - client.add_handler(dialog_filter_handler(cache)) - if client.me: - policies = await load_policy_set(pool, account_id) - sample = resolve( - ChatMeta(chat_id=client.me.id, kind=ChatKind.DM), cache.folders, policies - ) - logger.info(f"[green]Sample resolve (self DM):[/] {sample.model_dump()}") + client.capture = await build_capture_context(client, pool, storage, account_id) + logger.info("[green]Capture context ready.[/]") + + +async def _listen_policy_changes( + clients: list[PyroClient], tasks: set[asyncio.Task] +) -> asyncpg.Connection: + def on_change( + _conn: asyncpg.Connection, _pid: int, _channel: str, _payload: str + ) -> None: + for client in clients: + if client.capture is None: + continue + task = asyncio.create_task(client.capture.reload_policies()) + tasks.add(task) + task.add_done_callback(tasks.discard) + + conn = await asyncpg.connect(dsn=env.db.connection_url) + await conn.add_listener("policy_changed", on_change) + return conn async def runner() -> None: pool = await container.get(asyncpg.Pool) + storage = await container.get(ContentAddressedStorage) sessions_dir = Path(env.tg.sessions_dir) session_files = _discover_sessions(sessions_dir) @@ -88,6 +100,8 @@ async def runner() -> None: ) clients: list[PyroClient] = [] + reload_tasks: set[asyncio.Task] = set() + listen_conn: asyncpg.Connection | None = None try: for session_path in session_files: session_name = session_path.stem @@ -101,12 +115,16 @@ async def runner() -> None: ) account_id = await _sync_account(pool, client, session_name) if account_id is not None: - await _setup_policy(pool, client, account_id) + await _setup_capture(pool, client, account_id, storage) if clients: + listen_conn = await _listen_policy_changes(clients, reload_tasks) logger.info("[green]Userbot running.[/]") await asyncio.Event().wait() finally: + if listen_conn is not None: + with contextlib.suppress(Exception): + await listen_conn.close() for client in clients: with contextlib.suppress(Exception): await client.stop() diff --git a/backend/src/utils/db/models.py b/backend/src/utils/db/models.py index d680807..d52c039 100644 --- a/backend/src/utils/db/models.py +++ b/backend/src/utils/db/models.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import Any -from sqlalchemy import BigInteger, Column, DateTime, func +from sqlalchemy import BigInteger, Column, DateTime, LargeBinary, func from sqlalchemy.dialects.postgresql import JSONB from sqlmodel import Field, SQLModel @@ -42,9 +42,16 @@ class Message(SQLModel, table=True): chat_id: int = Field(sa_column=Column(BigInteger, primary_key=True)) message_id: int = Field(sa_column=Column(BigInteger, primary_key=True)) date: datetime = Field(sa_column=Column(DateTime(timezone=True), primary_key=True)) + sender_id: int | None = Field(default=None, sa_column=Column(BigInteger)) + text: str | None = None raw: dict[str, Any] = Field( default_factory=dict, sa_column=Column(JSONB, nullable=False) ) + has_media: bool = False + is_self_destruct: bool = False + edited_at: datetime | None = Field( + default=None, sa_column=Column(DateTime(timezone=True)) + ) deleted_at: datetime | None = Field( default=None, sa_column=Column(DateTime(timezone=True)) ) @@ -71,6 +78,86 @@ class Folder(SQLModel, table=True): ) +class MessageVersion(SQLModel, table=True): + __tablename__ = "message_versions" + + account_id: int = Field(primary_key=True) + chat_id: int = Field(sa_column=Column(BigInteger, primary_key=True)) + message_id: int = Field(sa_column=Column(BigInteger, primary_key=True)) + observed_at: datetime = Field( + sa_column=Column(DateTime(timezone=True), primary_key=True) + ) + edit_date: datetime | None = Field( + default=None, sa_column=Column(DateTime(timezone=True)) + ) + text: str | None = None + raw: dict[str, Any] = Field( + default_factory=dict, sa_column=Column(JSONB, nullable=False) + ) + + +class Media(SQLModel, table=True): + __tablename__ = "media" + + id: int | None = Field(default=None, primary_key=True) + account_id: int + chat_id: int = Field(sa_column=Column(BigInteger, nullable=False)) + message_id: int = Field(sa_column=Column(BigInteger, nullable=False)) + kind: str + storage_key: str | None = None + file_size: int | None = Field(default=None, sa_column=Column(BigInteger)) + mime: str | None = None + ttl_seconds: int | None = None + downloaded: bool = False + extracted_text: str | None = None + created_at: datetime = Field( + sa_column=Column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + ) + + +class Callback(SQLModel, table=True): + __tablename__ = "callbacks" + + id: int | None = Field(default=None, primary_key=True) + account_id: int + chat_id: int = Field(sa_column=Column(BigInteger, nullable=False)) + message_id: int = Field(sa_column=Column(BigInteger, nullable=False)) + position: int + label: str | None = None + data: bytes | None = Field(default=None, sa_column=Column(LargeBinary)) + created_at: datetime = Field( + sa_column=Column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + ) + + +class Reaction(SQLModel, table=True): + __tablename__ = "reactions" + + id: int | None = Field(default=None, primary_key=True) + account_id: int + chat_id: int = Field(sa_column=Column(BigInteger, nullable=False)) + message_id: int = Field(sa_column=Column(BigInteger, nullable=False)) + peer_id: int = Field(sa_column=Column(BigInteger, nullable=False)) + reaction: str + added_at: datetime = Field( + sa_column=Column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + ) + removed_at: datetime | None = Field( + default=None, sa_column=Column(DateTime(timezone=True)) + ) + created_at: datetime = Field( + sa_column=Column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + ) + + class CapturePolicy(SQLModel, table=True): __tablename__ = "capture_policy"