From 51093da660d7adcdb6aa258eace441d66ebe0bd8 Mon Sep 17 00:00:00 2001 From: h Date: Fri, 29 May 2026 19:33:57 +0200 Subject: [PATCH] feat: add backfills logic --- backend/src/api/app.py | 7 +- backend/src/api/routers/backfill.py | 119 ++++++++++++++++++ backend/src/api/routers/folders.py | 5 +- backend/src/api/routers/policy.py | 10 +- backend/src/userbot/handlers/edits.py | 2 +- backend/src/userbot/handlers/messages.py | 50 +------- .../src/userbot/modules/capture/__init__.py | 3 +- .../src/userbot/modules/capture/message.py | 57 +++++++++ backend/src/userbot/modules/jobs/__init__.py | 2 + .../userbot/modules/jobs/handlers/__init__.py | 3 + .../userbot/modules/jobs/handlers/backfill.py | 33 +++++ .../modules/jobs/handlers/fetch_media.py | 23 ++++ 12 files changed, 251 insertions(+), 63 deletions(-) create mode 100644 backend/src/api/routers/backfill.py create mode 100644 backend/src/userbot/modules/capture/message.py create mode 100644 backend/src/userbot/modules/jobs/handlers/__init__.py create mode 100644 backend/src/userbot/modules/jobs/handlers/backfill.py create mode 100644 backend/src/userbot/modules/jobs/handlers/fetch_media.py diff --git a/backend/src/api/app.py b/backend/src/api/app.py index 1aea606..3944054 100644 --- a/backend/src/api/app.py +++ b/backend/src/api/app.py @@ -2,10 +2,10 @@ from collections.abc import AsyncGenerator from contextlib import asynccontextmanager import asyncpg -from dishka.integrations.fastapi import FromDishka, inject, setup_dishka +from dishka.integrations.fastapi import DishkaRoute, FromDishka, setup_dishka from fastapi import FastAPI -from api.routers import folders, policy +from api.routers import backfill, folders, policy from dependencies.container import container @@ -16,10 +16,10 @@ async def lifespan(app_: FastAPI) -> AsyncGenerator[None]: app = FastAPI(title="beavergram API", lifespan=lifespan) +app.router.route_class = DishkaRoute @app.get("/health") -@inject async def health(pool: FromDishka[asyncpg.Pool]) -> dict[str, bool]: db_ok = await pool.fetchval("SELECT 1") == 1 timescale_ok = await pool.fetchval( @@ -30,5 +30,6 @@ async def health(pool: FromDishka[asyncpg.Pool]) -> dict[str, bool]: app.include_router(policy.router) app.include_router(folders.router) +app.include_router(backfill.router) setup_dishka(container, app) diff --git a/backend/src/api/routers/backfill.py b/backend/src/api/routers/backfill.py new file mode 100644 index 0000000..0e4b969 --- /dev/null +++ b/backend/src/api/routers/backfill.py @@ -0,0 +1,119 @@ +import json +from datetime import datetime +from typing import Annotated, Any + +import asyncpg +from dishka.integrations.fastapi import DishkaRoute, FromDishka +from fastapi import APIRouter, HTTPException, Query +from pydantic import BaseModel + +router = APIRouter(prefix="/api", tags=["backfill"], route_class=DishkaRoute) + +JOBS_CHANGED_CHANNEL = "jobs_changed" + + +class BackfillRequest(BaseModel): + account_id: int + chat_id: int + media: bool = False + + +class FetchMediaRequest(BaseModel): + account_id: int + chat_id: int + message_id: int + + +class EnqueueResponse(BaseModel): + job_id: int + + +class JobView(BaseModel): + id: int + account_id: int + kind: str + status: str + params: dict[str, Any] + cursor: dict[str, Any] | None + progress: dict[str, Any] + flood_waits: int + attempts: int + error: str | None + created_at: datetime + started_at: datetime | None + finished_at: datetime | None + + +def _to_view(row: asyncpg.Record) -> JobView: + data = dict(row) + data["params"] = json.loads(data["params"]) + data["progress"] = json.loads(data["progress"]) + data["cursor"] = json.loads(data["cursor"]) if data["cursor"] is not None else None + return JobView(**data) + + +async def _enqueue( + pool: asyncpg.Pool, account_id: int, kind: str, params: dict[str, Any] +) -> int: + job_id = await pool.fetchval( + "INSERT INTO jobs (account_id, kind, params) " + "VALUES ($1, $2, $3::jsonb) RETURNING id", + account_id, + kind, + json.dumps(params), + ) + await pool.execute(f"NOTIFY {JOBS_CHANGED_CHANNEL}") + return job_id + + +@router.post("/backfill", status_code=201) +async def enqueue_backfill( + pool: FromDishka[asyncpg.Pool], body: BackfillRequest +) -> EnqueueResponse: + job_id = await _enqueue( + pool, + body.account_id, + "backfill", + {"chat_id": body.chat_id, "media": body.media}, + ) + return EnqueueResponse(job_id=job_id) + + +@router.post("/media/fetch", status_code=201) +async def enqueue_fetch_media( + pool: FromDishka[asyncpg.Pool], body: FetchMediaRequest +) -> EnqueueResponse: + job_id = await _enqueue( + pool, + body.account_id, + "fetch_media", + {"chat_id": body.chat_id, "message_id": body.message_id}, + ) + return EnqueueResponse(job_id=job_id) + + +@router.get("/jobs") +async def list_jobs( + pool: FromDishka[asyncpg.Pool], + account_id: Annotated[int, Query()], + status: Annotated[str | None, Query()] = None, +) -> list[JobView]: + if status is None: + rows = await pool.fetch( + "SELECT * FROM jobs WHERE account_id = $1 ORDER BY id DESC", account_id + ) + else: + rows = await pool.fetch( + "SELECT * FROM jobs WHERE account_id = $1 AND status = $2 ORDER BY id DESC", + account_id, + status, + ) + return [_to_view(row) for row in rows] + + +@router.get("/jobs/{job_id}") +async def get_job(pool: FromDishka[asyncpg.Pool], job_id: int) -> JobView: + row = await pool.fetchrow("SELECT * FROM jobs WHERE id = $1", job_id) + if row is None: + raise HTTPException(status_code=404, detail="job not found") + return _to_view(row) diff --git a/backend/src/api/routers/folders.py b/backend/src/api/routers/folders.py index def1af4..57a0218 100644 --- a/backend/src/api/routers/folders.py +++ b/backend/src/api/routers/folders.py @@ -1,13 +1,13 @@ from typing import Annotated import asyncpg -from dishka.integrations.fastapi import FromDishka, inject +from dishka.integrations.fastapi import DishkaRoute, FromDishka from fastapi import APIRouter, Query from utils.policy import repository from utils.policy.models import FolderSpec -router = APIRouter(prefix="/api/folders", tags=["folders"]) +router = APIRouter(prefix="/api/folders", tags=["folders"], route_class=DishkaRoute) def _serialize(spec: FolderSpec) -> dict: @@ -28,7 +28,6 @@ def _serialize(spec: FolderSpec) -> dict: @router.get("") -@inject async def list_folders( pool: FromDishka[asyncpg.Pool], account_id: Annotated[int, Query()] ) -> list[dict]: diff --git a/backend/src/api/routers/policy.py b/backend/src/api/routers/policy.py index b4785a5..4970bf5 100644 --- a/backend/src/api/routers/policy.py +++ b/backend/src/api/routers/policy.py @@ -1,7 +1,7 @@ from typing import Annotated import asyncpg -from dishka.integrations.fastapi import FromDishka, inject +from dishka.integrations.fastapi import DishkaRoute, FromDishka from fastapi import APIRouter, HTTPException, Query from pydantic import BaseModel @@ -15,7 +15,7 @@ from utils.policy.models import ( ) from utils.policy.resolver import resolve -router = APIRouter(prefix="/api/policy", tags=["policy"]) +router = APIRouter(prefix="/api/policy", tags=["policy"], route_class=DishkaRoute) POLICY_CHANGED_CHANNEL = "policy_changed" @@ -29,7 +29,6 @@ class EffectiveQuery(BaseModel): @router.get("") -@inject async def list_policies( pool: FromDishka[asyncpg.Pool], account_id: Annotated[int, Query()] ) -> list[PolicyRecord]: @@ -37,7 +36,6 @@ async def list_policies( @router.get("/effective") -@inject async def effective_policy( pool: FromDishka[asyncpg.Pool], query: Annotated[EffectiveQuery, Query()] ) -> CaptureToggles: @@ -53,7 +51,6 @@ async def effective_policy( @router.post("", status_code=201) -@inject async def create_policy( pool: FromDishka[asyncpg.Pool], body: PolicyCreate ) -> PolicyRecord: @@ -65,7 +62,6 @@ async def create_policy( @router.get("/{policy_id}") -@inject async def get_policy(pool: FromDishka[asyncpg.Pool], policy_id: int) -> PolicyRecord: record = await repository.get_policy(pool, policy_id) if record is None: @@ -74,7 +70,6 @@ async def get_policy(pool: FromDishka[asyncpg.Pool], policy_id: int) -> PolicyRe @router.put("/{policy_id}") -@inject async def update_policy( pool: FromDishka[asyncpg.Pool], policy_id: int, body: CaptureToggles ) -> PolicyRecord: @@ -86,7 +81,6 @@ async def update_policy( @router.delete("/{policy_id}", status_code=204) -@inject async def delete_policy(pool: FromDishka[asyncpg.Pool], policy_id: int) -> None: if not await repository.delete_policy(pool, policy_id): raise HTTPException(status_code=404, detail="policy not found") diff --git a/backend/src/userbot/handlers/edits.py b/backend/src/userbot/handlers/edits.py index fc130f3..16396d5 100644 --- a/backend/src/userbot/handlers/edits.py +++ b/backend/src/userbot/handlers/edits.py @@ -1,9 +1,9 @@ from pyrogram.types import Message from userbot import PyroClient -from userbot.handlers.messages import sender_id from userbot.modules.capture import repository from userbot.modules.capture.chat_meta import meta_from_chat +from userbot.modules.capture.message import sender_id from userbot.modules.media import self_destruct_ttl diff --git a/backend/src/userbot/handlers/messages.py b/backend/src/userbot/handlers/messages.py index ef4c568..879f669 100644 --- a/backend/src/userbot/handlers/messages.py +++ b/backend/src/userbot/handlers/messages.py @@ -1,33 +1,8 @@ from pyrogram.types import Message from userbot import PyroClient -from userbot.modules.capture import repository +from userbot.modules.capture import capture_message from userbot.modules.capture.chat_meta import meta_from_chat -from userbot.modules.media import capture_media, self_destruct_ttl - - -def sender_id(message: Message) -> int | None: - if message.from_user is not None: - return message.from_user.id - if message.sender_chat is not None: - return message.sender_chat.id - return None - - -def _callbacks(message: Message) -> list[tuple[int, str | None, bytes | None]]: - rows = getattr(message.reply_markup, "inline_keyboard", None) - if not rows: - return [] - buttons: list[tuple[int, str | None, bytes | None]] = [] - position = 0 - for row in rows: - for button in row: - data = button.callback_data - if data is not None: - encoded = data.encode() if isinstance(data, str) else data - buttons.append((position, button.text, encoded)) - position += 1 - return buttons @PyroClient.on_message() @@ -35,30 +10,11 @@ async def on_message(client: PyroClient, message: Message) -> None: ctx = client.capture if ctx is None or message.empty or message.chat is None or message.date is None: return - chat = message.chat - chat_id = chat.id or 0 - meta = meta_from_chat(chat, ctx.contacts.ids) + meta = meta_from_chat(message.chat, ctx.contacts.ids) toggles = ctx.resolve(meta) if not toggles.messages: return - await repository.upsert_message( - ctx.pool, - ctx.account_id, - chat_id, - message.id, - message.date, - sender_id(message), - message.text or message.caption, - str(message), - has_media=message.media is not None, - is_self_destruct=self_destruct_ttl(message) is not None, - ) - await capture_media(client, message, ctx, chat_id, message.id, toggles) - buttons = _callbacks(message) - if buttons: - await repository.insert_callbacks( - ctx.pool, ctx.account_id, chat_id, message.id, buttons - ) + await capture_message(client, message, ctx, toggles) handlers = on_message.handlers diff --git a/backend/src/userbot/modules/capture/__init__.py b/backend/src/userbot/modules/capture/__init__.py index c8437ad..7905452 100644 --- a/backend/src/userbot/modules/capture/__init__.py +++ b/backend/src/userbot/modules/capture/__init__.py @@ -1,3 +1,4 @@ from userbot.modules.capture.context import CaptureContext, build_capture_context +from userbot.modules.capture.message import capture_message -__all__ = ["CaptureContext", "build_capture_context"] +__all__ = ["CaptureContext", "build_capture_context", "capture_message"] diff --git a/backend/src/userbot/modules/capture/message.py b/backend/src/userbot/modules/capture/message.py new file mode 100644 index 0000000..60b7263 --- /dev/null +++ b/backend/src/userbot/modules/capture/message.py @@ -0,0 +1,57 @@ +from pyrogram import Client +from pyrogram.types import Message + +from userbot.modules.capture import repository +from userbot.modules.capture.context import CaptureContext +from userbot.modules.media import capture_media, self_destruct_ttl +from utils.policy.models import CaptureToggles + + +def sender_id(message: Message) -> int | None: + if message.from_user is not None: + return message.from_user.id + if message.sender_chat is not None: + return message.sender_chat.id + return None + + +def callbacks(message: Message) -> list[tuple[int, str | None, bytes | None]]: + rows = getattr(message.reply_markup, "inline_keyboard", None) + if not rows: + return [] + buttons: list[tuple[int, str | None, bytes | None]] = [] + position = 0 + for row in rows: + for button in row: + data = button.callback_data + if data is not None: + encoded = data.encode() if isinstance(data, str) else data + buttons.append((position, button.text, encoded)) + position += 1 + return buttons + + +async def capture_message( + client: Client, message: Message, ctx: CaptureContext, toggles: CaptureToggles +) -> None: + if message.empty or message.chat is None or message.date is None: + return + chat_id = message.chat.id or 0 + await repository.upsert_message( + ctx.pool, + ctx.account_id, + chat_id, + message.id, + message.date, + sender_id(message), + message.text or message.caption, + str(message), + has_media=message.media is not None, + is_self_destruct=self_destruct_ttl(message) is not None, + ) + await capture_media(client, message, ctx, chat_id, message.id, toggles) + buttons = callbacks(message) + if buttons: + await repository.insert_callbacks( + ctx.pool, ctx.account_id, chat_id, message.id, buttons + ) diff --git a/backend/src/userbot/modules/jobs/__init__.py b/backend/src/userbot/modules/jobs/__init__.py index 1b97288..9a7d0a0 100644 --- a/backend/src/userbot/modules/jobs/__init__.py +++ b/backend/src/userbot/modules/jobs/__init__.py @@ -1,3 +1,4 @@ +from userbot.modules.jobs import handlers from userbot.modules.jobs.consumer import JOBS_CHANGED_CHANNEL, JobConsumer from userbot.modules.jobs.context import JobContext from userbot.modules.jobs.registry import JOB_HANDLERS, JobHandler, register @@ -8,5 +9,6 @@ __all__ = [ "JobConsumer", "JobContext", "JobHandler", + "handlers", "register", ] diff --git a/backend/src/userbot/modules/jobs/handlers/__init__.py b/backend/src/userbot/modules/jobs/handlers/__init__.py new file mode 100644 index 0000000..76bdaba --- /dev/null +++ b/backend/src/userbot/modules/jobs/handlers/__init__.py @@ -0,0 +1,3 @@ +from userbot.modules.jobs.handlers import backfill, fetch_media + +__all__ = ["backfill", "fetch_media"] diff --git a/backend/src/userbot/modules/jobs/handlers/backfill.py b/backend/src/userbot/modules/jobs/handlers/backfill.py new file mode 100644 index 0000000..9aa8b94 --- /dev/null +++ b/backend/src/userbot/modules/jobs/handlers/backfill.py @@ -0,0 +1,33 @@ +from userbot.modules.capture import capture_message +from userbot.modules.jobs.context import JobContext +from userbot.modules.jobs.registry import register +from utils.policy.models import CaptureToggles + +SAVE_EVERY = 100 + + +@register("backfill") +async def backfill(ctx: JobContext) -> None: + client = ctx.client + if client is None: + return + capture = getattr(client, "capture", None) + if capture is None: + return + chat_id = ctx.job.params["chat_id"] + toggles = CaptureToggles( + messages=True, + media=bool(ctx.job.params.get("media")), + self_destruct_media=False, + ) + max_id = (ctx.job.cursor or {}).get("max_id", 0) + processed = ctx.job.progress.get("processed", 0) + kwargs = {"max_id": max_id} if max_id else {} + async for message in client.get_chat_history(chat_id, **kwargs): + await capture_message(client, message, capture, toggles) + processed += 1 + if processed % SAVE_EVERY == 0: + next_max = message.id - 1 + await ctx.save_cursor({"max_id": next_max}) + await ctx.report_progress({"processed": processed, "max_id": next_max}) + await ctx.report_progress({"processed": processed, "done": True}) diff --git a/backend/src/userbot/modules/jobs/handlers/fetch_media.py b/backend/src/userbot/modules/jobs/handlers/fetch_media.py new file mode 100644 index 0000000..c283864 --- /dev/null +++ b/backend/src/userbot/modules/jobs/handlers/fetch_media.py @@ -0,0 +1,23 @@ +from userbot.modules.jobs.context import JobContext +from userbot.modules.jobs.registry import register +from userbot.modules.media import capture_media +from utils.policy.models import CaptureToggles + + +@register("fetch_media") +async def fetch_media(ctx: JobContext) -> None: + client = ctx.client + if client is None: + return + capture = getattr(client, "capture", None) + if capture is None: + return + chat_id = ctx.job.params["chat_id"] + message_id = ctx.job.params["message_id"] + message = await client.get_messages(chat_id, message_id) + if isinstance(message, list): + message = message[0] if message else None + if message is None or message.empty: + return + toggles = CaptureToggles(media=True, self_destruct_media=True) + await capture_media(client, message, capture, chat_id, message_id, toggles)