Files
solaris-guest-bot/src/dependencies/providers/model.py

65 lines
2.2 KiB
Python

from typing import AsyncIterable
from dishka import Provider, Scope, provide
from google.genai.types import (
HarmBlockThreshold,
HarmCategory,
SafetySettingDict,
ThinkingConfigDict,
)
from pydantic_ai.models.google import GoogleModel, GoogleModelSettings
from pydantic_ai.providers.google import GoogleProvider
from utils.db.models import DynamicConfig, RespondSession
from utils.env import env
from ..types.model import RespondModel, ReviewModel
GOOGLE_SAFETY_SETTINGS = [
SafetySettingDict(category=category, threshold=HarmBlockThreshold.OFF)
for category in HarmCategory
]
class AIServiceProvider(Provider):
@provide(scope=Scope.REQUEST)
async def get_google_provider(
self, respond_session: RespondSession
) -> AsyncIterable[GoogleProvider]:
yield GoogleProvider(
api_key=respond_session.api_key_override
or env.google.api_key.get_secret_value()
)
class ModelProvider(Provider):
@provide(scope=Scope.REQUEST)
async def get_respond_model(
self, config: DynamicConfig, google_provider: GoogleProvider
) -> AsyncIterable[RespondModel]:
if config.models.respond_model.startswith("gem"):
model = GoogleModel(
model_name=config.models.respond_model,
provider=google_provider,
settings=GoogleModelSettings(
google_thinking_config=ThinkingConfigDict(thinking_budget=0),
google_safety_settings=GOOGLE_SAFETY_SETTINGS,
),
)
yield model
@provide(scope=Scope.REQUEST)
async def get_review_model(
self, config: DynamicConfig, google_provider: GoogleProvider
) -> AsyncIterable[ReviewModel]:
if config.models.message_review_model.startswith("gem"):
model = GoogleModel(
model_name=config.models.message_review_model,
provider=google_provider,
settings=GoogleModelSettings(
google_thinking_config=ThinkingConfigDict(thinking_budget=0),
google_safety_settings=GOOGLE_SAFETY_SETTINGS,
),
)
yield model