feat: search functionality

This commit is contained in:
h
2026-05-29 22:50:21 +02:00
parent bcb94b6474
commit 6a5cde6ae4
4 changed files with 150 additions and 0 deletions
@@ -0,0 +1,32 @@
"""fts expression gin on messages and media
Revision ID: b7e3d9f1a4c6
Revises: a1d4f7c2e9b8
Create Date: 2026-05-29 22:30:00.000000
"""
from collections.abc import Sequence
from alembic import op
revision: str = "b7e3d9f1a4c6"
down_revision: str | None = "a1d4f7c2e9b8"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
op.execute(
"CREATE INDEX ix_messages_tsv ON messages USING gin "
"(to_tsvector('russian', coalesce(text, '')))"
)
op.execute(
"CREATE INDEX ix_media_tsv ON media USING gin "
"(to_tsvector('russian', coalesce(extracted_text, '')))"
)
def downgrade() -> None:
op.execute("DROP INDEX IF EXISTS ix_media_tsv")
op.execute("DROP INDEX IF EXISTS ix_messages_tsv")
+30
View File
@@ -0,0 +1,30 @@
from datetime import datetime
from typing import Literal
from pydantic import BaseModel
DEFAULT_LIMIT = 50
MAX_LIMIT = 500
class SearchFilters(BaseModel):
account_id: int
query: str | None = None
chat_id: int | None = None
sender_id: int | None = None
has_media: bool | None = None
date_from: datetime | None = None
date_to: datetime | None = None
regex: str | None = None
limit: int = DEFAULT_LIMIT
offset: int = 0
class SearchHit(BaseModel):
chat_id: int
message_id: int
date: datetime
sender_id: int | None
text: str | None
source: Literal["text", "stt"]
extracted_text: str | None
+88
View File
@@ -0,0 +1,88 @@
import asyncpg
from utils.search.models import MAX_LIMIT, SearchFilters, SearchHit
_TEXT_TSV = "to_tsvector('russian', coalesce(m.text, ''))"
_STT_TSV = "to_tsvector('russian', coalesce(md.extracted_text, ''))"
_TSQUERY = "websearch_to_tsquery('russian', ${})"
_HITS = "{cols} FROM messages m WHERE {where}" # text branch
_STT_HITS = (
"{cols} FROM messages m "
"JOIN media md ON md.account_id = m.account_id "
"AND md.chat_id = m.chat_id AND md.message_id = m.message_id "
"WHERE {where}"
)
_TEXT_COLS = (
"SELECT m.chat_id, m.message_id, m.date, m.sender_id, m.text, "
"'text' AS source, NULL::text AS extracted_text"
)
_STT_COLS = (
"SELECT m.chat_id, m.message_id, m.date, m.sender_id, m.text, "
"'stt' AS source, md.extracted_text"
)
def _common_conditions(filters: SearchFilters, params: list[object]) -> list[str]:
params.append(filters.account_id)
conds = [f"m.account_id = ${len(params)}"]
if filters.chat_id is not None:
params.append(filters.chat_id)
conds.append(f"m.chat_id = ${len(params)}")
if filters.sender_id is not None:
params.append(filters.sender_id)
conds.append(f"m.sender_id = ${len(params)}")
if filters.has_media is not None:
params.append(filters.has_media)
conds.append(f"m.has_media = ${len(params)}")
if filters.date_from is not None:
params.append(filters.date_from)
conds.append(f"m.date >= ${len(params)}")
if filters.date_to is not None:
params.append(filters.date_to)
conds.append(f"m.date <= ${len(params)}")
return conds
async def search_messages(
pool: asyncpg.Pool, filters: SearchFilters
) -> list[SearchHit]:
query = filters.query.strip() if filters.query else None
params: list[object] = []
common = _common_conditions(filters, params)
if filters.regex is not None:
params.append(filters.regex)
common.append(f"m.text ~ ${len(params)}")
q_idx: int | None = None
if query is not None:
params.append(query)
q_idx = len(params)
text_conds = [*common]
if q_idx is not None:
text_conds.append(f"{_TEXT_TSV} @@ {_TSQUERY.format(q_idx)}")
branches = [_HITS.format(cols=_TEXT_COLS, where=" AND ".join(text_conds))]
if q_idx is not None:
stt_conds = [*common, f"{_STT_TSV} @@ {_TSQUERY.format(q_idx)}"]
branches.append(_STT_HITS.format(cols=_STT_COLS, where=" AND ".join(stt_conds)))
params.append(min(filters.limit, MAX_LIMIT))
limit_idx = len(params)
params.append(filters.offset)
offset_idx = len(params)
page = f"ORDER BY date DESC LIMIT ${limit_idx} OFFSET ${offset_idx}"
if len(branches) == 1:
sql = f"SELECT * FROM ({branches[0]}) hits {page}" # noqa: S608
else:
union = " UNION ALL ".join(f"({b})" for b in branches)
dedup = (
"SELECT DISTINCT ON (chat_id, message_id) * " # noqa: S608
f"FROM ({union}) hits ORDER BY chat_id, message_id, source DESC"
)
sql = f"SELECT * FROM ({dedup}) d {page}" # noqa: S608
rows = await pool.fetch(sql, *params)
return [SearchHit(**dict(row)) for row in rows]