feat(solaris): migrate to pydantic ai, wired respond agent through di providers
This commit is contained in:
@@ -1,14 +1,15 @@
|
||||
from typing import Annotated, List
|
||||
|
||||
from beanie import Document, Indexed
|
||||
from google.genai.types import Content
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_ai.messages import ModelMessage
|
||||
|
||||
|
||||
class SessionBase(BaseModel):
|
||||
chat_id: Annotated[int, Indexed(unique=True)]
|
||||
system_prompt: str
|
||||
history: List[Content] = Field(default_factory=list)
|
||||
system_prompt_override: str = None
|
||||
history: List[ModelMessage] = Field(default_factory=list)
|
||||
api_key_override: str = None
|
||||
|
||||
|
||||
class __CommonSessionRepository(SessionBase, Document):
|
||||
@@ -17,10 +18,17 @@ class __CommonSessionRepository(SessionBase, Document):
|
||||
return await cls.find_one(cls.chat_id == chat_id)
|
||||
|
||||
@classmethod
|
||||
async def create_empty(cls, chat_id: int, system_prompt: str):
|
||||
return await cls(chat_id=chat_id, system_prompt=system_prompt).insert()
|
||||
async def create_empty(cls, chat_id: int):
|
||||
return await cls(chat_id=chat_id).insert()
|
||||
|
||||
async def update_history(self, history: List[Content]):
|
||||
@classmethod
|
||||
async def get_or_create_by_chat_id(cls, chat_id: int):
|
||||
session = await cls.get_by_chat_id(chat_id)
|
||||
if not session:
|
||||
session = await cls.create_empty(chat_id=chat_id)
|
||||
return session
|
||||
|
||||
async def update_history(self, history: List[ModelMessage]):
|
||||
await self.set({self.history: history})
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user