From 920a0235e22a4087743f3f8138e14502a751ac2a Mon Sep 17 00:00:00 2001 From: h Date: Fri, 29 May 2026 16:47:02 +0200 Subject: [PATCH] feat: create message capture policies --- ...b2f7c1a9d3e4_folders_and_capture_policy.py | 97 +++++++++++ backend/src/api/app.py | 4 + backend/src/api/routers/__init__.py | 0 backend/src/api/routers/folders.py | 36 ++++ backend/src/api/routers/policy.py | 87 ++++++++++ backend/src/userbot/folders.py | 68 ++++++++ backend/src/userbot/handlers/__init__.py | 3 + .../src/userbot/handlers/dialog_filters.py | 20 +++ backend/src/userbot/runner.py | 33 +++- backend/src/utils/db/models.py | 53 ++++++ backend/src/utils/policy/__init__.py | 0 backend/src/utils/policy/defaults.py | 18 ++ backend/src/utils/policy/models.py | 86 ++++++++++ backend/src/utils/policy/repository.py | 154 ++++++++++++++++++ backend/src/utils/policy/resolver.py | 46 ++++++ 15 files changed, 700 insertions(+), 5 deletions(-) create mode 100644 backend/migrations/versions/b2f7c1a9d3e4_folders_and_capture_policy.py create mode 100644 backend/src/api/routers/__init__.py create mode 100644 backend/src/api/routers/folders.py create mode 100644 backend/src/api/routers/policy.py create mode 100644 backend/src/userbot/folders.py create mode 100644 backend/src/userbot/handlers/__init__.py create mode 100644 backend/src/userbot/handlers/dialog_filters.py create mode 100644 backend/src/utils/policy/__init__.py create mode 100644 backend/src/utils/policy/defaults.py create mode 100644 backend/src/utils/policy/models.py create mode 100644 backend/src/utils/policy/repository.py create mode 100644 backend/src/utils/policy/resolver.py diff --git a/backend/migrations/versions/b2f7c1a9d3e4_folders_and_capture_policy.py b/backend/migrations/versions/b2f7c1a9d3e4_folders_and_capture_policy.py new file mode 100644 index 0000000..75c6461 --- /dev/null +++ b/backend/migrations/versions/b2f7c1a9d3e4_folders_and_capture_policy.py @@ -0,0 +1,97 @@ +"""folders and capture policy + +Revision ID: b2f7c1a9d3e4 +Revises: 77df960a31de +Create Date: 2026-05-29 17:30:00.000000 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +revision: str = "b2f7c1a9d3e4" +down_revision: str | None = "77df960a31de" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + +_TOGGLES = ( + "messages", + "media", + "self_destruct_media", + "stt", + "reactions", + "track_edits_deletes", + "profile_history", + "stories", + "presence", + "backfill", +) + + +def _toggle_columns() -> list[sa.Column]: + return [ + sa.Column(name, sa.Boolean(), nullable=False, server_default=sa.false()) + for name in _TOGGLES + ] + + +def upgrade() -> None: + op.create_table( + "folders", + sa.Column("account_id", sa.Integer(), nullable=False), + sa.Column("folder_id", sa.Integer(), nullable=False), + sa.Column("title", sa.String(), nullable=False), + sa.Column("order_index", sa.Integer(), nullable=False), + sa.Column("is_chatlist", sa.Boolean(), nullable=False), + sa.Column("raw", postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.PrimaryKeyConstraint("account_id", "folder_id"), + ) + op.create_table( + "capture_policy", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("account_id", sa.Integer(), nullable=True), + sa.Column("scope_type", sa.String(), nullable=False), + sa.Column("scope_id", sa.BigInteger(), nullable=True), + *_toggle_columns(), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("account_id", "scope_type", "scope_id"), + ) + op.execute( + "CREATE UNIQUE INDEX ix_capture_policy_default " + "ON capture_policy (scope_type) WHERE scope_type LIKE 'default_%'" + ) + + op.execute( + "INSERT INTO capture_policy " + "(scope_type, messages, media, self_destruct_media, stt, reactions, " + "track_edits_deletes, profile_history, stories, presence, backfill) VALUES " + "('default_channel', FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE), " + "('default_group', TRUE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE), " + "('default_dm', TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE)" + ) + + +def downgrade() -> None: + op.drop_table("capture_policy") + op.drop_table("folders") diff --git a/backend/src/api/app.py b/backend/src/api/app.py index 8d74ce8..1aea606 100644 --- a/backend/src/api/app.py +++ b/backend/src/api/app.py @@ -5,6 +5,7 @@ import asyncpg from dishka.integrations.fastapi import FromDishka, inject, setup_dishka from fastapi import FastAPI +from api.routers import folders, policy from dependencies.container import container @@ -27,4 +28,7 @@ async def health(pool: FromDishka[asyncpg.Pool]) -> dict[str, bool]: return {"db": db_ok, "timescaledb": bool(timescale_ok)} +app.include_router(policy.router) +app.include_router(folders.router) + setup_dishka(container, app) diff --git a/backend/src/api/routers/__init__.py b/backend/src/api/routers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/src/api/routers/folders.py b/backend/src/api/routers/folders.py new file mode 100644 index 0000000..def1af4 --- /dev/null +++ b/backend/src/api/routers/folders.py @@ -0,0 +1,36 @@ +from typing import Annotated + +import asyncpg +from dishka.integrations.fastapi import FromDishka, inject +from fastapi import APIRouter, Query + +from utils.policy import repository +from utils.policy.models import FolderSpec + +router = APIRouter(prefix="/api/folders", tags=["folders"]) + + +def _serialize(spec: FolderSpec) -> dict: + return { + "folder_id": spec.folder_id, + "order_index": spec.order_index, + "title": spec.title, + "include_ids": sorted(spec.include_ids), + "exclude_ids": sorted(spec.exclude_ids), + "pinned_ids": sorted(spec.pinned_ids), + "contacts": spec.contacts, + "non_contacts": spec.non_contacts, + "groups": spec.groups, + "broadcasts": spec.broadcasts, + "bots": spec.bots, + "is_chatlist": spec.is_chatlist, + } + + +@router.get("") +@inject +async def list_folders( + pool: FromDishka[asyncpg.Pool], account_id: Annotated[int, Query()] +) -> list[dict]: + folders = await repository.list_folders(pool, account_id) + return [_serialize(spec) for spec in folders] diff --git a/backend/src/api/routers/policy.py b/backend/src/api/routers/policy.py new file mode 100644 index 0000000..0ac0264 --- /dev/null +++ b/backend/src/api/routers/policy.py @@ -0,0 +1,87 @@ +from typing import Annotated + +import asyncpg +from dishka.integrations.fastapi import FromDishka, inject +from fastapi import APIRouter, HTTPException, Query +from pydantic import BaseModel + +from utils.policy import repository +from utils.policy.models import ( + CaptureToggles, + ChatKind, + ChatMeta, + PolicyCreate, + PolicyRecord, +) +from utils.policy.resolver import resolve + +router = APIRouter(prefix="/api/policy", tags=["policy"]) + + +class EffectiveQuery(BaseModel): + account_id: int + chat_id: int + kind: ChatKind + is_bot: bool = False + is_contact: bool | None = None + + +@router.get("") +@inject +async def list_policies( + pool: FromDishka[asyncpg.Pool], account_id: Annotated[int, Query()] +) -> list[PolicyRecord]: + return await repository.list_policies(pool, account_id) + + +@router.get("/effective") +@inject +async def effective_policy( + pool: FromDishka[asyncpg.Pool], query: Annotated[EffectiveQuery, Query()] +) -> CaptureToggles: + folders = await repository.list_folders(pool, query.account_id) + policies = await repository.load_policy_set(pool, query.account_id) + chat = ChatMeta( + chat_id=query.chat_id, + kind=query.kind, + is_bot=query.is_bot, + is_contact=query.is_contact, + ) + return resolve(chat, folders, policies) + + +@router.post("", status_code=201) +@inject +async def create_policy( + pool: FromDishka[asyncpg.Pool], body: PolicyCreate +) -> PolicyRecord: + return await repository.create_policy( + pool, body.account_id, body.scope_type, body.scope_id, body + ) + + +@router.get("/{policy_id}") +@inject +async def get_policy(pool: FromDishka[asyncpg.Pool], policy_id: int) -> PolicyRecord: + record = await repository.get_policy(pool, policy_id) + if record is None: + raise HTTPException(status_code=404, detail="policy not found") + return record + + +@router.put("/{policy_id}") +@inject +async def update_policy( + pool: FromDishka[asyncpg.Pool], policy_id: int, body: CaptureToggles +) -> PolicyRecord: + record = await repository.update_policy(pool, policy_id, body) + if record is None: + raise HTTPException(status_code=404, detail="policy not found") + return record + + +@router.delete("/{policy_id}", status_code=204) +@inject +async def delete_policy(pool: FromDishka[asyncpg.Pool], policy_id: int) -> None: + if not await repository.delete_policy(pool, policy_id): + raise HTTPException(status_code=404, detail="policy not found") diff --git a/backend/src/userbot/folders.py b/backend/src/userbot/folders.py new file mode 100644 index 0000000..1338a6e --- /dev/null +++ b/backend/src/userbot/folders.py @@ -0,0 +1,68 @@ +from collections.abc import Iterable + +import asyncpg +from pyrogram import Client, raw, utils + +from utils.logging import logger +from utils.policy.models import FolderSpec +from utils.policy.repository import replace_folders + + +def _peer_ids(peers: Iterable[raw.base.InputPeer]) -> set[int]: + ids: set[int] = set() + for peer in peers: + try: + ids.add(utils.get_peer_id(peer)) + except (ValueError, AttributeError): + continue + return ids + + +def _title(raw_title: object) -> str: + return getattr(raw_title, "text", None) or str(raw_title) + + +def _parse(raw_filter: raw.base.DialogFilter, order_index: int) -> FolderSpec | None: + if isinstance(raw_filter, raw.types.DialogFilterDefault): + return None + if isinstance(raw_filter, raw.types.DialogFilterChatlist): + return FolderSpec( + folder_id=raw_filter.id, + order_index=order_index, + title=_title(raw_filter.title), + include_ids=_peer_ids(raw_filter.include_peers), + pinned_ids=_peer_ids(raw_filter.pinned_peers), + is_chatlist=True, + ) + return FolderSpec( + folder_id=raw_filter.id, + order_index=order_index, + title=_title(raw_filter.title), + include_ids=_peer_ids(raw_filter.include_peers), + exclude_ids=_peer_ids(raw_filter.exclude_peers), + pinned_ids=_peer_ids(raw_filter.pinned_peers), + contacts=bool(raw_filter.contacts), + non_contacts=bool(raw_filter.non_contacts), + groups=bool(raw_filter.groups), + broadcasts=bool(raw_filter.broadcasts), + bots=bool(raw_filter.bots), + ) + + +class FolderCache: + def __init__(self, client: Client, pool: asyncpg.Pool, account_id: int) -> None: + self._client = client + self._pool = pool + self._account_id = account_id + self.folders: list[FolderSpec] = [] + + async def refresh(self) -> None: + result = await self._client.invoke(raw.functions.messages.GetDialogFilters()) + specs = [ + spec + for order_index, raw_filter in enumerate(result.filters) + if (spec := _parse(raw_filter, order_index)) is not None + ] + self.folders = specs + await replace_folders(self._pool, self._account_id, specs) + logger.info(f"[green]Folders cached:[/] {len(specs)}") diff --git a/backend/src/userbot/handlers/__init__.py b/backend/src/userbot/handlers/__init__.py new file mode 100644 index 0000000..f88cb35 --- /dev/null +++ b/backend/src/userbot/handlers/__init__.py @@ -0,0 +1,3 @@ +from .dialog_filters import dialog_filter_handler + +__all__ = ["dialog_filter_handler"] diff --git a/backend/src/userbot/handlers/dialog_filters.py b/backend/src/userbot/handlers/dialog_filters.py new file mode 100644 index 0000000..269bcfe --- /dev/null +++ b/backend/src/userbot/handlers/dialog_filters.py @@ -0,0 +1,20 @@ +from pyrogram import Client, raw +from pyrogram.handlers import RawUpdateHandler + +from userbot.folders import FolderCache + +_FILTER_UPDATES = ( + raw.types.UpdateDialogFilter, + raw.types.UpdateDialogFilters, + raw.types.UpdateDialogFilterOrder, +) + + +def dialog_filter_handler(cache: FolderCache) -> RawUpdateHandler: + async def on_update( + _client: Client, update: raw.base.Update, _users: dict, _chats: dict + ) -> None: + if isinstance(update, _FILTER_UPDATES): + await cache.refresh() + + return RawUpdateHandler(on_update) diff --git a/backend/src/userbot/runner.py b/backend/src/userbot/runner.py index 5aef66e..602707e 100644 --- a/backend/src/userbot/runner.py +++ b/backend/src/userbot/runner.py @@ -7,9 +7,14 @@ import asyncpg import uvloop from dependencies.container import container +from userbot.folders import FolderCache +from userbot.handlers import dialog_filter_handler from userbot.modules import PyroClient from utils.env import env from utils.logging import logger, setup_logging +from utils.policy.models import ChatKind, ChatMeta +from utils.policy.repository import load_policy_set +from utils.policy.resolver import resolve setup_logging() @@ -24,6 +29,7 @@ ON CONFLICT (tg_user_id) DO UPDATE SET is_active = TRUE, raw = EXCLUDED.raw, updated_at = now() +RETURNING account_id """ @@ -34,10 +40,10 @@ def _discover_sessions(sessions_dir: Path) -> list[Path]: async def _sync_account( pool: asyncpg.Pool, client: PyroClient, session_name: str -) -> None: +) -> int | None: me = client.me if not me: - return + return None raw = json.dumps( { "id": me.id, @@ -48,10 +54,25 @@ async def _sync_account( } ) label = " ".join(filter(None, [me.first_name, me.last_name])) or me.username - await pool.execute( + account_id = await pool.fetchval( _UPSERT_ACCOUNT, me.id, label, me.phone_number, session_name, raw ) logger.info(f"[green]Account synced:[/] {label} ({me.id})") + return account_id + + +async def _setup_policy( + pool: asyncpg.Pool, client: PyroClient, account_id: int +) -> None: + cache = FolderCache(client, pool, account_id) + await cache.refresh() + client.add_handler(dialog_filter_handler(cache)) + if client.me: + policies = await load_policy_set(pool, account_id) + sample = resolve( + ChatMeta(chat_id=client.me.id, kind=ChatKind.DM), cache.folders, policies + ) + logger.info(f"[green]Sample resolve (self DM):[/] {sample.model_dump()}") async def runner() -> None: @@ -78,10 +99,12 @@ async def runner() -> None: f"{client.me.full_name if client.me else 'unknown'} " f"{client.me.id if client.me else 'unknown'}" ) - await _sync_account(pool, client, session_name) + account_id = await _sync_account(pool, client, session_name) + if account_id is not None: + await _setup_policy(pool, client, account_id) if clients: - logger.info("[green]Userbot running. Idle (no handlers until phase 3).[/]") + logger.info("[green]Userbot running.[/]") await asyncio.Event().wait() finally: for client in clients: diff --git a/backend/src/utils/db/models.py b/backend/src/utils/db/models.py index 3485f87..d680807 100644 --- a/backend/src/utils/db/models.py +++ b/backend/src/utils/db/models.py @@ -48,3 +48,56 @@ class Message(SQLModel, table=True): deleted_at: datetime | None = Field( default=None, sa_column=Column(DateTime(timezone=True)) ) + + +class Folder(SQLModel, table=True): + __tablename__ = "folders" + + account_id: int = Field(primary_key=True) + folder_id: int = Field(primary_key=True) + title: str + order_index: int + is_chatlist: bool = False + raw: dict[str, Any] = Field( + default_factory=dict, sa_column=Column(JSONB, nullable=False) + ) + updated_at: datetime = Field( + sa_column=Column( + DateTime(timezone=True), + nullable=False, + server_default=func.now(), + onupdate=func.now(), + ) + ) + + +class CapturePolicy(SQLModel, table=True): + __tablename__ = "capture_policy" + + id: int | None = Field(default=None, primary_key=True) + account_id: int | None = None + scope_type: str + scope_id: int | None = Field(default=None, sa_column=Column(BigInteger)) + messages: bool = False + media: bool = False + self_destruct_media: bool = False + stt: bool = False + reactions: bool = False + track_edits_deletes: bool = False + profile_history: bool = False + stories: bool = False + presence: bool = False + backfill: bool = False + created_at: datetime = Field( + sa_column=Column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + ) + updated_at: datetime = Field( + sa_column=Column( + DateTime(timezone=True), + nullable=False, + server_default=func.now(), + onupdate=func.now(), + ) + ) diff --git a/backend/src/utils/policy/__init__.py b/backend/src/utils/policy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/src/utils/policy/defaults.py b/backend/src/utils/policy/defaults.py new file mode 100644 index 0000000..1b6c811 --- /dev/null +++ b/backend/src/utils/policy/defaults.py @@ -0,0 +1,18 @@ +from utils.policy.models import CaptureToggles, ChatKind + +DEFAULTS: dict[ChatKind, CaptureToggles] = { + ChatKind.CHANNEL: CaptureToggles(), + ChatKind.GROUP: CaptureToggles(messages=True), + ChatKind.DM: CaptureToggles( + messages=True, + media=True, + self_destruct_media=True, + stt=True, + reactions=True, + track_edits_deletes=True, + profile_history=True, + stories=True, + presence=True, + backfill=True, + ), +} diff --git a/backend/src/utils/policy/models.py b/backend/src/utils/policy/models.py new file mode 100644 index 0000000..a2dce6e --- /dev/null +++ b/backend/src/utils/policy/models.py @@ -0,0 +1,86 @@ +from dataclasses import dataclass, field +from enum import StrEnum + +from pydantic import BaseModel + + +class ChatKind(StrEnum): + DM = "dm" + GROUP = "group" + CHANNEL = "channel" + + +class ScopeType(StrEnum): + DEFAULT_DM = "default_dm" + DEFAULT_GROUP = "default_group" + DEFAULT_CHANNEL = "default_channel" + FOLDER = "folder" + CHAT = "chat" + + +DEFAULT_SCOPE_BY_KIND: dict[ChatKind, ScopeType] = { + ChatKind.DM: ScopeType.DEFAULT_DM, + ChatKind.GROUP: ScopeType.DEFAULT_GROUP, + ChatKind.CHANNEL: ScopeType.DEFAULT_CHANNEL, +} + +KIND_BY_DEFAULT_SCOPE: dict[ScopeType, ChatKind] = { + scope: kind for kind, scope in DEFAULT_SCOPE_BY_KIND.items() +} + + +class CaptureToggles(BaseModel): + messages: bool = False + media: bool = False + self_destruct_media: bool = False + stt: bool = False + reactions: bool = False + track_edits_deletes: bool = False + profile_history: bool = False + stories: bool = False + presence: bool = False + backfill: bool = False + + +class PolicyCreate(CaptureToggles): + account_id: int | None = None + scope_type: ScopeType + scope_id: int | None = None + + +class PolicyRecord(CaptureToggles): + id: int + account_id: int | None + scope_type: ScopeType + scope_id: int | None + + +@dataclass(slots=True) +class ChatMeta: + chat_id: int + kind: ChatKind + is_bot: bool = False + is_contact: bool | None = None + + +@dataclass(slots=True) +class FolderSpec: + folder_id: int + order_index: int + title: str + include_ids: set[int] = field(default_factory=set) + exclude_ids: set[int] = field(default_factory=set) + pinned_ids: set[int] = field(default_factory=set) + contacts: bool = False + non_contacts: bool = False + groups: bool = False + broadcasts: bool = False + bots: bool = False + is_chatlist: bool = False + + +@dataclass(slots=True) +class PolicySet: + chat: dict[int, CaptureToggles] = field(default_factory=dict) + folder: dict[int, CaptureToggles] = field(default_factory=dict) + defaults: dict[ChatKind, CaptureToggles] = field(default_factory=dict) diff --git a/backend/src/utils/policy/repository.py b/backend/src/utils/policy/repository.py new file mode 100644 index 0000000..1cb623f --- /dev/null +++ b/backend/src/utils/policy/repository.py @@ -0,0 +1,154 @@ +import json + +import asyncpg + +from utils.policy.models import ( + KIND_BY_DEFAULT_SCOPE, + CaptureToggles, + FolderSpec, + PolicyRecord, + PolicySet, + ScopeType, +) + +TOGGLES: tuple[str, ...] = tuple(CaptureToggles.model_fields) +_TOGGLE_COLS = ", ".join(TOGGLES) +_TOGGLE_SET = ", ".join(f"{name} = ${i}" for i, name in enumerate(TOGGLES, start=2)) + + +def _toggle_values(toggles: CaptureToggles) -> tuple[bool, ...]: + return tuple(getattr(toggles, name) for name in TOGGLES) + + +def _folder_raw(spec: FolderSpec) -> str: + return json.dumps( + { + "include_ids": sorted(spec.include_ids), + "exclude_ids": sorted(spec.exclude_ids), + "pinned_ids": sorted(spec.pinned_ids), + "contacts": spec.contacts, + "non_contacts": spec.non_contacts, + "groups": spec.groups, + "broadcasts": spec.broadcasts, + "bots": spec.bots, + } + ) + + +def _row_to_folder(row: asyncpg.Record) -> FolderSpec: + raw = json.loads(row["raw"]) + return FolderSpec( + folder_id=row["folder_id"], + order_index=row["order_index"], + title=row["title"], + include_ids=set(raw["include_ids"]), + exclude_ids=set(raw["exclude_ids"]), + pinned_ids=set(raw["pinned_ids"]), + contacts=raw["contacts"], + non_contacts=raw["non_contacts"], + groups=raw["groups"], + broadcasts=raw["broadcasts"], + bots=raw["bots"], + is_chatlist=row["is_chatlist"], + ) + + +async def replace_folders( + pool: asyncpg.Pool, account_id: int, specs: list[FolderSpec] +) -> None: + async with pool.acquire() as conn, conn.transaction(): + await conn.execute("DELETE FROM folders WHERE account_id = $1", account_id) + await conn.executemany( + "INSERT INTO folders " + "(account_id, folder_id, title, order_index, is_chatlist, raw, updated_at) " + "VALUES ($1, $2, $3, $4, $5, $6::jsonb, now())", + [ + ( + account_id, + spec.folder_id, + spec.title, + spec.order_index, + spec.is_chatlist, + _folder_raw(spec), + ) + for spec in specs + ], + ) + + +async def list_folders(pool: asyncpg.Pool, account_id: int) -> list[FolderSpec]: + rows = await pool.fetch( + "SELECT folder_id, title, order_index, is_chatlist, raw " + "FROM folders WHERE account_id = $1 ORDER BY order_index", + account_id, + ) + return [_row_to_folder(row) for row in rows] + + +async def create_policy( + pool: asyncpg.Pool, + account_id: int | None, + scope_type: ScopeType, + scope_id: int | None, + toggles: CaptureToggles, +) -> PolicyRecord: + placeholders = ", ".join(f"${i}" for i in range(4, 4 + len(TOGGLES))) + row = await pool.fetchrow( + f"INSERT INTO capture_policy " # noqa: S608 + f"(account_id, scope_type, scope_id, {_TOGGLE_COLS}) " + f"VALUES ($1, $2, $3, {placeholders}) RETURNING *", + account_id, + scope_type.value, + scope_id, + *_toggle_values(toggles), + ) + return PolicyRecord(**dict(row)) + + +async def get_policy(pool: asyncpg.Pool, policy_id: int) -> PolicyRecord | None: + row = await pool.fetchrow("SELECT * FROM capture_policy WHERE id = $1", policy_id) + return PolicyRecord(**dict(row)) if row else None + + +async def list_policies(pool: asyncpg.Pool, account_id: int) -> list[PolicyRecord]: + rows = await pool.fetch( + "SELECT * FROM capture_policy WHERE account_id = $1 OR account_id IS NULL " + "ORDER BY scope_type, scope_id", + account_id, + ) + return [PolicyRecord(**dict(row)) for row in rows] + + +async def update_policy( + pool: asyncpg.Pool, policy_id: int, toggles: CaptureToggles +) -> PolicyRecord | None: + row = await pool.fetchrow( + f"UPDATE capture_policy SET {_TOGGLE_SET}, updated_at = now() " # noqa: S608 + "WHERE id = $1 RETURNING *", + policy_id, + *_toggle_values(toggles), + ) + return PolicyRecord(**dict(row)) if row else None + + +async def delete_policy(pool: asyncpg.Pool, policy_id: int) -> bool: + result = await pool.execute("DELETE FROM capture_policy WHERE id = $1", policy_id) + return result.endswith("1") + + +async def load_policy_set(pool: asyncpg.Pool, account_id: int) -> PolicySet: + rows = await pool.fetch( + "SELECT * FROM capture_policy WHERE account_id = $1 OR account_id IS NULL", + account_id, + ) + policies = PolicySet() + for row in rows: + toggles = CaptureToggles(**{name: row[name] for name in TOGGLES}) + scope_type = ScopeType(row["scope_type"]) + if scope_type is ScopeType.CHAT and row["scope_id"] is not None: + policies.chat[row["scope_id"]] = toggles + elif scope_type is ScopeType.FOLDER and row["scope_id"] is not None: + policies.folder[row["scope_id"]] = toggles + elif scope_type in KIND_BY_DEFAULT_SCOPE: + policies.defaults[KIND_BY_DEFAULT_SCOPE[scope_type]] = toggles + return policies diff --git a/backend/src/utils/policy/resolver.py b/backend/src/utils/policy/resolver.py new file mode 100644 index 0000000..7b597f1 --- /dev/null +++ b/backend/src/utils/policy/resolver.py @@ -0,0 +1,46 @@ +from collections.abc import Iterable + +from utils.policy.models import ( + CaptureToggles, + ChatKind, + ChatMeta, + FolderSpec, + PolicySet, +) + + +def _category_match(folder: FolderSpec, chat: ChatMeta) -> bool: + if chat.kind is ChatKind.GROUP: + return folder.groups + if chat.kind is ChatKind.CHANNEL: + return folder.broadcasts + if chat.is_bot: + return folder.bots + if chat.is_contact is True: + return folder.contacts + if chat.is_contact is False: + return folder.non_contacts + return False + + +def folder_contains(folder: FolderSpec, chat: ChatMeta) -> bool: + if chat.chat_id in folder.exclude_ids: + return False + if chat.chat_id in folder.include_ids or chat.chat_id in folder.pinned_ids: + return True + if folder.is_chatlist: + return False + return _category_match(folder, chat) + + +def resolve( + chat: ChatMeta, folders: Iterable[FolderSpec], policies: PolicySet +) -> CaptureToggles: + chat_override = policies.chat.get(chat.chat_id) + if chat_override is not None: + return chat_override + for folder in sorted(folders, key=lambda f: f.order_index): + toggles = policies.folder.get(folder.folder_id) + if toggles is not None and folder_contains(folder, chat): + return toggles + return policies.defaults.get(chat.kind, CaptureToggles())