feat: search functionality
This commit is contained in:
@@ -0,0 +1,32 @@
|
|||||||
|
"""fts expression gin on messages and media
|
||||||
|
|
||||||
|
Revision ID: b7e3d9f1a4c6
|
||||||
|
Revises: a1d4f7c2e9b8
|
||||||
|
Create Date: 2026-05-29 22:30:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "b7e3d9f1a4c6"
|
||||||
|
down_revision: str | None = "a1d4f7c2e9b8"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX ix_messages_tsv ON messages USING gin "
|
||||||
|
"(to_tsvector('russian', coalesce(text, '')))"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX ix_media_tsv ON media USING gin "
|
||||||
|
"(to_tsvector('russian', coalesce(extracted_text, '')))"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.execute("DROP INDEX IF EXISTS ix_media_tsv")
|
||||||
|
op.execute("DROP INDEX IF EXISTS ix_messages_tsv")
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
DEFAULT_LIMIT = 50
|
||||||
|
MAX_LIMIT = 500
|
||||||
|
|
||||||
|
|
||||||
|
class SearchFilters(BaseModel):
|
||||||
|
account_id: int
|
||||||
|
query: str | None = None
|
||||||
|
chat_id: int | None = None
|
||||||
|
sender_id: int | None = None
|
||||||
|
has_media: bool | None = None
|
||||||
|
date_from: datetime | None = None
|
||||||
|
date_to: datetime | None = None
|
||||||
|
regex: str | None = None
|
||||||
|
limit: int = DEFAULT_LIMIT
|
||||||
|
offset: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class SearchHit(BaseModel):
|
||||||
|
chat_id: int
|
||||||
|
message_id: int
|
||||||
|
date: datetime
|
||||||
|
sender_id: int | None
|
||||||
|
text: str | None
|
||||||
|
source: Literal["text", "stt"]
|
||||||
|
extracted_text: str | None
|
||||||
@@ -0,0 +1,88 @@
|
|||||||
|
import asyncpg
|
||||||
|
|
||||||
|
from utils.search.models import MAX_LIMIT, SearchFilters, SearchHit
|
||||||
|
|
||||||
|
_TEXT_TSV = "to_tsvector('russian', coalesce(m.text, ''))"
|
||||||
|
_STT_TSV = "to_tsvector('russian', coalesce(md.extracted_text, ''))"
|
||||||
|
_TSQUERY = "websearch_to_tsquery('russian', ${})"
|
||||||
|
|
||||||
|
_HITS = "{cols} FROM messages m WHERE {where}" # text branch
|
||||||
|
_STT_HITS = (
|
||||||
|
"{cols} FROM messages m "
|
||||||
|
"JOIN media md ON md.account_id = m.account_id "
|
||||||
|
"AND md.chat_id = m.chat_id AND md.message_id = m.message_id "
|
||||||
|
"WHERE {where}"
|
||||||
|
)
|
||||||
|
_TEXT_COLS = (
|
||||||
|
"SELECT m.chat_id, m.message_id, m.date, m.sender_id, m.text, "
|
||||||
|
"'text' AS source, NULL::text AS extracted_text"
|
||||||
|
)
|
||||||
|
_STT_COLS = (
|
||||||
|
"SELECT m.chat_id, m.message_id, m.date, m.sender_id, m.text, "
|
||||||
|
"'stt' AS source, md.extracted_text"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _common_conditions(filters: SearchFilters, params: list[object]) -> list[str]:
|
||||||
|
params.append(filters.account_id)
|
||||||
|
conds = [f"m.account_id = ${len(params)}"]
|
||||||
|
if filters.chat_id is not None:
|
||||||
|
params.append(filters.chat_id)
|
||||||
|
conds.append(f"m.chat_id = ${len(params)}")
|
||||||
|
if filters.sender_id is not None:
|
||||||
|
params.append(filters.sender_id)
|
||||||
|
conds.append(f"m.sender_id = ${len(params)}")
|
||||||
|
if filters.has_media is not None:
|
||||||
|
params.append(filters.has_media)
|
||||||
|
conds.append(f"m.has_media = ${len(params)}")
|
||||||
|
if filters.date_from is not None:
|
||||||
|
params.append(filters.date_from)
|
||||||
|
conds.append(f"m.date >= ${len(params)}")
|
||||||
|
if filters.date_to is not None:
|
||||||
|
params.append(filters.date_to)
|
||||||
|
conds.append(f"m.date <= ${len(params)}")
|
||||||
|
return conds
|
||||||
|
|
||||||
|
|
||||||
|
async def search_messages(
|
||||||
|
pool: asyncpg.Pool, filters: SearchFilters
|
||||||
|
) -> list[SearchHit]:
|
||||||
|
query = filters.query.strip() if filters.query else None
|
||||||
|
params: list[object] = []
|
||||||
|
common = _common_conditions(filters, params)
|
||||||
|
if filters.regex is not None:
|
||||||
|
params.append(filters.regex)
|
||||||
|
common.append(f"m.text ~ ${len(params)}")
|
||||||
|
|
||||||
|
q_idx: int | None = None
|
||||||
|
if query is not None:
|
||||||
|
params.append(query)
|
||||||
|
q_idx = len(params)
|
||||||
|
|
||||||
|
text_conds = [*common]
|
||||||
|
if q_idx is not None:
|
||||||
|
text_conds.append(f"{_TEXT_TSV} @@ {_TSQUERY.format(q_idx)}")
|
||||||
|
|
||||||
|
branches = [_HITS.format(cols=_TEXT_COLS, where=" AND ".join(text_conds))]
|
||||||
|
if q_idx is not None:
|
||||||
|
stt_conds = [*common, f"{_STT_TSV} @@ {_TSQUERY.format(q_idx)}"]
|
||||||
|
branches.append(_STT_HITS.format(cols=_STT_COLS, where=" AND ".join(stt_conds)))
|
||||||
|
|
||||||
|
params.append(min(filters.limit, MAX_LIMIT))
|
||||||
|
limit_idx = len(params)
|
||||||
|
params.append(filters.offset)
|
||||||
|
offset_idx = len(params)
|
||||||
|
page = f"ORDER BY date DESC LIMIT ${limit_idx} OFFSET ${offset_idx}"
|
||||||
|
|
||||||
|
if len(branches) == 1:
|
||||||
|
sql = f"SELECT * FROM ({branches[0]}) hits {page}" # noqa: S608
|
||||||
|
else:
|
||||||
|
union = " UNION ALL ".join(f"({b})" for b in branches)
|
||||||
|
dedup = (
|
||||||
|
"SELECT DISTINCT ON (chat_id, message_id) * " # noqa: S608
|
||||||
|
f"FROM ({union}) hits ORDER BY chat_id, message_id, source DESC"
|
||||||
|
)
|
||||||
|
sql = f"SELECT * FROM ({dedup}) d {page}" # noqa: S608
|
||||||
|
|
||||||
|
rows = await pool.fetch(sql, *params)
|
||||||
|
return [SearchHit(**dict(row)) for row in rows]
|
||||||
Reference in New Issue
Block a user