feat(solaris): moving to service system to support multi-chat
This commit is contained in:
@@ -23,6 +23,7 @@ async def runner():
|
|||||||
await bot.delete_webhook(drop_pending_updates=True)
|
await bot.delete_webhook(drop_pending_updates=True)
|
||||||
await dp.start_polling(bot)
|
await dp.start_polling(bot)
|
||||||
|
|
||||||
|
|
||||||
def plugins():
|
def plugins():
|
||||||
from rich import traceback
|
from rich import traceback
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
from aiogram import Router, types
|
from aiogram import Router, types
|
||||||
from aiogram.filters import CommandStart
|
from aiogram.filters import CommandStart
|
||||||
|
from dishka.integrations.aiogram import FromDishka
|
||||||
|
|
||||||
|
from bot.modules.solaris.services.respond import RespondService
|
||||||
|
|
||||||
router = Router()
|
router = Router()
|
||||||
|
|
||||||
|
|
||||||
@router.message(CommandStart())
|
@router.message(CommandStart())
|
||||||
async def on_start(message: types.Message):
|
async def on_start(message: types.Message, respond_service: FromDishka[RespondService]):
|
||||||
await message.reply("hewo everynyan")
|
await message.reply(str(respond_service.chat_id))
|
||||||
|
|||||||
@@ -1,18 +1,31 @@
|
|||||||
import json
|
import json
|
||||||
from dataclasses import asdict
|
|
||||||
|
|
||||||
from google import genai
|
from google import genai
|
||||||
|
from google.genai import chats, types
|
||||||
|
|
||||||
from ..content_configs import generate_respond_config
|
from utils.config import dconfig
|
||||||
|
|
||||||
|
from ..constants import SAFETY_SETTINGS
|
||||||
from ..structures import InputMessage, OutputMessage
|
from ..structures import InputMessage, OutputMessage
|
||||||
|
|
||||||
RESPOND_MODEL = "gemini-2.5-flash"
|
|
||||||
|
|
||||||
|
|
||||||
class RespondAgent:
|
class RespondAgent:
|
||||||
def __init__(self, client: genai.client.AsyncClient, prompt: str) -> None:
|
chat: chats.AsyncChat
|
||||||
self.chat = client.chats.create(
|
|
||||||
model=RESPOND_MODEL, config=generate_respond_config(prompt=prompt)
|
def __init__(self, client: genai.client.AsyncClient) -> None:
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
async def load_chat(self, history: list[types.Content], system_prompt: str):
|
||||||
|
self.chat = self.client.chats.create(
|
||||||
|
model=(await dconfig()).models.respond_model,
|
||||||
|
config=types.GenerateContentConfig(
|
||||||
|
system_instruction=system_prompt,
|
||||||
|
thinking_config=types.ThinkingConfig(thinking_budget=0),
|
||||||
|
response_mime_type="application/json",
|
||||||
|
response_schema=list[OutputMessage],
|
||||||
|
safety_settings=SAFETY_SETTINGS,
|
||||||
|
),
|
||||||
|
history=history,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def send_messages(self, messages: list[InputMessage]) -> list[OutputMessage]:
|
async def send_messages(self, messages: list[InputMessage]) -> list[OutputMessage]:
|
||||||
|
|||||||
@@ -2,10 +2,10 @@ import json
|
|||||||
|
|
||||||
from google import genai
|
from google import genai
|
||||||
|
|
||||||
from ..content_configs import generate_review_config
|
from ..constants import generate_review_config
|
||||||
from ..structures import InputMessage
|
from ..structures import InputMessage
|
||||||
|
|
||||||
REVIEW_MODEL = ("gemini-2.5-flash-lite-preview-06-17",) # надо будет гемму
|
REVIEW_MODEL = "gemini-2.5-flash-lite-preview-06-17" # надо будет гемму
|
||||||
|
|
||||||
|
|
||||||
class ReviewAgent:
|
class ReviewAgent:
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from google import genai
|
|||||||
from google.genai import types
|
from google.genai import types
|
||||||
from pydub import AudioSegment
|
from pydub import AudioSegment
|
||||||
|
|
||||||
from ..content_configs import generate_tts_config
|
from ..constants import SAFETY_SETTINGS
|
||||||
|
|
||||||
TTS_MODEL = "gemini-2.5-flash-preview-tts"
|
TTS_MODEL = "gemini-2.5-flash-preview-tts"
|
||||||
|
|
||||||
@@ -12,18 +12,34 @@ TTS_MODEL = "gemini-2.5-flash-preview-tts"
|
|||||||
class TTSAgent:
|
class TTSAgent:
|
||||||
def __init__(self, client: genai.client.AsyncClient) -> None:
|
def __init__(self, client: genai.client.AsyncClient) -> None:
|
||||||
self.client = client
|
self.client = client
|
||||||
self.content_config = generate_tts_config()
|
|
||||||
|
self.content_config = types.GenerateContentConfig(
|
||||||
|
response_modalities=[types.Modality.AUDIO],
|
||||||
|
speech_config=types.SpeechConfig(
|
||||||
|
voice_config=types.VoiceConfig(
|
||||||
|
prebuilt_voice_config=types.PrebuiltVoiceConfig(
|
||||||
|
voice_name="Kore",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
),
|
||||||
|
safety_settings=SAFETY_SETTINGS,
|
||||||
|
)
|
||||||
|
|
||||||
async def generate(self, text: str):
|
async def generate(self, text: str):
|
||||||
response = await self.client.models.generate_content(
|
response = await self.client.models.generate_content(
|
||||||
model=TTS_MODEL, contents=text, config=self.content_config
|
model=TTS_MODEL, contents=text, config=self.content_config
|
||||||
)
|
)
|
||||||
|
|
||||||
data = response.candidates[0].content.parts[0].inline_data.data
|
data = response.candidates[0].content.parts[0].inline_data.data
|
||||||
pcm_io = io.BytesIO(data)
|
pcm_io = io.BytesIO(data)
|
||||||
|
pcm_io.seek(0)
|
||||||
|
|
||||||
audio = AudioSegment(
|
audio = AudioSegment(
|
||||||
pcm_io.read(), sample_width=2, frame_rate=24000, channels=1
|
pcm_io.read(), sample_width=2, frame_rate=24000, channels=1
|
||||||
)
|
)
|
||||||
|
|
||||||
ogg_io = io.BytesIO()
|
ogg_io = io.BytesIO()
|
||||||
audio.export(ogg_io, format="ogg", codec="libopus")
|
audio.export(ogg_io, format="ogg", codec="libopus")
|
||||||
ogg_bytes = ogg_io.getvalue()
|
ogg_bytes = ogg_io.getvalue()
|
||||||
|
|
||||||
return ogg_bytes
|
return ogg_bytes
|
||||||
|
|||||||
@@ -4,9 +4,8 @@ from .agents import BuildAgent
|
|||||||
|
|
||||||
|
|
||||||
class SolarisClient:
|
class SolarisClient:
|
||||||
def __init__(self, api_key: str) -> None:
|
def __init__(self, gemini_client: genai.client.AsyncClient) -> None:
|
||||||
client = genai.Client(api_key=api_key).aio
|
self.builder = BuildAgent(client=gemini_client)
|
||||||
self.builder = BuildAgent(client=client)
|
|
||||||
|
|
||||||
async def parse_user_data(self, some_data_idk):
|
async def parse_user_data(self, some_data_idk):
|
||||||
self.reviewer, self.responder = await self.builder.build(
|
self.reviewer, self.responder = await self.builder.build(
|
||||||
|
|||||||
28
src/bot/modules/solaris/constants.py
Normal file
28
src/bot/modules/solaris/constants.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
from google.genai import types
|
||||||
|
|
||||||
|
from .structures import OutputMessage
|
||||||
|
|
||||||
|
SAFETY_SETTINGS = [
|
||||||
|
types.SafetySetting(category=category, threshold=types.HarmBlockThreshold.OFF)
|
||||||
|
for category in types.HarmCategory
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def generate_respond_config(prompt: str) -> types.GenerateContentConfig:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def generate_review_config(prompt: str) -> types.GenerateContentConfig:
|
||||||
|
return types.GenerateContentConfig(
|
||||||
|
system_instruction=prompt,
|
||||||
|
thinking_config=types.ThinkingConfig(thinking_budget=0),
|
||||||
|
response_mime_type="application/json",
|
||||||
|
response_schema=list[int],
|
||||||
|
safety_settings=SAFETY_SETTINGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# MESSAGE_CONTENT_CONFIG = types.GenerateContentConfig(
|
||||||
|
# response_mime_type="application/json", # возможно можно скипнуть это если мы используем response_schema, надо проверить
|
||||||
|
# response_schema=list[OutputMessage],
|
||||||
|
# )
|
||||||
@@ -1,51 +0,0 @@
|
|||||||
from google.genai import types
|
|
||||||
|
|
||||||
from .structures import OutputMessage
|
|
||||||
|
|
||||||
safety_settings = [
|
|
||||||
types.SafetySetting(category=category, threshold=types.HarmBlockThreshold.OFF)
|
|
||||||
for category in types.HarmCategory
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def generate_respond_config(prompt: str) -> types.GenerateContentConfig:
|
|
||||||
return types.GenerateContentConfig(
|
|
||||||
system_instruction=prompt,
|
|
||||||
thinking_config=types.ThinkingConfig(thinking_budget=0),
|
|
||||||
response_mime_type="application/json",
|
|
||||||
response_schema=list[
|
|
||||||
OutputMessage
|
|
||||||
], # ты уверен что там json надо? мне просто каж что судя по всему вот эта нам
|
|
||||||
safety_settings=safety_settings,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_review_config(prompt: str) -> types.GenerateContentConfig:
|
|
||||||
return types.GenerateContentConfig(
|
|
||||||
system_instruction=prompt,
|
|
||||||
thinking_config=types.ThinkingConfig(thinking_budget=0),
|
|
||||||
response_mime_type="application/json",
|
|
||||||
response_schema=list[int],
|
|
||||||
safety_settings=safety_settings,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# MESSAGE_CONTENT_CONFIG = types.GenerateContentConfig(
|
|
||||||
# response_mime_type="application/json", # возможно можно скипнуть это если мы используем response_schema, надо проверить
|
|
||||||
# response_schema=list[OutputMessage],
|
|
||||||
# )
|
|
||||||
|
|
||||||
|
|
||||||
# можно было и константу оставить но хезе некрасиво чет
|
|
||||||
def generate_tts_config() -> types.GenerateContentConfig:
|
|
||||||
return types.GenerateContentConfig(
|
|
||||||
response_modalities=[types.Modality.AUDIO],
|
|
||||||
speech_config=types.SpeechConfig(
|
|
||||||
voice_config=types.VoiceConfig(
|
|
||||||
prebuilt_voice_config=types.PrebuiltVoiceConfig(
|
|
||||||
voice_name="Kore",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
),
|
|
||||||
safety_settings=safety_settings,
|
|
||||||
)
|
|
||||||
0
src/bot/modules/solaris/services/__init__.py
Normal file
0
src/bot/modules/solaris/services/__init__.py
Normal file
22
src/bot/modules/solaris/services/respond.py
Normal file
22
src/bot/modules/solaris/services/respond.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from google import genai
|
||||||
|
from google.genai import chats, types
|
||||||
|
|
||||||
|
from utils.config import dconfig
|
||||||
|
from utils.logging import console
|
||||||
|
|
||||||
|
from ..agents.respond import RespondAgent
|
||||||
|
from ..constants import SAFETY_SETTINGS
|
||||||
|
from ..structures import InputMessage, OutputMessage
|
||||||
|
|
||||||
|
|
||||||
|
class RespondService:
|
||||||
|
def __init__(self, client: genai.client.AsyncClient, chat_id: int) -> None:
|
||||||
|
self.agent = RespondAgent(client)
|
||||||
|
self.chat_id = chat_id
|
||||||
|
|
||||||
|
async def spawn_agent(self):
|
||||||
|
console.print(self.chat_id)
|
||||||
|
|
||||||
|
await self.agent.load_chat(history=[], system_prompt="nya nya")
|
||||||
0
src/bot/modules/solaris/tools/__init__.py
Normal file
0
src/bot/modules/solaris/tools/__init__.py
Normal file
@@ -1,7 +1,10 @@
|
|||||||
from dishka import make_async_container
|
from dishka import make_async_container
|
||||||
|
from dishka.integrations.aiogram import AiogramProvider
|
||||||
|
|
||||||
from .providers import SolarisClientProvider
|
from .providers import GeminiClientProvider, SolarisProvider
|
||||||
|
|
||||||
container = make_async_container(
|
container = make_async_container(
|
||||||
SolarisClientProvider(),
|
AiogramProvider(),
|
||||||
|
SolarisProvider(),
|
||||||
|
GeminiClientProvider(),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1 +1,2 @@
|
|||||||
from .solaris import SolarisClientProvider
|
from .gemini import GeminiClientProvider
|
||||||
|
from .solaris import SolarisProvider
|
||||||
|
|||||||
13
src/dependencies/providers/gemini.py
Normal file
13
src/dependencies/providers/gemini.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
from typing import AsyncIterable
|
||||||
|
|
||||||
|
from dishka import Provider, Scope, provide
|
||||||
|
from google import genai
|
||||||
|
|
||||||
|
from utils.env import env
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiClientProvider(Provider):
|
||||||
|
@provide(scope=Scope.APP)
|
||||||
|
async def get_client(self) -> AsyncIterable[genai.client.AsyncClient]:
|
||||||
|
client = genai.Client(api_key=env.google.api_key.get_secret_value()).aio
|
||||||
|
yield client
|
||||||
@@ -1,13 +1,27 @@
|
|||||||
from typing import AsyncIterable
|
from typing import AsyncIterable
|
||||||
|
|
||||||
|
import aiogram.types
|
||||||
from dishka import Provider, Scope, provide
|
from dishka import Provider, Scope, provide
|
||||||
|
from dishka.integrations.aiogram import AiogramMiddlewareData
|
||||||
|
from google.genai.client import AsyncClient
|
||||||
|
|
||||||
from bot.modules.solaris.client import SolarisClient
|
from bot.modules.solaris.client import SolarisClient
|
||||||
from utils.env import env
|
from bot.modules.solaris.services.respond import RespondService
|
||||||
|
|
||||||
|
|
||||||
class SolarisClientProvider(Provider):
|
class SolarisProvider(Provider):
|
||||||
@provide(scope=Scope.APP)
|
@provide(scope=Scope.APP)
|
||||||
async def get_client(self) -> AsyncIterable[SolarisClient]:
|
async def get_solaris_client(
|
||||||
client = SolarisClient(env.google.api_key.get_secret_value())
|
self, client: AsyncClient
|
||||||
|
) -> AsyncIterable[SolarisClient]:
|
||||||
|
client = SolarisClient(gemini_client=client)
|
||||||
yield client
|
yield client
|
||||||
|
|
||||||
|
@provide(scope=Scope.REQUEST)
|
||||||
|
async def get_respond_service(
|
||||||
|
self, client: AsyncClient, middleware_data: AiogramMiddlewareData
|
||||||
|
) -> AsyncIterable[RespondService]:
|
||||||
|
chat: aiogram.types.Chat = middleware_data["event_chat"]
|
||||||
|
service = RespondService(client=client, chat_id=chat.id)
|
||||||
|
await service.spawn_agent()
|
||||||
|
yield service
|
||||||
|
|||||||
@@ -7,8 +7,15 @@ class BotConfig(BaseModel):
|
|||||||
chats_whitelist: list[int] = []
|
chats_whitelist: list[int] = []
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiModelsConfig(BaseModel):
|
||||||
|
respond_model: str = "gemini-2.5-flash"
|
||||||
|
message_review_model: str = "gemini-2.5-flash-lite-preview-06-17"
|
||||||
|
tts_model: str = "gemini-2.5-flash-preview-tts"
|
||||||
|
|
||||||
|
|
||||||
class DynamicConfigBase(BaseModel):
|
class DynamicConfigBase(BaseModel):
|
||||||
bot: BotConfig = Field(default_factory=BotConfig)
|
bot: BotConfig = Field(default_factory=BotConfig)
|
||||||
|
models: GeminiModelsConfig = Field(default_factory=GeminiModelsConfig)
|
||||||
|
|
||||||
|
|
||||||
class DynamicConfig(DynamicConfigBase, Document):
|
class DynamicConfig(DynamicConfigBase, Document):
|
||||||
|
|||||||
Reference in New Issue
Block a user