feat(solaris): migrate to pydantic ai, wired respond agent through di providers
This commit is contained in:
@@ -10,7 +10,7 @@ dependencies = [
|
||||
"aiogram>=3.20.0.post0",
|
||||
"beanie>=2.0.0",
|
||||
"dishka>=1.6.0",
|
||||
"google-genai>=1.23.0",
|
||||
"pydantic-ai>=0.4.4",
|
||||
"pydantic-settings>=2.10.1",
|
||||
"pydub>=0.25.1",
|
||||
"rich>=14.0.0",
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import json
|
||||
|
||||
from google import genai
|
||||
|
||||
from .respond import RespondAgent
|
||||
|
||||
@@ -1,32 +1,25 @@
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
from google import genai
|
||||
from google.genai import chats, types
|
||||
from pydantic_ai import Agent
|
||||
from pydantic_ai.messages import ModelMessage
|
||||
|
||||
from utils.config import dconfig
|
||||
from dependencies.types.model import RespondModel
|
||||
|
||||
from ..constants import SAFETY_SETTINGS
|
||||
from ..structures import InputMessage, OutputMessage
|
||||
from ..tools import RESPOND_TOOLS
|
||||
from ..structures import OutputMessage
|
||||
from ..tools import RESPOND_TOOLSET
|
||||
|
||||
|
||||
class RespondAgent:
|
||||
def __init__(self, client: genai.client.AsyncClient) -> None:
|
||||
self.client = client
|
||||
def __init__(self, model: RespondModel, system_prompt: str) -> None:
|
||||
self.agent = Agent(
|
||||
model=model,
|
||||
instructions=system_prompt,
|
||||
output_type=list[OutputMessage],
|
||||
toolsets=[RESPOND_TOOLSET],
|
||||
)
|
||||
|
||||
async def generate_response(
|
||||
self, system_prompt: str, history: list[types.Content]
|
||||
) -> list[OutputMessage]:
|
||||
content = await self.client.models.generate_content(
|
||||
model=(await dconfig()).models.respond_model,
|
||||
contents=history,
|
||||
config=types.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
thinking_config=types.ThinkingConfig(thinking_budget=0),
|
||||
response_mime_type="application/json",
|
||||
response_schema=list[OutputMessage],
|
||||
# safety_settings=SAFETY_SETTINGS,
|
||||
tools=RESPOND_TOOLS,
|
||||
),
|
||||
)
|
||||
return content.parsed
|
||||
self, request: str, history: List[ModelMessage]
|
||||
) -> tuple[list[OutputMessage], List[ModelMessage]]:
|
||||
result = await self.agent.run(user_prompt=request, message_history=history)
|
||||
return result.output, result.all_messages()
|
||||
|
||||
@@ -4,8 +4,6 @@ from google import genai
|
||||
from google.genai import types
|
||||
from pydub import AudioSegment
|
||||
|
||||
from ..constants import SAFETY_SETTINGS
|
||||
|
||||
TTS_MODEL = "gemini-2.5-flash-preview-tts"
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from google.genai import types
|
||||
|
||||
from .structures import OutputMessage
|
||||
|
||||
SAFETY_SETTINGS = [
|
||||
types.SafetySetting(
|
||||
category=category.value, threshold=types.HarmBlockThreshold.OFF.value
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from .respond import RespondService
|
||||
|
||||
@@ -1,55 +1,20 @@
|
||||
import json
|
||||
|
||||
from beanie.odm.operators.update.general import Set
|
||||
from google import genai
|
||||
from google.genai import chats, types
|
||||
|
||||
from utils.config import dconfig
|
||||
from utils.db.models.session import RespondSession
|
||||
from utils.logging import console
|
||||
|
||||
from ..agents.respond import RespondAgent
|
||||
from ..constants import SAFETY_SETTINGS
|
||||
from ..prompts import load_prompt
|
||||
from ..structures import InputMessage, OutputMessage
|
||||
|
||||
|
||||
class RespondService:
|
||||
def __init__(self, client: genai.client.AsyncClient, chat_id: int) -> None:
|
||||
self.agent = RespondAgent(client)
|
||||
self.chat_id = chat_id
|
||||
self.session: RespondSession | None = None
|
||||
|
||||
async def _get_or_create_session(self) -> RespondSession:
|
||||
session = await RespondSession.get_by_chat_id(chat_id=self.chat_id)
|
||||
if not session:
|
||||
session = await RespondSession.create_empty(
|
||||
chat_id=self.chat_id,
|
||||
system_prompt=load_prompt("default_system_prompt.txt"),
|
||||
)
|
||||
|
||||
def __init__(self, agent: RespondAgent, session: RespondSession) -> None:
|
||||
self.agent = agent
|
||||
self.session = session
|
||||
return session
|
||||
|
||||
async def process_message(self, message: InputMessage) -> list[OutputMessage]:
|
||||
session = await self._get_or_create_session()
|
||||
|
||||
session.history.append(
|
||||
types.Content(role="user", parts=[types.Part(text=message.text)])
|
||||
await self.session.sync()
|
||||
result, history = await self.agent.generate_response(
|
||||
request=message.text, history=self.session.history
|
||||
)
|
||||
await self.session.update_history(history=history)
|
||||
|
||||
response_messages = await self.agent.generate_response(
|
||||
system_prompt=session.system_prompt, history=session.history
|
||||
)
|
||||
|
||||
if response_messages:
|
||||
model_response_content = response_messages[0].text
|
||||
session.history.append(
|
||||
types.Content(
|
||||
role="model", parts=[types.Part(text=model_response_content)]
|
||||
)
|
||||
)
|
||||
|
||||
await session.update_history(history=session.history)
|
||||
|
||||
return response_messages
|
||||
return result
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from .test import test_tool
|
||||
from pydantic_ai.toolsets import CombinedToolset
|
||||
|
||||
RESPOND_TOOLS = [test_tool]
|
||||
from .test import TEST_TOOLSET
|
||||
|
||||
RESPOND_TOOLSET = CombinedToolset([TEST_TOOLSET])
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from pydantic_ai.toolsets import FunctionToolset
|
||||
|
||||
from utils.logging import console
|
||||
|
||||
|
||||
@@ -12,3 +14,6 @@ async def test_tool(content: str):
|
||||
"""
|
||||
console.print(content)
|
||||
return "ok"
|
||||
|
||||
|
||||
TEST_TOOLSET = FunctionToolset(tools=[test_tool])
|
||||
|
||||
@@ -1,10 +1,19 @@
|
||||
from dishka import make_async_container
|
||||
from dishka.integrations.aiogram import AiogramProvider
|
||||
|
||||
from .providers import GeminiClientProvider, SolarisProvider
|
||||
from .providers import (
|
||||
AIServiceProvider,
|
||||
ConfigProvider,
|
||||
ModelProvider,
|
||||
SessionProvider,
|
||||
SolarisServicesProvider,
|
||||
)
|
||||
|
||||
container = make_async_container(
|
||||
AiogramProvider(),
|
||||
SolarisProvider(),
|
||||
GeminiClientProvider(),
|
||||
SolarisServicesProvider(),
|
||||
SessionProvider(),
|
||||
ModelProvider(),
|
||||
ConfigProvider(),
|
||||
AIServiceProvider(),
|
||||
)
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
from .gemini import GeminiClientProvider
|
||||
from .solaris import SolarisProvider
|
||||
from .database import ConfigProvider, SessionProvider
|
||||
from .model import AIServiceProvider, ModelProvider
|
||||
from .solaris import SolarisServicesProvider
|
||||
|
||||
31
src/dependencies/providers/database.py
Normal file
31
src/dependencies/providers/database.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from typing import AsyncIterable
|
||||
|
||||
import aiogram.types
|
||||
from dishka import Provider, Scope, provide
|
||||
from dishka.integrations.aiogram import AiogramMiddlewareData
|
||||
|
||||
from utils.db.models import DynamicConfig, RespondSession, ReviewSession
|
||||
|
||||
|
||||
class SessionProvider(Provider):
|
||||
@provide(scope=Scope.REQUEST)
|
||||
async def provide_respond_session(
|
||||
self, middleware_data: AiogramMiddlewareData
|
||||
) -> AsyncIterable[RespondSession]:
|
||||
chat: aiogram.types.Chat = middleware_data["event_chat"]
|
||||
session = await RespondSession.get_or_create_by_chat_id(chat_id=chat.id)
|
||||
yield session
|
||||
|
||||
@provide(scope=Scope.REQUEST)
|
||||
async def provide_review_session(
|
||||
self, middleware_data: AiogramMiddlewareData
|
||||
) -> AsyncIterable[ReviewSession]:
|
||||
chat: aiogram.types.Chat = middleware_data["event_chat"]
|
||||
session = await ReviewSession.get_or_create_by_chat_id(chat_id=chat.id)
|
||||
yield session
|
||||
|
||||
|
||||
class ConfigProvider(Provider):
|
||||
@provide(scope=Scope.REQUEST)
|
||||
async def provide_config(self) -> AsyncIterable[DynamicConfig]:
|
||||
yield await DynamicConfig.get_or_create()
|
||||
@@ -1,13 +0,0 @@
|
||||
from typing import AsyncIterable
|
||||
|
||||
from dishka import Provider, Scope, provide
|
||||
from google import genai
|
||||
|
||||
from utils.env import env
|
||||
|
||||
|
||||
class GeminiClientProvider(Provider):
|
||||
@provide(scope=Scope.APP)
|
||||
async def get_client(self) -> AsyncIterable[genai.client.AsyncClient]:
|
||||
client = genai.Client(api_key=env.google.api_key.get_secret_value()).aio
|
||||
yield client
|
||||
63
src/dependencies/providers/model.py
Normal file
63
src/dependencies/providers/model.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from typing import AsyncIterable
|
||||
|
||||
from dishka import Provider, Scope, provide
|
||||
from google.genai.types import (
|
||||
HarmBlockThreshold,
|
||||
HarmCategory,
|
||||
SafetySettingDict,
|
||||
ThinkingConfigDict,
|
||||
)
|
||||
from pydantic_ai.models.google import GoogleModel, GoogleModelSettings
|
||||
from pydantic_ai.providers.google import GoogleProvider
|
||||
|
||||
from utils.db.models import DynamicConfig, RespondSession
|
||||
from utils.env import env
|
||||
|
||||
from ..types.model import RespondModel, ReviewModel
|
||||
|
||||
GOOGLE_SAFETY_SETTINGS = [
|
||||
SafetySettingDict(category=category, threshold=HarmBlockThreshold.OFF)
|
||||
for category in HarmCategory
|
||||
]
|
||||
|
||||
|
||||
class AIServiceProvider(Provider):
|
||||
@provide(scope=Scope.REQUEST)
|
||||
async def get_google_provider(
|
||||
self, respond_session: RespondSession
|
||||
) -> AsyncIterable[GoogleProvider]:
|
||||
yield GoogleProvider(
|
||||
api_key=respond_session.api_key_override or env.google.api_key
|
||||
)
|
||||
|
||||
|
||||
class ModelProvider(Provider):
|
||||
@provide(scope=Scope.REQUEST)
|
||||
async def get_respond_model(
|
||||
self, config: DynamicConfig, google_provider: GoogleProvider
|
||||
) -> AsyncIterable[RespondModel]:
|
||||
if config.models.respond_model.startswith("gem"):
|
||||
model = GoogleModel(
|
||||
model_name=config.models.respond_model,
|
||||
provider=google_provider,
|
||||
settings=GoogleModelSettings(
|
||||
google_thinking_config=ThinkingConfigDict(thinking_budget=0),
|
||||
google_safety_settings=GOOGLE_SAFETY_SETTINGS,
|
||||
),
|
||||
)
|
||||
yield model
|
||||
|
||||
@provide(scope=Scope.REQUEST)
|
||||
async def get_review_model(
|
||||
self, config: DynamicConfig, google_provider: GoogleProvider
|
||||
) -> AsyncIterable[ReviewModel]:
|
||||
if config.models.message_review_model.startswith("gem"):
|
||||
model = GoogleModel(
|
||||
model_name=config.models.message_review_model,
|
||||
provider=google_provider,
|
||||
settings=GoogleModelSettings(
|
||||
google_thinking_config=ThinkingConfigDict(thinking_budget=0),
|
||||
google_safety_settings=GOOGLE_SAFETY_SETTINGS,
|
||||
),
|
||||
)
|
||||
yield model
|
||||
@@ -1,26 +1,31 @@
|
||||
from typing import AsyncIterable
|
||||
|
||||
import aiogram.types
|
||||
from dishka import Provider, Scope, provide
|
||||
from dishka.integrations.aiogram import AiogramMiddlewareData
|
||||
from google.genai.client import AsyncClient
|
||||
|
||||
from bot.modules.solaris.client import SolarisClient
|
||||
from bot.modules.solaris.services.respond import RespondService
|
||||
from bot.modules.solaris.agents import RespondAgent, ReviewAgent
|
||||
from bot.modules.solaris.prompts import load_prompt
|
||||
from bot.modules.solaris.services import RespondService
|
||||
from utils.db.models import RespondSession, ReviewSession
|
||||
|
||||
from ..types.model import RespondModel, ReviewModel
|
||||
|
||||
|
||||
class SolarisProvider(Provider):
|
||||
@provide(scope=Scope.APP)
|
||||
async def get_solaris_client(
|
||||
self, client: AsyncClient
|
||||
) -> AsyncIterable[SolarisClient]:
|
||||
client = SolarisClient(gemini_client=client)
|
||||
yield client
|
||||
class AgentsProvider(Provider):
|
||||
@provide(scope=Scope.REQUEST)
|
||||
async def get_respond_agent(
|
||||
self, model: RespondModel, session: RespondSession
|
||||
) -> AsyncIterable[RespondAgent]:
|
||||
yield RespondAgent(
|
||||
model=model,
|
||||
system_prompt=session.system_prompt_override
|
||||
or load_prompt("default_system_prompt.txt"),
|
||||
)
|
||||
|
||||
|
||||
class SolarisServicesProvider(Provider):
|
||||
@provide(scope=Scope.REQUEST)
|
||||
async def get_respond_service(
|
||||
self, client: AsyncClient, middleware_data: AiogramMiddlewareData
|
||||
self, agent: RespondAgent, session: RespondSession
|
||||
) -> AsyncIterable[RespondService]:
|
||||
chat: aiogram.types.Chat = middleware_data["event_chat"]
|
||||
service = RespondService(client=client, chat_id=chat.id)
|
||||
service = RespondService(agent=agent, session=session)
|
||||
yield service
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
pass
|
||||
|
||||
6
src/dependencies/types/model.py
Normal file
6
src/dependencies/types/model.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from typing import NewType
|
||||
|
||||
from pydantic_ai.models import Model
|
||||
|
||||
RespondModel = NewType("RespondModel", Model)
|
||||
ReviewModel = NewType("ReviewModel", Model)
|
||||
@@ -1,14 +1,15 @@
|
||||
from typing import Annotated, List
|
||||
|
||||
from beanie import Document, Indexed
|
||||
from google.genai.types import Content
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_ai.messages import ModelMessage
|
||||
|
||||
|
||||
class SessionBase(BaseModel):
|
||||
chat_id: Annotated[int, Indexed(unique=True)]
|
||||
system_prompt: str
|
||||
history: List[Content] = Field(default_factory=list)
|
||||
system_prompt_override: str = None
|
||||
history: List[ModelMessage] = Field(default_factory=list)
|
||||
api_key_override: str = None
|
||||
|
||||
|
||||
class __CommonSessionRepository(SessionBase, Document):
|
||||
@@ -17,10 +18,17 @@ class __CommonSessionRepository(SessionBase, Document):
|
||||
return await cls.find_one(cls.chat_id == chat_id)
|
||||
|
||||
@classmethod
|
||||
async def create_empty(cls, chat_id: int, system_prompt: str):
|
||||
return await cls(chat_id=chat_id, system_prompt=system_prompt).insert()
|
||||
async def create_empty(cls, chat_id: int):
|
||||
return await cls(chat_id=chat_id).insert()
|
||||
|
||||
async def update_history(self, history: List[Content]):
|
||||
@classmethod
|
||||
async def get_or_create_by_chat_id(cls, chat_id: int):
|
||||
session = await cls.get_by_chat_id(chat_id)
|
||||
if not session:
|
||||
session = await cls.create_empty(chat_id=chat_id)
|
||||
return session
|
||||
|
||||
async def update_history(self, history: List[ModelMessage]):
|
||||
await self.set({self.history: history})
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user