diff --git a/src/bot/handlers/__init__.py b/src/bot/handlers/__init__.py index 974cdbb..93e9d54 100644 --- a/src/bot/handlers/__init__.py +++ b/src/bot/handlers/__init__.py @@ -2,6 +2,7 @@ from aiogram import Router from . import ( initialize, + message, start, ) @@ -11,4 +12,5 @@ router = Router() router.include_routers( start.router, initialize.router, + message.router, ) diff --git a/src/bot/handlers/message/__init__.py b/src/bot/handlers/message/__init__.py new file mode 100644 index 0000000..66e6c25 --- /dev/null +++ b/src/bot/handlers/message/__init__.py @@ -0,0 +1 @@ +from .message import router diff --git a/src/bot/handlers/message/message.py b/src/bot/handlers/message/message.py new file mode 100644 index 0000000..0ceaa29 --- /dev/null +++ b/src/bot/handlers/message/message.py @@ -0,0 +1,29 @@ +from aiogram import F, Router +from aiogram.types import Message +from dishka import FromDishka + +from bot.modules.solaris.services.respond import RespondService +from bot.modules.solaris.structures import InputMessage + +router = Router() + + +@router.message(F.text) +async def message_handler( + message: Message, respond_service: FromDishka[RespondService] +): + input_message = InputMessage( + time=message.date, + message_id=message.message_id, + text=message.text, + user_id=message.from_user.id, + username=message.from_user.full_name, + reply_to=( + message.reply_to_message.message_id if message.reply_to_message else None + ), + ) + + output_messages = await respond_service.process_message(input_message) + + for msg in output_messages: + await message.answer(msg.text) diff --git a/src/bot/modules/solaris/agents/respond.py b/src/bot/modules/solaris/agents/respond.py index 29c5e5c..b994b49 100644 --- a/src/bot/modules/solaris/agents/respond.py +++ b/src/bot/modules/solaris/agents/respond.py @@ -7,28 +7,26 @@ from utils.config import dconfig from ..constants import SAFETY_SETTINGS from ..structures import InputMessage, OutputMessage +from ..tools import RESPOND_TOOLS class RespondAgent: - chat: chats.AsyncChat - def __init__(self, client: genai.client.AsyncClient) -> None: self.client = client - async def load_chat(self, history: list[types.Content], system_prompt: str): - self.chat = self.client.chats.create( + async def generate_response( + self, system_prompt: str, history: list[types.Content] + ) -> list[OutputMessage]: + content = await self.client.models.generate_content( model=(await dconfig()).models.respond_model, + contents=history, config=types.GenerateContentConfig( system_instruction=system_prompt, thinking_config=types.ThinkingConfig(thinking_budget=0), response_mime_type="application/json", response_schema=list[OutputMessage], - safety_settings=SAFETY_SETTINGS, + # safety_settings=SAFETY_SETTINGS, + tools=RESPOND_TOOLS, ), - history=history, ) - - async def send_messages(self, messages: list[InputMessage]) -> list[OutputMessage]: - data = json.dumps([msg.model_dump() for msg in messages], ensure_ascii=True) - response = await self.chat.send_message(data) - return response.parsed + return content.parsed diff --git a/src/bot/modules/solaris/agents/tts.py b/src/bot/modules/solaris/agents/tts.py index 1cff3c1..de9cd70 100644 --- a/src/bot/modules/solaris/agents/tts.py +++ b/src/bot/modules/solaris/agents/tts.py @@ -22,7 +22,7 @@ class TTSAgent: ) ) ), - safety_settings=SAFETY_SETTINGS, + # safety_settings=SAFETY_SETTINGS, ) async def generate(self, text: str): diff --git a/src/bot/modules/solaris/constants.py b/src/bot/modules/solaris/constants.py index be205f6..bab8191 100644 --- a/src/bot/modules/solaris/constants.py +++ b/src/bot/modules/solaris/constants.py @@ -3,7 +3,9 @@ from google.genai import types from .structures import OutputMessage SAFETY_SETTINGS = [ - types.SafetySetting(category=category, threshold=types.HarmBlockThreshold.OFF) + types.SafetySetting( + category=category.value, threshold=types.HarmBlockThreshold.OFF.value + ) for category in types.HarmCategory ] @@ -18,7 +20,7 @@ def generate_review_config(prompt: str) -> types.GenerateContentConfig: thinking_config=types.ThinkingConfig(thinking_budget=0), response_mime_type="application/json", response_schema=list[int], - safety_settings=SAFETY_SETTINGS, + # safety_settings=SAFETY_SETTINGS, ) diff --git a/src/bot/modules/solaris/prompts/__init__.py b/src/bot/modules/solaris/prompts/__init__.py new file mode 100644 index 0000000..959cb90 --- /dev/null +++ b/src/bot/modules/solaris/prompts/__init__.py @@ -0,0 +1,17 @@ +from pathlib import Path + +STATIC_PATH = Path(__file__).parent + + +def load_prompt(prompt_path: str, **kwargs) -> str: + full_path = STATIC_PATH / prompt_path + + try: + with open(full_path, "r", encoding="utf-8") as file: + template = file.read() + + return template.format(**kwargs) + except FileNotFoundError: + raise FileNotFoundError(f"Prompt template not found: {full_path}") + except KeyError as e: + raise KeyError(f"Missing placeholder in template: {e}") diff --git a/src/bot/modules/solaris/prompts/default_system_prompt.txt b/src/bot/modules/solaris/prompts/default_system_prompt.txt new file mode 100644 index 0000000..f538750 --- /dev/null +++ b/src/bot/modules/solaris/prompts/default_system_prompt.txt @@ -0,0 +1,7 @@ +Ты — Солярис, дружелюбный и многофункциональный AI-ассистент в Telegram. +Твоя главная задача — помогать пользователям, отвечая на их вопросы, +поддерживая беседу и выполняя задачи с помощью доступных инструментов. +Ты можешь запоминать контекст нашего диалога, чтобы общение было более естественным и продуктивным. +Обращайся к пользователям вежливо и старайся быть максимально полезным. +Когда тебя просят использовать инструмент для выполнения какой-либо задачи, +всегда подтверждай вызов инструмента и результат его работы. diff --git a/src/bot/modules/solaris/services/respond.py b/src/bot/modules/solaris/services/respond.py index 4ab66c1..35ef546 100644 --- a/src/bot/modules/solaris/services/respond.py +++ b/src/bot/modules/solaris/services/respond.py @@ -1,13 +1,16 @@ import json +from beanie.odm.operators.update.general import Set from google import genai from google.genai import chats, types from utils.config import dconfig +from utils.db.models.session import RespondSession from utils.logging import console from ..agents.respond import RespondAgent from ..constants import SAFETY_SETTINGS +from ..prompts import load_prompt from ..structures import InputMessage, OutputMessage @@ -15,8 +18,38 @@ class RespondService: def __init__(self, client: genai.client.AsyncClient, chat_id: int) -> None: self.agent = RespondAgent(client) self.chat_id = chat_id + self.session: RespondSession | None = None - async def spawn_agent(self): - console.print(self.chat_id) + async def _get_or_create_session(self) -> RespondSession: + session = await RespondSession.get_by_chat_id(chat_id=self.chat_id) + if not session: + session = await RespondSession.create_empty( + chat_id=self.chat_id, + system_prompt=load_prompt("default_system_prompt.txt"), + ) - await self.agent.load_chat(history=[], system_prompt="nya nya") + self.session = session + return session + + async def process_message(self, message: InputMessage) -> list[OutputMessage]: + session = await self._get_or_create_session() + + session.history.append( + types.Content(role="user", parts=[types.Part(text=message.text)]) + ) + + response_messages = await self.agent.generate_response( + system_prompt=session.system_prompt, history=session.history + ) + + if response_messages: + model_response_content = response_messages[0].text + session.history.append( + types.Content( + role="model", parts=[types.Part(text=model_response_content)] + ) + ) + + await session.update_history(history=session.history) + + return response_messages diff --git a/src/bot/modules/solaris/structures/input_message.py b/src/bot/modules/solaris/structures/input_message.py index 913c83e..396bca2 100644 --- a/src/bot/modules/solaris/structures/input_message.py +++ b/src/bot/modules/solaris/structures/input_message.py @@ -1,10 +1,11 @@ +import datetime from typing import Optional from pydantic import BaseModel class InputMessage(BaseModel): - time: str + time: datetime.datetime message_id: int text: str user_id: int diff --git a/src/bot/modules/solaris/tools/__init__.py b/src/bot/modules/solaris/tools/__init__.py index e69de29..41a2d6f 100644 --- a/src/bot/modules/solaris/tools/__init__.py +++ b/src/bot/modules/solaris/tools/__init__.py @@ -0,0 +1,3 @@ +from .test import test_tool + +RESPOND_TOOLS = [test_tool] diff --git a/src/bot/modules/solaris/tools/test.py b/src/bot/modules/solaris/tools/test.py new file mode 100644 index 0000000..21a2b61 --- /dev/null +++ b/src/bot/modules/solaris/tools/test.py @@ -0,0 +1,14 @@ +from utils.logging import console + + +async def test_tool(content: str): + """Prints the content to the developer console. + + Args: + content: Anything you want to print. + + Returns: + A status if content was printed. + """ + console.print(content) + return "ok" diff --git a/src/dependencies/providers/solaris.py b/src/dependencies/providers/solaris.py index 8857e94..dc8fe22 100644 --- a/src/dependencies/providers/solaris.py +++ b/src/dependencies/providers/solaris.py @@ -23,5 +23,4 @@ class SolarisProvider(Provider): ) -> AsyncIterable[RespondService]: chat: aiogram.types.Chat = middleware_data["event_chat"] service = RespondService(client=client, chat_id=chat.id) - await service.spawn_agent() yield service diff --git a/src/utils/db/__init__.py b/src/utils/db/__init__.py index cc1333b..c201964 100644 --- a/src/utils/db/__init__.py +++ b/src/utils/db/__init__.py @@ -9,11 +9,15 @@ client = AsyncIOMotorClient(env.db.connection_url) async def init_db(): from .models import ( DynamicConfig, + RespondSession, + ReviewSession, ) await init_beanie( database=client[env.db.db_name], document_models=[ DynamicConfig, + RespondSession, + ReviewSession, ], ) diff --git a/src/utils/db/models/__init__.py b/src/utils/db/models/__init__.py index f745e21..2255ab5 100644 --- a/src/utils/db/models/__init__.py +++ b/src/utils/db/models/__init__.py @@ -1 +1,2 @@ from .config import DynamicConfig +from .session import RespondSession, ReviewSession diff --git a/src/utils/db/models/session.py b/src/utils/db/models/session.py index 88fc85c..60e2a30 100644 --- a/src/utils/db/models/session.py +++ b/src/utils/db/models/session.py @@ -11,11 +11,24 @@ class SessionBase(BaseModel): history: List[Content] = Field(default_factory=list) -class ReviewSession(SessionBase, Document): +class __CommonSessionRepository(SessionBase, Document): + @classmethod + async def get_by_chat_id(cls, chat_id: int): + return await cls.find_one(cls.chat_id == chat_id) + + @classmethod + async def create_empty(cls, chat_id: int, system_prompt: str): + return await cls(chat_id=chat_id, system_prompt=system_prompt).insert() + + async def update_history(self, history: List[Content]): + await self.set({self.history: history}) + + +class ReviewSession(__CommonSessionRepository): class Settings: name = "review_sessions" -class RespondSession(SessionBase, Document): +class RespondSession(__CommonSessionRepository): class Settings: name = "respond_sessions"