165 lines
5.0 KiB
Python
165 lines
5.0 KiB
Python
import unittest
|
|
from datalite import datalite
|
|
from datalite.fetch import fetch_if, fetch_all, fetch_range, fetch_from, fetch_equals, fetch_where
|
|
from sqlite3 import connect
|
|
from dataclasses import dataclass, asdict
|
|
from math import floor
|
|
from os import remove
|
|
|
|
from datalite.migrations import basic_migrate, _drop_table
|
|
|
|
|
|
@datalite(db_path='test.db')
|
|
@dataclass
|
|
class TestClass:
|
|
integer_value: int = 1
|
|
byte_value: bytes = b'a'
|
|
float_value: float = 0.4
|
|
str_value: str = 'a'
|
|
|
|
def __eq__(self, other):
|
|
return asdict(self) == asdict(other)
|
|
|
|
|
|
@datalite(db_path='test.db')
|
|
@dataclass
|
|
class FetchClass:
|
|
ordinal: int
|
|
str_: str
|
|
|
|
def __eq__(self, other):
|
|
return asdict(self) == asdict(other)
|
|
|
|
@datalite(db_path='test.db')
|
|
@dataclass
|
|
class Migrate1:
|
|
ordinal: int
|
|
|
|
|
|
@datalite(db_path='test.db')
|
|
@dataclass
|
|
class Migrate2:
|
|
cardinal: int
|
|
|
|
|
|
def getValFromDB(obj_id = 1):
|
|
with connect('test.db') as db:
|
|
cur = db.cursor()
|
|
cur.execute(f'SELECT * FROM testclass WHERE obj_id = {obj_id}')
|
|
fields = list(TestClass.__dataclass_fields__.keys())
|
|
fields.sort()
|
|
repr = dict(zip(fields, cur.fetchall()[0][1:]))
|
|
field_types = {key: value.type for key, value in TestClass.__dataclass_fields__.items()}
|
|
for key in fields:
|
|
if field_types[key] == bytes:
|
|
repr[key] = bytes(repr[key], encoding='utf-8')
|
|
test_object = TestClass(**repr)
|
|
return test_object
|
|
|
|
|
|
class DatabaseMain(unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
self.test_object = TestClass(12, b'bytes', 0.4, 'TestValue')
|
|
|
|
def test_creation(self):
|
|
self.test_object.create_entry()
|
|
self.assertEqual(self.test_object, getValFromDB())
|
|
|
|
def test_update(self):
|
|
self.test_object.create_entry()
|
|
self.test_object.integer_value = 40
|
|
self.test_object.update_entry()
|
|
from_db = getValFromDB(getattr(self.test_object, 'obj_id'))
|
|
self.assertEqual(self.test_object.integer_value, from_db.integer_value)
|
|
|
|
def test_delete(self):
|
|
with connect('test.db') as db:
|
|
cur = db.cursor()
|
|
cur.execute('SELECT * FROM testclass')
|
|
objects = cur.fetchall()
|
|
init_len = len(objects)
|
|
self.test_object.create_entry()
|
|
self.test_object.remove_entry()
|
|
with connect('test.db') as db:
|
|
cur = db.cursor()
|
|
cur.execute('SELECT * FROM testclass')
|
|
objects = cur.fetchall()
|
|
self.assertEqual(len(objects), init_len)
|
|
|
|
|
|
class DatabaseFetchCalls(unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
self.objs = [FetchClass(1, 'a'), FetchClass(2, 'b'), FetchClass(3, 'b')]
|
|
[obj.create_entry() for obj in self.objs]
|
|
|
|
def testFetchFrom(self):
|
|
t_obj = fetch_from(FetchClass, self.objs[0].obj_id)
|
|
self.assertEqual(self.objs[0], t_obj)
|
|
|
|
def testFetchEquals(self):
|
|
t_obj = fetch_equals(FetchClass, 'str_', self.objs[0].str_)
|
|
self.assertEqual(self.objs[0], t_obj)
|
|
|
|
def testFetchAll(self):
|
|
t_objs = fetch_all(FetchClass)
|
|
self.assertEqual(tuple(self.objs), t_objs)
|
|
|
|
def testFetchIf(self):
|
|
t_objs = fetch_if(FetchClass, "str_ = \"b\"")
|
|
self.assertEqual(tuple(self.objs[1:]), t_objs)
|
|
|
|
def testFetchWhere(self):
|
|
t_objs = fetch_where(FetchClass, 'str_', 'b')
|
|
self.assertEqual(tuple(self.objs[1:]), t_objs)
|
|
|
|
def testFetchRange(self):
|
|
t_objs = fetch_range(FetchClass, range(self.objs[0].obj_id, self.objs[2].obj_id))
|
|
self.assertEqual(tuple(self.objs[0:2]), t_objs)
|
|
|
|
def tearDown(self) -> None:
|
|
[obj.remove_entry() for obj in self.objs]
|
|
|
|
|
|
class DatabaseFetchPaginationCalls(unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
self.objs = [FetchClass(i, f'{floor(i/10)}') for i in range(30)]
|
|
[obj.create_entry() for obj in self.objs]
|
|
|
|
def testFetchAllPagination(self):
|
|
t_objs = fetch_all(FetchClass, 1, 10)
|
|
self.assertEqual(tuple(self.objs[:10]), t_objs)
|
|
|
|
def testFetchWherePagination(self):
|
|
t_objs = fetch_where(FetchClass, 'str_', '0', 2, 5)
|
|
self.assertEqual(tuple(self.objs[5:10]), t_objs)
|
|
|
|
def testFetchIfPagination(self):
|
|
t_objs = fetch_if(FetchClass, 'str_ = "0"', 1, 5)
|
|
self.assertEqual(tuple(self.objs[:5]), t_objs)
|
|
|
|
def tearDown(self) -> None:
|
|
[obj.remove_entry() for obj in self.objs]
|
|
|
|
|
|
class DatabaseMigration(unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
self.objs = [Migrate1(i) for i in range(10)]
|
|
[obj.create_entry() for obj in self.objs]
|
|
|
|
def testBasicMigrate(self):
|
|
global Migrate1, Migrate2
|
|
Migrate1 = Migrate2
|
|
Migrate1.__name__ = 'Migrate1'
|
|
basic_migrate(Migrate1, {'ordinal': 'cardinal'})
|
|
t_objs = fetch_all(Migrate1)
|
|
self.assertEqual([obj.ordinal for obj in self.objs], [obj.cardinal for obj in t_objs])
|
|
|
|
def tearDown(self) -> None:
|
|
t_objs = fetch_all(Migrate1)
|
|
[obj.remove_entry() for obj in t_objs]
|
|
_drop_table('test.db', 'migrate1')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|