feat: add jobs queue

This commit is contained in:
h
2026-05-29 19:04:36 +02:00
parent 3c1a12750c
commit 4a471df8f1
8 changed files with 356 additions and 0 deletions
@@ -0,0 +1,72 @@
"""jobs queue
Revision ID: d4b9f2e6a1c7
Revises: c3a8e5f1b6d2
Create Date: 2026-05-29 19:30:00.000000
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
revision: str = "d4b9f2e6a1c7"
down_revision: str | None = "c3a8e5f1b6d2"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
op.create_table(
"jobs",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("account_id", sa.Integer(), nullable=False),
sa.Column("kind", sa.String(), nullable=False),
sa.Column(
"status", sa.String(), nullable=False, server_default=sa.text("'pending'")
),
sa.Column(
"params",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
server_default=sa.text("'{}'::jsonb"),
),
sa.Column("cursor", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column(
"progress",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
server_default=sa.text("'{}'::jsonb"),
),
sa.Column(
"flood_waits", sa.Integer(), nullable=False, server_default=sa.text("0")
),
sa.Column(
"attempts", sa.Integer(), nullable=False, server_default=sa.text("0")
),
sa.Column("error", sa.String(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("started_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("finished_at", sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.execute(
"CREATE INDEX ix_jobs_pending ON jobs (account_id, id) WHERE status = 'pending'"
)
def downgrade() -> None:
op.drop_table("jobs")
@@ -0,0 +1,12 @@
from userbot.modules.jobs.consumer import JOBS_CHANGED_CHANNEL, JobConsumer
from userbot.modules.jobs.context import JobContext
from userbot.modules.jobs.registry import JOB_HANDLERS, JobHandler, register
__all__ = [
"JOBS_CHANGED_CHANNEL",
"JOB_HANDLERS",
"JobConsumer",
"JobContext",
"JobHandler",
"register",
]
@@ -0,0 +1,85 @@
import asyncio
import contextlib
import asyncpg
from pyrogram import Client
from pyrogram.errors import FloodPremiumWait, FloodWait
from userbot.modules.jobs import repository
from userbot.modules.jobs.context import JobContext
from userbot.modules.jobs.registry import JOB_HANDLERS, JobHandler
from utils.db.models import Job, JobStatus
from utils.env import env
from utils.logging import logger
JOBS_CHANGED_CHANNEL = "jobs_changed"
POLL_INTERVAL_SECONDS = 60.0
class JobConsumer:
def __init__(
self,
client: Client | None,
pool: asyncpg.Pool,
account_id: int,
registry: dict[str, JobHandler] | None = None,
) -> None:
self.client = client
self.pool = pool
self.account_id = account_id
self.registry = registry if registry is not None else JOB_HANDLERS
self._wake = asyncio.Event()
async def execute(self, job: Job) -> None:
ctx = JobContext(self.client, self.pool, self.account_id, job)
handler = self.registry.get(job.kind)
if handler is None:
await repository.finish(
self.pool, ctx.job_id, JobStatus.FAILED, f"no handler for {job.kind!r}"
)
return
while True:
try:
await handler(ctx)
except (FloodWait, FloodPremiumWait) as exc:
await repository.bump_flood_wait(self.pool, ctx.job_id)
wait = exc.value if isinstance(exc.value, int) else 1
logger.warning(
f"[yellow]FloodWait {wait}s on job {ctx.job_id} ({job.kind})[/]"
)
await asyncio.sleep(wait)
except Exception as exc:
logger.exception(f"Job {ctx.job_id} ({job.kind}) failed")
await repository.finish(
self.pool, ctx.job_id, JobStatus.FAILED, str(exc)
)
return
else:
await repository.finish(self.pool, ctx.job_id, JobStatus.DONE)
return
async def drain(self) -> None:
while True:
job = await repository.claim_next(self.pool, self.account_id)
if job is None:
return
await self.execute(job)
async def run(self) -> None:
requeued = await repository.requeue_running(self.pool, self.account_id)
if requeued:
logger.info(f"[yellow]Requeued {requeued} stale running job(s).[/]")
conn = await asyncpg.connect(dsn=env.db.connection_url)
await conn.add_listener(JOBS_CHANGED_CHANNEL, lambda *_: self._wake.set())
logger.info(f"[green]Job consumer running for account {self.account_id}.[/]")
try:
while True:
await self.drain()
self._wake.clear()
with contextlib.suppress(TimeoutError):
await asyncio.wait_for(
self._wake.wait(), timeout=POLL_INTERVAL_SECONDS
)
finally:
with contextlib.suppress(Exception):
await conn.close()
@@ -0,0 +1,26 @@
from typing import Any
import asyncpg
from pyrogram import Client
from userbot.modules.jobs import repository
from utils.db.models import Job
class JobContext:
def __init__(
self, client: Client | None, pool: asyncpg.Pool, account_id: int, job: Job
) -> None:
self.client = client
self.pool = pool
self.account_id = account_id
self.job = job
self.job_id = job.id or 0
async def save_cursor(self, cursor: dict[str, Any]) -> None:
self.job.cursor = cursor
await repository.save_cursor(self.pool, self.job_id, cursor)
async def report_progress(self, progress: dict[str, Any]) -> None:
self.job.progress = progress
await repository.report_progress(self.pool, self.job_id, progress)
@@ -0,0 +1,15 @@
from collections.abc import Awaitable, Callable
from userbot.modules.jobs.context import JobContext
JobHandler = Callable[[JobContext], Awaitable[None]]
JOB_HANDLERS: dict[str, JobHandler] = {}
def register(kind: str) -> Callable[[JobHandler], JobHandler]:
def decorator(handler: JobHandler) -> JobHandler:
JOB_HANDLERS[kind] = handler
return handler
return decorator
@@ -0,0 +1,91 @@
import json
from typing import Any
import asyncpg
from utils.db.models import Job, JobStatus
def _row_to_job(row: asyncpg.Record) -> Job:
data = dict(row)
data["params"] = json.loads(data["params"])
data["progress"] = json.loads(data["progress"])
data["cursor"] = json.loads(data["cursor"]) if data["cursor"] is not None else None
return Job(**data)
async def enqueue(
pool: asyncpg.Pool, account_id: int, kind: str, params: dict[str, Any]
) -> int:
return await pool.fetchval(
"INSERT INTO jobs (account_id, kind, params) "
"VALUES ($1, $2, $3::jsonb) RETURNING id",
account_id,
kind,
json.dumps(params),
)
async def claim_next(pool: asyncpg.Pool, account_id: int) -> Job | None:
row = await pool.fetchrow(
"UPDATE jobs SET status = 'running', started_at = now(), "
"updated_at = now(), attempts = attempts + 1 "
"WHERE id = ("
" SELECT id FROM jobs WHERE account_id = $1 AND status = 'pending' "
" ORDER BY id FOR UPDATE SKIP LOCKED LIMIT 1"
") RETURNING *",
account_id,
)
return _row_to_job(row) if row else None
async def requeue_running(pool: asyncpg.Pool, account_id: int) -> int:
result = await pool.execute(
"UPDATE jobs SET status = 'pending', updated_at = now() "
"WHERE account_id = $1 AND status = 'running'",
account_id,
)
return int(result.split()[-1])
async def save_cursor(pool: asyncpg.Pool, job_id: int, cursor: dict[str, Any]) -> None:
await pool.execute(
"UPDATE jobs SET cursor = $2::jsonb, updated_at = now() WHERE id = $1",
job_id,
json.dumps(cursor),
)
async def report_progress(
pool: asyncpg.Pool, job_id: int, progress: dict[str, Any]
) -> None:
await pool.execute(
"UPDATE jobs SET progress = $2::jsonb, updated_at = now() WHERE id = $1",
job_id,
json.dumps(progress),
)
async def bump_flood_wait(pool: asyncpg.Pool, job_id: int) -> None:
await pool.execute(
"UPDATE jobs SET flood_waits = flood_waits + 1, updated_at = now() "
"WHERE id = $1",
job_id,
)
async def finish(
pool: asyncpg.Pool, job_id: int, status: JobStatus, error: str | None = None
) -> None:
await pool.execute(
"UPDATE jobs SET status = $2, error = $3, finished_at = now(), "
"updated_at = now() WHERE id = $1",
job_id,
status.value,
error,
)
async def get_job(pool: asyncpg.Pool, job_id: int) -> Job | None:
row = await pool.fetchrow("SELECT * FROM jobs WHERE id = $1", job_id)
return _row_to_job(row) if row else None
+9
View File
@@ -9,6 +9,7 @@ import uvloop
from dependencies.container import container
from userbot import PyroClient
from userbot.modules.capture import build_capture_context
from userbot.modules.jobs import JobConsumer
from utils.env import env
from utils.logging import logger, setup_logging
from utils.storage import ContentAddressedStorage
@@ -101,6 +102,7 @@ async def runner() -> None:
clients: list[PyroClient] = []
reload_tasks: set[asyncio.Task] = set()
consumer_tasks: list[asyncio.Task] = []
listen_conn: asyncpg.Connection | None = None
try:
for session_path in session_files:
@@ -116,12 +118,19 @@ async def runner() -> None:
account_id = await _sync_account(pool, client, session_name)
if account_id is not None:
await _setup_capture(pool, client, account_id, storage)
consumer = JobConsumer(client, pool, account_id)
consumer_tasks.append(asyncio.create_task(consumer.run()))
if clients:
listen_conn = await _listen_policy_changes(clients, reload_tasks)
logger.info("[green]Userbot running.[/]")
await asyncio.Event().wait()
finally:
for task in consumer_tasks:
task.cancel()
for task in consumer_tasks:
with contextlib.suppress(asyncio.CancelledError):
await task
if listen_conn is not None:
with contextlib.suppress(Exception):
await listen_conn.close()
+46
View File
@@ -1,4 +1,5 @@
from datetime import datetime
from enum import StrEnum
from typing import Any
from sqlalchemy import BigInteger, Column, DateTime, LargeBinary, func
@@ -6,6 +7,13 @@ from sqlalchemy.dialects.postgresql import JSONB
from sqlmodel import Field, SQLModel
class JobStatus(StrEnum):
PENDING = "pending"
RUNNING = "running"
DONE = "done"
FAILED = "failed"
class Account(SQLModel, table=True):
__tablename__ = "accounts"
@@ -158,6 +166,44 @@ class Reaction(SQLModel, table=True):
)
class Job(SQLModel, table=True):
__tablename__ = "jobs"
id: int | None = Field(default=None, sa_column=Column(BigInteger, primary_key=True))
account_id: int
kind: str
status: str = JobStatus.PENDING
params: dict[str, Any] = Field(
default_factory=dict, sa_column=Column(JSONB, nullable=False)
)
cursor: dict[str, Any] | None = Field(default=None, sa_column=Column(JSONB))
progress: dict[str, Any] = Field(
default_factory=dict, sa_column=Column(JSONB, nullable=False)
)
flood_waits: int = 0
attempts: int = 0
error: str | None = None
created_at: datetime = Field(
sa_column=Column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
)
updated_at: datetime = Field(
sa_column=Column(
DateTime(timezone=True),
nullable=False,
server_default=func.now(),
onupdate=func.now(),
)
)
started_at: datetime | None = Field(
default=None, sa_column=Column(DateTime(timezone=True))
)
finished_at: datetime | None = Field(
default=None, sa_column=Column(DateTime(timezone=True))
)
class CapturePolicy(SQLModel, table=True):
__tablename__ = "capture_policy"