Now test-passing!

This commit is contained in:
hhh
2024-03-17 15:54:31 +02:00
parent b0523e141e
commit a0db44f945
3 changed files with 161 additions and 114 deletions

View File

@@ -3,6 +3,7 @@ from pickle import HIGHEST_PROTOCOL, dumps, loads
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import aiosqlite import aiosqlite
from aiosqlite import IntegrityError
from .constraints import Unique from .constraints import Unique
@@ -158,7 +159,14 @@ def _tweaked_dump_value(self, value):
def _tweaked_dump(self, name): def _tweaked_dump(self, name):
value = getattr(self, name) value = getattr(self, name)
field_types = {key: value.type for key, value in self.__dataclass_fields__.items()}
if (
"NOT NULL UNIQUE" not in self.types_table.get(field_types[name], "")
or value is not None
):
return _tweaked_dump_value(self, value) return _tweaked_dump_value(self, value)
else:
raise IntegrityError
def _tweaked_load_value(data): def _tweaked_load_value(data):

View File

@@ -3,10 +3,10 @@ Defines the Datalite decorator that can be used to convert a dataclass to
a class bound to an sqlite3 database. a class bound to an sqlite3 database.
""" """
from dataclasses import asdict, fields from dataclasses import asdict, fields
from sqlite3.dbapi2 import IntegrityError
from typing import Callable, Dict, Optional from typing import Callable, Dict, Optional
import aiosqlite import aiosqlite
from aiosqlite import IntegrityError
from .commons import _create_table, _tweaked_create_table, _tweaked_dump, type_table from .commons import _create_table, _tweaked_create_table, _tweaked_dump, type_table
from .constraints import ConstraintFailedError from .constraints import ConstraintFailedError
@@ -36,6 +36,9 @@ async def _create_entry(self) -> None:
await con.commit() await con.commit()
except IntegrityError: except IntegrityError:
raise ConstraintFailedError("A constraint has failed.") raise ConstraintFailedError("A constraint has failed.")
finally:
await cur.close()
await con.close()
async def _tweaked_create_entry(self) -> None: async def _tweaked_create_entry(self) -> None:
@@ -55,6 +58,9 @@ async def _tweaked_create_entry(self) -> None:
await con.commit() await con.commit()
except IntegrityError: except IntegrityError:
raise ConstraintFailedError("A constraint has failed.") raise ConstraintFailedError("A constraint has failed.")
finally:
await cur.close()
await con.close()
async def _update_entry(self) -> None: async def _update_entry(self) -> None:

View File

@@ -1,223 +1,256 @@
import unittest import unittest
from datalite import datalite from dataclasses import asdict, dataclass
from datalite.constraints import Unique, ConstraintFailedError
from datalite.fetch import fetch_if, fetch_all, fetch_range, fetch_from, fetch_equals, fetch_where
from datalite.mass_actions import create_many, copy_many
from sqlite3 import connect
from dataclasses import dataclass, asdict
from math import floor from math import floor
from datalite.migrations import basic_migrate, _drop_table from sqlite3 import connect
from datalite import datalite
from datalite.constraints import ConstraintFailedError, Unique
from datalite.fetch import (
fetch_all,
fetch_equals,
fetch_from,
fetch_if,
fetch_range,
fetch_where,
)
from datalite.mass_actions import copy_many, create_many
from datalite.migrations import _drop_table, migrate
from datalite.typed import DataliteHinted
@datalite(db_path='test.db') @datalite(db_path="test.db")
@dataclass @dataclass
class TestClass: class TestClass(DataliteHinted):
integer_value: int = 1 integer_value: int = 1
byte_value: bytes = b'a' byte_value: bytes = b"a"
float_value: float = 0.4 float_value: float = 0.4
str_value: str = 'a' str_value: str = "a"
bool_value: bool = True bool_value: bool = True
def __eq__(self, other): def __eq__(self, other):
return asdict(self) == asdict(other) return asdict(self) == asdict(other)
@datalite(db_path='test.db') @datalite(db_path="test.db")
@dataclass @dataclass
class FetchClass: class FetchClass(DataliteHinted):
ordinal: int ordinal: int
str_: str str_: str
def __eq__(self, other): def __eq__(self, other):
return asdict(self) == asdict(other) return asdict(self) == asdict(other)
@datalite(db_path='test.db')
@datalite(db_path="test.db")
@dataclass @dataclass
class Migrate1: class Migrate1(DataliteHinted):
ordinal: int ordinal: int
conventional: str conventional: str
@datalite(db_path='test.db') @datalite(db_path="test.db")
@dataclass @dataclass
class Migrate2: class Migrate2(DataliteHinted):
cardinal: Unique[int] = 1 cardinal: Unique[int] = 1
str_: str = "default" str_: str = "default"
@datalite(db_path='test.db') @datalite(db_path="test.db")
@dataclass @dataclass
class ConstraintedClass: class ConstraintedClass(DataliteHinted):
unique_str: Unique[str] unique_str: Unique[str]
@datalite(db_path='test.db') @datalite(db_path="test.db")
@dataclass @dataclass
class MassCommit: class MassCommit(DataliteHinted):
str_: str str_: str
def getValFromDB(obj_id = 1): def getValFromDB(obj_id=1):
with connect('test.db') as db: with connect("test.db") as db:
cur = db.cursor() cur = db.cursor()
cur.execute(f'SELECT * FROM testclass WHERE obj_id = {obj_id}') cur.execute(f"SELECT * FROM testclass WHERE obj_id = {obj_id}")
fields = list(TestClass.__dataclass_fields__.keys()) fields = list(TestClass.__dataclass_fields__.keys())
fields.sort() fields.sort()
repr = dict(zip(fields, cur.fetchall()[0][1:])) repr = dict(zip(fields, cur.fetchall()[0][1:]))
field_types = {key: value.type for key, value in TestClass.__dataclass_fields__.items()} _ = {key: value.type for key, value in TestClass.__dataclass_fields__.items()}
test_object = TestClass(**repr) test_object = TestClass(**repr)
return test_object return test_object
class DatabaseMain(unittest.TestCase): class DatabaseMain(unittest.IsolatedAsyncioTestCase):
def setUp(self) -> None: async def asyncSetUp(self) -> None:
self.test_object = TestClass(12, b'bytes', 0.4, 'TestValue') self.test_object = TestClass(12, b"bytes", 0.4, "TestValue")
await self.test_object.markup_table()
def test_creation(self): async def test_creation(self):
self.test_object.create_entry() await self.test_object.create_entry()
self.assertEqual(self.test_object, getValFromDB()) self.assertEqual(self.test_object, getValFromDB())
def test_update(self): async def test_update(self):
self.test_object.create_entry() await self.test_object.create_entry()
self.test_object.integer_value = 40 self.test_object.integer_value = 40
self.test_object.update_entry() await self.test_object.update_entry()
from_db = getValFromDB(getattr(self.test_object, 'obj_id')) from_db = getValFromDB(getattr(self.test_object, "obj_id"))
self.assertEqual(self.test_object.integer_value, from_db.integer_value) self.assertEqual(self.test_object.integer_value, from_db.integer_value)
def test_delete(self): async def test_delete(self):
with connect('test.db') as db: with connect("test.db") as db:
cur = db.cursor() cur = db.cursor()
cur.execute('SELECT * FROM testclass') cur.execute("SELECT * FROM testclass")
objects = cur.fetchall() objects = cur.fetchall()
init_len = len(objects) init_len = len(objects)
self.test_object.create_entry() await self.test_object.create_entry()
self.test_object.remove_entry() await self.test_object.remove_entry()
with connect('test.db') as db: with connect("test.db") as db:
cur = db.cursor() cur = db.cursor()
cur.execute('SELECT * FROM testclass') cur.execute("SELECT * FROM testclass")
objects = cur.fetchall() objects = cur.fetchall()
self.assertEqual(len(objects), init_len) self.assertEqual(len(objects), init_len)
class DatabaseFetchCalls(unittest.TestCase): class DatabaseFetchCalls(unittest.IsolatedAsyncioTestCase):
def setUp(self) -> None: async def asyncSetUp(self) -> None:
self.objs = [FetchClass(1, 'a'), FetchClass(2, 'b'), FetchClass(3, 'b')] self.objs = [FetchClass(1, "a"), FetchClass(2, "b"), FetchClass(3, "b")]
[obj.create_entry() for obj in self.objs] await self.objs[0].markup_table()
[await obj.create_entry() for obj in self.objs]
def testFetchFrom(self): async def testFetchFrom(self):
t_obj = fetch_from(FetchClass, self.objs[0].obj_id) t_obj = await fetch_from(FetchClass, self.objs[0].obj_id)
self.assertEqual(self.objs[0], t_obj) self.assertEqual(self.objs[0], t_obj)
def testFetchEquals(self): async def testFetchEquals(self):
t_obj = fetch_equals(FetchClass, 'str_', self.objs[0].str_) t_obj = await fetch_equals(FetchClass, "str_", self.objs[0].str_)
self.assertEqual(self.objs[0], t_obj) self.assertEqual(self.objs[0], t_obj)
def testFetchAll(self): async def testFetchAll(self):
t_objs = fetch_all(FetchClass) t_objs = await fetch_all(FetchClass)
self.assertEqual(tuple(self.objs), t_objs) self.assertEqual(tuple(self.objs), t_objs)
def testFetchIf(self): async def testFetchIf(self):
t_objs = fetch_if(FetchClass, "str_ = \"b\"") t_objs = await fetch_if(FetchClass, 'str_ = "b"')
self.assertEqual(tuple(self.objs[1:]), t_objs) self.assertEqual(tuple(self.objs[1:]), t_objs)
def testFetchWhere(self): async def testFetchWhere(self):
t_objs = fetch_where(FetchClass, 'str_', 'b') t_objs = await fetch_where(FetchClass, "str_", "b")
self.assertEqual(tuple(self.objs[1:]), t_objs) self.assertEqual(tuple(self.objs[1:]), t_objs)
def testFetchRange(self): async def testFetchRange(self):
t_objs = fetch_range(FetchClass, range(self.objs[0].obj_id, self.objs[2].obj_id)) t_objs = await fetch_range(
FetchClass, range(self.objs[0].obj_id, self.objs[2].obj_id)
)
self.assertEqual(tuple(self.objs[0:2]), t_objs) self.assertEqual(tuple(self.objs[0:2]), t_objs)
def tearDown(self) -> None: async def asyncTearDown(self) -> None:
[obj.remove_entry() for obj in self.objs] [await obj.remove_entry() for obj in self.objs]
class DatabaseFetchPaginationCalls(unittest.TestCase): class DatabaseFetchPaginationCalls(unittest.IsolatedAsyncioTestCase):
def setUp(self) -> None: async def asyncSetUp(self) -> None:
self.objs = [FetchClass(i, f'{floor(i/10)}') for i in range(30)] self.objs = [FetchClass(i, f"{floor(i/10)}") for i in range(30)]
[obj.create_entry() for obj in self.objs] await self.objs[0].markup_table()
[await obj.create_entry() for obj in self.objs]
def testFetchAllPagination(self): async def testFetchAllPagination(self):
t_objs = fetch_all(FetchClass, 1, 10) t_objs = await fetch_all(FetchClass, 1, 10)
self.assertEqual(tuple(self.objs[:10]), t_objs) self.assertEqual(tuple(self.objs[:10]), t_objs)
def testFetchWherePagination(self): async def testFetchWherePagination(self):
t_objs = fetch_where(FetchClass, 'str_', '0', 2, 5) t_objs = await fetch_where(FetchClass, "str_", "0", 2, 5)
self.assertEqual(tuple(self.objs[5:10]), t_objs) self.assertEqual(tuple(self.objs[5:10]), t_objs)
def testFetchIfPagination(self): async def testFetchIfPagination(self):
t_objs = fetch_if(FetchClass, 'str_ = "0"', 1, 5) t_objs = await fetch_if(FetchClass, 'str_ = "0"', 1, 5)
self.assertEqual(tuple(self.objs[:5]), t_objs) self.assertEqual(tuple(self.objs[:5]), t_objs)
def tearDown(self) -> None: async def asyncTearDown(self) -> None:
[obj.remove_entry() for obj in self.objs] [await obj.remove_entry() for obj in self.objs]
class DatabaseMigration(unittest.TestCase): class DatabaseMigration(unittest.IsolatedAsyncioTestCase):
def setUp(self) -> None: async def asyncSetUp(self) -> None:
self.objs = [Migrate1(i, "a") for i in range(10)] self.objs = [Migrate1(i, "a") for i in range(10)]
[obj.create_entry() for obj in self.objs] await self.objs[0].markup_table()
[await obj.create_entry() for obj in self.objs]
def testBasicMigrate(self): async def testBasicMigrate(self):
global Migrate1, Migrate2 global Migrate1, Migrate2
Migrate1 = Migrate2 Migrate1 = Migrate2
Migrate1.__name__ = 'Migrate1' Migrate1.__name__ = "Migrate1"
basic_migrate(Migrate1, {'ordinal': 'cardinal'}) await migrate(Migrate1, {"ordinal": "cardinal"})
t_objs = fetch_all(Migrate1) t_objs = await fetch_all(Migrate1)
self.assertEqual([obj.ordinal for obj in self.objs], [obj.cardinal for obj in t_objs]) self.assertEqual(
[obj.ordinal for obj in self.objs], [obj.cardinal for obj in t_objs]
)
self.assertEqual(["default" for _ in range(10)], [obj.str_ for obj in t_objs]) self.assertEqual(["default" for _ in range(10)], [obj.str_ for obj in t_objs])
def tearDown(self) -> None: async def asyncTearDown(self) -> None:
t_objs = fetch_all(Migrate1) t_objs = await fetch_all(Migrate1)
[obj.remove_entry() for obj in t_objs] [await obj.remove_entry() for obj in t_objs]
_drop_table('test.db', 'migrate1') await _drop_table("test.db", "migrate1")
def helperFunc(): async def helperFunc():
obj = ConstraintedClass("This string is supposed to be unique.") obj = ConstraintedClass("This string is supposed to be unique.")
obj.create_entry() await obj.create_entry()
class DatabaseConstraints(unittest.TestCase): class DatabaseConstraints(unittest.IsolatedAsyncioTestCase):
def setUp(self) -> None: async def asyncSetUp(self) -> None:
self.obj = ConstraintedClass("This string is supposed to be unique.") self.obj = ConstraintedClass("This string is supposed to be unique.")
self.obj.create_entry() try:
await _drop_table("test.db", "constraintedclass")
except Exception as e:
assert e
await self.obj.markup_table()
await self.obj.create_entry()
def testUniquness(self): async def testUniquness(self):
self.assertRaises(ConstraintFailedError, helperFunc) try:
await helperFunc()
except Exception as e:
self.assertEqual(e.__class__, ConstraintFailedError)
else:
self.fail("Did not raise")
def testNullness(self): async def testNullness(self):
self.assertRaises(ConstraintFailedError, lambda : ConstraintedClass(None).create_entry()) try:
await ConstraintedClass(None).create_entry()
except Exception as e:
self.assertEqual(e.__class__, ConstraintFailedError)
else:
self.fail("Did not raise")
def tearDown(self) -> None: async def asyncTearDown(self) -> None:
self.obj.remove_entry() await self.obj.remove_entry()
class DatabaseMassInsert(unittest.TestCase): class DatabaseMassInsert(unittest.IsolatedAsyncioTestCase):
def setUp(self) -> None: async def asyncSetUp(self) -> None:
self.objs = [MassCommit(f'cat + {i}') for i in range(30)] self.objs = [MassCommit(f"cat + {i}") for i in range(30)]
await self.objs[0].markup_table()
async def testMassCreate(self):
def testMassCreate(self): with connect("other.db") as con:
with connect('other.db') as con:
cur = con.cursor() cur = con.cursor()
cur.execute(f'CREATE TABLE IF NOT EXISTS MASSCOMMIT (obj_id, str_)') cur.execute("CREATE TABLE IF NOT EXISTS MASSCOMMIT (obj_id, str_)")
start_tup = fetch_all(MassCommit) start_tup = await fetch_all(MassCommit)
create_many(self.objs, protect_memory=False) await create_many(self.objs, protect_memory=False)
_objs = fetch_all(MassCommit) _objs = await fetch_all(MassCommit)
self.assertEqual(_objs, start_tup + tuple(self.objs)) self.assertEqual(_objs, start_tup + tuple(self.objs))
def _testMassCopy(self): async def _testMassCopy(self):
setattr(MassCommit, 'db_path', 'other.db') setattr(MassCommit, "db_path", "other.db")
start_tup = fetch_all(MassCommit) start_tup = await fetch_all(MassCommit)
copy_many(self.objs, 'other.db', False) await copy_many(self.objs, "other.db", False)
tup = fetch_all(MassCommit) tup = await fetch_all(MassCommit)
self.assertEqual(tup, start_tup + tuple(self.objs)) self.assertEqual(tup, start_tup + tuple(self.objs))
def tearDown(self) -> None: async def asyncTearDown(self) -> None:
[obj.remove_entry() for obj in self.objs] [await obj.remove_entry() for obj in self.objs]
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()