feat: add jobs queue
This commit is contained in:
@@ -0,0 +1,72 @@
|
||||
"""jobs queue
|
||||
|
||||
Revision ID: d4b9f2e6a1c7
|
||||
Revises: c3a8e5f1b6d2
|
||||
Create Date: 2026-05-29 19:30:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = "d4b9f2e6a1c7"
|
||||
down_revision: str | None = "c3a8e5f1b6d2"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"jobs",
|
||||
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("account_id", sa.Integer(), nullable=False),
|
||||
sa.Column("kind", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"status", sa.String(), nullable=False, server_default=sa.text("'pending'")
|
||||
),
|
||||
sa.Column(
|
||||
"params",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default=sa.text("'{}'::jsonb"),
|
||||
),
|
||||
sa.Column("cursor", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column(
|
||||
"progress",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default=sa.text("'{}'::jsonb"),
|
||||
),
|
||||
sa.Column(
|
||||
"flood_waits", sa.Integer(), nullable=False, server_default=sa.text("0")
|
||||
),
|
||||
sa.Column(
|
||||
"attempts", sa.Integer(), nullable=False, server_default=sa.text("0")
|
||||
),
|
||||
sa.Column("error", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("started_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("finished_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX ix_jobs_pending ON jobs (account_id, id) WHERE status = 'pending'"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("jobs")
|
||||
@@ -0,0 +1,12 @@
|
||||
from userbot.modules.jobs.consumer import JOBS_CHANGED_CHANNEL, JobConsumer
|
||||
from userbot.modules.jobs.context import JobContext
|
||||
from userbot.modules.jobs.registry import JOB_HANDLERS, JobHandler, register
|
||||
|
||||
__all__ = [
|
||||
"JOBS_CHANGED_CHANNEL",
|
||||
"JOB_HANDLERS",
|
||||
"JobConsumer",
|
||||
"JobContext",
|
||||
"JobHandler",
|
||||
"register",
|
||||
]
|
||||
@@ -0,0 +1,85 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
|
||||
import asyncpg
|
||||
from pyrogram import Client
|
||||
from pyrogram.errors import FloodPremiumWait, FloodWait
|
||||
|
||||
from userbot.modules.jobs import repository
|
||||
from userbot.modules.jobs.context import JobContext
|
||||
from userbot.modules.jobs.registry import JOB_HANDLERS, JobHandler
|
||||
from utils.db.models import Job, JobStatus
|
||||
from utils.env import env
|
||||
from utils.logging import logger
|
||||
|
||||
JOBS_CHANGED_CHANNEL = "jobs_changed"
|
||||
POLL_INTERVAL_SECONDS = 60.0
|
||||
|
||||
|
||||
class JobConsumer:
|
||||
def __init__(
|
||||
self,
|
||||
client: Client | None,
|
||||
pool: asyncpg.Pool,
|
||||
account_id: int,
|
||||
registry: dict[str, JobHandler] | None = None,
|
||||
) -> None:
|
||||
self.client = client
|
||||
self.pool = pool
|
||||
self.account_id = account_id
|
||||
self.registry = registry if registry is not None else JOB_HANDLERS
|
||||
self._wake = asyncio.Event()
|
||||
|
||||
async def execute(self, job: Job) -> None:
|
||||
ctx = JobContext(self.client, self.pool, self.account_id, job)
|
||||
handler = self.registry.get(job.kind)
|
||||
if handler is None:
|
||||
await repository.finish(
|
||||
self.pool, ctx.job_id, JobStatus.FAILED, f"no handler for {job.kind!r}"
|
||||
)
|
||||
return
|
||||
while True:
|
||||
try:
|
||||
await handler(ctx)
|
||||
except (FloodWait, FloodPremiumWait) as exc:
|
||||
await repository.bump_flood_wait(self.pool, ctx.job_id)
|
||||
wait = exc.value if isinstance(exc.value, int) else 1
|
||||
logger.warning(
|
||||
f"[yellow]FloodWait {wait}s on job {ctx.job_id} ({job.kind})[/]"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
except Exception as exc:
|
||||
logger.exception(f"Job {ctx.job_id} ({job.kind}) failed")
|
||||
await repository.finish(
|
||||
self.pool, ctx.job_id, JobStatus.FAILED, str(exc)
|
||||
)
|
||||
return
|
||||
else:
|
||||
await repository.finish(self.pool, ctx.job_id, JobStatus.DONE)
|
||||
return
|
||||
|
||||
async def drain(self) -> None:
|
||||
while True:
|
||||
job = await repository.claim_next(self.pool, self.account_id)
|
||||
if job is None:
|
||||
return
|
||||
await self.execute(job)
|
||||
|
||||
async def run(self) -> None:
|
||||
requeued = await repository.requeue_running(self.pool, self.account_id)
|
||||
if requeued:
|
||||
logger.info(f"[yellow]Requeued {requeued} stale running job(s).[/]")
|
||||
conn = await asyncpg.connect(dsn=env.db.connection_url)
|
||||
await conn.add_listener(JOBS_CHANGED_CHANNEL, lambda *_: self._wake.set())
|
||||
logger.info(f"[green]Job consumer running for account {self.account_id}.[/]")
|
||||
try:
|
||||
while True:
|
||||
await self.drain()
|
||||
self._wake.clear()
|
||||
with contextlib.suppress(TimeoutError):
|
||||
await asyncio.wait_for(
|
||||
self._wake.wait(), timeout=POLL_INTERVAL_SECONDS
|
||||
)
|
||||
finally:
|
||||
with contextlib.suppress(Exception):
|
||||
await conn.close()
|
||||
@@ -0,0 +1,26 @@
|
||||
from typing import Any
|
||||
|
||||
import asyncpg
|
||||
from pyrogram import Client
|
||||
|
||||
from userbot.modules.jobs import repository
|
||||
from utils.db.models import Job
|
||||
|
||||
|
||||
class JobContext:
|
||||
def __init__(
|
||||
self, client: Client | None, pool: asyncpg.Pool, account_id: int, job: Job
|
||||
) -> None:
|
||||
self.client = client
|
||||
self.pool = pool
|
||||
self.account_id = account_id
|
||||
self.job = job
|
||||
self.job_id = job.id or 0
|
||||
|
||||
async def save_cursor(self, cursor: dict[str, Any]) -> None:
|
||||
self.job.cursor = cursor
|
||||
await repository.save_cursor(self.pool, self.job_id, cursor)
|
||||
|
||||
async def report_progress(self, progress: dict[str, Any]) -> None:
|
||||
self.job.progress = progress
|
||||
await repository.report_progress(self.pool, self.job_id, progress)
|
||||
@@ -0,0 +1,15 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from userbot.modules.jobs.context import JobContext
|
||||
|
||||
JobHandler = Callable[[JobContext], Awaitable[None]]
|
||||
|
||||
JOB_HANDLERS: dict[str, JobHandler] = {}
|
||||
|
||||
|
||||
def register(kind: str) -> Callable[[JobHandler], JobHandler]:
|
||||
def decorator(handler: JobHandler) -> JobHandler:
|
||||
JOB_HANDLERS[kind] = handler
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
@@ -0,0 +1,91 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import asyncpg
|
||||
|
||||
from utils.db.models import Job, JobStatus
|
||||
|
||||
|
||||
def _row_to_job(row: asyncpg.Record) -> Job:
|
||||
data = dict(row)
|
||||
data["params"] = json.loads(data["params"])
|
||||
data["progress"] = json.loads(data["progress"])
|
||||
data["cursor"] = json.loads(data["cursor"]) if data["cursor"] is not None else None
|
||||
return Job(**data)
|
||||
|
||||
|
||||
async def enqueue(
|
||||
pool: asyncpg.Pool, account_id: int, kind: str, params: dict[str, Any]
|
||||
) -> int:
|
||||
return await pool.fetchval(
|
||||
"INSERT INTO jobs (account_id, kind, params) "
|
||||
"VALUES ($1, $2, $3::jsonb) RETURNING id",
|
||||
account_id,
|
||||
kind,
|
||||
json.dumps(params),
|
||||
)
|
||||
|
||||
|
||||
async def claim_next(pool: asyncpg.Pool, account_id: int) -> Job | None:
|
||||
row = await pool.fetchrow(
|
||||
"UPDATE jobs SET status = 'running', started_at = now(), "
|
||||
"updated_at = now(), attempts = attempts + 1 "
|
||||
"WHERE id = ("
|
||||
" SELECT id FROM jobs WHERE account_id = $1 AND status = 'pending' "
|
||||
" ORDER BY id FOR UPDATE SKIP LOCKED LIMIT 1"
|
||||
") RETURNING *",
|
||||
account_id,
|
||||
)
|
||||
return _row_to_job(row) if row else None
|
||||
|
||||
|
||||
async def requeue_running(pool: asyncpg.Pool, account_id: int) -> int:
|
||||
result = await pool.execute(
|
||||
"UPDATE jobs SET status = 'pending', updated_at = now() "
|
||||
"WHERE account_id = $1 AND status = 'running'",
|
||||
account_id,
|
||||
)
|
||||
return int(result.split()[-1])
|
||||
|
||||
|
||||
async def save_cursor(pool: asyncpg.Pool, job_id: int, cursor: dict[str, Any]) -> None:
|
||||
await pool.execute(
|
||||
"UPDATE jobs SET cursor = $2::jsonb, updated_at = now() WHERE id = $1",
|
||||
job_id,
|
||||
json.dumps(cursor),
|
||||
)
|
||||
|
||||
|
||||
async def report_progress(
|
||||
pool: asyncpg.Pool, job_id: int, progress: dict[str, Any]
|
||||
) -> None:
|
||||
await pool.execute(
|
||||
"UPDATE jobs SET progress = $2::jsonb, updated_at = now() WHERE id = $1",
|
||||
job_id,
|
||||
json.dumps(progress),
|
||||
)
|
||||
|
||||
|
||||
async def bump_flood_wait(pool: asyncpg.Pool, job_id: int) -> None:
|
||||
await pool.execute(
|
||||
"UPDATE jobs SET flood_waits = flood_waits + 1, updated_at = now() "
|
||||
"WHERE id = $1",
|
||||
job_id,
|
||||
)
|
||||
|
||||
|
||||
async def finish(
|
||||
pool: asyncpg.Pool, job_id: int, status: JobStatus, error: str | None = None
|
||||
) -> None:
|
||||
await pool.execute(
|
||||
"UPDATE jobs SET status = $2, error = $3, finished_at = now(), "
|
||||
"updated_at = now() WHERE id = $1",
|
||||
job_id,
|
||||
status.value,
|
||||
error,
|
||||
)
|
||||
|
||||
|
||||
async def get_job(pool: asyncpg.Pool, job_id: int) -> Job | None:
|
||||
row = await pool.fetchrow("SELECT * FROM jobs WHERE id = $1", job_id)
|
||||
return _row_to_job(row) if row else None
|
||||
@@ -9,6 +9,7 @@ import uvloop
|
||||
from dependencies.container import container
|
||||
from userbot import PyroClient
|
||||
from userbot.modules.capture import build_capture_context
|
||||
from userbot.modules.jobs import JobConsumer
|
||||
from utils.env import env
|
||||
from utils.logging import logger, setup_logging
|
||||
from utils.storage import ContentAddressedStorage
|
||||
@@ -101,6 +102,7 @@ async def runner() -> None:
|
||||
|
||||
clients: list[PyroClient] = []
|
||||
reload_tasks: set[asyncio.Task] = set()
|
||||
consumer_tasks: list[asyncio.Task] = []
|
||||
listen_conn: asyncpg.Connection | None = None
|
||||
try:
|
||||
for session_path in session_files:
|
||||
@@ -116,12 +118,19 @@ async def runner() -> None:
|
||||
account_id = await _sync_account(pool, client, session_name)
|
||||
if account_id is not None:
|
||||
await _setup_capture(pool, client, account_id, storage)
|
||||
consumer = JobConsumer(client, pool, account_id)
|
||||
consumer_tasks.append(asyncio.create_task(consumer.run()))
|
||||
|
||||
if clients:
|
||||
listen_conn = await _listen_policy_changes(clients, reload_tasks)
|
||||
logger.info("[green]Userbot running.[/]")
|
||||
await asyncio.Event().wait()
|
||||
finally:
|
||||
for task in consumer_tasks:
|
||||
task.cancel()
|
||||
for task in consumer_tasks:
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
if listen_conn is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
await listen_conn.close()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import BigInteger, Column, DateTime, LargeBinary, func
|
||||
@@ -6,6 +7,13 @@ from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class JobStatus(StrEnum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
DONE = "done"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class Account(SQLModel, table=True):
|
||||
__tablename__ = "accounts"
|
||||
|
||||
@@ -158,6 +166,44 @@ class Reaction(SQLModel, table=True):
|
||||
)
|
||||
|
||||
|
||||
class Job(SQLModel, table=True):
|
||||
__tablename__ = "jobs"
|
||||
|
||||
id: int | None = Field(default=None, sa_column=Column(BigInteger, primary_key=True))
|
||||
account_id: int
|
||||
kind: str
|
||||
status: str = JobStatus.PENDING
|
||||
params: dict[str, Any] = Field(
|
||||
default_factory=dict, sa_column=Column(JSONB, nullable=False)
|
||||
)
|
||||
cursor: dict[str, Any] | None = Field(default=None, sa_column=Column(JSONB))
|
||||
progress: dict[str, Any] = Field(
|
||||
default_factory=dict, sa_column=Column(JSONB, nullable=False)
|
||||
)
|
||||
flood_waits: int = 0
|
||||
attempts: int = 0
|
||||
error: str | None = None
|
||||
created_at: datetime = Field(
|
||||
sa_column=Column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
)
|
||||
updated_at: datetime = Field(
|
||||
sa_column=Column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
)
|
||||
started_at: datetime | None = Field(
|
||||
default=None, sa_column=Column(DateTime(timezone=True))
|
||||
)
|
||||
finished_at: datetime | None = Field(
|
||||
default=None, sa_column=Column(DateTime(timezone=True))
|
||||
)
|
||||
|
||||
|
||||
class CapturePolicy(SQLModel, table=True):
|
||||
__tablename__ = "capture_policy"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user