feat: search functionality
This commit is contained in:
@@ -0,0 +1,32 @@
|
||||
"""fts expression gin on messages and media
|
||||
|
||||
Revision ID: b7e3d9f1a4c6
|
||||
Revises: a1d4f7c2e9b8
|
||||
Create Date: 2026-05-29 22:30:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "b7e3d9f1a4c6"
|
||||
down_revision: str | None = "a1d4f7c2e9b8"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"CREATE INDEX ix_messages_tsv ON messages USING gin "
|
||||
"(to_tsvector('russian', coalesce(text, '')))"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX ix_media_tsv ON media USING gin "
|
||||
"(to_tsvector('russian', coalesce(extracted_text, '')))"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS ix_media_tsv")
|
||||
op.execute("DROP INDEX IF EXISTS ix_messages_tsv")
|
||||
@@ -0,0 +1,30 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
DEFAULT_LIMIT = 50
|
||||
MAX_LIMIT = 500
|
||||
|
||||
|
||||
class SearchFilters(BaseModel):
|
||||
account_id: int
|
||||
query: str | None = None
|
||||
chat_id: int | None = None
|
||||
sender_id: int | None = None
|
||||
has_media: bool | None = None
|
||||
date_from: datetime | None = None
|
||||
date_to: datetime | None = None
|
||||
regex: str | None = None
|
||||
limit: int = DEFAULT_LIMIT
|
||||
offset: int = 0
|
||||
|
||||
|
||||
class SearchHit(BaseModel):
|
||||
chat_id: int
|
||||
message_id: int
|
||||
date: datetime
|
||||
sender_id: int | None
|
||||
text: str | None
|
||||
source: Literal["text", "stt"]
|
||||
extracted_text: str | None
|
||||
@@ -0,0 +1,88 @@
|
||||
import asyncpg
|
||||
|
||||
from utils.search.models import MAX_LIMIT, SearchFilters, SearchHit
|
||||
|
||||
_TEXT_TSV = "to_tsvector('russian', coalesce(m.text, ''))"
|
||||
_STT_TSV = "to_tsvector('russian', coalesce(md.extracted_text, ''))"
|
||||
_TSQUERY = "websearch_to_tsquery('russian', ${})"
|
||||
|
||||
_HITS = "{cols} FROM messages m WHERE {where}" # text branch
|
||||
_STT_HITS = (
|
||||
"{cols} FROM messages m "
|
||||
"JOIN media md ON md.account_id = m.account_id "
|
||||
"AND md.chat_id = m.chat_id AND md.message_id = m.message_id "
|
||||
"WHERE {where}"
|
||||
)
|
||||
_TEXT_COLS = (
|
||||
"SELECT m.chat_id, m.message_id, m.date, m.sender_id, m.text, "
|
||||
"'text' AS source, NULL::text AS extracted_text"
|
||||
)
|
||||
_STT_COLS = (
|
||||
"SELECT m.chat_id, m.message_id, m.date, m.sender_id, m.text, "
|
||||
"'stt' AS source, md.extracted_text"
|
||||
)
|
||||
|
||||
|
||||
def _common_conditions(filters: SearchFilters, params: list[object]) -> list[str]:
|
||||
params.append(filters.account_id)
|
||||
conds = [f"m.account_id = ${len(params)}"]
|
||||
if filters.chat_id is not None:
|
||||
params.append(filters.chat_id)
|
||||
conds.append(f"m.chat_id = ${len(params)}")
|
||||
if filters.sender_id is not None:
|
||||
params.append(filters.sender_id)
|
||||
conds.append(f"m.sender_id = ${len(params)}")
|
||||
if filters.has_media is not None:
|
||||
params.append(filters.has_media)
|
||||
conds.append(f"m.has_media = ${len(params)}")
|
||||
if filters.date_from is not None:
|
||||
params.append(filters.date_from)
|
||||
conds.append(f"m.date >= ${len(params)}")
|
||||
if filters.date_to is not None:
|
||||
params.append(filters.date_to)
|
||||
conds.append(f"m.date <= ${len(params)}")
|
||||
return conds
|
||||
|
||||
|
||||
async def search_messages(
|
||||
pool: asyncpg.Pool, filters: SearchFilters
|
||||
) -> list[SearchHit]:
|
||||
query = filters.query.strip() if filters.query else None
|
||||
params: list[object] = []
|
||||
common = _common_conditions(filters, params)
|
||||
if filters.regex is not None:
|
||||
params.append(filters.regex)
|
||||
common.append(f"m.text ~ ${len(params)}")
|
||||
|
||||
q_idx: int | None = None
|
||||
if query is not None:
|
||||
params.append(query)
|
||||
q_idx = len(params)
|
||||
|
||||
text_conds = [*common]
|
||||
if q_idx is not None:
|
||||
text_conds.append(f"{_TEXT_TSV} @@ {_TSQUERY.format(q_idx)}")
|
||||
|
||||
branches = [_HITS.format(cols=_TEXT_COLS, where=" AND ".join(text_conds))]
|
||||
if q_idx is not None:
|
||||
stt_conds = [*common, f"{_STT_TSV} @@ {_TSQUERY.format(q_idx)}"]
|
||||
branches.append(_STT_HITS.format(cols=_STT_COLS, where=" AND ".join(stt_conds)))
|
||||
|
||||
params.append(min(filters.limit, MAX_LIMIT))
|
||||
limit_idx = len(params)
|
||||
params.append(filters.offset)
|
||||
offset_idx = len(params)
|
||||
page = f"ORDER BY date DESC LIMIT ${limit_idx} OFFSET ${offset_idx}"
|
||||
|
||||
if len(branches) == 1:
|
||||
sql = f"SELECT * FROM ({branches[0]}) hits {page}" # noqa: S608
|
||||
else:
|
||||
union = " UNION ALL ".join(f"({b})" for b in branches)
|
||||
dedup = (
|
||||
"SELECT DISTINCT ON (chat_id, message_id) * " # noqa: S608
|
||||
f"FROM ({union}) hits ORDER BY chat_id, message_id, source DESC"
|
||||
)
|
||||
sql = f"SELECT * FROM ({dedup}) d {page}" # noqa: S608
|
||||
|
||||
rows = await pool.fetch(sql, *params)
|
||||
return [SearchHit(**dict(row)) for row in rows]
|
||||
Reference in New Issue
Block a user