diff --git a/aiodatalite/commons.py b/aiodatalite/commons.py index 26eac90..440104a 100644 --- a/aiodatalite/commons.py +++ b/aiodatalite/commons.py @@ -1,3 +1,4 @@ +import sqlite3 from dataclasses import MISSING, Field from pickle import HIGHEST_PROTOCOL, dumps, loads from typing import Any, Dict, List, Optional @@ -103,6 +104,28 @@ def _get_default( return " DEFAULT ?" +def _get_creation_data( + class_: type, + type_overload: Dict[Optional[type], str], + type_converter, +): + fields: List[Field] = [ + class_.__dataclass_fields__[key] for key in class_.__dataclass_fields__.keys() + ] + fields.sort(key=lambda field: field.name) # Since dictionaries *may* be unsorted. + + def_params = list() + + sql_fields = ", ".join( + f"{field.name} {type_converter(field.type, type_overload)}" + f"{_get_default(field.default, type_overload, def_params)}" + for field in fields + ) + + sql_fields = "obj_id INTEGER PRIMARY KEY AUTOINCREMENT, " + sql_fields + return sql_fields, def_params + + # noinspection PyDefaultArgument async def _tweaked_create_table( class_: type, @@ -129,27 +152,46 @@ async def _create_table( with a custom table, this is that custom table. :return: None. """ - # noinspection PyUnresolvedReferences - fields: List[Field] = [ - class_.__dataclass_fields__[key] for key in class_.__dataclass_fields__.keys() - ] - fields.sort(key=lambda field: field.name) # Since dictionaries *may* be unsorted. - - def_params = list() - - sql_fields = ", ".join( - f"{field.name} {type_converter(field.type, type_overload)}" - f"{_get_default(field.default, type_overload, def_params)}" - for field in fields + sql_fields, def_params = _get_creation_data( + class_, type_overload, type_converter=type_converter ) - - sql_fields = "obj_id INTEGER PRIMARY KEY AUTOINCREMENT, " + sql_fields + print(sql_fields) + print(def_params) await cursor.execute( f"CREATE TABLE IF NOT EXISTS {class_.__name__.lower()} ({sql_fields});", def_params if def_params else None, ) +# noinspection PyDefaultArgument +def _sync_create_table( + class_: type, + cursor: sqlite3.Cursor, + type_overload: Dict[Optional[type], str] = type_table, + type_converter=_convert_type, +) -> None: + sql_fields, def_params = _get_creation_data( + class_, type_overload, type_converter=type_converter + ) + print(sql_fields) + print(def_params) + cursor.execute( + f"CREATE TABLE IF NOT EXISTS {class_.__name__.lower()} ({sql_fields});", + def_params if def_params else (), + ) + + +# noinspection PyDefaultArgument +def _tweaked_sync_create_table( + class_: type, + cursor: sqlite3.Cursor, + type_overload: Dict[Optional[type], str] = type_table, +) -> None: + _sync_create_table( + class_, cursor, type_overload, type_converter=_tweaked_convert_type + ) + + def _tweaked_dump_value(self, value): if type(value) in self.types_table: return value diff --git a/aiodatalite/datalite_decorator.py b/aiodatalite/datalite_decorator.py index 00585d0..3ec20e8 100644 --- a/aiodatalite/datalite_decorator.py +++ b/aiodatalite/datalite_decorator.py @@ -2,13 +2,21 @@ Defines the Datalite decorator that can be used to convert a dataclass to a class bound to an sqlite3 database. """ +import sqlite3 from dataclasses import asdict, fields from typing import Callable, Dict, Optional import aiosqlite from aiosqlite import IntegrityError -from .commons import _create_table, _tweaked_create_table, _tweaked_dump, type_table +from .commons import ( + _create_table, + _sync_create_table, + _tweaked_create_table, + _tweaked_dump, + _tweaked_sync_create_table, + type_table, +) from .constraints import ConstraintFailedError @@ -132,6 +140,7 @@ def datalite( db_path: str, type_overload: Optional[Dict[Optional[type], str]] = None, tweaked: bool = True, + automarkup: bool = False, ) -> Callable: """Bind a dataclass to a sqlite3 database. This adds new methods to the class, such as `create_entry()`, `remove_entry()` and `update_entry()`. @@ -139,6 +148,7 @@ def datalite( :param db_path: Path of the database to be bound. :param type_overload: Type overload dictionary. :param tweaked: Whether to use pickle type tweaks + :param automarkup: Whether to use automarkup (synchronously) :return: The new dataclass. """ @@ -151,6 +161,14 @@ def datalite( setattr(dataclass_, "types_table", types_table) setattr(dataclass_, "tweaked", tweaked) + if automarkup: + with sqlite3.connect(db_path) as con: + cur: sqlite3.Cursor = con.cursor() + if tweaked: + _tweaked_sync_create_table(dataclass_, cur, types_table) + else: + _sync_create_table(dataclass_, cur, types_table) + if tweaked: dataclass_.markup_table = _markup_table(_tweaked_create_table) dataclass_.create_entry = _tweaked_create_entry