Add new entry creation.

This commit is contained in:
Ege Emir Özkan
2020-08-01 18:24:36 +03:00
parent 4cf92ca963
commit 0ddc9dd19b

View File

@@ -1,14 +1,8 @@
from os.path import exists from os.path import exists
from pathlib import Path from pathlib import Path
import sqlite3 as sql import sqlite3 as sql
from dataclasses import Field, _MISSING_TYPE, dataclass from dataclasses import Field, asdict
from typing import List, Dict, Optional, Callable from typing import List, Dict, Optional, Callable, Any
"""
The default type table for conversion between
Python types and SQLite3 Datatypes.
"""
def _database_exists(db_path: str) -> bool: def _database_exists(db_path: str) -> bool:
@@ -43,6 +37,23 @@ def _convert_type(type_: Optional[type], type_overload: Dict[Optional[type], str
raise TypeError("Requested type not in the default or overloaded type table.") raise TypeError("Requested type not in the default or overloaded type table.")
def _convert_sql_format(value: Any) -> str:
"""
Given a Python value, convert to string representation
of the equivalent SQL datatype.
:param value: A value, ie: a literal, a variable etc.
:return: The string representation of the SQL equivalent.
>>> _convert_sql_format(1)
"1"
>>> _convert_sql_format("John Smith")
'"John Smith"'
"""
if isinstance(value, str):
return f'"{value}"'
else:
return str(value)
def _get_default(default_object: object, type_overload: Dict[Optional[type], str]) -> str: def _get_default(default_object: object, type_overload: Dict[Optional[type], str]) -> str:
""" """
Check if the field's default object is filled, Check if the field's default object is filled,
@@ -72,12 +83,47 @@ def _create_table(class_: type, cursor: sql.Cursor, type_overload: Dict[Optional
""" """
fields: List[Field] = [class_.__dataclass_fields__[key] for fields: List[Field] = [class_.__dataclass_fields__[key] for
key in class_.__dataclass_fields__.keys()] key in class_.__dataclass_fields__.keys()]
fields.sort(key=lambda field: field.name) # Since dictionaries *may* be unsorted.
sql_fields = ', '.join(f"{field.name} {_convert_type(field.type, type_overload)}" sql_fields = ', '.join(f"{field.name} {_convert_type(field.type, type_overload)}"
f"{_get_default(field.default, type_overload)}" for field in fields) f"{_get_default(field.default, type_overload)}" for field in fields)
sql_fields = "obj_id INTEGER PRIMARY KEY AUTOINCREMENT, " + sql_fields sql_fields = "obj_id INTEGER PRIMARY KEY AUTOINCREMENT, " + sql_fields
cursor.execute(f"CREATE TABLE IF NOT EXISTS {class_.__name__.lower()} ({sql_fields});") cursor.execute(f"CREATE TABLE IF NOT EXISTS {class_.__name__.lower()} ({sql_fields});")
def _create_entry(self, cur: sql.Cursor) -> None:
"""
Given an object, create the entry for the object. As a side-effect,
this will set the object_id attribute of the object to the unique
id of the entry.
:param cur: Cursor of the database.
:param self: Instance of the object.
:param args: Initialisation arguments.
:param kwargs: Initialisation keyword arguments.
:return: None.
"""
table_name: str = self.__class__.__name__.lower()
kv_pairs = [item for item in asdict(self).items()]
kv_pairs.sort(key=lambda item: item[0]) # Sort by the name of the fields.
cur.execute(f"INSERT INTO {table_name}("
f"{', '.join(item[0] for item in kv_pairs)})"
f" VALUES ({', '.join(_convert_sql_format(item[1]) for item in kv_pairs)});")
self.__setattr__("obj_id", cur.lastrowid)
def _modify_init(dataclass_: type):
def modifier(self, *args, **kwargs):
self.__init__()
if "create_entry" in kwargs and kwargs["create_entry"]:
try:
with sql.connect(dataclass_.__db_path__) as con:
cur: sql.Cursor = con.cursor()
self._create_entry(cur)
con.commit()
except AttributeError:
raise TypeError("Are you sure this is a datalite class?")
return modifier
def sqlify(db_path: str, type_overload: Optional[Dict[Optional[type], str]] = None, def sqlify(db_path: str, type_overload: Optional[Dict[Optional[type], str]] = None,
*args, **kwargs) -> Callable: *args, **kwargs) -> Callable:
def decorator(dataclass_: type, *args_i, **kwargs_i): def decorator(dataclass_: type, *args_i, **kwargs_i):
@@ -90,5 +136,7 @@ def sqlify(db_path: str, type_overload: Optional[Dict[Optional[type], str]] = No
with sql.connect(db_path) as con: with sql.connect(db_path) as con:
cur: sql.Cursor = con.cursor() cur: sql.Cursor = con.cursor()
_create_table(dataclass_, cur, type_table) _create_table(dataclass_, cur, type_table)
dataclass_.__db_path__ == db_path # We add the path of the database to class itself.
dataclass_.__init__ = _modify_init(dataclass_) # Replace the init method.
return dataclass_ return dataclass_
return decorator return decorator