feat(*): add RAG support
This commit is contained in:
@@ -20,6 +20,7 @@ from convex import ConvexInt64
|
||||
|
||||
from bot.modules.ai import (
|
||||
SUMMARIZE_PROMPT,
|
||||
AgentDeps,
|
||||
ImageData,
|
||||
create_follow_up_agent,
|
||||
create_text_agent,
|
||||
@@ -218,6 +219,18 @@ async def process_message_from_web( # noqa: C901, PLR0912, PLR0913, PLR0915
|
||||
api_key = user["geminiApiKey"]
|
||||
model_name = user.get("model", "gemini-3-pro-preview")
|
||||
|
||||
rag_connections = await convex.query(
|
||||
"ragConnections:getActiveForUser", {"userId": convex_user_id}
|
||||
)
|
||||
rag_db_names: list[str] = []
|
||||
if rag_connections:
|
||||
for conn in rag_connections:
|
||||
db = await convex.query(
|
||||
"rag:getDatabaseById", {"ragDatabaseId": conn["ragDatabaseId"]}
|
||||
)
|
||||
if db:
|
||||
rag_db_names.append(db["name"])
|
||||
|
||||
assistant_message_id = await convex.mutation(
|
||||
"messages:create",
|
||||
{
|
||||
@@ -235,7 +248,16 @@ async def process_message_from_web( # noqa: C901, PLR0912, PLR0913, PLR0915
|
||||
|
||||
system_prompt = SUMMARIZE_PROMPT if is_summarize else user.get("systemPrompt")
|
||||
text_agent = create_text_agent(
|
||||
api_key=api_key, model_name=model_name, system_prompt=system_prompt
|
||||
api_key=api_key,
|
||||
model_name=model_name,
|
||||
system_prompt=system_prompt,
|
||||
rag_db_names=rag_db_names if rag_db_names else None,
|
||||
)
|
||||
|
||||
agent_deps = (
|
||||
AgentDeps(user_id=convex_user_id, api_key=api_key, rag_db_names=rag_db_names)
|
||||
if rag_db_names
|
||||
else None
|
||||
)
|
||||
|
||||
processing_msg = None
|
||||
@@ -266,7 +288,7 @@ async def process_message_from_web( # noqa: C901, PLR0912, PLR0913, PLR0915
|
||||
chat_images = await fetch_chat_images(convex_chat_id)
|
||||
|
||||
final_answer = await stream_response(
|
||||
text_agent, prompt_text, hist, on_chunk, images=chat_images
|
||||
text_agent, prompt_text, hist, on_chunk, images=chat_images, deps=agent_deps
|
||||
)
|
||||
|
||||
if state:
|
||||
@@ -354,6 +376,19 @@ async def process_message(
|
||||
active_chat_id = user["activeChatId"]
|
||||
api_key = user["geminiApiKey"]
|
||||
model_name = user.get("model", "gemini-3-pro-preview")
|
||||
convex_user_id = user["_id"]
|
||||
|
||||
rag_connections = await convex.query(
|
||||
"ragConnections:getActiveForUser", {"userId": convex_user_id}
|
||||
)
|
||||
rag_db_names: list[str] = []
|
||||
if rag_connections:
|
||||
for conn in rag_connections:
|
||||
db = await convex.query(
|
||||
"rag:getDatabaseById", {"ragDatabaseId": conn["ragDatabaseId"]}
|
||||
)
|
||||
if db:
|
||||
rag_db_names.append(db["name"])
|
||||
|
||||
if not skip_user_message:
|
||||
await convex.mutation(
|
||||
@@ -382,7 +417,16 @@ async def process_message(
|
||||
)
|
||||
|
||||
text_agent = create_text_agent(
|
||||
api_key=api_key, model_name=model_name, system_prompt=user.get("systemPrompt")
|
||||
api_key=api_key,
|
||||
model_name=model_name,
|
||||
system_prompt=user.get("systemPrompt"),
|
||||
rag_db_names=rag_db_names if rag_db_names else None,
|
||||
)
|
||||
|
||||
agent_deps = (
|
||||
AgentDeps(user_id=convex_user_id, api_key=api_key, rag_db_names=rag_db_names)
|
||||
if rag_db_names
|
||||
else None
|
||||
)
|
||||
|
||||
processing_msg = await bot.send_message(chat_id, "...")
|
||||
@@ -401,7 +445,12 @@ async def process_message(
|
||||
chat_images = await fetch_chat_images(active_chat_id)
|
||||
|
||||
final_answer = await stream_response(
|
||||
text_agent, text, history[:-2], on_chunk, images=chat_images
|
||||
text_agent,
|
||||
text,
|
||||
history[:-2],
|
||||
on_chunk,
|
||||
images=chat_images,
|
||||
deps=agent_deps,
|
||||
)
|
||||
|
||||
await state.flush()
|
||||
|
||||
Reference in New Issue
Block a user