feat: create message capture policies
This commit is contained in:
@@ -48,3 +48,56 @@ class Message(SQLModel, table=True):
|
||||
deleted_at: datetime | None = Field(
|
||||
default=None, sa_column=Column(DateTime(timezone=True))
|
||||
)
|
||||
|
||||
|
||||
class Folder(SQLModel, table=True):
|
||||
__tablename__ = "folders"
|
||||
|
||||
account_id: int = Field(primary_key=True)
|
||||
folder_id: int = Field(primary_key=True)
|
||||
title: str
|
||||
order_index: int
|
||||
is_chatlist: bool = False
|
||||
raw: dict[str, Any] = Field(
|
||||
default_factory=dict, sa_column=Column(JSONB, nullable=False)
|
||||
)
|
||||
updated_at: datetime = Field(
|
||||
sa_column=Column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class CapturePolicy(SQLModel, table=True):
|
||||
__tablename__ = "capture_policy"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
account_id: int | None = None
|
||||
scope_type: str
|
||||
scope_id: int | None = Field(default=None, sa_column=Column(BigInteger))
|
||||
messages: bool = False
|
||||
media: bool = False
|
||||
self_destruct_media: bool = False
|
||||
stt: bool = False
|
||||
reactions: bool = False
|
||||
track_edits_deletes: bool = False
|
||||
profile_history: bool = False
|
||||
stories: bool = False
|
||||
presence: bool = False
|
||||
backfill: bool = False
|
||||
created_at: datetime = Field(
|
||||
sa_column=Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
)
|
||||
updated_at: datetime = Field(
|
||||
sa_column=Column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
from utils.policy.models import CaptureToggles, ChatKind
|
||||
|
||||
DEFAULTS: dict[ChatKind, CaptureToggles] = {
|
||||
ChatKind.CHANNEL: CaptureToggles(),
|
||||
ChatKind.GROUP: CaptureToggles(messages=True),
|
||||
ChatKind.DM: CaptureToggles(
|
||||
messages=True,
|
||||
media=True,
|
||||
self_destruct_media=True,
|
||||
stt=True,
|
||||
reactions=True,
|
||||
track_edits_deletes=True,
|
||||
profile_history=True,
|
||||
stories=True,
|
||||
presence=True,
|
||||
backfill=True,
|
||||
),
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
from dataclasses import dataclass, field
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ChatKind(StrEnum):
|
||||
DM = "dm"
|
||||
GROUP = "group"
|
||||
CHANNEL = "channel"
|
||||
|
||||
|
||||
class ScopeType(StrEnum):
|
||||
DEFAULT_DM = "default_dm"
|
||||
DEFAULT_GROUP = "default_group"
|
||||
DEFAULT_CHANNEL = "default_channel"
|
||||
FOLDER = "folder"
|
||||
CHAT = "chat"
|
||||
|
||||
|
||||
DEFAULT_SCOPE_BY_KIND: dict[ChatKind, ScopeType] = {
|
||||
ChatKind.DM: ScopeType.DEFAULT_DM,
|
||||
ChatKind.GROUP: ScopeType.DEFAULT_GROUP,
|
||||
ChatKind.CHANNEL: ScopeType.DEFAULT_CHANNEL,
|
||||
}
|
||||
|
||||
KIND_BY_DEFAULT_SCOPE: dict[ScopeType, ChatKind] = {
|
||||
scope: kind for kind, scope in DEFAULT_SCOPE_BY_KIND.items()
|
||||
}
|
||||
|
||||
|
||||
class CaptureToggles(BaseModel):
|
||||
messages: bool = False
|
||||
media: bool = False
|
||||
self_destruct_media: bool = False
|
||||
stt: bool = False
|
||||
reactions: bool = False
|
||||
track_edits_deletes: bool = False
|
||||
profile_history: bool = False
|
||||
stories: bool = False
|
||||
presence: bool = False
|
||||
backfill: bool = False
|
||||
|
||||
|
||||
class PolicyCreate(CaptureToggles):
|
||||
account_id: int | None = None
|
||||
scope_type: ScopeType
|
||||
scope_id: int | None = None
|
||||
|
||||
|
||||
class PolicyRecord(CaptureToggles):
|
||||
id: int
|
||||
account_id: int | None
|
||||
scope_type: ScopeType
|
||||
scope_id: int | None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ChatMeta:
|
||||
chat_id: int
|
||||
kind: ChatKind
|
||||
is_bot: bool = False
|
||||
is_contact: bool | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class FolderSpec:
|
||||
folder_id: int
|
||||
order_index: int
|
||||
title: str
|
||||
include_ids: set[int] = field(default_factory=set)
|
||||
exclude_ids: set[int] = field(default_factory=set)
|
||||
pinned_ids: set[int] = field(default_factory=set)
|
||||
contacts: bool = False
|
||||
non_contacts: bool = False
|
||||
groups: bool = False
|
||||
broadcasts: bool = False
|
||||
bots: bool = False
|
||||
is_chatlist: bool = False
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class PolicySet:
|
||||
chat: dict[int, CaptureToggles] = field(default_factory=dict)
|
||||
folder: dict[int, CaptureToggles] = field(default_factory=dict)
|
||||
defaults: dict[ChatKind, CaptureToggles] = field(default_factory=dict)
|
||||
@@ -0,0 +1,154 @@
|
||||
import json
|
||||
|
||||
import asyncpg
|
||||
|
||||
from utils.policy.models import (
|
||||
KIND_BY_DEFAULT_SCOPE,
|
||||
CaptureToggles,
|
||||
FolderSpec,
|
||||
PolicyRecord,
|
||||
PolicySet,
|
||||
ScopeType,
|
||||
)
|
||||
|
||||
TOGGLES: tuple[str, ...] = tuple(CaptureToggles.model_fields)
|
||||
_TOGGLE_COLS = ", ".join(TOGGLES)
|
||||
_TOGGLE_SET = ", ".join(f"{name} = ${i}" for i, name in enumerate(TOGGLES, start=2))
|
||||
|
||||
|
||||
def _toggle_values(toggles: CaptureToggles) -> tuple[bool, ...]:
|
||||
return tuple(getattr(toggles, name) for name in TOGGLES)
|
||||
|
||||
|
||||
def _folder_raw(spec: FolderSpec) -> str:
|
||||
return json.dumps(
|
||||
{
|
||||
"include_ids": sorted(spec.include_ids),
|
||||
"exclude_ids": sorted(spec.exclude_ids),
|
||||
"pinned_ids": sorted(spec.pinned_ids),
|
||||
"contacts": spec.contacts,
|
||||
"non_contacts": spec.non_contacts,
|
||||
"groups": spec.groups,
|
||||
"broadcasts": spec.broadcasts,
|
||||
"bots": spec.bots,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _row_to_folder(row: asyncpg.Record) -> FolderSpec:
|
||||
raw = json.loads(row["raw"])
|
||||
return FolderSpec(
|
||||
folder_id=row["folder_id"],
|
||||
order_index=row["order_index"],
|
||||
title=row["title"],
|
||||
include_ids=set(raw["include_ids"]),
|
||||
exclude_ids=set(raw["exclude_ids"]),
|
||||
pinned_ids=set(raw["pinned_ids"]),
|
||||
contacts=raw["contacts"],
|
||||
non_contacts=raw["non_contacts"],
|
||||
groups=raw["groups"],
|
||||
broadcasts=raw["broadcasts"],
|
||||
bots=raw["bots"],
|
||||
is_chatlist=row["is_chatlist"],
|
||||
)
|
||||
|
||||
|
||||
async def replace_folders(
|
||||
pool: asyncpg.Pool, account_id: int, specs: list[FolderSpec]
|
||||
) -> None:
|
||||
async with pool.acquire() as conn, conn.transaction():
|
||||
await conn.execute("DELETE FROM folders WHERE account_id = $1", account_id)
|
||||
await conn.executemany(
|
||||
"INSERT INTO folders "
|
||||
"(account_id, folder_id, title, order_index, is_chatlist, raw, updated_at) "
|
||||
"VALUES ($1, $2, $3, $4, $5, $6::jsonb, now())",
|
||||
[
|
||||
(
|
||||
account_id,
|
||||
spec.folder_id,
|
||||
spec.title,
|
||||
spec.order_index,
|
||||
spec.is_chatlist,
|
||||
_folder_raw(spec),
|
||||
)
|
||||
for spec in specs
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
async def list_folders(pool: asyncpg.Pool, account_id: int) -> list[FolderSpec]:
|
||||
rows = await pool.fetch(
|
||||
"SELECT folder_id, title, order_index, is_chatlist, raw "
|
||||
"FROM folders WHERE account_id = $1 ORDER BY order_index",
|
||||
account_id,
|
||||
)
|
||||
return [_row_to_folder(row) for row in rows]
|
||||
|
||||
|
||||
async def create_policy(
|
||||
pool: asyncpg.Pool,
|
||||
account_id: int | None,
|
||||
scope_type: ScopeType,
|
||||
scope_id: int | None,
|
||||
toggles: CaptureToggles,
|
||||
) -> PolicyRecord:
|
||||
placeholders = ", ".join(f"${i}" for i in range(4, 4 + len(TOGGLES)))
|
||||
row = await pool.fetchrow(
|
||||
f"INSERT INTO capture_policy " # noqa: S608
|
||||
f"(account_id, scope_type, scope_id, {_TOGGLE_COLS}) "
|
||||
f"VALUES ($1, $2, $3, {placeholders}) RETURNING *",
|
||||
account_id,
|
||||
scope_type.value,
|
||||
scope_id,
|
||||
*_toggle_values(toggles),
|
||||
)
|
||||
return PolicyRecord(**dict(row))
|
||||
|
||||
|
||||
async def get_policy(pool: asyncpg.Pool, policy_id: int) -> PolicyRecord | None:
|
||||
row = await pool.fetchrow("SELECT * FROM capture_policy WHERE id = $1", policy_id)
|
||||
return PolicyRecord(**dict(row)) if row else None
|
||||
|
||||
|
||||
async def list_policies(pool: asyncpg.Pool, account_id: int) -> list[PolicyRecord]:
|
||||
rows = await pool.fetch(
|
||||
"SELECT * FROM capture_policy WHERE account_id = $1 OR account_id IS NULL "
|
||||
"ORDER BY scope_type, scope_id",
|
||||
account_id,
|
||||
)
|
||||
return [PolicyRecord(**dict(row)) for row in rows]
|
||||
|
||||
|
||||
async def update_policy(
|
||||
pool: asyncpg.Pool, policy_id: int, toggles: CaptureToggles
|
||||
) -> PolicyRecord | None:
|
||||
row = await pool.fetchrow(
|
||||
f"UPDATE capture_policy SET {_TOGGLE_SET}, updated_at = now() " # noqa: S608
|
||||
"WHERE id = $1 RETURNING *",
|
||||
policy_id,
|
||||
*_toggle_values(toggles),
|
||||
)
|
||||
return PolicyRecord(**dict(row)) if row else None
|
||||
|
||||
|
||||
async def delete_policy(pool: asyncpg.Pool, policy_id: int) -> bool:
|
||||
result = await pool.execute("DELETE FROM capture_policy WHERE id = $1", policy_id)
|
||||
return result.endswith("1")
|
||||
|
||||
|
||||
async def load_policy_set(pool: asyncpg.Pool, account_id: int) -> PolicySet:
|
||||
rows = await pool.fetch(
|
||||
"SELECT * FROM capture_policy WHERE account_id = $1 OR account_id IS NULL",
|
||||
account_id,
|
||||
)
|
||||
policies = PolicySet()
|
||||
for row in rows:
|
||||
toggles = CaptureToggles(**{name: row[name] for name in TOGGLES})
|
||||
scope_type = ScopeType(row["scope_type"])
|
||||
if scope_type is ScopeType.CHAT and row["scope_id"] is not None:
|
||||
policies.chat[row["scope_id"]] = toggles
|
||||
elif scope_type is ScopeType.FOLDER and row["scope_id"] is not None:
|
||||
policies.folder[row["scope_id"]] = toggles
|
||||
elif scope_type in KIND_BY_DEFAULT_SCOPE:
|
||||
policies.defaults[KIND_BY_DEFAULT_SCOPE[scope_type]] = toggles
|
||||
return policies
|
||||
@@ -0,0 +1,46 @@
|
||||
from collections.abc import Iterable
|
||||
|
||||
from utils.policy.models import (
|
||||
CaptureToggles,
|
||||
ChatKind,
|
||||
ChatMeta,
|
||||
FolderSpec,
|
||||
PolicySet,
|
||||
)
|
||||
|
||||
|
||||
def _category_match(folder: FolderSpec, chat: ChatMeta) -> bool:
|
||||
if chat.kind is ChatKind.GROUP:
|
||||
return folder.groups
|
||||
if chat.kind is ChatKind.CHANNEL:
|
||||
return folder.broadcasts
|
||||
if chat.is_bot:
|
||||
return folder.bots
|
||||
if chat.is_contact is True:
|
||||
return folder.contacts
|
||||
if chat.is_contact is False:
|
||||
return folder.non_contacts
|
||||
return False
|
||||
|
||||
|
||||
def folder_contains(folder: FolderSpec, chat: ChatMeta) -> bool:
|
||||
if chat.chat_id in folder.exclude_ids:
|
||||
return False
|
||||
if chat.chat_id in folder.include_ids or chat.chat_id in folder.pinned_ids:
|
||||
return True
|
||||
if folder.is_chatlist:
|
||||
return False
|
||||
return _category_match(folder, chat)
|
||||
|
||||
|
||||
def resolve(
|
||||
chat: ChatMeta, folders: Iterable[FolderSpec], policies: PolicySet
|
||||
) -> CaptureToggles:
|
||||
chat_override = policies.chat.get(chat.chat_id)
|
||||
if chat_override is not None:
|
||||
return chat_override
|
||||
for folder in sorted(folders, key=lambda f: f.order_index):
|
||||
toggles = policies.folder.get(folder.folder_id)
|
||||
if toggles is not None and folder_contains(folder, chat):
|
||||
return toggles
|
||||
return policies.defaults.get(chat.kind, CaptureToggles())
|
||||
Reference in New Issue
Block a user