feat(solaris): migrate to pydantic ai, wired respond agent through di providers

This commit is contained in:
h
2025-08-10 15:27:09 +03:00
parent bfa23d4db9
commit 6f1f2732ec
19 changed files with 1010 additions and 171 deletions

View File

@@ -10,7 +10,7 @@ dependencies = [
"aiogram>=3.20.0.post0", "aiogram>=3.20.0.post0",
"beanie>=2.0.0", "beanie>=2.0.0",
"dishka>=1.6.0", "dishka>=1.6.0",
"google-genai>=1.23.0", "pydantic-ai>=0.4.4",
"pydantic-settings>=2.10.1", "pydantic-settings>=2.10.1",
"pydub>=0.25.1", "pydub>=0.25.1",
"rich>=14.0.0", "rich>=14.0.0",

View File

@@ -1,5 +1,3 @@
import json
from google import genai from google import genai
from .respond import RespondAgent from .respond import RespondAgent

View File

@@ -1,32 +1,25 @@
import json from typing import List
from google import genai from pydantic_ai import Agent
from google.genai import chats, types from pydantic_ai.messages import ModelMessage
from utils.config import dconfig from dependencies.types.model import RespondModel
from ..constants import SAFETY_SETTINGS from ..structures import OutputMessage
from ..structures import InputMessage, OutputMessage from ..tools import RESPOND_TOOLSET
from ..tools import RESPOND_TOOLS
class RespondAgent: class RespondAgent:
def __init__(self, client: genai.client.AsyncClient) -> None: def __init__(self, model: RespondModel, system_prompt: str) -> None:
self.client = client self.agent = Agent(
model=model,
instructions=system_prompt,
output_type=list[OutputMessage],
toolsets=[RESPOND_TOOLSET],
)
async def generate_response( async def generate_response(
self, system_prompt: str, history: list[types.Content] self, request: str, history: List[ModelMessage]
) -> list[OutputMessage]: ) -> tuple[list[OutputMessage], List[ModelMessage]]:
content = await self.client.models.generate_content( result = await self.agent.run(user_prompt=request, message_history=history)
model=(await dconfig()).models.respond_model, return result.output, result.all_messages()
contents=history,
config=types.GenerateContentConfig(
system_instruction=system_prompt,
thinking_config=types.ThinkingConfig(thinking_budget=0),
response_mime_type="application/json",
response_schema=list[OutputMessage],
# safety_settings=SAFETY_SETTINGS,
tools=RESPOND_TOOLS,
),
)
return content.parsed

View File

@@ -4,8 +4,6 @@ from google import genai
from google.genai import types from google.genai import types
from pydub import AudioSegment from pydub import AudioSegment
from ..constants import SAFETY_SETTINGS
TTS_MODEL = "gemini-2.5-flash-preview-tts" TTS_MODEL = "gemini-2.5-flash-preview-tts"

View File

@@ -1,7 +1,5 @@
from google.genai import types from google.genai import types
from .structures import OutputMessage
SAFETY_SETTINGS = [ SAFETY_SETTINGS = [
types.SafetySetting( types.SafetySetting(
category=category.value, threshold=types.HarmBlockThreshold.OFF.value category=category.value, threshold=types.HarmBlockThreshold.OFF.value

View File

@@ -0,0 +1 @@
from .respond import RespondService

View File

@@ -1,55 +1,20 @@
import json
from beanie.odm.operators.update.general import Set
from google import genai
from google.genai import chats, types
from utils.config import dconfig
from utils.db.models.session import RespondSession from utils.db.models.session import RespondSession
from utils.logging import console
from ..agents.respond import RespondAgent from ..agents.respond import RespondAgent
from ..constants import SAFETY_SETTINGS
from ..prompts import load_prompt from ..prompts import load_prompt
from ..structures import InputMessage, OutputMessage from ..structures import InputMessage, OutputMessage
class RespondService: class RespondService:
def __init__(self, client: genai.client.AsyncClient, chat_id: int) -> None: def __init__(self, agent: RespondAgent, session: RespondSession) -> None:
self.agent = RespondAgent(client) self.agent = agent
self.chat_id = chat_id
self.session: RespondSession | None = None
async def _get_or_create_session(self) -> RespondSession:
session = await RespondSession.get_by_chat_id(chat_id=self.chat_id)
if not session:
session = await RespondSession.create_empty(
chat_id=self.chat_id,
system_prompt=load_prompt("default_system_prompt.txt"),
)
self.session = session self.session = session
return session
async def process_message(self, message: InputMessage) -> list[OutputMessage]: async def process_message(self, message: InputMessage) -> list[OutputMessage]:
session = await self._get_or_create_session() await self.session.sync()
result, history = await self.agent.generate_response(
session.history.append( request=message.text, history=self.session.history
types.Content(role="user", parts=[types.Part(text=message.text)])
) )
await self.session.update_history(history=history)
response_messages = await self.agent.generate_response( return result
system_prompt=session.system_prompt, history=session.history
)
if response_messages:
model_response_content = response_messages[0].text
session.history.append(
types.Content(
role="model", parts=[types.Part(text=model_response_content)]
)
)
await session.update_history(history=session.history)
return response_messages

View File

@@ -1,3 +1,5 @@
from .test import test_tool from pydantic_ai.toolsets import CombinedToolset
RESPOND_TOOLS = [test_tool] from .test import TEST_TOOLSET
RESPOND_TOOLSET = CombinedToolset([TEST_TOOLSET])

View File

@@ -1,3 +1,5 @@
from pydantic_ai.toolsets import FunctionToolset
from utils.logging import console from utils.logging import console
@@ -12,3 +14,6 @@ async def test_tool(content: str):
""" """
console.print(content) console.print(content)
return "ok" return "ok"
TEST_TOOLSET = FunctionToolset(tools=[test_tool])

View File

@@ -1,10 +1,19 @@
from dishka import make_async_container from dishka import make_async_container
from dishka.integrations.aiogram import AiogramProvider from dishka.integrations.aiogram import AiogramProvider
from .providers import GeminiClientProvider, SolarisProvider from .providers import (
AIServiceProvider,
ConfigProvider,
ModelProvider,
SessionProvider,
SolarisServicesProvider,
)
container = make_async_container( container = make_async_container(
AiogramProvider(), AiogramProvider(),
SolarisProvider(), SolarisServicesProvider(),
GeminiClientProvider(), SessionProvider(),
ModelProvider(),
ConfigProvider(),
AIServiceProvider(),
) )

View File

@@ -1,2 +1,3 @@
from .gemini import GeminiClientProvider from .database import ConfigProvider, SessionProvider
from .solaris import SolarisProvider from .model import AIServiceProvider, ModelProvider
from .solaris import SolarisServicesProvider

View File

@@ -0,0 +1,31 @@
from typing import AsyncIterable
import aiogram.types
from dishka import Provider, Scope, provide
from dishka.integrations.aiogram import AiogramMiddlewareData
from utils.db.models import DynamicConfig, RespondSession, ReviewSession
class SessionProvider(Provider):
@provide(scope=Scope.REQUEST)
async def provide_respond_session(
self, middleware_data: AiogramMiddlewareData
) -> AsyncIterable[RespondSession]:
chat: aiogram.types.Chat = middleware_data["event_chat"]
session = await RespondSession.get_or_create_by_chat_id(chat_id=chat.id)
yield session
@provide(scope=Scope.REQUEST)
async def provide_review_session(
self, middleware_data: AiogramMiddlewareData
) -> AsyncIterable[ReviewSession]:
chat: aiogram.types.Chat = middleware_data["event_chat"]
session = await ReviewSession.get_or_create_by_chat_id(chat_id=chat.id)
yield session
class ConfigProvider(Provider):
@provide(scope=Scope.REQUEST)
async def provide_config(self) -> AsyncIterable[DynamicConfig]:
yield await DynamicConfig.get_or_create()

View File

@@ -1,13 +0,0 @@
from typing import AsyncIterable
from dishka import Provider, Scope, provide
from google import genai
from utils.env import env
class GeminiClientProvider(Provider):
@provide(scope=Scope.APP)
async def get_client(self) -> AsyncIterable[genai.client.AsyncClient]:
client = genai.Client(api_key=env.google.api_key.get_secret_value()).aio
yield client

View File

@@ -0,0 +1,63 @@
from typing import AsyncIterable
from dishka import Provider, Scope, provide
from google.genai.types import (
HarmBlockThreshold,
HarmCategory,
SafetySettingDict,
ThinkingConfigDict,
)
from pydantic_ai.models.google import GoogleModel, GoogleModelSettings
from pydantic_ai.providers.google import GoogleProvider
from utils.db.models import DynamicConfig, RespondSession
from utils.env import env
from ..types.model import RespondModel, ReviewModel
GOOGLE_SAFETY_SETTINGS = [
SafetySettingDict(category=category, threshold=HarmBlockThreshold.OFF)
for category in HarmCategory
]
class AIServiceProvider(Provider):
@provide(scope=Scope.REQUEST)
async def get_google_provider(
self, respond_session: RespondSession
) -> AsyncIterable[GoogleProvider]:
yield GoogleProvider(
api_key=respond_session.api_key_override or env.google.api_key
)
class ModelProvider(Provider):
@provide(scope=Scope.REQUEST)
async def get_respond_model(
self, config: DynamicConfig, google_provider: GoogleProvider
) -> AsyncIterable[RespondModel]:
if config.models.respond_model.startswith("gem"):
model = GoogleModel(
model_name=config.models.respond_model,
provider=google_provider,
settings=GoogleModelSettings(
google_thinking_config=ThinkingConfigDict(thinking_budget=0),
google_safety_settings=GOOGLE_SAFETY_SETTINGS,
),
)
yield model
@provide(scope=Scope.REQUEST)
async def get_review_model(
self, config: DynamicConfig, google_provider: GoogleProvider
) -> AsyncIterable[ReviewModel]:
if config.models.message_review_model.startswith("gem"):
model = GoogleModel(
model_name=config.models.message_review_model,
provider=google_provider,
settings=GoogleModelSettings(
google_thinking_config=ThinkingConfigDict(thinking_budget=0),
google_safety_settings=GOOGLE_SAFETY_SETTINGS,
),
)
yield model

View File

@@ -1,26 +1,31 @@
from typing import AsyncIterable from typing import AsyncIterable
import aiogram.types
from dishka import Provider, Scope, provide from dishka import Provider, Scope, provide
from dishka.integrations.aiogram import AiogramMiddlewareData
from google.genai.client import AsyncClient
from bot.modules.solaris.client import SolarisClient from bot.modules.solaris.agents import RespondAgent, ReviewAgent
from bot.modules.solaris.services.respond import RespondService from bot.modules.solaris.prompts import load_prompt
from bot.modules.solaris.services import RespondService
from utils.db.models import RespondSession, ReviewSession
from ..types.model import RespondModel, ReviewModel
class SolarisProvider(Provider): class AgentsProvider(Provider):
@provide(scope=Scope.APP) @provide(scope=Scope.REQUEST)
async def get_solaris_client( async def get_respond_agent(
self, client: AsyncClient self, model: RespondModel, session: RespondSession
) -> AsyncIterable[SolarisClient]: ) -> AsyncIterable[RespondAgent]:
client = SolarisClient(gemini_client=client) yield RespondAgent(
yield client model=model,
system_prompt=session.system_prompt_override
or load_prompt("default_system_prompt.txt"),
)
class SolarisServicesProvider(Provider):
@provide(scope=Scope.REQUEST) @provide(scope=Scope.REQUEST)
async def get_respond_service( async def get_respond_service(
self, client: AsyncClient, middleware_data: AiogramMiddlewareData self, agent: RespondAgent, session: RespondSession
) -> AsyncIterable[RespondService]: ) -> AsyncIterable[RespondService]:
chat: aiogram.types.Chat = middleware_data["event_chat"] service = RespondService(agent=agent, session=session)
service = RespondService(client=client, chat_id=chat.id)
yield service yield service

View File

@@ -0,0 +1 @@
pass

View File

@@ -0,0 +1,6 @@
from typing import NewType
from pydantic_ai.models import Model
RespondModel = NewType("RespondModel", Model)
ReviewModel = NewType("ReviewModel", Model)

View File

@@ -1,14 +1,15 @@
from typing import Annotated, List from typing import Annotated, List
from beanie import Document, Indexed from beanie import Document, Indexed
from google.genai.types import Content
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from pydantic_ai.messages import ModelMessage
class SessionBase(BaseModel): class SessionBase(BaseModel):
chat_id: Annotated[int, Indexed(unique=True)] chat_id: Annotated[int, Indexed(unique=True)]
system_prompt: str system_prompt_override: str = None
history: List[Content] = Field(default_factory=list) history: List[ModelMessage] = Field(default_factory=list)
api_key_override: str = None
class __CommonSessionRepository(SessionBase, Document): class __CommonSessionRepository(SessionBase, Document):
@@ -17,10 +18,17 @@ class __CommonSessionRepository(SessionBase, Document):
return await cls.find_one(cls.chat_id == chat_id) return await cls.find_one(cls.chat_id == chat_id)
@classmethod @classmethod
async def create_empty(cls, chat_id: int, system_prompt: str): async def create_empty(cls, chat_id: int):
return await cls(chat_id=chat_id, system_prompt=system_prompt).insert() return await cls(chat_id=chat_id).insert()
async def update_history(self, history: List[Content]): @classmethod
async def get_or_create_by_chat_id(cls, chat_id: int):
session = await cls.get_by_chat_id(chat_id)
if not session:
session = await cls.create_empty(chat_id=chat_id)
return session
async def update_history(self, history: List[ModelMessage]):
await self.set({self.history: history}) await self.set({self.history: history})

882
uv.lock generated

File diff suppressed because it is too large Load Diff