import asyncio import json import logging from typing import Any import asyncpg from utils.env import env from utils.events import BG_EVENTS_CHANNEL from utils.read import chats as chats_read from utils.read import presence as presence_read logger = logging.getLogger(__name__) QUEUE_MAXSIZE = 256 class Subscriber: def __init__(self, account_id: int, chat_id: int | None) -> None: self.account_id = account_id self.chat_id = chat_id self.queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue(maxsize=QUEUE_MAXSIZE) class EventHub: def __init__(self) -> None: self._subscribers: set[Subscriber] = set() self._pool: asyncpg.Pool | None = None self._conn: asyncpg.Connection | None = None self._tasks: set[asyncio.Task] = set() def subscribe(self, account_id: int, chat_id: int | None) -> Subscriber: sub = Subscriber(account_id, chat_id) self._subscribers.add(sub) return sub def unsubscribe(self, sub: Subscriber) -> None: self._subscribers.discard(sub) async def start(self, pool: asyncpg.Pool) -> None: self._pool = pool conn = await asyncpg.connect(dsn=env.db.connection_url) await conn.add_listener(BG_EVENTS_CHANNEL, self._on_notify) self._conn = conn logger.info("Realtime hub listening on %s", BG_EVENTS_CHANNEL) async def stop(self) -> None: for task in self._tasks: task.cancel() if self._conn is not None: await self._conn.close() self._conn = None def _on_notify( self, _conn: asyncpg.Connection, _pid: int, _channel: str, payload: str ) -> None: task = asyncio.create_task(self._dispatch(payload)) self._tasks.add(task) task.add_done_callback(self._tasks.discard) async def _dispatch(self, payload: str) -> None: try: event = json.loads(payload) except json.JSONDecodeError: return account_id = event.get("account_id") chat_id = event.get("chat_id") targets = [ sub for sub in self._subscribers if sub.account_id == account_id and (sub.chat_id is None or sub.chat_id == chat_id) ] if not targets: return frame = await self._build_frame(event) if frame is None: return for sub in targets: try: sub.queue.put_nowait(frame) except asyncio.QueueFull: logger.warning("Dropping event for slow subscriber") async def _build_frame( # noqa: PLR0911 self, event: dict[str, Any] ) -> dict[str, Any] | None: if self._pool is None: return None kind = event.get("kind") account_id = event["account_id"] if kind in {"message", "edit", "reaction"}: view = await chats_read.get_message( self._pool, account_id, event["chat_id"], event["message_id"] ) if view is None: return None return {"type": kind, "message": view.model_dump(mode="json")} if kind == "delete": return { "type": "delete", "chat_id": event.get("chat_id"), "message_ids": event.get("message_ids", []), } if kind == "presence": sample = await presence_read.current_presence( self._pool, account_id, event["chat_id"] ) return { "type": "presence", "peer_id": event["chat_id"], "sample": sample.model_dump(mode="json") if sample else None, } if kind == "receipt": return { "type": "receipt", "chat_id": event["chat_id"], "read_up_to": event["message_id"], } return None hub = EventHub()