feat(solaris): migrate to pydantic ai, wired respond agent through di providers
This commit is contained in:
63
src/dependencies/providers/model.py
Normal file
63
src/dependencies/providers/model.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from typing import AsyncIterable
|
||||
|
||||
from dishka import Provider, Scope, provide
|
||||
from google.genai.types import (
|
||||
HarmBlockThreshold,
|
||||
HarmCategory,
|
||||
SafetySettingDict,
|
||||
ThinkingConfigDict,
|
||||
)
|
||||
from pydantic_ai.models.google import GoogleModel, GoogleModelSettings
|
||||
from pydantic_ai.providers.google import GoogleProvider
|
||||
|
||||
from utils.db.models import DynamicConfig, RespondSession
|
||||
from utils.env import env
|
||||
|
||||
from ..types.model import RespondModel, ReviewModel
|
||||
|
||||
GOOGLE_SAFETY_SETTINGS = [
|
||||
SafetySettingDict(category=category, threshold=HarmBlockThreshold.OFF)
|
||||
for category in HarmCategory
|
||||
]
|
||||
|
||||
|
||||
class AIServiceProvider(Provider):
|
||||
@provide(scope=Scope.REQUEST)
|
||||
async def get_google_provider(
|
||||
self, respond_session: RespondSession
|
||||
) -> AsyncIterable[GoogleProvider]:
|
||||
yield GoogleProvider(
|
||||
api_key=respond_session.api_key_override or env.google.api_key
|
||||
)
|
||||
|
||||
|
||||
class ModelProvider(Provider):
|
||||
@provide(scope=Scope.REQUEST)
|
||||
async def get_respond_model(
|
||||
self, config: DynamicConfig, google_provider: GoogleProvider
|
||||
) -> AsyncIterable[RespondModel]:
|
||||
if config.models.respond_model.startswith("gem"):
|
||||
model = GoogleModel(
|
||||
model_name=config.models.respond_model,
|
||||
provider=google_provider,
|
||||
settings=GoogleModelSettings(
|
||||
google_thinking_config=ThinkingConfigDict(thinking_budget=0),
|
||||
google_safety_settings=GOOGLE_SAFETY_SETTINGS,
|
||||
),
|
||||
)
|
||||
yield model
|
||||
|
||||
@provide(scope=Scope.REQUEST)
|
||||
async def get_review_model(
|
||||
self, config: DynamicConfig, google_provider: GoogleProvider
|
||||
) -> AsyncIterable[ReviewModel]:
|
||||
if config.models.message_review_model.startswith("gem"):
|
||||
model = GoogleModel(
|
||||
model_name=config.models.message_review_model,
|
||||
provider=google_provider,
|
||||
settings=GoogleModelSettings(
|
||||
google_thinking_config=ThinkingConfigDict(thinking_budget=0),
|
||||
google_safety_settings=GOOGLE_SAFETY_SETTINGS,
|
||||
),
|
||||
)
|
||||
yield model
|
||||
Reference in New Issue
Block a user