diff --git a/Pipfile b/Pipfile index b5846df..2c0f00c 100644 --- a/Pipfile +++ b/Pipfile @@ -6,6 +6,7 @@ verify_ssl = true [dev-packages] [packages] +sphinx = "*" [requires] python_version = "3.8" diff --git a/datalite/datalite_decorator.py b/datalite/datalite_decorator.py index 171ffca..b3d5bad 100644 --- a/datalite/datalite_decorator.py +++ b/datalite/datalite_decorator.py @@ -1,8 +1,14 @@ +""" +Defines the Datalite decorator that can be used to convert a dataclass to +a class bound to a sqlite3 database. +""" + from typing import Dict, Optional, List, Callable from dataclasses import Field, asdict import sqlite3 as sql from .commons import _convert_sql_format, _convert_type + def _get_default(default_object: object, type_overload: Dict[Optional[type], str]) -> str: """ Check if the field's default object is filled, @@ -90,8 +96,14 @@ def _remove_entry(self) -> None: remove_from(self.__class__, getattr(self, 'obj_id')) -def datalite(db_path: str, type_overload: Optional[Dict[Optional[type], str]] = None, - *args, **kwargs) -> Callable: +def datalite(db_path: str, type_overload: Optional[Dict[Optional[type], str]] = None) -> Callable: + """Bind a dataclass to a sqlite3 database. This adds new methods to the class, such as + `create_entry()`, `remove_entry()` and `update_entry()`. + + :param db_path: Path of the database to be binded. + :param type_overload: Type overload dictionary. + :return: The new dataclass. + """ def decorator(dataclass_: type, *args_i, **kwargs_i): type_table: Dict[Optional[type], str] = {None: "NULL", int: "INTEGER", float: "REAL", str: "TEXT", bytes: "BLOB"} diff --git a/datalite/fetch.py b/datalite/fetch.py index 12cf4b4..1b9d164 100644 --- a/datalite/fetch.py +++ b/datalite/fetch.py @@ -3,9 +3,10 @@ from typing import List, Tuple, Any from .commons import _convert_sql_format -def insert_pagination(query: str, page: int, element_count: int) -> str: +def _insert_pagination(query: str, page: int, element_count: int) -> str: """ Insert the pagination arguments if page number is given. + :param query: Query to insert to :param page: Page to get. :param element_count: Element count in each page. @@ -20,6 +21,7 @@ def is_fetchable(class_: type, obj_id: int) -> bool: """ Check if a record is fetchable given its obj_id and class_ type. + :param class_: Class type of the object. :param obj_id: Unique obj_id of the object. :return: If the object is fetchable. @@ -36,6 +38,7 @@ def is_fetchable(class_: type, obj_id: int) -> bool: def _get_table_cols(cur: sql.Cursor, table_name: str) -> List[str]: """ Get the column data of a table. + :param cur: Cursor in database. :param table_name: Name of the table. :return: the information about columns. @@ -47,6 +50,7 @@ def _get_table_cols(cur: sql.Cursor, table_name: str) -> List[str]: def fetch_equals(class_: type, field: str, value: Any, ) -> Any: """ Fetch a class_ type variable from its bound db. + :param class_: Class to fetch. :param field: Field to check for, by default, object id. :param value: Value of the field to check for. @@ -67,6 +71,7 @@ def fetch_equals(class_: type, field: str, value: Any, ) -> Any: def fetch_from(class_: type, obj_id: int) -> Any: """ Fetch a class_ type variable from its bound dv. + :param class_: Class to fetch from. :param obj_id: Unique object id of the object. :return: The fetched object. @@ -80,6 +85,7 @@ def fetch_from(class_: type, obj_id: int) -> Any: def _convert_record_to_object(class_: type, record: Tuple[Any], field_names: List[str]) -> Any: """ Convert a given record fetched from an SQL instance to a Python Object of given class_. + :param class_: Class type to convert the record to. :param record: Record to get data from. :param field_names: Field names of the class. @@ -100,6 +106,7 @@ def fetch_if(class_: type, condition: str, page: int = 0, element_count: int = 1 """ Fetch all class_ type variables from the bound db, provided they fit the given condition + :param class_: Class type to fetch. :param condition: Condition to check for. :param page: Which page to retrieve, default all. (0 means closed). @@ -110,7 +117,7 @@ def fetch_if(class_: type, condition: str, page: int = 0, element_count: int = 1 table_name = class_.__name__.lower() with sql.connect(getattr(class_, 'db_path')) as con: cur: sql.Cursor = con.cursor() - cur.execute(insert_pagination(f"SELECT * FROM {table_name} WHERE {condition}", page, element_count)) + cur.execute(_insert_pagination(f"SELECT * FROM {table_name} WHERE {condition}", page, element_count)) records: list = cur.fetchall() field_names: List[str] = _get_table_cols(cur, table_name) return tuple(_convert_record_to_object(class_, record, field_names) for record in records) @@ -121,6 +128,7 @@ def fetch_where(class_: type, field: str, value: Any, page: int = 0, element_cou Fetch all class_ type variables from the bound db, provided that the field of the records fit the given value. + :param class_: Class of the records. :param field: Field to check. :param value: Value to check for. @@ -134,6 +142,7 @@ def fetch_where(class_: type, field: str, value: Any, page: int = 0, element_cou def fetch_range(class_: type, range_: range) -> tuple: """ Fetch the records in a given range of object ids. + :param class_: Class of the records. :param range_: Range of the object ids. :return: A tuple of class_ type objects whose values @@ -145,6 +154,7 @@ def fetch_range(class_: type, range_: range) -> tuple: def fetch_all(class_: type, page: int = 0, element_count: int = 10) -> tuple: """ Fetchall the records in the bound database. + :param class_: Class of the records. :param page: Which page to retrieve, default all. (0 means closed). :param element_count: Element count in each page. @@ -158,7 +168,7 @@ def fetch_all(class_: type, page: int = 0, element_count: int = 10) -> tuple: with sql.connect(db_path) as con: cur: sql.Cursor = con.cursor() try: - cur.execute(insert_pagination(f"SELECT * FROM {class_.__name__.lower()}", page, element_count)) + cur.execute(_insert_pagination(f"SELECT * FROM {class_.__name__.lower()}", page, element_count)) except sql.OperationalError: raise TypeError(f"No record of type {class_.__name__.lower()}") records = cur.fetchall()