Files
stealth-ai-relay/backend/src/bot/handlers/message/handler.py
2026-01-29 16:28:49 +01:00

762 lines
24 KiB
Python

import asyncio
import base64
import contextlib
import io
import time
from collections.abc import Awaitable, Callable
from typing import Any
from aiogram import BaseMiddleware, Bot, F, Router, html, types
from aiogram.enums import ChatAction
from aiogram.types import (
BufferedInputFile,
InputMediaPhoto,
KeyboardButton,
ReplyKeyboardMarkup,
ReplyKeyboardRemove,
TelegramObject,
)
from convex import ConvexInt64
from bot.handlers.proxy.handler import get_proxy_config, increment_proxy_count
from bot.modules.ai import (
SUMMARIZE_PROMPT,
AgentDeps,
ImageData,
create_follow_up_agent,
create_text_agent,
get_follow_ups,
stream_response,
)
from utils import env
from utils.convex import ConvexClient
router = Router()
convex = ConvexClient(env.convex_url)
ALBUM_COLLECT_DELAY = 0.5
class AlbumMiddleware(BaseMiddleware):
def __init__(self) -> None:
self.albums: dict[str, list[types.Message]] = {}
self.scheduled: set[str] = set()
async def __call__(
self,
handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]],
event: TelegramObject,
data: dict[str, Any],
) -> Any: # noqa: ANN401
if not isinstance(event, types.Message) or not event.media_group_id:
return await handler(event, data)
album_id = event.media_group_id
if album_id not in self.albums:
self.albums[album_id] = []
self.albums[album_id].append(event)
if album_id in self.scheduled:
return None
self.scheduled.add(album_id)
await asyncio.sleep(ALBUM_COLLECT_DELAY)
messages = self.albums.pop(album_id, [])
self.scheduled.discard(album_id)
if messages:
data["album"] = messages
return await handler(messages[0], data)
return None
router.message.middleware(AlbumMiddleware())
EDIT_THROTTLE_SECONDS = 1.0
TELEGRAM_MAX_LENGTH = 4096
async def fetch_chat_images(chat_id: str) -> list[ImageData]:
chat_images = await convex.query("messages:getChatImages", {"chatId": chat_id})
return [
ImageData(data=base64.b64decode(img["base64"]), media_type=img["mediaType"])
for img in (chat_images or [])
]
def make_follow_up_keyboard(options: list[str]) -> ReplyKeyboardMarkup:
buttons = [[KeyboardButton(text=opt)] for opt in options]
return ReplyKeyboardMarkup(
keyboard=buttons, resize_keyboard=True, one_time_keyboard=True
)
def split_message(text: str, max_length: int = TELEGRAM_MAX_LENGTH) -> list[str]:
if len(text) <= max_length:
return [text]
parts: list[str] = []
while text:
if len(text) <= max_length:
parts.append(text)
break
split_pos = text.rfind("\n", 0, max_length)
if split_pos == -1:
split_pos = text.rfind(" ", 0, max_length)
if split_pos == -1:
split_pos = max_length
parts.append(text[:split_pos])
text = text[split_pos:].lstrip()
return parts
class StreamingState:
def __init__(self, bot: Bot, chat_id: int, message: types.Message) -> None:
self.bot = bot
self.chat_id = chat_id
self.message = message
self.last_edit_time = 0.0
self.last_content = ""
self.pending_content: str | None = None
self._typing_task: asyncio.Task[None] | None = None
async def start_typing(self) -> None:
async def typing_loop() -> None:
while True:
await self.bot.send_chat_action(self.chat_id, ChatAction.TYPING)
await asyncio.sleep(4)
self._typing_task = asyncio.create_task(typing_loop())
async def stop_typing(self) -> None:
if self._typing_task:
self._typing_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._typing_task
async def update_message(self, content: str, *, force: bool = False) -> None:
if content == self.last_content:
return
if len(content) > TELEGRAM_MAX_LENGTH:
display_content = content[: TELEGRAM_MAX_LENGTH - 3] + "..."
else:
display_content = content
now = time.monotonic()
if force or (now - self.last_edit_time) >= EDIT_THROTTLE_SECONDS:
with contextlib.suppress(Exception):
await self.message.edit_text(html.quote(display_content))
self.last_edit_time = now
self.last_content = content
self.pending_content = None
else:
self.pending_content = content
async def flush(self) -> None:
if self.pending_content and self.pending_content != self.last_content:
await self.update_message(self.pending_content, force=True)
class ProxyStreamingState:
def __init__(self, bot: Bot, chat_id: int, message: types.Message) -> None:
self.bot = bot
self.chat_id = chat_id
self.message = message
self.last_edit_time = 0.0
self.last_content = ""
self.pending_content: str | None = None
async def update_message(self, content: str, *, force: bool = False) -> None:
if content == self.last_content:
return
if len(content) > TELEGRAM_MAX_LENGTH:
display_content = content[: TELEGRAM_MAX_LENGTH - 3] + "..."
else:
display_content = content
now = time.monotonic()
if force or (now - self.last_edit_time) >= EDIT_THROTTLE_SECONDS:
with contextlib.suppress(Exception):
await self.message.edit_text(display_content)
self.last_edit_time = now
self.last_content = content
self.pending_content = None
else:
self.pending_content = content
async def flush(self) -> None:
if self.pending_content and self.pending_content != self.last_content:
await self.update_message(self.pending_content, force=True)
async def send_long_message(
bot: Bot, chat_id: int, text: str, reply_markup: ReplyKeyboardMarkup | None = None
) -> None:
parts = split_message(text)
for i, part in enumerate(parts):
is_last = i == len(parts) - 1
await bot.send_message(
chat_id, html.quote(part), reply_markup=reply_markup if is_last else None
)
async def process_message_from_web( # noqa: C901, PLR0912, PLR0913, PLR0915
convex_user_id: str,
text: str,
bot: Bot,
convex_chat_id: str,
images_base64: list[str] | None = None,
images_media_types: list[str] | None = None,
) -> None:
user = await convex.query("users:getById", {"userId": convex_user_id})
if not user or not user.get("geminiApiKey"):
return
tg_chat_id = user["telegramChatId"].value if user.get("telegramChatId") else None
is_summarize = text == "/summarize"
if tg_chat_id and not is_summarize:
if images_base64 and images_media_types:
if len(images_base64) == 1:
photo_bytes = base64.b64decode(images_base64[0])
await bot.send_photo(
tg_chat_id,
BufferedInputFile(photo_bytes, "photo.jpg"),
caption=f"📱 {text}" if text else "📱",
reply_markup=ReplyKeyboardRemove(),
)
else:
media = []
img_pairs = zip(images_base64, images_media_types, strict=True)
for i, (img_b64, _) in enumerate(img_pairs):
photo_bytes = base64.b64decode(img_b64)
caption = f"📱 {text}" if i == 0 and text else None
media.append(
InputMediaPhoto(
media=BufferedInputFile(photo_bytes, f"photo_{i}.jpg"),
caption=caption,
)
)
await bot.send_media_group(tg_chat_id, media)
else:
await bot.send_message(
tg_chat_id, f"📱 {html.quote(text)}", reply_markup=ReplyKeyboardRemove()
)
api_key = user["geminiApiKey"]
model_name = user.get("model", "gemini-3-pro-preview")
rag_connections = await convex.query(
"ragConnections:getActiveForUser", {"userId": convex_user_id}
)
rag_db_names: list[str] = []
if rag_connections:
for conn in rag_connections:
db = await convex.query(
"rag:getDatabaseById", {"ragDatabaseId": conn["ragDatabaseId"]}
)
if db:
rag_db_names.append(db["name"])
assistant_message_id = await convex.mutation(
"messages:create",
{
"chatId": convex_chat_id,
"role": "assistant",
"content": "",
"source": "web",
"isStreaming": True,
},
)
history = await convex.query(
"messages:getHistoryForAI", {"chatId": convex_chat_id, "limit": 50}
)
system_prompt = SUMMARIZE_PROMPT if is_summarize else user.get("systemPrompt")
text_agent = create_text_agent(
api_key=api_key,
model_name=model_name,
system_prompt=system_prompt,
rag_db_names=rag_db_names if rag_db_names else None,
)
agent_deps = (
AgentDeps(user_id=convex_user_id, api_key=api_key, rag_db_names=rag_db_names)
if rag_db_names
else None
)
processing_msg = None
state = None
if tg_chat_id:
processing_msg = await bot.send_message(tg_chat_id, "...")
state = StreamingState(bot, tg_chat_id, processing_msg)
try:
if state:
await state.start_typing()
async def on_chunk(content: str) -> None:
if state:
await state.update_message(content)
await convex.mutation(
"messages:update",
{"messageId": assistant_message_id, "content": content},
)
if is_summarize:
prompt_text = "Summarize what was done in this conversation."
hist = history[:-2]
else:
prompt_text = text
hist = history[:-1]
chat_images = await fetch_chat_images(convex_chat_id)
final_answer = await stream_response(
text_agent, prompt_text, hist, on_chunk, images=chat_images, deps=agent_deps
)
if state:
await state.flush()
full_history = [*history, {"role": "assistant", "content": final_answer}]
follow_up_model = user.get("followUpModel", "gemini-2.5-flash-lite")
follow_up_prompt = user.get("followUpPrompt")
follow_up_agent = create_follow_up_agent(
api_key=api_key, model_name=follow_up_model, system_prompt=follow_up_prompt
)
follow_ups = await get_follow_ups(follow_up_agent, full_history, chat_images)
if state:
await state.stop_typing()
await convex.mutation(
"messages:update",
{
"messageId": assistant_message_id,
"content": final_answer,
"followUpOptions": follow_ups,
"isStreaming": False,
},
)
if is_summarize:
await convex.mutation(
"chats:clear", {"chatId": convex_chat_id, "preserveImages": True}
)
await convex.mutation(
"messages:create",
{
"chatId": convex_chat_id,
"role": "assistant",
"content": final_answer,
"source": "web",
"followUpOptions": follow_ups,
},
)
if tg_chat_id and processing_msg:
with contextlib.suppress(Exception):
await processing_msg.delete()
keyboard = make_follow_up_keyboard(follow_ups)
await send_long_message(bot, tg_chat_id, final_answer, keyboard)
except Exception as e: # noqa: BLE001
if state:
await state.stop_typing()
error_msg = f"Error: {e}"
await convex.mutation(
"messages:update",
{
"messageId": assistant_message_id,
"content": error_msg,
"isStreaming": False,
},
)
if tg_chat_id and processing_msg:
with contextlib.suppress(Exception):
truncated = html.quote(error_msg[:TELEGRAM_MAX_LENGTH])
await processing_msg.edit_text(truncated)
async def process_message( # noqa: C901, PLR0912, PLR0913, PLR0915
user_id: int,
text: str,
bot: Bot,
chat_id: int,
*,
skip_user_message: bool = False,
skip_proxy_user_message: bool = False,
) -> None:
user = await convex.query(
"users:getByTelegramId", {"telegramId": ConvexInt64(user_id)}
)
if not user:
await bot.send_message(chat_id, "Use /apikey first to set your Gemini API key.")
return
if not user.get("geminiApiKey"):
await bot.send_message(chat_id, "Use /apikey first to set your Gemini API key.")
return
if not user.get("activeChatId"):
await bot.send_message(chat_id, "Use /new first to create a chat.")
return
active_chat_id = user["activeChatId"]
api_key = user["geminiApiKey"]
model_name = user.get("model", "gemini-3-pro-preview")
convex_user_id = user["_id"]
proxy_config = get_proxy_config(chat_id)
proxy_state: ProxyStreamingState | None = None
if proxy_config and not skip_proxy_user_message:
with contextlib.suppress(Exception):
await proxy_config.proxy_bot.send_message(
proxy_config.target_chat_id, f"👤 {text}"
)
await increment_proxy_count(chat_id)
proxy_config = get_proxy_config(chat_id)
rag_connections = await convex.query(
"ragConnections:getActiveForUser", {"userId": convex_user_id}
)
rag_db_names: list[str] = []
if rag_connections:
for conn in rag_connections:
db = await convex.query(
"rag:getDatabaseById", {"ragDatabaseId": conn["ragDatabaseId"]}
)
if db:
rag_db_names.append(db["name"])
if not skip_user_message:
await convex.mutation(
"messages:create",
{
"chatId": active_chat_id,
"role": "user",
"content": text,
"source": "telegram",
},
)
assistant_message_id = await convex.mutation(
"messages:create",
{
"chatId": active_chat_id,
"role": "assistant",
"content": "",
"source": "telegram",
"isStreaming": True,
},
)
history = await convex.query(
"messages:getHistoryForAI", {"chatId": active_chat_id, "limit": 50}
)
text_agent = create_text_agent(
api_key=api_key,
model_name=model_name,
system_prompt=user.get("systemPrompt"),
rag_db_names=rag_db_names if rag_db_names else None,
)
agent_deps = (
AgentDeps(user_id=convex_user_id, api_key=api_key, rag_db_names=rag_db_names)
if rag_db_names
else None
)
processing_msg = await bot.send_message(chat_id, "...")
state = StreamingState(bot, chat_id, processing_msg)
if proxy_config:
proxy_processing_msg = await proxy_config.proxy_bot.send_message(
proxy_config.target_chat_id, "..."
)
proxy_state = ProxyStreamingState(
proxy_config.proxy_bot, proxy_config.target_chat_id, proxy_processing_msg
)
try:
await state.start_typing()
async def on_chunk(content: str) -> None:
await state.update_message(content)
if proxy_state:
await proxy_state.update_message(content)
await convex.mutation(
"messages:update",
{"messageId": assistant_message_id, "content": content},
)
chat_images = await fetch_chat_images(active_chat_id)
final_answer = await stream_response(
text_agent,
text,
history[:-2],
on_chunk,
images=chat_images,
deps=agent_deps,
)
await state.flush()
if proxy_state:
await proxy_state.flush()
full_history = [*history[:-1], {"role": "assistant", "content": final_answer}]
follow_up_model = user.get("followUpModel", "gemini-2.5-flash-lite")
follow_up_prompt = user.get("followUpPrompt")
follow_up_agent = create_follow_up_agent(
api_key=api_key, model_name=follow_up_model, system_prompt=follow_up_prompt
)
follow_ups = await get_follow_ups(follow_up_agent, full_history, chat_images)
await state.stop_typing()
await convex.mutation(
"messages:update",
{
"messageId": assistant_message_id,
"content": final_answer,
"followUpOptions": follow_ups,
"isStreaming": False,
},
)
with contextlib.suppress(Exception):
await processing_msg.delete()
keyboard = make_follow_up_keyboard(follow_ups)
await send_long_message(bot, chat_id, final_answer, keyboard)
if proxy_state and proxy_config:
with contextlib.suppress(Exception):
await proxy_state.message.delete()
parts = split_message(final_answer)
for part in parts:
await proxy_config.proxy_bot.send_message(
proxy_config.target_chat_id, part
)
await increment_proxy_count(chat_id)
except Exception as e: # noqa: BLE001
await state.stop_typing()
error_msg = f"Error: {e}"
await convex.mutation(
"messages:update",
{
"messageId": assistant_message_id,
"content": error_msg,
"isStreaming": False,
},
)
with contextlib.suppress(Exception):
await processing_msg.edit_text(html.quote(error_msg[:TELEGRAM_MAX_LENGTH]))
if proxy_state:
with contextlib.suppress(Exception):
await proxy_state.message.edit_text(error_msg[:TELEGRAM_MAX_LENGTH])
async def send_to_telegram(user_id: int, text: str, bot: Bot) -> None:
user = await convex.query(
"users:getByTelegramId", {"telegramId": ConvexInt64(user_id)}
)
if not user or not user.get("telegramChatId"):
return
tg_chat_id = user["telegramChatId"]
await bot.send_message(
tg_chat_id, f"📱 {html.quote(text)}", reply_markup=ReplyKeyboardRemove()
)
@router.message(F.text & ~F.text.startswith("/"))
async def on_text_message(message: types.Message, bot: Bot) -> None:
if not message.from_user or not message.text:
return
await convex.mutation(
"users:getOrCreate",
{
"telegramId": ConvexInt64(message.from_user.id),
"telegramChatId": ConvexInt64(message.chat.id),
},
)
await process_message(message.from_user.id, message.text, bot, message.chat.id)
@router.message(F.media_group_id, F.photo)
async def on_album_message(
message: types.Message, bot: Bot, album: list[types.Message]
) -> None:
if not message.from_user:
return
await convex.mutation(
"users:getOrCreate",
{
"telegramId": ConvexInt64(message.from_user.id),
"telegramChatId": ConvexInt64(message.chat.id),
},
)
user = await convex.query(
"users:getByTelegramId", {"telegramId": ConvexInt64(message.from_user.id)}
)
if not user or not user.get("activeChatId"):
await message.answer("Use /new first to create a chat.")
return
caption = message.caption or "Process the images according to your task"
images_base64: list[str] = []
images_media_types: list[str] = []
photos_bytes: list[bytes] = []
for msg in album:
if not msg.photo:
continue
photo = msg.photo[-1]
file = await bot.get_file(photo.file_id)
if not file.file_path:
continue
buffer = io.BytesIO()
await bot.download_file(file.file_path, buffer)
image_bytes = buffer.getvalue()
photos_bytes.append(image_bytes)
images_base64.append(base64.b64encode(image_bytes).decode())
ext = file.file_path.rsplit(".", 1)[-1].lower()
media_type = f"image/{ext}" if ext in ("png", "gif", "webp") else "image/jpeg"
images_media_types.append(media_type)
if not images_base64:
await message.answer("Failed to get photos.")
return
proxy_config = get_proxy_config(message.chat.id)
if proxy_config:
with contextlib.suppress(Exception):
media = []
for i, photo_bytes in enumerate(photos_bytes):
cap = f"👤 {caption}" if i == 0 else None
media.append(
InputMediaPhoto(
media=BufferedInputFile(photo_bytes, f"photo_{i}.jpg"),
caption=cap,
)
)
await proxy_config.proxy_bot.send_media_group(
proxy_config.target_chat_id, media
)
await increment_proxy_count(message.chat.id)
active_chat_id = user["activeChatId"]
await convex.mutation(
"messages:create",
{
"chatId": active_chat_id,
"role": "user",
"content": caption,
"source": "telegram",
"imagesBase64": images_base64,
"imagesMediaTypes": images_media_types,
},
)
await process_message(
message.from_user.id,
caption,
bot,
message.chat.id,
skip_user_message=True,
skip_proxy_user_message=True,
)
@router.message(F.photo)
async def on_photo_message(message: types.Message, bot: Bot) -> None:
if not message.from_user or not message.photo:
return
await convex.mutation(
"users:getOrCreate",
{
"telegramId": ConvexInt64(message.from_user.id),
"telegramChatId": ConvexInt64(message.chat.id),
},
)
user = await convex.query(
"users:getByTelegramId", {"telegramId": ConvexInt64(message.from_user.id)}
)
if not user or not user.get("activeChatId"):
await message.answer("Use /new first to create a chat.")
return
caption = message.caption or "Process the image according to your task"
photo = message.photo[-1]
file = await bot.get_file(photo.file_id)
if not file.file_path:
await message.answer("Failed to get photo.")
return
buffer = io.BytesIO()
await bot.download_file(file.file_path, buffer)
image_bytes = buffer.getvalue()
image_base64 = base64.b64encode(image_bytes).decode()
ext = file.file_path.rsplit(".", 1)[-1].lower()
media_type = f"image/{ext}" if ext in ("png", "gif", "webp") else "image/jpeg"
proxy_config = get_proxy_config(message.chat.id)
if proxy_config:
with contextlib.suppress(Exception):
await proxy_config.proxy_bot.send_photo(
proxy_config.target_chat_id,
BufferedInputFile(image_bytes, "photo.jpg"),
caption=f"👤 {caption}",
)
await increment_proxy_count(message.chat.id)
active_chat_id = user["activeChatId"]
await convex.mutation(
"messages:create",
{
"chatId": active_chat_id,
"role": "user",
"content": caption,
"source": "telegram",
"imageBase64": image_base64,
"imageMediaType": media_type,
},
)
await process_message(
message.from_user.id,
caption,
bot,
message.chat.id,
skip_user_message=True,
skip_proxy_user_message=True,
)