diff --git a/.gitignore b/.gitignore index 1c2d52b..d4ea606 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ .idea/* +src/*.db diff --git a/src/datalite.py b/src/datalite.py index 2f3b1ac..8a0e243 100644 --- a/src/datalite.py +++ b/src/datalite.py @@ -1,27 +1,35 @@ from os.path import exists +from pathlib import Path import sqlite3 as sql -from dataclasses import Field, _MISSING_TYPE -from typing import List, Dict, Optional - +from dataclasses import Field, _MISSING_TYPE, dataclass +from typing import List, Dict, Optional, Callable """ The default type table for conversion between Python types and SQLite3 Datatypes. """ -type_table: Dict[Optional[type], str] = {None: "NULL", int: "INTEGER", float: "REAL", - str: "TEXT", bytes: "BLOB"} -def _database_exists(db_name: str) -> bool: + +def _database_exists(db_path: str) -> bool: """ Check if a given database exists. - :param db_name: Name of the database, including the extension. + :param db_path: Relative path of the database, including the extension. :return: True if database exists, False otherwise. """ - return exists(db_name) + return exists(db_path) -def _sqlify(type_: Optional[type], type_overload: Dict[Optional[type], str]) -> str: +def _create_db(db_path: str) -> None: + """ + Create the database file. + :param db_path: Relative path of the database file, including the extension. + :return: None. + """ + Path(db_path).touch() + + +def _convert_type(type_: Optional[type], type_overload: Dict[Optional[type], str]) -> str: """ Given a Python type, return the str name of its SQLlite equivalent. @@ -29,29 +37,28 @@ def _sqlify(type_: Optional[type], type_overload: Dict[Optional[type], str]) -> :param type_overload: A type table to overload the custom type table. :return: The str name of the sql type. """ - types_dict = type_table.copy() - types_dict.update(type_overload) try: - return types_dict[type_] + return type_overload[type_] except KeyError: raise TypeError("Requested type not in the default or overloaded type table.") -def get_default(default_object: object) -> str: +def _get_default(default_object: object, type_overload: Dict[Optional[type], str]) -> str: """ Check if the field's default object is filled, if filled return the string to be put in the, database. :param default_object: The default field of the field. + :param type_overload: Type overload table. :return: The string to be put on the table statement, empty string if no string is necessary. """ - if isinstance(default_object, _MISSING_TYPE): - return "" - elif isinstance(default_object, str): - return f' DEFAULT "{default_object}"' - else: - return f" DEFAULT {str(default_object)}" + if type(default_object) in type_overload: + if isinstance(default_object, str): + return f' DEFAULT "{default_object}"' + else: + return f" DEFAULT {str(default_object)}" + return "" def _create_table(class_: type, cursor: sql.Cursor, type_overload: Dict[Optional[type], str]) -> None: @@ -65,6 +72,23 @@ def _create_table(class_: type, cursor: sql.Cursor, type_overload: Dict[Optional """ fields: List[Field] = [class_.__dataclass_fields__[key] for key in class_.__dataclass_fields__.keys()] - sql_fields = ', '.join(f"{field.name} {_sqlify(field.type, type_overload)}" - f"{get_default(field.default)}" for field in fields) - cursor.execute(f"CREATE TABLE {class_.__name__.lower} ({sql_fields})") + sql_fields = ', '.join(f"{field.name} {_convert_type(field.type, type_overload)}" + f"{_get_default(field.default, type_overload)}" for field in fields) + sql_fields = "obj_id INTEGER PRIMARY KEY AUTOINCREMENT, " + sql_fields + cursor.execute(f"CREATE TABLE IF NOT EXISTS {class_.__name__.lower()} ({sql_fields});") + + +def sqlify(db_path: str, type_overload: Optional[Dict[Optional[type], str]] = None, + *args, **kwargs) -> Callable: + def decorator(dataclass_: type, *args_i, **kwargs_i): + if not _database_exists(db_path): + _create_db(db_path) + type_table: Dict[Optional[type], str] = {None: "NULL", int: "INTEGER", float: "REAL", + str: "TEXT", bytes: "BLOB"} + if type_overload is not None: + type_table.update(type_overload) + with sql.connect(db_path) as con: + cur: sql.Cursor = con.cursor() + _create_table(dataclass_, cur, type_table) + return dataclass_ + return decorator