feat(solaris): migrate to pydantic ai, wired respond agent through di providers
This commit is contained in:
@@ -10,7 +10,7 @@ dependencies = [
|
|||||||
"aiogram>=3.20.0.post0",
|
"aiogram>=3.20.0.post0",
|
||||||
"beanie>=2.0.0",
|
"beanie>=2.0.0",
|
||||||
"dishka>=1.6.0",
|
"dishka>=1.6.0",
|
||||||
"google-genai>=1.23.0",
|
"pydantic-ai>=0.4.4",
|
||||||
"pydantic-settings>=2.10.1",
|
"pydantic-settings>=2.10.1",
|
||||||
"pydub>=0.25.1",
|
"pydub>=0.25.1",
|
||||||
"rich>=14.0.0",
|
"rich>=14.0.0",
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
import json
|
|
||||||
|
|
||||||
from google import genai
|
from google import genai
|
||||||
|
|
||||||
from .respond import RespondAgent
|
from .respond import RespondAgent
|
||||||
|
|||||||
@@ -1,32 +1,25 @@
|
|||||||
import json
|
from typing import List
|
||||||
|
|
||||||
from google import genai
|
from pydantic_ai import Agent
|
||||||
from google.genai import chats, types
|
from pydantic_ai.messages import ModelMessage
|
||||||
|
|
||||||
from utils.config import dconfig
|
from dependencies.types.model import RespondModel
|
||||||
|
|
||||||
from ..constants import SAFETY_SETTINGS
|
from ..structures import OutputMessage
|
||||||
from ..structures import InputMessage, OutputMessage
|
from ..tools import RESPOND_TOOLSET
|
||||||
from ..tools import RESPOND_TOOLS
|
|
||||||
|
|
||||||
|
|
||||||
class RespondAgent:
|
class RespondAgent:
|
||||||
def __init__(self, client: genai.client.AsyncClient) -> None:
|
def __init__(self, model: RespondModel, system_prompt: str) -> None:
|
||||||
self.client = client
|
self.agent = Agent(
|
||||||
|
model=model,
|
||||||
|
instructions=system_prompt,
|
||||||
|
output_type=list[OutputMessage],
|
||||||
|
toolsets=[RESPOND_TOOLSET],
|
||||||
|
)
|
||||||
|
|
||||||
async def generate_response(
|
async def generate_response(
|
||||||
self, system_prompt: str, history: list[types.Content]
|
self, request: str, history: List[ModelMessage]
|
||||||
) -> list[OutputMessage]:
|
) -> tuple[list[OutputMessage], List[ModelMessage]]:
|
||||||
content = await self.client.models.generate_content(
|
result = await self.agent.run(user_prompt=request, message_history=history)
|
||||||
model=(await dconfig()).models.respond_model,
|
return result.output, result.all_messages()
|
||||||
contents=history,
|
|
||||||
config=types.GenerateContentConfig(
|
|
||||||
system_instruction=system_prompt,
|
|
||||||
thinking_config=types.ThinkingConfig(thinking_budget=0),
|
|
||||||
response_mime_type="application/json",
|
|
||||||
response_schema=list[OutputMessage],
|
|
||||||
# safety_settings=SAFETY_SETTINGS,
|
|
||||||
tools=RESPOND_TOOLS,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return content.parsed
|
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ from google import genai
|
|||||||
from google.genai import types
|
from google.genai import types
|
||||||
from pydub import AudioSegment
|
from pydub import AudioSegment
|
||||||
|
|
||||||
from ..constants import SAFETY_SETTINGS
|
|
||||||
|
|
||||||
TTS_MODEL = "gemini-2.5-flash-preview-tts"
|
TTS_MODEL = "gemini-2.5-flash-preview-tts"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
from google.genai import types
|
from google.genai import types
|
||||||
|
|
||||||
from .structures import OutputMessage
|
|
||||||
|
|
||||||
SAFETY_SETTINGS = [
|
SAFETY_SETTINGS = [
|
||||||
types.SafetySetting(
|
types.SafetySetting(
|
||||||
category=category.value, threshold=types.HarmBlockThreshold.OFF.value
|
category=category.value, threshold=types.HarmBlockThreshold.OFF.value
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
from .respond import RespondService
|
||||||
|
|||||||
@@ -1,55 +1,20 @@
|
|||||||
import json
|
|
||||||
|
|
||||||
from beanie.odm.operators.update.general import Set
|
|
||||||
from google import genai
|
|
||||||
from google.genai import chats, types
|
|
||||||
|
|
||||||
from utils.config import dconfig
|
|
||||||
from utils.db.models.session import RespondSession
|
from utils.db.models.session import RespondSession
|
||||||
from utils.logging import console
|
|
||||||
|
|
||||||
from ..agents.respond import RespondAgent
|
from ..agents.respond import RespondAgent
|
||||||
from ..constants import SAFETY_SETTINGS
|
|
||||||
from ..prompts import load_prompt
|
from ..prompts import load_prompt
|
||||||
from ..structures import InputMessage, OutputMessage
|
from ..structures import InputMessage, OutputMessage
|
||||||
|
|
||||||
|
|
||||||
class RespondService:
|
class RespondService:
|
||||||
def __init__(self, client: genai.client.AsyncClient, chat_id: int) -> None:
|
def __init__(self, agent: RespondAgent, session: RespondSession) -> None:
|
||||||
self.agent = RespondAgent(client)
|
self.agent = agent
|
||||||
self.chat_id = chat_id
|
|
||||||
self.session: RespondSession | None = None
|
|
||||||
|
|
||||||
async def _get_or_create_session(self) -> RespondSession:
|
|
||||||
session = await RespondSession.get_by_chat_id(chat_id=self.chat_id)
|
|
||||||
if not session:
|
|
||||||
session = await RespondSession.create_empty(
|
|
||||||
chat_id=self.chat_id,
|
|
||||||
system_prompt=load_prompt("default_system_prompt.txt"),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.session = session
|
self.session = session
|
||||||
return session
|
|
||||||
|
|
||||||
async def process_message(self, message: InputMessage) -> list[OutputMessage]:
|
async def process_message(self, message: InputMessage) -> list[OutputMessage]:
|
||||||
session = await self._get_or_create_session()
|
await self.session.sync()
|
||||||
|
result, history = await self.agent.generate_response(
|
||||||
session.history.append(
|
request=message.text, history=self.session.history
|
||||||
types.Content(role="user", parts=[types.Part(text=message.text)])
|
|
||||||
)
|
)
|
||||||
|
await self.session.update_history(history=history)
|
||||||
|
|
||||||
response_messages = await self.agent.generate_response(
|
return result
|
||||||
system_prompt=session.system_prompt, history=session.history
|
|
||||||
)
|
|
||||||
|
|
||||||
if response_messages:
|
|
||||||
model_response_content = response_messages[0].text
|
|
||||||
session.history.append(
|
|
||||||
types.Content(
|
|
||||||
role="model", parts=[types.Part(text=model_response_content)]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
await session.update_history(history=session.history)
|
|
||||||
|
|
||||||
return response_messages
|
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
from .test import test_tool
|
from pydantic_ai.toolsets import CombinedToolset
|
||||||
|
|
||||||
RESPOND_TOOLS = [test_tool]
|
from .test import TEST_TOOLSET
|
||||||
|
|
||||||
|
RESPOND_TOOLSET = CombinedToolset([TEST_TOOLSET])
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from pydantic_ai.toolsets import FunctionToolset
|
||||||
|
|
||||||
from utils.logging import console
|
from utils.logging import console
|
||||||
|
|
||||||
|
|
||||||
@@ -12,3 +14,6 @@ async def test_tool(content: str):
|
|||||||
"""
|
"""
|
||||||
console.print(content)
|
console.print(content)
|
||||||
return "ok"
|
return "ok"
|
||||||
|
|
||||||
|
|
||||||
|
TEST_TOOLSET = FunctionToolset(tools=[test_tool])
|
||||||
|
|||||||
@@ -1,10 +1,19 @@
|
|||||||
from dishka import make_async_container
|
from dishka import make_async_container
|
||||||
from dishka.integrations.aiogram import AiogramProvider
|
from dishka.integrations.aiogram import AiogramProvider
|
||||||
|
|
||||||
from .providers import GeminiClientProvider, SolarisProvider
|
from .providers import (
|
||||||
|
AIServiceProvider,
|
||||||
|
ConfigProvider,
|
||||||
|
ModelProvider,
|
||||||
|
SessionProvider,
|
||||||
|
SolarisServicesProvider,
|
||||||
|
)
|
||||||
|
|
||||||
container = make_async_container(
|
container = make_async_container(
|
||||||
AiogramProvider(),
|
AiogramProvider(),
|
||||||
SolarisProvider(),
|
SolarisServicesProvider(),
|
||||||
GeminiClientProvider(),
|
SessionProvider(),
|
||||||
|
ModelProvider(),
|
||||||
|
ConfigProvider(),
|
||||||
|
AIServiceProvider(),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,2 +1,3 @@
|
|||||||
from .gemini import GeminiClientProvider
|
from .database import ConfigProvider, SessionProvider
|
||||||
from .solaris import SolarisProvider
|
from .model import AIServiceProvider, ModelProvider
|
||||||
|
from .solaris import SolarisServicesProvider
|
||||||
|
|||||||
31
src/dependencies/providers/database.py
Normal file
31
src/dependencies/providers/database.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
from typing import AsyncIterable
|
||||||
|
|
||||||
|
import aiogram.types
|
||||||
|
from dishka import Provider, Scope, provide
|
||||||
|
from dishka.integrations.aiogram import AiogramMiddlewareData
|
||||||
|
|
||||||
|
from utils.db.models import DynamicConfig, RespondSession, ReviewSession
|
||||||
|
|
||||||
|
|
||||||
|
class SessionProvider(Provider):
|
||||||
|
@provide(scope=Scope.REQUEST)
|
||||||
|
async def provide_respond_session(
|
||||||
|
self, middleware_data: AiogramMiddlewareData
|
||||||
|
) -> AsyncIterable[RespondSession]:
|
||||||
|
chat: aiogram.types.Chat = middleware_data["event_chat"]
|
||||||
|
session = await RespondSession.get_or_create_by_chat_id(chat_id=chat.id)
|
||||||
|
yield session
|
||||||
|
|
||||||
|
@provide(scope=Scope.REQUEST)
|
||||||
|
async def provide_review_session(
|
||||||
|
self, middleware_data: AiogramMiddlewareData
|
||||||
|
) -> AsyncIterable[ReviewSession]:
|
||||||
|
chat: aiogram.types.Chat = middleware_data["event_chat"]
|
||||||
|
session = await ReviewSession.get_or_create_by_chat_id(chat_id=chat.id)
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigProvider(Provider):
|
||||||
|
@provide(scope=Scope.REQUEST)
|
||||||
|
async def provide_config(self) -> AsyncIterable[DynamicConfig]:
|
||||||
|
yield await DynamicConfig.get_or_create()
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
from typing import AsyncIterable
|
|
||||||
|
|
||||||
from dishka import Provider, Scope, provide
|
|
||||||
from google import genai
|
|
||||||
|
|
||||||
from utils.env import env
|
|
||||||
|
|
||||||
|
|
||||||
class GeminiClientProvider(Provider):
|
|
||||||
@provide(scope=Scope.APP)
|
|
||||||
async def get_client(self) -> AsyncIterable[genai.client.AsyncClient]:
|
|
||||||
client = genai.Client(api_key=env.google.api_key.get_secret_value()).aio
|
|
||||||
yield client
|
|
||||||
63
src/dependencies/providers/model.py
Normal file
63
src/dependencies/providers/model.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
from typing import AsyncIterable
|
||||||
|
|
||||||
|
from dishka import Provider, Scope, provide
|
||||||
|
from google.genai.types import (
|
||||||
|
HarmBlockThreshold,
|
||||||
|
HarmCategory,
|
||||||
|
SafetySettingDict,
|
||||||
|
ThinkingConfigDict,
|
||||||
|
)
|
||||||
|
from pydantic_ai.models.google import GoogleModel, GoogleModelSettings
|
||||||
|
from pydantic_ai.providers.google import GoogleProvider
|
||||||
|
|
||||||
|
from utils.db.models import DynamicConfig, RespondSession
|
||||||
|
from utils.env import env
|
||||||
|
|
||||||
|
from ..types.model import RespondModel, ReviewModel
|
||||||
|
|
||||||
|
GOOGLE_SAFETY_SETTINGS = [
|
||||||
|
SafetySettingDict(category=category, threshold=HarmBlockThreshold.OFF)
|
||||||
|
for category in HarmCategory
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class AIServiceProvider(Provider):
|
||||||
|
@provide(scope=Scope.REQUEST)
|
||||||
|
async def get_google_provider(
|
||||||
|
self, respond_session: RespondSession
|
||||||
|
) -> AsyncIterable[GoogleProvider]:
|
||||||
|
yield GoogleProvider(
|
||||||
|
api_key=respond_session.api_key_override or env.google.api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelProvider(Provider):
|
||||||
|
@provide(scope=Scope.REQUEST)
|
||||||
|
async def get_respond_model(
|
||||||
|
self, config: DynamicConfig, google_provider: GoogleProvider
|
||||||
|
) -> AsyncIterable[RespondModel]:
|
||||||
|
if config.models.respond_model.startswith("gem"):
|
||||||
|
model = GoogleModel(
|
||||||
|
model_name=config.models.respond_model,
|
||||||
|
provider=google_provider,
|
||||||
|
settings=GoogleModelSettings(
|
||||||
|
google_thinking_config=ThinkingConfigDict(thinking_budget=0),
|
||||||
|
google_safety_settings=GOOGLE_SAFETY_SETTINGS,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
yield model
|
||||||
|
|
||||||
|
@provide(scope=Scope.REQUEST)
|
||||||
|
async def get_review_model(
|
||||||
|
self, config: DynamicConfig, google_provider: GoogleProvider
|
||||||
|
) -> AsyncIterable[ReviewModel]:
|
||||||
|
if config.models.message_review_model.startswith("gem"):
|
||||||
|
model = GoogleModel(
|
||||||
|
model_name=config.models.message_review_model,
|
||||||
|
provider=google_provider,
|
||||||
|
settings=GoogleModelSettings(
|
||||||
|
google_thinking_config=ThinkingConfigDict(thinking_budget=0),
|
||||||
|
google_safety_settings=GOOGLE_SAFETY_SETTINGS,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
yield model
|
||||||
@@ -1,26 +1,31 @@
|
|||||||
from typing import AsyncIterable
|
from typing import AsyncIterable
|
||||||
|
|
||||||
import aiogram.types
|
|
||||||
from dishka import Provider, Scope, provide
|
from dishka import Provider, Scope, provide
|
||||||
from dishka.integrations.aiogram import AiogramMiddlewareData
|
|
||||||
from google.genai.client import AsyncClient
|
|
||||||
|
|
||||||
from bot.modules.solaris.client import SolarisClient
|
from bot.modules.solaris.agents import RespondAgent, ReviewAgent
|
||||||
from bot.modules.solaris.services.respond import RespondService
|
from bot.modules.solaris.prompts import load_prompt
|
||||||
|
from bot.modules.solaris.services import RespondService
|
||||||
|
from utils.db.models import RespondSession, ReviewSession
|
||||||
|
|
||||||
|
from ..types.model import RespondModel, ReviewModel
|
||||||
|
|
||||||
|
|
||||||
class SolarisProvider(Provider):
|
class AgentsProvider(Provider):
|
||||||
@provide(scope=Scope.APP)
|
@provide(scope=Scope.REQUEST)
|
||||||
async def get_solaris_client(
|
async def get_respond_agent(
|
||||||
self, client: AsyncClient
|
self, model: RespondModel, session: RespondSession
|
||||||
) -> AsyncIterable[SolarisClient]:
|
) -> AsyncIterable[RespondAgent]:
|
||||||
client = SolarisClient(gemini_client=client)
|
yield RespondAgent(
|
||||||
yield client
|
model=model,
|
||||||
|
system_prompt=session.system_prompt_override
|
||||||
|
or load_prompt("default_system_prompt.txt"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SolarisServicesProvider(Provider):
|
||||||
@provide(scope=Scope.REQUEST)
|
@provide(scope=Scope.REQUEST)
|
||||||
async def get_respond_service(
|
async def get_respond_service(
|
||||||
self, client: AsyncClient, middleware_data: AiogramMiddlewareData
|
self, agent: RespondAgent, session: RespondSession
|
||||||
) -> AsyncIterable[RespondService]:
|
) -> AsyncIterable[RespondService]:
|
||||||
chat: aiogram.types.Chat = middleware_data["event_chat"]
|
service = RespondService(agent=agent, session=session)
|
||||||
service = RespondService(client=client, chat_id=chat.id)
|
|
||||||
yield service
|
yield service
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
pass
|
||||||
|
|||||||
6
src/dependencies/types/model.py
Normal file
6
src/dependencies/types/model.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from typing import NewType
|
||||||
|
|
||||||
|
from pydantic_ai.models import Model
|
||||||
|
|
||||||
|
RespondModel = NewType("RespondModel", Model)
|
||||||
|
ReviewModel = NewType("ReviewModel", Model)
|
||||||
@@ -1,14 +1,15 @@
|
|||||||
from typing import Annotated, List
|
from typing import Annotated, List
|
||||||
|
|
||||||
from beanie import Document, Indexed
|
from beanie import Document, Indexed
|
||||||
from google.genai.types import Content
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from pydantic_ai.messages import ModelMessage
|
||||||
|
|
||||||
|
|
||||||
class SessionBase(BaseModel):
|
class SessionBase(BaseModel):
|
||||||
chat_id: Annotated[int, Indexed(unique=True)]
|
chat_id: Annotated[int, Indexed(unique=True)]
|
||||||
system_prompt: str
|
system_prompt_override: str = None
|
||||||
history: List[Content] = Field(default_factory=list)
|
history: List[ModelMessage] = Field(default_factory=list)
|
||||||
|
api_key_override: str = None
|
||||||
|
|
||||||
|
|
||||||
class __CommonSessionRepository(SessionBase, Document):
|
class __CommonSessionRepository(SessionBase, Document):
|
||||||
@@ -17,10 +18,17 @@ class __CommonSessionRepository(SessionBase, Document):
|
|||||||
return await cls.find_one(cls.chat_id == chat_id)
|
return await cls.find_one(cls.chat_id == chat_id)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create_empty(cls, chat_id: int, system_prompt: str):
|
async def create_empty(cls, chat_id: int):
|
||||||
return await cls(chat_id=chat_id, system_prompt=system_prompt).insert()
|
return await cls(chat_id=chat_id).insert()
|
||||||
|
|
||||||
async def update_history(self, history: List[Content]):
|
@classmethod
|
||||||
|
async def get_or_create_by_chat_id(cls, chat_id: int):
|
||||||
|
session = await cls.get_by_chat_id(chat_id)
|
||||||
|
if not session:
|
||||||
|
session = await cls.create_empty(chat_id=chat_id)
|
||||||
|
return session
|
||||||
|
|
||||||
|
async def update_history(self, history: List[ModelMessage]):
|
||||||
await self.set({self.history: history})
|
await self.set({self.history: history})
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user