feat(*): add RAG support
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from .agent import (
|
||||
AgentDeps,
|
||||
ImageData,
|
||||
StreamCallback,
|
||||
create_follow_up_agent,
|
||||
@@ -12,6 +13,7 @@ __all__ = [
|
||||
"DEFAULT_FOLLOW_UP",
|
||||
"PRESETS",
|
||||
"SUMMARIZE_PROMPT",
|
||||
"AgentDeps",
|
||||
"ImageData",
|
||||
"StreamCallback",
|
||||
"create_follow_up_agent",
|
||||
|
||||
@@ -7,16 +7,22 @@ from pydantic_ai import (
|
||||
ModelMessage,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
RunContext,
|
||||
TextPart,
|
||||
UserPromptPart,
|
||||
)
|
||||
from pydantic_ai.models.google import GoogleModel
|
||||
from pydantic_ai.providers.google import GoogleProvider
|
||||
|
||||
from utils import env
|
||||
from utils.convex import ConvexClient
|
||||
from utils.logging import logger
|
||||
|
||||
from .models import FollowUpOptions
|
||||
from .prompts import DEFAULT_FOLLOW_UP
|
||||
|
||||
StreamCallback = Callable[[str], Awaitable[None]]
|
||||
convex = ConvexClient(env.convex_url)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -25,21 +31,70 @@ class ImageData:
|
||||
media_type: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentDeps:
|
||||
user_id: str
|
||||
api_key: str
|
||||
rag_db_names: list[str]
|
||||
|
||||
|
||||
LATEX_INSTRUCTION = "For math, use LaTeX: $...$ inline, $$...$$ display."
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = (
|
||||
"You are a helpful AI assistant. Provide clear, concise answers."
|
||||
)
|
||||
|
||||
RAG_SYSTEM_ADDITION = (
|
||||
" You have access to a knowledge base. Use the search_knowledge_base tool "
|
||||
"to find relevant information when the user asks about topics that might "
|
||||
"be covered in the knowledge base."
|
||||
)
|
||||
|
||||
|
||||
def create_text_agent(
|
||||
api_key: str,
|
||||
model_name: str = "gemini-3-pro-preview",
|
||||
system_prompt: str | None = None,
|
||||
) -> Agent[None, str]:
|
||||
rag_db_names: list[str] | None = None,
|
||||
) -> Agent[AgentDeps, str] | Agent[None, str]:
|
||||
provider = GoogleProvider(api_key=api_key)
|
||||
model = GoogleModel(model_name, provider=provider)
|
||||
base_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
|
||||
|
||||
if rag_db_names:
|
||||
full_prompt = f"{base_prompt}{RAG_SYSTEM_ADDITION} {LATEX_INSTRUCTION}"
|
||||
agent: Agent[None, str] = Agent(
|
||||
model, instructions=full_prompt, deps_type=AgentDeps
|
||||
)
|
||||
|
||||
@agent.tool
|
||||
async def search_knowledge_base(ctx: RunContext[AgentDeps], query: str) -> str:
|
||||
"""Search the user's knowledge base for relevant information.
|
||||
|
||||
Args:
|
||||
ctx: The run context containing user dependencies.
|
||||
query: The search query to find relevant information.
|
||||
|
||||
Returns:
|
||||
Relevant text from the knowledge base.
|
||||
"""
|
||||
logger.info(f"Searching knowledge base for {query}")
|
||||
result = await convex.action(
|
||||
"rag:searchMultiple",
|
||||
{
|
||||
"userId": ctx.deps.user_id,
|
||||
"dbNames": ctx.deps.rag_db_names,
|
||||
"apiKey": ctx.deps.api_key,
|
||||
"query": query,
|
||||
"limit": 5,
|
||||
},
|
||||
)
|
||||
if result and result.get("text"):
|
||||
return f"Knowledge base results:\n\n{result['text']}"
|
||||
return "No relevant information found in the knowledge base."
|
||||
|
||||
return agent
|
||||
|
||||
full_prompt = f"{base_prompt} {LATEX_INSTRUCTION}"
|
||||
return Agent(model, instructions=full_prompt)
|
||||
|
||||
@@ -68,12 +123,13 @@ def build_message_history(history: list[dict[str, str]]) -> list[ModelMessage]:
|
||||
|
||||
|
||||
async def stream_response( # noqa: PLR0913
|
||||
text_agent: Agent[None, str],
|
||||
text_agent: Agent[AgentDeps, str] | Agent[None, str],
|
||||
message: str,
|
||||
history: list[dict[str, str]] | None = None,
|
||||
on_chunk: StreamCallback | None = None,
|
||||
image: ImageData | None = None,
|
||||
images: list[ImageData] | None = None,
|
||||
deps: AgentDeps | None = None,
|
||||
) -> str:
|
||||
message_history = build_message_history(history) if history else None
|
||||
|
||||
@@ -88,7 +144,7 @@ async def stream_response( # noqa: PLR0913
|
||||
else:
|
||||
prompt = message # type: ignore[assignment]
|
||||
|
||||
stream = text_agent.run_stream(prompt, message_history=message_history)
|
||||
stream = text_agent.run_stream(prompt, message_history=message_history, deps=deps)
|
||||
async with stream as result:
|
||||
async for text in result.stream_text():
|
||||
if on_chunk:
|
||||
|
||||
@@ -22,6 +22,39 @@ for example Group A: 1, Group A: 2a, Group B: 2b, etc.
|
||||
Or, Theory: 1, Theory: 2a, Practice: 1, etc.
|
||||
Only output identifiers that exist in the image."""
|
||||
|
||||
|
||||
RAGTHEORY_SYSTEM = """You help answer theoretical exam questions.
|
||||
|
||||
When you receive an IMAGE with exam questions:
|
||||
1. Identify ALL questions/blanks to fill
|
||||
2. For EACH question, use search_knowledge_base to find relevant info
|
||||
3. Provide exam-ready answers
|
||||
|
||||
OUTPUT FORMAT:
|
||||
- Number each answer matching the question number
|
||||
- Answer length: match what the question expects
|
||||
(1 sentence, 1-2 sentences, fill blank, list items, etc.)
|
||||
- Write answers EXACTLY as they should appear on the exam sheet - ready to copy 1:1
|
||||
- Use precise terminology from the course
|
||||
- No explanations, no "because", no fluff - just the answer itself
|
||||
- For multi-part questions (a, b, c), answer each part separately
|
||||
|
||||
LANGUAGE: Match the exam language (usually English for technical terms)
|
||||
|
||||
STYLE: Academic, precise, minimal - as if you're writing on paper with limited space
|
||||
|
||||
Example input:
|
||||
"Stigmergy is ............"
|
||||
Example output:
|
||||
"1. Stigmergy is indirect communication through environment\
|
||||
modification, e.g. by leaving some marks."
|
||||
|
||||
Example input:
|
||||
"How is crossing over performed in genetic programming? (one precise sentence)"
|
||||
Example output:
|
||||
"3. Usually implemented as swapping randomly selected subtrees of parent trees"
|
||||
"""
|
||||
|
||||
DEFAULT_FOLLOW_UP = (
|
||||
"Based on the conversation, suggest 3 short follow-up questions "
|
||||
"the user might want to ask. Each option should be under 50 characters."
|
||||
@@ -38,4 +71,7 @@ Summarize VERY briefly:
|
||||
|
||||
Max 2-3 sentences. This is for Apple Watch display."""
|
||||
|
||||
PRESETS: dict[str, tuple[str, str]] = {"exam": (EXAM_SYSTEM, EXAM_FOLLOW_UP)}
|
||||
PRESETS: dict[str, tuple[str, str]] = {
|
||||
"exam": (EXAM_SYSTEM, EXAM_FOLLOW_UP),
|
||||
"ragtheory": (RAGTHEORY_SYSTEM, EXAM_FOLLOW_UP),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user