feat: add jobs queue
This commit is contained in:
@@ -0,0 +1,12 @@
|
||||
from userbot.modules.jobs.consumer import JOBS_CHANGED_CHANNEL, JobConsumer
|
||||
from userbot.modules.jobs.context import JobContext
|
||||
from userbot.modules.jobs.registry import JOB_HANDLERS, JobHandler, register
|
||||
|
||||
__all__ = [
|
||||
"JOBS_CHANGED_CHANNEL",
|
||||
"JOB_HANDLERS",
|
||||
"JobConsumer",
|
||||
"JobContext",
|
||||
"JobHandler",
|
||||
"register",
|
||||
]
|
||||
@@ -0,0 +1,85 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
|
||||
import asyncpg
|
||||
from pyrogram import Client
|
||||
from pyrogram.errors import FloodPremiumWait, FloodWait
|
||||
|
||||
from userbot.modules.jobs import repository
|
||||
from userbot.modules.jobs.context import JobContext
|
||||
from userbot.modules.jobs.registry import JOB_HANDLERS, JobHandler
|
||||
from utils.db.models import Job, JobStatus
|
||||
from utils.env import env
|
||||
from utils.logging import logger
|
||||
|
||||
JOBS_CHANGED_CHANNEL = "jobs_changed"
|
||||
POLL_INTERVAL_SECONDS = 60.0
|
||||
|
||||
|
||||
class JobConsumer:
|
||||
def __init__(
|
||||
self,
|
||||
client: Client | None,
|
||||
pool: asyncpg.Pool,
|
||||
account_id: int,
|
||||
registry: dict[str, JobHandler] | None = None,
|
||||
) -> None:
|
||||
self.client = client
|
||||
self.pool = pool
|
||||
self.account_id = account_id
|
||||
self.registry = registry if registry is not None else JOB_HANDLERS
|
||||
self._wake = asyncio.Event()
|
||||
|
||||
async def execute(self, job: Job) -> None:
|
||||
ctx = JobContext(self.client, self.pool, self.account_id, job)
|
||||
handler = self.registry.get(job.kind)
|
||||
if handler is None:
|
||||
await repository.finish(
|
||||
self.pool, ctx.job_id, JobStatus.FAILED, f"no handler for {job.kind!r}"
|
||||
)
|
||||
return
|
||||
while True:
|
||||
try:
|
||||
await handler(ctx)
|
||||
except (FloodWait, FloodPremiumWait) as exc:
|
||||
await repository.bump_flood_wait(self.pool, ctx.job_id)
|
||||
wait = exc.value if isinstance(exc.value, int) else 1
|
||||
logger.warning(
|
||||
f"[yellow]FloodWait {wait}s on job {ctx.job_id} ({job.kind})[/]"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
except Exception as exc:
|
||||
logger.exception(f"Job {ctx.job_id} ({job.kind}) failed")
|
||||
await repository.finish(
|
||||
self.pool, ctx.job_id, JobStatus.FAILED, str(exc)
|
||||
)
|
||||
return
|
||||
else:
|
||||
await repository.finish(self.pool, ctx.job_id, JobStatus.DONE)
|
||||
return
|
||||
|
||||
async def drain(self) -> None:
|
||||
while True:
|
||||
job = await repository.claim_next(self.pool, self.account_id)
|
||||
if job is None:
|
||||
return
|
||||
await self.execute(job)
|
||||
|
||||
async def run(self) -> None:
|
||||
requeued = await repository.requeue_running(self.pool, self.account_id)
|
||||
if requeued:
|
||||
logger.info(f"[yellow]Requeued {requeued} stale running job(s).[/]")
|
||||
conn = await asyncpg.connect(dsn=env.db.connection_url)
|
||||
await conn.add_listener(JOBS_CHANGED_CHANNEL, lambda *_: self._wake.set())
|
||||
logger.info(f"[green]Job consumer running for account {self.account_id}.[/]")
|
||||
try:
|
||||
while True:
|
||||
await self.drain()
|
||||
self._wake.clear()
|
||||
with contextlib.suppress(TimeoutError):
|
||||
await asyncio.wait_for(
|
||||
self._wake.wait(), timeout=POLL_INTERVAL_SECONDS
|
||||
)
|
||||
finally:
|
||||
with contextlib.suppress(Exception):
|
||||
await conn.close()
|
||||
@@ -0,0 +1,26 @@
|
||||
from typing import Any
|
||||
|
||||
import asyncpg
|
||||
from pyrogram import Client
|
||||
|
||||
from userbot.modules.jobs import repository
|
||||
from utils.db.models import Job
|
||||
|
||||
|
||||
class JobContext:
|
||||
def __init__(
|
||||
self, client: Client | None, pool: asyncpg.Pool, account_id: int, job: Job
|
||||
) -> None:
|
||||
self.client = client
|
||||
self.pool = pool
|
||||
self.account_id = account_id
|
||||
self.job = job
|
||||
self.job_id = job.id or 0
|
||||
|
||||
async def save_cursor(self, cursor: dict[str, Any]) -> None:
|
||||
self.job.cursor = cursor
|
||||
await repository.save_cursor(self.pool, self.job_id, cursor)
|
||||
|
||||
async def report_progress(self, progress: dict[str, Any]) -> None:
|
||||
self.job.progress = progress
|
||||
await repository.report_progress(self.pool, self.job_id, progress)
|
||||
@@ -0,0 +1,15 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from userbot.modules.jobs.context import JobContext
|
||||
|
||||
JobHandler = Callable[[JobContext], Awaitable[None]]
|
||||
|
||||
JOB_HANDLERS: dict[str, JobHandler] = {}
|
||||
|
||||
|
||||
def register(kind: str) -> Callable[[JobHandler], JobHandler]:
|
||||
def decorator(handler: JobHandler) -> JobHandler:
|
||||
JOB_HANDLERS[kind] = handler
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
@@ -0,0 +1,91 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import asyncpg
|
||||
|
||||
from utils.db.models import Job, JobStatus
|
||||
|
||||
|
||||
def _row_to_job(row: asyncpg.Record) -> Job:
|
||||
data = dict(row)
|
||||
data["params"] = json.loads(data["params"])
|
||||
data["progress"] = json.loads(data["progress"])
|
||||
data["cursor"] = json.loads(data["cursor"]) if data["cursor"] is not None else None
|
||||
return Job(**data)
|
||||
|
||||
|
||||
async def enqueue(
|
||||
pool: asyncpg.Pool, account_id: int, kind: str, params: dict[str, Any]
|
||||
) -> int:
|
||||
return await pool.fetchval(
|
||||
"INSERT INTO jobs (account_id, kind, params) "
|
||||
"VALUES ($1, $2, $3::jsonb) RETURNING id",
|
||||
account_id,
|
||||
kind,
|
||||
json.dumps(params),
|
||||
)
|
||||
|
||||
|
||||
async def claim_next(pool: asyncpg.Pool, account_id: int) -> Job | None:
|
||||
row = await pool.fetchrow(
|
||||
"UPDATE jobs SET status = 'running', started_at = now(), "
|
||||
"updated_at = now(), attempts = attempts + 1 "
|
||||
"WHERE id = ("
|
||||
" SELECT id FROM jobs WHERE account_id = $1 AND status = 'pending' "
|
||||
" ORDER BY id FOR UPDATE SKIP LOCKED LIMIT 1"
|
||||
") RETURNING *",
|
||||
account_id,
|
||||
)
|
||||
return _row_to_job(row) if row else None
|
||||
|
||||
|
||||
async def requeue_running(pool: asyncpg.Pool, account_id: int) -> int:
|
||||
result = await pool.execute(
|
||||
"UPDATE jobs SET status = 'pending', updated_at = now() "
|
||||
"WHERE account_id = $1 AND status = 'running'",
|
||||
account_id,
|
||||
)
|
||||
return int(result.split()[-1])
|
||||
|
||||
|
||||
async def save_cursor(pool: asyncpg.Pool, job_id: int, cursor: dict[str, Any]) -> None:
|
||||
await pool.execute(
|
||||
"UPDATE jobs SET cursor = $2::jsonb, updated_at = now() WHERE id = $1",
|
||||
job_id,
|
||||
json.dumps(cursor),
|
||||
)
|
||||
|
||||
|
||||
async def report_progress(
|
||||
pool: asyncpg.Pool, job_id: int, progress: dict[str, Any]
|
||||
) -> None:
|
||||
await pool.execute(
|
||||
"UPDATE jobs SET progress = $2::jsonb, updated_at = now() WHERE id = $1",
|
||||
job_id,
|
||||
json.dumps(progress),
|
||||
)
|
||||
|
||||
|
||||
async def bump_flood_wait(pool: asyncpg.Pool, job_id: int) -> None:
|
||||
await pool.execute(
|
||||
"UPDATE jobs SET flood_waits = flood_waits + 1, updated_at = now() "
|
||||
"WHERE id = $1",
|
||||
job_id,
|
||||
)
|
||||
|
||||
|
||||
async def finish(
|
||||
pool: asyncpg.Pool, job_id: int, status: JobStatus, error: str | None = None
|
||||
) -> None:
|
||||
await pool.execute(
|
||||
"UPDATE jobs SET status = $2, error = $3, finished_at = now(), "
|
||||
"updated_at = now() WHERE id = $1",
|
||||
job_id,
|
||||
status.value,
|
||||
error,
|
||||
)
|
||||
|
||||
|
||||
async def get_job(pool: asyncpg.Pool, job_id: int) -> Job | None:
|
||||
row = await pool.fetchrow("SELECT * FROM jobs WHERE id = $1", job_id)
|
||||
return _row_to_job(row) if row else None
|
||||
Reference in New Issue
Block a user