import asyncio import base64 import contextlib import io import time from collections.abc import Awaitable, Callable from typing import Any from aiogram import BaseMiddleware, Bot, F, Router, html, types from aiogram.enums import ChatAction from aiogram.types import ( BufferedInputFile, InputMediaPhoto, KeyboardButton, ReplyKeyboardMarkup, ReplyKeyboardRemove, TelegramObject, ) from convex import ConvexInt64 from bot.handlers.proxy.handler import get_proxy_config, increment_proxy_count from bot.modules.ai import ( SUMMARIZE_PROMPT, AgentDeps, ImageData, create_follow_up_agent, create_text_agent, get_follow_ups, stream_response, ) from utils import env from utils.convex import ConvexClient router = Router() convex = ConvexClient(env.convex_url) ALBUM_COLLECT_DELAY = 0.5 class AlbumMiddleware(BaseMiddleware): def __init__(self) -> None: self.albums: dict[str, list[types.Message]] = {} self.scheduled: set[str] = set() async def __call__( self, handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]], event: TelegramObject, data: dict[str, Any], ) -> Any: # noqa: ANN401 if not isinstance(event, types.Message) or not event.media_group_id: return await handler(event, data) album_id = event.media_group_id if album_id not in self.albums: self.albums[album_id] = [] self.albums[album_id].append(event) if album_id in self.scheduled: return None self.scheduled.add(album_id) await asyncio.sleep(ALBUM_COLLECT_DELAY) messages = self.albums.pop(album_id, []) self.scheduled.discard(album_id) if messages: data["album"] = messages return await handler(messages[0], data) return None router.message.middleware(AlbumMiddleware()) EDIT_THROTTLE_SECONDS = 1.0 TELEGRAM_MAX_LENGTH = 4096 async def fetch_chat_images(chat_id: str) -> list[ImageData]: chat_images = await convex.query("messages:getChatImages", {"chatId": chat_id}) return [ ImageData(data=base64.b64decode(img["base64"]), media_type=img["mediaType"]) for img in (chat_images or []) ] def make_follow_up_keyboard(options: list[str]) -> ReplyKeyboardMarkup: buttons = [[KeyboardButton(text=opt)] for opt in options] return ReplyKeyboardMarkup( keyboard=buttons, resize_keyboard=True, one_time_keyboard=True ) def split_message(text: str, max_length: int = TELEGRAM_MAX_LENGTH) -> list[str]: if len(text) <= max_length: return [text] parts: list[str] = [] while text: if len(text) <= max_length: parts.append(text) break split_pos = text.rfind("\n", 0, max_length) if split_pos == -1: split_pos = text.rfind(" ", 0, max_length) if split_pos == -1: split_pos = max_length parts.append(text[:split_pos]) text = text[split_pos:].lstrip() return parts class StreamingState: def __init__(self, bot: Bot, chat_id: int, message: types.Message) -> None: self.bot = bot self.chat_id = chat_id self.message = message self.last_edit_time = 0.0 self.last_content = "" self.pending_content: str | None = None self._typing_task: asyncio.Task[None] | None = None async def start_typing(self) -> None: async def typing_loop() -> None: while True: await self.bot.send_chat_action(self.chat_id, ChatAction.TYPING) await asyncio.sleep(4) self._typing_task = asyncio.create_task(typing_loop()) async def stop_typing(self) -> None: if self._typing_task: self._typing_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._typing_task async def update_message(self, content: str, *, force: bool = False) -> None: if content == self.last_content: return if len(content) > TELEGRAM_MAX_LENGTH: display_content = content[: TELEGRAM_MAX_LENGTH - 3] + "..." else: display_content = content now = time.monotonic() if force or (now - self.last_edit_time) >= EDIT_THROTTLE_SECONDS: with contextlib.suppress(Exception): await self.message.edit_text(html.quote(display_content)) self.last_edit_time = now self.last_content = content self.pending_content = None else: self.pending_content = content async def flush(self) -> None: if self.pending_content and self.pending_content != self.last_content: await self.update_message(self.pending_content, force=True) class ProxyStreamingState: def __init__(self, bot: Bot, chat_id: int, message: types.Message) -> None: self.bot = bot self.chat_id = chat_id self.message = message self.last_edit_time = 0.0 self.last_content = "" self.pending_content: str | None = None async def update_message(self, content: str, *, force: bool = False) -> None: if content == self.last_content: return if len(content) > TELEGRAM_MAX_LENGTH: display_content = content[: TELEGRAM_MAX_LENGTH - 3] + "..." else: display_content = content now = time.monotonic() if force or (now - self.last_edit_time) >= EDIT_THROTTLE_SECONDS: with contextlib.suppress(Exception): await self.message.edit_text(display_content) self.last_edit_time = now self.last_content = content self.pending_content = None else: self.pending_content = content async def flush(self) -> None: if self.pending_content and self.pending_content != self.last_content: await self.update_message(self.pending_content, force=True) async def send_long_message( bot: Bot, chat_id: int, text: str, reply_markup: ReplyKeyboardMarkup | None = None ) -> None: parts = split_message(text) for i, part in enumerate(parts): is_last = i == len(parts) - 1 await bot.send_message( chat_id, html.quote(part), reply_markup=reply_markup if is_last else None ) async def process_message_from_web( # noqa: C901, PLR0912, PLR0913, PLR0915 convex_user_id: str, text: str, bot: Bot, convex_chat_id: str, images_base64: list[str] | None = None, images_media_types: list[str] | None = None, ) -> None: user = await convex.query("users:getById", {"userId": convex_user_id}) if not user or not user.get("geminiApiKey"): return tg_chat_id = user["telegramChatId"].value if user.get("telegramChatId") else None is_summarize = text == "/summarize" if tg_chat_id and not is_summarize: if images_base64 and images_media_types: if len(images_base64) == 1: photo_bytes = base64.b64decode(images_base64[0]) await bot.send_photo( tg_chat_id, BufferedInputFile(photo_bytes, "photo.jpg"), caption=f"📱 {text}" if text else "📱", reply_markup=ReplyKeyboardRemove(), ) else: media = [] img_pairs = zip(images_base64, images_media_types, strict=True) for i, (img_b64, _) in enumerate(img_pairs): photo_bytes = base64.b64decode(img_b64) caption = f"📱 {text}" if i == 0 and text else None media.append( InputMediaPhoto( media=BufferedInputFile(photo_bytes, f"photo_{i}.jpg"), caption=caption, ) ) await bot.send_media_group(tg_chat_id, media) else: await bot.send_message( tg_chat_id, f"📱 {html.quote(text)}", reply_markup=ReplyKeyboardRemove() ) api_key = user["geminiApiKey"] model_name = user.get("model", "gemini-3-pro-preview") rag_connections = await convex.query( "ragConnections:getActiveForUser", {"userId": convex_user_id} ) rag_db_names: list[str] = [] if rag_connections: for conn in rag_connections: db = await convex.query( "rag:getDatabaseById", {"ragDatabaseId": conn["ragDatabaseId"]} ) if db: rag_db_names.append(db["name"]) assistant_message_id = await convex.mutation( "messages:create", { "chatId": convex_chat_id, "role": "assistant", "content": "", "source": "web", "isStreaming": True, }, ) history = await convex.query( "messages:getHistoryForAI", {"chatId": convex_chat_id, "limit": 50} ) system_prompt = SUMMARIZE_PROMPT if is_summarize else user.get("systemPrompt") text_agent = create_text_agent( api_key=api_key, model_name=model_name, system_prompt=system_prompt, rag_db_names=rag_db_names if rag_db_names else None, ) agent_deps = ( AgentDeps(user_id=convex_user_id, api_key=api_key, rag_db_names=rag_db_names) if rag_db_names else None ) processing_msg = None state = None if tg_chat_id: processing_msg = await bot.send_message(tg_chat_id, "...") state = StreamingState(bot, tg_chat_id, processing_msg) try: if state: await state.start_typing() async def on_chunk(content: str) -> None: if state: await state.update_message(content) await convex.mutation( "messages:update", {"messageId": assistant_message_id, "content": content}, ) if is_summarize: prompt_text = "Summarize what was done in this conversation." hist = history[:-2] else: prompt_text = text hist = history[:-1] chat_images = await fetch_chat_images(convex_chat_id) final_answer = await stream_response( text_agent, prompt_text, hist, on_chunk, images=chat_images, deps=agent_deps ) if state: await state.flush() full_history = [*history, {"role": "assistant", "content": final_answer}] follow_up_model = user.get("followUpModel", "gemini-2.5-flash-lite") follow_up_prompt = user.get("followUpPrompt") follow_up_agent = create_follow_up_agent( api_key=api_key, model_name=follow_up_model, system_prompt=follow_up_prompt ) follow_ups = await get_follow_ups(follow_up_agent, full_history, chat_images) if state: await state.stop_typing() await convex.mutation( "messages:update", { "messageId": assistant_message_id, "content": final_answer, "followUpOptions": follow_ups, "isStreaming": False, }, ) if is_summarize: await convex.mutation( "chats:clear", {"chatId": convex_chat_id, "preserveImages": True} ) await convex.mutation( "messages:create", { "chatId": convex_chat_id, "role": "assistant", "content": final_answer, "source": "web", "followUpOptions": follow_ups, }, ) if tg_chat_id and processing_msg: with contextlib.suppress(Exception): await processing_msg.delete() keyboard = make_follow_up_keyboard(follow_ups) await send_long_message(bot, tg_chat_id, final_answer, keyboard) except Exception as e: # noqa: BLE001 if state: await state.stop_typing() error_msg = f"Error: {e}" await convex.mutation( "messages:update", { "messageId": assistant_message_id, "content": error_msg, "isStreaming": False, }, ) if tg_chat_id and processing_msg: with contextlib.suppress(Exception): truncated = html.quote(error_msg[:TELEGRAM_MAX_LENGTH]) await processing_msg.edit_text(truncated) async def process_message( # noqa: C901, PLR0912, PLR0913, PLR0915 user_id: int, text: str, bot: Bot, chat_id: int, *, skip_user_message: bool = False, skip_proxy_user_message: bool = False, ) -> None: user = await convex.query( "users:getByTelegramId", {"telegramId": ConvexInt64(user_id)} ) if not user: await bot.send_message(chat_id, "Use /apikey first to set your Gemini API key.") return if not user.get("geminiApiKey"): await bot.send_message(chat_id, "Use /apikey first to set your Gemini API key.") return if not user.get("activeChatId"): await bot.send_message(chat_id, "Use /new first to create a chat.") return active_chat_id = user["activeChatId"] api_key = user["geminiApiKey"] model_name = user.get("model", "gemini-3-pro-preview") convex_user_id = user["_id"] proxy_config = get_proxy_config(chat_id) proxy_state: ProxyStreamingState | None = None if proxy_config and not skip_proxy_user_message: with contextlib.suppress(Exception): await proxy_config.proxy_bot.send_message( proxy_config.target_chat_id, f"👤 {text}" ) await increment_proxy_count(chat_id) proxy_config = get_proxy_config(chat_id) rag_connections = await convex.query( "ragConnections:getActiveForUser", {"userId": convex_user_id} ) rag_db_names: list[str] = [] if rag_connections: for conn in rag_connections: db = await convex.query( "rag:getDatabaseById", {"ragDatabaseId": conn["ragDatabaseId"]} ) if db: rag_db_names.append(db["name"]) if not skip_user_message: await convex.mutation( "messages:create", { "chatId": active_chat_id, "role": "user", "content": text, "source": "telegram", }, ) assistant_message_id = await convex.mutation( "messages:create", { "chatId": active_chat_id, "role": "assistant", "content": "", "source": "telegram", "isStreaming": True, }, ) history = await convex.query( "messages:getHistoryForAI", {"chatId": active_chat_id, "limit": 50} ) text_agent = create_text_agent( api_key=api_key, model_name=model_name, system_prompt=user.get("systemPrompt"), rag_db_names=rag_db_names if rag_db_names else None, ) agent_deps = ( AgentDeps(user_id=convex_user_id, api_key=api_key, rag_db_names=rag_db_names) if rag_db_names else None ) processing_msg = await bot.send_message(chat_id, "...") state = StreamingState(bot, chat_id, processing_msg) if proxy_config: proxy_processing_msg = await proxy_config.proxy_bot.send_message( proxy_config.target_chat_id, "..." ) proxy_state = ProxyStreamingState( proxy_config.proxy_bot, proxy_config.target_chat_id, proxy_processing_msg ) try: await state.start_typing() async def on_chunk(content: str) -> None: await state.update_message(content) if proxy_state: await proxy_state.update_message(content) await convex.mutation( "messages:update", {"messageId": assistant_message_id, "content": content}, ) chat_images = await fetch_chat_images(active_chat_id) final_answer = await stream_response( text_agent, text, history[:-2], on_chunk, images=chat_images, deps=agent_deps, ) await state.flush() if proxy_state: await proxy_state.flush() full_history = [*history[:-1], {"role": "assistant", "content": final_answer}] follow_up_model = user.get("followUpModel", "gemini-2.5-flash-lite") follow_up_prompt = user.get("followUpPrompt") follow_up_agent = create_follow_up_agent( api_key=api_key, model_name=follow_up_model, system_prompt=follow_up_prompt ) follow_ups = await get_follow_ups(follow_up_agent, full_history, chat_images) await state.stop_typing() await convex.mutation( "messages:update", { "messageId": assistant_message_id, "content": final_answer, "followUpOptions": follow_ups, "isStreaming": False, }, ) with contextlib.suppress(Exception): await processing_msg.delete() keyboard = make_follow_up_keyboard(follow_ups) await send_long_message(bot, chat_id, final_answer, keyboard) if proxy_state and proxy_config: with contextlib.suppress(Exception): await proxy_state.message.delete() parts = split_message(final_answer) for part in parts: await proxy_config.proxy_bot.send_message( proxy_config.target_chat_id, part ) await increment_proxy_count(chat_id) except Exception as e: # noqa: BLE001 await state.stop_typing() error_msg = f"Error: {e}" await convex.mutation( "messages:update", { "messageId": assistant_message_id, "content": error_msg, "isStreaming": False, }, ) with contextlib.suppress(Exception): await processing_msg.edit_text(html.quote(error_msg[:TELEGRAM_MAX_LENGTH])) if proxy_state: with contextlib.suppress(Exception): await proxy_state.message.edit_text(error_msg[:TELEGRAM_MAX_LENGTH]) async def send_to_telegram(user_id: int, text: str, bot: Bot) -> None: user = await convex.query( "users:getByTelegramId", {"telegramId": ConvexInt64(user_id)} ) if not user or not user.get("telegramChatId"): return tg_chat_id = user["telegramChatId"] await bot.send_message( tg_chat_id, f"📱 {html.quote(text)}", reply_markup=ReplyKeyboardRemove() ) @router.message(F.text & ~F.text.startswith("/")) async def on_text_message(message: types.Message, bot: Bot) -> None: if not message.from_user or not message.text: return await convex.mutation( "users:getOrCreate", { "telegramId": ConvexInt64(message.from_user.id), "telegramChatId": ConvexInt64(message.chat.id), }, ) await process_message(message.from_user.id, message.text, bot, message.chat.id) @router.message(F.media_group_id, F.photo) async def on_album_message( message: types.Message, bot: Bot, album: list[types.Message] ) -> None: if not message.from_user: return await convex.mutation( "users:getOrCreate", { "telegramId": ConvexInt64(message.from_user.id), "telegramChatId": ConvexInt64(message.chat.id), }, ) user = await convex.query( "users:getByTelegramId", {"telegramId": ConvexInt64(message.from_user.id)} ) if not user or not user.get("activeChatId"): await message.answer("Use /new first to create a chat.") return caption = message.caption or "Process the images according to your task" images_base64: list[str] = [] images_media_types: list[str] = [] photos_bytes: list[bytes] = [] for msg in album: if not msg.photo: continue photo = msg.photo[-1] file = await bot.get_file(photo.file_id) if not file.file_path: continue buffer = io.BytesIO() await bot.download_file(file.file_path, buffer) image_bytes = buffer.getvalue() photos_bytes.append(image_bytes) images_base64.append(base64.b64encode(image_bytes).decode()) ext = file.file_path.rsplit(".", 1)[-1].lower() media_type = f"image/{ext}" if ext in ("png", "gif", "webp") else "image/jpeg" images_media_types.append(media_type) if not images_base64: await message.answer("Failed to get photos.") return proxy_config = get_proxy_config(message.chat.id) if proxy_config: with contextlib.suppress(Exception): media = [] for i, photo_bytes in enumerate(photos_bytes): cap = f"👤 {caption}" if i == 0 else None media.append( InputMediaPhoto( media=BufferedInputFile(photo_bytes, f"photo_{i}.jpg"), caption=cap, ) ) await proxy_config.proxy_bot.send_media_group( proxy_config.target_chat_id, media ) await increment_proxy_count(message.chat.id) active_chat_id = user["activeChatId"] await convex.mutation( "messages:create", { "chatId": active_chat_id, "role": "user", "content": caption, "source": "telegram", "imagesBase64": images_base64, "imagesMediaTypes": images_media_types, }, ) await process_message( message.from_user.id, caption, bot, message.chat.id, skip_user_message=True, skip_proxy_user_message=True, ) @router.message(F.photo) async def on_photo_message(message: types.Message, bot: Bot) -> None: if not message.from_user or not message.photo: return await convex.mutation( "users:getOrCreate", { "telegramId": ConvexInt64(message.from_user.id), "telegramChatId": ConvexInt64(message.chat.id), }, ) user = await convex.query( "users:getByTelegramId", {"telegramId": ConvexInt64(message.from_user.id)} ) if not user or not user.get("activeChatId"): await message.answer("Use /new first to create a chat.") return caption = message.caption or "Process the image according to your task" photo = message.photo[-1] file = await bot.get_file(photo.file_id) if not file.file_path: await message.answer("Failed to get photo.") return buffer = io.BytesIO() await bot.download_file(file.file_path, buffer) image_bytes = buffer.getvalue() image_base64 = base64.b64encode(image_bytes).decode() ext = file.file_path.rsplit(".", 1)[-1].lower() media_type = f"image/{ext}" if ext in ("png", "gif", "webp") else "image/jpeg" proxy_config = get_proxy_config(message.chat.id) if proxy_config: with contextlib.suppress(Exception): await proxy_config.proxy_bot.send_photo( proxy_config.target_chat_id, BufferedInputFile(image_bytes, "photo.jpg"), caption=f"👤 {caption}", ) await increment_proxy_count(message.chat.id) active_chat_id = user["activeChatId"] await convex.mutation( "messages:create", { "chatId": active_chat_id, "role": "user", "content": caption, "source": "telegram", "imageBase64": image_base64, "imageMediaType": media_type, }, ) await process_message( message.from_user.id, caption, bot, message.chat.id, skip_user_message=True, skip_proxy_user_message=True, )