43 lines
1.2 KiB
Python
43 lines
1.2 KiB
Python
from typing import Annotated, List
|
|
|
|
from beanie import Document, Indexed
|
|
from pydantic import BaseModel, Field
|
|
from pydantic_ai.messages import ModelMessage
|
|
|
|
|
|
class SessionBase(BaseModel):
|
|
chat_id: Annotated[int, Indexed(unique=True)]
|
|
system_prompt_override: str = None
|
|
history: List[ModelMessage] = Field(default_factory=list)
|
|
api_key_override: str = None
|
|
|
|
|
|
class __CommonSessionRepository(SessionBase, Document):
|
|
@classmethod
|
|
async def get_by_chat_id(cls, chat_id: int):
|
|
return await cls.find_one(cls.chat_id == chat_id)
|
|
|
|
@classmethod
|
|
async def create_empty(cls, chat_id: int):
|
|
return await cls(chat_id=chat_id).insert()
|
|
|
|
@classmethod
|
|
async def get_or_create_by_chat_id(cls, chat_id: int):
|
|
session = await cls.get_by_chat_id(chat_id)
|
|
if not session:
|
|
session = await cls.create_empty(chat_id=chat_id)
|
|
return session
|
|
|
|
async def update_history(self, history: List[ModelMessage]):
|
|
await self.set({self.history: history})
|
|
|
|
|
|
class ReviewSession(__CommonSessionRepository):
|
|
class Settings:
|
|
name = "review_sessions"
|
|
|
|
|
|
class RespondSession(__CommonSessionRepository):
|
|
class Settings:
|
|
name = "respond_sessions"
|