import asyncio import contextlib import json from pathlib import Path import asyncpg import uvloop from dependencies.container import container from userbot import PyroClient from userbot.modules.capture import build_capture_context from userbot.modules.jobs import JobConsumer from utils.env import env from utils.logging import logger, setup_logging from utils.storage import ContentAddressedStorage setup_logging() _UPSERT_ACCOUNT = """ INSERT INTO accounts (tg_user_id, label, phone, session_name, is_active, raw, updated_at) VALUES ($1, $2, $3, $4, TRUE, $5::jsonb, now()) ON CONFLICT (tg_user_id) DO UPDATE SET label = EXCLUDED.label, phone = EXCLUDED.phone, session_name = EXCLUDED.session_name, is_active = TRUE, raw = EXCLUDED.raw, updated_at = now() RETURNING account_id """ def _discover_sessions(sessions_dir: Path) -> list[Path]: sessions_dir.mkdir(parents=True, exist_ok=True) return sorted(sessions_dir.glob("*.session")) async def _sync_account( pool: asyncpg.Pool, client: PyroClient, session_name: str ) -> int | None: me = client.me if not me: return None raw = json.dumps( { "id": me.id, "first_name": me.first_name, "last_name": me.last_name, "username": me.username, "phone_number": me.phone_number, } ) label = " ".join(filter(None, [me.first_name, me.last_name])) or me.username account_id = await pool.fetchval( _UPSERT_ACCOUNT, me.id, label, me.phone_number, session_name, raw ) logger.info(f"[green]Account synced:[/] {label} ({me.id})") return account_id async def _setup_capture( pool: asyncpg.Pool, client: PyroClient, account_id: int, storage: ContentAddressedStorage, ) -> None: client.capture = await build_capture_context(client, pool, storage, account_id) logger.info("[green]Capture context ready.[/]") async def _listen_policy_changes( clients: list[PyroClient], tasks: set[asyncio.Task] ) -> asyncpg.Connection: def on_change( _conn: asyncpg.Connection, _pid: int, _channel: str, _payload: str ) -> None: for client in clients: if client.capture is None: continue task = asyncio.create_task(client.capture.reload_policies()) tasks.add(task) task.add_done_callback(tasks.discard) conn = await asyncpg.connect(dsn=env.db.connection_url) await conn.add_listener("policy_changed", on_change) return conn async def runner() -> None: pool = await container.get(asyncpg.Pool) storage = await container.get(ContentAddressedStorage) sessions_dir = Path(env.tg.sessions_dir) session_files = _discover_sessions(sessions_dir) if not session_files: logger.warning( f"[yellow]No .session files in {sessions_dir}/. " f"Log in first, then restart userbot.[/]" ) clients: list[PyroClient] = [] reload_tasks: set[asyncio.Task] = set() consumer_tasks: list[asyncio.Task] = [] listen_conn: asyncpg.Connection | None = None try: for session_path in session_files: session_name = session_path.stem client = PyroClient(session_name, workdir=str(sessions_dir)) await client.start() clients.append(client) logger.info( f"[green]Client started:[/] " f"{client.me.full_name if client.me else 'unknown'} " f"{client.me.id if client.me else 'unknown'}" ) account_id = await _sync_account(pool, client, session_name) if account_id is not None: await _setup_capture(pool, client, account_id, storage) consumer = JobConsumer(client, pool, account_id) consumer_tasks.append(asyncio.create_task(consumer.run())) if clients: listen_conn = await _listen_policy_changes(clients, reload_tasks) logger.info("[green]Userbot running.[/]") await asyncio.Event().wait() finally: for task in consumer_tasks: task.cancel() for task in consumer_tasks: with contextlib.suppress(asyncio.CancelledError): await task if listen_conn is not None: with contextlib.suppress(Exception): await listen_conn.close() for client in clients: with contextlib.suppress(Exception): await client.stop() await container.close() def main() -> None: uvloop.install() logger.info("Starting userbot...") with contextlib.suppress(KeyboardInterrupt): asyncio.run(runner()) logger.info("[red]Userbot stopped.[/]")