Files
beavergram/backend/src/userbot/runner.py
T
2026-05-29 19:04:36 +02:00

149 lines
4.7 KiB
Python

import asyncio
import contextlib
import json
from pathlib import Path
import asyncpg
import uvloop
from dependencies.container import container
from userbot import PyroClient
from userbot.modules.capture import build_capture_context
from userbot.modules.jobs import JobConsumer
from utils.env import env
from utils.logging import logger, setup_logging
from utils.storage import ContentAddressedStorage
setup_logging()
_UPSERT_ACCOUNT = """
INSERT INTO accounts
(tg_user_id, label, phone, session_name, is_active, raw, updated_at)
VALUES ($1, $2, $3, $4, TRUE, $5::jsonb, now())
ON CONFLICT (tg_user_id) DO UPDATE SET
label = EXCLUDED.label,
phone = EXCLUDED.phone,
session_name = EXCLUDED.session_name,
is_active = TRUE,
raw = EXCLUDED.raw,
updated_at = now()
RETURNING account_id
"""
def _discover_sessions(sessions_dir: Path) -> list[Path]:
sessions_dir.mkdir(parents=True, exist_ok=True)
return sorted(sessions_dir.glob("*.session"))
async def _sync_account(
pool: asyncpg.Pool, client: PyroClient, session_name: str
) -> int | None:
me = client.me
if not me:
return None
raw = json.dumps(
{
"id": me.id,
"first_name": me.first_name,
"last_name": me.last_name,
"username": me.username,
"phone_number": me.phone_number,
}
)
label = " ".join(filter(None, [me.first_name, me.last_name])) or me.username
account_id = await pool.fetchval(
_UPSERT_ACCOUNT, me.id, label, me.phone_number, session_name, raw
)
logger.info(f"[green]Account synced:[/] {label} ({me.id})")
return account_id
async def _setup_capture(
pool: asyncpg.Pool,
client: PyroClient,
account_id: int,
storage: ContentAddressedStorage,
) -> None:
client.capture = await build_capture_context(client, pool, storage, account_id)
logger.info("[green]Capture context ready.[/]")
async def _listen_policy_changes(
clients: list[PyroClient], tasks: set[asyncio.Task]
) -> asyncpg.Connection:
def on_change(
_conn: asyncpg.Connection, _pid: int, _channel: str, _payload: str
) -> None:
for client in clients:
if client.capture is None:
continue
task = asyncio.create_task(client.capture.reload_policies())
tasks.add(task)
task.add_done_callback(tasks.discard)
conn = await asyncpg.connect(dsn=env.db.connection_url)
await conn.add_listener("policy_changed", on_change)
return conn
async def runner() -> None:
pool = await container.get(asyncpg.Pool)
storage = await container.get(ContentAddressedStorage)
sessions_dir = Path(env.tg.sessions_dir)
session_files = _discover_sessions(sessions_dir)
if not session_files:
logger.warning(
f"[yellow]No .session files in {sessions_dir}/. "
f"Log in first, then restart userbot.[/]"
)
clients: list[PyroClient] = []
reload_tasks: set[asyncio.Task] = set()
consumer_tasks: list[asyncio.Task] = []
listen_conn: asyncpg.Connection | None = None
try:
for session_path in session_files:
session_name = session_path.stem
client = PyroClient(session_name, workdir=str(sessions_dir))
await client.start()
clients.append(client)
logger.info(
f"[green]Client started:[/] "
f"{client.me.full_name if client.me else 'unknown'} "
f"{client.me.id if client.me else 'unknown'}"
)
account_id = await _sync_account(pool, client, session_name)
if account_id is not None:
await _setup_capture(pool, client, account_id, storage)
consumer = JobConsumer(client, pool, account_id)
consumer_tasks.append(asyncio.create_task(consumer.run()))
if clients:
listen_conn = await _listen_policy_changes(clients, reload_tasks)
logger.info("[green]Userbot running.[/]")
await asyncio.Event().wait()
finally:
for task in consumer_tasks:
task.cancel()
for task in consumer_tasks:
with contextlib.suppress(asyncio.CancelledError):
await task
if listen_conn is not None:
with contextlib.suppress(Exception):
await listen_conn.close()
for client in clients:
with contextlib.suppress(Exception):
await client.stop()
await container.close()
def main() -> None:
uvloop.install()
logger.info("Starting userbot...")
with contextlib.suppress(KeyboardInterrupt):
asyncio.run(runner())
logger.info("[red]Userbot stopped.[/]")