feat(*): add multiple image support
This commit is contained in:
@@ -3,10 +3,17 @@ import base64
|
||||
import contextlib
|
||||
import io
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from aiogram import Bot, F, Router, html, types
|
||||
from aiogram import BaseMiddleware, Bot, F, Router, html, types
|
||||
from aiogram.enums import ChatAction
|
||||
from aiogram.types import KeyboardButton, ReplyKeyboardMarkup, ReplyKeyboardRemove
|
||||
from aiogram.types import (
|
||||
KeyboardButton,
|
||||
ReplyKeyboardMarkup,
|
||||
ReplyKeyboardRemove,
|
||||
TelegramObject,
|
||||
)
|
||||
from convex import ConvexInt64
|
||||
|
||||
from bot.modules.ai import (
|
||||
@@ -23,6 +30,45 @@ from utils.convex import ConvexClient
|
||||
router = Router()
|
||||
convex = ConvexClient(env.convex_url)
|
||||
|
||||
ALBUM_COLLECT_DELAY = 0.5
|
||||
|
||||
|
||||
class AlbumMiddleware(BaseMiddleware):
|
||||
def __init__(self) -> None:
|
||||
self.albums: dict[str, list[types.Message]] = {}
|
||||
self.scheduled: set[str] = set()
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]],
|
||||
event: TelegramObject,
|
||||
data: dict[str, Any],
|
||||
) -> Any: # noqa: ANN401
|
||||
if not isinstance(event, types.Message) or not event.media_group_id:
|
||||
return await handler(event, data)
|
||||
|
||||
album_id = event.media_group_id
|
||||
if album_id not in self.albums:
|
||||
self.albums[album_id] = []
|
||||
self.albums[album_id].append(event)
|
||||
|
||||
if album_id in self.scheduled:
|
||||
return None
|
||||
|
||||
self.scheduled.add(album_id)
|
||||
await asyncio.sleep(ALBUM_COLLECT_DELAY)
|
||||
|
||||
messages = self.albums.pop(album_id, [])
|
||||
self.scheduled.discard(album_id)
|
||||
|
||||
if messages:
|
||||
data["album"] = messages
|
||||
return await handler(messages[0], data)
|
||||
return None
|
||||
|
||||
|
||||
router.message.middleware(AlbumMiddleware())
|
||||
|
||||
EDIT_THROTTLE_SECONDS = 1.0
|
||||
TELEGRAM_MAX_LENGTH = 4096
|
||||
|
||||
@@ -398,6 +444,74 @@ async def on_text_message(message: types.Message, bot: Bot) -> None:
|
||||
await process_message(message.from_user.id, message.text, bot, message.chat.id)
|
||||
|
||||
|
||||
@router.message(F.media_group_id, F.photo)
|
||||
async def on_album_message(
|
||||
message: types.Message, bot: Bot, album: list[types.Message]
|
||||
) -> None:
|
||||
if not message.from_user:
|
||||
return
|
||||
|
||||
await convex.mutation(
|
||||
"users:getOrCreate",
|
||||
{
|
||||
"telegramId": ConvexInt64(message.from_user.id),
|
||||
"telegramChatId": ConvexInt64(message.chat.id),
|
||||
},
|
||||
)
|
||||
|
||||
user = await convex.query(
|
||||
"users:getByTelegramId", {"telegramId": ConvexInt64(message.from_user.id)}
|
||||
)
|
||||
|
||||
if not user or not user.get("activeChatId"):
|
||||
await message.answer("Use /new first to create a chat.")
|
||||
return
|
||||
|
||||
caption = message.caption or "Process the images according to your task"
|
||||
|
||||
images_base64: list[str] = []
|
||||
images_media_types: list[str] = []
|
||||
|
||||
for msg in album:
|
||||
if not msg.photo:
|
||||
continue
|
||||
|
||||
photo = msg.photo[-1]
|
||||
file = await bot.get_file(photo.file_id)
|
||||
if not file.file_path:
|
||||
continue
|
||||
|
||||
buffer = io.BytesIO()
|
||||
await bot.download_file(file.file_path, buffer)
|
||||
image_bytes = buffer.getvalue()
|
||||
images_base64.append(base64.b64encode(image_bytes).decode())
|
||||
|
||||
ext = file.file_path.rsplit(".", 1)[-1].lower()
|
||||
media_type = f"image/{ext}" if ext in ("png", "gif", "webp") else "image/jpeg"
|
||||
images_media_types.append(media_type)
|
||||
|
||||
if not images_base64:
|
||||
await message.answer("Failed to get photos.")
|
||||
return
|
||||
|
||||
active_chat_id = user["activeChatId"]
|
||||
await convex.mutation(
|
||||
"messages:create",
|
||||
{
|
||||
"chatId": active_chat_id,
|
||||
"role": "user",
|
||||
"content": caption,
|
||||
"source": "telegram",
|
||||
"imagesBase64": images_base64,
|
||||
"imagesMediaTypes": images_media_types,
|
||||
},
|
||||
)
|
||||
|
||||
await process_message(
|
||||
message.from_user.id, caption, bot, message.chat.id, skip_user_message=True
|
||||
)
|
||||
|
||||
|
||||
@router.message(F.photo)
|
||||
async def on_photo_message(message: types.Message, bot: Bot) -> None:
|
||||
if not message.from_user or not message.photo:
|
||||
|
||||
Reference in New Issue
Block a user