From a0db44f9455b66b0eef3e709fd2338dff49b14c4 Mon Sep 17 00:00:00 2001 From: hhh Date: Sun, 17 Mar 2024 15:54:31 +0200 Subject: [PATCH] Now test-passing! --- datalite/commons.py | 10 +- datalite/datalite_decorator.py | 8 +- test/main_tests.py | 257 +++++++++++++++++++-------------- 3 files changed, 161 insertions(+), 114 deletions(-) diff --git a/datalite/commons.py b/datalite/commons.py index 52e2332..26eac90 100644 --- a/datalite/commons.py +++ b/datalite/commons.py @@ -3,6 +3,7 @@ from pickle import HIGHEST_PROTOCOL, dumps, loads from typing import Any, Dict, List, Optional import aiosqlite +from aiosqlite import IntegrityError from .constraints import Unique @@ -158,7 +159,14 @@ def _tweaked_dump_value(self, value): def _tweaked_dump(self, name): value = getattr(self, name) - return _tweaked_dump_value(self, value) + field_types = {key: value.type for key, value in self.__dataclass_fields__.items()} + if ( + "NOT NULL UNIQUE" not in self.types_table.get(field_types[name], "") + or value is not None + ): + return _tweaked_dump_value(self, value) + else: + raise IntegrityError def _tweaked_load_value(data): diff --git a/datalite/datalite_decorator.py b/datalite/datalite_decorator.py index 99178d6..00585d0 100644 --- a/datalite/datalite_decorator.py +++ b/datalite/datalite_decorator.py @@ -3,10 +3,10 @@ Defines the Datalite decorator that can be used to convert a dataclass to a class bound to an sqlite3 database. """ from dataclasses import asdict, fields -from sqlite3.dbapi2 import IntegrityError from typing import Callable, Dict, Optional import aiosqlite +from aiosqlite import IntegrityError from .commons import _create_table, _tweaked_create_table, _tweaked_dump, type_table from .constraints import ConstraintFailedError @@ -36,6 +36,9 @@ async def _create_entry(self) -> None: await con.commit() except IntegrityError: raise ConstraintFailedError("A constraint has failed.") + finally: + await cur.close() + await con.close() async def _tweaked_create_entry(self) -> None: @@ -55,6 +58,9 @@ async def _tweaked_create_entry(self) -> None: await con.commit() except IntegrityError: raise ConstraintFailedError("A constraint has failed.") + finally: + await cur.close() + await con.close() async def _update_entry(self) -> None: diff --git a/test/main_tests.py b/test/main_tests.py index e52a095..6d5330b 100644 --- a/test/main_tests.py +++ b/test/main_tests.py @@ -1,223 +1,256 @@ import unittest -from datalite import datalite -from datalite.constraints import Unique, ConstraintFailedError -from datalite.fetch import fetch_if, fetch_all, fetch_range, fetch_from, fetch_equals, fetch_where -from datalite.mass_actions import create_many, copy_many -from sqlite3 import connect -from dataclasses import dataclass, asdict +from dataclasses import asdict, dataclass from math import floor -from datalite.migrations import basic_migrate, _drop_table +from sqlite3 import connect + +from datalite import datalite +from datalite.constraints import ConstraintFailedError, Unique +from datalite.fetch import ( + fetch_all, + fetch_equals, + fetch_from, + fetch_if, + fetch_range, + fetch_where, +) +from datalite.mass_actions import copy_many, create_many +from datalite.migrations import _drop_table, migrate +from datalite.typed import DataliteHinted -@datalite(db_path='test.db') +@datalite(db_path="test.db") @dataclass -class TestClass: +class TestClass(DataliteHinted): integer_value: int = 1 - byte_value: bytes = b'a' + byte_value: bytes = b"a" float_value: float = 0.4 - str_value: str = 'a' + str_value: str = "a" bool_value: bool = True def __eq__(self, other): return asdict(self) == asdict(other) -@datalite(db_path='test.db') +@datalite(db_path="test.db") @dataclass -class FetchClass: +class FetchClass(DataliteHinted): ordinal: int str_: str def __eq__(self, other): return asdict(self) == asdict(other) -@datalite(db_path='test.db') + +@datalite(db_path="test.db") @dataclass -class Migrate1: +class Migrate1(DataliteHinted): ordinal: int conventional: str -@datalite(db_path='test.db') +@datalite(db_path="test.db") @dataclass -class Migrate2: +class Migrate2(DataliteHinted): cardinal: Unique[int] = 1 str_: str = "default" -@datalite(db_path='test.db') +@datalite(db_path="test.db") @dataclass -class ConstraintedClass: +class ConstraintedClass(DataliteHinted): unique_str: Unique[str] -@datalite(db_path='test.db') +@datalite(db_path="test.db") @dataclass -class MassCommit: +class MassCommit(DataliteHinted): str_: str -def getValFromDB(obj_id = 1): - with connect('test.db') as db: +def getValFromDB(obj_id=1): + with connect("test.db") as db: cur = db.cursor() - cur.execute(f'SELECT * FROM testclass WHERE obj_id = {obj_id}') + cur.execute(f"SELECT * FROM testclass WHERE obj_id = {obj_id}") fields = list(TestClass.__dataclass_fields__.keys()) fields.sort() repr = dict(zip(fields, cur.fetchall()[0][1:])) - field_types = {key: value.type for key, value in TestClass.__dataclass_fields__.items()} + _ = {key: value.type for key, value in TestClass.__dataclass_fields__.items()} test_object = TestClass(**repr) return test_object -class DatabaseMain(unittest.TestCase): - def setUp(self) -> None: - self.test_object = TestClass(12, b'bytes', 0.4, 'TestValue') +class DatabaseMain(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self) -> None: + self.test_object = TestClass(12, b"bytes", 0.4, "TestValue") + await self.test_object.markup_table() - def test_creation(self): - self.test_object.create_entry() + async def test_creation(self): + await self.test_object.create_entry() self.assertEqual(self.test_object, getValFromDB()) - def test_update(self): - self.test_object.create_entry() + async def test_update(self): + await self.test_object.create_entry() self.test_object.integer_value = 40 - self.test_object.update_entry() - from_db = getValFromDB(getattr(self.test_object, 'obj_id')) + await self.test_object.update_entry() + from_db = getValFromDB(getattr(self.test_object, "obj_id")) self.assertEqual(self.test_object.integer_value, from_db.integer_value) - def test_delete(self): - with connect('test.db') as db: + async def test_delete(self): + with connect("test.db") as db: cur = db.cursor() - cur.execute('SELECT * FROM testclass') + cur.execute("SELECT * FROM testclass") objects = cur.fetchall() init_len = len(objects) - self.test_object.create_entry() - self.test_object.remove_entry() - with connect('test.db') as db: + await self.test_object.create_entry() + await self.test_object.remove_entry() + with connect("test.db") as db: cur = db.cursor() - cur.execute('SELECT * FROM testclass') + cur.execute("SELECT * FROM testclass") objects = cur.fetchall() self.assertEqual(len(objects), init_len) -class DatabaseFetchCalls(unittest.TestCase): - def setUp(self) -> None: - self.objs = [FetchClass(1, 'a'), FetchClass(2, 'b'), FetchClass(3, 'b')] - [obj.create_entry() for obj in self.objs] +class DatabaseFetchCalls(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self) -> None: + self.objs = [FetchClass(1, "a"), FetchClass(2, "b"), FetchClass(3, "b")] + await self.objs[0].markup_table() + [await obj.create_entry() for obj in self.objs] - def testFetchFrom(self): - t_obj = fetch_from(FetchClass, self.objs[0].obj_id) + async def testFetchFrom(self): + t_obj = await fetch_from(FetchClass, self.objs[0].obj_id) self.assertEqual(self.objs[0], t_obj) - def testFetchEquals(self): - t_obj = fetch_equals(FetchClass, 'str_', self.objs[0].str_) + async def testFetchEquals(self): + t_obj = await fetch_equals(FetchClass, "str_", self.objs[0].str_) self.assertEqual(self.objs[0], t_obj) - def testFetchAll(self): - t_objs = fetch_all(FetchClass) + async def testFetchAll(self): + t_objs = await fetch_all(FetchClass) self.assertEqual(tuple(self.objs), t_objs) - def testFetchIf(self): - t_objs = fetch_if(FetchClass, "str_ = \"b\"") + async def testFetchIf(self): + t_objs = await fetch_if(FetchClass, 'str_ = "b"') self.assertEqual(tuple(self.objs[1:]), t_objs) - def testFetchWhere(self): - t_objs = fetch_where(FetchClass, 'str_', 'b') + async def testFetchWhere(self): + t_objs = await fetch_where(FetchClass, "str_", "b") self.assertEqual(tuple(self.objs[1:]), t_objs) - def testFetchRange(self): - t_objs = fetch_range(FetchClass, range(self.objs[0].obj_id, self.objs[2].obj_id)) + async def testFetchRange(self): + t_objs = await fetch_range( + FetchClass, range(self.objs[0].obj_id, self.objs[2].obj_id) + ) self.assertEqual(tuple(self.objs[0:2]), t_objs) - def tearDown(self) -> None: - [obj.remove_entry() for obj in self.objs] + async def asyncTearDown(self) -> None: + [await obj.remove_entry() for obj in self.objs] -class DatabaseFetchPaginationCalls(unittest.TestCase): - def setUp(self) -> None: - self.objs = [FetchClass(i, f'{floor(i/10)}') for i in range(30)] - [obj.create_entry() for obj in self.objs] +class DatabaseFetchPaginationCalls(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self) -> None: + self.objs = [FetchClass(i, f"{floor(i/10)}") for i in range(30)] + await self.objs[0].markup_table() + [await obj.create_entry() for obj in self.objs] - def testFetchAllPagination(self): - t_objs = fetch_all(FetchClass, 1, 10) + async def testFetchAllPagination(self): + t_objs = await fetch_all(FetchClass, 1, 10) self.assertEqual(tuple(self.objs[:10]), t_objs) - def testFetchWherePagination(self): - t_objs = fetch_where(FetchClass, 'str_', '0', 2, 5) + async def testFetchWherePagination(self): + t_objs = await fetch_where(FetchClass, "str_", "0", 2, 5) self.assertEqual(tuple(self.objs[5:10]), t_objs) - def testFetchIfPagination(self): - t_objs = fetch_if(FetchClass, 'str_ = "0"', 1, 5) + async def testFetchIfPagination(self): + t_objs = await fetch_if(FetchClass, 'str_ = "0"', 1, 5) self.assertEqual(tuple(self.objs[:5]), t_objs) - def tearDown(self) -> None: - [obj.remove_entry() for obj in self.objs] + async def asyncTearDown(self) -> None: + [await obj.remove_entry() for obj in self.objs] -class DatabaseMigration(unittest.TestCase): - def setUp(self) -> None: +class DatabaseMigration(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self) -> None: self.objs = [Migrate1(i, "a") for i in range(10)] - [obj.create_entry() for obj in self.objs] + await self.objs[0].markup_table() + [await obj.create_entry() for obj in self.objs] - def testBasicMigrate(self): + async def testBasicMigrate(self): global Migrate1, Migrate2 Migrate1 = Migrate2 - Migrate1.__name__ = 'Migrate1' - basic_migrate(Migrate1, {'ordinal': 'cardinal'}) - t_objs = fetch_all(Migrate1) - self.assertEqual([obj.ordinal for obj in self.objs], [obj.cardinal for obj in t_objs]) + Migrate1.__name__ = "Migrate1" + await migrate(Migrate1, {"ordinal": "cardinal"}) + t_objs = await fetch_all(Migrate1) + self.assertEqual( + [obj.ordinal for obj in self.objs], [obj.cardinal for obj in t_objs] + ) self.assertEqual(["default" for _ in range(10)], [obj.str_ for obj in t_objs]) - def tearDown(self) -> None: - t_objs = fetch_all(Migrate1) - [obj.remove_entry() for obj in t_objs] - _drop_table('test.db', 'migrate1') + async def asyncTearDown(self) -> None: + t_objs = await fetch_all(Migrate1) + [await obj.remove_entry() for obj in t_objs] + await _drop_table("test.db", "migrate1") -def helperFunc(): +async def helperFunc(): obj = ConstraintedClass("This string is supposed to be unique.") - obj.create_entry() + await obj.create_entry() -class DatabaseConstraints(unittest.TestCase): - def setUp(self) -> None: +class DatabaseConstraints(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self) -> None: self.obj = ConstraintedClass("This string is supposed to be unique.") - self.obj.create_entry() + try: + await _drop_table("test.db", "constraintedclass") + except Exception as e: + assert e + await self.obj.markup_table() + await self.obj.create_entry() - def testUniquness(self): - self.assertRaises(ConstraintFailedError, helperFunc) + async def testUniquness(self): + try: + await helperFunc() + except Exception as e: + self.assertEqual(e.__class__, ConstraintFailedError) + else: + self.fail("Did not raise") - def testNullness(self): - self.assertRaises(ConstraintFailedError, lambda : ConstraintedClass(None).create_entry()) + async def testNullness(self): + try: + await ConstraintedClass(None).create_entry() + except Exception as e: + self.assertEqual(e.__class__, ConstraintFailedError) + else: + self.fail("Did not raise") - def tearDown(self) -> None: - self.obj.remove_entry() + async def asyncTearDown(self) -> None: + await self.obj.remove_entry() -class DatabaseMassInsert(unittest.TestCase): - def setUp(self) -> None: - self.objs = [MassCommit(f'cat + {i}') for i in range(30)] +class DatabaseMassInsert(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self) -> None: + self.objs = [MassCommit(f"cat + {i}") for i in range(30)] + await self.objs[0].markup_table() - - def testMassCreate(self): - with connect('other.db') as con: + async def testMassCreate(self): + with connect("other.db") as con: cur = con.cursor() - cur.execute(f'CREATE TABLE IF NOT EXISTS MASSCOMMIT (obj_id, str_)') + cur.execute("CREATE TABLE IF NOT EXISTS MASSCOMMIT (obj_id, str_)") - start_tup = fetch_all(MassCommit) - create_many(self.objs, protect_memory=False) - _objs = fetch_all(MassCommit) + start_tup = await fetch_all(MassCommit) + await create_many(self.objs, protect_memory=False) + _objs = await fetch_all(MassCommit) self.assertEqual(_objs, start_tup + tuple(self.objs)) - def _testMassCopy(self): - setattr(MassCommit, 'db_path', 'other.db') - start_tup = fetch_all(MassCommit) - copy_many(self.objs, 'other.db', False) - tup = fetch_all(MassCommit) + async def _testMassCopy(self): + setattr(MassCommit, "db_path", "other.db") + start_tup = await fetch_all(MassCommit) + await copy_many(self.objs, "other.db", False) + tup = await fetch_all(MassCommit) self.assertEqual(tup, start_tup + tuple(self.objs)) - def tearDown(self) -> None: - [obj.remove_entry() for obj in self.objs] + async def asyncTearDown(self) -> None: + [await obj.remove_entry() for obj in self.objs] -if __name__ == '__main__': +if __name__ == "__main__": unittest.main()