feat: add backfills logic

This commit is contained in:
h
2026-05-29 19:33:57 +02:00
parent 4a471df8f1
commit 51093da660
12 changed files with 251 additions and 63 deletions
+4 -3
View File
@@ -2,10 +2,10 @@ from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import asyncpg import asyncpg
from dishka.integrations.fastapi import FromDishka, inject, setup_dishka from dishka.integrations.fastapi import DishkaRoute, FromDishka, setup_dishka
from fastapi import FastAPI from fastapi import FastAPI
from api.routers import folders, policy from api.routers import backfill, folders, policy
from dependencies.container import container from dependencies.container import container
@@ -16,10 +16,10 @@ async def lifespan(app_: FastAPI) -> AsyncGenerator[None]:
app = FastAPI(title="beavergram API", lifespan=lifespan) app = FastAPI(title="beavergram API", lifespan=lifespan)
app.router.route_class = DishkaRoute
@app.get("/health") @app.get("/health")
@inject
async def health(pool: FromDishka[asyncpg.Pool]) -> dict[str, bool]: async def health(pool: FromDishka[asyncpg.Pool]) -> dict[str, bool]:
db_ok = await pool.fetchval("SELECT 1") == 1 db_ok = await pool.fetchval("SELECT 1") == 1
timescale_ok = await pool.fetchval( timescale_ok = await pool.fetchval(
@@ -30,5 +30,6 @@ async def health(pool: FromDishka[asyncpg.Pool]) -> dict[str, bool]:
app.include_router(policy.router) app.include_router(policy.router)
app.include_router(folders.router) app.include_router(folders.router)
app.include_router(backfill.router)
setup_dishka(container, app) setup_dishka(container, app)
+119
View File
@@ -0,0 +1,119 @@
import json
from datetime import datetime
from typing import Annotated, Any
import asyncpg
from dishka.integrations.fastapi import DishkaRoute, FromDishka
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
router = APIRouter(prefix="/api", tags=["backfill"], route_class=DishkaRoute)
JOBS_CHANGED_CHANNEL = "jobs_changed"
class BackfillRequest(BaseModel):
account_id: int
chat_id: int
media: bool = False
class FetchMediaRequest(BaseModel):
account_id: int
chat_id: int
message_id: int
class EnqueueResponse(BaseModel):
job_id: int
class JobView(BaseModel):
id: int
account_id: int
kind: str
status: str
params: dict[str, Any]
cursor: dict[str, Any] | None
progress: dict[str, Any]
flood_waits: int
attempts: int
error: str | None
created_at: datetime
started_at: datetime | None
finished_at: datetime | None
def _to_view(row: asyncpg.Record) -> JobView:
data = dict(row)
data["params"] = json.loads(data["params"])
data["progress"] = json.loads(data["progress"])
data["cursor"] = json.loads(data["cursor"]) if data["cursor"] is not None else None
return JobView(**data)
async def _enqueue(
pool: asyncpg.Pool, account_id: int, kind: str, params: dict[str, Any]
) -> int:
job_id = await pool.fetchval(
"INSERT INTO jobs (account_id, kind, params) "
"VALUES ($1, $2, $3::jsonb) RETURNING id",
account_id,
kind,
json.dumps(params),
)
await pool.execute(f"NOTIFY {JOBS_CHANGED_CHANNEL}")
return job_id
@router.post("/backfill", status_code=201)
async def enqueue_backfill(
pool: FromDishka[asyncpg.Pool], body: BackfillRequest
) -> EnqueueResponse:
job_id = await _enqueue(
pool,
body.account_id,
"backfill",
{"chat_id": body.chat_id, "media": body.media},
)
return EnqueueResponse(job_id=job_id)
@router.post("/media/fetch", status_code=201)
async def enqueue_fetch_media(
pool: FromDishka[asyncpg.Pool], body: FetchMediaRequest
) -> EnqueueResponse:
job_id = await _enqueue(
pool,
body.account_id,
"fetch_media",
{"chat_id": body.chat_id, "message_id": body.message_id},
)
return EnqueueResponse(job_id=job_id)
@router.get("/jobs")
async def list_jobs(
pool: FromDishka[asyncpg.Pool],
account_id: Annotated[int, Query()],
status: Annotated[str | None, Query()] = None,
) -> list[JobView]:
if status is None:
rows = await pool.fetch(
"SELECT * FROM jobs WHERE account_id = $1 ORDER BY id DESC", account_id
)
else:
rows = await pool.fetch(
"SELECT * FROM jobs WHERE account_id = $1 AND status = $2 ORDER BY id DESC",
account_id,
status,
)
return [_to_view(row) for row in rows]
@router.get("/jobs/{job_id}")
async def get_job(pool: FromDishka[asyncpg.Pool], job_id: int) -> JobView:
row = await pool.fetchrow("SELECT * FROM jobs WHERE id = $1", job_id)
if row is None:
raise HTTPException(status_code=404, detail="job not found")
return _to_view(row)
+2 -3
View File
@@ -1,13 +1,13 @@
from typing import Annotated from typing import Annotated
import asyncpg import asyncpg
from dishka.integrations.fastapi import FromDishka, inject from dishka.integrations.fastapi import DishkaRoute, FromDishka
from fastapi import APIRouter, Query from fastapi import APIRouter, Query
from utils.policy import repository from utils.policy import repository
from utils.policy.models import FolderSpec from utils.policy.models import FolderSpec
router = APIRouter(prefix="/api/folders", tags=["folders"]) router = APIRouter(prefix="/api/folders", tags=["folders"], route_class=DishkaRoute)
def _serialize(spec: FolderSpec) -> dict: def _serialize(spec: FolderSpec) -> dict:
@@ -28,7 +28,6 @@ def _serialize(spec: FolderSpec) -> dict:
@router.get("") @router.get("")
@inject
async def list_folders( async def list_folders(
pool: FromDishka[asyncpg.Pool], account_id: Annotated[int, Query()] pool: FromDishka[asyncpg.Pool], account_id: Annotated[int, Query()]
) -> list[dict]: ) -> list[dict]:
+2 -8
View File
@@ -1,7 +1,7 @@
from typing import Annotated from typing import Annotated
import asyncpg import asyncpg
from dishka.integrations.fastapi import FromDishka, inject from dishka.integrations.fastapi import DishkaRoute, FromDishka
from fastapi import APIRouter, HTTPException, Query from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel from pydantic import BaseModel
@@ -15,7 +15,7 @@ from utils.policy.models import (
) )
from utils.policy.resolver import resolve from utils.policy.resolver import resolve
router = APIRouter(prefix="/api/policy", tags=["policy"]) router = APIRouter(prefix="/api/policy", tags=["policy"], route_class=DishkaRoute)
POLICY_CHANGED_CHANNEL = "policy_changed" POLICY_CHANGED_CHANNEL = "policy_changed"
@@ -29,7 +29,6 @@ class EffectiveQuery(BaseModel):
@router.get("") @router.get("")
@inject
async def list_policies( async def list_policies(
pool: FromDishka[asyncpg.Pool], account_id: Annotated[int, Query()] pool: FromDishka[asyncpg.Pool], account_id: Annotated[int, Query()]
) -> list[PolicyRecord]: ) -> list[PolicyRecord]:
@@ -37,7 +36,6 @@ async def list_policies(
@router.get("/effective") @router.get("/effective")
@inject
async def effective_policy( async def effective_policy(
pool: FromDishka[asyncpg.Pool], query: Annotated[EffectiveQuery, Query()] pool: FromDishka[asyncpg.Pool], query: Annotated[EffectiveQuery, Query()]
) -> CaptureToggles: ) -> CaptureToggles:
@@ -53,7 +51,6 @@ async def effective_policy(
@router.post("", status_code=201) @router.post("", status_code=201)
@inject
async def create_policy( async def create_policy(
pool: FromDishka[asyncpg.Pool], body: PolicyCreate pool: FromDishka[asyncpg.Pool], body: PolicyCreate
) -> PolicyRecord: ) -> PolicyRecord:
@@ -65,7 +62,6 @@ async def create_policy(
@router.get("/{policy_id}") @router.get("/{policy_id}")
@inject
async def get_policy(pool: FromDishka[asyncpg.Pool], policy_id: int) -> PolicyRecord: async def get_policy(pool: FromDishka[asyncpg.Pool], policy_id: int) -> PolicyRecord:
record = await repository.get_policy(pool, policy_id) record = await repository.get_policy(pool, policy_id)
if record is None: if record is None:
@@ -74,7 +70,6 @@ async def get_policy(pool: FromDishka[asyncpg.Pool], policy_id: int) -> PolicyRe
@router.put("/{policy_id}") @router.put("/{policy_id}")
@inject
async def update_policy( async def update_policy(
pool: FromDishka[asyncpg.Pool], policy_id: int, body: CaptureToggles pool: FromDishka[asyncpg.Pool], policy_id: int, body: CaptureToggles
) -> PolicyRecord: ) -> PolicyRecord:
@@ -86,7 +81,6 @@ async def update_policy(
@router.delete("/{policy_id}", status_code=204) @router.delete("/{policy_id}", status_code=204)
@inject
async def delete_policy(pool: FromDishka[asyncpg.Pool], policy_id: int) -> None: async def delete_policy(pool: FromDishka[asyncpg.Pool], policy_id: int) -> None:
if not await repository.delete_policy(pool, policy_id): if not await repository.delete_policy(pool, policy_id):
raise HTTPException(status_code=404, detail="policy not found") raise HTTPException(status_code=404, detail="policy not found")
+1 -1
View File
@@ -1,9 +1,9 @@
from pyrogram.types import Message from pyrogram.types import Message
from userbot import PyroClient from userbot import PyroClient
from userbot.handlers.messages import sender_id
from userbot.modules.capture import repository from userbot.modules.capture import repository
from userbot.modules.capture.chat_meta import meta_from_chat from userbot.modules.capture.chat_meta import meta_from_chat
from userbot.modules.capture.message import sender_id
from userbot.modules.media import self_destruct_ttl from userbot.modules.media import self_destruct_ttl
+3 -47
View File
@@ -1,33 +1,8 @@
from pyrogram.types import Message from pyrogram.types import Message
from userbot import PyroClient from userbot import PyroClient
from userbot.modules.capture import repository from userbot.modules.capture import capture_message
from userbot.modules.capture.chat_meta import meta_from_chat from userbot.modules.capture.chat_meta import meta_from_chat
from userbot.modules.media import capture_media, self_destruct_ttl
def sender_id(message: Message) -> int | None:
if message.from_user is not None:
return message.from_user.id
if message.sender_chat is not None:
return message.sender_chat.id
return None
def _callbacks(message: Message) -> list[tuple[int, str | None, bytes | None]]:
rows = getattr(message.reply_markup, "inline_keyboard", None)
if not rows:
return []
buttons: list[tuple[int, str | None, bytes | None]] = []
position = 0
for row in rows:
for button in row:
data = button.callback_data
if data is not None:
encoded = data.encode() if isinstance(data, str) else data
buttons.append((position, button.text, encoded))
position += 1
return buttons
@PyroClient.on_message() @PyroClient.on_message()
@@ -35,30 +10,11 @@ async def on_message(client: PyroClient, message: Message) -> None:
ctx = client.capture ctx = client.capture
if ctx is None or message.empty or message.chat is None or message.date is None: if ctx is None or message.empty or message.chat is None or message.date is None:
return return
chat = message.chat meta = meta_from_chat(message.chat, ctx.contacts.ids)
chat_id = chat.id or 0
meta = meta_from_chat(chat, ctx.contacts.ids)
toggles = ctx.resolve(meta) toggles = ctx.resolve(meta)
if not toggles.messages: if not toggles.messages:
return return
await repository.upsert_message( await capture_message(client, message, ctx, toggles)
ctx.pool,
ctx.account_id,
chat_id,
message.id,
message.date,
sender_id(message),
message.text or message.caption,
str(message),
has_media=message.media is not None,
is_self_destruct=self_destruct_ttl(message) is not None,
)
await capture_media(client, message, ctx, chat_id, message.id, toggles)
buttons = _callbacks(message)
if buttons:
await repository.insert_callbacks(
ctx.pool, ctx.account_id, chat_id, message.id, buttons
)
handlers = on_message.handlers handlers = on_message.handlers
@@ -1,3 +1,4 @@
from userbot.modules.capture.context import CaptureContext, build_capture_context from userbot.modules.capture.context import CaptureContext, build_capture_context
from userbot.modules.capture.message import capture_message
__all__ = ["CaptureContext", "build_capture_context"] __all__ = ["CaptureContext", "build_capture_context", "capture_message"]
@@ -0,0 +1,57 @@
from pyrogram import Client
from pyrogram.types import Message
from userbot.modules.capture import repository
from userbot.modules.capture.context import CaptureContext
from userbot.modules.media import capture_media, self_destruct_ttl
from utils.policy.models import CaptureToggles
def sender_id(message: Message) -> int | None:
if message.from_user is not None:
return message.from_user.id
if message.sender_chat is not None:
return message.sender_chat.id
return None
def callbacks(message: Message) -> list[tuple[int, str | None, bytes | None]]:
rows = getattr(message.reply_markup, "inline_keyboard", None)
if not rows:
return []
buttons: list[tuple[int, str | None, bytes | None]] = []
position = 0
for row in rows:
for button in row:
data = button.callback_data
if data is not None:
encoded = data.encode() if isinstance(data, str) else data
buttons.append((position, button.text, encoded))
position += 1
return buttons
async def capture_message(
client: Client, message: Message, ctx: CaptureContext, toggles: CaptureToggles
) -> None:
if message.empty or message.chat is None or message.date is None:
return
chat_id = message.chat.id or 0
await repository.upsert_message(
ctx.pool,
ctx.account_id,
chat_id,
message.id,
message.date,
sender_id(message),
message.text or message.caption,
str(message),
has_media=message.media is not None,
is_self_destruct=self_destruct_ttl(message) is not None,
)
await capture_media(client, message, ctx, chat_id, message.id, toggles)
buttons = callbacks(message)
if buttons:
await repository.insert_callbacks(
ctx.pool, ctx.account_id, chat_id, message.id, buttons
)
@@ -1,3 +1,4 @@
from userbot.modules.jobs import handlers
from userbot.modules.jobs.consumer import JOBS_CHANGED_CHANNEL, JobConsumer from userbot.modules.jobs.consumer import JOBS_CHANGED_CHANNEL, JobConsumer
from userbot.modules.jobs.context import JobContext from userbot.modules.jobs.context import JobContext
from userbot.modules.jobs.registry import JOB_HANDLERS, JobHandler, register from userbot.modules.jobs.registry import JOB_HANDLERS, JobHandler, register
@@ -8,5 +9,6 @@ __all__ = [
"JobConsumer", "JobConsumer",
"JobContext", "JobContext",
"JobHandler", "JobHandler",
"handlers",
"register", "register",
] ]
@@ -0,0 +1,3 @@
from userbot.modules.jobs.handlers import backfill, fetch_media
__all__ = ["backfill", "fetch_media"]
@@ -0,0 +1,33 @@
from userbot.modules.capture import capture_message
from userbot.modules.jobs.context import JobContext
from userbot.modules.jobs.registry import register
from utils.policy.models import CaptureToggles
SAVE_EVERY = 100
@register("backfill")
async def backfill(ctx: JobContext) -> None:
client = ctx.client
if client is None:
return
capture = getattr(client, "capture", None)
if capture is None:
return
chat_id = ctx.job.params["chat_id"]
toggles = CaptureToggles(
messages=True,
media=bool(ctx.job.params.get("media")),
self_destruct_media=False,
)
max_id = (ctx.job.cursor or {}).get("max_id", 0)
processed = ctx.job.progress.get("processed", 0)
kwargs = {"max_id": max_id} if max_id else {}
async for message in client.get_chat_history(chat_id, **kwargs):
await capture_message(client, message, capture, toggles)
processed += 1
if processed % SAVE_EVERY == 0:
next_max = message.id - 1
await ctx.save_cursor({"max_id": next_max})
await ctx.report_progress({"processed": processed, "max_id": next_max})
await ctx.report_progress({"processed": processed, "done": True})
@@ -0,0 +1,23 @@
from userbot.modules.jobs.context import JobContext
from userbot.modules.jobs.registry import register
from userbot.modules.media import capture_media
from utils.policy.models import CaptureToggles
@register("fetch_media")
async def fetch_media(ctx: JobContext) -> None:
client = ctx.client
if client is None:
return
capture = getattr(client, "capture", None)
if capture is None:
return
chat_id = ctx.job.params["chat_id"]
message_id = ctx.job.params["message_id"]
message = await client.get_messages(chat_id, message_id)
if isinstance(message, list):
message = message[0] if message else None
if message is None or message.empty:
return
toggles = CaptureToggles(media=True, self_destruct_media=True)
await capture_media(client, message, capture, chat_id, message_id, toggles)