diff --git a/datalite/__init__.py b/datalite/__init__.py index 4d9ba59..34c6357 100644 --- a/datalite/__init__.py +++ b/datalite/__init__.py @@ -32,6 +32,8 @@ def _convert_sql_format(value: Any) -> str: """ if isinstance(value, str): return f'"{value}"' + elif isinstance(value, bytes): + return '"' + str(value).replace("b'", "")[:-1] + '"' else: return str(value) @@ -47,10 +49,7 @@ def _get_default(default_object: object, type_overload: Dict[Optional[type], str empty string if no string is necessary. """ if type(default_object) in type_overload: - if isinstance(default_object, str): - return f' DEFAULT "{default_object}"' - else: - return f" DEFAULT {str(default_object)}" + return f' DEFAULT {_convert_sql_format(default_object)}' return "" @@ -100,12 +99,13 @@ def _update_entry(self) -> None: """ with sql.connect(getattr(self, "db_path")) as con: cur: sql.Cursor = con.cursor() - table_name: str = self.__clas__.__name__.lower() + table_name: str = self.__class__.__name__.lower() kv_pairs = [item for item in asdict(self).items()] kv_pairs.sort(key=lambda item: item[0]) - cur.execute(f"UPDATE {table_name}" - f"SET {', '.join(item[0] + ' = ' + _convert_sql_format(item[1]) for item in kv_pairs)}" - f"WHERE obj_id = {getattr(self, 'obj_id')}") + query = f"UPDATE {table_name} " + \ + f"SET {', '.join(item[0] + ' = ' + _convert_sql_format(item[1]) for item in kv_pairs)}" + \ + f"WHERE obj_id = {getattr(self, 'obj_id')};" + cur.execute(query) con.commit() @@ -202,6 +202,10 @@ def _convert_record_to_object(class_: type, record: Tuple[Any], field_names: Lis :return: the created object. """ kwargs = dict(zip(field_names, record[1:])) + field_types = {key: value.type for key, value in class_.__dataclass_fields__.items()} + for key in kwargs: + if field_types[key] == bytes: + kwargs[key] = bytes(kwargs[key], encoding='utf-8') obj_id = record[0] obj = class_(**kwargs) setattr(obj, "obj_id", obj_id) diff --git a/requirments.txt b/requirments.txt new file mode 100644 index 0000000..5cb2823 --- /dev/null +++ b/requirments.txt @@ -0,0 +1,26 @@ +bleach==3.1.5 +certifi==2020.6.20 +cffi==1.14.1 +chardet==3.0.4 +colorama==0.4.3 +coverage==5.2.1 +cryptography==3.0 +docutils==0.16 +idna==2.10 +jeepney==0.4.3 +keyring==21.3.0 +packaging==20.4 +pkginfo==1.5.0.1 +pycparser==2.20 +Pygments==2.6.1 +pyparsing==2.4.7 +readme-renderer==26.0 +requests==2.24.0 +requests-toolbelt==0.9.1 +rfc3986==1.4.0 +SecretStorage==3.1.2 +six==1.15.0 +tqdm==4.48.1 +twine==3.2.0 +urllib3==1.25.10 +webencodings==0.5.1 diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/main_tests.py b/test/main_tests.py new file mode 100644 index 0000000..9ff509c --- /dev/null +++ b/test/main_tests.py @@ -0,0 +1,70 @@ +import unittest +try: + from datalite import datalite +except ModuleNotFoundError: + import importlib + importlib.import_module('datalite', '../datalite/') +from sqlite3 import connect +from dataclasses import dataclass, asdict +from os import remove + + +@datalite(db_path='test.db') +@dataclass +class TestClass: + integer_value: int = 1 + byte_value: bytes = b'a' + float_value: float = 0.4 + str_value: str = 'a' + + def __eq__(self, other): + return asdict(self) == asdict(other) + + +def getValFromDB(obj_id = 1): + with connect('test.db') as db: + cur = db.cursor() + cur.execute(f'SELECT * FROM testclass WHERE obj_id = {obj_id}') + fields = list(TestClass.__dataclass_fields__.keys()) + fields.sort() + repr = dict(zip(fields, cur.fetchall()[0][1:])) + field_types = {key: value.type for key, value in TestClass.__dataclass_fields__.items()} + for key in fields: + if field_types[key] == bytes: + repr[key] = bytes(repr[key], encoding='utf-8') + test_object = TestClass(**repr) + return test_object + + +class DatabaseMain(unittest.TestCase): + def setUp(self) -> None: + self.test_object = TestClass(12, b'bytes', 0.4, 'TestValue') + + def test_creation(self): + self.test_object.create_entry() + self.assertEqual(self.test_object, getValFromDB()) + + def test_update(self): + self.test_object.create_entry() + self.test_object.integer_value = 40 + self.test_object.update_entry() + from_db = getValFromDB(getattr(self.test_object, 'obj_id')) + self.assertEqual(self.test_object.integer_value, from_db.integer_value) + + def test_delete(self): + with connect('test.db') as db: + cur = db.cursor() + cur.execute('SELECT * FROM testclass') + objects = cur.fetchall() + init_len = len(objects) + self.test_object.create_entry() + self.test_object.remove_entry() + with connect('test.db') as db: + cur = db.cursor() + cur.execute('SELECT * FROM testclass') + objects = cur.fetchall() + self.assertEqual(len(objects), init_len) + + +if __name__ == '__main__': + unittest.main() diff --git a/travis.yml b/travis.yml new file mode 100644 index 0000000..bcdf647 --- /dev/null +++ b/travis.yml @@ -0,0 +1,26 @@ +dist: trusty +language: python +python: + - "3.8" +# command to install dependencies +install: + - curl -L https://codeclimate.com/downloads/test-reporter/test-reporter-latest-linux-amd64 > ./cc-test-reporter + - chmod +x ./cc-test-reporter + - pip install -r requirments.txt +# for codecoverage on codeclimate.com +env: + global: + - GIT_COMMITTED_AT=$(if [ "$TRAVIS_PULL_REQUEST" == "false" ]; then git log -1 --pretty=format:%ct; else git log -1 --skip 1 --pretty=format:%ct; fi) + - CODECLIMATE_REPO_TOKEN=[token] + - CC_TEST_REPORTER_ID=[id] + +before_script: + - ./cc-test-reporter before-build + +script: + - "coverage run -m unittest test/main_tests.py" + +after_script: + - coverage xml + - if [[ "$TRAVIS_PULL_REQUEST" == "false" && "$TRAVIS_PYTHON_VERSION" == "3.6" ]]; then ./cc-test-reporter after-build --exit-code $TRAVIS_TEST_RESULT; fi +# command to run tests \ No newline at end of file