feat: add realtime capturing
This commit is contained in:
@@ -7,14 +7,11 @@ import asyncpg
|
||||
import uvloop
|
||||
|
||||
from dependencies.container import container
|
||||
from userbot.folders import FolderCache
|
||||
from userbot.handlers import dialog_filter_handler
|
||||
from userbot.modules import PyroClient
|
||||
from userbot import PyroClient
|
||||
from userbot.modules.capture import build_capture_context
|
||||
from utils.env import env
|
||||
from utils.logging import logger, setup_logging
|
||||
from utils.policy.models import ChatKind, ChatMeta
|
||||
from utils.policy.repository import load_policy_set
|
||||
from utils.policy.resolver import resolve
|
||||
from utils.storage import ContentAddressedStorage
|
||||
|
||||
setup_logging()
|
||||
|
||||
@@ -61,22 +58,37 @@ async def _sync_account(
|
||||
return account_id
|
||||
|
||||
|
||||
async def _setup_policy(
|
||||
pool: asyncpg.Pool, client: PyroClient, account_id: int
|
||||
async def _setup_capture(
|
||||
pool: asyncpg.Pool,
|
||||
client: PyroClient,
|
||||
account_id: int,
|
||||
storage: ContentAddressedStorage,
|
||||
) -> None:
|
||||
cache = FolderCache(client, pool, account_id)
|
||||
await cache.refresh()
|
||||
client.add_handler(dialog_filter_handler(cache))
|
||||
if client.me:
|
||||
policies = await load_policy_set(pool, account_id)
|
||||
sample = resolve(
|
||||
ChatMeta(chat_id=client.me.id, kind=ChatKind.DM), cache.folders, policies
|
||||
)
|
||||
logger.info(f"[green]Sample resolve (self DM):[/] {sample.model_dump()}")
|
||||
client.capture = await build_capture_context(client, pool, storage, account_id)
|
||||
logger.info("[green]Capture context ready.[/]")
|
||||
|
||||
|
||||
async def _listen_policy_changes(
|
||||
clients: list[PyroClient], tasks: set[asyncio.Task]
|
||||
) -> asyncpg.Connection:
|
||||
def on_change(
|
||||
_conn: asyncpg.Connection, _pid: int, _channel: str, _payload: str
|
||||
) -> None:
|
||||
for client in clients:
|
||||
if client.capture is None:
|
||||
continue
|
||||
task = asyncio.create_task(client.capture.reload_policies())
|
||||
tasks.add(task)
|
||||
task.add_done_callback(tasks.discard)
|
||||
|
||||
conn = await asyncpg.connect(dsn=env.db.connection_url)
|
||||
await conn.add_listener("policy_changed", on_change)
|
||||
return conn
|
||||
|
||||
|
||||
async def runner() -> None:
|
||||
pool = await container.get(asyncpg.Pool)
|
||||
storage = await container.get(ContentAddressedStorage)
|
||||
|
||||
sessions_dir = Path(env.tg.sessions_dir)
|
||||
session_files = _discover_sessions(sessions_dir)
|
||||
@@ -88,6 +100,8 @@ async def runner() -> None:
|
||||
)
|
||||
|
||||
clients: list[PyroClient] = []
|
||||
reload_tasks: set[asyncio.Task] = set()
|
||||
listen_conn: asyncpg.Connection | None = None
|
||||
try:
|
||||
for session_path in session_files:
|
||||
session_name = session_path.stem
|
||||
@@ -101,12 +115,16 @@ async def runner() -> None:
|
||||
)
|
||||
account_id = await _sync_account(pool, client, session_name)
|
||||
if account_id is not None:
|
||||
await _setup_policy(pool, client, account_id)
|
||||
await _setup_capture(pool, client, account_id, storage)
|
||||
|
||||
if clients:
|
||||
listen_conn = await _listen_policy_changes(clients, reload_tasks)
|
||||
logger.info("[green]Userbot running.[/]")
|
||||
await asyncio.Event().wait()
|
||||
finally:
|
||||
if listen_conn is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
await listen_conn.close()
|
||||
for client in clients:
|
||||
with contextlib.suppress(Exception):
|
||||
await client.stop()
|
||||
|
||||
Reference in New Issue
Block a user