Add fetch_where, a single conditional wrapper around fetch_if, similar to fetch_equals

This commit is contained in:
Ege Emir Özkan
2020-08-10 06:29:20 +03:00
parent a3eed28f42
commit f0aac81209
3 changed files with 19 additions and 2 deletions

View File

@@ -241,6 +241,19 @@ def fetch_if(class_: type, condition: str) -> tuple:
return tuple(_convert_record_to_object(class_, record, field_names) for record in records) return tuple(_convert_record_to_object(class_, record, field_names) for record in records)
def fetch_where(class_: type, field: str, value: Any) -> tuple:
"""
Fetch all class_ type variables from the bound db,
provided that the field of the records fit the
given value.
:param class_: Class of the records.
:param field: Field to check.
:param value: Value to check for.
:return: A tuple of the records.
"""
return fetch_if(class_, f"{field} = {_convert_sql_format(value)}")
def fetch_range(class_: type, range_: range) -> tuple: def fetch_range(class_: type, range_: range) -> tuple:
""" """
Fetch the records in a given range of object ids. Fetch the records in a given range of object ids.

View File

@@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
setuptools.setup( setuptools.setup(
name="datalite", # Replace with your own username name="datalite", # Replace with your own username
version="0.4.2", version="0.4.3",
author="Ege Ozkan", author="Ege Ozkan",
author_email="egeemirozkan24@gmail.com", author_email="egeemirozkan24@gmail.com",
description="A small package that binds dataclasses to an sqlite database", description="A small package that binds dataclasses to an sqlite database",

View File

@@ -1,5 +1,5 @@
import unittest import unittest
from datalite import datalite, fetch_if, fetch_all, fetch_range, fetch_from, fetch_equals from datalite import datalite, fetch_if, fetch_all, fetch_range, fetch_from, fetch_equals, fetch_where
from sqlite3 import connect from sqlite3 import connect
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict
from os import remove from os import remove
@@ -94,6 +94,10 @@ class DatabaseFetchCalls(unittest.TestCase):
t_objs = fetch_if(FetchClass, "str_ = \"b\"") t_objs = fetch_if(FetchClass, "str_ = \"b\"")
self.assertEqual(tuple(self.objs[1:]), t_objs) self.assertEqual(tuple(self.objs[1:]), t_objs)
def testFetchWhere(self):
t_objs = fetch_where(FetchClass, 'str_', 'b')
self.assertEqual(tuple(self.objs[1:]), t_objs)
def testFetchRange(self): def testFetchRange(self):
t_objs = fetch_range(FetchClass, range(self.objs[0].obj_id, self.objs[2].obj_id)) t_objs = fetch_range(FetchClass, range(self.objs[0].obj_id, self.objs[2].obj_id))
self.assertEqual(tuple(self.objs[0:2]), t_objs) self.assertEqual(tuple(self.objs[0:2]), t_objs)