feat: add jobs queue

This commit is contained in:
h
2026-05-29 19:04:36 +02:00
parent 3c1a12750c
commit 4a471df8f1
8 changed files with 356 additions and 0 deletions
@@ -0,0 +1,72 @@
"""jobs queue
Revision ID: d4b9f2e6a1c7
Revises: c3a8e5f1b6d2
Create Date: 2026-05-29 19:30:00.000000
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
revision: str = "d4b9f2e6a1c7"
down_revision: str | None = "c3a8e5f1b6d2"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
op.create_table(
"jobs",
sa.Column("id", sa.BigInteger(), nullable=False),
sa.Column("account_id", sa.Integer(), nullable=False),
sa.Column("kind", sa.String(), nullable=False),
sa.Column(
"status", sa.String(), nullable=False, server_default=sa.text("'pending'")
),
sa.Column(
"params",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
server_default=sa.text("'{}'::jsonb"),
),
sa.Column("cursor", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column(
"progress",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
server_default=sa.text("'{}'::jsonb"),
),
sa.Column(
"flood_waits", sa.Integer(), nullable=False, server_default=sa.text("0")
),
sa.Column(
"attempts", sa.Integer(), nullable=False, server_default=sa.text("0")
),
sa.Column("error", sa.String(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("started_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("finished_at", sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.execute(
"CREATE INDEX ix_jobs_pending ON jobs (account_id, id) WHERE status = 'pending'"
)
def downgrade() -> None:
op.drop_table("jobs")
@@ -0,0 +1,12 @@
from userbot.modules.jobs.consumer import JOBS_CHANGED_CHANNEL, JobConsumer
from userbot.modules.jobs.context import JobContext
from userbot.modules.jobs.registry import JOB_HANDLERS, JobHandler, register
__all__ = [
"JOBS_CHANGED_CHANNEL",
"JOB_HANDLERS",
"JobConsumer",
"JobContext",
"JobHandler",
"register",
]
@@ -0,0 +1,85 @@
import asyncio
import contextlib
import asyncpg
from pyrogram import Client
from pyrogram.errors import FloodPremiumWait, FloodWait
from userbot.modules.jobs import repository
from userbot.modules.jobs.context import JobContext
from userbot.modules.jobs.registry import JOB_HANDLERS, JobHandler
from utils.db.models import Job, JobStatus
from utils.env import env
from utils.logging import logger
JOBS_CHANGED_CHANNEL = "jobs_changed"
POLL_INTERVAL_SECONDS = 60.0
class JobConsumer:
def __init__(
self,
client: Client | None,
pool: asyncpg.Pool,
account_id: int,
registry: dict[str, JobHandler] | None = None,
) -> None:
self.client = client
self.pool = pool
self.account_id = account_id
self.registry = registry if registry is not None else JOB_HANDLERS
self._wake = asyncio.Event()
async def execute(self, job: Job) -> None:
ctx = JobContext(self.client, self.pool, self.account_id, job)
handler = self.registry.get(job.kind)
if handler is None:
await repository.finish(
self.pool, ctx.job_id, JobStatus.FAILED, f"no handler for {job.kind!r}"
)
return
while True:
try:
await handler(ctx)
except (FloodWait, FloodPremiumWait) as exc:
await repository.bump_flood_wait(self.pool, ctx.job_id)
wait = exc.value if isinstance(exc.value, int) else 1
logger.warning(
f"[yellow]FloodWait {wait}s on job {ctx.job_id} ({job.kind})[/]"
)
await asyncio.sleep(wait)
except Exception as exc:
logger.exception(f"Job {ctx.job_id} ({job.kind}) failed")
await repository.finish(
self.pool, ctx.job_id, JobStatus.FAILED, str(exc)
)
return
else:
await repository.finish(self.pool, ctx.job_id, JobStatus.DONE)
return
async def drain(self) -> None:
while True:
job = await repository.claim_next(self.pool, self.account_id)
if job is None:
return
await self.execute(job)
async def run(self) -> None:
requeued = await repository.requeue_running(self.pool, self.account_id)
if requeued:
logger.info(f"[yellow]Requeued {requeued} stale running job(s).[/]")
conn = await asyncpg.connect(dsn=env.db.connection_url)
await conn.add_listener(JOBS_CHANGED_CHANNEL, lambda *_: self._wake.set())
logger.info(f"[green]Job consumer running for account {self.account_id}.[/]")
try:
while True:
await self.drain()
self._wake.clear()
with contextlib.suppress(TimeoutError):
await asyncio.wait_for(
self._wake.wait(), timeout=POLL_INTERVAL_SECONDS
)
finally:
with contextlib.suppress(Exception):
await conn.close()
@@ -0,0 +1,26 @@
from typing import Any
import asyncpg
from pyrogram import Client
from userbot.modules.jobs import repository
from utils.db.models import Job
class JobContext:
def __init__(
self, client: Client | None, pool: asyncpg.Pool, account_id: int, job: Job
) -> None:
self.client = client
self.pool = pool
self.account_id = account_id
self.job = job
self.job_id = job.id or 0
async def save_cursor(self, cursor: dict[str, Any]) -> None:
self.job.cursor = cursor
await repository.save_cursor(self.pool, self.job_id, cursor)
async def report_progress(self, progress: dict[str, Any]) -> None:
self.job.progress = progress
await repository.report_progress(self.pool, self.job_id, progress)
@@ -0,0 +1,15 @@
from collections.abc import Awaitable, Callable
from userbot.modules.jobs.context import JobContext
JobHandler = Callable[[JobContext], Awaitable[None]]
JOB_HANDLERS: dict[str, JobHandler] = {}
def register(kind: str) -> Callable[[JobHandler], JobHandler]:
def decorator(handler: JobHandler) -> JobHandler:
JOB_HANDLERS[kind] = handler
return handler
return decorator
@@ -0,0 +1,91 @@
import json
from typing import Any
import asyncpg
from utils.db.models import Job, JobStatus
def _row_to_job(row: asyncpg.Record) -> Job:
data = dict(row)
data["params"] = json.loads(data["params"])
data["progress"] = json.loads(data["progress"])
data["cursor"] = json.loads(data["cursor"]) if data["cursor"] is not None else None
return Job(**data)
async def enqueue(
pool: asyncpg.Pool, account_id: int, kind: str, params: dict[str, Any]
) -> int:
return await pool.fetchval(
"INSERT INTO jobs (account_id, kind, params) "
"VALUES ($1, $2, $3::jsonb) RETURNING id",
account_id,
kind,
json.dumps(params),
)
async def claim_next(pool: asyncpg.Pool, account_id: int) -> Job | None:
row = await pool.fetchrow(
"UPDATE jobs SET status = 'running', started_at = now(), "
"updated_at = now(), attempts = attempts + 1 "
"WHERE id = ("
" SELECT id FROM jobs WHERE account_id = $1 AND status = 'pending' "
" ORDER BY id FOR UPDATE SKIP LOCKED LIMIT 1"
") RETURNING *",
account_id,
)
return _row_to_job(row) if row else None
async def requeue_running(pool: asyncpg.Pool, account_id: int) -> int:
result = await pool.execute(
"UPDATE jobs SET status = 'pending', updated_at = now() "
"WHERE account_id = $1 AND status = 'running'",
account_id,
)
return int(result.split()[-1])
async def save_cursor(pool: asyncpg.Pool, job_id: int, cursor: dict[str, Any]) -> None:
await pool.execute(
"UPDATE jobs SET cursor = $2::jsonb, updated_at = now() WHERE id = $1",
job_id,
json.dumps(cursor),
)
async def report_progress(
pool: asyncpg.Pool, job_id: int, progress: dict[str, Any]
) -> None:
await pool.execute(
"UPDATE jobs SET progress = $2::jsonb, updated_at = now() WHERE id = $1",
job_id,
json.dumps(progress),
)
async def bump_flood_wait(pool: asyncpg.Pool, job_id: int) -> None:
await pool.execute(
"UPDATE jobs SET flood_waits = flood_waits + 1, updated_at = now() "
"WHERE id = $1",
job_id,
)
async def finish(
pool: asyncpg.Pool, job_id: int, status: JobStatus, error: str | None = None
) -> None:
await pool.execute(
"UPDATE jobs SET status = $2, error = $3, finished_at = now(), "
"updated_at = now() WHERE id = $1",
job_id,
status.value,
error,
)
async def get_job(pool: asyncpg.Pool, job_id: int) -> Job | None:
row = await pool.fetchrow("SELECT * FROM jobs WHERE id = $1", job_id)
return _row_to_job(row) if row else None
+9
View File
@@ -9,6 +9,7 @@ import uvloop
from dependencies.container import container from dependencies.container import container
from userbot import PyroClient from userbot import PyroClient
from userbot.modules.capture import build_capture_context from userbot.modules.capture import build_capture_context
from userbot.modules.jobs import JobConsumer
from utils.env import env from utils.env import env
from utils.logging import logger, setup_logging from utils.logging import logger, setup_logging
from utils.storage import ContentAddressedStorage from utils.storage import ContentAddressedStorage
@@ -101,6 +102,7 @@ async def runner() -> None:
clients: list[PyroClient] = [] clients: list[PyroClient] = []
reload_tasks: set[asyncio.Task] = set() reload_tasks: set[asyncio.Task] = set()
consumer_tasks: list[asyncio.Task] = []
listen_conn: asyncpg.Connection | None = None listen_conn: asyncpg.Connection | None = None
try: try:
for session_path in session_files: for session_path in session_files:
@@ -116,12 +118,19 @@ async def runner() -> None:
account_id = await _sync_account(pool, client, session_name) account_id = await _sync_account(pool, client, session_name)
if account_id is not None: if account_id is not None:
await _setup_capture(pool, client, account_id, storage) await _setup_capture(pool, client, account_id, storage)
consumer = JobConsumer(client, pool, account_id)
consumer_tasks.append(asyncio.create_task(consumer.run()))
if clients: if clients:
listen_conn = await _listen_policy_changes(clients, reload_tasks) listen_conn = await _listen_policy_changes(clients, reload_tasks)
logger.info("[green]Userbot running.[/]") logger.info("[green]Userbot running.[/]")
await asyncio.Event().wait() await asyncio.Event().wait()
finally: finally:
for task in consumer_tasks:
task.cancel()
for task in consumer_tasks:
with contextlib.suppress(asyncio.CancelledError):
await task
if listen_conn is not None: if listen_conn is not None:
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
await listen_conn.close() await listen_conn.close()
+46
View File
@@ -1,4 +1,5 @@
from datetime import datetime from datetime import datetime
from enum import StrEnum
from typing import Any from typing import Any
from sqlalchemy import BigInteger, Column, DateTime, LargeBinary, func from sqlalchemy import BigInteger, Column, DateTime, LargeBinary, func
@@ -6,6 +7,13 @@ from sqlalchemy.dialects.postgresql import JSONB
from sqlmodel import Field, SQLModel from sqlmodel import Field, SQLModel
class JobStatus(StrEnum):
PENDING = "pending"
RUNNING = "running"
DONE = "done"
FAILED = "failed"
class Account(SQLModel, table=True): class Account(SQLModel, table=True):
__tablename__ = "accounts" __tablename__ = "accounts"
@@ -158,6 +166,44 @@ class Reaction(SQLModel, table=True):
) )
class Job(SQLModel, table=True):
__tablename__ = "jobs"
id: int | None = Field(default=None, sa_column=Column(BigInteger, primary_key=True))
account_id: int
kind: str
status: str = JobStatus.PENDING
params: dict[str, Any] = Field(
default_factory=dict, sa_column=Column(JSONB, nullable=False)
)
cursor: dict[str, Any] | None = Field(default=None, sa_column=Column(JSONB))
progress: dict[str, Any] = Field(
default_factory=dict, sa_column=Column(JSONB, nullable=False)
)
flood_waits: int = 0
attempts: int = 0
error: str | None = None
created_at: datetime = Field(
sa_column=Column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
)
updated_at: datetime = Field(
sa_column=Column(
DateTime(timezone=True),
nullable=False,
server_default=func.now(),
onupdate=func.now(),
)
)
started_at: datetime | None = Field(
default=None, sa_column=Column(DateTime(timezone=True))
)
finished_at: datetime | None = Field(
default=None, sa_column=Column(DateTime(timezone=True))
)
class CapturePolicy(SQLModel, table=True): class CapturePolicy(SQLModel, table=True):
__tablename__ = "capture_policy" __tablename__ = "capture_policy"