Migrate to aiosqlite, add some undocumented features, will be documented later. Some problems with mass_actions, don't use them, because loading algorythm may change
This commit is contained in:
@@ -1,6 +1,12 @@
|
||||
import sqlite3 as sql
|
||||
from typing import List, Tuple, Any
|
||||
from .commons import _convert_sql_format, _get_table_cols
|
||||
import aiosqlite
|
||||
from typing import Any, List, Tuple, TypeVar, Type, cast
|
||||
|
||||
from .commons import _convert_sql_format, _get_table_cols, _tweaked_load_value
|
||||
|
||||
import base64
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _insert_pagination(query: str, page: int, element_count: int) -> str:
|
||||
@@ -17,7 +23,7 @@ def _insert_pagination(query: str, page: int, element_count: int) -> str:
|
||||
return query + ";"
|
||||
|
||||
|
||||
def is_fetchable(class_: type, obj_id: int) -> bool:
|
||||
async def is_fetchable(class_: type, obj_id: int) -> bool:
|
||||
"""
|
||||
Check if a record is fetchable given its obj_id and
|
||||
class_ type.
|
||||
@@ -26,16 +32,22 @@ def is_fetchable(class_: type, obj_id: int) -> bool:
|
||||
:param obj_id: Unique obj_id of the object.
|
||||
:return: If the object is fetchable.
|
||||
"""
|
||||
with sql.connect(getattr(class_, 'db_path')) as con:
|
||||
cur: sql.Cursor = con.cursor()
|
||||
async with aiosqlite.connect(getattr(class_, "db_path")) as con:
|
||||
cur: aiosqlite.Cursor = await con.cursor()
|
||||
try:
|
||||
cur.execute(f"SELECT 1 FROM {class_.__name__.lower()} WHERE obj_id = ?;", (obj_id, ))
|
||||
except sql.OperationalError:
|
||||
await cur.execute(
|
||||
f"SELECT 1 FROM {class_.__name__.lower()} WHERE obj_id = ?;", (obj_id,)
|
||||
)
|
||||
except aiosqlite.OperationalError:
|
||||
raise KeyError(f"Table {class_.__name__.lower()} does not exist.")
|
||||
return bool(cur.fetchall())
|
||||
return bool(await cur.fetchall())
|
||||
|
||||
|
||||
def fetch_equals(class_: type, field: str, value: Any, ) -> Any:
|
||||
async def fetch_equals(
|
||||
class_: Type[T],
|
||||
field: str,
|
||||
value: Any,
|
||||
) -> T:
|
||||
"""
|
||||
Fetch a class_ type variable from its bound db.
|
||||
|
||||
@@ -45,18 +57,30 @@ def fetch_equals(class_: type, field: str, value: Any, ) -> Any:
|
||||
:return: The object whose data is taken from the database.
|
||||
"""
|
||||
table_name = class_.__name__.lower()
|
||||
with sql.connect(getattr(class_, 'db_path')) as con:
|
||||
cur: sql.Cursor = con.cursor()
|
||||
cur.execute(f"SELECT * FROM {table_name} WHERE {field} = ?;", (value, ))
|
||||
obj_id, *field_values = list(cur.fetchone())
|
||||
field_names: List[str] = _get_table_cols(cur, class_.__name__.lower())
|
||||
async with aiosqlite.connect(getattr(class_, "db_path")) as con:
|
||||
cur: aiosqlite.Cursor = await con.cursor()
|
||||
await cur.execute(f"SELECT * FROM {table_name} WHERE {field} = ?;", (value,))
|
||||
obj_id, *field_values = list(await cur.fetchone())
|
||||
field_names: List[str] = await _get_table_cols(cur, class_.__name__.lower())
|
||||
|
||||
kwargs = dict(zip(field_names, field_values))
|
||||
if getattr(class_, "tweaked"):
|
||||
types_table = getattr(class_, "types_table")
|
||||
field_types = {
|
||||
key: value.type for key, value in class_.__dataclass_fields__.items()
|
||||
}
|
||||
for key in kwargs.keys():
|
||||
if field_types[key] not in types_table.keys():
|
||||
kwargs[key] = _tweaked_load_value(
|
||||
kwargs[key] if isinstance(kwargs[key], bytes) else base64.decodebytes(kwargs[key].encode("utf-8"))
|
||||
)
|
||||
|
||||
obj = class_(**kwargs)
|
||||
setattr(obj, "obj_id", obj_id)
|
||||
return obj
|
||||
|
||||
|
||||
def fetch_from(class_: type, obj_id: int) -> Any:
|
||||
async def fetch_from(class_: Type[T], obj_id: int) -> T:
|
||||
"""
|
||||
Fetch a class_ type variable from its bound dv.
|
||||
|
||||
@@ -64,13 +88,17 @@ def fetch_from(class_: type, obj_id: int) -> Any:
|
||||
:param obj_id: Unique object id of the object.
|
||||
:return: The fetched object.
|
||||
"""
|
||||
if not is_fetchable(class_, obj_id):
|
||||
raise KeyError(f"An object with {obj_id} of type {class_.__name__} does not exist, or"
|
||||
f"otherwise is unreachable.")
|
||||
return fetch_equals(class_, 'obj_id', obj_id)
|
||||
if not await is_fetchable(class_, obj_id):
|
||||
raise KeyError(
|
||||
f"An object with {obj_id} of type {class_.__name__} does not exist, or"
|
||||
f"otherwise is unreachable."
|
||||
)
|
||||
return await fetch_equals(class_, "obj_id", obj_id)
|
||||
|
||||
|
||||
def _convert_record_to_object(class_: type, record: Tuple[Any], field_names: List[str]) -> Any:
|
||||
def _convert_record_to_object(
|
||||
class_: Type[T], record: Tuple[Any], field_names: List[str]
|
||||
) -> T:
|
||||
"""
|
||||
Convert a given record fetched from an SQL instance to a Python Object of given class_.
|
||||
|
||||
@@ -80,17 +108,30 @@ def _convert_record_to_object(class_: type, record: Tuple[Any], field_names: Lis
|
||||
:return: the created object.
|
||||
"""
|
||||
kwargs = dict(zip(field_names, record[1:]))
|
||||
field_types = {key: value.type for key, value in class_.__dataclass_fields__.items()}
|
||||
for key in kwargs:
|
||||
field_types = {
|
||||
key: value.type for key, value in class_.__dataclass_fields__.items()
|
||||
}
|
||||
is_tweaked = getattr(class_, "tweaked")
|
||||
types_table = getattr(class_, "types_table")
|
||||
for key in kwargs.keys():
|
||||
if field_types[key] == bytes:
|
||||
kwargs[key] = bytes(kwargs[key], encoding='utf-8')
|
||||
kwargs[key] = bytes(kwargs[key], encoding="utf-8")
|
||||
|
||||
elif is_tweaked:
|
||||
if field_types[key] not in types_table.keys():
|
||||
kwargs[key] = _tweaked_load_value(
|
||||
kwargs[key] if isinstance(kwargs[key], bytes) else base64.decodebytes(kwargs[key].encode("utf-8"))
|
||||
)
|
||||
|
||||
obj_id = record[0]
|
||||
obj = class_(**kwargs)
|
||||
setattr(obj, "obj_id", obj_id)
|
||||
return obj
|
||||
|
||||
|
||||
def fetch_if(class_: type, condition: str, page: int = 0, element_count: int = 10) -> tuple:
|
||||
async def fetch_if(
|
||||
class_: Type[T], condition: str, page: int = 0, element_count: int = 10
|
||||
) -> T:
|
||||
"""
|
||||
Fetch all class_ type variables from the bound db,
|
||||
provided they fit the given condition
|
||||
@@ -103,18 +144,27 @@ def fetch_if(class_: type, condition: str, page: int = 0, element_count: int = 1
|
||||
of given type class_.
|
||||
"""
|
||||
table_name = class_.__name__.lower()
|
||||
with sql.connect(getattr(class_, 'db_path')) as con:
|
||||
cur: sql.Cursor = con.cursor()
|
||||
cur.execute(_insert_pagination(f"SELECT * FROM {table_name} WHERE {condition}", page, element_count))
|
||||
records: list = cur.fetchall()
|
||||
field_names: List[str] = _get_table_cols(cur, table_name)
|
||||
return tuple(_convert_record_to_object(class_, record, field_names) for record in records)
|
||||
async with aiosqlite.connect(getattr(class_, "db_path")) as con:
|
||||
cur: aiosqlite.Cursor = await con.cursor()
|
||||
await cur.execute(
|
||||
_insert_pagination(
|
||||
f"SELECT * FROM {table_name} WHERE {condition}", page, element_count
|
||||
)
|
||||
)
|
||||
# noinspection PyTypeChecker
|
||||
records: list = await cur.fetchall()
|
||||
field_names: List[str] = await _get_table_cols(cur, table_name)
|
||||
return tuple(
|
||||
_convert_record_to_object(class_, record, field_names) for record in records
|
||||
)
|
||||
|
||||
|
||||
def fetch_where(class_: type, field: str, value: Any, page: int = 0, element_count: int = 10) -> tuple:
|
||||
async def fetch_where(
|
||||
class_: Type[T], field: str, value: Any, page: int = 0, element_count: int = 10
|
||||
) -> tuple[T]:
|
||||
"""
|
||||
Fetch all class_ type variables from the bound db,
|
||||
provided that the field of the records fit the
|
||||
Fetch all class_ type variables from the bound db
|
||||
if the field of the records fit the
|
||||
given value.
|
||||
|
||||
:param class_: Class of the records.
|
||||
@@ -124,10 +174,12 @@ def fetch_where(class_: type, field: str, value: Any, page: int = 0, element_cou
|
||||
:param element_count: Element count in each page.
|
||||
:return: A tuple of the records.
|
||||
"""
|
||||
return fetch_if(class_, f"{field} = {_convert_sql_format(value)}", page, element_count)
|
||||
return await fetch_if(
|
||||
class_, f"{field} = {_convert_sql_format(value, getattr(class_, 'types_table'))}", page, element_count
|
||||
)
|
||||
|
||||
|
||||
def fetch_range(class_: type, range_: range) -> tuple:
|
||||
async def fetch_range(class_: Type[T], range_: range) -> tuple[T]:
|
||||
"""
|
||||
Fetch the records in a given range of object ids.
|
||||
|
||||
@@ -136,12 +188,23 @@ def fetch_range(class_: type, range_: range) -> tuple:
|
||||
:return: A tuple of class_ type objects whose values
|
||||
come from the class_' bound database.
|
||||
"""
|
||||
return tuple(fetch_from(class_, obj_id) for obj_id in range_ if is_fetchable(class_, obj_id))
|
||||
return cast(
|
||||
tuple[T],
|
||||
tuple(
|
||||
[
|
||||
(await fetch_from(class_, obj_id))
|
||||
for obj_id in range_
|
||||
if (await is_fetchable(class_, obj_id))
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def fetch_all(class_: type, page: int = 0, element_count: int = 10) -> tuple:
|
||||
async def fetch_all(
|
||||
class_: Type[T], page: int = 0, element_count: int = 10
|
||||
) -> tuple[T]:
|
||||
"""
|
||||
Fetchall the records in the bound database.
|
||||
Fetch all the records in the bound database.
|
||||
|
||||
:param class_: Class of the records.
|
||||
:param page: Which page to retrieve, default all. (0 means closed).
|
||||
@@ -150,15 +213,26 @@ def fetch_all(class_: type, page: int = 0, element_count: int = 10) -> tuple:
|
||||
the bound database as a tuple.
|
||||
"""
|
||||
try:
|
||||
db_path = getattr(class_, 'db_path')
|
||||
db_path = getattr(class_, "db_path")
|
||||
except AttributeError:
|
||||
raise TypeError("Given class is not decorated with datalite.")
|
||||
with sql.connect(db_path) as con:
|
||||
cur: sql.Cursor = con.cursor()
|
||||
async with aiosqlite.connect(db_path) as con:
|
||||
cur: aiosqlite.Cursor = await con.cursor()
|
||||
try:
|
||||
cur.execute(_insert_pagination(f"SELECT * FROM {class_.__name__.lower()}", page, element_count))
|
||||
except sql.OperationalError:
|
||||
await cur.execute(
|
||||
_insert_pagination(
|
||||
f"SELECT * FROM {class_.__name__.lower()}", page, element_count
|
||||
)
|
||||
)
|
||||
except aiosqlite.OperationalError:
|
||||
raise TypeError(f"No record of type {class_.__name__.lower()}")
|
||||
records = cur.fetchall()
|
||||
field_names: List[str] = _get_table_cols(cur, class_.__name__.lower())
|
||||
return tuple(_convert_record_to_object(class_, record, field_names) for record in records)
|
||||
records = await cur.fetchall()
|
||||
field_names: List[str] = await _get_table_cols(cur, class_.__name__.lower())
|
||||
# noinspection PyTypeChecker
|
||||
return cast(
|
||||
tuple[T],
|
||||
tuple(_convert_record_to_object(class_, record, field_names) for record in records),
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["is_fetchable", "fetch_equals", "fetch_from", "fetch_if", "fetch_where", "fetch_range", "fetch_all"]
|
||||
|
||||
Reference in New Issue
Block a user