feat: add jobs queue
This commit is contained in:
@@ -0,0 +1,72 @@
|
|||||||
|
"""jobs queue
|
||||||
|
|
||||||
|
Revision ID: d4b9f2e6a1c7
|
||||||
|
Revises: c3a8e5f1b6d2
|
||||||
|
Create Date: 2026-05-29 19:30:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
revision: str = "d4b9f2e6a1c7"
|
||||||
|
down_revision: str | None = "c3a8e5f1b6d2"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"jobs",
|
||||||
|
sa.Column("id", sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column("account_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("kind", sa.String(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"status", sa.String(), nullable=False, server_default=sa.text("'pending'")
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"params",
|
||||||
|
postgresql.JSONB(astext_type=sa.Text()),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("'{}'::jsonb"),
|
||||||
|
),
|
||||||
|
sa.Column("cursor", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"progress",
|
||||||
|
postgresql.JSONB(astext_type=sa.Text()),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("'{}'::jsonb"),
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"flood_waits", sa.Integer(), nullable=False, server_default=sa.text("0")
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"attempts", sa.Integer(), nullable=False, server_default=sa.text("0")
|
||||||
|
),
|
||||||
|
sa.Column("error", sa.String(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("started_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("finished_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX ix_jobs_pending ON jobs (account_id, id) WHERE status = 'pending'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table("jobs")
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
from userbot.modules.jobs.consumer import JOBS_CHANGED_CHANNEL, JobConsumer
|
||||||
|
from userbot.modules.jobs.context import JobContext
|
||||||
|
from userbot.modules.jobs.registry import JOB_HANDLERS, JobHandler, register
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"JOBS_CHANGED_CHANNEL",
|
||||||
|
"JOB_HANDLERS",
|
||||||
|
"JobConsumer",
|
||||||
|
"JobContext",
|
||||||
|
"JobHandler",
|
||||||
|
"register",
|
||||||
|
]
|
||||||
@@ -0,0 +1,85 @@
|
|||||||
|
import asyncio
|
||||||
|
import contextlib
|
||||||
|
|
||||||
|
import asyncpg
|
||||||
|
from pyrogram import Client
|
||||||
|
from pyrogram.errors import FloodPremiumWait, FloodWait
|
||||||
|
|
||||||
|
from userbot.modules.jobs import repository
|
||||||
|
from userbot.modules.jobs.context import JobContext
|
||||||
|
from userbot.modules.jobs.registry import JOB_HANDLERS, JobHandler
|
||||||
|
from utils.db.models import Job, JobStatus
|
||||||
|
from utils.env import env
|
||||||
|
from utils.logging import logger
|
||||||
|
|
||||||
|
JOBS_CHANGED_CHANNEL = "jobs_changed"
|
||||||
|
POLL_INTERVAL_SECONDS = 60.0
|
||||||
|
|
||||||
|
|
||||||
|
class JobConsumer:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
client: Client | None,
|
||||||
|
pool: asyncpg.Pool,
|
||||||
|
account_id: int,
|
||||||
|
registry: dict[str, JobHandler] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.client = client
|
||||||
|
self.pool = pool
|
||||||
|
self.account_id = account_id
|
||||||
|
self.registry = registry if registry is not None else JOB_HANDLERS
|
||||||
|
self._wake = asyncio.Event()
|
||||||
|
|
||||||
|
async def execute(self, job: Job) -> None:
|
||||||
|
ctx = JobContext(self.client, self.pool, self.account_id, job)
|
||||||
|
handler = self.registry.get(job.kind)
|
||||||
|
if handler is None:
|
||||||
|
await repository.finish(
|
||||||
|
self.pool, ctx.job_id, JobStatus.FAILED, f"no handler for {job.kind!r}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await handler(ctx)
|
||||||
|
except (FloodWait, FloodPremiumWait) as exc:
|
||||||
|
await repository.bump_flood_wait(self.pool, ctx.job_id)
|
||||||
|
wait = exc.value if isinstance(exc.value, int) else 1
|
||||||
|
logger.warning(
|
||||||
|
f"[yellow]FloodWait {wait}s on job {ctx.job_id} ({job.kind})[/]"
|
||||||
|
)
|
||||||
|
await asyncio.sleep(wait)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception(f"Job {ctx.job_id} ({job.kind}) failed")
|
||||||
|
await repository.finish(
|
||||||
|
self.pool, ctx.job_id, JobStatus.FAILED, str(exc)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
await repository.finish(self.pool, ctx.job_id, JobStatus.DONE)
|
||||||
|
return
|
||||||
|
|
||||||
|
async def drain(self) -> None:
|
||||||
|
while True:
|
||||||
|
job = await repository.claim_next(self.pool, self.account_id)
|
||||||
|
if job is None:
|
||||||
|
return
|
||||||
|
await self.execute(job)
|
||||||
|
|
||||||
|
async def run(self) -> None:
|
||||||
|
requeued = await repository.requeue_running(self.pool, self.account_id)
|
||||||
|
if requeued:
|
||||||
|
logger.info(f"[yellow]Requeued {requeued} stale running job(s).[/]")
|
||||||
|
conn = await asyncpg.connect(dsn=env.db.connection_url)
|
||||||
|
await conn.add_listener(JOBS_CHANGED_CHANNEL, lambda *_: self._wake.set())
|
||||||
|
logger.info(f"[green]Job consumer running for account {self.account_id}.[/]")
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
await self.drain()
|
||||||
|
self._wake.clear()
|
||||||
|
with contextlib.suppress(TimeoutError):
|
||||||
|
await asyncio.wait_for(
|
||||||
|
self._wake.wait(), timeout=POLL_INTERVAL_SECONDS
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
await conn.close()
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import asyncpg
|
||||||
|
from pyrogram import Client
|
||||||
|
|
||||||
|
from userbot.modules.jobs import repository
|
||||||
|
from utils.db.models import Job
|
||||||
|
|
||||||
|
|
||||||
|
class JobContext:
|
||||||
|
def __init__(
|
||||||
|
self, client: Client | None, pool: asyncpg.Pool, account_id: int, job: Job
|
||||||
|
) -> None:
|
||||||
|
self.client = client
|
||||||
|
self.pool = pool
|
||||||
|
self.account_id = account_id
|
||||||
|
self.job = job
|
||||||
|
self.job_id = job.id or 0
|
||||||
|
|
||||||
|
async def save_cursor(self, cursor: dict[str, Any]) -> None:
|
||||||
|
self.job.cursor = cursor
|
||||||
|
await repository.save_cursor(self.pool, self.job_id, cursor)
|
||||||
|
|
||||||
|
async def report_progress(self, progress: dict[str, Any]) -> None:
|
||||||
|
self.job.progress = progress
|
||||||
|
await repository.report_progress(self.pool, self.job_id, progress)
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
|
||||||
|
from userbot.modules.jobs.context import JobContext
|
||||||
|
|
||||||
|
JobHandler = Callable[[JobContext], Awaitable[None]]
|
||||||
|
|
||||||
|
JOB_HANDLERS: dict[str, JobHandler] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register(kind: str) -> Callable[[JobHandler], JobHandler]:
|
||||||
|
def decorator(handler: JobHandler) -> JobHandler:
|
||||||
|
JOB_HANDLERS[kind] = handler
|
||||||
|
return handler
|
||||||
|
|
||||||
|
return decorator
|
||||||
@@ -0,0 +1,91 @@
|
|||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import asyncpg
|
||||||
|
|
||||||
|
from utils.db.models import Job, JobStatus
|
||||||
|
|
||||||
|
|
||||||
|
def _row_to_job(row: asyncpg.Record) -> Job:
|
||||||
|
data = dict(row)
|
||||||
|
data["params"] = json.loads(data["params"])
|
||||||
|
data["progress"] = json.loads(data["progress"])
|
||||||
|
data["cursor"] = json.loads(data["cursor"]) if data["cursor"] is not None else None
|
||||||
|
return Job(**data)
|
||||||
|
|
||||||
|
|
||||||
|
async def enqueue(
|
||||||
|
pool: asyncpg.Pool, account_id: int, kind: str, params: dict[str, Any]
|
||||||
|
) -> int:
|
||||||
|
return await pool.fetchval(
|
||||||
|
"INSERT INTO jobs (account_id, kind, params) "
|
||||||
|
"VALUES ($1, $2, $3::jsonb) RETURNING id",
|
||||||
|
account_id,
|
||||||
|
kind,
|
||||||
|
json.dumps(params),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def claim_next(pool: asyncpg.Pool, account_id: int) -> Job | None:
|
||||||
|
row = await pool.fetchrow(
|
||||||
|
"UPDATE jobs SET status = 'running', started_at = now(), "
|
||||||
|
"updated_at = now(), attempts = attempts + 1 "
|
||||||
|
"WHERE id = ("
|
||||||
|
" SELECT id FROM jobs WHERE account_id = $1 AND status = 'pending' "
|
||||||
|
" ORDER BY id FOR UPDATE SKIP LOCKED LIMIT 1"
|
||||||
|
") RETURNING *",
|
||||||
|
account_id,
|
||||||
|
)
|
||||||
|
return _row_to_job(row) if row else None
|
||||||
|
|
||||||
|
|
||||||
|
async def requeue_running(pool: asyncpg.Pool, account_id: int) -> int:
|
||||||
|
result = await pool.execute(
|
||||||
|
"UPDATE jobs SET status = 'pending', updated_at = now() "
|
||||||
|
"WHERE account_id = $1 AND status = 'running'",
|
||||||
|
account_id,
|
||||||
|
)
|
||||||
|
return int(result.split()[-1])
|
||||||
|
|
||||||
|
|
||||||
|
async def save_cursor(pool: asyncpg.Pool, job_id: int, cursor: dict[str, Any]) -> None:
|
||||||
|
await pool.execute(
|
||||||
|
"UPDATE jobs SET cursor = $2::jsonb, updated_at = now() WHERE id = $1",
|
||||||
|
job_id,
|
||||||
|
json.dumps(cursor),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def report_progress(
|
||||||
|
pool: asyncpg.Pool, job_id: int, progress: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
|
await pool.execute(
|
||||||
|
"UPDATE jobs SET progress = $2::jsonb, updated_at = now() WHERE id = $1",
|
||||||
|
job_id,
|
||||||
|
json.dumps(progress),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def bump_flood_wait(pool: asyncpg.Pool, job_id: int) -> None:
|
||||||
|
await pool.execute(
|
||||||
|
"UPDATE jobs SET flood_waits = flood_waits + 1, updated_at = now() "
|
||||||
|
"WHERE id = $1",
|
||||||
|
job_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def finish(
|
||||||
|
pool: asyncpg.Pool, job_id: int, status: JobStatus, error: str | None = None
|
||||||
|
) -> None:
|
||||||
|
await pool.execute(
|
||||||
|
"UPDATE jobs SET status = $2, error = $3, finished_at = now(), "
|
||||||
|
"updated_at = now() WHERE id = $1",
|
||||||
|
job_id,
|
||||||
|
status.value,
|
||||||
|
error,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_job(pool: asyncpg.Pool, job_id: int) -> Job | None:
|
||||||
|
row = await pool.fetchrow("SELECT * FROM jobs WHERE id = $1", job_id)
|
||||||
|
return _row_to_job(row) if row else None
|
||||||
@@ -9,6 +9,7 @@ import uvloop
|
|||||||
from dependencies.container import container
|
from dependencies.container import container
|
||||||
from userbot import PyroClient
|
from userbot import PyroClient
|
||||||
from userbot.modules.capture import build_capture_context
|
from userbot.modules.capture import build_capture_context
|
||||||
|
from userbot.modules.jobs import JobConsumer
|
||||||
from utils.env import env
|
from utils.env import env
|
||||||
from utils.logging import logger, setup_logging
|
from utils.logging import logger, setup_logging
|
||||||
from utils.storage import ContentAddressedStorage
|
from utils.storage import ContentAddressedStorage
|
||||||
@@ -101,6 +102,7 @@ async def runner() -> None:
|
|||||||
|
|
||||||
clients: list[PyroClient] = []
|
clients: list[PyroClient] = []
|
||||||
reload_tasks: set[asyncio.Task] = set()
|
reload_tasks: set[asyncio.Task] = set()
|
||||||
|
consumer_tasks: list[asyncio.Task] = []
|
||||||
listen_conn: asyncpg.Connection | None = None
|
listen_conn: asyncpg.Connection | None = None
|
||||||
try:
|
try:
|
||||||
for session_path in session_files:
|
for session_path in session_files:
|
||||||
@@ -116,12 +118,19 @@ async def runner() -> None:
|
|||||||
account_id = await _sync_account(pool, client, session_name)
|
account_id = await _sync_account(pool, client, session_name)
|
||||||
if account_id is not None:
|
if account_id is not None:
|
||||||
await _setup_capture(pool, client, account_id, storage)
|
await _setup_capture(pool, client, account_id, storage)
|
||||||
|
consumer = JobConsumer(client, pool, account_id)
|
||||||
|
consumer_tasks.append(asyncio.create_task(consumer.run()))
|
||||||
|
|
||||||
if clients:
|
if clients:
|
||||||
listen_conn = await _listen_policy_changes(clients, reload_tasks)
|
listen_conn = await _listen_policy_changes(clients, reload_tasks)
|
||||||
logger.info("[green]Userbot running.[/]")
|
logger.info("[green]Userbot running.[/]")
|
||||||
await asyncio.Event().wait()
|
await asyncio.Event().wait()
|
||||||
finally:
|
finally:
|
||||||
|
for task in consumer_tasks:
|
||||||
|
task.cancel()
|
||||||
|
for task in consumer_tasks:
|
||||||
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
|
await task
|
||||||
if listen_conn is not None:
|
if listen_conn is not None:
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
await listen_conn.close()
|
await listen_conn.close()
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from enum import StrEnum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import BigInteger, Column, DateTime, LargeBinary, func
|
from sqlalchemy import BigInteger, Column, DateTime, LargeBinary, func
|
||||||
@@ -6,6 +7,13 @@ from sqlalchemy.dialects.postgresql import JSONB
|
|||||||
from sqlmodel import Field, SQLModel
|
from sqlmodel import Field, SQLModel
|
||||||
|
|
||||||
|
|
||||||
|
class JobStatus(StrEnum):
|
||||||
|
PENDING = "pending"
|
||||||
|
RUNNING = "running"
|
||||||
|
DONE = "done"
|
||||||
|
FAILED = "failed"
|
||||||
|
|
||||||
|
|
||||||
class Account(SQLModel, table=True):
|
class Account(SQLModel, table=True):
|
||||||
__tablename__ = "accounts"
|
__tablename__ = "accounts"
|
||||||
|
|
||||||
@@ -158,6 +166,44 @@ class Reaction(SQLModel, table=True):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Job(SQLModel, table=True):
|
||||||
|
__tablename__ = "jobs"
|
||||||
|
|
||||||
|
id: int | None = Field(default=None, sa_column=Column(BigInteger, primary_key=True))
|
||||||
|
account_id: int
|
||||||
|
kind: str
|
||||||
|
status: str = JobStatus.PENDING
|
||||||
|
params: dict[str, Any] = Field(
|
||||||
|
default_factory=dict, sa_column=Column(JSONB, nullable=False)
|
||||||
|
)
|
||||||
|
cursor: dict[str, Any] | None = Field(default=None, sa_column=Column(JSONB))
|
||||||
|
progress: dict[str, Any] = Field(
|
||||||
|
default_factory=dict, sa_column=Column(JSONB, nullable=False)
|
||||||
|
)
|
||||||
|
flood_waits: int = 0
|
||||||
|
attempts: int = 0
|
||||||
|
error: str | None = None
|
||||||
|
created_at: datetime = Field(
|
||||||
|
sa_column=Column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
updated_at: datetime = Field(
|
||||||
|
sa_column=Column(
|
||||||
|
DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=func.now(),
|
||||||
|
onupdate=func.now(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
started_at: datetime | None = Field(
|
||||||
|
default=None, sa_column=Column(DateTime(timezone=True))
|
||||||
|
)
|
||||||
|
finished_at: datetime | None = Field(
|
||||||
|
default=None, sa_column=Column(DateTime(timezone=True))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CapturePolicy(SQLModel, table=True):
|
class CapturePolicy(SQLModel, table=True):
|
||||||
__tablename__ = "capture_policy"
|
__tablename__ = "capture_policy"
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user