65 lines
2.2 KiB
Python
65 lines
2.2 KiB
Python
from typing import AsyncIterable
|
|
|
|
from dishka import Provider, Scope, provide
|
|
from google.genai.types import (
|
|
HarmBlockThreshold,
|
|
HarmCategory,
|
|
SafetySettingDict,
|
|
ThinkingConfigDict,
|
|
)
|
|
from pydantic_ai.models.google import GoogleModel, GoogleModelSettings
|
|
from pydantic_ai.providers.google import GoogleProvider
|
|
|
|
from utils.db.models import DynamicConfig, RespondSession
|
|
from utils.env import env
|
|
|
|
from ..types.model import RespondModel, ReviewModel
|
|
|
|
GOOGLE_SAFETY_SETTINGS = [
|
|
SafetySettingDict(category=category, threshold=HarmBlockThreshold.OFF)
|
|
for category in HarmCategory
|
|
]
|
|
|
|
|
|
class AIServiceProvider(Provider):
|
|
@provide(scope=Scope.REQUEST)
|
|
async def get_google_provider(
|
|
self, respond_session: RespondSession
|
|
) -> AsyncIterable[GoogleProvider]:
|
|
yield GoogleProvider(
|
|
api_key=respond_session.api_key_override
|
|
or env.google.api_key.get_secret_value()
|
|
)
|
|
|
|
|
|
class ModelProvider(Provider):
|
|
@provide(scope=Scope.REQUEST)
|
|
async def get_respond_model(
|
|
self, config: DynamicConfig, google_provider: GoogleProvider
|
|
) -> AsyncIterable[RespondModel]:
|
|
if config.models.respond_model.startswith("gem"):
|
|
model = GoogleModel(
|
|
model_name=config.models.respond_model,
|
|
provider=google_provider,
|
|
settings=GoogleModelSettings(
|
|
google_thinking_config=ThinkingConfigDict(thinking_budget=0),
|
|
google_safety_settings=GOOGLE_SAFETY_SETTINGS,
|
|
),
|
|
)
|
|
yield model
|
|
|
|
@provide(scope=Scope.REQUEST)
|
|
async def get_review_model(
|
|
self, config: DynamicConfig, google_provider: GoogleProvider
|
|
) -> AsyncIterable[ReviewModel]:
|
|
if config.models.message_review_model.startswith("gem"):
|
|
model = GoogleModel(
|
|
model_name=config.models.message_review_model,
|
|
provider=google_provider,
|
|
settings=GoogleModelSettings(
|
|
google_thinking_config=ThinkingConfigDict(thinking_budget=0),
|
|
google_safety_settings=GOOGLE_SAFETY_SETTINGS,
|
|
),
|
|
)
|
|
yield model
|