Compare commits
7 Commits
2465bcd184
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 6ed392617f | |||
| 4fdf70a898 | |||
| f3712cfe36 | |||
| f688530eac | |||
| 3aaa3c757f | |||
| 17cd31c41e | |||
| c6984a7286 |
@@ -1,6 +1,9 @@
|
||||
COMPOSE_PROFILES=db,userbot,api
|
||||
RUN_ENVIRONMENT=prod
|
||||
|
||||
# Leave empty for *.localhost.
|
||||
FRONTEND_DEV_HOST=
|
||||
|
||||
DB__HOST=postgres
|
||||
DB__PORT=5432
|
||||
DB__USER=beavergram
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
.PHONY: recreate down restart rebuild deploy migrate session-create
|
||||
.PHONY: recreate down restart rebuild deploy migrate session-create frontend
|
||||
|
||||
recreate:
|
||||
docker compose up -d --force-recreate
|
||||
@@ -13,7 +13,11 @@ rebuild:
|
||||
docker compose build
|
||||
docker compose up -d
|
||||
|
||||
frontend:
|
||||
docker compose run --rm --no-deps frontend-dev sh -c "bun install && bun run build"
|
||||
|
||||
deploy:
|
||||
$(MAKE) frontend
|
||||
$(MAKE) rebuild
|
||||
|
||||
migrate:
|
||||
|
||||
@@ -4,8 +4,8 @@ from pathlib import Path
|
||||
|
||||
import asyncpg
|
||||
from dishka.integrations.fastapi import DishkaRoute, FromDishka, setup_dishka
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import FileResponse, RedirectResponse
|
||||
from fastmcp.utilities.lifespan import combine_lifespans
|
||||
from starlette.applications import Starlette
|
||||
|
||||
@@ -88,6 +88,13 @@ app.include_router(watches.router)
|
||||
|
||||
app.mount("/mcp", mcp_app)
|
||||
|
||||
|
||||
@app.api_route("/mcp", methods=["GET", "POST", "DELETE"])
|
||||
async def mcp_trailing_slash(request: Request) -> RedirectResponse:
|
||||
query = request.url.query
|
||||
return RedirectResponse(f"/mcp/?{query}" if query else "/mcp/", status_code=307)
|
||||
|
||||
|
||||
_spa_dir = Path(env.api.static_dir).resolve()
|
||||
if _spa_dir.is_dir():
|
||||
_spa_index = _spa_dir / "index.html"
|
||||
|
||||
+13
-3
@@ -1,3 +1,6 @@
|
||||
from secrets import compare_digest
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
PROTECTED_PREFIXES = ("/api", "/mcp")
|
||||
@@ -9,6 +12,15 @@ class BearerAuthMiddleware:
|
||||
self.app = app
|
||||
self.token = token
|
||||
|
||||
def _authorized(self, scope: Scope) -> bool:
|
||||
headers = dict(scope["headers"])
|
||||
bearer = headers.get(b"authorization", b"").decode()
|
||||
if bearer.startswith("Bearer ") and compare_digest(bearer[7:], self.token):
|
||||
return True
|
||||
query = parse_qs(scope["query_string"].decode())
|
||||
token = query.get("token", [""])[0]
|
||||
return bool(token) and compare_digest(token, self.token)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
@@ -18,9 +30,7 @@ class BearerAuthMiddleware:
|
||||
):
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
headers = dict(scope["headers"])
|
||||
authorization = headers.get(b"authorization", b"").decode()
|
||||
if authorization == f"Bearer {self.token}":
|
||||
if self._authorized(scope):
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
await send(
|
||||
|
||||
@@ -0,0 +1,173 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import asyncpg
|
||||
from fastmcp.utilities.types import Image
|
||||
from mcp.types import TextContent
|
||||
|
||||
from utils.read import peers
|
||||
from utils.read.models import MediaRef, MessageView
|
||||
from utils.storage import ContentAddressedStorage
|
||||
|
||||
_VOICE_KINDS = {"voice", "video_note"}
|
||||
|
||||
|
||||
def _ts(value: datetime) -> str:
|
||||
return value.strftime("%Y-%m-%d %H:%M")
|
||||
|
||||
|
||||
def _name(sender_id: int | None, names: dict[int, str], self_id: int | None) -> str:
|
||||
if sender_id is None:
|
||||
return "Unknown"
|
||||
if sender_id == self_id:
|
||||
return "Me"
|
||||
return names.get(sender_id) or str(sender_id)
|
||||
|
||||
|
||||
def _media_note(media: list[MediaRef]) -> list[str]:
|
||||
notes: list[str] = []
|
||||
for item in media:
|
||||
if item.kind in _VOICE_KINDS:
|
||||
if item.extracted_text:
|
||||
notes.append(f"(Voice message, STT Content: {item.extracted_text})")
|
||||
else:
|
||||
notes.append("(Voice message, not transcribed)")
|
||||
elif item.kind == "photo":
|
||||
state = "" if item.downloaded else ", not downloaded"
|
||||
notes.append(f"[photo #{item.message_id}{state}]")
|
||||
else:
|
||||
notes.append(f"[{item.kind}]")
|
||||
return notes
|
||||
|
||||
|
||||
def _line(
|
||||
view: MessageView,
|
||||
names: dict[int, str],
|
||||
self_id: int | None,
|
||||
notes: dict[int, list[str]],
|
||||
) -> str:
|
||||
parts: list[str] = []
|
||||
if view.reply and (view.reply.sender_name or view.reply.text):
|
||||
ref = view.reply.sender_name or str(view.reply.sender_id or "?")
|
||||
parts.append(f"(reply to {ref})")
|
||||
if view.text:
|
||||
parts.append(view.text)
|
||||
parts.extend(_media_note(view.media))
|
||||
suffix = ""
|
||||
if view.edited_at:
|
||||
suffix += " (edited)"
|
||||
if view.deleted_at:
|
||||
suffix += " (deleted)"
|
||||
body = " ".join(part for part in parts if part) or "(no text)"
|
||||
name = _name(view.sender_id, names, self_id)
|
||||
line = f"#{view.message_id} {name} ({_ts(view.date)}): {body}{suffix}"
|
||||
for note in notes.get(view.message_id, []):
|
||||
line += f"\n 📝 [your private note, NOT in Telegram]: {note}"
|
||||
return line
|
||||
|
||||
|
||||
async def load_notes(
|
||||
pool: asyncpg.Pool, account_id: int, chat_id: int, views: list[MessageView]
|
||||
) -> dict[int, list[str]]:
|
||||
ids = [view.message_id for view in views]
|
||||
if not ids:
|
||||
return {}
|
||||
rows = await pool.fetch(
|
||||
"SELECT message_id, text FROM annotations "
|
||||
"WHERE account_id = $1 AND chat_id = $2 AND message_id = ANY($3::bigint[]) "
|
||||
"ORDER BY created_at",
|
||||
account_id,
|
||||
chat_id,
|
||||
ids,
|
||||
)
|
||||
notes: dict[int, list[str]] = {}
|
||||
for row in rows:
|
||||
notes.setdefault(row["message_id"], []).append(row["text"])
|
||||
return notes
|
||||
|
||||
|
||||
async def resolve_names(
|
||||
pool: asyncpg.Pool, account_id: int, views: list[MessageView]
|
||||
) -> dict[int, str]:
|
||||
ids = list({view.sender_id for view in views if view.sender_id is not None})
|
||||
found = await peers.get_peers(pool, account_id, ids)
|
||||
names: dict[int, str] = {}
|
||||
for peer in found:
|
||||
name = (
|
||||
" ".join(part for part in (peer.first_name, peer.last_name) if part)
|
||||
or peer.username
|
||||
)
|
||||
if name:
|
||||
names[peer.peer_id] = name
|
||||
return names
|
||||
|
||||
|
||||
async def load_photos(
|
||||
pool: asyncpg.Pool,
|
||||
storage: ContentAddressedStorage,
|
||||
account_id: int,
|
||||
views: list[MessageView],
|
||||
*,
|
||||
limit: int,
|
||||
) -> tuple[list[tuple[int, bytes, str]], bool]:
|
||||
refs = [
|
||||
(item.message_id, item.id)
|
||||
for view in views
|
||||
for item in view.media
|
||||
if item.kind == "photo" and item.downloaded and item.id is not None
|
||||
]
|
||||
truncated = len(refs) > limit
|
||||
refs = refs[-limit:]
|
||||
if not refs:
|
||||
return [], truncated
|
||||
rows = await pool.fetch(
|
||||
"SELECT id, storage_key, mime FROM media "
|
||||
"WHERE account_id = $1 AND id = ANY($2::bigint[])",
|
||||
account_id,
|
||||
[media_id for _, media_id in refs],
|
||||
)
|
||||
by_id = {row["id"]: row for row in rows}
|
||||
out: list[tuple[int, bytes, str]] = []
|
||||
for message_id, media_id in refs:
|
||||
row = by_id.get(media_id)
|
||||
if row is None or not row["storage_key"]:
|
||||
continue
|
||||
try:
|
||||
data = storage.get(row["storage_key"])
|
||||
except OSError:
|
||||
continue
|
||||
fmt = (row["mime"] or "image/jpeg").split("/")[-1]
|
||||
out.append((message_id, data, fmt))
|
||||
return out, truncated
|
||||
|
||||
|
||||
def build_transcript(
|
||||
views: list[MessageView],
|
||||
names: dict[int, str],
|
||||
self_id: int | None,
|
||||
photos: list[tuple[int, bytes, str]],
|
||||
*,
|
||||
notes: dict[int, list[str]] | None = None,
|
||||
truncated: bool = False,
|
||||
) -> list[Any]:
|
||||
if not views:
|
||||
return [TextContent(type="text", text="No messages.")]
|
||||
notes = notes or {}
|
||||
header = f"{len(views)} messages (oldest first)"
|
||||
body = f"{header}\n\n" + "\n".join(
|
||||
_line(view, names, self_id, notes) for view in views
|
||||
)
|
||||
blocks: list[Any] = [TextContent(type="text", text=body)]
|
||||
if truncated:
|
||||
blocks.append(
|
||||
TextContent(
|
||||
type="text",
|
||||
text="(images truncated to the most recent; narrow the range for more)",
|
||||
)
|
||||
)
|
||||
for message_id, data, fmt in photos:
|
||||
blocks.append(
|
||||
TextContent(type="text", text=f"Image attached to message #{message_id}:")
|
||||
)
|
||||
blocks.append(Image(data=data, format=fmt))
|
||||
return blocks
|
||||
@@ -6,14 +6,57 @@ import asyncpg
|
||||
from fastmcp import FastMCP
|
||||
from pydantic import BaseModel
|
||||
|
||||
from api.mcp.format import build_transcript, load_notes, load_photos, resolve_names
|
||||
from dependencies.container import container
|
||||
from utils.jobs import enqueue
|
||||
from utils.read import annotations, chats, media, peers, presence, social, watches
|
||||
from utils.read.accounts import self_user_id
|
||||
from utils.read.models import DEFAULT_LIMIT, Page
|
||||
from utils.search.models import SearchFilters
|
||||
from utils.search.repository import search_messages
|
||||
from utils.storage import ContentAddressedStorage
|
||||
|
||||
mcp: FastMCP = FastMCP("beavergram")
|
||||
INSTRUCTIONS = """\
|
||||
beavergram archives Telegram data (chats, messages, media, presence, stories,
|
||||
peer history) into Postgres and exposes it read-only over these tools.
|
||||
|
||||
Account scoping:
|
||||
- Every tool takes `account_id`. It selects which archived Telegram account to
|
||||
read. Unless the user names a different one, always pass `account_id=1`.
|
||||
|
||||
Identifiers:
|
||||
- `chat_id` and `peer_id` are Telegram IDs (negative for groups/channels,
|
||||
positive for users/bots). Discover them with `list_chats` before calling
|
||||
tools that need a specific chat.
|
||||
|
||||
Typical flows:
|
||||
- Read a chat: `list_chats` -> pick a `chat_id` -> `get_chat_history` for a
|
||||
human-readable transcript (names, timestamps, voice STT, inline photos), or
|
||||
`get_chat_history_raw` for structured JSON. Both return messages oldest
|
||||
first; default gives the latest `limit`. Page to older messages with
|
||||
`before_id` (first id you saw), to newer with `after_id` (last id you saw);
|
||||
begin a full forward walk with `after_id=0`. Each transcript line starts
|
||||
with its `#message_id`, and photos show as `[photo #message_id]` — pass that
|
||||
id to `get_media` (with `fetch=True` to download) or to annotation tools.
|
||||
Lines marked
|
||||
"📝 [your private note, NOT in Telegram]" are the user's own annotations
|
||||
attached locally via the web UI; they never existed in Telegram. Treat them
|
||||
as private notes from the user to you, not as chat content.
|
||||
- Find something: `search_messages_tool` (full-text over message text and STT
|
||||
transcripts; supports chat/sender/date filters and regex).
|
||||
- Forensics: `get_deleted_messages` and `get_message_versions` recover content
|
||||
removed or edited in Telegram but kept in the archive.
|
||||
- Media: `get_media` returns metadata; pass `fetch=True` to enqueue a lazy
|
||||
download if the file isn't stored yet.
|
||||
- Monitoring: `set_watch` creates a local rule, `list_watches` lists rules, and
|
||||
`list_alerts` reads what those rules fired.
|
||||
|
||||
Writes: everything is read-only except `set_watch`, the only allowed write.
|
||||
|
||||
Paging: list/search tools accept `limit` and `offset`.
|
||||
"""
|
||||
|
||||
mcp: FastMCP = FastMCP("beavergram", instructions=INSTRUCTIONS)
|
||||
|
||||
|
||||
async def _pool() -> asyncpg.Pool:
|
||||
@@ -62,24 +105,80 @@ async def list_chats(
|
||||
return _dump(await chats.list_chats(await _pool(), account_id, page))
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@mcp.tool(output_schema=None)
|
||||
async def get_chat_history(
|
||||
account_id: int,
|
||||
chat_id: int,
|
||||
limit: int = DEFAULT_LIMIT,
|
||||
offset: int = 0,
|
||||
before_id: int | None = None,
|
||||
after_id: int | None = None,
|
||||
include_deleted: bool = True,
|
||||
include_images: bool = True,
|
||||
max_images: int = 20,
|
||||
) -> list[Any]:
|
||||
"""Read a chat as a readable transcript, oldest first.
|
||||
|
||||
Renders "Name (time): text", voice notes as
|
||||
"(Voice message, STT Content: ...)", and inlines downloaded photos as
|
||||
images. Default returns the latest `limit` messages; page to older
|
||||
messages with `before_id` (the first id you saw) or to newer ones with
|
||||
`after_id` (the last id you saw).
|
||||
"""
|
||||
pool = await _pool()
|
||||
views = await chats.get_chat_history(
|
||||
pool,
|
||||
account_id,
|
||||
chat_id,
|
||||
Page(limit=limit),
|
||||
include_deleted=include_deleted,
|
||||
before_id=before_id,
|
||||
after_id=after_id,
|
||||
)
|
||||
if after_id is None:
|
||||
views = list(reversed(views))
|
||||
names = await resolve_names(pool, account_id, views)
|
||||
notes = await load_notes(pool, account_id, chat_id, views)
|
||||
self_id = await self_user_id(pool, account_id)
|
||||
photos: list[tuple[int, bytes, str]] = []
|
||||
truncated = False
|
||||
if include_images:
|
||||
storage = await container.get(ContentAddressedStorage)
|
||||
photos, truncated = await load_photos(
|
||||
pool, storage, account_id, views, limit=max_images
|
||||
)
|
||||
return build_transcript(
|
||||
views, names, self_id, photos, notes=notes, truncated=truncated
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool
|
||||
async def get_chat_history_raw(
|
||||
account_id: int,
|
||||
chat_id: int,
|
||||
limit: int = DEFAULT_LIMIT,
|
||||
before_id: int | None = None,
|
||||
after_id: int | None = None,
|
||||
include_deleted: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Read archived messages of a chat, newest first."""
|
||||
return _dump(
|
||||
await chats.get_chat_history(
|
||||
"""Structured chat messages as JSON, oldest first.
|
||||
|
||||
Default returns the latest `limit` messages. Walk the whole chat forward
|
||||
with `after_id` set to the last returned message_id (begin at
|
||||
`after_id=0`), or backward with `before_id` set to the first returned
|
||||
message_id.
|
||||
"""
|
||||
views = await chats.get_chat_history(
|
||||
await _pool(),
|
||||
account_id,
|
||||
chat_id,
|
||||
Page(limit=limit, offset=offset),
|
||||
Page(limit=limit),
|
||||
include_deleted=include_deleted,
|
||||
before_id=before_id,
|
||||
after_id=after_id,
|
||||
)
|
||||
)
|
||||
if after_id is None:
|
||||
views = list(reversed(views))
|
||||
return _dump(views)
|
||||
|
||||
|
||||
@mcp.tool
|
||||
|
||||
@@ -24,6 +24,12 @@ class FetchMediaRequest(BaseModel):
|
||||
message_id: int
|
||||
|
||||
|
||||
class TranscribeRequest(BaseModel):
|
||||
account_id: int
|
||||
chat_id: int
|
||||
message_id: int
|
||||
|
||||
|
||||
class SyncDialogsRequest(BaseModel):
|
||||
account_id: int
|
||||
|
||||
@@ -82,6 +88,19 @@ async def enqueue_fetch_media(
|
||||
return EnqueueResponse(job_id=job_id)
|
||||
|
||||
|
||||
@router.post("/media/transcribe", status_code=201)
|
||||
async def enqueue_transcribe(
|
||||
pool: FromDishka[asyncpg.Pool], body: TranscribeRequest
|
||||
) -> EnqueueResponse:
|
||||
job_id = await enqueue(
|
||||
pool,
|
||||
body.account_id,
|
||||
"transcribe",
|
||||
{"chat_id": body.chat_id, "message_id": body.message_id},
|
||||
)
|
||||
return EnqueueResponse(job_id=job_id)
|
||||
|
||||
|
||||
@router.post("/dialogs/sync", status_code=201)
|
||||
async def enqueue_sync_dialogs(
|
||||
pool: FromDishka[asyncpg.Pool], body: SyncDialogsRequest
|
||||
@@ -115,3 +134,18 @@ async def get_job(pool: FromDishka[asyncpg.Pool], job_id: int) -> JobView:
|
||||
if row is None:
|
||||
raise HTTPException(status_code=404, detail="job not found")
|
||||
return _to_view(row)
|
||||
|
||||
|
||||
@router.post("/jobs/{job_id}/cancel")
|
||||
async def cancel_job(pool: FromDishka[asyncpg.Pool], job_id: int) -> JobView:
|
||||
row = await pool.fetchrow(
|
||||
"UPDATE jobs SET status = 'canceled', finished_at = now(), "
|
||||
"updated_at = now() WHERE id = $1 AND status IN ('pending', 'running') "
|
||||
"RETURNING *",
|
||||
job_id,
|
||||
)
|
||||
if row is not None:
|
||||
return _to_view(row)
|
||||
if await pool.fetchval("SELECT 1 FROM jobs WHERE id = $1", job_id) is None:
|
||||
raise HTTPException(status_code=404, detail="job not found")
|
||||
raise HTTPException(status_code=409, detail="job already finished")
|
||||
|
||||
@@ -24,3 +24,6 @@ class JobContext:
|
||||
async def report_progress(self, progress: dict[str, Any]) -> None:
|
||||
self.job.progress = progress
|
||||
await repository.report_progress(self.pool, self.job_id, progress)
|
||||
|
||||
async def is_canceled(self) -> bool:
|
||||
return await repository.is_canceled(self.pool, self.job_id)
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
from pyrogram.errors import PeerIdInvalid
|
||||
|
||||
from userbot.modules.capture import capture_message
|
||||
from userbot.modules.capture.chat_meta import meta_from_chat
|
||||
from userbot.modules.jobs.context import JobContext
|
||||
from userbot.modules.jobs.registry import register
|
||||
from userbot.modules.stt import repository as stt_repo
|
||||
from userbot.modules.stt import should_transcribe_on_backfill
|
||||
from userbot.modules.stt.gate import safe_transcribe
|
||||
from utils.policy.models import CaptureToggles
|
||||
|
||||
SAVE_EVERY = 100
|
||||
@@ -25,14 +29,24 @@ async def backfill(ctx: JobContext) -> None:
|
||||
max_id = (ctx.job.cursor or {}).get("max_id", 0)
|
||||
processed = ctx.job.progress.get("processed", 0)
|
||||
kwargs = {"max_id": max_id} if max_id else {}
|
||||
self_id = client.me.id if client.me else None
|
||||
try:
|
||||
async for message in client.get_chat_history(chat_id, **kwargs):
|
||||
await capture_message(client, message, capture, toggles)
|
||||
if should_transcribe_on_backfill(message, self_id) and message.chat:
|
||||
meta = meta_from_chat(message.chat, capture.contacts.ids)
|
||||
already = await stt_repo.is_transcribed(
|
||||
capture.pool, capture.account_id, chat_id, message.id
|
||||
)
|
||||
if capture.resolve(meta).stt and not already:
|
||||
await safe_transcribe(client, capture, chat_id, message.id)
|
||||
processed += 1
|
||||
if processed % SAVE_EVERY == 0:
|
||||
next_max = message.id - 1
|
||||
await ctx.save_cursor({"max_id": next_max})
|
||||
await ctx.report_progress({"processed": processed, "max_id": next_max})
|
||||
if await ctx.is_canceled():
|
||||
return
|
||||
except PeerIdInvalid:
|
||||
await ctx.report_progress({"processed": processed, "error": "peer_id_invalid"})
|
||||
return
|
||||
|
||||
@@ -79,13 +79,18 @@ async def finish(
|
||||
) -> None:
|
||||
await pool.execute(
|
||||
"UPDATE jobs SET status = $2, error = $3, finished_at = now(), "
|
||||
"updated_at = now() WHERE id = $1",
|
||||
"updated_at = now() WHERE id = $1 AND status = 'running'",
|
||||
job_id,
|
||||
status.value,
|
||||
error,
|
||||
)
|
||||
|
||||
|
||||
async def is_canceled(pool: asyncpg.Pool, job_id: int) -> bool:
|
||||
status = await pool.fetchval("SELECT status FROM jobs WHERE id = $1", job_id)
|
||||
return status == JobStatus.CANCELED.value
|
||||
|
||||
|
||||
async def get_job(pool: asyncpg.Pool, job_id: int) -> Job | None:
|
||||
row = await pool.fetchrow("SELECT * FROM jobs WHERE id = $1", job_id)
|
||||
return _row_to_job(row) if row else None
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
from userbot.modules.stt.service import is_transcribable, transcribe_message
|
||||
from userbot.modules.stt.service import (
|
||||
is_transcribable,
|
||||
should_transcribe_on_backfill,
|
||||
transcribe_message,
|
||||
)
|
||||
|
||||
__all__ = ["is_transcribable", "transcribe_message"]
|
||||
__all__ = ["is_transcribable", "should_transcribe_on_backfill", "transcribe_message"]
|
||||
|
||||
@@ -2,7 +2,6 @@ from pyrogram import Client
|
||||
from pyrogram.errors import FloodPremiumWait, FloodWait, RPCError
|
||||
|
||||
from userbot.modules.capture.context import CaptureContext
|
||||
from userbot.modules.jobs.repository import enqueue
|
||||
from userbot.modules.stt.service import transcribe_message
|
||||
from utils.logging import logger
|
||||
|
||||
@@ -13,6 +12,8 @@ async def safe_transcribe(
|
||||
try:
|
||||
await transcribe_message(client, ctx, chat_id, message_id)
|
||||
except (FloodWait, FloodPremiumWait):
|
||||
from userbot.modules.jobs.repository import enqueue # noqa: PLC0415
|
||||
|
||||
await enqueue(
|
||||
ctx.pool,
|
||||
ctx.account_id,
|
||||
|
||||
@@ -9,6 +9,11 @@ UPDATE media SET extracted_text = $4
|
||||
WHERE account_id = $1 AND chat_id = $2 AND message_id = $3
|
||||
"""
|
||||
|
||||
_IS_TRANSCRIBED = """
|
||||
SELECT extracted_text IS NOT NULL FROM media
|
||||
WHERE account_id = $1 AND chat_id = $2 AND message_id = $3
|
||||
"""
|
||||
|
||||
_VOICE_READS_BOX = """
|
||||
SELECT md.chat_id, md.message_id, m.sender_id,
|
||||
md.extracted_text IS NULL AS untranscribed
|
||||
@@ -34,6 +39,12 @@ async def set_extracted_text(
|
||||
await pool.execute(_SET_EXTRACTED_TEXT, account_id, chat_id, message_id, text)
|
||||
|
||||
|
||||
async def is_transcribed(
|
||||
pool: asyncpg.Pool, account_id: int, chat_id: int, message_id: int
|
||||
) -> bool:
|
||||
return bool(await pool.fetchval(_IS_TRANSCRIBED, account_id, chat_id, message_id))
|
||||
|
||||
|
||||
async def voice_reads(
|
||||
pool: asyncpg.Pool,
|
||||
account_id: int,
|
||||
|
||||
@@ -4,6 +4,7 @@ from pyrogram.types import Message
|
||||
from userbot.modules.capture.context import CaptureContext
|
||||
from userbot.modules.media import self_destruct_ttl
|
||||
from userbot.modules.stt import repository
|
||||
from utils.logging import logger
|
||||
|
||||
|
||||
def is_transcribable(message: Message) -> bool:
|
||||
@@ -12,6 +13,19 @@ def is_transcribable(message: Message) -> bool:
|
||||
return message.voice is not None or message.video_note is not None
|
||||
|
||||
|
||||
def should_transcribe_on_backfill(message: Message, self_id: int | None) -> bool:
|
||||
if not is_transcribable(message):
|
||||
return False
|
||||
if message.outgoing:
|
||||
return True
|
||||
sender = message.from_user.id if message.from_user else None
|
||||
if sender is None and message.sender_chat is not None:
|
||||
sender = message.sender_chat.id
|
||||
if sender == self_id:
|
||||
return True
|
||||
return not message.unread_media
|
||||
|
||||
|
||||
async def transcribe_message(
|
||||
client: Client, ctx: CaptureContext, chat_id: int, message_id: int
|
||||
) -> None:
|
||||
@@ -21,7 +35,18 @@ async def transcribe_message(
|
||||
result = await client.invoke(
|
||||
raw.functions.messages.TranscribeAudio(peer=peer, msg_id=message_id)
|
||||
)
|
||||
if not result.pending and result.text:
|
||||
if result.pending:
|
||||
logger.info(
|
||||
f"[yellow]STT pending {chat_id}/{message_id} "
|
||||
f"(trial_remains={result.trial_remains_num})[/]"
|
||||
)
|
||||
return
|
||||
if result.text:
|
||||
await repository.set_extracted_text(
|
||||
ctx.pool, ctx.account_id, chat_id, message_id, result.text
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"[yellow]STT empty {chat_id}/{message_id} "
|
||||
f"(trial_remains={result.trial_remains_num})[/]"
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ class JobStatus(StrEnum):
|
||||
RUNNING = "running"
|
||||
DONE = "done"
|
||||
FAILED = "failed"
|
||||
CANCELED = "canceled"
|
||||
|
||||
|
||||
class Account(SQLModel, table=True):
|
||||
|
||||
@@ -25,7 +25,7 @@ async def _media_map(
|
||||
return {}
|
||||
media_rows = await pool.fetch(
|
||||
"SELECT id, chat_id, message_id, kind, downloaded, mime, file_size, "
|
||||
"ttl_seconds FROM media "
|
||||
"ttl_seconds, extracted_text FROM media "
|
||||
"WHERE account_id = $1 AND message_id = ANY($2::bigint[])",
|
||||
account_id,
|
||||
message_ids,
|
||||
|
||||
@@ -356,6 +356,7 @@ def media_ref_from(
|
||||
file_size=(media_row["file_size"] if media_row else None)
|
||||
or obj.get("file_size"),
|
||||
ttl_seconds=media_row["ttl_seconds"] if media_row else None,
|
||||
extracted_text=media_row["extracted_text"] if media_row else None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -77,6 +77,7 @@ class MediaRef(BaseModel):
|
||||
mime: str | None = None
|
||||
file_size: int | None = None
|
||||
ttl_seconds: int | None = None
|
||||
extracted_text: str | None = None
|
||||
|
||||
|
||||
class ReactionCount(BaseModel):
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
CLOUDFLARE_API_TOKEN=
|
||||
@@ -0,0 +1,2 @@
|
||||
Caddyfile
|
||||
.env
|
||||
@@ -0,0 +1,21 @@
|
||||
{
|
||||
admin off
|
||||
# acme_dns cloudflare {env.CLOUDFLARE_API_TOKEN}
|
||||
|
||||
log {
|
||||
format console
|
||||
}
|
||||
|
||||
servers {
|
||||
trusted_proxies cloudflare
|
||||
client_ip_headers Cf-Connecting-Ip
|
||||
}
|
||||
}
|
||||
|
||||
<DOMAIN> {
|
||||
reverse_proxy beavergram-api:8080
|
||||
}
|
||||
|
||||
dev.<DOMAIN> {
|
||||
reverse_proxy beavergram-frontend:5173
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
services:
|
||||
caddy:
|
||||
image: ghcr.io/caddybuilds/caddy-cloudflare:latest
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "0.0.0.0:80:80"
|
||||
- "0.0.0.0:443:443"
|
||||
- "0.0.0.0:443:443/udp"
|
||||
networks:
|
||||
- caddy
|
||||
volumes:
|
||||
- ./Caddyfile:/etc/caddy/Caddyfile
|
||||
- caddy_data:/data
|
||||
env_file:
|
||||
- .env
|
||||
|
||||
networks:
|
||||
caddy:
|
||||
external: true
|
||||
|
||||
volumes:
|
||||
caddy_data:
|
||||
@@ -1,12 +1,26 @@
|
||||
services:
|
||||
postgres:
|
||||
ports:
|
||||
- "127.0.0.1:5433:5432"
|
||||
# postgres:
|
||||
# ports:
|
||||
# - "127.0.0.1:5432:5432"
|
||||
|
||||
api:
|
||||
ports:
|
||||
- "127.0.0.1:8080:8080"
|
||||
# ports:
|
||||
# - "127.0.0.1:8080:8080"
|
||||
networks:
|
||||
default: {}
|
||||
caddy:
|
||||
aliases:
|
||||
- beavergram-api
|
||||
|
||||
frontend-dev:
|
||||
ports:
|
||||
- "127.0.0.1:5173:5173"
|
||||
# ports:
|
||||
# - "127.0.0.1:5173:5173"
|
||||
networks:
|
||||
default: { }
|
||||
caddy:
|
||||
aliases:
|
||||
- beavergram-frontend
|
||||
|
||||
networks:
|
||||
caddy:
|
||||
external: true
|
||||
|
||||
@@ -6,6 +6,7 @@ services:
|
||||
POSTGRES_USER: ${DB__USER:-beavergram}
|
||||
POSTGRES_PASSWORD: ${DB__PASSWORD:-beavergram}
|
||||
POSTGRES_DB: ${DB__DB_NAME:-beavergram}
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- pgdata:/var/lib/postgresql/data
|
||||
- ./backend/migrations/init:/docker-entrypoint-initdb.d:ro
|
||||
@@ -26,6 +27,8 @@ services:
|
||||
required: false
|
||||
environment:
|
||||
RUN_ENVIRONMENT: prod
|
||||
STORAGE__ROOT: /app/storage
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- ./backend/src:/app/src
|
||||
- ./backend/sessions:/app/sessions
|
||||
@@ -45,8 +48,11 @@ services:
|
||||
required: false
|
||||
environment:
|
||||
RUN_ENVIRONMENT: prod
|
||||
STORAGE__ROOT: /app/storage
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- ./backend/src:/app/src
|
||||
- ./frontend/build:/app/static:ro
|
||||
- ${STORAGE__ROOT:-./storage}:/app/storage
|
||||
depends_on:
|
||||
postgres:
|
||||
@@ -79,6 +85,7 @@ services:
|
||||
command: ["sh", "-c", "bun install && bun run dev --host 0.0.0.0 --port 5173"]
|
||||
environment:
|
||||
API_PROXY_TARGET: http://api:8080
|
||||
ALLOWED_HOSTS: ${FRONTEND_DEV_HOST:-}
|
||||
volumes:
|
||||
- ./frontend:/app
|
||||
- frontend_node_modules:/app/node_modules
|
||||
|
||||
@@ -280,6 +280,10 @@ export function getJob(jobId: number): Promise<JobView> {
|
||||
return request<JobView>(`/jobs/${jobId}`, { account: true });
|
||||
}
|
||||
|
||||
export function cancelJob(jobId: number): Promise<JobView> {
|
||||
return request<JobView>(`/jobs/${jobId}/cancel`, { method: "POST" });
|
||||
}
|
||||
|
||||
export function listJobs(status?: JobStatus): Promise<JobView[]> {
|
||||
return request<JobView[]>("/jobs", {
|
||||
account: true,
|
||||
@@ -350,6 +354,20 @@ export function fetchMedia(
|
||||
});
|
||||
}
|
||||
|
||||
export function transcribeMedia(
|
||||
chatId: number,
|
||||
messageId: number
|
||||
): Promise<{ job_id: number }> {
|
||||
return request<{ job_id: number }>("/media/transcribe", {
|
||||
method: "POST",
|
||||
body: {
|
||||
account_id: accounts.selectedId,
|
||||
chat_id: chatId,
|
||||
message_id: messageId,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
export function listWatches(): Promise<Watch[]> {
|
||||
return request<Watch[]>("/watches", { account: true });
|
||||
}
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import { requestMedia } from "$lib/api/client";
|
||||
import { getMessageMedia } from "$lib/api/endpoints";
|
||||
import { getMessageMedia, transcribeMedia } from "$lib/api/endpoints";
|
||||
import type { MediaRef } from "$lib/api/types";
|
||||
import { accounts } from "$lib/stores/accounts.svelte";
|
||||
|
||||
const TRANSCRIBE_TRIES = 10;
|
||||
const TRANSCRIBE_DELAY = 2000;
|
||||
|
||||
export type InlineMedia =
|
||||
| {
|
||||
state: "ready";
|
||||
@@ -143,6 +146,43 @@ export function loadMediaItem(media: MediaRef): Promise<InlineMedia> {
|
||||
return promise;
|
||||
}
|
||||
|
||||
function patchTranscript(
|
||||
chatId: number,
|
||||
messageId: number,
|
||||
text: string
|
||||
): void {
|
||||
const account = accounts.selectedId;
|
||||
if (account === null) {
|
||||
return;
|
||||
}
|
||||
const cached = ready.get(cacheKey(account, chatId, messageId));
|
||||
if (cached?.state === "ready") {
|
||||
cached.transcript = text;
|
||||
}
|
||||
}
|
||||
|
||||
export async function requestTranscription(
|
||||
chatId: number,
|
||||
messageId: number
|
||||
): Promise<string | null> {
|
||||
await transcribeMedia(chatId, messageId);
|
||||
for (let i = 0; i < TRANSCRIBE_TRIES; i++) {
|
||||
await new Promise((resolve) => {
|
||||
setTimeout(resolve, TRANSCRIBE_DELAY);
|
||||
});
|
||||
try {
|
||||
const meta = await getMessageMedia(chatId, messageId);
|
||||
if (meta.extracted_text) {
|
||||
patchTranscript(chatId, messageId, meta.extracted_text);
|
||||
return meta.extracted_text;
|
||||
}
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
export function loadInlineMedia(
|
||||
chatId: number,
|
||||
messageId: number
|
||||
|
||||
@@ -106,7 +106,13 @@
|
||||
{:else if !loaded}
|
||||
<div class="media-skeleton"><Spinner /></div>
|
||||
{:else if ready && kind === "voice"}
|
||||
<VoiceMessage url={ready.url} transcript={ready.transcript} {own} />
|
||||
<VoiceMessage
|
||||
url={ready.url}
|
||||
transcript={ready.transcript}
|
||||
chatId={message.chat_id}
|
||||
messageId={message.message_id}
|
||||
{own}
|
||||
/>
|
||||
{:else if ready && kind === "video_note"}
|
||||
<VideoNote url={ready.url} transcript={ready.transcript} />
|
||||
{:else if ready && kind === "audio"}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
<script lang="ts">
|
||||
import { listJobs } from "$lib/api/endpoints";
|
||||
import { cancelJob, listJobs } from "$lib/api/endpoints";
|
||||
import type { JobStatus, JobView } from "$lib/api/types";
|
||||
import Spinner from "$lib/components/ui/Spinner.svelte";
|
||||
import { formatFull } from "$lib/format/datetime";
|
||||
import { toasts } from "$lib/stores/toasts.svelte";
|
||||
|
||||
interface Props {
|
||||
version?: number;
|
||||
@@ -33,12 +34,32 @@
|
||||
|
||||
let jobs = $state<JobView[]>([]);
|
||||
let loading = $state(true);
|
||||
let canceling = $state<number | null>(null);
|
||||
let timer: ReturnType<typeof setTimeout> | null = null;
|
||||
|
||||
function isActive(list: JobView[]): boolean {
|
||||
return list.some((j) => j.status === "pending" || j.status === "running");
|
||||
}
|
||||
|
||||
function canCancel(job: JobView): boolean {
|
||||
return job.status === "pending" || job.status === "running";
|
||||
}
|
||||
|
||||
async function cancel(job: JobView) {
|
||||
if (canceling !== null) {
|
||||
return;
|
||||
}
|
||||
canceling = job.id;
|
||||
try {
|
||||
await cancelJob(job.id);
|
||||
await load();
|
||||
} catch {
|
||||
toasts.error("Не удалось остановить задачу");
|
||||
} finally {
|
||||
canceling = null;
|
||||
}
|
||||
}
|
||||
|
||||
function schedule() {
|
||||
if (timer) {
|
||||
clearTimeout(timer);
|
||||
@@ -96,6 +117,16 @@
|
||||
<div class="job">
|
||||
<div class="job-head">
|
||||
<span class="kind">{kindLabel(job.kind)}</span>
|
||||
{#if canCancel(job)}
|
||||
<button
|
||||
type="button"
|
||||
class="stop"
|
||||
onclick={() => cancel(job)}
|
||||
disabled={canceling === job.id}
|
||||
>
|
||||
Стоп
|
||||
</button>
|
||||
{/if}
|
||||
<span class="badge {job.status}">{STATUS_LABELS[job.status]}</span>
|
||||
</div>
|
||||
<div class="meta">
|
||||
@@ -178,6 +209,33 @@
|
||||
&.failed {
|
||||
background-color: var(--color-error);
|
||||
}
|
||||
|
||||
&.canceled {
|
||||
background-color: var(--color-text-secondary);
|
||||
}
|
||||
}
|
||||
|
||||
.stop {
|
||||
flex-shrink: 0;
|
||||
|
||||
padding: 0.0625rem 0.5rem;
|
||||
border: 1px solid var(--color-error);
|
||||
border-radius: 0.625rem;
|
||||
|
||||
font-size: 0.75rem;
|
||||
color: var(--color-error);
|
||||
cursor: pointer;
|
||||
background-color: transparent;
|
||||
|
||||
&:hover:not(:disabled) {
|
||||
color: var(--color-white);
|
||||
background-color: var(--color-error);
|
||||
}
|
||||
|
||||
&:disabled {
|
||||
opacity: 0.5;
|
||||
cursor: default;
|
||||
}
|
||||
}
|
||||
|
||||
.meta {
|
||||
|
||||
@@ -1,18 +1,51 @@
|
||||
<script lang="ts">
|
||||
import { requestTranscription } from "$lib/api/media";
|
||||
import Icon from "$lib/components/ui/Icon.svelte";
|
||||
import Spinner from "$lib/components/ui/Spinner.svelte";
|
||||
import { formatDuration } from "$lib/format/duration";
|
||||
import { claimPlayback, releasePlayback } from "$lib/media/playback";
|
||||
import { computeWaveform, flatWaveform } from "$lib/media/waveform";
|
||||
import { toasts } from "$lib/stores/toasts.svelte";
|
||||
|
||||
interface Props {
|
||||
chatId: number;
|
||||
messageId: number;
|
||||
own: boolean;
|
||||
transcript?: string | null;
|
||||
url: string;
|
||||
}
|
||||
|
||||
let { url, own, transcript = null }: Props = $props();
|
||||
let { url, own, transcript = null, chatId, messageId }: Props = $props();
|
||||
|
||||
let fetched = $state<string | null>(null);
|
||||
let pending = $state(false);
|
||||
const text = $derived(fetched ?? transcript);
|
||||
|
||||
let showTranscript = $state(false);
|
||||
|
||||
async function onTranscribe() {
|
||||
if (text) {
|
||||
showTranscript = !showTranscript;
|
||||
return;
|
||||
}
|
||||
if (pending) {
|
||||
return;
|
||||
}
|
||||
pending = true;
|
||||
try {
|
||||
const result = await requestTranscription(chatId, messageId);
|
||||
if (result) {
|
||||
fetched = result;
|
||||
showTranscript = true;
|
||||
} else {
|
||||
toasts.error("Не удалось расшифровать");
|
||||
}
|
||||
} catch {
|
||||
toasts.error("Не удалось расшифровать");
|
||||
} finally {
|
||||
pending = false;
|
||||
}
|
||||
}
|
||||
let element = $state<HTMLAudioElement>();
|
||||
let peaks = $state<number[]>(flatWaveform());
|
||||
let currentTime = $state(0);
|
||||
@@ -68,7 +101,11 @@
|
||||
type="button"
|
||||
aria-label="Play voice"
|
||||
>
|
||||
<Icon name={paused ? "play" : "pause"} size="1.5rem" />
|
||||
<Icon
|
||||
name={paused ? "play" : "pause"}
|
||||
size="1.5rem"
|
||||
class={paused ? "nudge" : ""}
|
||||
/>
|
||||
</button>
|
||||
<div class="body">
|
||||
<!-- svelte-ignore a11y_no_static_element_interactions -->
|
||||
@@ -83,17 +120,20 @@
|
||||
</div>
|
||||
<div class="time">{formatDuration(elapsed)}</div>
|
||||
</div>
|
||||
{#if transcript}
|
||||
<button
|
||||
class="transcribe"
|
||||
class:active={showTranscript}
|
||||
onclick={() => (showTranscript = !showTranscript)}
|
||||
class:active={showTranscript && Boolean(text)}
|
||||
disabled={pending}
|
||||
onclick={onTranscribe}
|
||||
type="button"
|
||||
aria-label="Show transcription"
|
||||
>
|
||||
{#if pending}
|
||||
<Spinner color={own ? "white" : "gray"} size="1rem" />
|
||||
{:else}
|
||||
<Icon name="transcribe" size="1.125rem" />
|
||||
</button>
|
||||
{/if}
|
||||
</button>
|
||||
<!-- biome-ignore lint/a11y/useMediaCaption: archived voice note has no captions -->
|
||||
<audio
|
||||
bind:this={element}
|
||||
@@ -106,8 +146,8 @@
|
||||
src={url}
|
||||
></audio>
|
||||
</div>
|
||||
{#if transcript && showTranscript}
|
||||
<div class="transcript">{transcript}</div>
|
||||
{#if text && showTranscript}
|
||||
<div class="transcript">{text}</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
@@ -158,6 +198,10 @@
|
||||
color: var(--toggle-fg);
|
||||
background-color: var(--active);
|
||||
}
|
||||
|
||||
&:disabled {
|
||||
cursor: default;
|
||||
}
|
||||
}
|
||||
|
||||
.transcript {
|
||||
@@ -187,6 +231,10 @@
|
||||
color: var(--toggle-fg);
|
||||
|
||||
background-color: var(--active);
|
||||
|
||||
:global(.nudge) {
|
||||
transform: translateX(0.0625rem);
|
||||
}
|
||||
}
|
||||
|
||||
.body {
|
||||
|
||||
@@ -23,6 +23,13 @@ body {
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
.icon {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
line-height: 1;
|
||||
}
|
||||
|
||||
.icon::before {
|
||||
font-family: "icons" !important;
|
||||
speak: none;
|
||||
|
||||
@@ -3,6 +3,9 @@ import tailwindcss from "@tailwindcss/vite";
|
||||
import { defineConfig } from "vite";
|
||||
|
||||
const proxyTarget = process.env.API_PROXY_TARGET ?? "http://127.0.0.1:8080";
|
||||
const allowedHosts = process.env.ALLOWED_HOSTS
|
||||
? process.env.ALLOWED_HOSTS.split(",")
|
||||
: undefined;
|
||||
|
||||
export default defineConfig({
|
||||
plugins: [tailwindcss(), sveltekit()],
|
||||
@@ -14,6 +17,7 @@ export default defineConfig({
|
||||
},
|
||||
},
|
||||
server: {
|
||||
allowedHosts,
|
||||
proxy: {
|
||||
"/api": { target: proxyTarget, changeOrigin: true },
|
||||
"/mcp": { target: proxyTarget, changeOrigin: true },
|
||||
|
||||
Reference in New Issue
Block a user