feat(solaris): moving to service system to support multi-chat

This commit is contained in:
h
2025-07-05 01:13:12 +03:00
parent fd84210a65
commit 41927a1e07
16 changed files with 143 additions and 74 deletions

View File

@@ -23,6 +23,7 @@ async def runner():
await bot.delete_webhook(drop_pending_updates=True) await bot.delete_webhook(drop_pending_updates=True)
await dp.start_polling(bot) await dp.start_polling(bot)
def plugins(): def plugins():
from rich import traceback from rich import traceback

View File

@@ -1,9 +1,12 @@
from aiogram import Router, types from aiogram import Router, types
from aiogram.filters import CommandStart from aiogram.filters import CommandStart
from dishka.integrations.aiogram import FromDishka
from bot.modules.solaris.services.respond import RespondService
router = Router() router = Router()
@router.message(CommandStart()) @router.message(CommandStart())
async def on_start(message: types.Message): async def on_start(message: types.Message, respond_service: FromDishka[RespondService]):
await message.reply("hewo everynyan") await message.reply(str(respond_service.chat_id))

View File

@@ -1,18 +1,31 @@
import json import json
from dataclasses import asdict
from google import genai from google import genai
from google.genai import chats, types
from ..content_configs import generate_respond_config from utils.config import dconfig
from ..constants import SAFETY_SETTINGS
from ..structures import InputMessage, OutputMessage from ..structures import InputMessage, OutputMessage
RESPOND_MODEL = "gemini-2.5-flash"
class RespondAgent: class RespondAgent:
def __init__(self, client: genai.client.AsyncClient, prompt: str) -> None: chat: chats.AsyncChat
self.chat = client.chats.create(
model=RESPOND_MODEL, config=generate_respond_config(prompt=prompt) def __init__(self, client: genai.client.AsyncClient) -> None:
self.client = client
async def load_chat(self, history: list[types.Content], system_prompt: str):
self.chat = self.client.chats.create(
model=(await dconfig()).models.respond_model,
config=types.GenerateContentConfig(
system_instruction=system_prompt,
thinking_config=types.ThinkingConfig(thinking_budget=0),
response_mime_type="application/json",
response_schema=list[OutputMessage],
safety_settings=SAFETY_SETTINGS,
),
history=history,
) )
async def send_messages(self, messages: list[InputMessage]) -> list[OutputMessage]: async def send_messages(self, messages: list[InputMessage]) -> list[OutputMessage]:

View File

@@ -2,10 +2,10 @@ import json
from google import genai from google import genai
from ..content_configs import generate_review_config from ..constants import generate_review_config
from ..structures import InputMessage from ..structures import InputMessage
REVIEW_MODEL = ("gemini-2.5-flash-lite-preview-06-17",) # надо будет гемму REVIEW_MODEL = "gemini-2.5-flash-lite-preview-06-17" # надо будет гемму
class ReviewAgent: class ReviewAgent:

View File

@@ -4,7 +4,7 @@ from google import genai
from google.genai import types from google.genai import types
from pydub import AudioSegment from pydub import AudioSegment
from ..content_configs import generate_tts_config from ..constants import SAFETY_SETTINGS
TTS_MODEL = "gemini-2.5-flash-preview-tts" TTS_MODEL = "gemini-2.5-flash-preview-tts"
@@ -12,18 +12,34 @@ TTS_MODEL = "gemini-2.5-flash-preview-tts"
class TTSAgent: class TTSAgent:
def __init__(self, client: genai.client.AsyncClient) -> None: def __init__(self, client: genai.client.AsyncClient) -> None:
self.client = client self.client = client
self.content_config = generate_tts_config()
self.content_config = types.GenerateContentConfig(
response_modalities=[types.Modality.AUDIO],
speech_config=types.SpeechConfig(
voice_config=types.VoiceConfig(
prebuilt_voice_config=types.PrebuiltVoiceConfig(
voice_name="Kore",
)
)
),
safety_settings=SAFETY_SETTINGS,
)
async def generate(self, text: str): async def generate(self, text: str):
response = await self.client.models.generate_content( response = await self.client.models.generate_content(
model=TTS_MODEL, contents=text, config=self.content_config model=TTS_MODEL, contents=text, config=self.content_config
) )
data = response.candidates[0].content.parts[0].inline_data.data data = response.candidates[0].content.parts[0].inline_data.data
pcm_io = io.BytesIO(data) pcm_io = io.BytesIO(data)
pcm_io.seek(0)
audio = AudioSegment( audio = AudioSegment(
pcm_io.read(), sample_width=2, frame_rate=24000, channels=1 pcm_io.read(), sample_width=2, frame_rate=24000, channels=1
) )
ogg_io = io.BytesIO() ogg_io = io.BytesIO()
audio.export(ogg_io, format="ogg", codec="libopus") audio.export(ogg_io, format="ogg", codec="libopus")
ogg_bytes = ogg_io.getvalue() ogg_bytes = ogg_io.getvalue()
return ogg_bytes return ogg_bytes

View File

@@ -4,9 +4,8 @@ from .agents import BuildAgent
class SolarisClient: class SolarisClient:
def __init__(self, api_key: str) -> None: def __init__(self, gemini_client: genai.client.AsyncClient) -> None:
client = genai.Client(api_key=api_key).aio self.builder = BuildAgent(client=gemini_client)
self.builder = BuildAgent(client=client)
async def parse_user_data(self, some_data_idk): async def parse_user_data(self, some_data_idk):
self.reviewer, self.responder = await self.builder.build( self.reviewer, self.responder = await self.builder.build(

View File

@@ -0,0 +1,28 @@
from google.genai import types
from .structures import OutputMessage
SAFETY_SETTINGS = [
types.SafetySetting(category=category, threshold=types.HarmBlockThreshold.OFF)
for category in types.HarmCategory
]
def generate_respond_config(prompt: str) -> types.GenerateContentConfig:
return
def generate_review_config(prompt: str) -> types.GenerateContentConfig:
return types.GenerateContentConfig(
system_instruction=prompt,
thinking_config=types.ThinkingConfig(thinking_budget=0),
response_mime_type="application/json",
response_schema=list[int],
safety_settings=SAFETY_SETTINGS,
)
# MESSAGE_CONTENT_CONFIG = types.GenerateContentConfig(
# response_mime_type="application/json", # возможно можно скипнуть это если мы используем response_schema, надо проверить
# response_schema=list[OutputMessage],
# )

View File

@@ -1,51 +0,0 @@
from google.genai import types
from .structures import OutputMessage
safety_settings = [
types.SafetySetting(category=category, threshold=types.HarmBlockThreshold.OFF)
for category in types.HarmCategory
]
def generate_respond_config(prompt: str) -> types.GenerateContentConfig:
return types.GenerateContentConfig(
system_instruction=prompt,
thinking_config=types.ThinkingConfig(thinking_budget=0),
response_mime_type="application/json",
response_schema=list[
OutputMessage
], # ты уверен что там json надо? мне просто каж что судя по всему вот эта нам
safety_settings=safety_settings,
)
def generate_review_config(prompt: str) -> types.GenerateContentConfig:
return types.GenerateContentConfig(
system_instruction=prompt,
thinking_config=types.ThinkingConfig(thinking_budget=0),
response_mime_type="application/json",
response_schema=list[int],
safety_settings=safety_settings,
)
# MESSAGE_CONTENT_CONFIG = types.GenerateContentConfig(
# response_mime_type="application/json", # возможно можно скипнуть это если мы используем response_schema, надо проверить
# response_schema=list[OutputMessage],
# )
# можно было и константу оставить но хезе некрасиво чет
def generate_tts_config() -> types.GenerateContentConfig:
return types.GenerateContentConfig(
response_modalities=[types.Modality.AUDIO],
speech_config=types.SpeechConfig(
voice_config=types.VoiceConfig(
prebuilt_voice_config=types.PrebuiltVoiceConfig(
voice_name="Kore",
)
)
),
safety_settings=safety_settings,
)

View File

@@ -0,0 +1,22 @@
import json
from google import genai
from google.genai import chats, types
from utils.config import dconfig
from utils.logging import console
from ..agents.respond import RespondAgent
from ..constants import SAFETY_SETTINGS
from ..structures import InputMessage, OutputMessage
class RespondService:
def __init__(self, client: genai.client.AsyncClient, chat_id: int) -> None:
self.agent = RespondAgent(client)
self.chat_id = chat_id
async def spawn_agent(self):
console.print(self.chat_id)
await self.agent.load_chat(history=[], system_prompt="nya nya")

View File

@@ -1,7 +1,10 @@
from dishka import make_async_container from dishka import make_async_container
from dishka.integrations.aiogram import AiogramProvider
from .providers import SolarisClientProvider from .providers import GeminiClientProvider, SolarisProvider
container = make_async_container( container = make_async_container(
SolarisClientProvider(), AiogramProvider(),
SolarisProvider(),
GeminiClientProvider(),
) )

View File

@@ -1 +1,2 @@
from .solaris import SolarisClientProvider from .gemini import GeminiClientProvider
from .solaris import SolarisProvider

View File

@@ -0,0 +1,13 @@
from typing import AsyncIterable
from dishka import Provider, Scope, provide
from google import genai
from utils.env import env
class GeminiClientProvider(Provider):
@provide(scope=Scope.APP)
async def get_client(self) -> AsyncIterable[genai.client.AsyncClient]:
client = genai.Client(api_key=env.google.api_key.get_secret_value()).aio
yield client

View File

@@ -1,13 +1,27 @@
from typing import AsyncIterable from typing import AsyncIterable
import aiogram.types
from dishka import Provider, Scope, provide from dishka import Provider, Scope, provide
from dishka.integrations.aiogram import AiogramMiddlewareData
from google.genai.client import AsyncClient
from bot.modules.solaris.client import SolarisClient from bot.modules.solaris.client import SolarisClient
from utils.env import env from bot.modules.solaris.services.respond import RespondService
class SolarisClientProvider(Provider): class SolarisProvider(Provider):
@provide(scope=Scope.APP) @provide(scope=Scope.APP)
async def get_client(self) -> AsyncIterable[SolarisClient]: async def get_solaris_client(
client = SolarisClient(env.google.api_key.get_secret_value()) self, client: AsyncClient
) -> AsyncIterable[SolarisClient]:
client = SolarisClient(gemini_client=client)
yield client yield client
@provide(scope=Scope.REQUEST)
async def get_respond_service(
self, client: AsyncClient, middleware_data: AiogramMiddlewareData
) -> AsyncIterable[RespondService]:
chat: aiogram.types.Chat = middleware_data["event_chat"]
service = RespondService(client=client, chat_id=chat.id)
await service.spawn_agent()
yield service

View File

@@ -7,8 +7,15 @@ class BotConfig(BaseModel):
chats_whitelist: list[int] = [] chats_whitelist: list[int] = []
class GeminiModelsConfig(BaseModel):
respond_model: str = "gemini-2.5-flash"
message_review_model: str = "gemini-2.5-flash-lite-preview-06-17"
tts_model: str = "gemini-2.5-flash-preview-tts"
class DynamicConfigBase(BaseModel): class DynamicConfigBase(BaseModel):
bot: BotConfig = Field(default_factory=BotConfig) bot: BotConfig = Field(default_factory=BotConfig)
models: GeminiModelsConfig = Field(default_factory=GeminiModelsConfig)
class DynamicConfig(DynamicConfigBase, Document): class DynamicConfig(DynamicConfigBase, Document):