diff --git a/.gitignore b/.gitignore index 0cca9c1..70e4135 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ datalite/__pycache__/* *.pyc datalite/*.db +*.db build/* datalite.egg-info/* dist/* diff --git a/README.md b/README.md index e03261a..352d86e 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,9 @@ # Datalite [![Maintainability](https://api.codeclimate.com/v1/badges/9d4ce56bfbd3b63649be/maintainability)](https://codeclimate.com/github/ambertide/datalite/maintainability) - +[![Test Coverage](https://api.codeclimate.com/v1/badges/9d4ce56bfbd3b63649be/test_coverage)](https://codeclimate.com/github/ambertide/datalite/test_coverage) +[![PyPI version shields.io](https://img.shields.io/pypi/v/datalite.svg)](https://pypi.python.org/pypi/datalite/) +[![PyPI license](https://img.shields.io/pypi/l/datalite.svg)](https://pypi.python.org/pypi/datalite/) Datalite is a simple Python package that binds your dataclasses to a table in a sqlite3 database, diff --git a/test.db b/test.db deleted file mode 100644 index 917b78b..0000000 Binary files a/test.db and /dev/null differ diff --git a/test/main_tests.py b/test/main_tests.py index 9ff509c..88324f9 100644 --- a/test/main_tests.py +++ b/test/main_tests.py @@ -1,9 +1,5 @@ import unittest -try: - from datalite import datalite -except ModuleNotFoundError: - import importlib - importlib.import_module('datalite', '../datalite/') +from datalite import datalite, fetch_if, fetch_all, fetch_range, fetch_from from sqlite3 import connect from dataclasses import dataclass, asdict from os import remove @@ -21,6 +17,16 @@ class TestClass: return asdict(self) == asdict(other) +@datalite(db_path='test.db') +@dataclass +class FetchClass: + ordinal: int + str_: str + + def __eq__(self, other): + return asdict(self) == asdict(other) + + def getValFromDB(obj_id = 1): with connect('test.db') as db: cur = db.cursor() @@ -66,5 +72,30 @@ class DatabaseMain(unittest.TestCase): self.assertEqual(len(objects), init_len) + +class DatabaseFetchCalls(unittest.TestCase): + def setUp(self) -> None: + self.objs = [FetchClass(1, 'a'), FetchClass(2, 'b'), FetchClass(3, 'b')] + [obj.create_entry() for obj in self.objs] + + def testFetchFrom(self): + t_obj = fetch_from(FetchClass, self.objs[0].obj_id) + self.assertEqual(self.objs[0], t_obj) + + def testFetchAll(self): + t_objs = fetch_all(FetchClass) + self.assertEqual(tuple(self.objs), t_objs) + + def testFetchIf(self): + t_objs = fetch_if(FetchClass, "str_ = \"b\"") + self.assertEqual(tuple(self.objs[1:]), t_objs) + + def testFetchRange(self): + t_objs = fetch_range(FetchClass, range(self.objs[0].obj_id, self.objs[2].obj_id)) + self.assertEqual(tuple(self.objs[0:2]), t_objs) + + def tearDown(self) -> None: + [obj.remove_entry() for obj in self.objs] + if __name__ == '__main__': unittest.main()