From f0aac812099617f7b8deee9982ff69c5db957e60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ege=20Emir=20=C3=96zkan?= Date: Mon, 10 Aug 2020 06:29:20 +0300 Subject: [PATCH] Add fetch_where, a single conditional wrapper around fetch_if, similar to fetch_equals --- datalite/__init__.py | 13 +++++++++++++ setup.py | 2 +- test/main_tests.py | 6 +++++- 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/datalite/__init__.py b/datalite/__init__.py index d8fd205..227a3d0 100644 --- a/datalite/__init__.py +++ b/datalite/__init__.py @@ -241,6 +241,19 @@ def fetch_if(class_: type, condition: str) -> tuple: return tuple(_convert_record_to_object(class_, record, field_names) for record in records) +def fetch_where(class_: type, field: str, value: Any) -> tuple: + """ + Fetch all class_ type variables from the bound db, + provided that the field of the records fit the + given value. + :param class_: Class of the records. + :param field: Field to check. + :param value: Value to check for. + :return: A tuple of the records. + """ + return fetch_if(class_, f"{field} = {_convert_sql_format(value)}") + + def fetch_range(class_: type, range_: range) -> tuple: """ Fetch the records in a given range of object ids. diff --git a/setup.py b/setup.py index f2bd892..c564f83 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ with open("README.md", "r") as fh: setuptools.setup( name="datalite", # Replace with your own username - version="0.4.2", + version="0.4.3", author="Ege Ozkan", author_email="egeemirozkan24@gmail.com", description="A small package that binds dataclasses to an sqlite database", diff --git a/test/main_tests.py b/test/main_tests.py index 97a0a32..8329dcd 100644 --- a/test/main_tests.py +++ b/test/main_tests.py @@ -1,5 +1,5 @@ import unittest -from datalite import datalite, fetch_if, fetch_all, fetch_range, fetch_from, fetch_equals +from datalite import datalite, fetch_if, fetch_all, fetch_range, fetch_from, fetch_equals, fetch_where from sqlite3 import connect from dataclasses import dataclass, asdict from os import remove @@ -94,6 +94,10 @@ class DatabaseFetchCalls(unittest.TestCase): t_objs = fetch_if(FetchClass, "str_ = \"b\"") self.assertEqual(tuple(self.objs[1:]), t_objs) + def testFetchWhere(self): + t_objs = fetch_where(FetchClass, 'str_', 'b') + self.assertEqual(tuple(self.objs[1:]), t_objs) + def testFetchRange(self): t_objs = fetch_range(FetchClass, range(self.objs[0].obj_id, self.objs[2].obj_id)) self.assertEqual(tuple(self.objs[0:2]), t_objs)