from typing import AsyncIterable from dishka import Provider, Scope, provide from google.genai.types import ( HarmBlockThreshold, HarmCategory, SafetySettingDict, ThinkingConfigDict, ) from pydantic_ai.models.google import GoogleModel, GoogleModelSettings from pydantic_ai.providers.google import GoogleProvider from utils.db.models import DynamicConfig, RespondSession from utils.env import env from ..types.model import RespondModel, ReviewModel GOOGLE_SAFETY_SETTINGS = [ SafetySettingDict(category=category, threshold=HarmBlockThreshold.OFF) for category in HarmCategory ] class AIServiceProvider(Provider): @provide(scope=Scope.REQUEST) async def get_google_provider( self, respond_session: RespondSession ) -> AsyncIterable[GoogleProvider]: yield GoogleProvider( api_key=respond_session.api_key_override or env.google.api_key.get_secret_value() ) class ModelProvider(Provider): @provide(scope=Scope.REQUEST) async def get_respond_model( self, config: DynamicConfig, google_provider: GoogleProvider ) -> AsyncIterable[RespondModel]: if config.models.respond_model.startswith("gem"): model = GoogleModel( model_name=config.models.respond_model, provider=google_provider, settings=GoogleModelSettings( google_thinking_config=ThinkingConfigDict(thinking_budget=0), google_safety_settings=GOOGLE_SAFETY_SETTINGS, ), ) yield model @provide(scope=Scope.REQUEST) async def get_review_model( self, config: DynamicConfig, google_provider: GoogleProvider ) -> AsyncIterable[ReviewModel]: if config.models.message_review_model.startswith("gem"): model = GoogleModel( model_name=config.models.message_review_model, provider=google_provider, settings=GoogleModelSettings( google_thinking_config=ThinkingConfigDict(thinking_budget=0), google_safety_settings=GOOGLE_SAFETY_SETTINGS, ), ) yield model