feat(solaris): migrate to pydantic ai, wired respond agent through di providers

This commit is contained in:
h
2025-08-10 15:27:09 +03:00
parent bfa23d4db9
commit 6f1f2732ec
19 changed files with 1010 additions and 171 deletions

View File

@@ -1,10 +1,19 @@
from dishka import make_async_container
from dishka.integrations.aiogram import AiogramProvider
from .providers import GeminiClientProvider, SolarisProvider
from .providers import (
AIServiceProvider,
ConfigProvider,
ModelProvider,
SessionProvider,
SolarisServicesProvider,
)
container = make_async_container(
AiogramProvider(),
SolarisProvider(),
GeminiClientProvider(),
SolarisServicesProvider(),
SessionProvider(),
ModelProvider(),
ConfigProvider(),
AIServiceProvider(),
)

View File

@@ -1,2 +1,3 @@
from .gemini import GeminiClientProvider
from .solaris import SolarisProvider
from .database import ConfigProvider, SessionProvider
from .model import AIServiceProvider, ModelProvider
from .solaris import SolarisServicesProvider

View File

@@ -0,0 +1,31 @@
from typing import AsyncIterable
import aiogram.types
from dishka import Provider, Scope, provide
from dishka.integrations.aiogram import AiogramMiddlewareData
from utils.db.models import DynamicConfig, RespondSession, ReviewSession
class SessionProvider(Provider):
@provide(scope=Scope.REQUEST)
async def provide_respond_session(
self, middleware_data: AiogramMiddlewareData
) -> AsyncIterable[RespondSession]:
chat: aiogram.types.Chat = middleware_data["event_chat"]
session = await RespondSession.get_or_create_by_chat_id(chat_id=chat.id)
yield session
@provide(scope=Scope.REQUEST)
async def provide_review_session(
self, middleware_data: AiogramMiddlewareData
) -> AsyncIterable[ReviewSession]:
chat: aiogram.types.Chat = middleware_data["event_chat"]
session = await ReviewSession.get_or_create_by_chat_id(chat_id=chat.id)
yield session
class ConfigProvider(Provider):
@provide(scope=Scope.REQUEST)
async def provide_config(self) -> AsyncIterable[DynamicConfig]:
yield await DynamicConfig.get_or_create()

View File

@@ -1,13 +0,0 @@
from typing import AsyncIterable
from dishka import Provider, Scope, provide
from google import genai
from utils.env import env
class GeminiClientProvider(Provider):
@provide(scope=Scope.APP)
async def get_client(self) -> AsyncIterable[genai.client.AsyncClient]:
client = genai.Client(api_key=env.google.api_key.get_secret_value()).aio
yield client

View File

@@ -0,0 +1,63 @@
from typing import AsyncIterable
from dishka import Provider, Scope, provide
from google.genai.types import (
HarmBlockThreshold,
HarmCategory,
SafetySettingDict,
ThinkingConfigDict,
)
from pydantic_ai.models.google import GoogleModel, GoogleModelSettings
from pydantic_ai.providers.google import GoogleProvider
from utils.db.models import DynamicConfig, RespondSession
from utils.env import env
from ..types.model import RespondModel, ReviewModel
GOOGLE_SAFETY_SETTINGS = [
SafetySettingDict(category=category, threshold=HarmBlockThreshold.OFF)
for category in HarmCategory
]
class AIServiceProvider(Provider):
@provide(scope=Scope.REQUEST)
async def get_google_provider(
self, respond_session: RespondSession
) -> AsyncIterable[GoogleProvider]:
yield GoogleProvider(
api_key=respond_session.api_key_override or env.google.api_key
)
class ModelProvider(Provider):
@provide(scope=Scope.REQUEST)
async def get_respond_model(
self, config: DynamicConfig, google_provider: GoogleProvider
) -> AsyncIterable[RespondModel]:
if config.models.respond_model.startswith("gem"):
model = GoogleModel(
model_name=config.models.respond_model,
provider=google_provider,
settings=GoogleModelSettings(
google_thinking_config=ThinkingConfigDict(thinking_budget=0),
google_safety_settings=GOOGLE_SAFETY_SETTINGS,
),
)
yield model
@provide(scope=Scope.REQUEST)
async def get_review_model(
self, config: DynamicConfig, google_provider: GoogleProvider
) -> AsyncIterable[ReviewModel]:
if config.models.message_review_model.startswith("gem"):
model = GoogleModel(
model_name=config.models.message_review_model,
provider=google_provider,
settings=GoogleModelSettings(
google_thinking_config=ThinkingConfigDict(thinking_budget=0),
google_safety_settings=GOOGLE_SAFETY_SETTINGS,
),
)
yield model

View File

@@ -1,26 +1,31 @@
from typing import AsyncIterable
import aiogram.types
from dishka import Provider, Scope, provide
from dishka.integrations.aiogram import AiogramMiddlewareData
from google.genai.client import AsyncClient
from bot.modules.solaris.client import SolarisClient
from bot.modules.solaris.services.respond import RespondService
from bot.modules.solaris.agents import RespondAgent, ReviewAgent
from bot.modules.solaris.prompts import load_prompt
from bot.modules.solaris.services import RespondService
from utils.db.models import RespondSession, ReviewSession
from ..types.model import RespondModel, ReviewModel
class SolarisProvider(Provider):
@provide(scope=Scope.APP)
async def get_solaris_client(
self, client: AsyncClient
) -> AsyncIterable[SolarisClient]:
client = SolarisClient(gemini_client=client)
yield client
class AgentsProvider(Provider):
@provide(scope=Scope.REQUEST)
async def get_respond_agent(
self, model: RespondModel, session: RespondSession
) -> AsyncIterable[RespondAgent]:
yield RespondAgent(
model=model,
system_prompt=session.system_prompt_override
or load_prompt("default_system_prompt.txt"),
)
class SolarisServicesProvider(Provider):
@provide(scope=Scope.REQUEST)
async def get_respond_service(
self, client: AsyncClient, middleware_data: AiogramMiddlewareData
self, agent: RespondAgent, session: RespondSession
) -> AsyncIterable[RespondService]:
chat: aiogram.types.Chat = middleware_data["event_chat"]
service = RespondService(client=client, chat_id=chat.id)
service = RespondService(agent=agent, session=session)
yield service

View File

@@ -0,0 +1 @@
pass

View File

@@ -0,0 +1,6 @@
from typing import NewType
from pydantic_ai.models import Model
RespondModel = NewType("RespondModel", Model)
ReviewModel = NewType("ReviewModel", Model)