From 9cae63bfec44f649fbd4edd43ff26a271e9b2b73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ege=20Emir=20=C3=96zkan?= Date: Mon, 10 Aug 2020 05:32:34 +0300 Subject: [PATCH] Add field option for fetch_from --- README.md | 7 ++++--- datalite/__init__.py | 12 +++++------- setup.py | 6 +++--- test/main_tests.py | 5 +++++ 4 files changed, 17 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 4f9f056..3fcee99 100644 --- a/README.md +++ b/README.md @@ -62,10 +62,11 @@ inserted onto the system. > :warning: **Limitation! Fetch can only fetch limited classes correctly**: int, float and str! Finally, you may wish to recreate objects from a table that already exist, for -this purpose we have the function `fetch_from(class_, object_id)` as well +this purpose we have the function `fetch_from(class_, value, field)` as well as `is_fetchable(className, object_id)` former fetches a record from the -SQL database whereas the latter checks if it is fetchable (most likely -to check if it exists.) +SQL database given a field and value. If only the value is provided, + field defaults to 'obj_id' which is unique for all objects, + whereas the latter checks if it is fetchable (most likely to check if it exists.) ```python >>> fetch_from(Student, 2) diff --git a/datalite/__init__.py b/datalite/__init__.py index 34c6357..fc72f89 100644 --- a/datalite/__init__.py +++ b/datalite/__init__.py @@ -171,21 +171,19 @@ def _get_table_cols(cur: sql.Cursor, table_name: str) -> List[str]: return [row_info[1] for row_info in cur.fetchall()][1:] -def fetch_from(class_: type, obj_id: int) -> Any: +def fetch_from(class_: type, value: Any, field: str = 'obj_id') -> Any: """ Fetch a class_ type variable from its bound db. :param class_: Class to fetch. - :param obj_id: Unique object id of the class. + :param field: Field to check for, by default, object id. + :param value: Value of the field to check for. :return: The object whose data is taken from the database. """ table_name = class_.__name__.lower() - if not is_fetchable(class_, obj_id): - raise KeyError(f"An object with the id {obj_id} in table {table_name} does not exist." - f"or is otherwise unable to be fetched.") with sql.connect(getattr(class_, 'db_path')) as con: cur: sql.Cursor = con.cursor() - cur.execute(f"SELECT * FROM {class_.__name__.lower()} WHERE obj_id = {obj_id};") # Guaranteed to work. - field_values: List[str] = list(cur.fetchone())[1:] + cur.execute(f"SELECT * FROM {table_name} WHERE {field} = {_convert_sql_format(value)};") + obj_id, *field_values = list(cur.fetchone()) field_names: List[str] = _get_table_cols(cur, class_.__name__.lower()) kwargs = dict(zip(field_names, field_values)) obj = class_(**kwargs) diff --git a/setup.py b/setup.py index 7c6b0fc..f2bd892 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ with open("README.md", "r") as fh: setuptools.setup( name="datalite", # Replace with your own username - version="0.4.1", + version="0.4.2", author="Ege Ozkan", author_email="egeemirozkan24@gmail.com", description="A small package that binds dataclasses to an sqlite database", @@ -14,9 +14,9 @@ setuptools.setup( url="https://github.com/ambertide/datalite", packages=setuptools.find_packages(), classifiers=[ - "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ], - python_requires='>=3.6', + python_requires='>=3.7', ) \ No newline at end of file diff --git a/test/main_tests.py b/test/main_tests.py index 88324f9..75c478b 100644 --- a/test/main_tests.py +++ b/test/main_tests.py @@ -82,6 +82,10 @@ class DatabaseFetchCalls(unittest.TestCase): t_obj = fetch_from(FetchClass, self.objs[0].obj_id) self.assertEqual(self.objs[0], t_obj) + def testFetchFromDif(self): + t_obj = fetch_from(FetchClass, self.objs[0].str_, 'str_') + self.assertEqual(self.objs[0], t_obj) + def testFetchAll(self): t_objs = fetch_all(FetchClass) self.assertEqual(tuple(self.objs), t_objs) @@ -97,5 +101,6 @@ class DatabaseFetchCalls(unittest.TestCase): def tearDown(self) -> None: [obj.remove_entry() for obj in self.objs] + if __name__ == '__main__': unittest.main()