From 6a5cde6ae448ebb3483d89f4ad0fbe74de76af4a Mon Sep 17 00:00:00 2001 From: h Date: Fri, 29 May 2026 22:50:21 +0200 Subject: [PATCH] feat: search functionality --- .../versions/b7e3d9f1a4c6_fts_tsvector_gin.py | 32 +++++++ backend/src/utils/search/__init__.py | 0 backend/src/utils/search/models.py | 30 +++++++ backend/src/utils/search/repository.py | 88 +++++++++++++++++++ 4 files changed, 150 insertions(+) create mode 100644 backend/migrations/versions/b7e3d9f1a4c6_fts_tsvector_gin.py create mode 100644 backend/src/utils/search/__init__.py create mode 100644 backend/src/utils/search/models.py create mode 100644 backend/src/utils/search/repository.py diff --git a/backend/migrations/versions/b7e3d9f1a4c6_fts_tsvector_gin.py b/backend/migrations/versions/b7e3d9f1a4c6_fts_tsvector_gin.py new file mode 100644 index 0000000..554ad59 --- /dev/null +++ b/backend/migrations/versions/b7e3d9f1a4c6_fts_tsvector_gin.py @@ -0,0 +1,32 @@ +"""fts expression gin on messages and media + +Revision ID: b7e3d9f1a4c6 +Revises: a1d4f7c2e9b8 +Create Date: 2026-05-29 22:30:00.000000 + +""" + +from collections.abc import Sequence + +from alembic import op + +revision: str = "b7e3d9f1a4c6" +down_revision: str | None = "a1d4f7c2e9b8" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.execute( + "CREATE INDEX ix_messages_tsv ON messages USING gin " + "(to_tsvector('russian', coalesce(text, '')))" + ) + op.execute( + "CREATE INDEX ix_media_tsv ON media USING gin " + "(to_tsvector('russian', coalesce(extracted_text, '')))" + ) + + +def downgrade() -> None: + op.execute("DROP INDEX IF EXISTS ix_media_tsv") + op.execute("DROP INDEX IF EXISTS ix_messages_tsv") diff --git a/backend/src/utils/search/__init__.py b/backend/src/utils/search/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/src/utils/search/models.py b/backend/src/utils/search/models.py new file mode 100644 index 0000000..bf9c2b6 --- /dev/null +++ b/backend/src/utils/search/models.py @@ -0,0 +1,30 @@ +from datetime import datetime +from typing import Literal + +from pydantic import BaseModel + +DEFAULT_LIMIT = 50 +MAX_LIMIT = 500 + + +class SearchFilters(BaseModel): + account_id: int + query: str | None = None + chat_id: int | None = None + sender_id: int | None = None + has_media: bool | None = None + date_from: datetime | None = None + date_to: datetime | None = None + regex: str | None = None + limit: int = DEFAULT_LIMIT + offset: int = 0 + + +class SearchHit(BaseModel): + chat_id: int + message_id: int + date: datetime + sender_id: int | None + text: str | None + source: Literal["text", "stt"] + extracted_text: str | None diff --git a/backend/src/utils/search/repository.py b/backend/src/utils/search/repository.py new file mode 100644 index 0000000..10b96cf --- /dev/null +++ b/backend/src/utils/search/repository.py @@ -0,0 +1,88 @@ +import asyncpg + +from utils.search.models import MAX_LIMIT, SearchFilters, SearchHit + +_TEXT_TSV = "to_tsvector('russian', coalesce(m.text, ''))" +_STT_TSV = "to_tsvector('russian', coalesce(md.extracted_text, ''))" +_TSQUERY = "websearch_to_tsquery('russian', ${})" + +_HITS = "{cols} FROM messages m WHERE {where}" # text branch +_STT_HITS = ( + "{cols} FROM messages m " + "JOIN media md ON md.account_id = m.account_id " + "AND md.chat_id = m.chat_id AND md.message_id = m.message_id " + "WHERE {where}" +) +_TEXT_COLS = ( + "SELECT m.chat_id, m.message_id, m.date, m.sender_id, m.text, " + "'text' AS source, NULL::text AS extracted_text" +) +_STT_COLS = ( + "SELECT m.chat_id, m.message_id, m.date, m.sender_id, m.text, " + "'stt' AS source, md.extracted_text" +) + + +def _common_conditions(filters: SearchFilters, params: list[object]) -> list[str]: + params.append(filters.account_id) + conds = [f"m.account_id = ${len(params)}"] + if filters.chat_id is not None: + params.append(filters.chat_id) + conds.append(f"m.chat_id = ${len(params)}") + if filters.sender_id is not None: + params.append(filters.sender_id) + conds.append(f"m.sender_id = ${len(params)}") + if filters.has_media is not None: + params.append(filters.has_media) + conds.append(f"m.has_media = ${len(params)}") + if filters.date_from is not None: + params.append(filters.date_from) + conds.append(f"m.date >= ${len(params)}") + if filters.date_to is not None: + params.append(filters.date_to) + conds.append(f"m.date <= ${len(params)}") + return conds + + +async def search_messages( + pool: asyncpg.Pool, filters: SearchFilters +) -> list[SearchHit]: + query = filters.query.strip() if filters.query else None + params: list[object] = [] + common = _common_conditions(filters, params) + if filters.regex is not None: + params.append(filters.regex) + common.append(f"m.text ~ ${len(params)}") + + q_idx: int | None = None + if query is not None: + params.append(query) + q_idx = len(params) + + text_conds = [*common] + if q_idx is not None: + text_conds.append(f"{_TEXT_TSV} @@ {_TSQUERY.format(q_idx)}") + + branches = [_HITS.format(cols=_TEXT_COLS, where=" AND ".join(text_conds))] + if q_idx is not None: + stt_conds = [*common, f"{_STT_TSV} @@ {_TSQUERY.format(q_idx)}"] + branches.append(_STT_HITS.format(cols=_STT_COLS, where=" AND ".join(stt_conds))) + + params.append(min(filters.limit, MAX_LIMIT)) + limit_idx = len(params) + params.append(filters.offset) + offset_idx = len(params) + page = f"ORDER BY date DESC LIMIT ${limit_idx} OFFSET ${offset_idx}" + + if len(branches) == 1: + sql = f"SELECT * FROM ({branches[0]}) hits {page}" # noqa: S608 + else: + union = " UNION ALL ".join(f"({b})" for b in branches) + dedup = ( + "SELECT DISTINCT ON (chat_id, message_id) * " # noqa: S608 + f"FROM ({union}) hits ORDER BY chat_id, message_id, source DESC" + ) + sql = f"SELECT * FROM ({dedup}) d {page}" # noqa: S608 + + rows = await pool.fetch(sql, *params) + return [SearchHit(**dict(row)) for row in rows]