feat(bot): developing integration with solaris
This commit is contained in:
@@ -7,28 +7,26 @@ from utils.config import dconfig
|
||||
|
||||
from ..constants import SAFETY_SETTINGS
|
||||
from ..structures import InputMessage, OutputMessage
|
||||
from ..tools import RESPOND_TOOLS
|
||||
|
||||
|
||||
class RespondAgent:
|
||||
chat: chats.AsyncChat
|
||||
|
||||
def __init__(self, client: genai.client.AsyncClient) -> None:
|
||||
self.client = client
|
||||
|
||||
async def load_chat(self, history: list[types.Content], system_prompt: str):
|
||||
self.chat = self.client.chats.create(
|
||||
async def generate_response(
|
||||
self, system_prompt: str, history: list[types.Content]
|
||||
) -> list[OutputMessage]:
|
||||
content = await self.client.models.generate_content(
|
||||
model=(await dconfig()).models.respond_model,
|
||||
contents=history,
|
||||
config=types.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
thinking_config=types.ThinkingConfig(thinking_budget=0),
|
||||
response_mime_type="application/json",
|
||||
response_schema=list[OutputMessage],
|
||||
safety_settings=SAFETY_SETTINGS,
|
||||
# safety_settings=SAFETY_SETTINGS,
|
||||
tools=RESPOND_TOOLS,
|
||||
),
|
||||
history=history,
|
||||
)
|
||||
|
||||
async def send_messages(self, messages: list[InputMessage]) -> list[OutputMessage]:
|
||||
data = json.dumps([msg.model_dump() for msg in messages], ensure_ascii=True)
|
||||
response = await self.chat.send_message(data)
|
||||
return response.parsed
|
||||
return content.parsed
|
||||
|
||||
@@ -22,7 +22,7 @@ class TTSAgent:
|
||||
)
|
||||
)
|
||||
),
|
||||
safety_settings=SAFETY_SETTINGS,
|
||||
# safety_settings=SAFETY_SETTINGS,
|
||||
)
|
||||
|
||||
async def generate(self, text: str):
|
||||
|
||||
@@ -3,7 +3,9 @@ from google.genai import types
|
||||
from .structures import OutputMessage
|
||||
|
||||
SAFETY_SETTINGS = [
|
||||
types.SafetySetting(category=category, threshold=types.HarmBlockThreshold.OFF)
|
||||
types.SafetySetting(
|
||||
category=category.value, threshold=types.HarmBlockThreshold.OFF.value
|
||||
)
|
||||
for category in types.HarmCategory
|
||||
]
|
||||
|
||||
@@ -18,7 +20,7 @@ def generate_review_config(prompt: str) -> types.GenerateContentConfig:
|
||||
thinking_config=types.ThinkingConfig(thinking_budget=0),
|
||||
response_mime_type="application/json",
|
||||
response_schema=list[int],
|
||||
safety_settings=SAFETY_SETTINGS,
|
||||
# safety_settings=SAFETY_SETTINGS,
|
||||
)
|
||||
|
||||
|
||||
|
||||
17
src/bot/modules/solaris/prompts/__init__.py
Normal file
17
src/bot/modules/solaris/prompts/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from pathlib import Path
|
||||
|
||||
STATIC_PATH = Path(__file__).parent
|
||||
|
||||
|
||||
def load_prompt(prompt_path: str, **kwargs) -> str:
|
||||
full_path = STATIC_PATH / prompt_path
|
||||
|
||||
try:
|
||||
with open(full_path, "r", encoding="utf-8") as file:
|
||||
template = file.read()
|
||||
|
||||
return template.format(**kwargs)
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"Prompt template not found: {full_path}")
|
||||
except KeyError as e:
|
||||
raise KeyError(f"Missing placeholder in template: {e}")
|
||||
@@ -0,0 +1,7 @@
|
||||
Ты — Солярис, дружелюбный и многофункциональный AI-ассистент в Telegram.
|
||||
Твоя главная задача — помогать пользователям, отвечая на их вопросы,
|
||||
поддерживая беседу и выполняя задачи с помощью доступных инструментов.
|
||||
Ты можешь запоминать контекст нашего диалога, чтобы общение было более естественным и продуктивным.
|
||||
Обращайся к пользователям вежливо и старайся быть максимально полезным.
|
||||
Когда тебя просят использовать инструмент для выполнения какой-либо задачи,
|
||||
всегда подтверждай вызов инструмента и результат его работы.
|
||||
@@ -1,13 +1,16 @@
|
||||
import json
|
||||
|
||||
from beanie.odm.operators.update.general import Set
|
||||
from google import genai
|
||||
from google.genai import chats, types
|
||||
|
||||
from utils.config import dconfig
|
||||
from utils.db.models.session import RespondSession
|
||||
from utils.logging import console
|
||||
|
||||
from ..agents.respond import RespondAgent
|
||||
from ..constants import SAFETY_SETTINGS
|
||||
from ..prompts import load_prompt
|
||||
from ..structures import InputMessage, OutputMessage
|
||||
|
||||
|
||||
@@ -15,8 +18,38 @@ class RespondService:
|
||||
def __init__(self, client: genai.client.AsyncClient, chat_id: int) -> None:
|
||||
self.agent = RespondAgent(client)
|
||||
self.chat_id = chat_id
|
||||
self.session: RespondSession | None = None
|
||||
|
||||
async def spawn_agent(self):
|
||||
console.print(self.chat_id)
|
||||
async def _get_or_create_session(self) -> RespondSession:
|
||||
session = await RespondSession.get_by_chat_id(chat_id=self.chat_id)
|
||||
if not session:
|
||||
session = await RespondSession.create_empty(
|
||||
chat_id=self.chat_id,
|
||||
system_prompt=load_prompt("default_system_prompt.txt"),
|
||||
)
|
||||
|
||||
await self.agent.load_chat(history=[], system_prompt="nya nya")
|
||||
self.session = session
|
||||
return session
|
||||
|
||||
async def process_message(self, message: InputMessage) -> list[OutputMessage]:
|
||||
session = await self._get_or_create_session()
|
||||
|
||||
session.history.append(
|
||||
types.Content(role="user", parts=[types.Part(text=message.text)])
|
||||
)
|
||||
|
||||
response_messages = await self.agent.generate_response(
|
||||
system_prompt=session.system_prompt, history=session.history
|
||||
)
|
||||
|
||||
if response_messages:
|
||||
model_response_content = response_messages[0].text
|
||||
session.history.append(
|
||||
types.Content(
|
||||
role="model", parts=[types.Part(text=model_response_content)]
|
||||
)
|
||||
)
|
||||
|
||||
await session.update_history(history=session.history)
|
||||
|
||||
return response_messages
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class InputMessage(BaseModel):
|
||||
time: str
|
||||
time: datetime.datetime
|
||||
message_id: int
|
||||
text: str
|
||||
user_id: int
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from .test import test_tool
|
||||
|
||||
RESPOND_TOOLS = [test_tool]
|
||||
|
||||
14
src/bot/modules/solaris/tools/test.py
Normal file
14
src/bot/modules/solaris/tools/test.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from utils.logging import console
|
||||
|
||||
|
||||
async def test_tool(content: str):
|
||||
"""Prints the content to the developer console.
|
||||
|
||||
Args:
|
||||
content: Anything you want to print.
|
||||
|
||||
Returns:
|
||||
A status if content was printed.
|
||||
"""
|
||||
console.print(content)
|
||||
return "ok"
|
||||
Reference in New Issue
Block a user