feat: add realtime capturing

This commit is contained in:
h
2026-05-29 18:19:06 +02:00
parent 920a0235e2
commit 3c1a12750c
29 changed files with 967 additions and 47 deletions
+36 -18
View File
@@ -7,14 +7,11 @@ import asyncpg
import uvloop
from dependencies.container import container
from userbot.folders import FolderCache
from userbot.handlers import dialog_filter_handler
from userbot.modules import PyroClient
from userbot import PyroClient
from userbot.modules.capture import build_capture_context
from utils.env import env
from utils.logging import logger, setup_logging
from utils.policy.models import ChatKind, ChatMeta
from utils.policy.repository import load_policy_set
from utils.policy.resolver import resolve
from utils.storage import ContentAddressedStorage
setup_logging()
@@ -61,22 +58,37 @@ async def _sync_account(
return account_id
async def _setup_policy(
pool: asyncpg.Pool, client: PyroClient, account_id: int
async def _setup_capture(
pool: asyncpg.Pool,
client: PyroClient,
account_id: int,
storage: ContentAddressedStorage,
) -> None:
cache = FolderCache(client, pool, account_id)
await cache.refresh()
client.add_handler(dialog_filter_handler(cache))
if client.me:
policies = await load_policy_set(pool, account_id)
sample = resolve(
ChatMeta(chat_id=client.me.id, kind=ChatKind.DM), cache.folders, policies
)
logger.info(f"[green]Sample resolve (self DM):[/] {sample.model_dump()}")
client.capture = await build_capture_context(client, pool, storage, account_id)
logger.info("[green]Capture context ready.[/]")
async def _listen_policy_changes(
clients: list[PyroClient], tasks: set[asyncio.Task]
) -> asyncpg.Connection:
def on_change(
_conn: asyncpg.Connection, _pid: int, _channel: str, _payload: str
) -> None:
for client in clients:
if client.capture is None:
continue
task = asyncio.create_task(client.capture.reload_policies())
tasks.add(task)
task.add_done_callback(tasks.discard)
conn = await asyncpg.connect(dsn=env.db.connection_url)
await conn.add_listener("policy_changed", on_change)
return conn
async def runner() -> None:
pool = await container.get(asyncpg.Pool)
storage = await container.get(ContentAddressedStorage)
sessions_dir = Path(env.tg.sessions_dir)
session_files = _discover_sessions(sessions_dir)
@@ -88,6 +100,8 @@ async def runner() -> None:
)
clients: list[PyroClient] = []
reload_tasks: set[asyncio.Task] = set()
listen_conn: asyncpg.Connection | None = None
try:
for session_path in session_files:
session_name = session_path.stem
@@ -101,12 +115,16 @@ async def runner() -> None:
)
account_id = await _sync_account(pool, client, session_name)
if account_id is not None:
await _setup_policy(pool, client, account_id)
await _setup_capture(pool, client, account_id, storage)
if clients:
listen_conn = await _listen_policy_changes(clients, reload_tasks)
logger.info("[green]Userbot running.[/]")
await asyncio.Event().wait()
finally:
if listen_conn is not None:
with contextlib.suppress(Exception):
await listen_conn.close()
for client in clients:
with contextlib.suppress(Exception):
await client.stop()