diff --git a/datalite/__init__.py b/datalite/__init__.py index 2834220..ed2ba27 100644 --- a/datalite/__init__.py +++ b/datalite/__init__.py @@ -1,2 +1,10 @@ -__all__ = ['commons', 'datalite_decorator', 'fetch', 'migrations', 'datalite', 'constraints', 'mass_actions'] +__all__ = [ + "commons", + "datalite_decorator", + "fetch", + "migrations", + "datalite", + "constraints", + "mass_actions", +] from .datalite_decorator import datalite diff --git a/datalite/commons.py b/datalite/commons.py index c9cab8b..9b9bc07 100644 --- a/datalite/commons.py +++ b/datalite/commons.py @@ -1,15 +1,28 @@ from dataclasses import Field -from typing import Any, Optional, Dict, List +from pickle import HIGHEST_PROTOCOL, dumps, loads +from typing import Any, Dict, List, Optional + +import aiosqlite +import base64 + from .constraints import Unique -import sqlite3 as sql + +type_table: Dict[Optional[type], str] = { + None: "NULL", + int: "INTEGER", + float: "REAL", + str: "TEXT", + bytes: "BLOB", + bool: "INTEGER", +} +type_table.update( + {Unique[key]: f"{value} NOT NULL UNIQUE" for key, value in type_table.items()} +) -type_table: Dict[Optional[type], str] = {None: "NULL", int: "INTEGER", float: "REAL", - str: "TEXT", bytes: "BLOB", bool: "INTEGER"} -type_table.update({Unique[key]: f"{value} NOT NULL UNIQUE" for key, value in type_table.items()}) - - -def _convert_type(type_: Optional[type], type_overload: Dict[Optional[type], str]) -> str: +def _convert_type( + type_: Optional[type], type_overload: Dict[Optional[type], str] +) -> str: """ Given a Python type, return the str name of its SQLlite equivalent. @@ -22,19 +35,24 @@ def _convert_type(type_: Optional[type], type_overload: Dict[Optional[type], str try: return type_overload[type_] except KeyError: - raise TypeError("Requested type not in the default or overloaded type table.") + raise TypeError( + "Requested type not in the default or overloaded type table. Use @datalite(tweaked=True) to " + "encode custom types" + ) -def _convert_sql_format(value: Any) -> str: +def _tweaked_convert_type( + type_: Optional[type], type_overload: Dict[Optional[type], str] +) -> str: + return type_overload.get(type_, "BLOB") + + +def _convert_sql_format(value: Any, type_overload: Dict[Optional[type], str]) -> str: """ Given a Python value, convert to string representation of the equivalent SQL datatype. :param value: A value, ie: a literal, a variable etc. :return: The string representation of the SQL equivalent. - >>> _convert_sql_format(1) - "1" - >>> _convert_sql_format("John Smith") - '"John Smith"' """ if value is None: return "NULL" @@ -44,11 +62,13 @@ def _convert_sql_format(value: Any) -> str: return '"' + str(value).replace("b'", "")[:-1] + '"' elif isinstance(value, bool): return "TRUE" if value else "FALSE" - else: + elif type(value) in type_overload: return str(value) + else: + return '"' + base64.encodebytes(dumps(value, protocol=HIGHEST_PROTOCOL)).decode() + '"' -def _get_table_cols(cur: sql.Cursor, table_name: str) -> List[str]: +async def _get_table_cols(cur: aiosqlite.Cursor, table_name: str) -> List[str]: """ Get the column data of a table. @@ -56,11 +76,13 @@ def _get_table_cols(cur: sql.Cursor, table_name: str) -> List[str]: :param table_name: Name of the table. :return: the information about columns. """ - cur.execute(f"PRAGMA table_info({table_name});") - return [row_info[1] for row_info in cur.fetchall()][1:] + await cur.execute(f"PRAGMA table_info({table_name});") + return [row_info[1] for row_info in await cur.fetchall()][1:] -def _get_default(default_object: object, type_overload: Dict[Optional[type], str]) -> str: +def _get_default( + default_object: object, type_overload: Dict[Optional[type], str] +) -> str: """ Check if the field's default object is filled, if filled return the string to be put in the, @@ -71,11 +93,28 @@ def _get_default(default_object: object, type_overload: Dict[Optional[type], str empty string if no string is necessary. """ if type(default_object) in type_overload: - return f' DEFAULT {_convert_sql_format(default_object)}' + return f" DEFAULT {_convert_sql_format(default_object, type_overload)}" return "" -def _create_table(class_: type, cursor: sql.Cursor, type_overload: Dict[Optional[type], str] = type_table) -> None: +# noinspection PyDefaultArgument +async def _tweaked_create_table( + class_: type, + cursor: aiosqlite.Cursor, + type_overload: Dict[Optional[type], str] = type_table, +) -> None: + await _create_table( + class_, cursor, type_overload, type_converter=_tweaked_convert_type + ) + + +# noinspection PyDefaultArgument +async def _create_table( + class_: type, + cursor: aiosqlite.Cursor, + type_overload: Dict[Optional[type], str] = type_table, + type_converter=_convert_type, +) -> None: """ Create the table for a specific dataclass given :param class_: A dataclass. @@ -84,10 +123,35 @@ def _create_table(class_: type, cursor: sql.Cursor, type_overload: Dict[Optional with a custom table, this is that custom table. :return: None. """ - fields: List[Field] = [class_.__dataclass_fields__[key] for - key in class_.__dataclass_fields__.keys()] + # noinspection PyUnresolvedReferences + fields: List[Field] = [ + class_.__dataclass_fields__[key] for key in class_.__dataclass_fields__.keys() + ] fields.sort(key=lambda field: field.name) # Since dictionaries *may* be unsorted. - sql_fields = ', '.join(f"{field.name} {_convert_type(field.type, type_overload)}" - f"{_get_default(field.default, type_overload)}" for field in fields) + + sql_fields = ", ".join( + f"{field.name} {type_converter(field.type, type_overload)}" + f"{_get_default(field.default, type_overload)}" + for field in fields + ) + sql_fields = "obj_id INTEGER PRIMARY KEY AUTOINCREMENT, " + sql_fields - cursor.execute(f"CREATE TABLE IF NOT EXISTS {class_.__name__.lower()} ({sql_fields});") + await cursor.execute( + f"CREATE TABLE IF NOT EXISTS {class_.__name__.lower()} ({sql_fields});" + ) + + +def _tweaked_dump_value(self, value): + if type(value) in self.types_table: + return value + else: + return bytes(dumps(value, protocol=HIGHEST_PROTOCOL)) + + +def _tweaked_dump(self, name): + value = getattr(self, name) + return _tweaked_dump_value(self, value) + + +def _tweaked_load_value(data): + return loads(bytes(data)) diff --git a/datalite/constraints.py b/datalite/constraints.py index 310ce3e..d7cb26d 100644 --- a/datalite/constraints.py +++ b/datalite/constraints.py @@ -4,15 +4,16 @@ datalite.constraints module introduces constraint that can be used to signal datalite decorator constraints in the database. """ -from typing import TypeVar, Union, Tuple +from typing import Tuple, TypeVar, Union -T = TypeVar('T') +T = TypeVar("T") class ConstraintFailedError(Exception): """ This exception is raised when a Constraint fails. """ + pass @@ -24,3 +25,4 @@ Dataclass fields hinted with this type signals Unique = Union[Tuple[T], T] +__all__ = ['Unique', 'ConstraintFailedError'] diff --git a/datalite/datalite_decorator.py b/datalite/datalite_decorator.py index 9a2ddfe..99178d6 100644 --- a/datalite/datalite_decorator.py +++ b/datalite/datalite_decorator.py @@ -1,93 +1,160 @@ """ Defines the Datalite decorator that can be used to convert a dataclass to -a class bound to a sqlite3 database. +a class bound to an sqlite3 database. """ +from dataclasses import asdict, fields from sqlite3.dbapi2 import IntegrityError -from typing import Dict, Optional, Callable -from dataclasses import asdict -import sqlite3 as sql +from typing import Callable, Dict, Optional +import aiosqlite + +from .commons import _create_table, _tweaked_create_table, _tweaked_dump, type_table from .constraints import ConstraintFailedError -from .commons import _convert_sql_format, _convert_type, _create_table, type_table -def _create_entry(self) -> None: +async def _create_entry(self) -> None: """ - Given an object, create the entry for the object. As a side-effect, + Given an object, create the entry for the object. As a side effect, this will set the object_id attribute of the object to the unique id of the entry. :param self: Instance of the object. :return: None. """ - with sql.connect(getattr(self, "db_path")) as con: - cur: sql.Cursor = con.cursor() + async with aiosqlite.connect(getattr(self, "db_path")) as con: + cur: aiosqlite.Cursor = await con.cursor() table_name: str = self.__class__.__name__.lower() kv_pairs = [item for item in asdict(self).items()] kv_pairs.sort(key=lambda item: item[0]) # Sort by the name of the fields. try: - cur.execute(f"INSERT INTO {table_name}(" - f"{', '.join(item[0] for item in kv_pairs)})" - f" VALUES ({', '.join('?' for item in kv_pairs)})", - [item[1] for item in kv_pairs]) + await cur.execute( + f"INSERT INTO {table_name}(" + f"{', '.join(item[0] for item in kv_pairs)})" + f" VALUES ({', '.join('?' for _ in kv_pairs)})", + [item[1] for item in kv_pairs], + ) self.__setattr__("obj_id", cur.lastrowid) - con.commit() + await con.commit() except IntegrityError: raise ConstraintFailedError("A constraint has failed.") -def _update_entry(self) -> None: +async def _tweaked_create_entry(self) -> None: + async with aiosqlite.connect(getattr(self, "db_path")) as con: + cur: aiosqlite.Cursor = await con.cursor() + table_name: str = self.__class__.__name__.lower() + kv_pairs = [item for item in fields(self)] + kv_pairs.sort(key=lambda item: item.name) # Sort by the name of the fields. + try: + await cur.execute( + f"INSERT INTO {table_name}(" + f"{', '.join(item.name for item in kv_pairs)})" + f" VALUES ({', '.join('?' for _ in kv_pairs)})", + [_tweaked_dump(self, item.name) for item in kv_pairs], + ) + self.__setattr__("obj_id", cur.lastrowid) + await con.commit() + except IntegrityError: + raise ConstraintFailedError("A constraint has failed.") + + +async def _update_entry(self) -> None: """ Given an object, update the objects entry in the bound database. :param self: The object. :return: None. """ - with sql.connect(getattr(self, "db_path")) as con: - cur: sql.Cursor = con.cursor() + async with aiosqlite.connect(getattr(self, "db_path")) as con: + cur: aiosqlite.Cursor = await con.cursor() table_name: str = self.__class__.__name__.lower() kv_pairs = [item for item in asdict(self).items()] kv_pairs.sort(key=lambda item: item[0]) - query = f"UPDATE {table_name} " + \ - f"SET {', '.join(item[0] + ' = ?' for item in kv_pairs)} " + \ - f"WHERE obj_id = {getattr(self, 'obj_id')};" - cur.execute(query, [item[1] for item in kv_pairs]) - con.commit() + query = ( + f"UPDATE {table_name} " + f"SET {', '.join(item[0] + ' = ?' for item in kv_pairs)} " + f"WHERE obj_id = {getattr(self, 'obj_id')};" + ) + await cur.execute(query, [item[1] for item in kv_pairs]) + await con.commit() -def remove_from(class_: type, obj_id: int): - with sql.connect(getattr(class_, "db_path")) as con: - cur: sql.Cursor = con.cursor() - cur.execute(f"DELETE FROM {class_.__name__.lower()} WHERE obj_id = ?", (obj_id, )) - con.commit() +async def _tweaked_update_entry(self) -> None: + async with aiosqlite.connect(getattr(self, "db_path")) as con: + cur: aiosqlite.Cursor = await con.cursor() + table_name: str = self.__class__.__name__.lower() + kv_pairs = [item for item in fields(self)] + kv_pairs.sort(key=lambda item: item.name) + query = ( + f"UPDATE {table_name} " + f"SET {', '.join(item.name + ' = ?' for item in kv_pairs)} " + f"WHERE obj_id = {getattr(self, 'obj_id')};" + ) + await cur.execute(query, [_tweaked_dump(self, item.name) for item in kv_pairs]) + await con.commit() -def _remove_entry(self) -> None: +async def remove_from(class_: type, obj_id: int): + async with aiosqlite.connect(getattr(class_, "db_path")) as con: + cur: aiosqlite.Cursor = await con.cursor() + await cur.execute( + f"DELETE FROM {class_.__name__.lower()} WHERE obj_id = ?", (obj_id,) + ) + await con.commit() + + +async def _remove_entry(self) -> None: """ Remove the object's record in the underlying database. :param self: self instance. :return: None. """ - remove_from(self.__class__, getattr(self, 'obj_id')) + await remove_from(self.__class__, getattr(self, "obj_id")) -def datalite(db_path: str, type_overload: Optional[Dict[Optional[type], str]] = None) -> Callable: +def _markup_table(markup_function): + async def inner(self=None, **kwargs): + if not kwargs: + async with aiosqlite.connect(getattr(self, "db_path")) as con: + cur: aiosqlite.Cursor = await con.cursor() + await markup_function(self.__class__, cur, self.types_table) + else: + await markup_function(**kwargs) + + return inner + + +def datalite( + db_path: str, + type_overload: Optional[Dict[Optional[type], str]] = None, + tweaked: bool = True, +) -> Callable: """Bind a dataclass to a sqlite3 database. This adds new methods to the class, such as `create_entry()`, `remove_entry()` and `update_entry()`. - :param db_path: Path of the database to be binded. + :param db_path: Path of the database to be bound. :param type_overload: Type overload dictionary. + :param tweaked: Whether to use pickle type tweaks :return: The new dataclass. """ - def decorator(dataclass_: type, *args_i, **kwargs_i): + + def decorator(dataclass_: type, *_, **__): types_table = type_table.copy() if type_overload is not None: types_table.update(type_overload) - with sql.connect(db_path) as con: - cur: sql.Cursor = con.cursor() - _create_table(dataclass_, cur, types_table) - setattr(dataclass_, 'db_path', db_path) # We add the path of the database to class itself. - setattr(dataclass_, 'types_table', types_table) # We add the type table for migration. - dataclass_.create_entry = _create_entry + + setattr(dataclass_, "db_path", db_path) + setattr(dataclass_, "types_table", types_table) + setattr(dataclass_, "tweaked", tweaked) + + if tweaked: + dataclass_.markup_table = _markup_table(_tweaked_create_table) + dataclass_.create_entry = _tweaked_create_entry + dataclass_.update_entry = _tweaked_update_entry + else: + dataclass_.markup_table = _markup_table(_create_table) + dataclass_.create_entry = _create_entry + dataclass_.update_entry = _update_entry dataclass_.remove_entry = _remove_entry - dataclass_.update_entry = _update_entry + return dataclass_ + return decorator diff --git a/datalite/fetch.py b/datalite/fetch.py index 99f53b3..51c1e49 100644 --- a/datalite/fetch.py +++ b/datalite/fetch.py @@ -1,6 +1,12 @@ -import sqlite3 as sql -from typing import List, Tuple, Any -from .commons import _convert_sql_format, _get_table_cols +import aiosqlite +from typing import Any, List, Tuple, TypeVar, Type, cast + +from .commons import _convert_sql_format, _get_table_cols, _tweaked_load_value + +import base64 + + +T = TypeVar("T") def _insert_pagination(query: str, page: int, element_count: int) -> str: @@ -17,7 +23,7 @@ def _insert_pagination(query: str, page: int, element_count: int) -> str: return query + ";" -def is_fetchable(class_: type, obj_id: int) -> bool: +async def is_fetchable(class_: type, obj_id: int) -> bool: """ Check if a record is fetchable given its obj_id and class_ type. @@ -26,16 +32,22 @@ def is_fetchable(class_: type, obj_id: int) -> bool: :param obj_id: Unique obj_id of the object. :return: If the object is fetchable. """ - with sql.connect(getattr(class_, 'db_path')) as con: - cur: sql.Cursor = con.cursor() + async with aiosqlite.connect(getattr(class_, "db_path")) as con: + cur: aiosqlite.Cursor = await con.cursor() try: - cur.execute(f"SELECT 1 FROM {class_.__name__.lower()} WHERE obj_id = ?;", (obj_id, )) - except sql.OperationalError: + await cur.execute( + f"SELECT 1 FROM {class_.__name__.lower()} WHERE obj_id = ?;", (obj_id,) + ) + except aiosqlite.OperationalError: raise KeyError(f"Table {class_.__name__.lower()} does not exist.") - return bool(cur.fetchall()) + return bool(await cur.fetchall()) -def fetch_equals(class_: type, field: str, value: Any, ) -> Any: +async def fetch_equals( + class_: Type[T], + field: str, + value: Any, +) -> T: """ Fetch a class_ type variable from its bound db. @@ -45,18 +57,30 @@ def fetch_equals(class_: type, field: str, value: Any, ) -> Any: :return: The object whose data is taken from the database. """ table_name = class_.__name__.lower() - with sql.connect(getattr(class_, 'db_path')) as con: - cur: sql.Cursor = con.cursor() - cur.execute(f"SELECT * FROM {table_name} WHERE {field} = ?;", (value, )) - obj_id, *field_values = list(cur.fetchone()) - field_names: List[str] = _get_table_cols(cur, class_.__name__.lower()) + async with aiosqlite.connect(getattr(class_, "db_path")) as con: + cur: aiosqlite.Cursor = await con.cursor() + await cur.execute(f"SELECT * FROM {table_name} WHERE {field} = ?;", (value,)) + obj_id, *field_values = list(await cur.fetchone()) + field_names: List[str] = await _get_table_cols(cur, class_.__name__.lower()) + kwargs = dict(zip(field_names, field_values)) + if getattr(class_, "tweaked"): + types_table = getattr(class_, "types_table") + field_types = { + key: value.type for key, value in class_.__dataclass_fields__.items() + } + for key in kwargs.keys(): + if field_types[key] not in types_table.keys(): + kwargs[key] = _tweaked_load_value( + kwargs[key] if isinstance(kwargs[key], bytes) else base64.decodebytes(kwargs[key].encode("utf-8")) + ) + obj = class_(**kwargs) setattr(obj, "obj_id", obj_id) return obj -def fetch_from(class_: type, obj_id: int) -> Any: +async def fetch_from(class_: Type[T], obj_id: int) -> T: """ Fetch a class_ type variable from its bound dv. @@ -64,13 +88,17 @@ def fetch_from(class_: type, obj_id: int) -> Any: :param obj_id: Unique object id of the object. :return: The fetched object. """ - if not is_fetchable(class_, obj_id): - raise KeyError(f"An object with {obj_id} of type {class_.__name__} does not exist, or" - f"otherwise is unreachable.") - return fetch_equals(class_, 'obj_id', obj_id) + if not await is_fetchable(class_, obj_id): + raise KeyError( + f"An object with {obj_id} of type {class_.__name__} does not exist, or" + f"otherwise is unreachable." + ) + return await fetch_equals(class_, "obj_id", obj_id) -def _convert_record_to_object(class_: type, record: Tuple[Any], field_names: List[str]) -> Any: +def _convert_record_to_object( + class_: Type[T], record: Tuple[Any], field_names: List[str] +) -> T: """ Convert a given record fetched from an SQL instance to a Python Object of given class_. @@ -80,17 +108,30 @@ def _convert_record_to_object(class_: type, record: Tuple[Any], field_names: Lis :return: the created object. """ kwargs = dict(zip(field_names, record[1:])) - field_types = {key: value.type for key, value in class_.__dataclass_fields__.items()} - for key in kwargs: + field_types = { + key: value.type for key, value in class_.__dataclass_fields__.items() + } + is_tweaked = getattr(class_, "tweaked") + types_table = getattr(class_, "types_table") + for key in kwargs.keys(): if field_types[key] == bytes: - kwargs[key] = bytes(kwargs[key], encoding='utf-8') + kwargs[key] = bytes(kwargs[key], encoding="utf-8") + + elif is_tweaked: + if field_types[key] not in types_table.keys(): + kwargs[key] = _tweaked_load_value( + kwargs[key] if isinstance(kwargs[key], bytes) else base64.decodebytes(kwargs[key].encode("utf-8")) + ) + obj_id = record[0] obj = class_(**kwargs) setattr(obj, "obj_id", obj_id) return obj -def fetch_if(class_: type, condition: str, page: int = 0, element_count: int = 10) -> tuple: +async def fetch_if( + class_: Type[T], condition: str, page: int = 0, element_count: int = 10 +) -> T: """ Fetch all class_ type variables from the bound db, provided they fit the given condition @@ -103,18 +144,27 @@ def fetch_if(class_: type, condition: str, page: int = 0, element_count: int = 1 of given type class_. """ table_name = class_.__name__.lower() - with sql.connect(getattr(class_, 'db_path')) as con: - cur: sql.Cursor = con.cursor() - cur.execute(_insert_pagination(f"SELECT * FROM {table_name} WHERE {condition}", page, element_count)) - records: list = cur.fetchall() - field_names: List[str] = _get_table_cols(cur, table_name) - return tuple(_convert_record_to_object(class_, record, field_names) for record in records) + async with aiosqlite.connect(getattr(class_, "db_path")) as con: + cur: aiosqlite.Cursor = await con.cursor() + await cur.execute( + _insert_pagination( + f"SELECT * FROM {table_name} WHERE {condition}", page, element_count + ) + ) + # noinspection PyTypeChecker + records: list = await cur.fetchall() + field_names: List[str] = await _get_table_cols(cur, table_name) + return tuple( + _convert_record_to_object(class_, record, field_names) for record in records + ) -def fetch_where(class_: type, field: str, value: Any, page: int = 0, element_count: int = 10) -> tuple: +async def fetch_where( + class_: Type[T], field: str, value: Any, page: int = 0, element_count: int = 10 +) -> tuple[T]: """ - Fetch all class_ type variables from the bound db, - provided that the field of the records fit the + Fetch all class_ type variables from the bound db + if the field of the records fit the given value. :param class_: Class of the records. @@ -124,10 +174,12 @@ def fetch_where(class_: type, field: str, value: Any, page: int = 0, element_cou :param element_count: Element count in each page. :return: A tuple of the records. """ - return fetch_if(class_, f"{field} = {_convert_sql_format(value)}", page, element_count) + return await fetch_if( + class_, f"{field} = {_convert_sql_format(value, getattr(class_, 'types_table'))}", page, element_count + ) -def fetch_range(class_: type, range_: range) -> tuple: +async def fetch_range(class_: Type[T], range_: range) -> tuple[T]: """ Fetch the records in a given range of object ids. @@ -136,12 +188,23 @@ def fetch_range(class_: type, range_: range) -> tuple: :return: A tuple of class_ type objects whose values come from the class_' bound database. """ - return tuple(fetch_from(class_, obj_id) for obj_id in range_ if is_fetchable(class_, obj_id)) + return cast( + tuple[T], + tuple( + [ + (await fetch_from(class_, obj_id)) + for obj_id in range_ + if (await is_fetchable(class_, obj_id)) + ] + ), + ) -def fetch_all(class_: type, page: int = 0, element_count: int = 10) -> tuple: +async def fetch_all( + class_: Type[T], page: int = 0, element_count: int = 10 +) -> tuple[T]: """ - Fetchall the records in the bound database. + Fetch all the records in the bound database. :param class_: Class of the records. :param page: Which page to retrieve, default all. (0 means closed). @@ -150,15 +213,26 @@ def fetch_all(class_: type, page: int = 0, element_count: int = 10) -> tuple: the bound database as a tuple. """ try: - db_path = getattr(class_, 'db_path') + db_path = getattr(class_, "db_path") except AttributeError: raise TypeError("Given class is not decorated with datalite.") - with sql.connect(db_path) as con: - cur: sql.Cursor = con.cursor() + async with aiosqlite.connect(db_path) as con: + cur: aiosqlite.Cursor = await con.cursor() try: - cur.execute(_insert_pagination(f"SELECT * FROM {class_.__name__.lower()}", page, element_count)) - except sql.OperationalError: + await cur.execute( + _insert_pagination( + f"SELECT * FROM {class_.__name__.lower()}", page, element_count + ) + ) + except aiosqlite.OperationalError: raise TypeError(f"No record of type {class_.__name__.lower()}") - records = cur.fetchall() - field_names: List[str] = _get_table_cols(cur, class_.__name__.lower()) - return tuple(_convert_record_to_object(class_, record, field_names) for record in records) + records = await cur.fetchall() + field_names: List[str] = await _get_table_cols(cur, class_.__name__.lower()) + # noinspection PyTypeChecker + return cast( + tuple[T], + tuple(_convert_record_to_object(class_, record, field_names) for record in records), + ) + + +__all__ = ["is_fetchable", "fetch_equals", "fetch_from", "fetch_if", "fetch_where", "fetch_range", "fetch_all"] diff --git a/datalite/mass_actions.py b/datalite/mass_actions.py index 2de2980..0f9369c 100644 --- a/datalite/mass_actions.py +++ b/datalite/mass_actions.py @@ -3,14 +3,15 @@ This module includes functions to insert multiple records to a bound database at one time, with one time open and closing of the database file. """ -from typing import TypeVar, Union, List, Tuple +import aiosqlite from dataclasses import asdict +from typing import List, Tuple, TypeVar, Union from warnings import warn -from .constraints import ConstraintFailedError -from .commons import _convert_sql_format, _create_table -import sqlite3 as sql -T = TypeVar('T') +from .commons import _convert_sql_format +from .constraints import ConstraintFailedError + +T = TypeVar("T") class HeterogeneousCollectionError(Exception): @@ -19,45 +20,56 @@ class HeterogeneousCollectionError(Exception): ie: If a List or Tuple has elements of multiple types. """ + pass def _check_homogeneity(objects: Union[List[T], Tuple[T]]) -> None: """ - Check if all of the members a Tuple or a List + Check if all the members a Tuple or a List is of the same type. :param objects: Tuple or list to check. - :return: If all of the members of the same type. + :return: If all the members of the same type. """ class_ = objects[0].__class__ - if not all([isinstance(obj, class_) or isinstance(objects[0], obj.__class__) for obj in objects]): + if not all( + [ + isinstance(obj, class_) or isinstance(objects[0], obj.__class__) + for obj in objects + ] + ): raise HeterogeneousCollectionError("Tuple or List is not homogeneous.") -def _toggle_memory_protection(cur: sql.Cursor, protect_memory: bool) -> None: +async def _toggle_memory_protection(cur: aiosqlite.Cursor, protect_memory: bool) -> None: """ - Given a cursor to an sqlite3 connection, if memory protection is false, + Given a cursor to a sqlite3 connection, if memory protection is false, toggle memory protections off. :param cur: Cursor to an open SQLite3 connection. - :param protect_memory: Whether or not should memory be protected. + :param protect_memory: Whether should memory be protected. :return: Memory protections off. """ if not protect_memory: - warn("Memory protections are turned off, " - "if operations are interrupted, file may get corrupt.", RuntimeWarning) - cur.execute("PRAGMA synchronous = OFF") - cur.execute("PRAGMA journal_mode = MEMORY") + warn( + "Memory protections are turned off, " + "if operations are interrupted, file may get corrupt.", + RuntimeWarning, + ) + await cur.execute("PRAGMA synchronous = OFF") + await cur.execute("PRAGMA journal_mode = MEMORY") -def _mass_insert(objects: Union[List[T], Tuple[T]], db_name: str, protect_memory: bool = True) -> None: +async def _mass_insert( + objects: Union[List[T], Tuple[T]], db_name: str, protect_memory: bool = True +) -> None: """ Insert multiple records into an SQLite3 database. :param objects: Objects to insert. :param db_name: Name of the database to insert. - :param protect_memory: Whether or not memory + :param protect_memory: Whether memory protections are on or off. :return: None """ @@ -69,24 +81,28 @@ def _mass_insert(objects: Union[List[T], Tuple[T]], db_name: str, protect_memory for i, obj in enumerate(objects): kv_pairs = asdict(obj).items() setattr(obj, "obj_id", first_index + i + 1) - sql_queries.append(f"INSERT INTO {table_name}(" + - f"{', '.join(item[0] for item in kv_pairs)})" + - f" VALUES ({', '.join(_convert_sql_format(item[1]) for item in kv_pairs)});") - with sql.connect(db_name) as con: - cur: sql.Cursor = con.cursor() + sql_queries.append( + f"INSERT INTO {table_name}(" + + f"{', '.join(item[0] for item in kv_pairs)})" + + f" VALUES ({', '.join(_convert_sql_format(item[1], getattr(obj, 'types_table')) for item in kv_pairs)});" + ) + async with aiosqlite.connect(db_name) as con: + cur: aiosqlite.Cursor = await con.cursor() try: - _toggle_memory_protection(cur, protect_memory) - cur.execute(f"SELECT obj_id FROM {table_name} ORDER BY obj_id DESC LIMIT 1") - index_tuple = cur.fetchone() + await _toggle_memory_protection(cur, protect_memory) + await cur.execute(f"SELECT obj_id FROM {table_name} ORDER BY obj_id DESC LIMIT 1") + index_tuple = await cur.fetchone() if index_tuple: - first_index = index_tuple[0] - cur.executescript("BEGIN TRANSACTION;\n" + '\n'.join(sql_queries) + '\nEND TRANSACTION;') - except sql.IntegrityError: + _ = index_tuple[0] + await cur.executescript( + "BEGIN TRANSACTION;\n" + "\n".join(sql_queries) + "\nEND TRANSACTION;" + ) + except aiosqlite.IntegrityError: raise ConstraintFailedError - con.commit() + await con.commit() -def create_many(objects: Union[List[T], Tuple[T]], protect_memory: bool = True) -> None: +async def create_many(objects: Union[List[T], Tuple[T]], protect_memory: bool = True) -> None: """ Insert many records corresponding to objects in a tuple or a list. @@ -98,29 +114,34 @@ def create_many(objects: Union[List[T], Tuple[T]], protect_memory: bool = True) :return: None. """ if objects: - _mass_insert(objects, getattr(objects[0], "db_path"), protect_memory) + await _mass_insert(objects, getattr(objects[0], "db_path"), protect_memory) else: raise ValueError("Collection is empty.") -def copy_many(objects: Union[List[T], Tuple[T]], db_name: str, protect_memory: bool = True) -> None: +async def copy_many( + objects: Union[List[T], Tuple[T]], db_name: str, protect_memory: bool = True +) -> None: """ Copy many records to another database, from - their original database to new database, do + their original database to a new database, do not delete old records. :param objects: Objects to copy. :param db_name: Name of the new database. - :param protect_memory: Wheter to protect memory during operation, + :param protect_memory: Whether to protect memory during operation, Setting this to False will quicken the operation, but if the - operation is cut short, database file will corrupt. + operation is cut short, the database file will corrupt. :return: None """ if objects: - with sql.connect(db_name) as con: - cur = con.cursor() - _create_table(objects[0].__class__, cur) - con.commit() - _mass_insert(objects, db_name, protect_memory) + async with aiosqlite.connect(db_name) as con: + cur = await con.cursor() + await objects[0].markup_table(class_=objects[0].__class__, cursor=cur) + await con.commit() + await _mass_insert(objects, db_name, protect_memory) else: raise ValueError("Collection is empty.") + + +__all__ = ['copy_many', 'create_many', 'HeterogeneousCollectionError'] diff --git a/datalite/migrations.py b/datalite/migrations.py index 1b65984..160472f 100644 --- a/datalite/migrations.py +++ b/datalite/migrations.py @@ -2,53 +2,59 @@ Migrations module deals with migrating data when the object definitions change. This functions deal with Schema Migrations. """ +import shutil +import time from dataclasses import Field from os.path import exists -from typing import Dict, Tuple, List -import sqlite3 as sql +from typing import Any, Dict, List, Tuple, cast -from .commons import _create_table, _get_table_cols +import aiosqlite + +from .commons import _get_table_cols, _tweaked_dump_value -def _get_db_table(class_: type) -> Tuple[str, str]: +async def _get_db_table(class_: type) -> Tuple[str, str]: """ - Check if the class is a datalite class, the database exists + Check if the class is a datalite class, the database exists, and the table exists. Return database and table names. :param class_: A datalite class. :return: A tuple of database and table names. """ - database_name: str = getattr(class_, 'db_path', None) + database_name: str = getattr(class_, "db_path", None) if not database_name: raise TypeError(f"{class_.__name__} is not a datalite class.") table_name: str = class_.__name__.lower() if not exists(database_name): raise FileNotFoundError(f"{database_name} does not exist") - with sql.connect(database_name) as con: - cur: sql.Cursor = con.cursor() - cur.execute("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?;", (table_name, )) - count: int = int(cur.fetchone()[0]) + async with aiosqlite.connect(database_name) as con: + cur: aiosqlite.Cursor = await con.cursor() + await cur.execute( + "SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?;", + (table_name,), + ) + count: int = int((await cur.fetchone())[0]) if not count: raise FileExistsError(f"Table, {table_name}, already exists.") return database_name, table_name -def _get_table_column_names(database_name: str, table_name: str) -> Tuple[str]: +async def _get_table_column_names(database_name: str, table_name: str) -> Tuple[str]: """ - Get the column names of table. + Get the column names of the table. - :param database_name: Name of the database the table + :param database_name: The name of the database the table resides in. :param table_name: Name of the table. :return: A tuple holding the column names of the table. """ - with sql.connect(database_name) as con: - cur: sql.Cursor = con.cursor() - cols: List[str] = _get_table_cols(cur, table_name) - return tuple(cols) + async with aiosqlite.connect(database_name) as con: + cur: aiosqlite.Cursor = await con.cursor() + cols: List[str] = await _get_table_cols(cur, table_name) + return cast(Tuple[str], tuple(cols)) -def _copy_records(database_name: str, table_name: str): +async def _copy_records(database_name: str, table_name: str): """ Copy all records from a table. @@ -56,17 +62,17 @@ def _copy_records(database_name: str, table_name: str): :param table_name: Name of the table. :return: A generator holding dataclass asdict representations. """ - with sql.connect(database_name) as con: - cur: sql.Cursor = con.cursor() - cur.execute(f'SELECT * FROM {table_name};') - values = cur.fetchall() - keys = _get_table_cols(cur, table_name) - keys.insert(0, 'obj_id') + async with aiosqlite.connect(database_name) as con: + cur: aiosqlite.Cursor = await con.cursor() + await cur.execute(f"SELECT * FROM {table_name};") + values = await cur.fetchall() + keys = await _get_table_cols(cur, table_name) + keys.insert(0, "obj_id") records = (dict(zip(keys, value)) for value in values) return records -def _drop_table(database_name: str, table_name: str) -> None: +async def _drop_table(database_name: str, table_name: str) -> None: """ Drop a table. @@ -74,14 +80,15 @@ def _drop_table(database_name: str, table_name: str) -> None: :param table_name: Name of the table to be dropped. :return: None. """ - with sql.connect(database_name) as con: - cur: sql.Cursor = con.cursor() - cur.execute(f'DROP TABLE {table_name};') - con.commit() + async with aiosqlite.connect(database_name) as con: + cur: aiosqlite.Cursor = await con.cursor() + await cur.execute(f"DROP TABLE {table_name};") + await con.commit() -def _modify_records(data, col_to_del: Tuple[str], col_to_add: Tuple[str], - flow: Dict[str, str]) -> Tuple[Dict[str, str]]: +def _modify_records( + data, col_to_del: Tuple[str], col_to_add: Tuple[str], flow: Dict[str, str] +) -> Tuple[Dict[str, str]]: """ Modify the asdict records in accordance with schema migration rules provided. @@ -89,9 +96,9 @@ def _modify_records(data, col_to_del: Tuple[str], col_to_add: Tuple[str], :param data: Data kept as asdict in tuple. :param col_to_del: Column names to delete. :param col_to_add: Column names to add. - :param flow: A dictionary that explain + :param flow: A dictionary that explains if the data from a deleted column - will be transferred to a column + is transferred to a column to be added. :return: The modified data records. """ @@ -109,15 +116,22 @@ def _modify_records(data, col_to_del: Tuple[str], col_to_add: Tuple[str], if key_to_add not in record_mod: record_mod[key_to_add] = None records.append(record_mod) - return records + return cast(Tuple[Dict[str, str]], records) -def _migrate_records(class_: type, database_name: str, data, - col_to_del: Tuple[str], col_to_add: Tuple[str], flow: Dict[str, str]) -> None: +async def _migrate_records( + class_: type, + database_name: str, + data, + col_to_del: Tuple[str], + col_to_add: Tuple[str], + flow: Dict[str, str], + safe_migration_defaults: Dict[str, Any] = None, +) -> None: """ Migrate the records into the modified table. - :param class_: Class of the entries. + :param class_: Class of entries. :param database_name: Name of the database. :param data: Data, asdict tuple. :param col_to_del: Columns to be deleted. @@ -126,40 +140,94 @@ def _migrate_records(class_: type, database_name: str, data, column data will be transferred. :return: None. """ - with sql.connect(database_name) as con: - cur: sql.Cursor = con.cursor() - _create_table(class_, cur, getattr(class_, 'types_table')) - con.commit() + if safe_migration_defaults is None: + safe_migration_defaults = {} + + async with aiosqlite.connect(database_name) as con: + cur: aiosqlite.Cursor = await con.cursor() + # noinspection PyUnresolvedReferences + await class_.markup_table( + class_=class_, cursor=cur, type_overload=getattr(class_, "types_table") + ) + await con.commit() new_records = _modify_records(data, col_to_del, col_to_add, flow) for record in new_records: - del record['obj_id'] + del record["obj_id"] keys_to_delete = [key for key in record if record[key] is None] for key in keys_to_delete: del record[key] - class_(**record).create_entry() + await class_( + **{ + **record, + **{ + k: _tweaked_dump_value(class_, v) + for k, v in safe_migration_defaults.items() + if k not in record + }, + } + ).create_entry() -def basic_migrate(class_: type, column_transfer: dict = None) -> None: +async def migrate( + class_: type, + column_transfer: dict = None, + do_backup: bool = True, + safe_migration_defaults: Dict[str, Any] = None, +) -> None: """ Given a class, compare its previous table, delete the fields that no longer exist, create new columns for new fields. If the column_flow parameter is given, migrate elements from previous column to the new ones. It should be - noted that, the obj_ids do not persist. + noted that the obj_ids do not persist. :param class_: Datalite class to migrate. :param column_transfer: A dictionary showing which columns will be copied to new ones. + :param do_backup: Whether to copy a whole database before dropping table + :param safe_migration_defaults: Key-value that will be written to old records in the database during + migration so as not to break the schema :return: None. """ - database_name, table_name = _get_db_table(class_) - table_column_names: Tuple[str] = _get_table_column_names(database_name, table_name) - values = class_.__dataclass_fields__.values() - data_fields: Tuple[Field] = tuple(field for field in values) - data_field_names: Tuple[str] = tuple(field.name for field in data_fields) - columns_to_be_deleted: Tuple[str] = tuple(column for column in table_column_names if column not in data_field_names) - columns_to_be_added: Tuple[str] = tuple(column for column in data_field_names if column not in table_column_names) - records = _copy_records(database_name, table_name) - _drop_table(database_name, table_name) - _migrate_records(class_, database_name, records, columns_to_be_deleted, columns_to_be_added, column_transfer) + database_name, table_name = await _get_db_table(class_) + table_column_names: Tuple[str] = await _get_table_column_names( + database_name, table_name + ) + + # noinspection PyUnresolvedReferences + values: List[Field] = class_.__dataclass_fields__.values() + + data_fields: Tuple[Field] = cast(Tuple[Field], tuple(field for field in values)) + data_field_names: Tuple[str] = cast( + Tuple[str], tuple(field.name for field in data_fields) + ) + columns_to_be_deleted: Tuple[str] = cast( + Tuple[str], + tuple( + column for column in table_column_names if column not in data_field_names + ), + ) + columns_to_be_added: Tuple[str] = cast( + Tuple[str], + tuple( + column for column in data_field_names if column not in table_column_names + ), + ) + + records = await _copy_records(database_name, table_name) + if do_backup: + shutil.copy(database_name, f"{database_name}-{time.time()}") + await _drop_table(database_name, table_name) + await _migrate_records( + class_, + database_name, + records, + columns_to_be_deleted, + columns_to_be_added, + column_transfer, + safe_migration_defaults, + ) + + +__all__ = ["migrate"] diff --git a/datalite/typed.py b/datalite/typed.py new file mode 100644 index 0000000..2786393 --- /dev/null +++ b/datalite/typed.py @@ -0,0 +1,34 @@ +from typing import Dict, Optional + + +def _datalite_hinted_direct_use(): + raise ValueError( + "Don't use DataliteHinted directly. Inherited classes should also be wrapped in " + "datalite and dataclass decorators" + ) + + +class DataliteHinted: + db_path: str + types_table: Dict[Optional[type], str] + tweaked: bool + obj_id: int + + # noinspection PyMethodMayBeStatic + async def markup_table(self): + _datalite_hinted_direct_use() + + # noinspection PyMethodMayBeStatic + async def create_entry(self): + _datalite_hinted_direct_use() + + # noinspection PyMethodMayBeStatic + async def update_entry(self): + _datalite_hinted_direct_use() + + # noinspection PyMethodMayBeStatic + async def remove_entry(self): + _datalite_hinted_direct_use() + + +__all__ = ["DataliteHinted"] diff --git a/docs/conf.py b/docs/conf.py index f034226..7936711 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -12,18 +12,19 @@ # import os import sys -sys.path.insert(0, os.path.abspath('.')) -sys.path.insert(0, os.path.abspath('../.')) + +sys.path.insert(0, os.path.abspath(".")) +sys.path.insert(0, os.path.abspath("../.")) # -- Project information ----------------------------------------------------- -project = 'Datalite' -copyright = '2020, Ege Ozkan' -author = 'Ege Ozkan' +project = "Datalite" +copyright = "2020, Ege Ozkan" +author = "Ege Ozkan" # The full version, including alpha/beta/rc tags -release = 'v0.7.1' +release = "v0.7.1" # -- General configuration --------------------------------------------------- @@ -31,15 +32,15 @@ release = 'v0.7.1' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = ['sphinx.ext.autodoc'] +extensions = ["sphinx.ext.autodoc"] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # -- Options for HTML output ------------------------------------------------- @@ -47,9 +48,9 @@ exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"]