diff --git a/src/datalite.py b/src/datalite.py index 8a0e243..590383b 100644 --- a/src/datalite.py +++ b/src/datalite.py @@ -1,14 +1,8 @@ from os.path import exists from pathlib import Path import sqlite3 as sql -from dataclasses import Field, _MISSING_TYPE, dataclass -from typing import List, Dict, Optional, Callable - -""" -The default type table for conversion between -Python types and SQLite3 Datatypes. -""" - +from dataclasses import Field, asdict +from typing import List, Dict, Optional, Callable, Any def _database_exists(db_path: str) -> bool: @@ -43,6 +37,23 @@ def _convert_type(type_: Optional[type], type_overload: Dict[Optional[type], str raise TypeError("Requested type not in the default or overloaded type table.") +def _convert_sql_format(value: Any) -> str: + """ + Given a Python value, convert to string representation + of the equivalent SQL datatype. + :param value: A value, ie: a literal, a variable etc. + :return: The string representation of the SQL equivalent. + >>> _convert_sql_format(1) + "1" + >>> _convert_sql_format("John Smith") + '"John Smith"' + """ + if isinstance(value, str): + return f'"{value}"' + else: + return str(value) + + def _get_default(default_object: object, type_overload: Dict[Optional[type], str]) -> str: """ Check if the field's default object is filled, @@ -72,12 +83,47 @@ def _create_table(class_: type, cursor: sql.Cursor, type_overload: Dict[Optional """ fields: List[Field] = [class_.__dataclass_fields__[key] for key in class_.__dataclass_fields__.keys()] + fields.sort(key=lambda field: field.name) # Since dictionaries *may* be unsorted. sql_fields = ', '.join(f"{field.name} {_convert_type(field.type, type_overload)}" f"{_get_default(field.default, type_overload)}" for field in fields) sql_fields = "obj_id INTEGER PRIMARY KEY AUTOINCREMENT, " + sql_fields cursor.execute(f"CREATE TABLE IF NOT EXISTS {class_.__name__.lower()} ({sql_fields});") +def _create_entry(self, cur: sql.Cursor) -> None: + """ + Given an object, create the entry for the object. As a side-effect, + this will set the object_id attribute of the object to the unique + id of the entry. + :param cur: Cursor of the database. + :param self: Instance of the object. + :param args: Initialisation arguments. + :param kwargs: Initialisation keyword arguments. + :return: None. + """ + table_name: str = self.__class__.__name__.lower() + kv_pairs = [item for item in asdict(self).items()] + kv_pairs.sort(key=lambda item: item[0]) # Sort by the name of the fields. + cur.execute(f"INSERT INTO {table_name}(" + f"{', '.join(item[0] for item in kv_pairs)})" + f" VALUES ({', '.join(_convert_sql_format(item[1]) for item in kv_pairs)});") + self.__setattr__("obj_id", cur.lastrowid) + + +def _modify_init(dataclass_: type): + def modifier(self, *args, **kwargs): + self.__init__() + if "create_entry" in kwargs and kwargs["create_entry"]: + try: + with sql.connect(dataclass_.__db_path__) as con: + cur: sql.Cursor = con.cursor() + self._create_entry(cur) + con.commit() + except AttributeError: + raise TypeError("Are you sure this is a datalite class?") + return modifier + + def sqlify(db_path: str, type_overload: Optional[Dict[Optional[type], str]] = None, *args, **kwargs) -> Callable: def decorator(dataclass_: type, *args_i, **kwargs_i): @@ -90,5 +136,7 @@ def sqlify(db_path: str, type_overload: Optional[Dict[Optional[type], str]] = No with sql.connect(db_path) as con: cur: sql.Cursor = con.cursor() _create_table(dataclass_, cur, type_table) + dataclass_.__db_path__ == db_path # We add the path of the database to class itself. + dataclass_.__init__ = _modify_init(dataclass_) # Replace the init method. return dataclass_ return decorator