From b0523e141ea068da45e703dc12dd42f8f70a52b7 Mon Sep 17 00:00:00 2001 From: hhh Date: Sun, 17 Mar 2024 14:56:23 +0200 Subject: [PATCH] Safer defaults and mass actions, base64 support from previous commit dropped --- datalite/commons.py | 30 ++++++++++++++-------- datalite/constraints.py | 2 +- datalite/fetch.py | 46 +++++++++++++++++++++------------ datalite/mass_actions.py | 55 +++++++++++++++++++++++++++------------- 4 files changed, 88 insertions(+), 45 deletions(-) diff --git a/datalite/commons.py b/datalite/commons.py index 9b9bc07..52e2332 100644 --- a/datalite/commons.py +++ b/datalite/commons.py @@ -1,9 +1,8 @@ -from dataclasses import Field +from dataclasses import MISSING, Field from pickle import HIGHEST_PROTOCOL, dumps, loads from typing import Any, Dict, List, Optional import aiosqlite -import base64 from .constraints import Unique @@ -47,7 +46,7 @@ def _tweaked_convert_type( return type_overload.get(type_, "BLOB") -def _convert_sql_format(value: Any, type_overload: Dict[Optional[type], str]) -> str: +def _convert_sql_format(value: Any) -> str: """ Given a Python value, convert to string representation of the equivalent SQL datatype. @@ -62,10 +61,8 @@ def _convert_sql_format(value: Any, type_overload: Dict[Optional[type], str]) -> return '"' + str(value).replace("b'", "")[:-1] + '"' elif isinstance(value, bool): return "TRUE" if value else "FALSE" - elif type(value) in type_overload: - return str(value) else: - return '"' + base64.encodebytes(dumps(value, protocol=HIGHEST_PROTOCOL)).decode() + '"' + return str(value) async def _get_table_cols(cur: aiosqlite.Cursor, table_name: str) -> List[str]: @@ -81,7 +78,9 @@ async def _get_table_cols(cur: aiosqlite.Cursor, table_name: str) -> List[str]: def _get_default( - default_object: object, type_overload: Dict[Optional[type], str] + default_object: object, + type_overload: Dict[Optional[type], str], + mutable_def_params: list, ) -> str: """ Check if the field's default object is filled, @@ -93,8 +92,14 @@ def _get_default( empty string if no string is necessary. """ if type(default_object) in type_overload: - return f" DEFAULT {_convert_sql_format(default_object, type_overload)}" - return "" + return f" DEFAULT {_convert_sql_format(default_object)}" + elif type(default_object) is type(MISSING): + return "" + else: + mutable_def_params.append( + bytes(dumps(default_object, protocol=HIGHEST_PROTOCOL)) + ) + return " DEFAULT ?" # noinspection PyDefaultArgument @@ -129,15 +134,18 @@ async def _create_table( ] fields.sort(key=lambda field: field.name) # Since dictionaries *may* be unsorted. + def_params = list() + sql_fields = ", ".join( f"{field.name} {type_converter(field.type, type_overload)}" - f"{_get_default(field.default, type_overload)}" + f"{_get_default(field.default, type_overload, def_params)}" for field in fields ) sql_fields = "obj_id INTEGER PRIMARY KEY AUTOINCREMENT, " + sql_fields await cursor.execute( - f"CREATE TABLE IF NOT EXISTS {class_.__name__.lower()} ({sql_fields});" + f"CREATE TABLE IF NOT EXISTS {class_.__name__.lower()} ({sql_fields});", + def_params if def_params else None, ) diff --git a/datalite/constraints.py b/datalite/constraints.py index d7cb26d..763dbd7 100644 --- a/datalite/constraints.py +++ b/datalite/constraints.py @@ -25,4 +25,4 @@ Dataclass fields hinted with this type signals Unique = Union[Tuple[T], T] -__all__ = ['Unique', 'ConstraintFailedError'] +__all__ = ["Unique", "ConstraintFailedError"] diff --git a/datalite/fetch.py b/datalite/fetch.py index 51c1e49..08de29f 100644 --- a/datalite/fetch.py +++ b/datalite/fetch.py @@ -1,10 +1,8 @@ +from typing import Any, List, Tuple, Type, TypeVar, cast + import aiosqlite -from typing import Any, List, Tuple, TypeVar, Type, cast - -from .commons import _convert_sql_format, _get_table_cols, _tweaked_load_value - -import base64 +from .commons import _get_table_cols, _tweaked_dump_value, _tweaked_load_value T = TypeVar("T") @@ -71,9 +69,7 @@ async def fetch_equals( } for key in kwargs.keys(): if field_types[key] not in types_table.keys(): - kwargs[key] = _tweaked_load_value( - kwargs[key] if isinstance(kwargs[key], bytes) else base64.decodebytes(kwargs[key].encode("utf-8")) - ) + kwargs[key] = _tweaked_load_value(kwargs[key]) obj = class_(**kwargs) setattr(obj, "obj_id", obj_id) @@ -119,9 +115,7 @@ def _convert_record_to_object( elif is_tweaked: if field_types[key] not in types_table.keys(): - kwargs[key] = _tweaked_load_value( - kwargs[key] if isinstance(kwargs[key], bytes) else base64.decodebytes(kwargs[key].encode("utf-8")) - ) + kwargs[key] = _tweaked_load_value(kwargs[key]) obj_id = record[0] obj = class_(**kwargs) @@ -130,7 +124,11 @@ def _convert_record_to_object( async def fetch_if( - class_: Type[T], condition: str, page: int = 0, element_count: int = 10 + class_: Type[T], + condition: str, + page: int = 0, + element_count: int = 10, + parameter_values: tuple = None, ) -> T: """ Fetch all class_ type variables from the bound db, @@ -140,6 +138,7 @@ async def fetch_if( :param condition: Condition to check for. :param page: Which page to retrieve, default all. (0 means closed). :param element_count: Element count in each page. + :param parameter_values: If placeholders are used, they will be replaced with these values :return: A tuple of records that fit the given condition of given type class_. """ @@ -149,7 +148,8 @@ async def fetch_if( await cur.execute( _insert_pagination( f"SELECT * FROM {table_name} WHERE {condition}", page, element_count - ) + ), + parameter_values, ) # noinspection PyTypeChecker records: list = await cur.fetchall() @@ -175,7 +175,11 @@ async def fetch_where( :return: A tuple of the records. """ return await fetch_if( - class_, f"{field} = {_convert_sql_format(value, getattr(class_, 'types_table'))}", page, element_count + class_, + f"{field} = ?", + page, + element_count, + parameter_values=(_tweaked_dump_value(class_, value),), ) @@ -231,8 +235,18 @@ async def fetch_all( # noinspection PyTypeChecker return cast( tuple[T], - tuple(_convert_record_to_object(class_, record, field_names) for record in records), + tuple( + _convert_record_to_object(class_, record, field_names) for record in records + ), ) -__all__ = ["is_fetchable", "fetch_equals", "fetch_from", "fetch_if", "fetch_where", "fetch_range", "fetch_all"] +__all__ = [ + "is_fetchable", + "fetch_equals", + "fetch_from", + "fetch_if", + "fetch_where", + "fetch_range", + "fetch_all", +] diff --git a/datalite/mass_actions.py b/datalite/mass_actions.py index 0f9369c..4b8a363 100644 --- a/datalite/mass_actions.py +++ b/datalite/mass_actions.py @@ -3,12 +3,13 @@ This module includes functions to insert multiple records to a bound database at one time, with one time open and closing of the database file. """ -import aiosqlite -from dataclasses import asdict +from dataclasses import asdict, fields from typing import List, Tuple, TypeVar, Union from warnings import warn -from .commons import _convert_sql_format +import aiosqlite + +from .commons import _tweaked_dump from .constraints import ConstraintFailedError T = TypeVar("T") @@ -42,7 +43,9 @@ def _check_homogeneity(objects: Union[List[T], Tuple[T]]) -> None: raise HeterogeneousCollectionError("Tuple or List is not homogeneous.") -async def _toggle_memory_protection(cur: aiosqlite.Cursor, protect_memory: bool) -> None: +async def _toggle_memory_protection( + cur: aiosqlite.Cursor, protect_memory: bool +) -> None: """ Given a cursor to a sqlite3 connection, if memory protection is false, toggle memory protections off. @@ -74,35 +77,53 @@ async def _mass_insert( :return: None """ _check_homogeneity(objects) - sql_queries = [] + is_tweaked = getattr(objects[0], "tweaked") first_index: int = 0 table_name = objects[0].__class__.__name__.lower() for i, obj in enumerate(objects): - kv_pairs = asdict(obj).items() setattr(obj, "obj_id", first_index + i + 1) - sql_queries.append( - f"INSERT INTO {table_name}(" - + f"{', '.join(item[0] for item in kv_pairs)})" - + f" VALUES ({', '.join(_convert_sql_format(item[1], getattr(obj, 'types_table')) for item in kv_pairs)});" - ) async with aiosqlite.connect(db_name) as con: cur: aiosqlite.Cursor = await con.cursor() try: await _toggle_memory_protection(cur, protect_memory) - await cur.execute(f"SELECT obj_id FROM {table_name} ORDER BY obj_id DESC LIMIT 1") + await cur.execute( + f"SELECT obj_id FROM {table_name} ORDER BY obj_id DESC LIMIT 1" + ) index_tuple = await cur.fetchone() if index_tuple: _ = index_tuple[0] - await cur.executescript( - "BEGIN TRANSACTION;\n" + "\n".join(sql_queries) + "\nEND TRANSACTION;" - ) + + await cur.execute("BEGIN TRANSACTION;") + + for i, obj in enumerate(objects): + if is_tweaked: + kv_pairs = [item for item in fields(obj)] + kv_pairs.sort(key=lambda item: item.name) + column_names = ", ".join(item.name for item in kv_pairs) + vals = tuple(_tweaked_dump(obj, item.name) for item in kv_pairs) + else: + kv_pairs = [item for item in asdict(obj).items()] + kv_pairs.sort(key=lambda item: item[0]) + column_names = ", ".join(column[0] for column in kv_pairs) + vals = tuple(column[1] for column in kv_pairs) + setattr(obj, "obj_id", first_index + i + 1) + + placeholders = ", ".join("?" for _ in kv_pairs) + sql_statement = f"INSERT INTO {table_name} ({column_names}) VALUES ({placeholders});" + + await cur.execute(sql_statement, vals) + + await cur.execute("END TRANSACTION;") + except aiosqlite.IntegrityError: raise ConstraintFailedError await con.commit() -async def create_many(objects: Union[List[T], Tuple[T]], protect_memory: bool = True) -> None: +async def create_many( + objects: Union[List[T], Tuple[T]], protect_memory: bool = True +) -> None: """ Insert many records corresponding to objects in a tuple or a list. @@ -144,4 +165,4 @@ async def copy_many( raise ValueError("Collection is empty.") -__all__ = ['copy_many', 'create_many', 'HeterogeneousCollectionError'] +__all__ = ["copy_many", "create_many", "HeterogeneousCollectionError"]