149 lines
4.7 KiB
Python
149 lines
4.7 KiB
Python
import asyncio
|
|
import contextlib
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import asyncpg
|
|
import uvloop
|
|
|
|
from dependencies.container import container
|
|
from userbot import PyroClient
|
|
from userbot.modules.capture import build_capture_context
|
|
from userbot.modules.jobs import JobConsumer
|
|
from utils.env import env
|
|
from utils.logging import logger, setup_logging
|
|
from utils.storage import ContentAddressedStorage
|
|
|
|
setup_logging()
|
|
|
|
_UPSERT_ACCOUNT = """
|
|
INSERT INTO accounts
|
|
(tg_user_id, label, phone, session_name, is_active, raw, updated_at)
|
|
VALUES ($1, $2, $3, $4, TRUE, $5::jsonb, now())
|
|
ON CONFLICT (tg_user_id) DO UPDATE SET
|
|
label = EXCLUDED.label,
|
|
phone = EXCLUDED.phone,
|
|
session_name = EXCLUDED.session_name,
|
|
is_active = TRUE,
|
|
raw = EXCLUDED.raw,
|
|
updated_at = now()
|
|
RETURNING account_id
|
|
"""
|
|
|
|
|
|
def _discover_sessions(sessions_dir: Path) -> list[Path]:
|
|
sessions_dir.mkdir(parents=True, exist_ok=True)
|
|
return sorted(sessions_dir.glob("*.session"))
|
|
|
|
|
|
async def _sync_account(
|
|
pool: asyncpg.Pool, client: PyroClient, session_name: str
|
|
) -> int | None:
|
|
me = client.me
|
|
if not me:
|
|
return None
|
|
raw = json.dumps(
|
|
{
|
|
"id": me.id,
|
|
"first_name": me.first_name,
|
|
"last_name": me.last_name,
|
|
"username": me.username,
|
|
"phone_number": me.phone_number,
|
|
}
|
|
)
|
|
label = " ".join(filter(None, [me.first_name, me.last_name])) or me.username
|
|
account_id = await pool.fetchval(
|
|
_UPSERT_ACCOUNT, me.id, label, me.phone_number, session_name, raw
|
|
)
|
|
logger.info(f"[green]Account synced:[/] {label} ({me.id})")
|
|
return account_id
|
|
|
|
|
|
async def _setup_capture(
|
|
pool: asyncpg.Pool,
|
|
client: PyroClient,
|
|
account_id: int,
|
|
storage: ContentAddressedStorage,
|
|
) -> None:
|
|
client.capture = await build_capture_context(client, pool, storage, account_id)
|
|
logger.info("[green]Capture context ready.[/]")
|
|
|
|
|
|
async def _listen_policy_changes(
|
|
clients: list[PyroClient], tasks: set[asyncio.Task]
|
|
) -> asyncpg.Connection:
|
|
def on_change(
|
|
_conn: asyncpg.Connection, _pid: int, _channel: str, _payload: str
|
|
) -> None:
|
|
for client in clients:
|
|
if client.capture is None:
|
|
continue
|
|
task = asyncio.create_task(client.capture.reload_policies())
|
|
tasks.add(task)
|
|
task.add_done_callback(tasks.discard)
|
|
|
|
conn = await asyncpg.connect(dsn=env.db.connection_url)
|
|
await conn.add_listener("policy_changed", on_change)
|
|
return conn
|
|
|
|
|
|
async def runner() -> None:
|
|
pool = await container.get(asyncpg.Pool)
|
|
storage = await container.get(ContentAddressedStorage)
|
|
|
|
sessions_dir = Path(env.tg.sessions_dir)
|
|
session_files = _discover_sessions(sessions_dir)
|
|
|
|
if not session_files:
|
|
logger.warning(
|
|
f"[yellow]No .session files in {sessions_dir}/. "
|
|
f"Log in first, then restart userbot.[/]"
|
|
)
|
|
|
|
clients: list[PyroClient] = []
|
|
reload_tasks: set[asyncio.Task] = set()
|
|
consumer_tasks: list[asyncio.Task] = []
|
|
listen_conn: asyncpg.Connection | None = None
|
|
try:
|
|
for session_path in session_files:
|
|
session_name = session_path.stem
|
|
client = PyroClient(session_name, workdir=str(sessions_dir))
|
|
await client.start()
|
|
clients.append(client)
|
|
logger.info(
|
|
f"[green]Client started:[/] "
|
|
f"{client.me.full_name if client.me else 'unknown'} "
|
|
f"{client.me.id if client.me else 'unknown'}"
|
|
)
|
|
account_id = await _sync_account(pool, client, session_name)
|
|
if account_id is not None:
|
|
await _setup_capture(pool, client, account_id, storage)
|
|
consumer = JobConsumer(client, pool, account_id)
|
|
consumer_tasks.append(asyncio.create_task(consumer.run()))
|
|
|
|
if clients:
|
|
listen_conn = await _listen_policy_changes(clients, reload_tasks)
|
|
logger.info("[green]Userbot running.[/]")
|
|
await asyncio.Event().wait()
|
|
finally:
|
|
for task in consumer_tasks:
|
|
task.cancel()
|
|
for task in consumer_tasks:
|
|
with contextlib.suppress(asyncio.CancelledError):
|
|
await task
|
|
if listen_conn is not None:
|
|
with contextlib.suppress(Exception):
|
|
await listen_conn.close()
|
|
for client in clients:
|
|
with contextlib.suppress(Exception):
|
|
await client.stop()
|
|
await container.close()
|
|
|
|
|
|
def main() -> None:
|
|
uvloop.install()
|
|
logger.info("Starting userbot...")
|
|
with contextlib.suppress(KeyboardInterrupt):
|
|
asyncio.run(runner())
|
|
logger.info("[red]Userbot stopped.[/]")
|