from collections.abc import Awaitable, Callable from dataclasses import dataclass from pydantic_ai import ( Agent, BinaryContent, ModelMessage, ModelRequest, ModelResponse, RunContext, TextPart, UserPromptPart, ) from pydantic_ai.models.google import GoogleModel from pydantic_ai.providers.google import GoogleProvider from utils import env from utils.convex import ConvexClient from utils.logging import logger from .models import FollowUpOptions from .prompts import DEFAULT_FOLLOW_UP StreamCallback = Callable[[str], Awaitable[None]] convex = ConvexClient(env.convex_url) @dataclass class ImageData: data: bytes media_type: str @dataclass class AgentDeps: user_id: str api_key: str rag_db_names: list[str] LATEX_INSTRUCTION = "For math, use LaTeX: $...$ inline, $$...$$ display." DEFAULT_SYSTEM_PROMPT = ( "You are a helpful AI assistant. Provide clear, concise answers." ) RAG_SYSTEM_ADDITION = ( " You have access to a knowledge base. Use the search_knowledge_base tool " "to find relevant information when the user asks about topics that might " "be covered in the knowledge base." ) def create_text_agent( api_key: str, model_name: str = "gemini-3-pro-preview", system_prompt: str | None = None, rag_db_names: list[str] | None = None, ) -> Agent[AgentDeps, str] | Agent[None, str]: provider = GoogleProvider(api_key=api_key) model = GoogleModel(model_name, provider=provider) base_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT if rag_db_names: full_prompt = f"{base_prompt}{RAG_SYSTEM_ADDITION} {LATEX_INSTRUCTION}" agent: Agent[None, str] = Agent( model, instructions=full_prompt, deps_type=AgentDeps ) @agent.tool async def search_knowledge_base(ctx: RunContext[AgentDeps], query: str) -> str: """Search the user's knowledge base for relevant information. Args: ctx: The run context containing user dependencies. query: The search query to find relevant information. Returns: Relevant text from the knowledge base. """ logger.info(f"Searching knowledge base for {query}") result = await convex.action( "rag:searchMultiple", { "userId": ctx.deps.user_id, "dbNames": ctx.deps.rag_db_names, "apiKey": ctx.deps.api_key, "query": query, "limit": 5, }, ) if result and result.get("text"): return f"Knowledge base results:\n\n{result['text']}" return "No relevant information found in the knowledge base." return agent full_prompt = f"{base_prompt} {LATEX_INSTRUCTION}" return Agent(model, instructions=full_prompt) def create_follow_up_agent( api_key: str, model_name: str = "gemini-2.5-flash-lite", system_prompt: str | None = None, ) -> Agent[None, FollowUpOptions]: provider = GoogleProvider(api_key=api_key) model = GoogleModel(model_name, provider=provider) prompt = system_prompt or DEFAULT_FOLLOW_UP return Agent(model, output_type=FollowUpOptions, instructions=prompt) def build_message_history(history: list[dict[str, str]]) -> list[ModelMessage]: messages: list[ModelMessage] = [] for msg in history: if msg["role"] == "user": messages.append( ModelRequest(parts=[UserPromptPart(content=msg["content"])]) ) else: messages.append(ModelResponse(parts=[TextPart(content=msg["content"])])) return messages async def stream_response( # noqa: PLR0913 text_agent: Agent[AgentDeps, str] | Agent[None, str], message: str, history: list[dict[str, str]] | None = None, on_chunk: StreamCallback | None = None, image: ImageData | None = None, images: list[ImageData] | None = None, deps: AgentDeps | None = None, ) -> str: message_history = build_message_history(history) if history else None all_images = images or ([image] if image else []) if all_images: prompt: list[str | BinaryContent] = [message] prompt.extend( BinaryContent(data=img.data, media_type=img.media_type) for img in all_images ) else: prompt = message # type: ignore[assignment] stream = text_agent.run_stream(prompt, message_history=message_history, deps=deps) async with stream as result: async for text in result.stream_text(): if on_chunk: await on_chunk(text) return await result.get_output() async def get_follow_ups( follow_up_agent: Agent[None, FollowUpOptions], history: list[dict[str, str]], images: list[ImageData] | None = None, ) -> list[str]: message_history = build_message_history(history) if history else None if images: prompt: list[str | BinaryContent] = ["Process this:"] prompt.extend( BinaryContent(data=img.data, media_type=img.media_type) for img in images ) else: prompt = "Process this conversation." # type: ignore[assignment] result = await follow_up_agent.run(prompt, message_history=message_history) return result.output["options"]