This commit is contained in:
Ege Emir Özkan
2020-08-13 21:37:48 +03:00
parent 6adda2e719
commit 8098a50e52
3 changed files with 28 additions and 5 deletions

View File

@@ -6,6 +6,7 @@ verify_ssl = true
[dev-packages] [dev-packages]
[packages] [packages]
sphinx = "*"
[requires] [requires]
python_version = "3.8" python_version = "3.8"

View File

@@ -1,8 +1,14 @@
"""
Defines the Datalite decorator that can be used to convert a dataclass to
a class bound to a sqlite3 database.
"""
from typing import Dict, Optional, List, Callable from typing import Dict, Optional, List, Callable
from dataclasses import Field, asdict from dataclasses import Field, asdict
import sqlite3 as sql import sqlite3 as sql
from .commons import _convert_sql_format, _convert_type from .commons import _convert_sql_format, _convert_type
def _get_default(default_object: object, type_overload: Dict[Optional[type], str]) -> str: def _get_default(default_object: object, type_overload: Dict[Optional[type], str]) -> str:
""" """
Check if the field's default object is filled, Check if the field's default object is filled,
@@ -90,8 +96,14 @@ def _remove_entry(self) -> None:
remove_from(self.__class__, getattr(self, 'obj_id')) remove_from(self.__class__, getattr(self, 'obj_id'))
def datalite(db_path: str, type_overload: Optional[Dict[Optional[type], str]] = None, def datalite(db_path: str, type_overload: Optional[Dict[Optional[type], str]] = None) -> Callable:
*args, **kwargs) -> Callable: """Bind a dataclass to a sqlite3 database. This adds new methods to the class, such as
`create_entry()`, `remove_entry()` and `update_entry()`.
:param db_path: Path of the database to be binded.
:param type_overload: Type overload dictionary.
:return: The new dataclass.
"""
def decorator(dataclass_: type, *args_i, **kwargs_i): def decorator(dataclass_: type, *args_i, **kwargs_i):
type_table: Dict[Optional[type], str] = {None: "NULL", int: "INTEGER", float: "REAL", type_table: Dict[Optional[type], str] = {None: "NULL", int: "INTEGER", float: "REAL",
str: "TEXT", bytes: "BLOB"} str: "TEXT", bytes: "BLOB"}

View File

@@ -3,9 +3,10 @@ from typing import List, Tuple, Any
from .commons import _convert_sql_format from .commons import _convert_sql_format
def insert_pagination(query: str, page: int, element_count: int) -> str: def _insert_pagination(query: str, page: int, element_count: int) -> str:
""" """
Insert the pagination arguments if page number is given. Insert the pagination arguments if page number is given.
:param query: Query to insert to :param query: Query to insert to
:param page: Page to get. :param page: Page to get.
:param element_count: Element count in each page. :param element_count: Element count in each page.
@@ -20,6 +21,7 @@ def is_fetchable(class_: type, obj_id: int) -> bool:
""" """
Check if a record is fetchable given its obj_id and Check if a record is fetchable given its obj_id and
class_ type. class_ type.
:param class_: Class type of the object. :param class_: Class type of the object.
:param obj_id: Unique obj_id of the object. :param obj_id: Unique obj_id of the object.
:return: If the object is fetchable. :return: If the object is fetchable.
@@ -36,6 +38,7 @@ def is_fetchable(class_: type, obj_id: int) -> bool:
def _get_table_cols(cur: sql.Cursor, table_name: str) -> List[str]: def _get_table_cols(cur: sql.Cursor, table_name: str) -> List[str]:
""" """
Get the column data of a table. Get the column data of a table.
:param cur: Cursor in database. :param cur: Cursor in database.
:param table_name: Name of the table. :param table_name: Name of the table.
:return: the information about columns. :return: the information about columns.
@@ -47,6 +50,7 @@ def _get_table_cols(cur: sql.Cursor, table_name: str) -> List[str]:
def fetch_equals(class_: type, field: str, value: Any, ) -> Any: def fetch_equals(class_: type, field: str, value: Any, ) -> Any:
""" """
Fetch a class_ type variable from its bound db. Fetch a class_ type variable from its bound db.
:param class_: Class to fetch. :param class_: Class to fetch.
:param field: Field to check for, by default, object id. :param field: Field to check for, by default, object id.
:param value: Value of the field to check for. :param value: Value of the field to check for.
@@ -67,6 +71,7 @@ def fetch_equals(class_: type, field: str, value: Any, ) -> Any:
def fetch_from(class_: type, obj_id: int) -> Any: def fetch_from(class_: type, obj_id: int) -> Any:
""" """
Fetch a class_ type variable from its bound dv. Fetch a class_ type variable from its bound dv.
:param class_: Class to fetch from. :param class_: Class to fetch from.
:param obj_id: Unique object id of the object. :param obj_id: Unique object id of the object.
:return: The fetched object. :return: The fetched object.
@@ -80,6 +85,7 @@ def fetch_from(class_: type, obj_id: int) -> Any:
def _convert_record_to_object(class_: type, record: Tuple[Any], field_names: List[str]) -> Any: def _convert_record_to_object(class_: type, record: Tuple[Any], field_names: List[str]) -> Any:
""" """
Convert a given record fetched from an SQL instance to a Python Object of given class_. Convert a given record fetched from an SQL instance to a Python Object of given class_.
:param class_: Class type to convert the record to. :param class_: Class type to convert the record to.
:param record: Record to get data from. :param record: Record to get data from.
:param field_names: Field names of the class. :param field_names: Field names of the class.
@@ -100,6 +106,7 @@ def fetch_if(class_: type, condition: str, page: int = 0, element_count: int = 1
""" """
Fetch all class_ type variables from the bound db, Fetch all class_ type variables from the bound db,
provided they fit the given condition provided they fit the given condition
:param class_: Class type to fetch. :param class_: Class type to fetch.
:param condition: Condition to check for. :param condition: Condition to check for.
:param page: Which page to retrieve, default all. (0 means closed). :param page: Which page to retrieve, default all. (0 means closed).
@@ -110,7 +117,7 @@ def fetch_if(class_: type, condition: str, page: int = 0, element_count: int = 1
table_name = class_.__name__.lower() table_name = class_.__name__.lower()
with sql.connect(getattr(class_, 'db_path')) as con: with sql.connect(getattr(class_, 'db_path')) as con:
cur: sql.Cursor = con.cursor() cur: sql.Cursor = con.cursor()
cur.execute(insert_pagination(f"SELECT * FROM {table_name} WHERE {condition}", page, element_count)) cur.execute(_insert_pagination(f"SELECT * FROM {table_name} WHERE {condition}", page, element_count))
records: list = cur.fetchall() records: list = cur.fetchall()
field_names: List[str] = _get_table_cols(cur, table_name) field_names: List[str] = _get_table_cols(cur, table_name)
return tuple(_convert_record_to_object(class_, record, field_names) for record in records) return tuple(_convert_record_to_object(class_, record, field_names) for record in records)
@@ -121,6 +128,7 @@ def fetch_where(class_: type, field: str, value: Any, page: int = 0, element_cou
Fetch all class_ type variables from the bound db, Fetch all class_ type variables from the bound db,
provided that the field of the records fit the provided that the field of the records fit the
given value. given value.
:param class_: Class of the records. :param class_: Class of the records.
:param field: Field to check. :param field: Field to check.
:param value: Value to check for. :param value: Value to check for.
@@ -134,6 +142,7 @@ def fetch_where(class_: type, field: str, value: Any, page: int = 0, element_cou
def fetch_range(class_: type, range_: range) -> tuple: def fetch_range(class_: type, range_: range) -> tuple:
""" """
Fetch the records in a given range of object ids. Fetch the records in a given range of object ids.
:param class_: Class of the records. :param class_: Class of the records.
:param range_: Range of the object ids. :param range_: Range of the object ids.
:return: A tuple of class_ type objects whose values :return: A tuple of class_ type objects whose values
@@ -145,6 +154,7 @@ def fetch_range(class_: type, range_: range) -> tuple:
def fetch_all(class_: type, page: int = 0, element_count: int = 10) -> tuple: def fetch_all(class_: type, page: int = 0, element_count: int = 10) -> tuple:
""" """
Fetchall the records in the bound database. Fetchall the records in the bound database.
:param class_: Class of the records. :param class_: Class of the records.
:param page: Which page to retrieve, default all. (0 means closed). :param page: Which page to retrieve, default all. (0 means closed).
:param element_count: Element count in each page. :param element_count: Element count in each page.
@@ -158,7 +168,7 @@ def fetch_all(class_: type, page: int = 0, element_count: int = 10) -> tuple:
with sql.connect(db_path) as con: with sql.connect(db_path) as con:
cur: sql.Cursor = con.cursor() cur: sql.Cursor = con.cursor()
try: try:
cur.execute(insert_pagination(f"SELECT * FROM {class_.__name__.lower()}", page, element_count)) cur.execute(_insert_pagination(f"SELECT * FROM {class_.__name__.lower()}", page, element_count))
except sql.OperationalError: except sql.OperationalError:
raise TypeError(f"No record of type {class_.__name__.lower()}") raise TypeError(f"No record of type {class_.__name__.lower()}")
records = cur.fetchall() records = cur.fetchall()