feat: add event watcher
This commit is contained in:
@@ -0,0 +1,33 @@
|
||||
"""phase10.1 alert dedup key
|
||||
|
||||
Revision ID: a3f1c8e94d72
|
||||
Revises: 6b2d95ac3b46
|
||||
Create Date: 2026-05-30 02:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision: str = "a3f1c8e94d72"
|
||||
down_revision: str | None = "6b2d95ac3b46"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("alerts", sa.Column("dedup_key", sa.Text(), nullable=True))
|
||||
op.create_index(
|
||||
"ix_alerts_watch_dedup",
|
||||
"alerts",
|
||||
["watch_id", "dedup_key"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("dedup_key IS NOT NULL"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_alerts_watch_dedup", table_name="alerts")
|
||||
op.drop_column("alerts", "dedup_key")
|
||||
@@ -13,6 +13,7 @@ async def on_message(client: PyroClient, message: Message) -> None:
|
||||
if ctx is None or message.empty or message.chat is None or message.date is None:
|
||||
return
|
||||
meta = meta_from_chat(message.chat, ctx.contacts.ids)
|
||||
await ctx.watches.on_text(meta.chat_id, message.id, message.text or message.caption)
|
||||
toggles = ctx.resolve(meta)
|
||||
if not toggles.messages:
|
||||
return
|
||||
|
||||
@@ -21,6 +21,7 @@ async def on_user_status(client: PyroClient, user: User) -> None:
|
||||
user.next_offline_date,
|
||||
str(user.raw),
|
||||
)
|
||||
await ctx.watches.on_status(user.id, is_online=user.status.name.lower() == "online")
|
||||
|
||||
|
||||
handlers = on_user_status.handlers
|
||||
|
||||
@@ -20,3 +20,4 @@ async def handle(
|
||||
await repository.set_extracted_text(
|
||||
ctx.pool, ctx.account_id, meta.chat_id, update.msg_id, update.text
|
||||
)
|
||||
await ctx.watches.on_text(meta.chat_id, update.msg_id, update.text)
|
||||
|
||||
@@ -3,6 +3,7 @@ from pyrogram import Client
|
||||
|
||||
from userbot.modules.contacts import ContactCache
|
||||
from userbot.modules.folders import FolderCache
|
||||
from userbot.modules.watches import WatchCache
|
||||
from utils.policy.models import CaptureToggles, ChatMeta, PolicySet
|
||||
from utils.policy.repository import load_policy_set
|
||||
from utils.policy.resolver import resolve
|
||||
@@ -23,6 +24,7 @@ class CaptureContext:
|
||||
self.storage = storage
|
||||
self.folders = folders
|
||||
self.contacts = contacts
|
||||
self.watches = WatchCache(pool, account_id)
|
||||
self.policies = PolicySet()
|
||||
|
||||
async def reload_policies(self) -> None:
|
||||
@@ -44,4 +46,5 @@ async def build_capture_context(
|
||||
await contacts.refresh()
|
||||
ctx = CaptureContext(account_id, pool, storage, folders, contacts)
|
||||
await ctx.reload_policies()
|
||||
await ctx.watches.refresh()
|
||||
return ctx
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from userbot.modules.watches.cache import WatchCache
|
||||
|
||||
__all__ = ["WatchCache"]
|
||||
@@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from userbot.modules.watches.evaluator import (
|
||||
KIND_KEYWORD,
|
||||
KIND_PEER_ONLINE,
|
||||
match_keyword,
|
||||
match_peer_online,
|
||||
)
|
||||
from utils.logging import logger
|
||||
from utils.read import watches as watches_read
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import asyncpg
|
||||
|
||||
from utils.read.models import WatchView
|
||||
|
||||
|
||||
class WatchCache:
|
||||
def __init__(self, pool: asyncpg.Pool, account_id: int) -> None:
|
||||
self._pool = pool
|
||||
self._account_id = account_id
|
||||
self.watches: list[WatchView] = []
|
||||
self._online: set[int] = set()
|
||||
|
||||
async def refresh(self) -> None:
|
||||
rows = await watches_read.list_watches(self._pool, self._account_id)
|
||||
self.watches = [watch for watch in rows if watch.enabled]
|
||||
logger.info(f"[green]Watches cached:[/] {len(self.watches)}")
|
||||
|
||||
async def on_text(self, chat_id: int, message_id: int, text: str | None) -> None:
|
||||
if not text:
|
||||
return
|
||||
for watch in self.watches:
|
||||
if watch.kind == KIND_KEYWORD and match_keyword(
|
||||
text, chat_id, watch.params
|
||||
):
|
||||
await watches_read.insert_alert(
|
||||
self._pool,
|
||||
self._account_id,
|
||||
watch.id,
|
||||
{"chat_id": chat_id, "message_id": message_id},
|
||||
dedup_key=f"{chat_id}:{message_id}",
|
||||
)
|
||||
|
||||
async def on_status(self, peer_id: int, *, is_online: bool) -> None:
|
||||
if not is_online:
|
||||
self._online.discard(peer_id)
|
||||
return
|
||||
if peer_id in self._online:
|
||||
return
|
||||
self._online.add(peer_id)
|
||||
for watch in self.watches:
|
||||
if watch.kind == KIND_PEER_ONLINE and match_peer_online(
|
||||
peer_id, watch.params
|
||||
):
|
||||
await watches_read.insert_alert(
|
||||
self._pool, self._account_id, watch.id, {"peer_id": peer_id}
|
||||
)
|
||||
@@ -0,0 +1,24 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
KIND_KEYWORD = "keyword"
|
||||
KIND_PEER_ONLINE = "peer_online"
|
||||
|
||||
|
||||
def match_keyword(text: str, chat_id: int, params: dict[str, Any]) -> bool:
|
||||
target_chat = params.get("chat_id")
|
||||
if target_chat is not None and target_chat != chat_id:
|
||||
return False
|
||||
pattern = params.get("pattern")
|
||||
if not pattern:
|
||||
return False
|
||||
if params.get("regex"):
|
||||
try:
|
||||
return re.search(pattern, text) is not None
|
||||
except re.error:
|
||||
return False
|
||||
return pattern.casefold() in text.casefold()
|
||||
|
||||
|
||||
def match_peer_online(peer_id: int, params: dict[str, Any]) -> bool:
|
||||
return params.get("peer_id") == peer_id
|
||||
@@ -1,17 +1,20 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
from collections.abc import Callable, Coroutine
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import asyncpg
|
||||
import uvloop
|
||||
|
||||
from dependencies.container import container
|
||||
from userbot import PyroClient
|
||||
from userbot.modules.capture import build_capture_context
|
||||
from userbot.modules.capture import CaptureContext, build_capture_context
|
||||
from userbot.modules.jobs import JobConsumer
|
||||
from utils.env import env
|
||||
from utils.logging import logger, setup_logging
|
||||
from utils.read.watches import WATCHES_CHANGED_CHANNEL
|
||||
from utils.storage import ContentAddressedStorage
|
||||
|
||||
setup_logging()
|
||||
@@ -69,21 +72,32 @@ async def _setup_capture(
|
||||
logger.info("[green]Capture context ready.[/]")
|
||||
|
||||
|
||||
async def _listen_policy_changes(
|
||||
async def _listen_changes(
|
||||
clients: list[PyroClient], tasks: set[asyncio.Task]
|
||||
) -> asyncpg.Connection:
|
||||
def on_change(
|
||||
_conn: asyncpg.Connection, _pid: int, _channel: str, _payload: str
|
||||
def reload(
|
||||
make_coro: Callable[[CaptureContext], Coroutine[Any, Any, None]],
|
||||
) -> None:
|
||||
for client in clients:
|
||||
if client.capture is None:
|
||||
continue
|
||||
task = asyncio.create_task(client.capture.reload_policies())
|
||||
task = asyncio.create_task(make_coro(client.capture))
|
||||
tasks.add(task)
|
||||
task.add_done_callback(tasks.discard)
|
||||
|
||||
def on_policy(
|
||||
_conn: asyncpg.Connection, _pid: int, _channel: str, _payload: str
|
||||
) -> None:
|
||||
reload(lambda capture: capture.reload_policies())
|
||||
|
||||
def on_watch(
|
||||
_conn: asyncpg.Connection, _pid: int, _channel: str, _payload: str
|
||||
) -> None:
|
||||
reload(lambda capture: capture.watches.refresh())
|
||||
|
||||
conn = await asyncpg.connect(dsn=env.db.connection_url)
|
||||
await conn.add_listener("policy_changed", on_change)
|
||||
await conn.add_listener("policy_changed", on_policy)
|
||||
await conn.add_listener(WATCHES_CHANGED_CHANNEL, on_watch)
|
||||
return conn
|
||||
|
||||
|
||||
@@ -122,7 +136,7 @@ async def runner() -> None:
|
||||
consumer_tasks.append(asyncio.create_task(consumer.run()))
|
||||
|
||||
if clients:
|
||||
listen_conn = await _listen_policy_changes(clients, reload_tasks)
|
||||
listen_conn = await _listen_changes(clients, reload_tasks)
|
||||
logger.info("[green]Userbot running.[/]")
|
||||
await asyncio.Event().wait()
|
||||
finally:
|
||||
|
||||
@@ -2,7 +2,7 @@ from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import BigInteger, Column, DateTime, LargeBinary, func
|
||||
from sqlalchemy import BigInteger, Column, DateTime, LargeBinary, Text, func
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
@@ -471,6 +471,7 @@ class Alert(SQLModel, table=True):
|
||||
default_factory=dict, sa_column=Column(JSONB, nullable=False)
|
||||
)
|
||||
seen: bool = False
|
||||
dedup_key: str | None = Field(default=None, sa_column=Column(Text, nullable=True))
|
||||
created_at: datetime = Field(
|
||||
sa_column=Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
|
||||
@@ -6,6 +6,9 @@ import asyncpg
|
||||
|
||||
from utils.read.models import AlertView, Page, WatchView
|
||||
|
||||
WATCHES_CHANGED_CHANNEL = "watches_changed"
|
||||
ALERTS_CHANGED_CHANNEL = "alerts_changed"
|
||||
|
||||
_WATCH_COLS = "id, account_id, kind, params, enabled, created_at, updated_at"
|
||||
_ALERT_COLS = "id, account_id, watch_id, ts, payload, seen, created_at"
|
||||
|
||||
@@ -55,6 +58,7 @@ async def create_watch(
|
||||
json.dumps(params),
|
||||
enabled,
|
||||
)
|
||||
await pool.execute(f"NOTIFY {WATCHES_CHANGED_CHANNEL}")
|
||||
return _to_watch(row)
|
||||
|
||||
|
||||
@@ -68,12 +72,18 @@ async def update_watch(
|
||||
json.dumps(params),
|
||||
enabled,
|
||||
)
|
||||
return _to_watch(row) if row else None
|
||||
if row is None:
|
||||
return None
|
||||
await pool.execute(f"NOTIFY {WATCHES_CHANGED_CHANNEL}")
|
||||
return _to_watch(row)
|
||||
|
||||
|
||||
async def delete_watch(pool: asyncpg.Pool, watch_id: int) -> bool:
|
||||
result = await pool.execute("DELETE FROM watches WHERE id = $1", watch_id)
|
||||
return result.endswith("1")
|
||||
if not result.endswith("1"):
|
||||
return False
|
||||
await pool.execute(f"NOTIFY {WATCHES_CHANGED_CHANNEL}")
|
||||
return True
|
||||
|
||||
|
||||
async def list_alerts(
|
||||
@@ -95,16 +105,27 @@ async def list_alerts(
|
||||
|
||||
|
||||
async def insert_alert(
|
||||
pool: asyncpg.Pool, account_id: int, watch_id: int, payload: dict[str, Any]
|
||||
) -> AlertView:
|
||||
pool: asyncpg.Pool,
|
||||
account_id: int,
|
||||
watch_id: int,
|
||||
payload: dict[str, Any],
|
||||
*,
|
||||
dedup_key: str | None = None,
|
||||
) -> AlertView | None:
|
||||
row = await pool.fetchrow(
|
||||
"INSERT INTO alerts (account_id, watch_id, ts, payload) " # noqa: S608
|
||||
f"VALUES ($1, $2, $3, $4::jsonb) RETURNING {_ALERT_COLS}",
|
||||
"INSERT INTO alerts (account_id, watch_id, ts, payload, seen, dedup_key) " # noqa: S608
|
||||
"VALUES ($1, $2, $3, $4::jsonb, false, $5) "
|
||||
"ON CONFLICT (watch_id, dedup_key) WHERE dedup_key IS NOT NULL "
|
||||
f"DO NOTHING RETURNING {_ALERT_COLS}",
|
||||
account_id,
|
||||
watch_id,
|
||||
datetime.now(UTC),
|
||||
json.dumps(payload),
|
||||
dedup_key,
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
await pool.execute(f"NOTIFY {ALERTS_CHANGED_CHANNEL}")
|
||||
return _to_alert(row)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user