172 lines
5.3 KiB
Python
172 lines
5.3 KiB
Python
from collections.abc import Awaitable, Callable
|
|
from dataclasses import dataclass
|
|
|
|
from pydantic_ai import (
|
|
Agent,
|
|
BinaryContent,
|
|
ModelMessage,
|
|
ModelRequest,
|
|
ModelResponse,
|
|
RunContext,
|
|
TextPart,
|
|
UserPromptPart,
|
|
)
|
|
from pydantic_ai.models.google import GoogleModel
|
|
from pydantic_ai.providers.google import GoogleProvider
|
|
|
|
from utils import env
|
|
from utils.convex import ConvexClient
|
|
from utils.logging import logger
|
|
|
|
from .models import FollowUpOptions
|
|
from .prompts import DEFAULT_FOLLOW_UP
|
|
|
|
StreamCallback = Callable[[str], Awaitable[None]]
|
|
convex = ConvexClient(env.convex_url)
|
|
|
|
|
|
@dataclass
|
|
class ImageData:
|
|
data: bytes
|
|
media_type: str
|
|
|
|
|
|
@dataclass
|
|
class AgentDeps:
|
|
user_id: str
|
|
api_key: str
|
|
rag_db_names: list[str]
|
|
|
|
|
|
LATEX_INSTRUCTION = "For math, use LaTeX: $...$ inline, $$...$$ display."
|
|
|
|
DEFAULT_SYSTEM_PROMPT = (
|
|
"You are a helpful AI assistant. Provide clear, concise answers."
|
|
)
|
|
|
|
RAG_SYSTEM_ADDITION = (
|
|
" You have access to a knowledge base. Use the search_knowledge_base tool "
|
|
"to find relevant information when the user asks about topics that might "
|
|
"be covered in the knowledge base."
|
|
)
|
|
|
|
|
|
def create_text_agent(
|
|
api_key: str,
|
|
model_name: str = "gemini-3-pro-preview",
|
|
system_prompt: str | None = None,
|
|
rag_db_names: list[str] | None = None,
|
|
) -> Agent[AgentDeps, str] | Agent[None, str]:
|
|
provider = GoogleProvider(api_key=api_key)
|
|
model = GoogleModel(model_name, provider=provider)
|
|
base_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
|
|
|
|
if rag_db_names:
|
|
full_prompt = f"{base_prompt}{RAG_SYSTEM_ADDITION} {LATEX_INSTRUCTION}"
|
|
agent: Agent[None, str] = Agent(
|
|
model, instructions=full_prompt, deps_type=AgentDeps
|
|
)
|
|
|
|
@agent.tool
|
|
async def search_knowledge_base(ctx: RunContext[AgentDeps], query: str) -> str:
|
|
"""Search the user's knowledge base for relevant information.
|
|
|
|
Args:
|
|
ctx: The run context containing user dependencies.
|
|
query: The search query to find relevant information.
|
|
|
|
Returns:
|
|
Relevant text from the knowledge base.
|
|
"""
|
|
logger.info(f"Searching knowledge base for {query}")
|
|
result = await convex.action(
|
|
"rag:searchMultiple",
|
|
{
|
|
"userId": ctx.deps.user_id,
|
|
"dbNames": ctx.deps.rag_db_names,
|
|
"apiKey": ctx.deps.api_key,
|
|
"query": query,
|
|
"limit": 5,
|
|
},
|
|
)
|
|
if result and result.get("text"):
|
|
return f"Knowledge base results:\n\n{result['text']}"
|
|
return "No relevant information found in the knowledge base."
|
|
|
|
return agent
|
|
|
|
full_prompt = f"{base_prompt} {LATEX_INSTRUCTION}"
|
|
return Agent(model, instructions=full_prompt)
|
|
|
|
|
|
def create_follow_up_agent(
|
|
api_key: str,
|
|
model_name: str = "gemini-2.5-flash-lite",
|
|
system_prompt: str | None = None,
|
|
) -> Agent[None, FollowUpOptions]:
|
|
provider = GoogleProvider(api_key=api_key)
|
|
model = GoogleModel(model_name, provider=provider)
|
|
prompt = system_prompt or DEFAULT_FOLLOW_UP
|
|
return Agent(model, output_type=FollowUpOptions, instructions=prompt)
|
|
|
|
|
|
def build_message_history(history: list[dict[str, str]]) -> list[ModelMessage]:
|
|
messages: list[ModelMessage] = []
|
|
for msg in history:
|
|
if msg["role"] == "user":
|
|
messages.append(
|
|
ModelRequest(parts=[UserPromptPart(content=msg["content"])])
|
|
)
|
|
else:
|
|
messages.append(ModelResponse(parts=[TextPart(content=msg["content"])]))
|
|
return messages
|
|
|
|
|
|
async def stream_response( # noqa: PLR0913
|
|
text_agent: Agent[AgentDeps, str] | Agent[None, str],
|
|
message: str,
|
|
history: list[dict[str, str]] | None = None,
|
|
on_chunk: StreamCallback | None = None,
|
|
image: ImageData | None = None,
|
|
images: list[ImageData] | None = None,
|
|
deps: AgentDeps | None = None,
|
|
) -> str:
|
|
message_history = build_message_history(history) if history else None
|
|
|
|
all_images = images or ([image] if image else [])
|
|
|
|
if all_images:
|
|
prompt: list[str | BinaryContent] = [message]
|
|
prompt.extend(
|
|
BinaryContent(data=img.data, media_type=img.media_type)
|
|
for img in all_images
|
|
)
|
|
else:
|
|
prompt = message # type: ignore[assignment]
|
|
|
|
stream = text_agent.run_stream(prompt, message_history=message_history, deps=deps)
|
|
async with stream as result:
|
|
async for text in result.stream_text():
|
|
if on_chunk:
|
|
await on_chunk(text)
|
|
return await result.get_output()
|
|
|
|
|
|
async def get_follow_ups(
|
|
follow_up_agent: Agent[None, FollowUpOptions],
|
|
history: list[dict[str, str]],
|
|
images: list[ImageData] | None = None,
|
|
) -> list[str]:
|
|
message_history = build_message_history(history) if history else None
|
|
|
|
if images:
|
|
prompt: list[str | BinaryContent] = ["Process this:"]
|
|
prompt.extend(
|
|
BinaryContent(data=img.data, media_type=img.media_type) for img in images
|
|
)
|
|
else:
|
|
prompt = "Process this conversation." # type: ignore[assignment]
|
|
|
|
result = await follow_up_agent.run(prompt, message_history=message_history)
|
|
return result.output["options"]
|