feat: add event watcher

This commit is contained in:
h
2026-05-30 01:54:49 +02:00
parent c40e720163
commit f0afb7ec5b
11 changed files with 176 additions and 14 deletions
@@ -0,0 +1,33 @@
"""phase10.1 alert dedup key
Revision ID: a3f1c8e94d72
Revises: 6b2d95ac3b46
Create Date: 2026-05-30 02:00:00.000000
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "a3f1c8e94d72"
down_revision: str | None = "6b2d95ac3b46"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
op.add_column("alerts", sa.Column("dedup_key", sa.Text(), nullable=True))
op.create_index(
"ix_alerts_watch_dedup",
"alerts",
["watch_id", "dedup_key"],
unique=True,
postgresql_where=sa.text("dedup_key IS NOT NULL"),
)
def downgrade() -> None:
op.drop_index("ix_alerts_watch_dedup", table_name="alerts")
op.drop_column("alerts", "dedup_key")
+1
View File
@@ -13,6 +13,7 @@ async def on_message(client: PyroClient, message: Message) -> None:
if ctx is None or message.empty or message.chat is None or message.date is None: if ctx is None or message.empty or message.chat is None or message.date is None:
return return
meta = meta_from_chat(message.chat, ctx.contacts.ids) meta = meta_from_chat(message.chat, ctx.contacts.ids)
await ctx.watches.on_text(meta.chat_id, message.id, message.text or message.caption)
toggles = ctx.resolve(meta) toggles = ctx.resolve(meta)
if not toggles.messages: if not toggles.messages:
return return
+1
View File
@@ -21,6 +21,7 @@ async def on_user_status(client: PyroClient, user: User) -> None:
user.next_offline_date, user.next_offline_date,
str(user.raw), str(user.raw),
) )
await ctx.watches.on_status(user.id, is_online=user.status.name.lower() == "online")
handlers = on_user_status.handlers handlers = on_user_status.handlers
@@ -20,3 +20,4 @@ async def handle(
await repository.set_extracted_text( await repository.set_extracted_text(
ctx.pool, ctx.account_id, meta.chat_id, update.msg_id, update.text ctx.pool, ctx.account_id, meta.chat_id, update.msg_id, update.text
) )
await ctx.watches.on_text(meta.chat_id, update.msg_id, update.text)
@@ -3,6 +3,7 @@ from pyrogram import Client
from userbot.modules.contacts import ContactCache from userbot.modules.contacts import ContactCache
from userbot.modules.folders import FolderCache from userbot.modules.folders import FolderCache
from userbot.modules.watches import WatchCache
from utils.policy.models import CaptureToggles, ChatMeta, PolicySet from utils.policy.models import CaptureToggles, ChatMeta, PolicySet
from utils.policy.repository import load_policy_set from utils.policy.repository import load_policy_set
from utils.policy.resolver import resolve from utils.policy.resolver import resolve
@@ -23,6 +24,7 @@ class CaptureContext:
self.storage = storage self.storage = storage
self.folders = folders self.folders = folders
self.contacts = contacts self.contacts = contacts
self.watches = WatchCache(pool, account_id)
self.policies = PolicySet() self.policies = PolicySet()
async def reload_policies(self) -> None: async def reload_policies(self) -> None:
@@ -44,4 +46,5 @@ async def build_capture_context(
await contacts.refresh() await contacts.refresh()
ctx = CaptureContext(account_id, pool, storage, folders, contacts) ctx = CaptureContext(account_id, pool, storage, folders, contacts)
await ctx.reload_policies() await ctx.reload_policies()
await ctx.watches.refresh()
return ctx return ctx
@@ -0,0 +1,3 @@
from userbot.modules.watches.cache import WatchCache
__all__ = ["WatchCache"]
@@ -0,0 +1,60 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from userbot.modules.watches.evaluator import (
KIND_KEYWORD,
KIND_PEER_ONLINE,
match_keyword,
match_peer_online,
)
from utils.logging import logger
from utils.read import watches as watches_read
if TYPE_CHECKING:
import asyncpg
from utils.read.models import WatchView
class WatchCache:
def __init__(self, pool: asyncpg.Pool, account_id: int) -> None:
self._pool = pool
self._account_id = account_id
self.watches: list[WatchView] = []
self._online: set[int] = set()
async def refresh(self) -> None:
rows = await watches_read.list_watches(self._pool, self._account_id)
self.watches = [watch for watch in rows if watch.enabled]
logger.info(f"[green]Watches cached:[/] {len(self.watches)}")
async def on_text(self, chat_id: int, message_id: int, text: str | None) -> None:
if not text:
return
for watch in self.watches:
if watch.kind == KIND_KEYWORD and match_keyword(
text, chat_id, watch.params
):
await watches_read.insert_alert(
self._pool,
self._account_id,
watch.id,
{"chat_id": chat_id, "message_id": message_id},
dedup_key=f"{chat_id}:{message_id}",
)
async def on_status(self, peer_id: int, *, is_online: bool) -> None:
if not is_online:
self._online.discard(peer_id)
return
if peer_id in self._online:
return
self._online.add(peer_id)
for watch in self.watches:
if watch.kind == KIND_PEER_ONLINE and match_peer_online(
peer_id, watch.params
):
await watches_read.insert_alert(
self._pool, self._account_id, watch.id, {"peer_id": peer_id}
)
@@ -0,0 +1,24 @@
import re
from typing import Any
KIND_KEYWORD = "keyword"
KIND_PEER_ONLINE = "peer_online"
def match_keyword(text: str, chat_id: int, params: dict[str, Any]) -> bool:
target_chat = params.get("chat_id")
if target_chat is not None and target_chat != chat_id:
return False
pattern = params.get("pattern")
if not pattern:
return False
if params.get("regex"):
try:
return re.search(pattern, text) is not None
except re.error:
return False
return pattern.casefold() in text.casefold()
def match_peer_online(peer_id: int, params: dict[str, Any]) -> bool:
return params.get("peer_id") == peer_id
+21 -7
View File
@@ -1,17 +1,20 @@
import asyncio import asyncio
import contextlib import contextlib
import json import json
from collections.abc import Callable, Coroutine
from pathlib import Path from pathlib import Path
from typing import Any
import asyncpg import asyncpg
import uvloop import uvloop
from dependencies.container import container from dependencies.container import container
from userbot import PyroClient from userbot import PyroClient
from userbot.modules.capture import build_capture_context from userbot.modules.capture import CaptureContext, build_capture_context
from userbot.modules.jobs import JobConsumer from userbot.modules.jobs import JobConsumer
from utils.env import env from utils.env import env
from utils.logging import logger, setup_logging from utils.logging import logger, setup_logging
from utils.read.watches import WATCHES_CHANGED_CHANNEL
from utils.storage import ContentAddressedStorage from utils.storage import ContentAddressedStorage
setup_logging() setup_logging()
@@ -69,21 +72,32 @@ async def _setup_capture(
logger.info("[green]Capture context ready.[/]") logger.info("[green]Capture context ready.[/]")
async def _listen_policy_changes( async def _listen_changes(
clients: list[PyroClient], tasks: set[asyncio.Task] clients: list[PyroClient], tasks: set[asyncio.Task]
) -> asyncpg.Connection: ) -> asyncpg.Connection:
def on_change( def reload(
_conn: asyncpg.Connection, _pid: int, _channel: str, _payload: str make_coro: Callable[[CaptureContext], Coroutine[Any, Any, None]],
) -> None: ) -> None:
for client in clients: for client in clients:
if client.capture is None: if client.capture is None:
continue continue
task = asyncio.create_task(client.capture.reload_policies()) task = asyncio.create_task(make_coro(client.capture))
tasks.add(task) tasks.add(task)
task.add_done_callback(tasks.discard) task.add_done_callback(tasks.discard)
def on_policy(
_conn: asyncpg.Connection, _pid: int, _channel: str, _payload: str
) -> None:
reload(lambda capture: capture.reload_policies())
def on_watch(
_conn: asyncpg.Connection, _pid: int, _channel: str, _payload: str
) -> None:
reload(lambda capture: capture.watches.refresh())
conn = await asyncpg.connect(dsn=env.db.connection_url) conn = await asyncpg.connect(dsn=env.db.connection_url)
await conn.add_listener("policy_changed", on_change) await conn.add_listener("policy_changed", on_policy)
await conn.add_listener(WATCHES_CHANGED_CHANNEL, on_watch)
return conn return conn
@@ -122,7 +136,7 @@ async def runner() -> None:
consumer_tasks.append(asyncio.create_task(consumer.run())) consumer_tasks.append(asyncio.create_task(consumer.run()))
if clients: if clients:
listen_conn = await _listen_policy_changes(clients, reload_tasks) listen_conn = await _listen_changes(clients, reload_tasks)
logger.info("[green]Userbot running.[/]") logger.info("[green]Userbot running.[/]")
await asyncio.Event().wait() await asyncio.Event().wait()
finally: finally:
+2 -1
View File
@@ -2,7 +2,7 @@ from datetime import datetime
from enum import StrEnum from enum import StrEnum
from typing import Any from typing import Any
from sqlalchemy import BigInteger, Column, DateTime, LargeBinary, func from sqlalchemy import BigInteger, Column, DateTime, LargeBinary, Text, func
from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.dialects.postgresql import JSONB
from sqlmodel import Field, SQLModel from sqlmodel import Field, SQLModel
@@ -471,6 +471,7 @@ class Alert(SQLModel, table=True):
default_factory=dict, sa_column=Column(JSONB, nullable=False) default_factory=dict, sa_column=Column(JSONB, nullable=False)
) )
seen: bool = False seen: bool = False
dedup_key: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
created_at: datetime = Field( created_at: datetime = Field(
sa_column=Column( sa_column=Column(
DateTime(timezone=True), nullable=False, server_default=func.now() DateTime(timezone=True), nullable=False, server_default=func.now()
+27 -6
View File
@@ -6,6 +6,9 @@ import asyncpg
from utils.read.models import AlertView, Page, WatchView from utils.read.models import AlertView, Page, WatchView
WATCHES_CHANGED_CHANNEL = "watches_changed"
ALERTS_CHANGED_CHANNEL = "alerts_changed"
_WATCH_COLS = "id, account_id, kind, params, enabled, created_at, updated_at" _WATCH_COLS = "id, account_id, kind, params, enabled, created_at, updated_at"
_ALERT_COLS = "id, account_id, watch_id, ts, payload, seen, created_at" _ALERT_COLS = "id, account_id, watch_id, ts, payload, seen, created_at"
@@ -55,6 +58,7 @@ async def create_watch(
json.dumps(params), json.dumps(params),
enabled, enabled,
) )
await pool.execute(f"NOTIFY {WATCHES_CHANGED_CHANNEL}")
return _to_watch(row) return _to_watch(row)
@@ -68,12 +72,18 @@ async def update_watch(
json.dumps(params), json.dumps(params),
enabled, enabled,
) )
return _to_watch(row) if row else None if row is None:
return None
await pool.execute(f"NOTIFY {WATCHES_CHANGED_CHANNEL}")
return _to_watch(row)
async def delete_watch(pool: asyncpg.Pool, watch_id: int) -> bool: async def delete_watch(pool: asyncpg.Pool, watch_id: int) -> bool:
result = await pool.execute("DELETE FROM watches WHERE id = $1", watch_id) result = await pool.execute("DELETE FROM watches WHERE id = $1", watch_id)
return result.endswith("1") if not result.endswith("1"):
return False
await pool.execute(f"NOTIFY {WATCHES_CHANGED_CHANNEL}")
return True
async def list_alerts( async def list_alerts(
@@ -95,16 +105,27 @@ async def list_alerts(
async def insert_alert( async def insert_alert(
pool: asyncpg.Pool, account_id: int, watch_id: int, payload: dict[str, Any] pool: asyncpg.Pool,
) -> AlertView: account_id: int,
watch_id: int,
payload: dict[str, Any],
*,
dedup_key: str | None = None,
) -> AlertView | None:
row = await pool.fetchrow( row = await pool.fetchrow(
"INSERT INTO alerts (account_id, watch_id, ts, payload) " # noqa: S608 "INSERT INTO alerts (account_id, watch_id, ts, payload, seen, dedup_key) " # noqa: S608
f"VALUES ($1, $2, $3, $4::jsonb) RETURNING {_ALERT_COLS}", "VALUES ($1, $2, $3, $4::jsonb, false, $5) "
"ON CONFLICT (watch_id, dedup_key) WHERE dedup_key IS NOT NULL "
f"DO NOTHING RETURNING {_ALERT_COLS}",
account_id, account_id,
watch_id, watch_id,
datetime.now(UTC), datetime.now(UTC),
json.dumps(payload), json.dumps(payload),
dedup_key,
) )
if row is None:
return None
await pool.execute(f"NOTIFY {ALERTS_CHANGED_CHANNEL}")
return _to_alert(row) return _to_alert(row)