diff --git a/backend/src/bot/handlers/message/handler.py b/backend/src/bot/handlers/message/handler.py index dc77c47..6d5c784 100644 --- a/backend/src/bot/handlers/message/handler.py +++ b/backend/src/bot/handlers/message/handler.py @@ -28,6 +28,7 @@ from bot.modules.ai import ( get_follow_ups, stream_response, ) +from bot.user_lock import get_user_lock from utils import env from utils.convex import ConvexClient @@ -256,15 +257,34 @@ async def process_message_from_web( # noqa: C901, PLR0912, PLR0913, PLR0915 convex_chat_id: str, images_base64: list[str] | None = None, images_media_types: list[str] | None = None, + pending_generation_id: str | None = None, ) -> None: user = await convex.query("users:getById", {"userId": convex_user_id}) if not user or not user.get("geminiApiKey"): return - tg_chat_id = user["telegramChatId"].value if user.get("telegramChatId") else None is_summarize = text == "/summarize" + if not is_summarize: + user_message_args: dict = { + "chatId": convex_chat_id, + "role": "user", + "content": text, + "source": "web", + } + if images_base64 and images_media_types: + user_message_args["imagesBase64"] = images_base64 + user_message_args["imagesMediaTypes"] = images_media_types + await convex.mutation("messages:createFromBackend", user_message_args) + + if pending_generation_id: + await convex.mutation( + "pendingGenerations:remove", {"id": pending_generation_id} + ) + + tg_chat_id = user["telegramChatId"].value if user.get("telegramChatId") else None + if tg_chat_id and not is_summarize: if images_base64 and images_media_types: if len(images_base64) == 1: @@ -387,7 +407,7 @@ async def process_message_from_web( # noqa: C901, PLR0912, PLR0913, PLR0915 await state.flush() full_history = [*history, {"role": "assistant", "content": final_answer}] - follow_up_model = user.get("followUpModel", "gemini-2.5-flash-lite") + follow_up_model = user.get("followUpModel", "gemini-3.1-flash-lite-preview") follow_up_prompt = user.get("followUpPrompt") follow_up_agent = create_follow_up_agent( api_key=api_key, model_name=follow_up_model, system_prompt=follow_up_prompt @@ -485,174 +505,189 @@ async def process_message( # noqa: C901, PLR0912, PLR0913, PLR0915 model_name = user.get("model", "gemini-3-pro-preview") convex_user_id = user["_id"] - proxy_config = get_proxy_config(chat_id) - proxy_state: ProxyStreamingState | None = None + async with get_user_lock(convex_user_id): + proxy_config = get_proxy_config(chat_id) + proxy_state: ProxyStreamingState | None = None - if proxy_config and not skip_proxy_user_message: - with contextlib.suppress(Exception): - await proxy_config.proxy_bot.send_message( - proxy_config.target_chat_id, f"👤 {text}" + if proxy_config and not skip_proxy_user_message: + with contextlib.suppress(Exception): + await proxy_config.proxy_bot.send_message( + proxy_config.target_chat_id, f"👤 {text}" + ) + await increment_proxy_count(chat_id) + proxy_config = get_proxy_config(chat_id) + + rag_connections = await convex.query( + "ragConnections:getActiveForUser", {"userId": convex_user_id} + ) + rag_db_names: list[str] = [] + if rag_connections: + for conn in rag_connections: + db = await convex.query( + "rag:getDatabaseById", {"ragDatabaseId": conn["ragDatabaseId"]} + ) + if db: + rag_db_names.append(db["name"]) + + inject_connections = await convex.query( + "injectConnections:getActiveForUser", {"userId": convex_user_id} + ) + inject_content = "" + if inject_connections: + for conn in inject_connections: + db = await convex.query( + "inject:getDatabaseById", + {"injectDatabaseId": conn["injectDatabaseId"]}, + ) + if db and db.get("content"): + inject_content += db["content"] + "\n\n" + inject_content = inject_content.strip() + + if not skip_user_message: + await convex.mutation( + "messages:create", + { + "chatId": active_chat_id, + "role": "user", + "content": text, + "source": "telegram", + }, ) - await increment_proxy_count(chat_id) - proxy_config = get_proxy_config(chat_id) - rag_connections = await convex.query( - "ragConnections:getActiveForUser", {"userId": convex_user_id} - ) - rag_db_names: list[str] = [] - if rag_connections: - for conn in rag_connections: - db = await convex.query( - "rag:getDatabaseById", {"ragDatabaseId": conn["ragDatabaseId"]} - ) - if db: - rag_db_names.append(db["name"]) - - inject_connections = await convex.query( - "injectConnections:getActiveForUser", {"userId": convex_user_id} - ) - inject_content = "" - if inject_connections: - for conn in inject_connections: - db = await convex.query( - "inject:getDatabaseById", {"injectDatabaseId": conn["injectDatabaseId"]} - ) - if db and db.get("content"): - inject_content += db["content"] + "\n\n" - inject_content = inject_content.strip() - - if not skip_user_message: - await convex.mutation( + assistant_message_id = await convex.mutation( "messages:create", { "chatId": active_chat_id, - "role": "user", - "content": text, + "role": "assistant", + "content": "", "source": "telegram", + "isStreaming": True, }, ) - assistant_message_id = await convex.mutation( - "messages:create", - { - "chatId": active_chat_id, - "role": "assistant", - "content": "", - "source": "telegram", - "isStreaming": True, - }, - ) - - history = await convex.query( - "messages:getHistoryForAI", {"chatId": active_chat_id, "limit": 50} - ) - - system_prompt = user.get("systemPrompt") - if system_prompt and inject_content: - system_prompt = system_prompt.replace("{theory_database}", inject_content) - text_agent = create_text_agent( - api_key=api_key, - model_name=model_name, - system_prompt=system_prompt, - rag_db_names=rag_db_names or None, - ) - - agent_deps = ( - AgentDeps(user_id=convex_user_id, api_key=api_key, rag_db_names=rag_db_names) - if rag_db_names - else None - ) - - processing_msg = await bot.send_message(chat_id, "...") - state = StreamingState(bot, chat_id, processing_msg) - - if proxy_config: - proxy_processing_msg = await proxy_config.proxy_bot.send_message( - proxy_config.target_chat_id, "..." - ) - proxy_state = ProxyStreamingState( - proxy_config.proxy_bot, proxy_config.target_chat_id, proxy_processing_msg + history = await convex.query( + "messages:getHistoryForAI", {"chatId": active_chat_id, "limit": 50} ) - try: - await state.start_typing() + system_prompt = user.get("systemPrompt") + if system_prompt and inject_content: + system_prompt = system_prompt.replace("{theory_database}", inject_content) + text_agent = create_text_agent( + api_key=api_key, + model_name=model_name, + system_prompt=system_prompt, + rag_db_names=rag_db_names or None, + ) - async def on_chunk(content: str) -> None: - await state.update_message(content) - if proxy_state: - await proxy_state.update_message(content) - await convex.mutation( - "messages:update", - {"messageId": assistant_message_id, "content": content}, + agent_deps = ( + AgentDeps( + user_id=convex_user_id, api_key=api_key, rag_db_names=rag_db_names + ) + if rag_db_names + else None + ) + + processing_msg = await bot.send_message(chat_id, "...") + state = StreamingState(bot, chat_id, processing_msg) + + if proxy_config: + proxy_processing_msg = await proxy_config.proxy_bot.send_message( + proxy_config.target_chat_id, "..." + ) + proxy_state = ProxyStreamingState( + proxy_config.proxy_bot, + proxy_config.target_chat_id, + proxy_processing_msg, ) - chat_images = await fetch_chat_images(active_chat_id) + try: + await state.start_typing() - final_answer = await stream_response( - text_agent, - text, - history[:-2], - on_chunk, - images=chat_images, - deps=agent_deps, - ) - - await state.flush() - if proxy_state: - await proxy_state.flush() - - full_history = [*history[:-1], {"role": "assistant", "content": final_answer}] - follow_up_model = user.get("followUpModel", "gemini-2.5-flash-lite") - follow_up_prompt = user.get("followUpPrompt") - follow_up_agent = create_follow_up_agent( - api_key=api_key, model_name=follow_up_model, system_prompt=follow_up_prompt - ) - follow_ups = await get_follow_ups(follow_up_agent, full_history, chat_images) - - await state.stop_typing() - - await convex.mutation( - "messages:update", - { - "messageId": assistant_message_id, - "content": final_answer, - "followUpOptions": follow_ups, - "isStreaming": False, - }, - ) - - with contextlib.suppress(Exception): - await processing_msg.delete() - - keyboard = make_follow_up_keyboard(follow_ups) - await send_long_message(bot, chat_id, final_answer, keyboard) - - if proxy_state and proxy_config: - with contextlib.suppress(Exception): - await proxy_state.message.delete() - parts = split_message(final_answer) - for part in parts: - await proxy_config.proxy_bot.send_message( - proxy_config.target_chat_id, part + async def on_chunk(content: str) -> None: + await state.update_message(content) + if proxy_state: + await proxy_state.update_message(content) + await convex.mutation( + "messages:update", + {"messageId": assistant_message_id, "content": content}, ) - await increment_proxy_count(chat_id) - except Exception as e: # noqa: BLE001 - await state.stop_typing() - error_msg = f"Error: {e}" - await convex.mutation( - "messages:update", - { - "messageId": assistant_message_id, - "content": error_msg, - "isStreaming": False, - }, - ) - with contextlib.suppress(Exception): - await processing_msg.edit_text(html.quote(error_msg[:TELEGRAM_MAX_LENGTH])) - if proxy_state: + chat_images = await fetch_chat_images(active_chat_id) + + final_answer = await stream_response( + text_agent, + text, + history[:-2], + on_chunk, + images=chat_images, + deps=agent_deps, + ) + + await state.flush() + if proxy_state: + await proxy_state.flush() + + full_history = [ + *history[:-1], + {"role": "assistant", "content": final_answer}, + ] + follow_up_model = user.get("followUpModel", "gemini-3.1-flash-lite-preview") + follow_up_prompt = user.get("followUpPrompt") + follow_up_agent = create_follow_up_agent( + api_key=api_key, + model_name=follow_up_model, + system_prompt=follow_up_prompt, + ) + follow_ups = await get_follow_ups( + follow_up_agent, full_history, chat_images + ) + + await state.stop_typing() + + await convex.mutation( + "messages:update", + { + "messageId": assistant_message_id, + "content": final_answer, + "followUpOptions": follow_ups, + "isStreaming": False, + }, + ) + with contextlib.suppress(Exception): - await proxy_state.message.edit_text(error_msg[:TELEGRAM_MAX_LENGTH]) + await processing_msg.delete() + + keyboard = make_follow_up_keyboard(follow_ups) + await send_long_message(bot, chat_id, final_answer, keyboard) + + if proxy_state and proxy_config: + with contextlib.suppress(Exception): + await proxy_state.message.delete() + parts = split_message(final_answer) + for part in parts: + await proxy_config.proxy_bot.send_message( + proxy_config.target_chat_id, part + ) + await increment_proxy_count(chat_id) + + except Exception as e: # noqa: BLE001 + await state.stop_typing() + error_msg = f"Error: {e}" + await convex.mutation( + "messages:update", + { + "messageId": assistant_message_id, + "content": error_msg, + "isStreaming": False, + }, + ) + with contextlib.suppress(Exception): + await processing_msg.edit_text( + html.quote(error_msg[:TELEGRAM_MAX_LENGTH]) + ) + if proxy_state: + with contextlib.suppress(Exception): + await proxy_state.message.edit_text(error_msg[:TELEGRAM_MAX_LENGTH]) async def send_to_telegram(user_id: int, text: str, bot: Bot) -> None: diff --git a/backend/src/bot/modules/ai/agent.py b/backend/src/bot/modules/ai/agent.py index ea39216..6b79137 100644 --- a/backend/src/bot/modules/ai/agent.py +++ b/backend/src/bot/modules/ai/agent.py @@ -63,7 +63,7 @@ def create_text_agent( if rag_db_names: full_prompt = f"{base_prompt}{RAG_SYSTEM_ADDITION} {LATEX_INSTRUCTION}" - agent: Agent[None, str] = Agent( + agent: Agent[AgentDeps, str] = Agent( model, instructions=full_prompt, deps_type=AgentDeps ) @@ -101,13 +101,16 @@ def create_text_agent( def create_follow_up_agent( api_key: str, - model_name: str = "gemini-2.5-flash-lite", + model_name: str = "gemini-3.1-flash-lite-preview", system_prompt: str | None = None, ) -> Agent[None, FollowUpOptions]: provider = GoogleProvider(api_key=api_key) model = GoogleModel(model_name, provider=provider) prompt = system_prompt or DEFAULT_FOLLOW_UP - return Agent(model, output_type=FollowUpOptions, instructions=prompt) + agent: Agent[None, FollowUpOptions] = Agent( # ty: ignore[invalid-assignment] + model, output_type=FollowUpOptions, instructions=prompt + ) + return agent def build_message_history(history: list[dict[str, str]]) -> list[ModelMessage]: diff --git a/backend/src/bot/sync.py b/backend/src/bot/sync.py index f47860b..1467462 100644 --- a/backend/src/bot/sync.py +++ b/backend/src/bot/sync.py @@ -5,6 +5,7 @@ import time from aiogram import Bot from bot.handlers.message.handler import process_message_from_web +from bot.user_lock import get_user_lock from utils import env from utils.collaborative import ( CollaborativeClient, @@ -53,14 +54,16 @@ async def start_sync_listener(bot: Bot) -> None: async def handle_pending_generation(bot: Bot, item: dict, item_id: str) -> None: try: - await process_message_from_web( - convex_user_id=item["userId"], - text=item["userMessage"], - bot=bot, - convex_chat_id=item["chatId"], - images_base64=item.get("imagesBase64"), - images_media_types=item.get("imagesMediaTypes"), - ) + async with get_user_lock(item["userId"]): + await process_message_from_web( + convex_user_id=item["userId"], + text=item["userMessage"], + bot=bot, + convex_chat_id=item["chatId"], + images_base64=item.get("imagesBase64"), + images_media_types=item.get("imagesMediaTypes"), + pending_generation_id=item_id, + ) except Exception as e: # noqa: BLE001 logger.error(f"Error processing {item_id}: {e}") finally: diff --git a/backend/src/bot/user_lock.py b/backend/src/bot/user_lock.py new file mode 100644 index 0000000..3141917 --- /dev/null +++ b/backend/src/bot/user_lock.py @@ -0,0 +1,11 @@ +import asyncio + +_locks: dict[str, asyncio.Lock] = {} + + +def get_user_lock(convex_user_id: str) -> asyncio.Lock: + lock = _locks.get(convex_user_id) + if lock is None: + lock = asyncio.Lock() + _locks[convex_user_id] = lock + return lock diff --git a/frontend/src/lib/components/PendingMessageBubble.svelte b/frontend/src/lib/components/PendingMessageBubble.svelte new file mode 100644 index 0000000..a71bcba --- /dev/null +++ b/frontend/src/lib/components/PendingMessageBubble.svelte @@ -0,0 +1,48 @@ + + +