feat(*): add RAG support

This commit is contained in:
h
2026-01-25 16:44:59 +01:00
parent 5b1f50a6f6
commit a992e3f0c2
20 changed files with 1412 additions and 17 deletions

View File

@@ -20,6 +20,7 @@ from convex import ConvexInt64
from bot.modules.ai import (
SUMMARIZE_PROMPT,
AgentDeps,
ImageData,
create_follow_up_agent,
create_text_agent,
@@ -218,6 +219,18 @@ async def process_message_from_web( # noqa: C901, PLR0912, PLR0913, PLR0915
api_key = user["geminiApiKey"]
model_name = user.get("model", "gemini-3-pro-preview")
rag_connections = await convex.query(
"ragConnections:getActiveForUser", {"userId": convex_user_id}
)
rag_db_names: list[str] = []
if rag_connections:
for conn in rag_connections:
db = await convex.query(
"rag:getDatabaseById", {"ragDatabaseId": conn["ragDatabaseId"]}
)
if db:
rag_db_names.append(db["name"])
assistant_message_id = await convex.mutation(
"messages:create",
{
@@ -235,7 +248,16 @@ async def process_message_from_web( # noqa: C901, PLR0912, PLR0913, PLR0915
system_prompt = SUMMARIZE_PROMPT if is_summarize else user.get("systemPrompt")
text_agent = create_text_agent(
api_key=api_key, model_name=model_name, system_prompt=system_prompt
api_key=api_key,
model_name=model_name,
system_prompt=system_prompt,
rag_db_names=rag_db_names if rag_db_names else None,
)
agent_deps = (
AgentDeps(user_id=convex_user_id, api_key=api_key, rag_db_names=rag_db_names)
if rag_db_names
else None
)
processing_msg = None
@@ -266,7 +288,7 @@ async def process_message_from_web( # noqa: C901, PLR0912, PLR0913, PLR0915
chat_images = await fetch_chat_images(convex_chat_id)
final_answer = await stream_response(
text_agent, prompt_text, hist, on_chunk, images=chat_images
text_agent, prompt_text, hist, on_chunk, images=chat_images, deps=agent_deps
)
if state:
@@ -354,6 +376,19 @@ async def process_message(
active_chat_id = user["activeChatId"]
api_key = user["geminiApiKey"]
model_name = user.get("model", "gemini-3-pro-preview")
convex_user_id = user["_id"]
rag_connections = await convex.query(
"ragConnections:getActiveForUser", {"userId": convex_user_id}
)
rag_db_names: list[str] = []
if rag_connections:
for conn in rag_connections:
db = await convex.query(
"rag:getDatabaseById", {"ragDatabaseId": conn["ragDatabaseId"]}
)
if db:
rag_db_names.append(db["name"])
if not skip_user_message:
await convex.mutation(
@@ -382,7 +417,16 @@ async def process_message(
)
text_agent = create_text_agent(
api_key=api_key, model_name=model_name, system_prompt=user.get("systemPrompt")
api_key=api_key,
model_name=model_name,
system_prompt=user.get("systemPrompt"),
rag_db_names=rag_db_names if rag_db_names else None,
)
agent_deps = (
AgentDeps(user_id=convex_user_id, api_key=api_key, rag_db_names=rag_db_names)
if rag_db_names
else None
)
processing_msg = await bot.send_message(chat_id, "...")
@@ -401,7 +445,12 @@ async def process_message(
chat_images = await fetch_chat_images(active_chat_id)
final_answer = await stream_response(
text_agent, text, history[:-2], on_chunk, images=chat_images
text_agent,
text,
history[:-2],
on_chunk,
images=chat_images,
deps=agent_deps,
)
await state.flush()