from typing import Annotated, List from beanie import Document, Indexed from pydantic import BaseModel, Field from pydantic_ai.messages import ModelMessage class SessionBase(BaseModel): chat_id: Annotated[int, Indexed(unique=True)] system_prompt_override: str = None history: List[ModelMessage] = Field(default_factory=list) api_key_override: str = None class __CommonSessionRepository(SessionBase, Document): @classmethod async def get_by_chat_id(cls, chat_id: int): return await cls.find_one(cls.chat_id == chat_id) @classmethod async def create_empty(cls, chat_id: int): return await cls(chat_id=chat_id).insert() @classmethod async def get_or_create_by_chat_id(cls, chat_id: int): session = await cls.get_by_chat_id(chat_id) if not session: session = await cls.create_empty(chat_id=chat_id) return session async def update_history(self, history: List[ModelMessage]): await self.set({self.history: history}) class ReviewSession(__CommonSessionRepository): class Settings: name = "review_sessions" class RespondSession(__CommonSessionRepository): class Settings: name = "respond_sessions"