Add field option for fetch_from

This commit is contained in:
Ege Emir Özkan
2020-08-10 05:32:34 +03:00
parent 951623617c
commit 9cae63bfec
4 changed files with 17 additions and 13 deletions

View File

@@ -62,10 +62,11 @@ inserted onto the system.
> :warning: **Limitation! Fetch can only fetch limited classes correctly**: int, float and str! > :warning: **Limitation! Fetch can only fetch limited classes correctly**: int, float and str!
Finally, you may wish to recreate objects from a table that already exist, for Finally, you may wish to recreate objects from a table that already exist, for
this purpose we have the function `fetch_from(class_, object_id)` as well this purpose we have the function `fetch_from(class_, value, field)` as well
as `is_fetchable(className, object_id)` former fetches a record from the as `is_fetchable(className, object_id)` former fetches a record from the
SQL database whereas the latter checks if it is fetchable (most likely SQL database given a field and value. If only the value is provided,
to check if it exists.) field defaults to 'obj_id' which is unique for all objects,
whereas the latter checks if it is fetchable (most likely to check if it exists.)
```python ```python
>>> fetch_from(Student, 2) >>> fetch_from(Student, 2)

View File

@@ -171,21 +171,19 @@ def _get_table_cols(cur: sql.Cursor, table_name: str) -> List[str]:
return [row_info[1] for row_info in cur.fetchall()][1:] return [row_info[1] for row_info in cur.fetchall()][1:]
def fetch_from(class_: type, obj_id: int) -> Any: def fetch_from(class_: type, value: Any, field: str = 'obj_id') -> Any:
""" """
Fetch a class_ type variable from its bound db. Fetch a class_ type variable from its bound db.
:param class_: Class to fetch. :param class_: Class to fetch.
:param obj_id: Unique object id of the class. :param field: Field to check for, by default, object id.
:param value: Value of the field to check for.
:return: The object whose data is taken from the database. :return: The object whose data is taken from the database.
""" """
table_name = class_.__name__.lower() table_name = class_.__name__.lower()
if not is_fetchable(class_, obj_id):
raise KeyError(f"An object with the id {obj_id} in table {table_name} does not exist."
f"or is otherwise unable to be fetched.")
with sql.connect(getattr(class_, 'db_path')) as con: with sql.connect(getattr(class_, 'db_path')) as con:
cur: sql.Cursor = con.cursor() cur: sql.Cursor = con.cursor()
cur.execute(f"SELECT * FROM {class_.__name__.lower()} WHERE obj_id = {obj_id};") # Guaranteed to work. cur.execute(f"SELECT * FROM {table_name} WHERE {field} = {_convert_sql_format(value)};")
field_values: List[str] = list(cur.fetchone())[1:] obj_id, *field_values = list(cur.fetchone())
field_names: List[str] = _get_table_cols(cur, class_.__name__.lower()) field_names: List[str] = _get_table_cols(cur, class_.__name__.lower())
kwargs = dict(zip(field_names, field_values)) kwargs = dict(zip(field_names, field_values))
obj = class_(**kwargs) obj = class_(**kwargs)

View File

@@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
setuptools.setup( setuptools.setup(
name="datalite", # Replace with your own username name="datalite", # Replace with your own username
version="0.4.1", version="0.4.2",
author="Ege Ozkan", author="Ege Ozkan",
author_email="egeemirozkan24@gmail.com", author_email="egeemirozkan24@gmail.com",
description="A small package that binds dataclasses to an sqlite database", description="A small package that binds dataclasses to an sqlite database",
@@ -14,9 +14,9 @@ setuptools.setup(
url="https://github.com/ambertide/datalite", url="https://github.com/ambertide/datalite",
packages=setuptools.find_packages(), packages=setuptools.find_packages(),
classifiers=[ classifiers=[
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3.7",
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
"Operating System :: OS Independent", "Operating System :: OS Independent",
], ],
python_requires='>=3.6', python_requires='>=3.7',
) )

View File

@@ -82,6 +82,10 @@ class DatabaseFetchCalls(unittest.TestCase):
t_obj = fetch_from(FetchClass, self.objs[0].obj_id) t_obj = fetch_from(FetchClass, self.objs[0].obj_id)
self.assertEqual(self.objs[0], t_obj) self.assertEqual(self.objs[0], t_obj)
def testFetchFromDif(self):
t_obj = fetch_from(FetchClass, self.objs[0].str_, 'str_')
self.assertEqual(self.objs[0], t_obj)
def testFetchAll(self): def testFetchAll(self):
t_objs = fetch_all(FetchClass) t_objs = fetch_all(FetchClass)
self.assertEqual(tuple(self.objs), t_objs) self.assertEqual(tuple(self.objs), t_objs)
@@ -97,5 +101,6 @@ class DatabaseFetchCalls(unittest.TestCase):
def tearDown(self) -> None: def tearDown(self) -> None:
[obj.remove_entry() for obj in self.objs] [obj.remove_entry() for obj in self.objs]
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()