feat: add realtime capturing
This commit is contained in:
@@ -0,0 +1,139 @@
|
||||
"""realtime core
|
||||
|
||||
Revision ID: c3a8e5f1b6d2
|
||||
Revises: b2f7c1a9d3e4
|
||||
Create Date: 2026-05-29 18:30:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = "c3a8e5f1b6d2"
|
||||
down_revision: str | None = "b2f7c1a9d3e4"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("messages", sa.Column("sender_id", sa.BigInteger(), nullable=True))
|
||||
op.add_column("messages", sa.Column("text", sa.String(), nullable=True))
|
||||
op.add_column(
|
||||
"messages",
|
||||
sa.Column("has_media", sa.Boolean(), nullable=False, server_default=sa.false()),
|
||||
)
|
||||
op.add_column(
|
||||
"messages",
|
||||
sa.Column(
|
||||
"is_self_destruct", sa.Boolean(), nullable=False, server_default=sa.false()
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"messages", sa.Column("edited_at", sa.DateTime(timezone=True), nullable=True)
|
||||
)
|
||||
op.create_index("ix_messages_box", "messages", ["account_id", "message_id"])
|
||||
|
||||
op.create_table(
|
||||
"message_versions",
|
||||
sa.Column("account_id", sa.Integer(), nullable=False),
|
||||
sa.Column("chat_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("message_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("observed_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("edit_date", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("text", sa.String(), nullable=True),
|
||||
sa.Column("raw", postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.PrimaryKeyConstraint("account_id", "chat_id", "message_id", "observed_at"),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"media",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("account_id", sa.Integer(), nullable=False),
|
||||
sa.Column("chat_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("message_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("kind", sa.String(), nullable=False),
|
||||
sa.Column("storage_key", sa.String(), nullable=True),
|
||||
sa.Column("file_size", sa.BigInteger(), nullable=True),
|
||||
sa.Column("mime", sa.String(), nullable=True),
|
||||
sa.Column("ttl_seconds", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"downloaded", sa.Boolean(), nullable=False, server_default=sa.false()
|
||||
),
|
||||
sa.Column("extracted_text", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("account_id", "chat_id", "message_id"),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"callbacks",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("account_id", sa.Integer(), nullable=False),
|
||||
sa.Column("chat_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("message_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("position", sa.Integer(), nullable=False),
|
||||
sa.Column("label", sa.String(), nullable=True),
|
||||
sa.Column("data", sa.LargeBinary(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_callbacks_message", "callbacks", ["account_id", "chat_id", "message_id"]
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"reactions",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("account_id", sa.Integer(), nullable=False),
|
||||
sa.Column("chat_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("message_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("peer_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("reaction", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"added_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("removed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_reactions_message", "reactions", ["account_id", "chat_id", "message_id"]
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX ix_reactions_open ON reactions "
|
||||
"(account_id, chat_id, message_id) WHERE removed_at IS NULL"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("reactions")
|
||||
op.drop_table("callbacks")
|
||||
op.drop_table("media")
|
||||
op.drop_table("message_versions")
|
||||
op.drop_index("ix_messages_box", table_name="messages")
|
||||
op.drop_column("messages", "edited_at")
|
||||
op.drop_column("messages", "is_self_destruct")
|
||||
op.drop_column("messages", "has_media")
|
||||
op.drop_column("messages", "text")
|
||||
op.drop_column("messages", "sender_id")
|
||||
@@ -17,6 +17,8 @@ from utils.policy.resolver import resolve
|
||||
|
||||
router = APIRouter(prefix="/api/policy", tags=["policy"])
|
||||
|
||||
POLICY_CHANGED_CHANNEL = "policy_changed"
|
||||
|
||||
|
||||
class EffectiveQuery(BaseModel):
|
||||
account_id: int
|
||||
@@ -55,9 +57,11 @@ async def effective_policy(
|
||||
async def create_policy(
|
||||
pool: FromDishka[asyncpg.Pool], body: PolicyCreate
|
||||
) -> PolicyRecord:
|
||||
return await repository.create_policy(
|
||||
record = await repository.create_policy(
|
||||
pool, body.account_id, body.scope_type, body.scope_id, body
|
||||
)
|
||||
await pool.execute(f"NOTIFY {POLICY_CHANGED_CHANNEL}")
|
||||
return record
|
||||
|
||||
|
||||
@router.get("/{policy_id}")
|
||||
@@ -77,6 +81,7 @@ async def update_policy(
|
||||
record = await repository.update_policy(pool, policy_id, body)
|
||||
if record is None:
|
||||
raise HTTPException(status_code=404, detail="policy not found")
|
||||
await pool.execute(f"NOTIFY {POLICY_CHANGED_CHANNEL}")
|
||||
return record
|
||||
|
||||
|
||||
@@ -85,3 +90,4 @@ async def update_policy(
|
||||
async def delete_policy(pool: FromDishka[asyncpg.Pool], policy_id: int) -> None:
|
||||
if not await repository.delete_policy(pool, policy_id):
|
||||
raise HTTPException(status_code=404, detail="policy not found")
|
||||
await pool.execute(f"NOTIFY {POLICY_CHANGED_CHANNEL}")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from dishka import make_async_container
|
||||
|
||||
from dependencies.providers.postgres import DbProvider
|
||||
from dependencies.providers.storage import StorageProvider
|
||||
|
||||
container = make_async_container(DbProvider())
|
||||
container = make_async_container(DbProvider(), StorageProvider())
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
from dishka import Provider, Scope, provide
|
||||
|
||||
from utils.env import env
|
||||
from utils.storage import ContentAddressedStorage
|
||||
|
||||
|
||||
class StorageProvider(Provider):
|
||||
@provide(scope=Scope.APP)
|
||||
def get_storage(self) -> ContentAddressedStorage:
|
||||
return ContentAddressedStorage(env.storage.root, env.storage.shard_depth)
|
||||
@@ -0,0 +1,3 @@
|
||||
from userbot.modules.client import PyroClient
|
||||
|
||||
__all__ = ["PyroClient"]
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from .dialog_filters import dialog_filter_handler
|
||||
from userbot.handlers import deletes, edits, messages, raw
|
||||
|
||||
__all__ = ["dialog_filter_handler"]
|
||||
handlers = messages.handlers + edits.handlers + deletes.handlers + raw.handlers
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
from pyrogram.types import Message
|
||||
|
||||
from userbot import PyroClient
|
||||
from userbot.modules.capture import repository
|
||||
from userbot.modules.capture.repository import CHANNEL_ID_THRESHOLD
|
||||
|
||||
|
||||
@PyroClient.on_deleted_messages()
|
||||
async def on_deleted_messages(client: PyroClient, messages: list[Message]) -> None:
|
||||
ctx = client.capture
|
||||
if ctx is None:
|
||||
return
|
||||
box: list[int] = []
|
||||
channels: dict[int, list[int]] = {}
|
||||
for message in messages:
|
||||
chat = message.chat
|
||||
chat_id = chat.id if chat is not None else None
|
||||
if chat_id is None or chat_id > CHANNEL_ID_THRESHOLD:
|
||||
box.append(message.id)
|
||||
else:
|
||||
channels.setdefault(chat_id, []).append(message.id)
|
||||
if box:
|
||||
await repository.mark_deleted_box(ctx.pool, ctx.account_id, box)
|
||||
for chat_id, ids in channels.items():
|
||||
await repository.mark_deleted_channel(ctx.pool, ctx.account_id, chat_id, ids)
|
||||
|
||||
|
||||
handlers = on_deleted_messages.handlers
|
||||
@@ -1,20 +0,0 @@
|
||||
from pyrogram import Client, raw
|
||||
from pyrogram.handlers import RawUpdateHandler
|
||||
|
||||
from userbot.folders import FolderCache
|
||||
|
||||
_FILTER_UPDATES = (
|
||||
raw.types.UpdateDialogFilter,
|
||||
raw.types.UpdateDialogFilters,
|
||||
raw.types.UpdateDialogFilterOrder,
|
||||
)
|
||||
|
||||
|
||||
def dialog_filter_handler(cache: FolderCache) -> RawUpdateHandler:
|
||||
async def on_update(
|
||||
_client: Client, update: raw.base.Update, _users: dict, _chats: dict
|
||||
) -> None:
|
||||
if isinstance(update, _FILTER_UPDATES):
|
||||
await cache.refresh()
|
||||
|
||||
return RawUpdateHandler(on_update)
|
||||
@@ -0,0 +1,36 @@
|
||||
from pyrogram.types import Message
|
||||
|
||||
from userbot import PyroClient
|
||||
from userbot.handlers.messages import sender_id
|
||||
from userbot.modules.capture import repository
|
||||
from userbot.modules.capture.chat_meta import meta_from_chat
|
||||
from userbot.modules.media import self_destruct_ttl
|
||||
|
||||
|
||||
@PyroClient.on_edited_message()
|
||||
async def on_edited_message(client: PyroClient, message: Message) -> None:
|
||||
ctx = client.capture
|
||||
if ctx is None or message.empty or message.chat is None or message.date is None:
|
||||
return
|
||||
chat = message.chat
|
||||
chat_id = chat.id or 0
|
||||
meta = meta_from_chat(chat, ctx.contacts.ids)
|
||||
toggles = ctx.resolve(meta)
|
||||
if not toggles.track_edits_deletes:
|
||||
return
|
||||
await repository.add_version(
|
||||
ctx.pool,
|
||||
ctx.account_id,
|
||||
chat_id,
|
||||
message.id,
|
||||
message.date,
|
||||
sender_id(message),
|
||||
message.text or message.caption,
|
||||
str(message),
|
||||
message.edit_date,
|
||||
has_media=message.media is not None,
|
||||
is_self_destruct=self_destruct_ttl(message) is not None,
|
||||
)
|
||||
|
||||
|
||||
handlers = on_edited_message.handlers
|
||||
@@ -0,0 +1,64 @@
|
||||
from pyrogram.types import Message
|
||||
|
||||
from userbot import PyroClient
|
||||
from userbot.modules.capture import repository
|
||||
from userbot.modules.capture.chat_meta import meta_from_chat
|
||||
from userbot.modules.media import capture_media, self_destruct_ttl
|
||||
|
||||
|
||||
def sender_id(message: Message) -> int | None:
|
||||
if message.from_user is not None:
|
||||
return message.from_user.id
|
||||
if message.sender_chat is not None:
|
||||
return message.sender_chat.id
|
||||
return None
|
||||
|
||||
|
||||
def _callbacks(message: Message) -> list[tuple[int, str | None, bytes | None]]:
|
||||
rows = getattr(message.reply_markup, "inline_keyboard", None)
|
||||
if not rows:
|
||||
return []
|
||||
buttons: list[tuple[int, str | None, bytes | None]] = []
|
||||
position = 0
|
||||
for row in rows:
|
||||
for button in row:
|
||||
data = button.callback_data
|
||||
if data is not None:
|
||||
encoded = data.encode() if isinstance(data, str) else data
|
||||
buttons.append((position, button.text, encoded))
|
||||
position += 1
|
||||
return buttons
|
||||
|
||||
|
||||
@PyroClient.on_message()
|
||||
async def on_message(client: PyroClient, message: Message) -> None:
|
||||
ctx = client.capture
|
||||
if ctx is None or message.empty or message.chat is None or message.date is None:
|
||||
return
|
||||
chat = message.chat
|
||||
chat_id = chat.id or 0
|
||||
meta = meta_from_chat(chat, ctx.contacts.ids)
|
||||
toggles = ctx.resolve(meta)
|
||||
if not toggles.messages:
|
||||
return
|
||||
await repository.upsert_message(
|
||||
ctx.pool,
|
||||
ctx.account_id,
|
||||
chat_id,
|
||||
message.id,
|
||||
message.date,
|
||||
sender_id(message),
|
||||
message.text or message.caption,
|
||||
str(message),
|
||||
has_media=message.media is not None,
|
||||
is_self_destruct=self_destruct_ttl(message) is not None,
|
||||
)
|
||||
await capture_media(client, message, ctx, chat_id, message.id, toggles)
|
||||
buttons = _callbacks(message)
|
||||
if buttons:
|
||||
await repository.insert_callbacks(
|
||||
ctx.pool, ctx.account_id, chat_id, message.id, buttons
|
||||
)
|
||||
|
||||
|
||||
handlers = on_message.handlers
|
||||
@@ -0,0 +1,3 @@
|
||||
from userbot.handlers.raw.dispatcher import handlers
|
||||
|
||||
__all__ = ["handlers"]
|
||||
@@ -0,0 +1,22 @@
|
||||
from pyrogram import raw
|
||||
|
||||
from userbot import PyroClient
|
||||
|
||||
HANDLES = (raw.types.UpdateContactsReset, raw.types.UpdateUser)
|
||||
|
||||
|
||||
async def handle(
|
||||
client: PyroClient, update: raw.base.Update, users: dict, _chats: dict
|
||||
) -> None:
|
||||
ctx = client.capture
|
||||
if ctx is None:
|
||||
return
|
||||
if isinstance(update, raw.types.UpdateContactsReset):
|
||||
await ctx.contacts.refresh()
|
||||
return
|
||||
if not isinstance(update, raw.types.UpdateUser):
|
||||
return
|
||||
user = users.get(update.user_id)
|
||||
if not isinstance(user, raw.types.User) or user.min:
|
||||
return
|
||||
ctx.contacts.mark(user.id, is_contact=bool(user.contact))
|
||||
@@ -0,0 +1,19 @@
|
||||
from pyrogram import raw
|
||||
|
||||
from userbot import PyroClient
|
||||
|
||||
HANDLES = (
|
||||
raw.types.UpdateDialogFilter,
|
||||
raw.types.UpdateDialogFilters,
|
||||
raw.types.UpdateDialogFilterOrder,
|
||||
)
|
||||
|
||||
|
||||
async def handle(
|
||||
client: PyroClient, _update: raw.base.Update, _users: dict, _chats: dict
|
||||
) -> None:
|
||||
ctx = client.capture
|
||||
if ctx is None:
|
||||
return
|
||||
await ctx.folders.refresh()
|
||||
await ctx.reload_policies()
|
||||
@@ -0,0 +1,25 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from pyrogram import raw
|
||||
|
||||
from userbot import PyroClient
|
||||
from userbot.handlers.raw import contacts, dialog_filters, reactions
|
||||
|
||||
RawHandler = Callable[..., Awaitable[None]]
|
||||
|
||||
_MODULES = (contacts, dialog_filters, reactions)
|
||||
_REGISTRY: dict[type, RawHandler] = {
|
||||
update_type: module.handle for module in _MODULES for update_type in module.HANDLES
|
||||
}
|
||||
|
||||
|
||||
@PyroClient.on_raw_update()
|
||||
async def dispatch(
|
||||
client: PyroClient, update: raw.base.Update, users: dict, chats: dict
|
||||
) -> None:
|
||||
handler = _REGISTRY.get(type(update))
|
||||
if handler is not None:
|
||||
await handler(client, update, users, chats)
|
||||
|
||||
|
||||
handlers = dispatch.handlers
|
||||
@@ -0,0 +1,49 @@
|
||||
from pyrogram import raw, utils
|
||||
|
||||
from userbot import PyroClient
|
||||
from userbot.modules.capture import repository
|
||||
from userbot.modules.capture.chat_meta import meta_from_peer
|
||||
|
||||
HANDLES = (raw.types.UpdateMessageReactions,)
|
||||
|
||||
|
||||
def _reaction_key(reaction: raw.base.Reaction) -> str | None:
|
||||
if isinstance(reaction, raw.types.ReactionEmoji):
|
||||
return reaction.emoticon
|
||||
if isinstance(reaction, raw.types.ReactionCustomEmoji):
|
||||
return f"custom:{reaction.document_id}"
|
||||
return None
|
||||
|
||||
|
||||
def _parse_recent(reactions: raw.base.MessageReactions) -> list[tuple[int, str]]:
|
||||
recent = getattr(reactions, "recent_reactions", None) or []
|
||||
parsed: list[tuple[int, str]] = []
|
||||
for peer_reaction in recent:
|
||||
key = _reaction_key(peer_reaction.reaction)
|
||||
if key is None:
|
||||
continue
|
||||
try:
|
||||
peer_id = utils.get_peer_id(peer_reaction.peer_id)
|
||||
except (ValueError, AttributeError):
|
||||
continue
|
||||
parsed.append((peer_id, key))
|
||||
return parsed
|
||||
|
||||
|
||||
async def handle(
|
||||
client: PyroClient,
|
||||
update: raw.types.UpdateMessageReactions,
|
||||
users: dict,
|
||||
chats: dict,
|
||||
) -> None:
|
||||
ctx = client.capture
|
||||
if ctx is None:
|
||||
return
|
||||
meta = meta_from_peer(update.peer, chats, users, ctx.contacts.ids)
|
||||
toggles = ctx.resolve(meta)
|
||||
if not toggles.reactions:
|
||||
return
|
||||
current = _parse_recent(update.reactions)
|
||||
await repository.sync_reactions(
|
||||
ctx.pool, ctx.account_id, meta.chat_id, update.msg_id, current
|
||||
)
|
||||
@@ -1,3 +0,0 @@
|
||||
from .client import PyroClient
|
||||
|
||||
__all__ = ["PyroClient"]
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from userbot.modules.capture.context import CaptureContext, build_capture_context
|
||||
|
||||
__all__ = ["CaptureContext", "build_capture_context"]
|
||||
@@ -0,0 +1,46 @@
|
||||
from pyrogram import enums, raw, utils
|
||||
from pyrogram.types import Chat
|
||||
|
||||
from utils.policy.models import ChatKind, ChatMeta
|
||||
|
||||
_KIND_BY_TYPE: dict[enums.ChatType, ChatKind] = {
|
||||
enums.ChatType.PRIVATE: ChatKind.DM,
|
||||
enums.ChatType.BOT: ChatKind.DM,
|
||||
enums.ChatType.GROUP: ChatKind.GROUP,
|
||||
enums.ChatType.SUPERGROUP: ChatKind.GROUP,
|
||||
enums.ChatType.FORUM: ChatKind.GROUP,
|
||||
enums.ChatType.CHANNEL: ChatKind.CHANNEL,
|
||||
}
|
||||
|
||||
|
||||
def chat_kind(chat_type: enums.ChatType | None) -> ChatKind:
|
||||
return _KIND_BY_TYPE.get(chat_type, ChatKind.DM) if chat_type else ChatKind.DM
|
||||
|
||||
|
||||
def meta_from_chat(chat: Chat, contacts: set[int]) -> ChatMeta:
|
||||
kind = chat_kind(chat.type)
|
||||
is_bot = chat.type is enums.ChatType.BOT
|
||||
chat_id = chat.id or 0
|
||||
is_contact = chat_id in contacts if kind is ChatKind.DM else None
|
||||
return ChatMeta(chat_id=chat_id, kind=kind, is_bot=is_bot, is_contact=is_contact)
|
||||
|
||||
|
||||
def meta_from_peer(
|
||||
peer: raw.base.Peer, chats: dict, users: dict, contacts: set[int]
|
||||
) -> ChatMeta:
|
||||
chat_id = utils.get_peer_id(peer)
|
||||
if isinstance(peer, raw.types.PeerUser):
|
||||
user = users.get(peer.user_id)
|
||||
return ChatMeta(
|
||||
chat_id=chat_id,
|
||||
kind=ChatKind.DM,
|
||||
is_bot=bool(getattr(user, "bot", False)),
|
||||
is_contact=peer.user_id in contacts,
|
||||
)
|
||||
if isinstance(peer, raw.types.PeerChannel):
|
||||
channel = chats.get(peer.channel_id)
|
||||
kind = (
|
||||
ChatKind.CHANNEL if getattr(channel, "broadcast", False) else ChatKind.GROUP
|
||||
)
|
||||
return ChatMeta(chat_id=chat_id, kind=kind)
|
||||
return ChatMeta(chat_id=chat_id, kind=ChatKind.GROUP)
|
||||
@@ -0,0 +1,47 @@
|
||||
import asyncpg
|
||||
from pyrogram import Client
|
||||
|
||||
from userbot.modules.contacts import ContactCache
|
||||
from userbot.modules.folders import FolderCache
|
||||
from utils.policy.models import CaptureToggles, ChatMeta, PolicySet
|
||||
from utils.policy.repository import load_policy_set
|
||||
from utils.policy.resolver import resolve
|
||||
from utils.storage import ContentAddressedStorage
|
||||
|
||||
|
||||
class CaptureContext:
|
||||
def __init__(
|
||||
self,
|
||||
account_id: int,
|
||||
pool: asyncpg.Pool,
|
||||
storage: ContentAddressedStorage,
|
||||
folders: FolderCache,
|
||||
contacts: ContactCache,
|
||||
) -> None:
|
||||
self.account_id = account_id
|
||||
self.pool = pool
|
||||
self.storage = storage
|
||||
self.folders = folders
|
||||
self.contacts = contacts
|
||||
self.policies = PolicySet()
|
||||
|
||||
async def reload_policies(self) -> None:
|
||||
self.policies = await load_policy_set(self.pool, self.account_id)
|
||||
|
||||
def resolve(self, chat: ChatMeta) -> CaptureToggles:
|
||||
return resolve(chat, self.folders.folders, self.policies)
|
||||
|
||||
|
||||
async def build_capture_context(
|
||||
client: Client,
|
||||
pool: asyncpg.Pool,
|
||||
storage: ContentAddressedStorage,
|
||||
account_id: int,
|
||||
) -> CaptureContext:
|
||||
folders = FolderCache(client, pool, account_id)
|
||||
await folders.refresh()
|
||||
contacts = ContactCache(client)
|
||||
await contacts.refresh()
|
||||
ctx = CaptureContext(account_id, pool, storage, folders, contacts)
|
||||
await ctx.reload_policies()
|
||||
return ctx
|
||||
@@ -0,0 +1,223 @@
|
||||
from datetime import datetime
|
||||
|
||||
import asyncpg
|
||||
|
||||
CHANNEL_ID_THRESHOLD = -1000000000000
|
||||
|
||||
_UPSERT_MESSAGE = """
|
||||
INSERT INTO messages
|
||||
(account_id, chat_id, message_id, date, sender_id, text, raw,
|
||||
has_media, is_self_destruct)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7::jsonb, $8, $9)
|
||||
ON CONFLICT (account_id, chat_id, message_id, date) DO UPDATE SET
|
||||
sender_id = EXCLUDED.sender_id,
|
||||
text = EXCLUDED.text,
|
||||
raw = EXCLUDED.raw,
|
||||
has_media = EXCLUDED.has_media,
|
||||
is_self_destruct = EXCLUDED.is_self_destruct
|
||||
"""
|
||||
|
||||
_TOUCH_EDITED = """
|
||||
INSERT INTO messages
|
||||
(account_id, chat_id, message_id, date, sender_id, text, raw,
|
||||
has_media, is_self_destruct, edited_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7::jsonb, $8, $9, now())
|
||||
ON CONFLICT (account_id, chat_id, message_id, date) DO UPDATE SET
|
||||
edited_at = now()
|
||||
"""
|
||||
|
||||
_INSERT_VERSION = """
|
||||
INSERT INTO message_versions
|
||||
(account_id, chat_id, message_id, observed_at, edit_date, text, raw)
|
||||
VALUES ($1, $2, $3, now(), $4, $5, $6::jsonb)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
|
||||
_INSERT_MEDIA = """
|
||||
INSERT INTO media
|
||||
(account_id, chat_id, message_id, kind, storage_key, file_size, mime,
|
||||
ttl_seconds, downloaded)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
ON CONFLICT (account_id, chat_id, message_id) DO UPDATE SET
|
||||
kind = EXCLUDED.kind,
|
||||
storage_key = EXCLUDED.storage_key,
|
||||
file_size = EXCLUDED.file_size,
|
||||
mime = EXCLUDED.mime,
|
||||
ttl_seconds = EXCLUDED.ttl_seconds,
|
||||
downloaded = EXCLUDED.downloaded
|
||||
"""
|
||||
|
||||
|
||||
async def upsert_message( # noqa: PLR0913
|
||||
pool: asyncpg.Pool,
|
||||
account_id: int,
|
||||
chat_id: int,
|
||||
message_id: int,
|
||||
date: datetime,
|
||||
sender_id: int | None,
|
||||
text: str | None,
|
||||
raw: str,
|
||||
*,
|
||||
has_media: bool,
|
||||
is_self_destruct: bool,
|
||||
) -> None:
|
||||
await pool.execute(
|
||||
_UPSERT_MESSAGE,
|
||||
account_id,
|
||||
chat_id,
|
||||
message_id,
|
||||
date,
|
||||
sender_id,
|
||||
text,
|
||||
raw,
|
||||
has_media,
|
||||
is_self_destruct,
|
||||
)
|
||||
|
||||
|
||||
async def mark_deleted_box(
|
||||
pool: asyncpg.Pool, account_id: int, message_ids: list[int]
|
||||
) -> None:
|
||||
await pool.execute(
|
||||
"UPDATE messages SET deleted_at = now() "
|
||||
"WHERE account_id = $1 AND message_id = ANY($2::bigint[]) "
|
||||
"AND chat_id > $3 AND deleted_at IS NULL",
|
||||
account_id,
|
||||
message_ids,
|
||||
CHANNEL_ID_THRESHOLD,
|
||||
)
|
||||
|
||||
|
||||
async def mark_deleted_channel(
|
||||
pool: asyncpg.Pool, account_id: int, chat_id: int, message_ids: list[int]
|
||||
) -> None:
|
||||
await pool.execute(
|
||||
"UPDATE messages SET deleted_at = now() "
|
||||
"WHERE account_id = $1 AND chat_id = $2 AND message_id = ANY($3::bigint[]) "
|
||||
"AND deleted_at IS NULL",
|
||||
account_id,
|
||||
chat_id,
|
||||
message_ids,
|
||||
)
|
||||
|
||||
|
||||
async def add_version( # noqa: PLR0913
|
||||
pool: asyncpg.Pool,
|
||||
account_id: int,
|
||||
chat_id: int,
|
||||
message_id: int,
|
||||
date: datetime,
|
||||
sender_id: int | None,
|
||||
text: str | None,
|
||||
raw: str,
|
||||
edit_date: datetime | None,
|
||||
*,
|
||||
has_media: bool,
|
||||
is_self_destruct: bool,
|
||||
) -> None:
|
||||
async with pool.acquire() as conn, conn.transaction():
|
||||
await conn.execute(
|
||||
_TOUCH_EDITED,
|
||||
account_id,
|
||||
chat_id,
|
||||
message_id,
|
||||
date,
|
||||
sender_id,
|
||||
text,
|
||||
raw,
|
||||
has_media,
|
||||
is_self_destruct,
|
||||
)
|
||||
await conn.execute(
|
||||
_INSERT_VERSION, account_id, chat_id, message_id, edit_date, text, raw
|
||||
)
|
||||
|
||||
|
||||
async def insert_media( # noqa: PLR0913
|
||||
pool: asyncpg.Pool,
|
||||
account_id: int,
|
||||
chat_id: int,
|
||||
message_id: int,
|
||||
kind: str,
|
||||
storage_key: str | None,
|
||||
file_size: int | None,
|
||||
mime: str | None,
|
||||
ttl_seconds: int | None,
|
||||
*,
|
||||
downloaded: bool,
|
||||
) -> None:
|
||||
await pool.execute(
|
||||
_INSERT_MEDIA,
|
||||
account_id,
|
||||
chat_id,
|
||||
message_id,
|
||||
kind,
|
||||
storage_key,
|
||||
file_size,
|
||||
mime,
|
||||
ttl_seconds,
|
||||
downloaded,
|
||||
)
|
||||
|
||||
|
||||
async def insert_callbacks(
|
||||
pool: asyncpg.Pool,
|
||||
account_id: int,
|
||||
chat_id: int,
|
||||
message_id: int,
|
||||
buttons: list[tuple[int, str | None, bytes | None]],
|
||||
) -> None:
|
||||
async with pool.acquire() as conn, conn.transaction():
|
||||
await conn.execute(
|
||||
"DELETE FROM callbacks "
|
||||
"WHERE account_id = $1 AND chat_id = $2 AND message_id = $3",
|
||||
account_id,
|
||||
chat_id,
|
||||
message_id,
|
||||
)
|
||||
await conn.executemany(
|
||||
"INSERT INTO callbacks "
|
||||
"(account_id, chat_id, message_id, position, label, data) "
|
||||
"VALUES ($1, $2, $3, $4, $5, $6)",
|
||||
[
|
||||
(account_id, chat_id, message_id, position, label, data)
|
||||
for position, label, data in buttons
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
async def sync_reactions(
|
||||
pool: asyncpg.Pool,
|
||||
account_id: int,
|
||||
chat_id: int,
|
||||
message_id: int,
|
||||
current: list[tuple[int, str]],
|
||||
) -> None:
|
||||
async with pool.acquire() as conn, conn.transaction():
|
||||
rows = await conn.fetch(
|
||||
"SELECT id, peer_id, reaction FROM reactions "
|
||||
"WHERE account_id = $1 AND chat_id = $2 AND message_id = $3 "
|
||||
"AND removed_at IS NULL",
|
||||
account_id,
|
||||
chat_id,
|
||||
message_id,
|
||||
)
|
||||
existing = {(row["peer_id"], row["reaction"]): row["id"] for row in rows}
|
||||
current_set = set(current)
|
||||
stale = [pid for key, pid in existing.items() if key not in current_set]
|
||||
if stale:
|
||||
await conn.execute(
|
||||
"UPDATE reactions SET removed_at = now() WHERE id = ANY($1::bigint[])",
|
||||
stale,
|
||||
)
|
||||
fresh = [key for key in current_set if key not in existing]
|
||||
if fresh:
|
||||
await conn.executemany(
|
||||
"INSERT INTO reactions "
|
||||
"(account_id, chat_id, message_id, peer_id, reaction) "
|
||||
"VALUES ($1, $2, $3, $4, $5)",
|
||||
[
|
||||
(account_id, chat_id, message_id, peer_id, reaction)
|
||||
for peer_id, reaction in fresh
|
||||
],
|
||||
)
|
||||
@@ -1,8 +1,15 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pyrogram import Client, enums
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from userbot.modules.capture import CaptureContext
|
||||
|
||||
|
||||
class PyroClient(Client):
|
||||
def __init__(self, name: str, *, workdir: str = "sessions") -> None:
|
||||
def __init__(
|
||||
self, name: str, *, workdir: str = "sessions", load_handlers: bool = True
|
||||
) -> None:
|
||||
super().__init__(
|
||||
name,
|
||||
workdir=workdir,
|
||||
@@ -14,6 +21,13 @@ class PyroClient(Client):
|
||||
lang_pack="tdesktop",
|
||||
client_platform=enums.ClientPlatform.DESKTOP,
|
||||
)
|
||||
self.capture: CaptureContext | None = None
|
||||
|
||||
if load_handlers:
|
||||
from userbot import handlers # noqa: PLC0415
|
||||
|
||||
for handler in handlers.handlers:
|
||||
self.add_handler(*handler)
|
||||
|
||||
|
||||
__all__ = ["PyroClient"]
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from userbot.modules.contacts.cache import ContactCache
|
||||
|
||||
__all__ = ["ContactCache"]
|
||||
@@ -0,0 +1,20 @@
|
||||
from pyrogram import Client
|
||||
|
||||
from utils.logging import logger
|
||||
|
||||
|
||||
class ContactCache:
|
||||
def __init__(self, client: Client) -> None:
|
||||
self._client = client
|
||||
self.ids: set[int] = set()
|
||||
|
||||
async def refresh(self) -> None:
|
||||
contacts = await self._client.get_contacts()
|
||||
self.ids = {user.id for user in contacts}
|
||||
logger.info(f"[green]Contacts cached:[/] {len(self.ids)}")
|
||||
|
||||
def mark(self, user_id: int, *, is_contact: bool) -> None:
|
||||
if is_contact:
|
||||
self.ids.add(user_id)
|
||||
else:
|
||||
self.ids.discard(user_id)
|
||||
@@ -0,0 +1,3 @@
|
||||
from userbot.modules.folders.cache import FolderCache
|
||||
|
||||
__all__ = ["FolderCache"]
|
||||
@@ -0,0 +1,3 @@
|
||||
from userbot.modules.media.downloader import capture_media, self_destruct_ttl
|
||||
|
||||
__all__ = ["capture_media", "self_destruct_ttl"]
|
||||
@@ -0,0 +1,71 @@
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
from pyrogram import Client
|
||||
from pyrogram.types import Message
|
||||
|
||||
from userbot.modules.capture import repository
|
||||
from userbot.modules.capture.context import CaptureContext
|
||||
from utils.policy.models import CaptureToggles
|
||||
|
||||
_MEDIA_ATTRS = (
|
||||
"photo",
|
||||
"video",
|
||||
"voice",
|
||||
"video_note",
|
||||
"document",
|
||||
"audio",
|
||||
"animation",
|
||||
"sticker",
|
||||
)
|
||||
|
||||
|
||||
def media_object(message: Message) -> tuple[str | None, Any]:
|
||||
for attr in _MEDIA_ATTRS:
|
||||
obj = getattr(message, attr, None)
|
||||
if obj is not None:
|
||||
return attr, obj
|
||||
return None, None
|
||||
|
||||
|
||||
def self_destruct_ttl(message: Message) -> int | None:
|
||||
_, obj = media_object(message)
|
||||
return getattr(obj, "ttl_seconds", None) if obj is not None else None
|
||||
|
||||
|
||||
async def capture_media( # noqa: PLR0913
|
||||
client: Client,
|
||||
message: Message,
|
||||
ctx: CaptureContext,
|
||||
chat_id: int,
|
||||
message_id: int,
|
||||
toggles: CaptureToggles,
|
||||
) -> None:
|
||||
kind, obj = media_object(message)
|
||||
if obj is None:
|
||||
return
|
||||
ttl = getattr(obj, "ttl_seconds", None)
|
||||
want = toggles.self_destruct_media if ttl else toggles.media
|
||||
file_size = getattr(obj, "file_size", None)
|
||||
mime = getattr(obj, "mime_type", None)
|
||||
storage_key: str | None = None
|
||||
downloaded = False
|
||||
if want:
|
||||
buffer = await client.download_media(message, in_memory=True)
|
||||
if isinstance(buffer, BytesIO):
|
||||
data = buffer.getvalue()
|
||||
storage_key = ctx.storage.put(data)
|
||||
file_size = len(data)
|
||||
downloaded = True
|
||||
await repository.insert_media(
|
||||
ctx.pool,
|
||||
ctx.account_id,
|
||||
chat_id,
|
||||
message_id,
|
||||
kind or "unknown",
|
||||
storage_key,
|
||||
file_size,
|
||||
mime,
|
||||
ttl,
|
||||
downloaded=downloaded,
|
||||
)
|
||||
@@ -7,14 +7,11 @@ import asyncpg
|
||||
import uvloop
|
||||
|
||||
from dependencies.container import container
|
||||
from userbot.folders import FolderCache
|
||||
from userbot.handlers import dialog_filter_handler
|
||||
from userbot.modules import PyroClient
|
||||
from userbot import PyroClient
|
||||
from userbot.modules.capture import build_capture_context
|
||||
from utils.env import env
|
||||
from utils.logging import logger, setup_logging
|
||||
from utils.policy.models import ChatKind, ChatMeta
|
||||
from utils.policy.repository import load_policy_set
|
||||
from utils.policy.resolver import resolve
|
||||
from utils.storage import ContentAddressedStorage
|
||||
|
||||
setup_logging()
|
||||
|
||||
@@ -61,22 +58,37 @@ async def _sync_account(
|
||||
return account_id
|
||||
|
||||
|
||||
async def _setup_policy(
|
||||
pool: asyncpg.Pool, client: PyroClient, account_id: int
|
||||
async def _setup_capture(
|
||||
pool: asyncpg.Pool,
|
||||
client: PyroClient,
|
||||
account_id: int,
|
||||
storage: ContentAddressedStorage,
|
||||
) -> None:
|
||||
cache = FolderCache(client, pool, account_id)
|
||||
await cache.refresh()
|
||||
client.add_handler(dialog_filter_handler(cache))
|
||||
if client.me:
|
||||
policies = await load_policy_set(pool, account_id)
|
||||
sample = resolve(
|
||||
ChatMeta(chat_id=client.me.id, kind=ChatKind.DM), cache.folders, policies
|
||||
)
|
||||
logger.info(f"[green]Sample resolve (self DM):[/] {sample.model_dump()}")
|
||||
client.capture = await build_capture_context(client, pool, storage, account_id)
|
||||
logger.info("[green]Capture context ready.[/]")
|
||||
|
||||
|
||||
async def _listen_policy_changes(
|
||||
clients: list[PyroClient], tasks: set[asyncio.Task]
|
||||
) -> asyncpg.Connection:
|
||||
def on_change(
|
||||
_conn: asyncpg.Connection, _pid: int, _channel: str, _payload: str
|
||||
) -> None:
|
||||
for client in clients:
|
||||
if client.capture is None:
|
||||
continue
|
||||
task = asyncio.create_task(client.capture.reload_policies())
|
||||
tasks.add(task)
|
||||
task.add_done_callback(tasks.discard)
|
||||
|
||||
conn = await asyncpg.connect(dsn=env.db.connection_url)
|
||||
await conn.add_listener("policy_changed", on_change)
|
||||
return conn
|
||||
|
||||
|
||||
async def runner() -> None:
|
||||
pool = await container.get(asyncpg.Pool)
|
||||
storage = await container.get(ContentAddressedStorage)
|
||||
|
||||
sessions_dir = Path(env.tg.sessions_dir)
|
||||
session_files = _discover_sessions(sessions_dir)
|
||||
@@ -88,6 +100,8 @@ async def runner() -> None:
|
||||
)
|
||||
|
||||
clients: list[PyroClient] = []
|
||||
reload_tasks: set[asyncio.Task] = set()
|
||||
listen_conn: asyncpg.Connection | None = None
|
||||
try:
|
||||
for session_path in session_files:
|
||||
session_name = session_path.stem
|
||||
@@ -101,12 +115,16 @@ async def runner() -> None:
|
||||
)
|
||||
account_id = await _sync_account(pool, client, session_name)
|
||||
if account_id is not None:
|
||||
await _setup_policy(pool, client, account_id)
|
||||
await _setup_capture(pool, client, account_id, storage)
|
||||
|
||||
if clients:
|
||||
listen_conn = await _listen_policy_changes(clients, reload_tasks)
|
||||
logger.info("[green]Userbot running.[/]")
|
||||
await asyncio.Event().wait()
|
||||
finally:
|
||||
if listen_conn is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
await listen_conn.close()
|
||||
for client in clients:
|
||||
with contextlib.suppress(Exception):
|
||||
await client.stop()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import BigInteger, Column, DateTime, func
|
||||
from sqlalchemy import BigInteger, Column, DateTime, LargeBinary, func
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
@@ -42,9 +42,16 @@ class Message(SQLModel, table=True):
|
||||
chat_id: int = Field(sa_column=Column(BigInteger, primary_key=True))
|
||||
message_id: int = Field(sa_column=Column(BigInteger, primary_key=True))
|
||||
date: datetime = Field(sa_column=Column(DateTime(timezone=True), primary_key=True))
|
||||
sender_id: int | None = Field(default=None, sa_column=Column(BigInteger))
|
||||
text: str | None = None
|
||||
raw: dict[str, Any] = Field(
|
||||
default_factory=dict, sa_column=Column(JSONB, nullable=False)
|
||||
)
|
||||
has_media: bool = False
|
||||
is_self_destruct: bool = False
|
||||
edited_at: datetime | None = Field(
|
||||
default=None, sa_column=Column(DateTime(timezone=True))
|
||||
)
|
||||
deleted_at: datetime | None = Field(
|
||||
default=None, sa_column=Column(DateTime(timezone=True))
|
||||
)
|
||||
@@ -71,6 +78,86 @@ class Folder(SQLModel, table=True):
|
||||
)
|
||||
|
||||
|
||||
class MessageVersion(SQLModel, table=True):
|
||||
__tablename__ = "message_versions"
|
||||
|
||||
account_id: int = Field(primary_key=True)
|
||||
chat_id: int = Field(sa_column=Column(BigInteger, primary_key=True))
|
||||
message_id: int = Field(sa_column=Column(BigInteger, primary_key=True))
|
||||
observed_at: datetime = Field(
|
||||
sa_column=Column(DateTime(timezone=True), primary_key=True)
|
||||
)
|
||||
edit_date: datetime | None = Field(
|
||||
default=None, sa_column=Column(DateTime(timezone=True))
|
||||
)
|
||||
text: str | None = None
|
||||
raw: dict[str, Any] = Field(
|
||||
default_factory=dict, sa_column=Column(JSONB, nullable=False)
|
||||
)
|
||||
|
||||
|
||||
class Media(SQLModel, table=True):
|
||||
__tablename__ = "media"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
account_id: int
|
||||
chat_id: int = Field(sa_column=Column(BigInteger, nullable=False))
|
||||
message_id: int = Field(sa_column=Column(BigInteger, nullable=False))
|
||||
kind: str
|
||||
storage_key: str | None = None
|
||||
file_size: int | None = Field(default=None, sa_column=Column(BigInteger))
|
||||
mime: str | None = None
|
||||
ttl_seconds: int | None = None
|
||||
downloaded: bool = False
|
||||
extracted_text: str | None = None
|
||||
created_at: datetime = Field(
|
||||
sa_column=Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class Callback(SQLModel, table=True):
|
||||
__tablename__ = "callbacks"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
account_id: int
|
||||
chat_id: int = Field(sa_column=Column(BigInteger, nullable=False))
|
||||
message_id: int = Field(sa_column=Column(BigInteger, nullable=False))
|
||||
position: int
|
||||
label: str | None = None
|
||||
data: bytes | None = Field(default=None, sa_column=Column(LargeBinary))
|
||||
created_at: datetime = Field(
|
||||
sa_column=Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class Reaction(SQLModel, table=True):
|
||||
__tablename__ = "reactions"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
account_id: int
|
||||
chat_id: int = Field(sa_column=Column(BigInteger, nullable=False))
|
||||
message_id: int = Field(sa_column=Column(BigInteger, nullable=False))
|
||||
peer_id: int = Field(sa_column=Column(BigInteger, nullable=False))
|
||||
reaction: str
|
||||
added_at: datetime = Field(
|
||||
sa_column=Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
)
|
||||
removed_at: datetime | None = Field(
|
||||
default=None, sa_column=Column(DateTime(timezone=True))
|
||||
)
|
||||
created_at: datetime = Field(
|
||||
sa_column=Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class CapturePolicy(SQLModel, table=True):
|
||||
__tablename__ = "capture_policy"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user