feat(solaris): migrate to pydantic ai, wired respond agent through di providers
This commit is contained in:
@@ -1,10 +1,19 @@
|
||||
from dishka import make_async_container
|
||||
from dishka.integrations.aiogram import AiogramProvider
|
||||
|
||||
from .providers import GeminiClientProvider, SolarisProvider
|
||||
from .providers import (
|
||||
AIServiceProvider,
|
||||
ConfigProvider,
|
||||
ModelProvider,
|
||||
SessionProvider,
|
||||
SolarisServicesProvider,
|
||||
)
|
||||
|
||||
container = make_async_container(
|
||||
AiogramProvider(),
|
||||
SolarisProvider(),
|
||||
GeminiClientProvider(),
|
||||
SolarisServicesProvider(),
|
||||
SessionProvider(),
|
||||
ModelProvider(),
|
||||
ConfigProvider(),
|
||||
AIServiceProvider(),
|
||||
)
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
from .gemini import GeminiClientProvider
|
||||
from .solaris import SolarisProvider
|
||||
from .database import ConfigProvider, SessionProvider
|
||||
from .model import AIServiceProvider, ModelProvider
|
||||
from .solaris import SolarisServicesProvider
|
||||
|
||||
31
src/dependencies/providers/database.py
Normal file
31
src/dependencies/providers/database.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from typing import AsyncIterable
|
||||
|
||||
import aiogram.types
|
||||
from dishka import Provider, Scope, provide
|
||||
from dishka.integrations.aiogram import AiogramMiddlewareData
|
||||
|
||||
from utils.db.models import DynamicConfig, RespondSession, ReviewSession
|
||||
|
||||
|
||||
class SessionProvider(Provider):
|
||||
@provide(scope=Scope.REQUEST)
|
||||
async def provide_respond_session(
|
||||
self, middleware_data: AiogramMiddlewareData
|
||||
) -> AsyncIterable[RespondSession]:
|
||||
chat: aiogram.types.Chat = middleware_data["event_chat"]
|
||||
session = await RespondSession.get_or_create_by_chat_id(chat_id=chat.id)
|
||||
yield session
|
||||
|
||||
@provide(scope=Scope.REQUEST)
|
||||
async def provide_review_session(
|
||||
self, middleware_data: AiogramMiddlewareData
|
||||
) -> AsyncIterable[ReviewSession]:
|
||||
chat: aiogram.types.Chat = middleware_data["event_chat"]
|
||||
session = await ReviewSession.get_or_create_by_chat_id(chat_id=chat.id)
|
||||
yield session
|
||||
|
||||
|
||||
class ConfigProvider(Provider):
|
||||
@provide(scope=Scope.REQUEST)
|
||||
async def provide_config(self) -> AsyncIterable[DynamicConfig]:
|
||||
yield await DynamicConfig.get_or_create()
|
||||
@@ -1,13 +0,0 @@
|
||||
from typing import AsyncIterable
|
||||
|
||||
from dishka import Provider, Scope, provide
|
||||
from google import genai
|
||||
|
||||
from utils.env import env
|
||||
|
||||
|
||||
class GeminiClientProvider(Provider):
|
||||
@provide(scope=Scope.APP)
|
||||
async def get_client(self) -> AsyncIterable[genai.client.AsyncClient]:
|
||||
client = genai.Client(api_key=env.google.api_key.get_secret_value()).aio
|
||||
yield client
|
||||
63
src/dependencies/providers/model.py
Normal file
63
src/dependencies/providers/model.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from typing import AsyncIterable
|
||||
|
||||
from dishka import Provider, Scope, provide
|
||||
from google.genai.types import (
|
||||
HarmBlockThreshold,
|
||||
HarmCategory,
|
||||
SafetySettingDict,
|
||||
ThinkingConfigDict,
|
||||
)
|
||||
from pydantic_ai.models.google import GoogleModel, GoogleModelSettings
|
||||
from pydantic_ai.providers.google import GoogleProvider
|
||||
|
||||
from utils.db.models import DynamicConfig, RespondSession
|
||||
from utils.env import env
|
||||
|
||||
from ..types.model import RespondModel, ReviewModel
|
||||
|
||||
GOOGLE_SAFETY_SETTINGS = [
|
||||
SafetySettingDict(category=category, threshold=HarmBlockThreshold.OFF)
|
||||
for category in HarmCategory
|
||||
]
|
||||
|
||||
|
||||
class AIServiceProvider(Provider):
|
||||
@provide(scope=Scope.REQUEST)
|
||||
async def get_google_provider(
|
||||
self, respond_session: RespondSession
|
||||
) -> AsyncIterable[GoogleProvider]:
|
||||
yield GoogleProvider(
|
||||
api_key=respond_session.api_key_override or env.google.api_key
|
||||
)
|
||||
|
||||
|
||||
class ModelProvider(Provider):
|
||||
@provide(scope=Scope.REQUEST)
|
||||
async def get_respond_model(
|
||||
self, config: DynamicConfig, google_provider: GoogleProvider
|
||||
) -> AsyncIterable[RespondModel]:
|
||||
if config.models.respond_model.startswith("gem"):
|
||||
model = GoogleModel(
|
||||
model_name=config.models.respond_model,
|
||||
provider=google_provider,
|
||||
settings=GoogleModelSettings(
|
||||
google_thinking_config=ThinkingConfigDict(thinking_budget=0),
|
||||
google_safety_settings=GOOGLE_SAFETY_SETTINGS,
|
||||
),
|
||||
)
|
||||
yield model
|
||||
|
||||
@provide(scope=Scope.REQUEST)
|
||||
async def get_review_model(
|
||||
self, config: DynamicConfig, google_provider: GoogleProvider
|
||||
) -> AsyncIterable[ReviewModel]:
|
||||
if config.models.message_review_model.startswith("gem"):
|
||||
model = GoogleModel(
|
||||
model_name=config.models.message_review_model,
|
||||
provider=google_provider,
|
||||
settings=GoogleModelSettings(
|
||||
google_thinking_config=ThinkingConfigDict(thinking_budget=0),
|
||||
google_safety_settings=GOOGLE_SAFETY_SETTINGS,
|
||||
),
|
||||
)
|
||||
yield model
|
||||
@@ -1,26 +1,31 @@
|
||||
from typing import AsyncIterable
|
||||
|
||||
import aiogram.types
|
||||
from dishka import Provider, Scope, provide
|
||||
from dishka.integrations.aiogram import AiogramMiddlewareData
|
||||
from google.genai.client import AsyncClient
|
||||
|
||||
from bot.modules.solaris.client import SolarisClient
|
||||
from bot.modules.solaris.services.respond import RespondService
|
||||
from bot.modules.solaris.agents import RespondAgent, ReviewAgent
|
||||
from bot.modules.solaris.prompts import load_prompt
|
||||
from bot.modules.solaris.services import RespondService
|
||||
from utils.db.models import RespondSession, ReviewSession
|
||||
|
||||
from ..types.model import RespondModel, ReviewModel
|
||||
|
||||
|
||||
class SolarisProvider(Provider):
|
||||
@provide(scope=Scope.APP)
|
||||
async def get_solaris_client(
|
||||
self, client: AsyncClient
|
||||
) -> AsyncIterable[SolarisClient]:
|
||||
client = SolarisClient(gemini_client=client)
|
||||
yield client
|
||||
class AgentsProvider(Provider):
|
||||
@provide(scope=Scope.REQUEST)
|
||||
async def get_respond_agent(
|
||||
self, model: RespondModel, session: RespondSession
|
||||
) -> AsyncIterable[RespondAgent]:
|
||||
yield RespondAgent(
|
||||
model=model,
|
||||
system_prompt=session.system_prompt_override
|
||||
or load_prompt("default_system_prompt.txt"),
|
||||
)
|
||||
|
||||
|
||||
class SolarisServicesProvider(Provider):
|
||||
@provide(scope=Scope.REQUEST)
|
||||
async def get_respond_service(
|
||||
self, client: AsyncClient, middleware_data: AiogramMiddlewareData
|
||||
self, agent: RespondAgent, session: RespondSession
|
||||
) -> AsyncIterable[RespondService]:
|
||||
chat: aiogram.types.Chat = middleware_data["event_chat"]
|
||||
service = RespondService(client=client, chat_id=chat.id)
|
||||
service = RespondService(agent=agent, session=session)
|
||||
yield service
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
pass
|
||||
|
||||
6
src/dependencies/types/model.py
Normal file
6
src/dependencies/types/model.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from typing import NewType
|
||||
|
||||
from pydantic_ai.models import Model
|
||||
|
||||
RespondModel = NewType("RespondModel", Model)
|
||||
ReviewModel = NewType("ReviewModel", Model)
|
||||
Reference in New Issue
Block a user