feat(solaris): moving to service system to support multi-chat
This commit is contained in:
@@ -23,6 +23,7 @@ async def runner():
|
||||
await bot.delete_webhook(drop_pending_updates=True)
|
||||
await dp.start_polling(bot)
|
||||
|
||||
|
||||
def plugins():
|
||||
from rich import traceback
|
||||
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from aiogram import Router, types
|
||||
from aiogram.filters import CommandStart
|
||||
from dishka.integrations.aiogram import FromDishka
|
||||
|
||||
from bot.modules.solaris.services.respond import RespondService
|
||||
|
||||
router = Router()
|
||||
|
||||
|
||||
@router.message(CommandStart())
|
||||
async def on_start(message: types.Message):
|
||||
await message.reply("hewo everynyan")
|
||||
async def on_start(message: types.Message, respond_service: FromDishka[RespondService]):
|
||||
await message.reply(str(respond_service.chat_id))
|
||||
|
||||
@@ -1,18 +1,31 @@
|
||||
import json
|
||||
from dataclasses import asdict
|
||||
|
||||
from google import genai
|
||||
from google.genai import chats, types
|
||||
|
||||
from ..content_configs import generate_respond_config
|
||||
from utils.config import dconfig
|
||||
|
||||
from ..constants import SAFETY_SETTINGS
|
||||
from ..structures import InputMessage, OutputMessage
|
||||
|
||||
RESPOND_MODEL = "gemini-2.5-flash"
|
||||
|
||||
|
||||
class RespondAgent:
|
||||
def __init__(self, client: genai.client.AsyncClient, prompt: str) -> None:
|
||||
self.chat = client.chats.create(
|
||||
model=RESPOND_MODEL, config=generate_respond_config(prompt=prompt)
|
||||
chat: chats.AsyncChat
|
||||
|
||||
def __init__(self, client: genai.client.AsyncClient) -> None:
|
||||
self.client = client
|
||||
|
||||
async def load_chat(self, history: list[types.Content], system_prompt: str):
|
||||
self.chat = self.client.chats.create(
|
||||
model=(await dconfig()).models.respond_model,
|
||||
config=types.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
thinking_config=types.ThinkingConfig(thinking_budget=0),
|
||||
response_mime_type="application/json",
|
||||
response_schema=list[OutputMessage],
|
||||
safety_settings=SAFETY_SETTINGS,
|
||||
),
|
||||
history=history,
|
||||
)
|
||||
|
||||
async def send_messages(self, messages: list[InputMessage]) -> list[OutputMessage]:
|
||||
|
||||
@@ -2,10 +2,10 @@ import json
|
||||
|
||||
from google import genai
|
||||
|
||||
from ..content_configs import generate_review_config
|
||||
from ..constants import generate_review_config
|
||||
from ..structures import InputMessage
|
||||
|
||||
REVIEW_MODEL = ("gemini-2.5-flash-lite-preview-06-17",) # надо будет гемму
|
||||
REVIEW_MODEL = "gemini-2.5-flash-lite-preview-06-17" # надо будет гемму
|
||||
|
||||
|
||||
class ReviewAgent:
|
||||
|
||||
@@ -4,7 +4,7 @@ from google import genai
|
||||
from google.genai import types
|
||||
from pydub import AudioSegment
|
||||
|
||||
from ..content_configs import generate_tts_config
|
||||
from ..constants import SAFETY_SETTINGS
|
||||
|
||||
TTS_MODEL = "gemini-2.5-flash-preview-tts"
|
||||
|
||||
@@ -12,18 +12,34 @@ TTS_MODEL = "gemini-2.5-flash-preview-tts"
|
||||
class TTSAgent:
|
||||
def __init__(self, client: genai.client.AsyncClient) -> None:
|
||||
self.client = client
|
||||
self.content_config = generate_tts_config()
|
||||
|
||||
self.content_config = types.GenerateContentConfig(
|
||||
response_modalities=[types.Modality.AUDIO],
|
||||
speech_config=types.SpeechConfig(
|
||||
voice_config=types.VoiceConfig(
|
||||
prebuilt_voice_config=types.PrebuiltVoiceConfig(
|
||||
voice_name="Kore",
|
||||
)
|
||||
)
|
||||
),
|
||||
safety_settings=SAFETY_SETTINGS,
|
||||
)
|
||||
|
||||
async def generate(self, text: str):
|
||||
response = await self.client.models.generate_content(
|
||||
model=TTS_MODEL, contents=text, config=self.content_config
|
||||
)
|
||||
|
||||
data = response.candidates[0].content.parts[0].inline_data.data
|
||||
pcm_io = io.BytesIO(data)
|
||||
pcm_io.seek(0)
|
||||
|
||||
audio = AudioSegment(
|
||||
pcm_io.read(), sample_width=2, frame_rate=24000, channels=1
|
||||
)
|
||||
|
||||
ogg_io = io.BytesIO()
|
||||
audio.export(ogg_io, format="ogg", codec="libopus")
|
||||
ogg_bytes = ogg_io.getvalue()
|
||||
|
||||
return ogg_bytes
|
||||
|
||||
@@ -4,9 +4,8 @@ from .agents import BuildAgent
|
||||
|
||||
|
||||
class SolarisClient:
|
||||
def __init__(self, api_key: str) -> None:
|
||||
client = genai.Client(api_key=api_key).aio
|
||||
self.builder = BuildAgent(client=client)
|
||||
def __init__(self, gemini_client: genai.client.AsyncClient) -> None:
|
||||
self.builder = BuildAgent(client=gemini_client)
|
||||
|
||||
async def parse_user_data(self, some_data_idk):
|
||||
self.reviewer, self.responder = await self.builder.build(
|
||||
|
||||
28
src/bot/modules/solaris/constants.py
Normal file
28
src/bot/modules/solaris/constants.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from google.genai import types
|
||||
|
||||
from .structures import OutputMessage
|
||||
|
||||
SAFETY_SETTINGS = [
|
||||
types.SafetySetting(category=category, threshold=types.HarmBlockThreshold.OFF)
|
||||
for category in types.HarmCategory
|
||||
]
|
||||
|
||||
|
||||
def generate_respond_config(prompt: str) -> types.GenerateContentConfig:
|
||||
return
|
||||
|
||||
|
||||
def generate_review_config(prompt: str) -> types.GenerateContentConfig:
|
||||
return types.GenerateContentConfig(
|
||||
system_instruction=prompt,
|
||||
thinking_config=types.ThinkingConfig(thinking_budget=0),
|
||||
response_mime_type="application/json",
|
||||
response_schema=list[int],
|
||||
safety_settings=SAFETY_SETTINGS,
|
||||
)
|
||||
|
||||
|
||||
# MESSAGE_CONTENT_CONFIG = types.GenerateContentConfig(
|
||||
# response_mime_type="application/json", # возможно можно скипнуть это если мы используем response_schema, надо проверить
|
||||
# response_schema=list[OutputMessage],
|
||||
# )
|
||||
@@ -1,51 +0,0 @@
|
||||
from google.genai import types
|
||||
|
||||
from .structures import OutputMessage
|
||||
|
||||
safety_settings = [
|
||||
types.SafetySetting(category=category, threshold=types.HarmBlockThreshold.OFF)
|
||||
for category in types.HarmCategory
|
||||
]
|
||||
|
||||
|
||||
def generate_respond_config(prompt: str) -> types.GenerateContentConfig:
|
||||
return types.GenerateContentConfig(
|
||||
system_instruction=prompt,
|
||||
thinking_config=types.ThinkingConfig(thinking_budget=0),
|
||||
response_mime_type="application/json",
|
||||
response_schema=list[
|
||||
OutputMessage
|
||||
], # ты уверен что там json надо? мне просто каж что судя по всему вот эта нам
|
||||
safety_settings=safety_settings,
|
||||
)
|
||||
|
||||
|
||||
def generate_review_config(prompt: str) -> types.GenerateContentConfig:
|
||||
return types.GenerateContentConfig(
|
||||
system_instruction=prompt,
|
||||
thinking_config=types.ThinkingConfig(thinking_budget=0),
|
||||
response_mime_type="application/json",
|
||||
response_schema=list[int],
|
||||
safety_settings=safety_settings,
|
||||
)
|
||||
|
||||
|
||||
# MESSAGE_CONTENT_CONFIG = types.GenerateContentConfig(
|
||||
# response_mime_type="application/json", # возможно можно скипнуть это если мы используем response_schema, надо проверить
|
||||
# response_schema=list[OutputMessage],
|
||||
# )
|
||||
|
||||
|
||||
# можно было и константу оставить но хезе некрасиво чет
|
||||
def generate_tts_config() -> types.GenerateContentConfig:
|
||||
return types.GenerateContentConfig(
|
||||
response_modalities=[types.Modality.AUDIO],
|
||||
speech_config=types.SpeechConfig(
|
||||
voice_config=types.VoiceConfig(
|
||||
prebuilt_voice_config=types.PrebuiltVoiceConfig(
|
||||
voice_name="Kore",
|
||||
)
|
||||
)
|
||||
),
|
||||
safety_settings=safety_settings,
|
||||
)
|
||||
0
src/bot/modules/solaris/services/__init__.py
Normal file
0
src/bot/modules/solaris/services/__init__.py
Normal file
22
src/bot/modules/solaris/services/respond.py
Normal file
22
src/bot/modules/solaris/services/respond.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import json
|
||||
|
||||
from google import genai
|
||||
from google.genai import chats, types
|
||||
|
||||
from utils.config import dconfig
|
||||
from utils.logging import console
|
||||
|
||||
from ..agents.respond import RespondAgent
|
||||
from ..constants import SAFETY_SETTINGS
|
||||
from ..structures import InputMessage, OutputMessage
|
||||
|
||||
|
||||
class RespondService:
|
||||
def __init__(self, client: genai.client.AsyncClient, chat_id: int) -> None:
|
||||
self.agent = RespondAgent(client)
|
||||
self.chat_id = chat_id
|
||||
|
||||
async def spawn_agent(self):
|
||||
console.print(self.chat_id)
|
||||
|
||||
await self.agent.load_chat(history=[], system_prompt="nya nya")
|
||||
0
src/bot/modules/solaris/tools/__init__.py
Normal file
0
src/bot/modules/solaris/tools/__init__.py
Normal file
@@ -1,7 +1,10 @@
|
||||
from dishka import make_async_container
|
||||
from dishka.integrations.aiogram import AiogramProvider
|
||||
|
||||
from .providers import SolarisClientProvider
|
||||
from .providers import GeminiClientProvider, SolarisProvider
|
||||
|
||||
container = make_async_container(
|
||||
SolarisClientProvider(),
|
||||
AiogramProvider(),
|
||||
SolarisProvider(),
|
||||
GeminiClientProvider(),
|
||||
)
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
from .solaris import SolarisClientProvider
|
||||
from .gemini import GeminiClientProvider
|
||||
from .solaris import SolarisProvider
|
||||
|
||||
13
src/dependencies/providers/gemini.py
Normal file
13
src/dependencies/providers/gemini.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from typing import AsyncIterable
|
||||
|
||||
from dishka import Provider, Scope, provide
|
||||
from google import genai
|
||||
|
||||
from utils.env import env
|
||||
|
||||
|
||||
class GeminiClientProvider(Provider):
|
||||
@provide(scope=Scope.APP)
|
||||
async def get_client(self) -> AsyncIterable[genai.client.AsyncClient]:
|
||||
client = genai.Client(api_key=env.google.api_key.get_secret_value()).aio
|
||||
yield client
|
||||
@@ -1,13 +1,27 @@
|
||||
from typing import AsyncIterable
|
||||
|
||||
import aiogram.types
|
||||
from dishka import Provider, Scope, provide
|
||||
from dishka.integrations.aiogram import AiogramMiddlewareData
|
||||
from google.genai.client import AsyncClient
|
||||
|
||||
from bot.modules.solaris.client import SolarisClient
|
||||
from utils.env import env
|
||||
from bot.modules.solaris.services.respond import RespondService
|
||||
|
||||
|
||||
class SolarisClientProvider(Provider):
|
||||
class SolarisProvider(Provider):
|
||||
@provide(scope=Scope.APP)
|
||||
async def get_client(self) -> AsyncIterable[SolarisClient]:
|
||||
client = SolarisClient(env.google.api_key.get_secret_value())
|
||||
async def get_solaris_client(
|
||||
self, client: AsyncClient
|
||||
) -> AsyncIterable[SolarisClient]:
|
||||
client = SolarisClient(gemini_client=client)
|
||||
yield client
|
||||
|
||||
@provide(scope=Scope.REQUEST)
|
||||
async def get_respond_service(
|
||||
self, client: AsyncClient, middleware_data: AiogramMiddlewareData
|
||||
) -> AsyncIterable[RespondService]:
|
||||
chat: aiogram.types.Chat = middleware_data["event_chat"]
|
||||
service = RespondService(client=client, chat_id=chat.id)
|
||||
await service.spawn_agent()
|
||||
yield service
|
||||
|
||||
@@ -7,8 +7,15 @@ class BotConfig(BaseModel):
|
||||
chats_whitelist: list[int] = []
|
||||
|
||||
|
||||
class GeminiModelsConfig(BaseModel):
|
||||
respond_model: str = "gemini-2.5-flash"
|
||||
message_review_model: str = "gemini-2.5-flash-lite-preview-06-17"
|
||||
tts_model: str = "gemini-2.5-flash-preview-tts"
|
||||
|
||||
|
||||
class DynamicConfigBase(BaseModel):
|
||||
bot: BotConfig = Field(default_factory=BotConfig)
|
||||
models: GeminiModelsConfig = Field(default_factory=GeminiModelsConfig)
|
||||
|
||||
|
||||
class DynamicConfig(DynamicConfigBase, Document):
|
||||
|
||||
Reference in New Issue
Block a user