Files
aiodatalite/datalite/fetch.py
2022-05-29 01:32:59 +03:00

165 lines
6.2 KiB
Python

import sqlite3 as sql
from typing import List, Tuple, Any
from .commons import _convert_sql_format, _get_table_cols
def _insert_pagination(query: str, page: int, element_count: int) -> str:
"""
Insert the pagination arguments if page number is given.
:param query: Query to insert to
:param page: Page to get.
:param element_count: Element count in each page.
:return: The modified (or not) query.
"""
if page:
query += f" ORDER BY obj_id LIMIT {element_count} OFFSET {(page - 1) * element_count}"
return query + ";"
def is_fetchable(class_: type, obj_id: int) -> bool:
"""
Check if a record is fetchable given its obj_id and
class_ type.
:param class_: Class type of the object.
:param obj_id: Unique obj_id of the object.
:return: If the object is fetchable.
"""
with sql.connect(getattr(class_, 'db_path')) as con:
cur: sql.Cursor = con.cursor()
try:
cur.execute(f"SELECT 1 FROM {class_.__name__.lower()} WHERE obj_id = ?;", (obj_id, ))
except sql.OperationalError:
raise KeyError(f"Table {class_.__name__.lower()} does not exist.")
return bool(cur.fetchall())
def fetch_equals(class_: type, field: str, value: Any, ) -> Any:
"""
Fetch a class_ type variable from its bound db.
:param class_: Class to fetch.
:param field: Field to check for, by default, object id.
:param value: Value of the field to check for.
:return: The object whose data is taken from the database.
"""
table_name = class_.__name__.lower()
with sql.connect(getattr(class_, 'db_path')) as con:
cur: sql.Cursor = con.cursor()
cur.execute(f"SELECT * FROM {table_name} WHERE {field} = ?;", (value, ))
obj_id, *field_values = list(cur.fetchone())
field_names: List[str] = _get_table_cols(cur, class_.__name__.lower())
kwargs = dict(zip(field_names, field_values))
obj = class_(**kwargs)
setattr(obj, "obj_id", obj_id)
return obj
def fetch_from(class_: type, obj_id: int) -> Any:
"""
Fetch a class_ type variable from its bound dv.
:param class_: Class to fetch from.
:param obj_id: Unique object id of the object.
:return: The fetched object.
"""
if not is_fetchable(class_, obj_id):
raise KeyError(f"An object with {obj_id} of type {class_.__name__} does not exist, or"
f"otherwise is unreachable.")
return fetch_equals(class_, 'obj_id', obj_id)
def _convert_record_to_object(class_: type, record: Tuple[Any], field_names: List[str]) -> Any:
"""
Convert a given record fetched from an SQL instance to a Python Object of given class_.
:param class_: Class type to convert the record to.
:param record: Record to get data from.
:param field_names: Field names of the class.
:return: the created object.
"""
kwargs = dict(zip(field_names, record[1:]))
field_types = {key: value.type for key, value in class_.__dataclass_fields__.items()}
for key in kwargs:
if field_types[key] == bytes:
kwargs[key] = bytes(kwargs[key], encoding='utf-8')
obj_id = record[0]
obj = class_(**kwargs)
setattr(obj, "obj_id", obj_id)
return obj
def fetch_if(class_: type, condition: str, page: int = 0, element_count: int = 10) -> tuple:
"""
Fetch all class_ type variables from the bound db,
provided they fit the given condition
:param class_: Class type to fetch.
:param condition: Condition to check for.
:param page: Which page to retrieve, default all. (0 means closed).
:param element_count: Element count in each page.
:return: A tuple of records that fit the given condition
of given type class_.
"""
table_name = class_.__name__.lower()
with sql.connect(getattr(class_, 'db_path')) as con:
cur: sql.Cursor = con.cursor()
cur.execute(_insert_pagination(f"SELECT * FROM {table_name} WHERE {condition}", page, element_count))
records: list = cur.fetchall()
field_names: List[str] = _get_table_cols(cur, table_name)
return tuple(_convert_record_to_object(class_, record, field_names) for record in records)
def fetch_where(class_: type, field: str, value: Any, page: int = 0, element_count: int = 10) -> tuple:
"""
Fetch all class_ type variables from the bound db,
provided that the field of the records fit the
given value.
:param class_: Class of the records.
:param field: Field to check.
:param value: Value to check for.
:param page: Which page to retrieve, default all. (0 means closed).
:param element_count: Element count in each page.
:return: A tuple of the records.
"""
return fetch_if(class_, f"{field} = {_convert_sql_format(value)}", page, element_count)
def fetch_range(class_: type, range_: range) -> tuple:
"""
Fetch the records in a given range of object ids.
:param class_: Class of the records.
:param range_: Range of the object ids.
:return: A tuple of class_ type objects whose values
come from the class_' bound database.
"""
return tuple(fetch_from(class_, obj_id) for obj_id in range_ if is_fetchable(class_, obj_id))
def fetch_all(class_: type, page: int = 0, element_count: int = 10) -> tuple:
"""
Fetchall the records in the bound database.
:param class_: Class of the records.
:param page: Which page to retrieve, default all. (0 means closed).
:param element_count: Element count in each page.
:return: All the records of type class_ in
the bound database as a tuple.
"""
try:
db_path = getattr(class_, 'db_path')
except AttributeError:
raise TypeError("Given class is not decorated with datalite.")
with sql.connect(db_path) as con:
cur: sql.Cursor = con.cursor()
try:
cur.execute(_insert_pagination(f"SELECT * FROM {class_.__name__.lower()}", page, element_count))
except sql.OperationalError:
raise TypeError(f"No record of type {class_.__name__.lower()}")
records = cur.fetchall()
field_names: List[str] = _get_table_cols(cur, class_.__name__.lower())
return tuple(_convert_record_to_object(class_, record, field_names) for record in records)