feat: add backfills logic
This commit is contained in:
@@ -2,10 +2,10 @@ from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import asyncpg
|
||||
from dishka.integrations.fastapi import FromDishka, inject, setup_dishka
|
||||
from dishka.integrations.fastapi import DishkaRoute, FromDishka, setup_dishka
|
||||
from fastapi import FastAPI
|
||||
|
||||
from api.routers import folders, policy
|
||||
from api.routers import backfill, folders, policy
|
||||
from dependencies.container import container
|
||||
|
||||
|
||||
@@ -16,10 +16,10 @@ async def lifespan(app_: FastAPI) -> AsyncGenerator[None]:
|
||||
|
||||
|
||||
app = FastAPI(title="beavergram API", lifespan=lifespan)
|
||||
app.router.route_class = DishkaRoute
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
@inject
|
||||
async def health(pool: FromDishka[asyncpg.Pool]) -> dict[str, bool]:
|
||||
db_ok = await pool.fetchval("SELECT 1") == 1
|
||||
timescale_ok = await pool.fetchval(
|
||||
@@ -30,5 +30,6 @@ async def health(pool: FromDishka[asyncpg.Pool]) -> dict[str, bool]:
|
||||
|
||||
app.include_router(policy.router)
|
||||
app.include_router(folders.router)
|
||||
app.include_router(backfill.router)
|
||||
|
||||
setup_dishka(container, app)
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any
|
||||
|
||||
import asyncpg
|
||||
from dishka.integrations.fastapi import DishkaRoute, FromDishka
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["backfill"], route_class=DishkaRoute)
|
||||
|
||||
JOBS_CHANGED_CHANNEL = "jobs_changed"
|
||||
|
||||
|
||||
class BackfillRequest(BaseModel):
|
||||
account_id: int
|
||||
chat_id: int
|
||||
media: bool = False
|
||||
|
||||
|
||||
class FetchMediaRequest(BaseModel):
|
||||
account_id: int
|
||||
chat_id: int
|
||||
message_id: int
|
||||
|
||||
|
||||
class EnqueueResponse(BaseModel):
|
||||
job_id: int
|
||||
|
||||
|
||||
class JobView(BaseModel):
|
||||
id: int
|
||||
account_id: int
|
||||
kind: str
|
||||
status: str
|
||||
params: dict[str, Any]
|
||||
cursor: dict[str, Any] | None
|
||||
progress: dict[str, Any]
|
||||
flood_waits: int
|
||||
attempts: int
|
||||
error: str | None
|
||||
created_at: datetime
|
||||
started_at: datetime | None
|
||||
finished_at: datetime | None
|
||||
|
||||
|
||||
def _to_view(row: asyncpg.Record) -> JobView:
|
||||
data = dict(row)
|
||||
data["params"] = json.loads(data["params"])
|
||||
data["progress"] = json.loads(data["progress"])
|
||||
data["cursor"] = json.loads(data["cursor"]) if data["cursor"] is not None else None
|
||||
return JobView(**data)
|
||||
|
||||
|
||||
async def _enqueue(
|
||||
pool: asyncpg.Pool, account_id: int, kind: str, params: dict[str, Any]
|
||||
) -> int:
|
||||
job_id = await pool.fetchval(
|
||||
"INSERT INTO jobs (account_id, kind, params) "
|
||||
"VALUES ($1, $2, $3::jsonb) RETURNING id",
|
||||
account_id,
|
||||
kind,
|
||||
json.dumps(params),
|
||||
)
|
||||
await pool.execute(f"NOTIFY {JOBS_CHANGED_CHANNEL}")
|
||||
return job_id
|
||||
|
||||
|
||||
@router.post("/backfill", status_code=201)
|
||||
async def enqueue_backfill(
|
||||
pool: FromDishka[asyncpg.Pool], body: BackfillRequest
|
||||
) -> EnqueueResponse:
|
||||
job_id = await _enqueue(
|
||||
pool,
|
||||
body.account_id,
|
||||
"backfill",
|
||||
{"chat_id": body.chat_id, "media": body.media},
|
||||
)
|
||||
return EnqueueResponse(job_id=job_id)
|
||||
|
||||
|
||||
@router.post("/media/fetch", status_code=201)
|
||||
async def enqueue_fetch_media(
|
||||
pool: FromDishka[asyncpg.Pool], body: FetchMediaRequest
|
||||
) -> EnqueueResponse:
|
||||
job_id = await _enqueue(
|
||||
pool,
|
||||
body.account_id,
|
||||
"fetch_media",
|
||||
{"chat_id": body.chat_id, "message_id": body.message_id},
|
||||
)
|
||||
return EnqueueResponse(job_id=job_id)
|
||||
|
||||
|
||||
@router.get("/jobs")
|
||||
async def list_jobs(
|
||||
pool: FromDishka[asyncpg.Pool],
|
||||
account_id: Annotated[int, Query()],
|
||||
status: Annotated[str | None, Query()] = None,
|
||||
) -> list[JobView]:
|
||||
if status is None:
|
||||
rows = await pool.fetch(
|
||||
"SELECT * FROM jobs WHERE account_id = $1 ORDER BY id DESC", account_id
|
||||
)
|
||||
else:
|
||||
rows = await pool.fetch(
|
||||
"SELECT * FROM jobs WHERE account_id = $1 AND status = $2 ORDER BY id DESC",
|
||||
account_id,
|
||||
status,
|
||||
)
|
||||
return [_to_view(row) for row in rows]
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}")
|
||||
async def get_job(pool: FromDishka[asyncpg.Pool], job_id: int) -> JobView:
|
||||
row = await pool.fetchrow("SELECT * FROM jobs WHERE id = $1", job_id)
|
||||
if row is None:
|
||||
raise HTTPException(status_code=404, detail="job not found")
|
||||
return _to_view(row)
|
||||
@@ -1,13 +1,13 @@
|
||||
from typing import Annotated
|
||||
|
||||
import asyncpg
|
||||
from dishka.integrations.fastapi import FromDishka, inject
|
||||
from dishka.integrations.fastapi import DishkaRoute, FromDishka
|
||||
from fastapi import APIRouter, Query
|
||||
|
||||
from utils.policy import repository
|
||||
from utils.policy.models import FolderSpec
|
||||
|
||||
router = APIRouter(prefix="/api/folders", tags=["folders"])
|
||||
router = APIRouter(prefix="/api/folders", tags=["folders"], route_class=DishkaRoute)
|
||||
|
||||
|
||||
def _serialize(spec: FolderSpec) -> dict:
|
||||
@@ -28,7 +28,6 @@ def _serialize(spec: FolderSpec) -> dict:
|
||||
|
||||
|
||||
@router.get("")
|
||||
@inject
|
||||
async def list_folders(
|
||||
pool: FromDishka[asyncpg.Pool], account_id: Annotated[int, Query()]
|
||||
) -> list[dict]:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Annotated
|
||||
|
||||
import asyncpg
|
||||
from dishka.integrations.fastapi import FromDishka, inject
|
||||
from dishka.integrations.fastapi import DishkaRoute, FromDishka
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -15,7 +15,7 @@ from utils.policy.models import (
|
||||
)
|
||||
from utils.policy.resolver import resolve
|
||||
|
||||
router = APIRouter(prefix="/api/policy", tags=["policy"])
|
||||
router = APIRouter(prefix="/api/policy", tags=["policy"], route_class=DishkaRoute)
|
||||
|
||||
POLICY_CHANGED_CHANNEL = "policy_changed"
|
||||
|
||||
@@ -29,7 +29,6 @@ class EffectiveQuery(BaseModel):
|
||||
|
||||
|
||||
@router.get("")
|
||||
@inject
|
||||
async def list_policies(
|
||||
pool: FromDishka[asyncpg.Pool], account_id: Annotated[int, Query()]
|
||||
) -> list[PolicyRecord]:
|
||||
@@ -37,7 +36,6 @@ async def list_policies(
|
||||
|
||||
|
||||
@router.get("/effective")
|
||||
@inject
|
||||
async def effective_policy(
|
||||
pool: FromDishka[asyncpg.Pool], query: Annotated[EffectiveQuery, Query()]
|
||||
) -> CaptureToggles:
|
||||
@@ -53,7 +51,6 @@ async def effective_policy(
|
||||
|
||||
|
||||
@router.post("", status_code=201)
|
||||
@inject
|
||||
async def create_policy(
|
||||
pool: FromDishka[asyncpg.Pool], body: PolicyCreate
|
||||
) -> PolicyRecord:
|
||||
@@ -65,7 +62,6 @@ async def create_policy(
|
||||
|
||||
|
||||
@router.get("/{policy_id}")
|
||||
@inject
|
||||
async def get_policy(pool: FromDishka[asyncpg.Pool], policy_id: int) -> PolicyRecord:
|
||||
record = await repository.get_policy(pool, policy_id)
|
||||
if record is None:
|
||||
@@ -74,7 +70,6 @@ async def get_policy(pool: FromDishka[asyncpg.Pool], policy_id: int) -> PolicyRe
|
||||
|
||||
|
||||
@router.put("/{policy_id}")
|
||||
@inject
|
||||
async def update_policy(
|
||||
pool: FromDishka[asyncpg.Pool], policy_id: int, body: CaptureToggles
|
||||
) -> PolicyRecord:
|
||||
@@ -86,7 +81,6 @@ async def update_policy(
|
||||
|
||||
|
||||
@router.delete("/{policy_id}", status_code=204)
|
||||
@inject
|
||||
async def delete_policy(pool: FromDishka[asyncpg.Pool], policy_id: int) -> None:
|
||||
if not await repository.delete_policy(pool, policy_id):
|
||||
raise HTTPException(status_code=404, detail="policy not found")
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from pyrogram.types import Message
|
||||
|
||||
from userbot import PyroClient
|
||||
from userbot.handlers.messages import sender_id
|
||||
from userbot.modules.capture import repository
|
||||
from userbot.modules.capture.chat_meta import meta_from_chat
|
||||
from userbot.modules.capture.message import sender_id
|
||||
from userbot.modules.media import self_destruct_ttl
|
||||
|
||||
|
||||
|
||||
@@ -1,33 +1,8 @@
|
||||
from pyrogram.types import Message
|
||||
|
||||
from userbot import PyroClient
|
||||
from userbot.modules.capture import repository
|
||||
from userbot.modules.capture import capture_message
|
||||
from userbot.modules.capture.chat_meta import meta_from_chat
|
||||
from userbot.modules.media import capture_media, self_destruct_ttl
|
||||
|
||||
|
||||
def sender_id(message: Message) -> int | None:
|
||||
if message.from_user is not None:
|
||||
return message.from_user.id
|
||||
if message.sender_chat is not None:
|
||||
return message.sender_chat.id
|
||||
return None
|
||||
|
||||
|
||||
def _callbacks(message: Message) -> list[tuple[int, str | None, bytes | None]]:
|
||||
rows = getattr(message.reply_markup, "inline_keyboard", None)
|
||||
if not rows:
|
||||
return []
|
||||
buttons: list[tuple[int, str | None, bytes | None]] = []
|
||||
position = 0
|
||||
for row in rows:
|
||||
for button in row:
|
||||
data = button.callback_data
|
||||
if data is not None:
|
||||
encoded = data.encode() if isinstance(data, str) else data
|
||||
buttons.append((position, button.text, encoded))
|
||||
position += 1
|
||||
return buttons
|
||||
|
||||
|
||||
@PyroClient.on_message()
|
||||
@@ -35,30 +10,11 @@ async def on_message(client: PyroClient, message: Message) -> None:
|
||||
ctx = client.capture
|
||||
if ctx is None or message.empty or message.chat is None or message.date is None:
|
||||
return
|
||||
chat = message.chat
|
||||
chat_id = chat.id or 0
|
||||
meta = meta_from_chat(chat, ctx.contacts.ids)
|
||||
meta = meta_from_chat(message.chat, ctx.contacts.ids)
|
||||
toggles = ctx.resolve(meta)
|
||||
if not toggles.messages:
|
||||
return
|
||||
await repository.upsert_message(
|
||||
ctx.pool,
|
||||
ctx.account_id,
|
||||
chat_id,
|
||||
message.id,
|
||||
message.date,
|
||||
sender_id(message),
|
||||
message.text or message.caption,
|
||||
str(message),
|
||||
has_media=message.media is not None,
|
||||
is_self_destruct=self_destruct_ttl(message) is not None,
|
||||
)
|
||||
await capture_media(client, message, ctx, chat_id, message.id, toggles)
|
||||
buttons = _callbacks(message)
|
||||
if buttons:
|
||||
await repository.insert_callbacks(
|
||||
ctx.pool, ctx.account_id, chat_id, message.id, buttons
|
||||
)
|
||||
await capture_message(client, message, ctx, toggles)
|
||||
|
||||
|
||||
handlers = on_message.handlers
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from userbot.modules.capture.context import CaptureContext, build_capture_context
|
||||
from userbot.modules.capture.message import capture_message
|
||||
|
||||
__all__ = ["CaptureContext", "build_capture_context"]
|
||||
__all__ = ["CaptureContext", "build_capture_context", "capture_message"]
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
from pyrogram import Client
|
||||
from pyrogram.types import Message
|
||||
|
||||
from userbot.modules.capture import repository
|
||||
from userbot.modules.capture.context import CaptureContext
|
||||
from userbot.modules.media import capture_media, self_destruct_ttl
|
||||
from utils.policy.models import CaptureToggles
|
||||
|
||||
|
||||
def sender_id(message: Message) -> int | None:
|
||||
if message.from_user is not None:
|
||||
return message.from_user.id
|
||||
if message.sender_chat is not None:
|
||||
return message.sender_chat.id
|
||||
return None
|
||||
|
||||
|
||||
def callbacks(message: Message) -> list[tuple[int, str | None, bytes | None]]:
|
||||
rows = getattr(message.reply_markup, "inline_keyboard", None)
|
||||
if not rows:
|
||||
return []
|
||||
buttons: list[tuple[int, str | None, bytes | None]] = []
|
||||
position = 0
|
||||
for row in rows:
|
||||
for button in row:
|
||||
data = button.callback_data
|
||||
if data is not None:
|
||||
encoded = data.encode() if isinstance(data, str) else data
|
||||
buttons.append((position, button.text, encoded))
|
||||
position += 1
|
||||
return buttons
|
||||
|
||||
|
||||
async def capture_message(
|
||||
client: Client, message: Message, ctx: CaptureContext, toggles: CaptureToggles
|
||||
) -> None:
|
||||
if message.empty or message.chat is None or message.date is None:
|
||||
return
|
||||
chat_id = message.chat.id or 0
|
||||
await repository.upsert_message(
|
||||
ctx.pool,
|
||||
ctx.account_id,
|
||||
chat_id,
|
||||
message.id,
|
||||
message.date,
|
||||
sender_id(message),
|
||||
message.text or message.caption,
|
||||
str(message),
|
||||
has_media=message.media is not None,
|
||||
is_self_destruct=self_destruct_ttl(message) is not None,
|
||||
)
|
||||
await capture_media(client, message, ctx, chat_id, message.id, toggles)
|
||||
buttons = callbacks(message)
|
||||
if buttons:
|
||||
await repository.insert_callbacks(
|
||||
ctx.pool, ctx.account_id, chat_id, message.id, buttons
|
||||
)
|
||||
@@ -1,3 +1,4 @@
|
||||
from userbot.modules.jobs import handlers
|
||||
from userbot.modules.jobs.consumer import JOBS_CHANGED_CHANNEL, JobConsumer
|
||||
from userbot.modules.jobs.context import JobContext
|
||||
from userbot.modules.jobs.registry import JOB_HANDLERS, JobHandler, register
|
||||
@@ -8,5 +9,6 @@ __all__ = [
|
||||
"JobConsumer",
|
||||
"JobContext",
|
||||
"JobHandler",
|
||||
"handlers",
|
||||
"register",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from userbot.modules.jobs.handlers import backfill, fetch_media
|
||||
|
||||
__all__ = ["backfill", "fetch_media"]
|
||||
@@ -0,0 +1,33 @@
|
||||
from userbot.modules.capture import capture_message
|
||||
from userbot.modules.jobs.context import JobContext
|
||||
from userbot.modules.jobs.registry import register
|
||||
from utils.policy.models import CaptureToggles
|
||||
|
||||
SAVE_EVERY = 100
|
||||
|
||||
|
||||
@register("backfill")
|
||||
async def backfill(ctx: JobContext) -> None:
|
||||
client = ctx.client
|
||||
if client is None:
|
||||
return
|
||||
capture = getattr(client, "capture", None)
|
||||
if capture is None:
|
||||
return
|
||||
chat_id = ctx.job.params["chat_id"]
|
||||
toggles = CaptureToggles(
|
||||
messages=True,
|
||||
media=bool(ctx.job.params.get("media")),
|
||||
self_destruct_media=False,
|
||||
)
|
||||
max_id = (ctx.job.cursor or {}).get("max_id", 0)
|
||||
processed = ctx.job.progress.get("processed", 0)
|
||||
kwargs = {"max_id": max_id} if max_id else {}
|
||||
async for message in client.get_chat_history(chat_id, **kwargs):
|
||||
await capture_message(client, message, capture, toggles)
|
||||
processed += 1
|
||||
if processed % SAVE_EVERY == 0:
|
||||
next_max = message.id - 1
|
||||
await ctx.save_cursor({"max_id": next_max})
|
||||
await ctx.report_progress({"processed": processed, "max_id": next_max})
|
||||
await ctx.report_progress({"processed": processed, "done": True})
|
||||
@@ -0,0 +1,23 @@
|
||||
from userbot.modules.jobs.context import JobContext
|
||||
from userbot.modules.jobs.registry import register
|
||||
from userbot.modules.media import capture_media
|
||||
from utils.policy.models import CaptureToggles
|
||||
|
||||
|
||||
@register("fetch_media")
|
||||
async def fetch_media(ctx: JobContext) -> None:
|
||||
client = ctx.client
|
||||
if client is None:
|
||||
return
|
||||
capture = getattr(client, "capture", None)
|
||||
if capture is None:
|
||||
return
|
||||
chat_id = ctx.job.params["chat_id"]
|
||||
message_id = ctx.job.params["message_id"]
|
||||
message = await client.get_messages(chat_id, message_id)
|
||||
if isinstance(message, list):
|
||||
message = message[0] if message else None
|
||||
if message is None or message.empty:
|
||||
return
|
||||
toggles = CaptureToggles(media=True, self_destruct_media=True)
|
||||
await capture_media(client, message, capture, chat_id, message_id, toggles)
|
||||
Reference in New Issue
Block a user