This commit is contained in:
Ege Emir Özkan
2020-08-13 21:37:48 +03:00
parent 6adda2e719
commit 8098a50e52
3 changed files with 28 additions and 5 deletions

View File

@@ -6,6 +6,7 @@ verify_ssl = true
[dev-packages]
[packages]
sphinx = "*"
[requires]
python_version = "3.8"

View File

@@ -1,8 +1,14 @@
"""
Defines the Datalite decorator that can be used to convert a dataclass to
a class bound to a sqlite3 database.
"""
from typing import Dict, Optional, List, Callable
from dataclasses import Field, asdict
import sqlite3 as sql
from .commons import _convert_sql_format, _convert_type
def _get_default(default_object: object, type_overload: Dict[Optional[type], str]) -> str:
"""
Check if the field's default object is filled,
@@ -90,8 +96,14 @@ def _remove_entry(self) -> None:
remove_from(self.__class__, getattr(self, 'obj_id'))
def datalite(db_path: str, type_overload: Optional[Dict[Optional[type], str]] = None,
*args, **kwargs) -> Callable:
def datalite(db_path: str, type_overload: Optional[Dict[Optional[type], str]] = None) -> Callable:
"""Bind a dataclass to a sqlite3 database. This adds new methods to the class, such as
`create_entry()`, `remove_entry()` and `update_entry()`.
:param db_path: Path of the database to be binded.
:param type_overload: Type overload dictionary.
:return: The new dataclass.
"""
def decorator(dataclass_: type, *args_i, **kwargs_i):
type_table: Dict[Optional[type], str] = {None: "NULL", int: "INTEGER", float: "REAL",
str: "TEXT", bytes: "BLOB"}

View File

@@ -3,9 +3,10 @@ from typing import List, Tuple, Any
from .commons import _convert_sql_format
def insert_pagination(query: str, page: int, element_count: int) -> str:
def _insert_pagination(query: str, page: int, element_count: int) -> str:
"""
Insert the pagination arguments if page number is given.
:param query: Query to insert to
:param page: Page to get.
:param element_count: Element count in each page.
@@ -20,6 +21,7 @@ def is_fetchable(class_: type, obj_id: int) -> bool:
"""
Check if a record is fetchable given its obj_id and
class_ type.
:param class_: Class type of the object.
:param obj_id: Unique obj_id of the object.
:return: If the object is fetchable.
@@ -36,6 +38,7 @@ def is_fetchable(class_: type, obj_id: int) -> bool:
def _get_table_cols(cur: sql.Cursor, table_name: str) -> List[str]:
"""
Get the column data of a table.
:param cur: Cursor in database.
:param table_name: Name of the table.
:return: the information about columns.
@@ -47,6 +50,7 @@ def _get_table_cols(cur: sql.Cursor, table_name: str) -> List[str]:
def fetch_equals(class_: type, field: str, value: Any, ) -> Any:
"""
Fetch a class_ type variable from its bound db.
:param class_: Class to fetch.
:param field: Field to check for, by default, object id.
:param value: Value of the field to check for.
@@ -67,6 +71,7 @@ def fetch_equals(class_: type, field: str, value: Any, ) -> Any:
def fetch_from(class_: type, obj_id: int) -> Any:
"""
Fetch a class_ type variable from its bound dv.
:param class_: Class to fetch from.
:param obj_id: Unique object id of the object.
:return: The fetched object.
@@ -80,6 +85,7 @@ def fetch_from(class_: type, obj_id: int) -> Any:
def _convert_record_to_object(class_: type, record: Tuple[Any], field_names: List[str]) -> Any:
"""
Convert a given record fetched from an SQL instance to a Python Object of given class_.
:param class_: Class type to convert the record to.
:param record: Record to get data from.
:param field_names: Field names of the class.
@@ -100,6 +106,7 @@ def fetch_if(class_: type, condition: str, page: int = 0, element_count: int = 1
"""
Fetch all class_ type variables from the bound db,
provided they fit the given condition
:param class_: Class type to fetch.
:param condition: Condition to check for.
:param page: Which page to retrieve, default all. (0 means closed).
@@ -110,7 +117,7 @@ def fetch_if(class_: type, condition: str, page: int = 0, element_count: int = 1
table_name = class_.__name__.lower()
with sql.connect(getattr(class_, 'db_path')) as con:
cur: sql.Cursor = con.cursor()
cur.execute(insert_pagination(f"SELECT * FROM {table_name} WHERE {condition}", page, element_count))
cur.execute(_insert_pagination(f"SELECT * FROM {table_name} WHERE {condition}", page, element_count))
records: list = cur.fetchall()
field_names: List[str] = _get_table_cols(cur, table_name)
return tuple(_convert_record_to_object(class_, record, field_names) for record in records)
@@ -121,6 +128,7 @@ def fetch_where(class_: type, field: str, value: Any, page: int = 0, element_cou
Fetch all class_ type variables from the bound db,
provided that the field of the records fit the
given value.
:param class_: Class of the records.
:param field: Field to check.
:param value: Value to check for.
@@ -134,6 +142,7 @@ def fetch_where(class_: type, field: str, value: Any, page: int = 0, element_cou
def fetch_range(class_: type, range_: range) -> tuple:
"""
Fetch the records in a given range of object ids.
:param class_: Class of the records.
:param range_: Range of the object ids.
:return: A tuple of class_ type objects whose values
@@ -145,6 +154,7 @@ def fetch_range(class_: type, range_: range) -> tuple:
def fetch_all(class_: type, page: int = 0, element_count: int = 10) -> tuple:
"""
Fetchall the records in the bound database.
:param class_: Class of the records.
:param page: Which page to retrieve, default all. (0 means closed).
:param element_count: Element count in each page.
@@ -158,7 +168,7 @@ def fetch_all(class_: type, page: int = 0, element_count: int = 10) -> tuple:
with sql.connect(db_path) as con:
cur: sql.Cursor = con.cursor()
try:
cur.execute(insert_pagination(f"SELECT * FROM {class_.__name__.lower()}", page, element_count))
cur.execute(_insert_pagination(f"SELECT * FROM {class_.__name__.lower()}", page, element_count))
except sql.OperationalError:
raise TypeError(f"No record of type {class_.__name__.lower()}")
records = cur.fetchall()