Now test-passing!
This commit is contained in:
@@ -1,223 +1,256 @@
|
||||
import unittest
|
||||
from datalite import datalite
|
||||
from datalite.constraints import Unique, ConstraintFailedError
|
||||
from datalite.fetch import fetch_if, fetch_all, fetch_range, fetch_from, fetch_equals, fetch_where
|
||||
from datalite.mass_actions import create_many, copy_many
|
||||
from sqlite3 import connect
|
||||
from dataclasses import dataclass, asdict
|
||||
from dataclasses import asdict, dataclass
|
||||
from math import floor
|
||||
from datalite.migrations import basic_migrate, _drop_table
|
||||
from sqlite3 import connect
|
||||
|
||||
from datalite import datalite
|
||||
from datalite.constraints import ConstraintFailedError, Unique
|
||||
from datalite.fetch import (
|
||||
fetch_all,
|
||||
fetch_equals,
|
||||
fetch_from,
|
||||
fetch_if,
|
||||
fetch_range,
|
||||
fetch_where,
|
||||
)
|
||||
from datalite.mass_actions import copy_many, create_many
|
||||
from datalite.migrations import _drop_table, migrate
|
||||
from datalite.typed import DataliteHinted
|
||||
|
||||
|
||||
@datalite(db_path='test.db')
|
||||
@datalite(db_path="test.db")
|
||||
@dataclass
|
||||
class TestClass:
|
||||
class TestClass(DataliteHinted):
|
||||
integer_value: int = 1
|
||||
byte_value: bytes = b'a'
|
||||
byte_value: bytes = b"a"
|
||||
float_value: float = 0.4
|
||||
str_value: str = 'a'
|
||||
str_value: str = "a"
|
||||
bool_value: bool = True
|
||||
|
||||
def __eq__(self, other):
|
||||
return asdict(self) == asdict(other)
|
||||
|
||||
|
||||
@datalite(db_path='test.db')
|
||||
@datalite(db_path="test.db")
|
||||
@dataclass
|
||||
class FetchClass:
|
||||
class FetchClass(DataliteHinted):
|
||||
ordinal: int
|
||||
str_: str
|
||||
|
||||
def __eq__(self, other):
|
||||
return asdict(self) == asdict(other)
|
||||
|
||||
@datalite(db_path='test.db')
|
||||
|
||||
@datalite(db_path="test.db")
|
||||
@dataclass
|
||||
class Migrate1:
|
||||
class Migrate1(DataliteHinted):
|
||||
ordinal: int
|
||||
conventional: str
|
||||
|
||||
|
||||
@datalite(db_path='test.db')
|
||||
@datalite(db_path="test.db")
|
||||
@dataclass
|
||||
class Migrate2:
|
||||
class Migrate2(DataliteHinted):
|
||||
cardinal: Unique[int] = 1
|
||||
str_: str = "default"
|
||||
|
||||
|
||||
@datalite(db_path='test.db')
|
||||
@datalite(db_path="test.db")
|
||||
@dataclass
|
||||
class ConstraintedClass:
|
||||
class ConstraintedClass(DataliteHinted):
|
||||
unique_str: Unique[str]
|
||||
|
||||
|
||||
@datalite(db_path='test.db')
|
||||
@datalite(db_path="test.db")
|
||||
@dataclass
|
||||
class MassCommit:
|
||||
class MassCommit(DataliteHinted):
|
||||
str_: str
|
||||
|
||||
|
||||
def getValFromDB(obj_id = 1):
|
||||
with connect('test.db') as db:
|
||||
def getValFromDB(obj_id=1):
|
||||
with connect("test.db") as db:
|
||||
cur = db.cursor()
|
||||
cur.execute(f'SELECT * FROM testclass WHERE obj_id = {obj_id}')
|
||||
cur.execute(f"SELECT * FROM testclass WHERE obj_id = {obj_id}")
|
||||
fields = list(TestClass.__dataclass_fields__.keys())
|
||||
fields.sort()
|
||||
repr = dict(zip(fields, cur.fetchall()[0][1:]))
|
||||
field_types = {key: value.type for key, value in TestClass.__dataclass_fields__.items()}
|
||||
_ = {key: value.type for key, value in TestClass.__dataclass_fields__.items()}
|
||||
test_object = TestClass(**repr)
|
||||
return test_object
|
||||
|
||||
|
||||
class DatabaseMain(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.test_object = TestClass(12, b'bytes', 0.4, 'TestValue')
|
||||
class DatabaseMain(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self) -> None:
|
||||
self.test_object = TestClass(12, b"bytes", 0.4, "TestValue")
|
||||
await self.test_object.markup_table()
|
||||
|
||||
def test_creation(self):
|
||||
self.test_object.create_entry()
|
||||
async def test_creation(self):
|
||||
await self.test_object.create_entry()
|
||||
self.assertEqual(self.test_object, getValFromDB())
|
||||
|
||||
def test_update(self):
|
||||
self.test_object.create_entry()
|
||||
async def test_update(self):
|
||||
await self.test_object.create_entry()
|
||||
self.test_object.integer_value = 40
|
||||
self.test_object.update_entry()
|
||||
from_db = getValFromDB(getattr(self.test_object, 'obj_id'))
|
||||
await self.test_object.update_entry()
|
||||
from_db = getValFromDB(getattr(self.test_object, "obj_id"))
|
||||
self.assertEqual(self.test_object.integer_value, from_db.integer_value)
|
||||
|
||||
def test_delete(self):
|
||||
with connect('test.db') as db:
|
||||
async def test_delete(self):
|
||||
with connect("test.db") as db:
|
||||
cur = db.cursor()
|
||||
cur.execute('SELECT * FROM testclass')
|
||||
cur.execute("SELECT * FROM testclass")
|
||||
objects = cur.fetchall()
|
||||
init_len = len(objects)
|
||||
self.test_object.create_entry()
|
||||
self.test_object.remove_entry()
|
||||
with connect('test.db') as db:
|
||||
await self.test_object.create_entry()
|
||||
await self.test_object.remove_entry()
|
||||
with connect("test.db") as db:
|
||||
cur = db.cursor()
|
||||
cur.execute('SELECT * FROM testclass')
|
||||
cur.execute("SELECT * FROM testclass")
|
||||
objects = cur.fetchall()
|
||||
self.assertEqual(len(objects), init_len)
|
||||
|
||||
|
||||
class DatabaseFetchCalls(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.objs = [FetchClass(1, 'a'), FetchClass(2, 'b'), FetchClass(3, 'b')]
|
||||
[obj.create_entry() for obj in self.objs]
|
||||
class DatabaseFetchCalls(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self) -> None:
|
||||
self.objs = [FetchClass(1, "a"), FetchClass(2, "b"), FetchClass(3, "b")]
|
||||
await self.objs[0].markup_table()
|
||||
[await obj.create_entry() for obj in self.objs]
|
||||
|
||||
def testFetchFrom(self):
|
||||
t_obj = fetch_from(FetchClass, self.objs[0].obj_id)
|
||||
async def testFetchFrom(self):
|
||||
t_obj = await fetch_from(FetchClass, self.objs[0].obj_id)
|
||||
self.assertEqual(self.objs[0], t_obj)
|
||||
|
||||
def testFetchEquals(self):
|
||||
t_obj = fetch_equals(FetchClass, 'str_', self.objs[0].str_)
|
||||
async def testFetchEquals(self):
|
||||
t_obj = await fetch_equals(FetchClass, "str_", self.objs[0].str_)
|
||||
self.assertEqual(self.objs[0], t_obj)
|
||||
|
||||
def testFetchAll(self):
|
||||
t_objs = fetch_all(FetchClass)
|
||||
async def testFetchAll(self):
|
||||
t_objs = await fetch_all(FetchClass)
|
||||
self.assertEqual(tuple(self.objs), t_objs)
|
||||
|
||||
def testFetchIf(self):
|
||||
t_objs = fetch_if(FetchClass, "str_ = \"b\"")
|
||||
async def testFetchIf(self):
|
||||
t_objs = await fetch_if(FetchClass, 'str_ = "b"')
|
||||
self.assertEqual(tuple(self.objs[1:]), t_objs)
|
||||
|
||||
def testFetchWhere(self):
|
||||
t_objs = fetch_where(FetchClass, 'str_', 'b')
|
||||
async def testFetchWhere(self):
|
||||
t_objs = await fetch_where(FetchClass, "str_", "b")
|
||||
self.assertEqual(tuple(self.objs[1:]), t_objs)
|
||||
|
||||
def testFetchRange(self):
|
||||
t_objs = fetch_range(FetchClass, range(self.objs[0].obj_id, self.objs[2].obj_id))
|
||||
async def testFetchRange(self):
|
||||
t_objs = await fetch_range(
|
||||
FetchClass, range(self.objs[0].obj_id, self.objs[2].obj_id)
|
||||
)
|
||||
self.assertEqual(tuple(self.objs[0:2]), t_objs)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
[obj.remove_entry() for obj in self.objs]
|
||||
async def asyncTearDown(self) -> None:
|
||||
[await obj.remove_entry() for obj in self.objs]
|
||||
|
||||
|
||||
class DatabaseFetchPaginationCalls(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.objs = [FetchClass(i, f'{floor(i/10)}') for i in range(30)]
|
||||
[obj.create_entry() for obj in self.objs]
|
||||
class DatabaseFetchPaginationCalls(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self) -> None:
|
||||
self.objs = [FetchClass(i, f"{floor(i/10)}") for i in range(30)]
|
||||
await self.objs[0].markup_table()
|
||||
[await obj.create_entry() for obj in self.objs]
|
||||
|
||||
def testFetchAllPagination(self):
|
||||
t_objs = fetch_all(FetchClass, 1, 10)
|
||||
async def testFetchAllPagination(self):
|
||||
t_objs = await fetch_all(FetchClass, 1, 10)
|
||||
self.assertEqual(tuple(self.objs[:10]), t_objs)
|
||||
|
||||
def testFetchWherePagination(self):
|
||||
t_objs = fetch_where(FetchClass, 'str_', '0', 2, 5)
|
||||
async def testFetchWherePagination(self):
|
||||
t_objs = await fetch_where(FetchClass, "str_", "0", 2, 5)
|
||||
self.assertEqual(tuple(self.objs[5:10]), t_objs)
|
||||
|
||||
def testFetchIfPagination(self):
|
||||
t_objs = fetch_if(FetchClass, 'str_ = "0"', 1, 5)
|
||||
async def testFetchIfPagination(self):
|
||||
t_objs = await fetch_if(FetchClass, 'str_ = "0"', 1, 5)
|
||||
self.assertEqual(tuple(self.objs[:5]), t_objs)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
[obj.remove_entry() for obj in self.objs]
|
||||
async def asyncTearDown(self) -> None:
|
||||
[await obj.remove_entry() for obj in self.objs]
|
||||
|
||||
|
||||
class DatabaseMigration(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
class DatabaseMigration(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self) -> None:
|
||||
self.objs = [Migrate1(i, "a") for i in range(10)]
|
||||
[obj.create_entry() for obj in self.objs]
|
||||
await self.objs[0].markup_table()
|
||||
[await obj.create_entry() for obj in self.objs]
|
||||
|
||||
def testBasicMigrate(self):
|
||||
async def testBasicMigrate(self):
|
||||
global Migrate1, Migrate2
|
||||
Migrate1 = Migrate2
|
||||
Migrate1.__name__ = 'Migrate1'
|
||||
basic_migrate(Migrate1, {'ordinal': 'cardinal'})
|
||||
t_objs = fetch_all(Migrate1)
|
||||
self.assertEqual([obj.ordinal for obj in self.objs], [obj.cardinal for obj in t_objs])
|
||||
Migrate1.__name__ = "Migrate1"
|
||||
await migrate(Migrate1, {"ordinal": "cardinal"})
|
||||
t_objs = await fetch_all(Migrate1)
|
||||
self.assertEqual(
|
||||
[obj.ordinal for obj in self.objs], [obj.cardinal for obj in t_objs]
|
||||
)
|
||||
self.assertEqual(["default" for _ in range(10)], [obj.str_ for obj in t_objs])
|
||||
|
||||
def tearDown(self) -> None:
|
||||
t_objs = fetch_all(Migrate1)
|
||||
[obj.remove_entry() for obj in t_objs]
|
||||
_drop_table('test.db', 'migrate1')
|
||||
async def asyncTearDown(self) -> None:
|
||||
t_objs = await fetch_all(Migrate1)
|
||||
[await obj.remove_entry() for obj in t_objs]
|
||||
await _drop_table("test.db", "migrate1")
|
||||
|
||||
|
||||
def helperFunc():
|
||||
async def helperFunc():
|
||||
obj = ConstraintedClass("This string is supposed to be unique.")
|
||||
obj.create_entry()
|
||||
await obj.create_entry()
|
||||
|
||||
|
||||
class DatabaseConstraints(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
class DatabaseConstraints(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self) -> None:
|
||||
self.obj = ConstraintedClass("This string is supposed to be unique.")
|
||||
self.obj.create_entry()
|
||||
try:
|
||||
await _drop_table("test.db", "constraintedclass")
|
||||
except Exception as e:
|
||||
assert e
|
||||
await self.obj.markup_table()
|
||||
await self.obj.create_entry()
|
||||
|
||||
def testUniquness(self):
|
||||
self.assertRaises(ConstraintFailedError, helperFunc)
|
||||
async def testUniquness(self):
|
||||
try:
|
||||
await helperFunc()
|
||||
except Exception as e:
|
||||
self.assertEqual(e.__class__, ConstraintFailedError)
|
||||
else:
|
||||
self.fail("Did not raise")
|
||||
|
||||
def testNullness(self):
|
||||
self.assertRaises(ConstraintFailedError, lambda : ConstraintedClass(None).create_entry())
|
||||
async def testNullness(self):
|
||||
try:
|
||||
await ConstraintedClass(None).create_entry()
|
||||
except Exception as e:
|
||||
self.assertEqual(e.__class__, ConstraintFailedError)
|
||||
else:
|
||||
self.fail("Did not raise")
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.obj.remove_entry()
|
||||
async def asyncTearDown(self) -> None:
|
||||
await self.obj.remove_entry()
|
||||
|
||||
|
||||
class DatabaseMassInsert(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.objs = [MassCommit(f'cat + {i}') for i in range(30)]
|
||||
class DatabaseMassInsert(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self) -> None:
|
||||
self.objs = [MassCommit(f"cat + {i}") for i in range(30)]
|
||||
await self.objs[0].markup_table()
|
||||
|
||||
|
||||
def testMassCreate(self):
|
||||
with connect('other.db') as con:
|
||||
async def testMassCreate(self):
|
||||
with connect("other.db") as con:
|
||||
cur = con.cursor()
|
||||
cur.execute(f'CREATE TABLE IF NOT EXISTS MASSCOMMIT (obj_id, str_)')
|
||||
cur.execute("CREATE TABLE IF NOT EXISTS MASSCOMMIT (obj_id, str_)")
|
||||
|
||||
start_tup = fetch_all(MassCommit)
|
||||
create_many(self.objs, protect_memory=False)
|
||||
_objs = fetch_all(MassCommit)
|
||||
start_tup = await fetch_all(MassCommit)
|
||||
await create_many(self.objs, protect_memory=False)
|
||||
_objs = await fetch_all(MassCommit)
|
||||
self.assertEqual(_objs, start_tup + tuple(self.objs))
|
||||
|
||||
def _testMassCopy(self):
|
||||
setattr(MassCommit, 'db_path', 'other.db')
|
||||
start_tup = fetch_all(MassCommit)
|
||||
copy_many(self.objs, 'other.db', False)
|
||||
tup = fetch_all(MassCommit)
|
||||
async def _testMassCopy(self):
|
||||
setattr(MassCommit, "db_path", "other.db")
|
||||
start_tup = await fetch_all(MassCommit)
|
||||
await copy_many(self.objs, "other.db", False)
|
||||
tup = await fetch_all(MassCommit)
|
||||
self.assertEqual(tup, start_tup + tuple(self.objs))
|
||||
|
||||
def tearDown(self) -> None:
|
||||
[obj.remove_entry() for obj in self.objs]
|
||||
async def asyncTearDown(self) -> None:
|
||||
[await obj.remove_entry() for obj in self.objs]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user