124 lines
3.9 KiB
Python
124 lines
3.9 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import Any
|
|
|
|
import asyncpg
|
|
|
|
from utils.env import env
|
|
from utils.events import BG_EVENTS_CHANNEL
|
|
from utils.read import chats as chats_read
|
|
from utils.read import presence as presence_read
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
QUEUE_MAXSIZE = 256
|
|
|
|
|
|
class Subscriber:
|
|
def __init__(self, account_id: int, chat_id: int | None) -> None:
|
|
self.account_id = account_id
|
|
self.chat_id = chat_id
|
|
self.queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue(maxsize=QUEUE_MAXSIZE)
|
|
|
|
|
|
class EventHub:
|
|
def __init__(self) -> None:
|
|
self._subscribers: set[Subscriber] = set()
|
|
self._pool: asyncpg.Pool | None = None
|
|
self._conn: asyncpg.Connection | None = None
|
|
self._tasks: set[asyncio.Task] = set()
|
|
|
|
def subscribe(self, account_id: int, chat_id: int | None) -> Subscriber:
|
|
sub = Subscriber(account_id, chat_id)
|
|
self._subscribers.add(sub)
|
|
return sub
|
|
|
|
def unsubscribe(self, sub: Subscriber) -> None:
|
|
self._subscribers.discard(sub)
|
|
|
|
async def start(self, pool: asyncpg.Pool) -> None:
|
|
self._pool = pool
|
|
conn = await asyncpg.connect(dsn=env.db.connection_url)
|
|
await conn.add_listener(BG_EVENTS_CHANNEL, self._on_notify)
|
|
self._conn = conn
|
|
logger.info("Realtime hub listening on %s", BG_EVENTS_CHANNEL)
|
|
|
|
async def stop(self) -> None:
|
|
for task in self._tasks:
|
|
task.cancel()
|
|
if self._conn is not None:
|
|
await self._conn.close()
|
|
self._conn = None
|
|
|
|
def _on_notify(
|
|
self, _conn: asyncpg.Connection, _pid: int, _channel: str, payload: str
|
|
) -> None:
|
|
task = asyncio.create_task(self._dispatch(payload))
|
|
self._tasks.add(task)
|
|
task.add_done_callback(self._tasks.discard)
|
|
|
|
async def _dispatch(self, payload: str) -> None:
|
|
try:
|
|
event = json.loads(payload)
|
|
except json.JSONDecodeError:
|
|
return
|
|
account_id = event.get("account_id")
|
|
chat_id = event.get("chat_id")
|
|
targets = [
|
|
sub
|
|
for sub in self._subscribers
|
|
if sub.account_id == account_id
|
|
and (sub.chat_id is None or sub.chat_id == chat_id)
|
|
]
|
|
if not targets:
|
|
return
|
|
frame = await self._build_frame(event)
|
|
if frame is None:
|
|
return
|
|
for sub in targets:
|
|
try:
|
|
sub.queue.put_nowait(frame)
|
|
except asyncio.QueueFull:
|
|
logger.warning("Dropping event for slow subscriber")
|
|
|
|
async def _build_frame( # noqa: PLR0911
|
|
self, event: dict[str, Any]
|
|
) -> dict[str, Any] | None:
|
|
if self._pool is None:
|
|
return None
|
|
kind = event.get("kind")
|
|
account_id = event["account_id"]
|
|
if kind in {"message", "edit", "reaction"}:
|
|
view = await chats_read.get_message(
|
|
self._pool, account_id, event["chat_id"], event["message_id"]
|
|
)
|
|
if view is None:
|
|
return None
|
|
return {"type": kind, "message": view.model_dump(mode="json")}
|
|
if kind == "delete":
|
|
return {
|
|
"type": "delete",
|
|
"chat_id": event.get("chat_id"),
|
|
"message_ids": event.get("message_ids", []),
|
|
}
|
|
if kind == "presence":
|
|
sample = await presence_read.current_presence(
|
|
self._pool, account_id, event["chat_id"]
|
|
)
|
|
return {
|
|
"type": "presence",
|
|
"peer_id": event["chat_id"],
|
|
"sample": sample.model_dump(mode="json") if sample else None,
|
|
}
|
|
if kind == "receipt":
|
|
return {
|
|
"type": "receipt",
|
|
"chat_id": event["chat_id"],
|
|
"read_up_to": event["message_id"],
|
|
}
|
|
return None
|
|
|
|
|
|
hub = EventHub()
|