Add fetch_if, related documention and refactor code slightly.

This commit is contained in:
Ege Emir Özkan
2020-08-03 04:43:13 +03:00
parent 6ce0613b4c
commit 77a6afdfd2
3 changed files with 45 additions and 15 deletions

View File

@@ -1,6 +1,6 @@
import sqlite3 as sql
from dataclasses import Field, asdict, dataclass
from typing import List, Dict, Optional, Callable, Any
from typing import List, Dict, Optional, Callable, Any, Tuple
def _convert_type(type_: Optional[type], type_overload: Dict[Optional[type], str]) -> str:
@@ -171,6 +171,39 @@ def fetch_from(class_: type, obj_id: int) -> Any:
return obj
def _convert_record_to_object(class_: type, record: Tuple[Any], field_names: List[str]) -> Any:
"""
Convert a given record fetched from an SQL instance to a Python Object of given class_.
:param class_: Class type to convert the record to.
:param record: Record to get data from.
:param field_names: Field names of the class.
:return: the created object.
"""
kwargs = dict(zip(field_names, record[1:]))
obj_id = record[0]
obj = class_(**kwargs)
setattr(obj, "obj_id", obj_id)
return obj
def fetch_when(class_: type, condition: str) -> tuple:
"""
Fetch all class_ type variables from the bound db,
provided they fit the given condition
:param class_: Class type to fetch.
:param condition: Condition to check for.
:return: A tuple of records that fit the given condition
of given type class_.
"""
table_name = class_.__name__.lower()
with sql.connect(getattr(class_, 'db_path')) as con:
cur: sql.Cursor = con.cursor()
cur.execute(f"SELECT * FROM {table_name} WHERE {condition};")
field_names: List[str] = _get_table_cols(cur, table_name)
records: list = cur.fetchall()
return tuple(_convert_record_to_object(class_, record, field_names) for record in records)
def fetch_range(class_: type, range_: range) -> tuple:
"""
Fetch the records in a given range of object ids.
@@ -182,7 +215,7 @@ def fetch_range(class_: type, range_: range) -> tuple:
return tuple(fetch_from(class_, obj_id) for obj_id in range_ if is_fetchable(class_, obj_id))
def fetch_all(class_: type) -> tuple:
def fetch_if(class_: type) -> tuple:
"""
Fetchall the records in the bound database.
:param class_: Class of the records.
@@ -201,11 +234,4 @@ def fetch_all(class_: type) -> tuple:
raise TypeError(f"No record of type {class_.__name__.lower()}")
records = cur.fetchall()
field_names: List[str] = _get_table_cols(cur, class_.__name__.lower())
objects: List[class_] = []
for record in records:
kwargs = dict(zip(field_names, record[1:]))
obj_id = record[0]
obj = class_(**kwargs)
setattr(obj, "obj_id", obj_id)
objects.append(obj)
return tuple(objects)
return tuple(_convert_record_to_object(class_, record, field_names) for record in records)