feat: create message capture policies
This commit is contained in:
@@ -0,0 +1,97 @@
|
|||||||
|
"""folders and capture policy
|
||||||
|
|
||||||
|
Revision ID: b2f7c1a9d3e4
|
||||||
|
Revises: 77df960a31de
|
||||||
|
Create Date: 2026-05-29 17:30:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
revision: str = "b2f7c1a9d3e4"
|
||||||
|
down_revision: str | None = "77df960a31de"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
_TOGGLES = (
|
||||||
|
"messages",
|
||||||
|
"media",
|
||||||
|
"self_destruct_media",
|
||||||
|
"stt",
|
||||||
|
"reactions",
|
||||||
|
"track_edits_deletes",
|
||||||
|
"profile_history",
|
||||||
|
"stories",
|
||||||
|
"presence",
|
||||||
|
"backfill",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _toggle_columns() -> list[sa.Column]:
|
||||||
|
return [
|
||||||
|
sa.Column(name, sa.Boolean(), nullable=False, server_default=sa.false())
|
||||||
|
for name in _TOGGLES
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"folders",
|
||||||
|
sa.Column("account_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("folder_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("title", sa.String(), nullable=False),
|
||||||
|
sa.Column("order_index", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("is_chatlist", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("raw", postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("account_id", "folder_id"),
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"capture_policy",
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("account_id", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("scope_type", sa.String(), nullable=False),
|
||||||
|
sa.Column("scope_id", sa.BigInteger(), nullable=True),
|
||||||
|
*_toggle_columns(),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.UniqueConstraint("account_id", "scope_type", "scope_id"),
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX ix_capture_policy_default "
|
||||||
|
"ON capture_policy (scope_type) WHERE scope_type LIKE 'default_%'"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"INSERT INTO capture_policy "
|
||||||
|
"(scope_type, messages, media, self_destruct_media, stt, reactions, "
|
||||||
|
"track_edits_deletes, profile_history, stories, presence, backfill) VALUES "
|
||||||
|
"('default_channel', FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE), "
|
||||||
|
"('default_group', TRUE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE), "
|
||||||
|
"('default_dm', TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table("capture_policy")
|
||||||
|
op.drop_table("folders")
|
||||||
@@ -5,6 +5,7 @@ import asyncpg
|
|||||||
from dishka.integrations.fastapi import FromDishka, inject, setup_dishka
|
from dishka.integrations.fastapi import FromDishka, inject, setup_dishka
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from api.routers import folders, policy
|
||||||
from dependencies.container import container
|
from dependencies.container import container
|
||||||
|
|
||||||
|
|
||||||
@@ -27,4 +28,7 @@ async def health(pool: FromDishka[asyncpg.Pool]) -> dict[str, bool]:
|
|||||||
return {"db": db_ok, "timescaledb": bool(timescale_ok)}
|
return {"db": db_ok, "timescaledb": bool(timescale_ok)}
|
||||||
|
|
||||||
|
|
||||||
|
app.include_router(policy.router)
|
||||||
|
app.include_router(folders.router)
|
||||||
|
|
||||||
setup_dishka(container, app)
|
setup_dishka(container, app)
|
||||||
|
|||||||
@@ -0,0 +1,36 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
import asyncpg
|
||||||
|
from dishka.integrations.fastapi import FromDishka, inject
|
||||||
|
from fastapi import APIRouter, Query
|
||||||
|
|
||||||
|
from utils.policy import repository
|
||||||
|
from utils.policy.models import FolderSpec
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/folders", tags=["folders"])
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize(spec: FolderSpec) -> dict:
|
||||||
|
return {
|
||||||
|
"folder_id": spec.folder_id,
|
||||||
|
"order_index": spec.order_index,
|
||||||
|
"title": spec.title,
|
||||||
|
"include_ids": sorted(spec.include_ids),
|
||||||
|
"exclude_ids": sorted(spec.exclude_ids),
|
||||||
|
"pinned_ids": sorted(spec.pinned_ids),
|
||||||
|
"contacts": spec.contacts,
|
||||||
|
"non_contacts": spec.non_contacts,
|
||||||
|
"groups": spec.groups,
|
||||||
|
"broadcasts": spec.broadcasts,
|
||||||
|
"bots": spec.bots,
|
||||||
|
"is_chatlist": spec.is_chatlist,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
@inject
|
||||||
|
async def list_folders(
|
||||||
|
pool: FromDishka[asyncpg.Pool], account_id: Annotated[int, Query()]
|
||||||
|
) -> list[dict]:
|
||||||
|
folders = await repository.list_folders(pool, account_id)
|
||||||
|
return [_serialize(spec) for spec in folders]
|
||||||
@@ -0,0 +1,87 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
import asyncpg
|
||||||
|
from dishka.integrations.fastapi import FromDishka, inject
|
||||||
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from utils.policy import repository
|
||||||
|
from utils.policy.models import (
|
||||||
|
CaptureToggles,
|
||||||
|
ChatKind,
|
||||||
|
ChatMeta,
|
||||||
|
PolicyCreate,
|
||||||
|
PolicyRecord,
|
||||||
|
)
|
||||||
|
from utils.policy.resolver import resolve
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/policy", tags=["policy"])
|
||||||
|
|
||||||
|
|
||||||
|
class EffectiveQuery(BaseModel):
|
||||||
|
account_id: int
|
||||||
|
chat_id: int
|
||||||
|
kind: ChatKind
|
||||||
|
is_bot: bool = False
|
||||||
|
is_contact: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
@inject
|
||||||
|
async def list_policies(
|
||||||
|
pool: FromDishka[asyncpg.Pool], account_id: Annotated[int, Query()]
|
||||||
|
) -> list[PolicyRecord]:
|
||||||
|
return await repository.list_policies(pool, account_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/effective")
|
||||||
|
@inject
|
||||||
|
async def effective_policy(
|
||||||
|
pool: FromDishka[asyncpg.Pool], query: Annotated[EffectiveQuery, Query()]
|
||||||
|
) -> CaptureToggles:
|
||||||
|
folders = await repository.list_folders(pool, query.account_id)
|
||||||
|
policies = await repository.load_policy_set(pool, query.account_id)
|
||||||
|
chat = ChatMeta(
|
||||||
|
chat_id=query.chat_id,
|
||||||
|
kind=query.kind,
|
||||||
|
is_bot=query.is_bot,
|
||||||
|
is_contact=query.is_contact,
|
||||||
|
)
|
||||||
|
return resolve(chat, folders, policies)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", status_code=201)
|
||||||
|
@inject
|
||||||
|
async def create_policy(
|
||||||
|
pool: FromDishka[asyncpg.Pool], body: PolicyCreate
|
||||||
|
) -> PolicyRecord:
|
||||||
|
return await repository.create_policy(
|
||||||
|
pool, body.account_id, body.scope_type, body.scope_id, body
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{policy_id}")
|
||||||
|
@inject
|
||||||
|
async def get_policy(pool: FromDishka[asyncpg.Pool], policy_id: int) -> PolicyRecord:
|
||||||
|
record = await repository.get_policy(pool, policy_id)
|
||||||
|
if record is None:
|
||||||
|
raise HTTPException(status_code=404, detail="policy not found")
|
||||||
|
return record
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{policy_id}")
|
||||||
|
@inject
|
||||||
|
async def update_policy(
|
||||||
|
pool: FromDishka[asyncpg.Pool], policy_id: int, body: CaptureToggles
|
||||||
|
) -> PolicyRecord:
|
||||||
|
record = await repository.update_policy(pool, policy_id, body)
|
||||||
|
if record is None:
|
||||||
|
raise HTTPException(status_code=404, detail="policy not found")
|
||||||
|
return record
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{policy_id}", status_code=204)
|
||||||
|
@inject
|
||||||
|
async def delete_policy(pool: FromDishka[asyncpg.Pool], policy_id: int) -> None:
|
||||||
|
if not await repository.delete_policy(pool, policy_id):
|
||||||
|
raise HTTPException(status_code=404, detail="policy not found")
|
||||||
@@ -0,0 +1,68 @@
|
|||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
import asyncpg
|
||||||
|
from pyrogram import Client, raw, utils
|
||||||
|
|
||||||
|
from utils.logging import logger
|
||||||
|
from utils.policy.models import FolderSpec
|
||||||
|
from utils.policy.repository import replace_folders
|
||||||
|
|
||||||
|
|
||||||
|
def _peer_ids(peers: Iterable[raw.base.InputPeer]) -> set[int]:
|
||||||
|
ids: set[int] = set()
|
||||||
|
for peer in peers:
|
||||||
|
try:
|
||||||
|
ids.add(utils.get_peer_id(peer))
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
continue
|
||||||
|
return ids
|
||||||
|
|
||||||
|
|
||||||
|
def _title(raw_title: object) -> str:
|
||||||
|
return getattr(raw_title, "text", None) or str(raw_title)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse(raw_filter: raw.base.DialogFilter, order_index: int) -> FolderSpec | None:
|
||||||
|
if isinstance(raw_filter, raw.types.DialogFilterDefault):
|
||||||
|
return None
|
||||||
|
if isinstance(raw_filter, raw.types.DialogFilterChatlist):
|
||||||
|
return FolderSpec(
|
||||||
|
folder_id=raw_filter.id,
|
||||||
|
order_index=order_index,
|
||||||
|
title=_title(raw_filter.title),
|
||||||
|
include_ids=_peer_ids(raw_filter.include_peers),
|
||||||
|
pinned_ids=_peer_ids(raw_filter.pinned_peers),
|
||||||
|
is_chatlist=True,
|
||||||
|
)
|
||||||
|
return FolderSpec(
|
||||||
|
folder_id=raw_filter.id,
|
||||||
|
order_index=order_index,
|
||||||
|
title=_title(raw_filter.title),
|
||||||
|
include_ids=_peer_ids(raw_filter.include_peers),
|
||||||
|
exclude_ids=_peer_ids(raw_filter.exclude_peers),
|
||||||
|
pinned_ids=_peer_ids(raw_filter.pinned_peers),
|
||||||
|
contacts=bool(raw_filter.contacts),
|
||||||
|
non_contacts=bool(raw_filter.non_contacts),
|
||||||
|
groups=bool(raw_filter.groups),
|
||||||
|
broadcasts=bool(raw_filter.broadcasts),
|
||||||
|
bots=bool(raw_filter.bots),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FolderCache:
|
||||||
|
def __init__(self, client: Client, pool: asyncpg.Pool, account_id: int) -> None:
|
||||||
|
self._client = client
|
||||||
|
self._pool = pool
|
||||||
|
self._account_id = account_id
|
||||||
|
self.folders: list[FolderSpec] = []
|
||||||
|
|
||||||
|
async def refresh(self) -> None:
|
||||||
|
result = await self._client.invoke(raw.functions.messages.GetDialogFilters())
|
||||||
|
specs = [
|
||||||
|
spec
|
||||||
|
for order_index, raw_filter in enumerate(result.filters)
|
||||||
|
if (spec := _parse(raw_filter, order_index)) is not None
|
||||||
|
]
|
||||||
|
self.folders = specs
|
||||||
|
await replace_folders(self._pool, self._account_id, specs)
|
||||||
|
logger.info(f"[green]Folders cached:[/] {len(specs)}")
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
from .dialog_filters import dialog_filter_handler
|
||||||
|
|
||||||
|
__all__ = ["dialog_filter_handler"]
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
from pyrogram import Client, raw
|
||||||
|
from pyrogram.handlers import RawUpdateHandler
|
||||||
|
|
||||||
|
from userbot.folders import FolderCache
|
||||||
|
|
||||||
|
_FILTER_UPDATES = (
|
||||||
|
raw.types.UpdateDialogFilter,
|
||||||
|
raw.types.UpdateDialogFilters,
|
||||||
|
raw.types.UpdateDialogFilterOrder,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def dialog_filter_handler(cache: FolderCache) -> RawUpdateHandler:
|
||||||
|
async def on_update(
|
||||||
|
_client: Client, update: raw.base.Update, _users: dict, _chats: dict
|
||||||
|
) -> None:
|
||||||
|
if isinstance(update, _FILTER_UPDATES):
|
||||||
|
await cache.refresh()
|
||||||
|
|
||||||
|
return RawUpdateHandler(on_update)
|
||||||
@@ -7,9 +7,14 @@ import asyncpg
|
|||||||
import uvloop
|
import uvloop
|
||||||
|
|
||||||
from dependencies.container import container
|
from dependencies.container import container
|
||||||
|
from userbot.folders import FolderCache
|
||||||
|
from userbot.handlers import dialog_filter_handler
|
||||||
from userbot.modules import PyroClient
|
from userbot.modules import PyroClient
|
||||||
from utils.env import env
|
from utils.env import env
|
||||||
from utils.logging import logger, setup_logging
|
from utils.logging import logger, setup_logging
|
||||||
|
from utils.policy.models import ChatKind, ChatMeta
|
||||||
|
from utils.policy.repository import load_policy_set
|
||||||
|
from utils.policy.resolver import resolve
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
|
|
||||||
@@ -24,6 +29,7 @@ ON CONFLICT (tg_user_id) DO UPDATE SET
|
|||||||
is_active = TRUE,
|
is_active = TRUE,
|
||||||
raw = EXCLUDED.raw,
|
raw = EXCLUDED.raw,
|
||||||
updated_at = now()
|
updated_at = now()
|
||||||
|
RETURNING account_id
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@@ -34,10 +40,10 @@ def _discover_sessions(sessions_dir: Path) -> list[Path]:
|
|||||||
|
|
||||||
async def _sync_account(
|
async def _sync_account(
|
||||||
pool: asyncpg.Pool, client: PyroClient, session_name: str
|
pool: asyncpg.Pool, client: PyroClient, session_name: str
|
||||||
) -> None:
|
) -> int | None:
|
||||||
me = client.me
|
me = client.me
|
||||||
if not me:
|
if not me:
|
||||||
return
|
return None
|
||||||
raw = json.dumps(
|
raw = json.dumps(
|
||||||
{
|
{
|
||||||
"id": me.id,
|
"id": me.id,
|
||||||
@@ -48,10 +54,25 @@ async def _sync_account(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
label = " ".join(filter(None, [me.first_name, me.last_name])) or me.username
|
label = " ".join(filter(None, [me.first_name, me.last_name])) or me.username
|
||||||
await pool.execute(
|
account_id = await pool.fetchval(
|
||||||
_UPSERT_ACCOUNT, me.id, label, me.phone_number, session_name, raw
|
_UPSERT_ACCOUNT, me.id, label, me.phone_number, session_name, raw
|
||||||
)
|
)
|
||||||
logger.info(f"[green]Account synced:[/] {label} ({me.id})")
|
logger.info(f"[green]Account synced:[/] {label} ({me.id})")
|
||||||
|
return account_id
|
||||||
|
|
||||||
|
|
||||||
|
async def _setup_policy(
|
||||||
|
pool: asyncpg.Pool, client: PyroClient, account_id: int
|
||||||
|
) -> None:
|
||||||
|
cache = FolderCache(client, pool, account_id)
|
||||||
|
await cache.refresh()
|
||||||
|
client.add_handler(dialog_filter_handler(cache))
|
||||||
|
if client.me:
|
||||||
|
policies = await load_policy_set(pool, account_id)
|
||||||
|
sample = resolve(
|
||||||
|
ChatMeta(chat_id=client.me.id, kind=ChatKind.DM), cache.folders, policies
|
||||||
|
)
|
||||||
|
logger.info(f"[green]Sample resolve (self DM):[/] {sample.model_dump()}")
|
||||||
|
|
||||||
|
|
||||||
async def runner() -> None:
|
async def runner() -> None:
|
||||||
@@ -78,10 +99,12 @@ async def runner() -> None:
|
|||||||
f"{client.me.full_name if client.me else 'unknown'} "
|
f"{client.me.full_name if client.me else 'unknown'} "
|
||||||
f"{client.me.id if client.me else 'unknown'}"
|
f"{client.me.id if client.me else 'unknown'}"
|
||||||
)
|
)
|
||||||
await _sync_account(pool, client, session_name)
|
account_id = await _sync_account(pool, client, session_name)
|
||||||
|
if account_id is not None:
|
||||||
|
await _setup_policy(pool, client, account_id)
|
||||||
|
|
||||||
if clients:
|
if clients:
|
||||||
logger.info("[green]Userbot running. Idle (no handlers until phase 3).[/]")
|
logger.info("[green]Userbot running.[/]")
|
||||||
await asyncio.Event().wait()
|
await asyncio.Event().wait()
|
||||||
finally:
|
finally:
|
||||||
for client in clients:
|
for client in clients:
|
||||||
|
|||||||
@@ -48,3 +48,56 @@ class Message(SQLModel, table=True):
|
|||||||
deleted_at: datetime | None = Field(
|
deleted_at: datetime | None = Field(
|
||||||
default=None, sa_column=Column(DateTime(timezone=True))
|
default=None, sa_column=Column(DateTime(timezone=True))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Folder(SQLModel, table=True):
|
||||||
|
__tablename__ = "folders"
|
||||||
|
|
||||||
|
account_id: int = Field(primary_key=True)
|
||||||
|
folder_id: int = Field(primary_key=True)
|
||||||
|
title: str
|
||||||
|
order_index: int
|
||||||
|
is_chatlist: bool = False
|
||||||
|
raw: dict[str, Any] = Field(
|
||||||
|
default_factory=dict, sa_column=Column(JSONB, nullable=False)
|
||||||
|
)
|
||||||
|
updated_at: datetime = Field(
|
||||||
|
sa_column=Column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=func.now(),
|
||||||
|
onupdate=func.now(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CapturePolicy(SQLModel, table=True):
|
||||||
|
__tablename__ = "capture_policy"
|
||||||
|
|
||||||
|
id: int | None = Field(default=None, primary_key=True)
|
||||||
|
account_id: int | None = None
|
||||||
|
scope_type: str
|
||||||
|
scope_id: int | None = Field(default=None, sa_column=Column(BigInteger))
|
||||||
|
messages: bool = False
|
||||||
|
media: bool = False
|
||||||
|
self_destruct_media: bool = False
|
||||||
|
stt: bool = False
|
||||||
|
reactions: bool = False
|
||||||
|
track_edits_deletes: bool = False
|
||||||
|
profile_history: bool = False
|
||||||
|
stories: bool = False
|
||||||
|
presence: bool = False
|
||||||
|
backfill: bool = False
|
||||||
|
created_at: datetime = Field(
|
||||||
|
sa_column=Column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
updated_at: datetime = Field(
|
||||||
|
sa_column=Column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=func.now(),
|
||||||
|
onupdate=func.now(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|||||||
@@ -0,0 +1,18 @@
|
|||||||
|
from utils.policy.models import CaptureToggles, ChatKind
|
||||||
|
|
||||||
|
DEFAULTS: dict[ChatKind, CaptureToggles] = {
|
||||||
|
ChatKind.CHANNEL: CaptureToggles(),
|
||||||
|
ChatKind.GROUP: CaptureToggles(messages=True),
|
||||||
|
ChatKind.DM: CaptureToggles(
|
||||||
|
messages=True,
|
||||||
|
media=True,
|
||||||
|
self_destruct_media=True,
|
||||||
|
stt=True,
|
||||||
|
reactions=True,
|
||||||
|
track_edits_deletes=True,
|
||||||
|
profile_history=True,
|
||||||
|
stories=True,
|
||||||
|
presence=True,
|
||||||
|
backfill=True,
|
||||||
|
),
|
||||||
|
}
|
||||||
@@ -0,0 +1,86 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import StrEnum
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ChatKind(StrEnum):
|
||||||
|
DM = "dm"
|
||||||
|
GROUP = "group"
|
||||||
|
CHANNEL = "channel"
|
||||||
|
|
||||||
|
|
||||||
|
class ScopeType(StrEnum):
|
||||||
|
DEFAULT_DM = "default_dm"
|
||||||
|
DEFAULT_GROUP = "default_group"
|
||||||
|
DEFAULT_CHANNEL = "default_channel"
|
||||||
|
FOLDER = "folder"
|
||||||
|
CHAT = "chat"
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_SCOPE_BY_KIND: dict[ChatKind, ScopeType] = {
|
||||||
|
ChatKind.DM: ScopeType.DEFAULT_DM,
|
||||||
|
ChatKind.GROUP: ScopeType.DEFAULT_GROUP,
|
||||||
|
ChatKind.CHANNEL: ScopeType.DEFAULT_CHANNEL,
|
||||||
|
}
|
||||||
|
|
||||||
|
KIND_BY_DEFAULT_SCOPE: dict[ScopeType, ChatKind] = {
|
||||||
|
scope: kind for kind, scope in DEFAULT_SCOPE_BY_KIND.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class CaptureToggles(BaseModel):
|
||||||
|
messages: bool = False
|
||||||
|
media: bool = False
|
||||||
|
self_destruct_media: bool = False
|
||||||
|
stt: bool = False
|
||||||
|
reactions: bool = False
|
||||||
|
track_edits_deletes: bool = False
|
||||||
|
profile_history: bool = False
|
||||||
|
stories: bool = False
|
||||||
|
presence: bool = False
|
||||||
|
backfill: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class PolicyCreate(CaptureToggles):
|
||||||
|
account_id: int | None = None
|
||||||
|
scope_type: ScopeType
|
||||||
|
scope_id: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class PolicyRecord(CaptureToggles):
|
||||||
|
id: int
|
||||||
|
account_id: int | None
|
||||||
|
scope_type: ScopeType
|
||||||
|
scope_id: int | None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class ChatMeta:
|
||||||
|
chat_id: int
|
||||||
|
kind: ChatKind
|
||||||
|
is_bot: bool = False
|
||||||
|
is_contact: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class FolderSpec:
|
||||||
|
folder_id: int
|
||||||
|
order_index: int
|
||||||
|
title: str
|
||||||
|
include_ids: set[int] = field(default_factory=set)
|
||||||
|
exclude_ids: set[int] = field(default_factory=set)
|
||||||
|
pinned_ids: set[int] = field(default_factory=set)
|
||||||
|
contacts: bool = False
|
||||||
|
non_contacts: bool = False
|
||||||
|
groups: bool = False
|
||||||
|
broadcasts: bool = False
|
||||||
|
bots: bool = False
|
||||||
|
is_chatlist: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class PolicySet:
|
||||||
|
chat: dict[int, CaptureToggles] = field(default_factory=dict)
|
||||||
|
folder: dict[int, CaptureToggles] = field(default_factory=dict)
|
||||||
|
defaults: dict[ChatKind, CaptureToggles] = field(default_factory=dict)
|
||||||
@@ -0,0 +1,154 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
import asyncpg
|
||||||
|
|
||||||
|
from utils.policy.models import (
|
||||||
|
KIND_BY_DEFAULT_SCOPE,
|
||||||
|
CaptureToggles,
|
||||||
|
FolderSpec,
|
||||||
|
PolicyRecord,
|
||||||
|
PolicySet,
|
||||||
|
ScopeType,
|
||||||
|
)
|
||||||
|
|
||||||
|
TOGGLES: tuple[str, ...] = tuple(CaptureToggles.model_fields)
|
||||||
|
_TOGGLE_COLS = ", ".join(TOGGLES)
|
||||||
|
_TOGGLE_SET = ", ".join(f"{name} = ${i}" for i, name in enumerate(TOGGLES, start=2))
|
||||||
|
|
||||||
|
|
||||||
|
def _toggle_values(toggles: CaptureToggles) -> tuple[bool, ...]:
|
||||||
|
return tuple(getattr(toggles, name) for name in TOGGLES)
|
||||||
|
|
||||||
|
|
||||||
|
def _folder_raw(spec: FolderSpec) -> str:
|
||||||
|
return json.dumps(
|
||||||
|
{
|
||||||
|
"include_ids": sorted(spec.include_ids),
|
||||||
|
"exclude_ids": sorted(spec.exclude_ids),
|
||||||
|
"pinned_ids": sorted(spec.pinned_ids),
|
||||||
|
"contacts": spec.contacts,
|
||||||
|
"non_contacts": spec.non_contacts,
|
||||||
|
"groups": spec.groups,
|
||||||
|
"broadcasts": spec.broadcasts,
|
||||||
|
"bots": spec.bots,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _row_to_folder(row: asyncpg.Record) -> FolderSpec:
|
||||||
|
raw = json.loads(row["raw"])
|
||||||
|
return FolderSpec(
|
||||||
|
folder_id=row["folder_id"],
|
||||||
|
order_index=row["order_index"],
|
||||||
|
title=row["title"],
|
||||||
|
include_ids=set(raw["include_ids"]),
|
||||||
|
exclude_ids=set(raw["exclude_ids"]),
|
||||||
|
pinned_ids=set(raw["pinned_ids"]),
|
||||||
|
contacts=raw["contacts"],
|
||||||
|
non_contacts=raw["non_contacts"],
|
||||||
|
groups=raw["groups"],
|
||||||
|
broadcasts=raw["broadcasts"],
|
||||||
|
bots=raw["bots"],
|
||||||
|
is_chatlist=row["is_chatlist"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def replace_folders(
|
||||||
|
pool: asyncpg.Pool, account_id: int, specs: list[FolderSpec]
|
||||||
|
) -> None:
|
||||||
|
async with pool.acquire() as conn, conn.transaction():
|
||||||
|
await conn.execute("DELETE FROM folders WHERE account_id = $1", account_id)
|
||||||
|
await conn.executemany(
|
||||||
|
"INSERT INTO folders "
|
||||||
|
"(account_id, folder_id, title, order_index, is_chatlist, raw, updated_at) "
|
||||||
|
"VALUES ($1, $2, $3, $4, $5, $6::jsonb, now())",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
account_id,
|
||||||
|
spec.folder_id,
|
||||||
|
spec.title,
|
||||||
|
spec.order_index,
|
||||||
|
spec.is_chatlist,
|
||||||
|
_folder_raw(spec),
|
||||||
|
)
|
||||||
|
for spec in specs
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def list_folders(pool: asyncpg.Pool, account_id: int) -> list[FolderSpec]:
|
||||||
|
rows = await pool.fetch(
|
||||||
|
"SELECT folder_id, title, order_index, is_chatlist, raw "
|
||||||
|
"FROM folders WHERE account_id = $1 ORDER BY order_index",
|
||||||
|
account_id,
|
||||||
|
)
|
||||||
|
return [_row_to_folder(row) for row in rows]
|
||||||
|
|
||||||
|
|
||||||
|
async def create_policy(
|
||||||
|
pool: asyncpg.Pool,
|
||||||
|
account_id: int | None,
|
||||||
|
scope_type: ScopeType,
|
||||||
|
scope_id: int | None,
|
||||||
|
toggles: CaptureToggles,
|
||||||
|
) -> PolicyRecord:
|
||||||
|
placeholders = ", ".join(f"${i}" for i in range(4, 4 + len(TOGGLES)))
|
||||||
|
row = await pool.fetchrow(
|
||||||
|
f"INSERT INTO capture_policy " # noqa: S608
|
||||||
|
f"(account_id, scope_type, scope_id, {_TOGGLE_COLS}) "
|
||||||
|
f"VALUES ($1, $2, $3, {placeholders}) RETURNING *",
|
||||||
|
account_id,
|
||||||
|
scope_type.value,
|
||||||
|
scope_id,
|
||||||
|
*_toggle_values(toggles),
|
||||||
|
)
|
||||||
|
return PolicyRecord(**dict(row))
|
||||||
|
|
||||||
|
|
||||||
|
async def get_policy(pool: asyncpg.Pool, policy_id: int) -> PolicyRecord | None:
|
||||||
|
row = await pool.fetchrow("SELECT * FROM capture_policy WHERE id = $1", policy_id)
|
||||||
|
return PolicyRecord(**dict(row)) if row else None
|
||||||
|
|
||||||
|
|
||||||
|
async def list_policies(pool: asyncpg.Pool, account_id: int) -> list[PolicyRecord]:
|
||||||
|
rows = await pool.fetch(
|
||||||
|
"SELECT * FROM capture_policy WHERE account_id = $1 OR account_id IS NULL "
|
||||||
|
"ORDER BY scope_type, scope_id",
|
||||||
|
account_id,
|
||||||
|
)
|
||||||
|
return [PolicyRecord(**dict(row)) for row in rows]
|
||||||
|
|
||||||
|
|
||||||
|
async def update_policy(
|
||||||
|
pool: asyncpg.Pool, policy_id: int, toggles: CaptureToggles
|
||||||
|
) -> PolicyRecord | None:
|
||||||
|
row = await pool.fetchrow(
|
||||||
|
f"UPDATE capture_policy SET {_TOGGLE_SET}, updated_at = now() " # noqa: S608
|
||||||
|
"WHERE id = $1 RETURNING *",
|
||||||
|
policy_id,
|
||||||
|
*_toggle_values(toggles),
|
||||||
|
)
|
||||||
|
return PolicyRecord(**dict(row)) if row else None
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_policy(pool: asyncpg.Pool, policy_id: int) -> bool:
|
||||||
|
result = await pool.execute("DELETE FROM capture_policy WHERE id = $1", policy_id)
|
||||||
|
return result.endswith("1")
|
||||||
|
|
||||||
|
|
||||||
|
async def load_policy_set(pool: asyncpg.Pool, account_id: int) -> PolicySet:
|
||||||
|
rows = await pool.fetch(
|
||||||
|
"SELECT * FROM capture_policy WHERE account_id = $1 OR account_id IS NULL",
|
||||||
|
account_id,
|
||||||
|
)
|
||||||
|
policies = PolicySet()
|
||||||
|
for row in rows:
|
||||||
|
toggles = CaptureToggles(**{name: row[name] for name in TOGGLES})
|
||||||
|
scope_type = ScopeType(row["scope_type"])
|
||||||
|
if scope_type is ScopeType.CHAT and row["scope_id"] is not None:
|
||||||
|
policies.chat[row["scope_id"]] = toggles
|
||||||
|
elif scope_type is ScopeType.FOLDER and row["scope_id"] is not None:
|
||||||
|
policies.folder[row["scope_id"]] = toggles
|
||||||
|
elif scope_type in KIND_BY_DEFAULT_SCOPE:
|
||||||
|
policies.defaults[KIND_BY_DEFAULT_SCOPE[scope_type]] = toggles
|
||||||
|
return policies
|
||||||
@@ -0,0 +1,46 @@
|
|||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
from utils.policy.models import (
|
||||||
|
CaptureToggles,
|
||||||
|
ChatKind,
|
||||||
|
ChatMeta,
|
||||||
|
FolderSpec,
|
||||||
|
PolicySet,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _category_match(folder: FolderSpec, chat: ChatMeta) -> bool:
|
||||||
|
if chat.kind is ChatKind.GROUP:
|
||||||
|
return folder.groups
|
||||||
|
if chat.kind is ChatKind.CHANNEL:
|
||||||
|
return folder.broadcasts
|
||||||
|
if chat.is_bot:
|
||||||
|
return folder.bots
|
||||||
|
if chat.is_contact is True:
|
||||||
|
return folder.contacts
|
||||||
|
if chat.is_contact is False:
|
||||||
|
return folder.non_contacts
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def folder_contains(folder: FolderSpec, chat: ChatMeta) -> bool:
|
||||||
|
if chat.chat_id in folder.exclude_ids:
|
||||||
|
return False
|
||||||
|
if chat.chat_id in folder.include_ids or chat.chat_id in folder.pinned_ids:
|
||||||
|
return True
|
||||||
|
if folder.is_chatlist:
|
||||||
|
return False
|
||||||
|
return _category_match(folder, chat)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve(
|
||||||
|
chat: ChatMeta, folders: Iterable[FolderSpec], policies: PolicySet
|
||||||
|
) -> CaptureToggles:
|
||||||
|
chat_override = policies.chat.get(chat.chat_id)
|
||||||
|
if chat_override is not None:
|
||||||
|
return chat_override
|
||||||
|
for folder in sorted(folders, key=lambda f: f.order_index):
|
||||||
|
toggles = policies.folder.get(folder.folder_id)
|
||||||
|
if toggles is not None and folder_contains(folder, chat):
|
||||||
|
return toggles
|
||||||
|
return policies.defaults.get(chat.kind, CaptureToggles())
|
||||||
Reference in New Issue
Block a user