feat(*): add RAG support

This commit is contained in:
h
2026-01-25 16:44:59 +01:00
parent 5b1f50a6f6
commit a992e3f0c2
20 changed files with 1412 additions and 17 deletions

View File

@@ -7,16 +7,22 @@ from pydantic_ai import (
ModelMessage,
ModelRequest,
ModelResponse,
RunContext,
TextPart,
UserPromptPart,
)
from pydantic_ai.models.google import GoogleModel
from pydantic_ai.providers.google import GoogleProvider
from utils import env
from utils.convex import ConvexClient
from utils.logging import logger
from .models import FollowUpOptions
from .prompts import DEFAULT_FOLLOW_UP
StreamCallback = Callable[[str], Awaitable[None]]
convex = ConvexClient(env.convex_url)
@dataclass
@@ -25,21 +31,70 @@ class ImageData:
media_type: str
@dataclass
class AgentDeps:
user_id: str
api_key: str
rag_db_names: list[str]
LATEX_INSTRUCTION = "For math, use LaTeX: $...$ inline, $$...$$ display."
DEFAULT_SYSTEM_PROMPT = (
"You are a helpful AI assistant. Provide clear, concise answers."
)
RAG_SYSTEM_ADDITION = (
" You have access to a knowledge base. Use the search_knowledge_base tool "
"to find relevant information when the user asks about topics that might "
"be covered in the knowledge base."
)
def create_text_agent(
api_key: str,
model_name: str = "gemini-3-pro-preview",
system_prompt: str | None = None,
) -> Agent[None, str]:
rag_db_names: list[str] | None = None,
) -> Agent[AgentDeps, str] | Agent[None, str]:
provider = GoogleProvider(api_key=api_key)
model = GoogleModel(model_name, provider=provider)
base_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
if rag_db_names:
full_prompt = f"{base_prompt}{RAG_SYSTEM_ADDITION} {LATEX_INSTRUCTION}"
agent: Agent[None, str] = Agent(
model, instructions=full_prompt, deps_type=AgentDeps
)
@agent.tool
async def search_knowledge_base(ctx: RunContext[AgentDeps], query: str) -> str:
"""Search the user's knowledge base for relevant information.
Args:
ctx: The run context containing user dependencies.
query: The search query to find relevant information.
Returns:
Relevant text from the knowledge base.
"""
logger.info(f"Searching knowledge base for {query}")
result = await convex.action(
"rag:searchMultiple",
{
"userId": ctx.deps.user_id,
"dbNames": ctx.deps.rag_db_names,
"apiKey": ctx.deps.api_key,
"query": query,
"limit": 5,
},
)
if result and result.get("text"):
return f"Knowledge base results:\n\n{result['text']}"
return "No relevant information found in the knowledge base."
return agent
full_prompt = f"{base_prompt} {LATEX_INSTRUCTION}"
return Agent(model, instructions=full_prompt)
@@ -68,12 +123,13 @@ def build_message_history(history: list[dict[str, str]]) -> list[ModelMessage]:
async def stream_response( # noqa: PLR0913
text_agent: Agent[None, str],
text_agent: Agent[AgentDeps, str] | Agent[None, str],
message: str,
history: list[dict[str, str]] | None = None,
on_chunk: StreamCallback | None = None,
image: ImageData | None = None,
images: list[ImageData] | None = None,
deps: AgentDeps | None = None,
) -> str:
message_history = build_message_history(history) if history else None
@@ -88,7 +144,7 @@ async def stream_response( # noqa: PLR0913
else:
prompt = message # type: ignore[assignment]
stream = text_agent.run_stream(prompt, message_history=message_history)
stream = text_agent.run_stream(prompt, message_history=message_history, deps=deps)
async with stream as result:
async for text in result.stream_text():
if on_chunk: