feat(solaris): migrate to pydantic ai, wired respond agent through di providers

This commit is contained in:
h
2025-08-10 15:27:09 +03:00
parent bfa23d4db9
commit 6f1f2732ec
19 changed files with 1010 additions and 171 deletions

View File

@@ -10,7 +10,7 @@ dependencies = [
"aiogram>=3.20.0.post0",
"beanie>=2.0.0",
"dishka>=1.6.0",
"google-genai>=1.23.0",
"pydantic-ai>=0.4.4",
"pydantic-settings>=2.10.1",
"pydub>=0.25.1",
"rich>=14.0.0",

View File

@@ -1,5 +1,3 @@
import json
from google import genai
from .respond import RespondAgent

View File

@@ -1,32 +1,25 @@
import json
from typing import List
from google import genai
from google.genai import chats, types
from pydantic_ai import Agent
from pydantic_ai.messages import ModelMessage
from utils.config import dconfig
from dependencies.types.model import RespondModel
from ..constants import SAFETY_SETTINGS
from ..structures import InputMessage, OutputMessage
from ..tools import RESPOND_TOOLS
from ..structures import OutputMessage
from ..tools import RESPOND_TOOLSET
class RespondAgent:
def __init__(self, client: genai.client.AsyncClient) -> None:
self.client = client
def __init__(self, model: RespondModel, system_prompt: str) -> None:
self.agent = Agent(
model=model,
instructions=system_prompt,
output_type=list[OutputMessage],
toolsets=[RESPOND_TOOLSET],
)
async def generate_response(
self, system_prompt: str, history: list[types.Content]
) -> list[OutputMessage]:
content = await self.client.models.generate_content(
model=(await dconfig()).models.respond_model,
contents=history,
config=types.GenerateContentConfig(
system_instruction=system_prompt,
thinking_config=types.ThinkingConfig(thinking_budget=0),
response_mime_type="application/json",
response_schema=list[OutputMessage],
# safety_settings=SAFETY_SETTINGS,
tools=RESPOND_TOOLS,
),
)
return content.parsed
self, request: str, history: List[ModelMessage]
) -> tuple[list[OutputMessage], List[ModelMessage]]:
result = await self.agent.run(user_prompt=request, message_history=history)
return result.output, result.all_messages()

View File

@@ -4,8 +4,6 @@ from google import genai
from google.genai import types
from pydub import AudioSegment
from ..constants import SAFETY_SETTINGS
TTS_MODEL = "gemini-2.5-flash-preview-tts"

View File

@@ -1,7 +1,5 @@
from google.genai import types
from .structures import OutputMessage
SAFETY_SETTINGS = [
types.SafetySetting(
category=category.value, threshold=types.HarmBlockThreshold.OFF.value

View File

@@ -0,0 +1 @@
from .respond import RespondService

View File

@@ -1,55 +1,20 @@
import json
from beanie.odm.operators.update.general import Set
from google import genai
from google.genai import chats, types
from utils.config import dconfig
from utils.db.models.session import RespondSession
from utils.logging import console
from ..agents.respond import RespondAgent
from ..constants import SAFETY_SETTINGS
from ..prompts import load_prompt
from ..structures import InputMessage, OutputMessage
class RespondService:
def __init__(self, client: genai.client.AsyncClient, chat_id: int) -> None:
self.agent = RespondAgent(client)
self.chat_id = chat_id
self.session: RespondSession | None = None
async def _get_or_create_session(self) -> RespondSession:
session = await RespondSession.get_by_chat_id(chat_id=self.chat_id)
if not session:
session = await RespondSession.create_empty(
chat_id=self.chat_id,
system_prompt=load_prompt("default_system_prompt.txt"),
)
def __init__(self, agent: RespondAgent, session: RespondSession) -> None:
self.agent = agent
self.session = session
return session
async def process_message(self, message: InputMessage) -> list[OutputMessage]:
session = await self._get_or_create_session()
session.history.append(
types.Content(role="user", parts=[types.Part(text=message.text)])
await self.session.sync()
result, history = await self.agent.generate_response(
request=message.text, history=self.session.history
)
await self.session.update_history(history=history)
response_messages = await self.agent.generate_response(
system_prompt=session.system_prompt, history=session.history
)
if response_messages:
model_response_content = response_messages[0].text
session.history.append(
types.Content(
role="model", parts=[types.Part(text=model_response_content)]
)
)
await session.update_history(history=session.history)
return response_messages
return result

View File

@@ -1,3 +1,5 @@
from .test import test_tool
from pydantic_ai.toolsets import CombinedToolset
RESPOND_TOOLS = [test_tool]
from .test import TEST_TOOLSET
RESPOND_TOOLSET = CombinedToolset([TEST_TOOLSET])

View File

@@ -1,3 +1,5 @@
from pydantic_ai.toolsets import FunctionToolset
from utils.logging import console
@@ -12,3 +14,6 @@ async def test_tool(content: str):
"""
console.print(content)
return "ok"
TEST_TOOLSET = FunctionToolset(tools=[test_tool])

View File

@@ -1,10 +1,19 @@
from dishka import make_async_container
from dishka.integrations.aiogram import AiogramProvider
from .providers import GeminiClientProvider, SolarisProvider
from .providers import (
AIServiceProvider,
ConfigProvider,
ModelProvider,
SessionProvider,
SolarisServicesProvider,
)
container = make_async_container(
AiogramProvider(),
SolarisProvider(),
GeminiClientProvider(),
SolarisServicesProvider(),
SessionProvider(),
ModelProvider(),
ConfigProvider(),
AIServiceProvider(),
)

View File

@@ -1,2 +1,3 @@
from .gemini import GeminiClientProvider
from .solaris import SolarisProvider
from .database import ConfigProvider, SessionProvider
from .model import AIServiceProvider, ModelProvider
from .solaris import SolarisServicesProvider

View File

@@ -0,0 +1,31 @@
from typing import AsyncIterable
import aiogram.types
from dishka import Provider, Scope, provide
from dishka.integrations.aiogram import AiogramMiddlewareData
from utils.db.models import DynamicConfig, RespondSession, ReviewSession
class SessionProvider(Provider):
@provide(scope=Scope.REQUEST)
async def provide_respond_session(
self, middleware_data: AiogramMiddlewareData
) -> AsyncIterable[RespondSession]:
chat: aiogram.types.Chat = middleware_data["event_chat"]
session = await RespondSession.get_or_create_by_chat_id(chat_id=chat.id)
yield session
@provide(scope=Scope.REQUEST)
async def provide_review_session(
self, middleware_data: AiogramMiddlewareData
) -> AsyncIterable[ReviewSession]:
chat: aiogram.types.Chat = middleware_data["event_chat"]
session = await ReviewSession.get_or_create_by_chat_id(chat_id=chat.id)
yield session
class ConfigProvider(Provider):
@provide(scope=Scope.REQUEST)
async def provide_config(self) -> AsyncIterable[DynamicConfig]:
yield await DynamicConfig.get_or_create()

View File

@@ -1,13 +0,0 @@
from typing import AsyncIterable
from dishka import Provider, Scope, provide
from google import genai
from utils.env import env
class GeminiClientProvider(Provider):
@provide(scope=Scope.APP)
async def get_client(self) -> AsyncIterable[genai.client.AsyncClient]:
client = genai.Client(api_key=env.google.api_key.get_secret_value()).aio
yield client

View File

@@ -0,0 +1,63 @@
from typing import AsyncIterable
from dishka import Provider, Scope, provide
from google.genai.types import (
HarmBlockThreshold,
HarmCategory,
SafetySettingDict,
ThinkingConfigDict,
)
from pydantic_ai.models.google import GoogleModel, GoogleModelSettings
from pydantic_ai.providers.google import GoogleProvider
from utils.db.models import DynamicConfig, RespondSession
from utils.env import env
from ..types.model import RespondModel, ReviewModel
GOOGLE_SAFETY_SETTINGS = [
SafetySettingDict(category=category, threshold=HarmBlockThreshold.OFF)
for category in HarmCategory
]
class AIServiceProvider(Provider):
@provide(scope=Scope.REQUEST)
async def get_google_provider(
self, respond_session: RespondSession
) -> AsyncIterable[GoogleProvider]:
yield GoogleProvider(
api_key=respond_session.api_key_override or env.google.api_key
)
class ModelProvider(Provider):
@provide(scope=Scope.REQUEST)
async def get_respond_model(
self, config: DynamicConfig, google_provider: GoogleProvider
) -> AsyncIterable[RespondModel]:
if config.models.respond_model.startswith("gem"):
model = GoogleModel(
model_name=config.models.respond_model,
provider=google_provider,
settings=GoogleModelSettings(
google_thinking_config=ThinkingConfigDict(thinking_budget=0),
google_safety_settings=GOOGLE_SAFETY_SETTINGS,
),
)
yield model
@provide(scope=Scope.REQUEST)
async def get_review_model(
self, config: DynamicConfig, google_provider: GoogleProvider
) -> AsyncIterable[ReviewModel]:
if config.models.message_review_model.startswith("gem"):
model = GoogleModel(
model_name=config.models.message_review_model,
provider=google_provider,
settings=GoogleModelSettings(
google_thinking_config=ThinkingConfigDict(thinking_budget=0),
google_safety_settings=GOOGLE_SAFETY_SETTINGS,
),
)
yield model

View File

@@ -1,26 +1,31 @@
from typing import AsyncIterable
import aiogram.types
from dishka import Provider, Scope, provide
from dishka.integrations.aiogram import AiogramMiddlewareData
from google.genai.client import AsyncClient
from bot.modules.solaris.client import SolarisClient
from bot.modules.solaris.services.respond import RespondService
from bot.modules.solaris.agents import RespondAgent, ReviewAgent
from bot.modules.solaris.prompts import load_prompt
from bot.modules.solaris.services import RespondService
from utils.db.models import RespondSession, ReviewSession
from ..types.model import RespondModel, ReviewModel
class SolarisProvider(Provider):
@provide(scope=Scope.APP)
async def get_solaris_client(
self, client: AsyncClient
) -> AsyncIterable[SolarisClient]:
client = SolarisClient(gemini_client=client)
yield client
class AgentsProvider(Provider):
@provide(scope=Scope.REQUEST)
async def get_respond_agent(
self, model: RespondModel, session: RespondSession
) -> AsyncIterable[RespondAgent]:
yield RespondAgent(
model=model,
system_prompt=session.system_prompt_override
or load_prompt("default_system_prompt.txt"),
)
class SolarisServicesProvider(Provider):
@provide(scope=Scope.REQUEST)
async def get_respond_service(
self, client: AsyncClient, middleware_data: AiogramMiddlewareData
self, agent: RespondAgent, session: RespondSession
) -> AsyncIterable[RespondService]:
chat: aiogram.types.Chat = middleware_data["event_chat"]
service = RespondService(client=client, chat_id=chat.id)
service = RespondService(agent=agent, session=session)
yield service

View File

@@ -0,0 +1 @@
pass

View File

@@ -0,0 +1,6 @@
from typing import NewType
from pydantic_ai.models import Model
RespondModel = NewType("RespondModel", Model)
ReviewModel = NewType("ReviewModel", Model)

View File

@@ -1,14 +1,15 @@
from typing import Annotated, List
from beanie import Document, Indexed
from google.genai.types import Content
from pydantic import BaseModel, Field
from pydantic_ai.messages import ModelMessage
class SessionBase(BaseModel):
chat_id: Annotated[int, Indexed(unique=True)]
system_prompt: str
history: List[Content] = Field(default_factory=list)
system_prompt_override: str = None
history: List[ModelMessage] = Field(default_factory=list)
api_key_override: str = None
class __CommonSessionRepository(SessionBase, Document):
@@ -17,10 +18,17 @@ class __CommonSessionRepository(SessionBase, Document):
return await cls.find_one(cls.chat_id == chat_id)
@classmethod
async def create_empty(cls, chat_id: int, system_prompt: str):
return await cls(chat_id=chat_id, system_prompt=system_prompt).insert()
async def create_empty(cls, chat_id: int):
return await cls(chat_id=chat_id).insert()
async def update_history(self, history: List[Content]):
@classmethod
async def get_or_create_by_chat_id(cls, chat_id: int):
session = await cls.get_by_chat_id(chat_id)
if not session:
session = await cls.create_empty(chat_id=chat_id)
return session
async def update_history(self, history: List[ModelMessage]):
await self.set({self.history: history})

882
uv.lock generated

File diff suppressed because it is too large Load Diff