feat: add backfills logic

This commit is contained in:
h
2026-05-29 19:33:57 +02:00
parent 4a471df8f1
commit 51093da660
12 changed files with 251 additions and 63 deletions
+4 -3
View File
@@ -2,10 +2,10 @@ from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
import asyncpg
from dishka.integrations.fastapi import FromDishka, inject, setup_dishka
from dishka.integrations.fastapi import DishkaRoute, FromDishka, setup_dishka
from fastapi import FastAPI
from api.routers import folders, policy
from api.routers import backfill, folders, policy
from dependencies.container import container
@@ -16,10 +16,10 @@ async def lifespan(app_: FastAPI) -> AsyncGenerator[None]:
app = FastAPI(title="beavergram API", lifespan=lifespan)
app.router.route_class = DishkaRoute
@app.get("/health")
@inject
async def health(pool: FromDishka[asyncpg.Pool]) -> dict[str, bool]:
db_ok = await pool.fetchval("SELECT 1") == 1
timescale_ok = await pool.fetchval(
@@ -30,5 +30,6 @@ async def health(pool: FromDishka[asyncpg.Pool]) -> dict[str, bool]:
app.include_router(policy.router)
app.include_router(folders.router)
app.include_router(backfill.router)
setup_dishka(container, app)
+119
View File
@@ -0,0 +1,119 @@
import json
from datetime import datetime
from typing import Annotated, Any
import asyncpg
from dishka.integrations.fastapi import DishkaRoute, FromDishka
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
router = APIRouter(prefix="/api", tags=["backfill"], route_class=DishkaRoute)
JOBS_CHANGED_CHANNEL = "jobs_changed"
class BackfillRequest(BaseModel):
account_id: int
chat_id: int
media: bool = False
class FetchMediaRequest(BaseModel):
account_id: int
chat_id: int
message_id: int
class EnqueueResponse(BaseModel):
job_id: int
class JobView(BaseModel):
id: int
account_id: int
kind: str
status: str
params: dict[str, Any]
cursor: dict[str, Any] | None
progress: dict[str, Any]
flood_waits: int
attempts: int
error: str | None
created_at: datetime
started_at: datetime | None
finished_at: datetime | None
def _to_view(row: asyncpg.Record) -> JobView:
data = dict(row)
data["params"] = json.loads(data["params"])
data["progress"] = json.loads(data["progress"])
data["cursor"] = json.loads(data["cursor"]) if data["cursor"] is not None else None
return JobView(**data)
async def _enqueue(
pool: asyncpg.Pool, account_id: int, kind: str, params: dict[str, Any]
) -> int:
job_id = await pool.fetchval(
"INSERT INTO jobs (account_id, kind, params) "
"VALUES ($1, $2, $3::jsonb) RETURNING id",
account_id,
kind,
json.dumps(params),
)
await pool.execute(f"NOTIFY {JOBS_CHANGED_CHANNEL}")
return job_id
@router.post("/backfill", status_code=201)
async def enqueue_backfill(
pool: FromDishka[asyncpg.Pool], body: BackfillRequest
) -> EnqueueResponse:
job_id = await _enqueue(
pool,
body.account_id,
"backfill",
{"chat_id": body.chat_id, "media": body.media},
)
return EnqueueResponse(job_id=job_id)
@router.post("/media/fetch", status_code=201)
async def enqueue_fetch_media(
pool: FromDishka[asyncpg.Pool], body: FetchMediaRequest
) -> EnqueueResponse:
job_id = await _enqueue(
pool,
body.account_id,
"fetch_media",
{"chat_id": body.chat_id, "message_id": body.message_id},
)
return EnqueueResponse(job_id=job_id)
@router.get("/jobs")
async def list_jobs(
pool: FromDishka[asyncpg.Pool],
account_id: Annotated[int, Query()],
status: Annotated[str | None, Query()] = None,
) -> list[JobView]:
if status is None:
rows = await pool.fetch(
"SELECT * FROM jobs WHERE account_id = $1 ORDER BY id DESC", account_id
)
else:
rows = await pool.fetch(
"SELECT * FROM jobs WHERE account_id = $1 AND status = $2 ORDER BY id DESC",
account_id,
status,
)
return [_to_view(row) for row in rows]
@router.get("/jobs/{job_id}")
async def get_job(pool: FromDishka[asyncpg.Pool], job_id: int) -> JobView:
row = await pool.fetchrow("SELECT * FROM jobs WHERE id = $1", job_id)
if row is None:
raise HTTPException(status_code=404, detail="job not found")
return _to_view(row)
+2 -3
View File
@@ -1,13 +1,13 @@
from typing import Annotated
import asyncpg
from dishka.integrations.fastapi import FromDishka, inject
from dishka.integrations.fastapi import DishkaRoute, FromDishka
from fastapi import APIRouter, Query
from utils.policy import repository
from utils.policy.models import FolderSpec
router = APIRouter(prefix="/api/folders", tags=["folders"])
router = APIRouter(prefix="/api/folders", tags=["folders"], route_class=DishkaRoute)
def _serialize(spec: FolderSpec) -> dict:
@@ -28,7 +28,6 @@ def _serialize(spec: FolderSpec) -> dict:
@router.get("")
@inject
async def list_folders(
pool: FromDishka[asyncpg.Pool], account_id: Annotated[int, Query()]
) -> list[dict]:
+2 -8
View File
@@ -1,7 +1,7 @@
from typing import Annotated
import asyncpg
from dishka.integrations.fastapi import FromDishka, inject
from dishka.integrations.fastapi import DishkaRoute, FromDishka
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
@@ -15,7 +15,7 @@ from utils.policy.models import (
)
from utils.policy.resolver import resolve
router = APIRouter(prefix="/api/policy", tags=["policy"])
router = APIRouter(prefix="/api/policy", tags=["policy"], route_class=DishkaRoute)
POLICY_CHANGED_CHANNEL = "policy_changed"
@@ -29,7 +29,6 @@ class EffectiveQuery(BaseModel):
@router.get("")
@inject
async def list_policies(
pool: FromDishka[asyncpg.Pool], account_id: Annotated[int, Query()]
) -> list[PolicyRecord]:
@@ -37,7 +36,6 @@ async def list_policies(
@router.get("/effective")
@inject
async def effective_policy(
pool: FromDishka[asyncpg.Pool], query: Annotated[EffectiveQuery, Query()]
) -> CaptureToggles:
@@ -53,7 +51,6 @@ async def effective_policy(
@router.post("", status_code=201)
@inject
async def create_policy(
pool: FromDishka[asyncpg.Pool], body: PolicyCreate
) -> PolicyRecord:
@@ -65,7 +62,6 @@ async def create_policy(
@router.get("/{policy_id}")
@inject
async def get_policy(pool: FromDishka[asyncpg.Pool], policy_id: int) -> PolicyRecord:
record = await repository.get_policy(pool, policy_id)
if record is None:
@@ -74,7 +70,6 @@ async def get_policy(pool: FromDishka[asyncpg.Pool], policy_id: int) -> PolicyRe
@router.put("/{policy_id}")
@inject
async def update_policy(
pool: FromDishka[asyncpg.Pool], policy_id: int, body: CaptureToggles
) -> PolicyRecord:
@@ -86,7 +81,6 @@ async def update_policy(
@router.delete("/{policy_id}", status_code=204)
@inject
async def delete_policy(pool: FromDishka[asyncpg.Pool], policy_id: int) -> None:
if not await repository.delete_policy(pool, policy_id):
raise HTTPException(status_code=404, detail="policy not found")
+1 -1
View File
@@ -1,9 +1,9 @@
from pyrogram.types import Message
from userbot import PyroClient
from userbot.handlers.messages import sender_id
from userbot.modules.capture import repository
from userbot.modules.capture.chat_meta import meta_from_chat
from userbot.modules.capture.message import sender_id
from userbot.modules.media import self_destruct_ttl
+3 -47
View File
@@ -1,33 +1,8 @@
from pyrogram.types import Message
from userbot import PyroClient
from userbot.modules.capture import repository
from userbot.modules.capture import capture_message
from userbot.modules.capture.chat_meta import meta_from_chat
from userbot.modules.media import capture_media, self_destruct_ttl
def sender_id(message: Message) -> int | None:
if message.from_user is not None:
return message.from_user.id
if message.sender_chat is not None:
return message.sender_chat.id
return None
def _callbacks(message: Message) -> list[tuple[int, str | None, bytes | None]]:
rows = getattr(message.reply_markup, "inline_keyboard", None)
if not rows:
return []
buttons: list[tuple[int, str | None, bytes | None]] = []
position = 0
for row in rows:
for button in row:
data = button.callback_data
if data is not None:
encoded = data.encode() if isinstance(data, str) else data
buttons.append((position, button.text, encoded))
position += 1
return buttons
@PyroClient.on_message()
@@ -35,30 +10,11 @@ async def on_message(client: PyroClient, message: Message) -> None:
ctx = client.capture
if ctx is None or message.empty or message.chat is None or message.date is None:
return
chat = message.chat
chat_id = chat.id or 0
meta = meta_from_chat(chat, ctx.contacts.ids)
meta = meta_from_chat(message.chat, ctx.contacts.ids)
toggles = ctx.resolve(meta)
if not toggles.messages:
return
await repository.upsert_message(
ctx.pool,
ctx.account_id,
chat_id,
message.id,
message.date,
sender_id(message),
message.text or message.caption,
str(message),
has_media=message.media is not None,
is_self_destruct=self_destruct_ttl(message) is not None,
)
await capture_media(client, message, ctx, chat_id, message.id, toggles)
buttons = _callbacks(message)
if buttons:
await repository.insert_callbacks(
ctx.pool, ctx.account_id, chat_id, message.id, buttons
)
await capture_message(client, message, ctx, toggles)
handlers = on_message.handlers
@@ -1,3 +1,4 @@
from userbot.modules.capture.context import CaptureContext, build_capture_context
from userbot.modules.capture.message import capture_message
__all__ = ["CaptureContext", "build_capture_context"]
__all__ = ["CaptureContext", "build_capture_context", "capture_message"]
@@ -0,0 +1,57 @@
from pyrogram import Client
from pyrogram.types import Message
from userbot.modules.capture import repository
from userbot.modules.capture.context import CaptureContext
from userbot.modules.media import capture_media, self_destruct_ttl
from utils.policy.models import CaptureToggles
def sender_id(message: Message) -> int | None:
if message.from_user is not None:
return message.from_user.id
if message.sender_chat is not None:
return message.sender_chat.id
return None
def callbacks(message: Message) -> list[tuple[int, str | None, bytes | None]]:
rows = getattr(message.reply_markup, "inline_keyboard", None)
if not rows:
return []
buttons: list[tuple[int, str | None, bytes | None]] = []
position = 0
for row in rows:
for button in row:
data = button.callback_data
if data is not None:
encoded = data.encode() if isinstance(data, str) else data
buttons.append((position, button.text, encoded))
position += 1
return buttons
async def capture_message(
client: Client, message: Message, ctx: CaptureContext, toggles: CaptureToggles
) -> None:
if message.empty or message.chat is None or message.date is None:
return
chat_id = message.chat.id or 0
await repository.upsert_message(
ctx.pool,
ctx.account_id,
chat_id,
message.id,
message.date,
sender_id(message),
message.text or message.caption,
str(message),
has_media=message.media is not None,
is_self_destruct=self_destruct_ttl(message) is not None,
)
await capture_media(client, message, ctx, chat_id, message.id, toggles)
buttons = callbacks(message)
if buttons:
await repository.insert_callbacks(
ctx.pool, ctx.account_id, chat_id, message.id, buttons
)
@@ -1,3 +1,4 @@
from userbot.modules.jobs import handlers
from userbot.modules.jobs.consumer import JOBS_CHANGED_CHANNEL, JobConsumer
from userbot.modules.jobs.context import JobContext
from userbot.modules.jobs.registry import JOB_HANDLERS, JobHandler, register
@@ -8,5 +9,6 @@ __all__ = [
"JobConsumer",
"JobContext",
"JobHandler",
"handlers",
"register",
]
@@ -0,0 +1,3 @@
from userbot.modules.jobs.handlers import backfill, fetch_media
__all__ = ["backfill", "fetch_media"]
@@ -0,0 +1,33 @@
from userbot.modules.capture import capture_message
from userbot.modules.jobs.context import JobContext
from userbot.modules.jobs.registry import register
from utils.policy.models import CaptureToggles
SAVE_EVERY = 100
@register("backfill")
async def backfill(ctx: JobContext) -> None:
client = ctx.client
if client is None:
return
capture = getattr(client, "capture", None)
if capture is None:
return
chat_id = ctx.job.params["chat_id"]
toggles = CaptureToggles(
messages=True,
media=bool(ctx.job.params.get("media")),
self_destruct_media=False,
)
max_id = (ctx.job.cursor or {}).get("max_id", 0)
processed = ctx.job.progress.get("processed", 0)
kwargs = {"max_id": max_id} if max_id else {}
async for message in client.get_chat_history(chat_id, **kwargs):
await capture_message(client, message, capture, toggles)
processed += 1
if processed % SAVE_EVERY == 0:
next_max = message.id - 1
await ctx.save_cursor({"max_id": next_max})
await ctx.report_progress({"processed": processed, "max_id": next_max})
await ctx.report_progress({"processed": processed, "done": True})
@@ -0,0 +1,23 @@
from userbot.modules.jobs.context import JobContext
from userbot.modules.jobs.registry import register
from userbot.modules.media import capture_media
from utils.policy.models import CaptureToggles
@register("fetch_media")
async def fetch_media(ctx: JobContext) -> None:
client = ctx.client
if client is None:
return
capture = getattr(client, "capture", None)
if capture is None:
return
chat_id = ctx.job.params["chat_id"]
message_id = ctx.job.params["message_id"]
message = await client.get_messages(chat_id, message_id)
if isinstance(message, list):
message = message[0] if message else None
if message is None or message.empty:
return
toggles = CaptureToggles(media=True, self_destruct_media=True)
await capture_media(client, message, capture, chat_id, message_id, toggles)