feat(*): make message processing sequential
This commit is contained in:
@@ -28,6 +28,7 @@ from bot.modules.ai import (
|
||||
get_follow_ups,
|
||||
stream_response,
|
||||
)
|
||||
from bot.user_lock import get_user_lock
|
||||
from utils import env
|
||||
from utils.convex import ConvexClient
|
||||
|
||||
@@ -256,15 +257,34 @@ async def process_message_from_web( # noqa: C901, PLR0912, PLR0913, PLR0915
|
||||
convex_chat_id: str,
|
||||
images_base64: list[str] | None = None,
|
||||
images_media_types: list[str] | None = None,
|
||||
pending_generation_id: str | None = None,
|
||||
) -> None:
|
||||
user = await convex.query("users:getById", {"userId": convex_user_id})
|
||||
|
||||
if not user or not user.get("geminiApiKey"):
|
||||
return
|
||||
|
||||
tg_chat_id = user["telegramChatId"].value if user.get("telegramChatId") else None
|
||||
is_summarize = text == "/summarize"
|
||||
|
||||
if not is_summarize:
|
||||
user_message_args: dict = {
|
||||
"chatId": convex_chat_id,
|
||||
"role": "user",
|
||||
"content": text,
|
||||
"source": "web",
|
||||
}
|
||||
if images_base64 and images_media_types:
|
||||
user_message_args["imagesBase64"] = images_base64
|
||||
user_message_args["imagesMediaTypes"] = images_media_types
|
||||
await convex.mutation("messages:createFromBackend", user_message_args)
|
||||
|
||||
if pending_generation_id:
|
||||
await convex.mutation(
|
||||
"pendingGenerations:remove", {"id": pending_generation_id}
|
||||
)
|
||||
|
||||
tg_chat_id = user["telegramChatId"].value if user.get("telegramChatId") else None
|
||||
|
||||
if tg_chat_id and not is_summarize:
|
||||
if images_base64 and images_media_types:
|
||||
if len(images_base64) == 1:
|
||||
@@ -387,7 +407,7 @@ async def process_message_from_web( # noqa: C901, PLR0912, PLR0913, PLR0915
|
||||
await state.flush()
|
||||
|
||||
full_history = [*history, {"role": "assistant", "content": final_answer}]
|
||||
follow_up_model = user.get("followUpModel", "gemini-2.5-flash-lite")
|
||||
follow_up_model = user.get("followUpModel", "gemini-3.1-flash-lite-preview")
|
||||
follow_up_prompt = user.get("followUpPrompt")
|
||||
follow_up_agent = create_follow_up_agent(
|
||||
api_key=api_key, model_name=follow_up_model, system_prompt=follow_up_prompt
|
||||
@@ -485,174 +505,189 @@ async def process_message( # noqa: C901, PLR0912, PLR0913, PLR0915
|
||||
model_name = user.get("model", "gemini-3-pro-preview")
|
||||
convex_user_id = user["_id"]
|
||||
|
||||
proxy_config = get_proxy_config(chat_id)
|
||||
proxy_state: ProxyStreamingState | None = None
|
||||
async with get_user_lock(convex_user_id):
|
||||
proxy_config = get_proxy_config(chat_id)
|
||||
proxy_state: ProxyStreamingState | None = None
|
||||
|
||||
if proxy_config and not skip_proxy_user_message:
|
||||
with contextlib.suppress(Exception):
|
||||
await proxy_config.proxy_bot.send_message(
|
||||
proxy_config.target_chat_id, f"👤 {text}"
|
||||
if proxy_config and not skip_proxy_user_message:
|
||||
with contextlib.suppress(Exception):
|
||||
await proxy_config.proxy_bot.send_message(
|
||||
proxy_config.target_chat_id, f"👤 {text}"
|
||||
)
|
||||
await increment_proxy_count(chat_id)
|
||||
proxy_config = get_proxy_config(chat_id)
|
||||
|
||||
rag_connections = await convex.query(
|
||||
"ragConnections:getActiveForUser", {"userId": convex_user_id}
|
||||
)
|
||||
rag_db_names: list[str] = []
|
||||
if rag_connections:
|
||||
for conn in rag_connections:
|
||||
db = await convex.query(
|
||||
"rag:getDatabaseById", {"ragDatabaseId": conn["ragDatabaseId"]}
|
||||
)
|
||||
if db:
|
||||
rag_db_names.append(db["name"])
|
||||
|
||||
inject_connections = await convex.query(
|
||||
"injectConnections:getActiveForUser", {"userId": convex_user_id}
|
||||
)
|
||||
inject_content = ""
|
||||
if inject_connections:
|
||||
for conn in inject_connections:
|
||||
db = await convex.query(
|
||||
"inject:getDatabaseById",
|
||||
{"injectDatabaseId": conn["injectDatabaseId"]},
|
||||
)
|
||||
if db and db.get("content"):
|
||||
inject_content += db["content"] + "\n\n"
|
||||
inject_content = inject_content.strip()
|
||||
|
||||
if not skip_user_message:
|
||||
await convex.mutation(
|
||||
"messages:create",
|
||||
{
|
||||
"chatId": active_chat_id,
|
||||
"role": "user",
|
||||
"content": text,
|
||||
"source": "telegram",
|
||||
},
|
||||
)
|
||||
await increment_proxy_count(chat_id)
|
||||
proxy_config = get_proxy_config(chat_id)
|
||||
|
||||
rag_connections = await convex.query(
|
||||
"ragConnections:getActiveForUser", {"userId": convex_user_id}
|
||||
)
|
||||
rag_db_names: list[str] = []
|
||||
if rag_connections:
|
||||
for conn in rag_connections:
|
||||
db = await convex.query(
|
||||
"rag:getDatabaseById", {"ragDatabaseId": conn["ragDatabaseId"]}
|
||||
)
|
||||
if db:
|
||||
rag_db_names.append(db["name"])
|
||||
|
||||
inject_connections = await convex.query(
|
||||
"injectConnections:getActiveForUser", {"userId": convex_user_id}
|
||||
)
|
||||
inject_content = ""
|
||||
if inject_connections:
|
||||
for conn in inject_connections:
|
||||
db = await convex.query(
|
||||
"inject:getDatabaseById", {"injectDatabaseId": conn["injectDatabaseId"]}
|
||||
)
|
||||
if db and db.get("content"):
|
||||
inject_content += db["content"] + "\n\n"
|
||||
inject_content = inject_content.strip()
|
||||
|
||||
if not skip_user_message:
|
||||
await convex.mutation(
|
||||
assistant_message_id = await convex.mutation(
|
||||
"messages:create",
|
||||
{
|
||||
"chatId": active_chat_id,
|
||||
"role": "user",
|
||||
"content": text,
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"source": "telegram",
|
||||
"isStreaming": True,
|
||||
},
|
||||
)
|
||||
|
||||
assistant_message_id = await convex.mutation(
|
||||
"messages:create",
|
||||
{
|
||||
"chatId": active_chat_id,
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"source": "telegram",
|
||||
"isStreaming": True,
|
||||
},
|
||||
)
|
||||
|
||||
history = await convex.query(
|
||||
"messages:getHistoryForAI", {"chatId": active_chat_id, "limit": 50}
|
||||
)
|
||||
|
||||
system_prompt = user.get("systemPrompt")
|
||||
if system_prompt and inject_content:
|
||||
system_prompt = system_prompt.replace("{theory_database}", inject_content)
|
||||
text_agent = create_text_agent(
|
||||
api_key=api_key,
|
||||
model_name=model_name,
|
||||
system_prompt=system_prompt,
|
||||
rag_db_names=rag_db_names or None,
|
||||
)
|
||||
|
||||
agent_deps = (
|
||||
AgentDeps(user_id=convex_user_id, api_key=api_key, rag_db_names=rag_db_names)
|
||||
if rag_db_names
|
||||
else None
|
||||
)
|
||||
|
||||
processing_msg = await bot.send_message(chat_id, "...")
|
||||
state = StreamingState(bot, chat_id, processing_msg)
|
||||
|
||||
if proxy_config:
|
||||
proxy_processing_msg = await proxy_config.proxy_bot.send_message(
|
||||
proxy_config.target_chat_id, "..."
|
||||
)
|
||||
proxy_state = ProxyStreamingState(
|
||||
proxy_config.proxy_bot, proxy_config.target_chat_id, proxy_processing_msg
|
||||
history = await convex.query(
|
||||
"messages:getHistoryForAI", {"chatId": active_chat_id, "limit": 50}
|
||||
)
|
||||
|
||||
try:
|
||||
await state.start_typing()
|
||||
system_prompt = user.get("systemPrompt")
|
||||
if system_prompt and inject_content:
|
||||
system_prompt = system_prompt.replace("{theory_database}", inject_content)
|
||||
text_agent = create_text_agent(
|
||||
api_key=api_key,
|
||||
model_name=model_name,
|
||||
system_prompt=system_prompt,
|
||||
rag_db_names=rag_db_names or None,
|
||||
)
|
||||
|
||||
async def on_chunk(content: str) -> None:
|
||||
await state.update_message(content)
|
||||
if proxy_state:
|
||||
await proxy_state.update_message(content)
|
||||
await convex.mutation(
|
||||
"messages:update",
|
||||
{"messageId": assistant_message_id, "content": content},
|
||||
agent_deps = (
|
||||
AgentDeps(
|
||||
user_id=convex_user_id, api_key=api_key, rag_db_names=rag_db_names
|
||||
)
|
||||
if rag_db_names
|
||||
else None
|
||||
)
|
||||
|
||||
processing_msg = await bot.send_message(chat_id, "...")
|
||||
state = StreamingState(bot, chat_id, processing_msg)
|
||||
|
||||
if proxy_config:
|
||||
proxy_processing_msg = await proxy_config.proxy_bot.send_message(
|
||||
proxy_config.target_chat_id, "..."
|
||||
)
|
||||
proxy_state = ProxyStreamingState(
|
||||
proxy_config.proxy_bot,
|
||||
proxy_config.target_chat_id,
|
||||
proxy_processing_msg,
|
||||
)
|
||||
|
||||
chat_images = await fetch_chat_images(active_chat_id)
|
||||
try:
|
||||
await state.start_typing()
|
||||
|
||||
final_answer = await stream_response(
|
||||
text_agent,
|
||||
text,
|
||||
history[:-2],
|
||||
on_chunk,
|
||||
images=chat_images,
|
||||
deps=agent_deps,
|
||||
)
|
||||
|
||||
await state.flush()
|
||||
if proxy_state:
|
||||
await proxy_state.flush()
|
||||
|
||||
full_history = [*history[:-1], {"role": "assistant", "content": final_answer}]
|
||||
follow_up_model = user.get("followUpModel", "gemini-2.5-flash-lite")
|
||||
follow_up_prompt = user.get("followUpPrompt")
|
||||
follow_up_agent = create_follow_up_agent(
|
||||
api_key=api_key, model_name=follow_up_model, system_prompt=follow_up_prompt
|
||||
)
|
||||
follow_ups = await get_follow_ups(follow_up_agent, full_history, chat_images)
|
||||
|
||||
await state.stop_typing()
|
||||
|
||||
await convex.mutation(
|
||||
"messages:update",
|
||||
{
|
||||
"messageId": assistant_message_id,
|
||||
"content": final_answer,
|
||||
"followUpOptions": follow_ups,
|
||||
"isStreaming": False,
|
||||
},
|
||||
)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
await processing_msg.delete()
|
||||
|
||||
keyboard = make_follow_up_keyboard(follow_ups)
|
||||
await send_long_message(bot, chat_id, final_answer, keyboard)
|
||||
|
||||
if proxy_state and proxy_config:
|
||||
with contextlib.suppress(Exception):
|
||||
await proxy_state.message.delete()
|
||||
parts = split_message(final_answer)
|
||||
for part in parts:
|
||||
await proxy_config.proxy_bot.send_message(
|
||||
proxy_config.target_chat_id, part
|
||||
async def on_chunk(content: str) -> None:
|
||||
await state.update_message(content)
|
||||
if proxy_state:
|
||||
await proxy_state.update_message(content)
|
||||
await convex.mutation(
|
||||
"messages:update",
|
||||
{"messageId": assistant_message_id, "content": content},
|
||||
)
|
||||
await increment_proxy_count(chat_id)
|
||||
|
||||
except Exception as e: # noqa: BLE001
|
||||
await state.stop_typing()
|
||||
error_msg = f"Error: {e}"
|
||||
await convex.mutation(
|
||||
"messages:update",
|
||||
{
|
||||
"messageId": assistant_message_id,
|
||||
"content": error_msg,
|
||||
"isStreaming": False,
|
||||
},
|
||||
)
|
||||
with contextlib.suppress(Exception):
|
||||
await processing_msg.edit_text(html.quote(error_msg[:TELEGRAM_MAX_LENGTH]))
|
||||
if proxy_state:
|
||||
chat_images = await fetch_chat_images(active_chat_id)
|
||||
|
||||
final_answer = await stream_response(
|
||||
text_agent,
|
||||
text,
|
||||
history[:-2],
|
||||
on_chunk,
|
||||
images=chat_images,
|
||||
deps=agent_deps,
|
||||
)
|
||||
|
||||
await state.flush()
|
||||
if proxy_state:
|
||||
await proxy_state.flush()
|
||||
|
||||
full_history = [
|
||||
*history[:-1],
|
||||
{"role": "assistant", "content": final_answer},
|
||||
]
|
||||
follow_up_model = user.get("followUpModel", "gemini-3.1-flash-lite-preview")
|
||||
follow_up_prompt = user.get("followUpPrompt")
|
||||
follow_up_agent = create_follow_up_agent(
|
||||
api_key=api_key,
|
||||
model_name=follow_up_model,
|
||||
system_prompt=follow_up_prompt,
|
||||
)
|
||||
follow_ups = await get_follow_ups(
|
||||
follow_up_agent, full_history, chat_images
|
||||
)
|
||||
|
||||
await state.stop_typing()
|
||||
|
||||
await convex.mutation(
|
||||
"messages:update",
|
||||
{
|
||||
"messageId": assistant_message_id,
|
||||
"content": final_answer,
|
||||
"followUpOptions": follow_ups,
|
||||
"isStreaming": False,
|
||||
},
|
||||
)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
await proxy_state.message.edit_text(error_msg[:TELEGRAM_MAX_LENGTH])
|
||||
await processing_msg.delete()
|
||||
|
||||
keyboard = make_follow_up_keyboard(follow_ups)
|
||||
await send_long_message(bot, chat_id, final_answer, keyboard)
|
||||
|
||||
if proxy_state and proxy_config:
|
||||
with contextlib.suppress(Exception):
|
||||
await proxy_state.message.delete()
|
||||
parts = split_message(final_answer)
|
||||
for part in parts:
|
||||
await proxy_config.proxy_bot.send_message(
|
||||
proxy_config.target_chat_id, part
|
||||
)
|
||||
await increment_proxy_count(chat_id)
|
||||
|
||||
except Exception as e: # noqa: BLE001
|
||||
await state.stop_typing()
|
||||
error_msg = f"Error: {e}"
|
||||
await convex.mutation(
|
||||
"messages:update",
|
||||
{
|
||||
"messageId": assistant_message_id,
|
||||
"content": error_msg,
|
||||
"isStreaming": False,
|
||||
},
|
||||
)
|
||||
with contextlib.suppress(Exception):
|
||||
await processing_msg.edit_text(
|
||||
html.quote(error_msg[:TELEGRAM_MAX_LENGTH])
|
||||
)
|
||||
if proxy_state:
|
||||
with contextlib.suppress(Exception):
|
||||
await proxy_state.message.edit_text(error_msg[:TELEGRAM_MAX_LENGTH])
|
||||
|
||||
|
||||
async def send_to_telegram(user_id: int, text: str, bot: Bot) -> None:
|
||||
|
||||
@@ -63,7 +63,7 @@ def create_text_agent(
|
||||
|
||||
if rag_db_names:
|
||||
full_prompt = f"{base_prompt}{RAG_SYSTEM_ADDITION} {LATEX_INSTRUCTION}"
|
||||
agent: Agent[None, str] = Agent(
|
||||
agent: Agent[AgentDeps, str] = Agent(
|
||||
model, instructions=full_prompt, deps_type=AgentDeps
|
||||
)
|
||||
|
||||
@@ -101,13 +101,16 @@ def create_text_agent(
|
||||
|
||||
def create_follow_up_agent(
|
||||
api_key: str,
|
||||
model_name: str = "gemini-2.5-flash-lite",
|
||||
model_name: str = "gemini-3.1-flash-lite-preview",
|
||||
system_prompt: str | None = None,
|
||||
) -> Agent[None, FollowUpOptions]:
|
||||
provider = GoogleProvider(api_key=api_key)
|
||||
model = GoogleModel(model_name, provider=provider)
|
||||
prompt = system_prompt or DEFAULT_FOLLOW_UP
|
||||
return Agent(model, output_type=FollowUpOptions, instructions=prompt)
|
||||
agent: Agent[None, FollowUpOptions] = Agent( # ty: ignore[invalid-assignment]
|
||||
model, output_type=FollowUpOptions, instructions=prompt
|
||||
)
|
||||
return agent
|
||||
|
||||
|
||||
def build_message_history(history: list[dict[str, str]]) -> list[ModelMessage]:
|
||||
|
||||
+11
-8
@@ -5,6 +5,7 @@ import time
|
||||
from aiogram import Bot
|
||||
|
||||
from bot.handlers.message.handler import process_message_from_web
|
||||
from bot.user_lock import get_user_lock
|
||||
from utils import env
|
||||
from utils.collaborative import (
|
||||
CollaborativeClient,
|
||||
@@ -53,14 +54,16 @@ async def start_sync_listener(bot: Bot) -> None:
|
||||
|
||||
async def handle_pending_generation(bot: Bot, item: dict, item_id: str) -> None:
|
||||
try:
|
||||
await process_message_from_web(
|
||||
convex_user_id=item["userId"],
|
||||
text=item["userMessage"],
|
||||
bot=bot,
|
||||
convex_chat_id=item["chatId"],
|
||||
images_base64=item.get("imagesBase64"),
|
||||
images_media_types=item.get("imagesMediaTypes"),
|
||||
)
|
||||
async with get_user_lock(item["userId"]):
|
||||
await process_message_from_web(
|
||||
convex_user_id=item["userId"],
|
||||
text=item["userMessage"],
|
||||
bot=bot,
|
||||
convex_chat_id=item["chatId"],
|
||||
images_base64=item.get("imagesBase64"),
|
||||
images_media_types=item.get("imagesMediaTypes"),
|
||||
pending_generation_id=item_id,
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f"Error processing {item_id}: {e}")
|
||||
finally:
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
import asyncio
|
||||
|
||||
_locks: dict[str, asyncio.Lock] = {}
|
||||
|
||||
|
||||
def get_user_lock(convex_user_id: str) -> asyncio.Lock:
|
||||
lock = _locks.get(convex_user_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_locks[convex_user_id] = lock
|
||||
return lock
|
||||
Reference in New Issue
Block a user