Safer defaults and mass actions, base64 support from previous commit dropped

This commit is contained in:
hhh
2024-03-17 14:56:23 +02:00
parent 6dfc3cebbe
commit b0523e141e
4 changed files with 88 additions and 45 deletions

View File

@@ -1,9 +1,8 @@
from dataclasses import Field from dataclasses import MISSING, Field
from pickle import HIGHEST_PROTOCOL, dumps, loads from pickle import HIGHEST_PROTOCOL, dumps, loads
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import aiosqlite import aiosqlite
import base64
from .constraints import Unique from .constraints import Unique
@@ -47,7 +46,7 @@ def _tweaked_convert_type(
return type_overload.get(type_, "BLOB") return type_overload.get(type_, "BLOB")
def _convert_sql_format(value: Any, type_overload: Dict[Optional[type], str]) -> str: def _convert_sql_format(value: Any) -> str:
""" """
Given a Python value, convert to string representation Given a Python value, convert to string representation
of the equivalent SQL datatype. of the equivalent SQL datatype.
@@ -62,10 +61,8 @@ def _convert_sql_format(value: Any, type_overload: Dict[Optional[type], str]) ->
return '"' + str(value).replace("b'", "")[:-1] + '"' return '"' + str(value).replace("b'", "")[:-1] + '"'
elif isinstance(value, bool): elif isinstance(value, bool):
return "TRUE" if value else "FALSE" return "TRUE" if value else "FALSE"
elif type(value) in type_overload:
return str(value)
else: else:
return '"' + base64.encodebytes(dumps(value, protocol=HIGHEST_PROTOCOL)).decode() + '"' return str(value)
async def _get_table_cols(cur: aiosqlite.Cursor, table_name: str) -> List[str]: async def _get_table_cols(cur: aiosqlite.Cursor, table_name: str) -> List[str]:
@@ -81,7 +78,9 @@ async def _get_table_cols(cur: aiosqlite.Cursor, table_name: str) -> List[str]:
def _get_default( def _get_default(
default_object: object, type_overload: Dict[Optional[type], str] default_object: object,
type_overload: Dict[Optional[type], str],
mutable_def_params: list,
) -> str: ) -> str:
""" """
Check if the field's default object is filled, Check if the field's default object is filled,
@@ -93,8 +92,14 @@ def _get_default(
empty string if no string is necessary. empty string if no string is necessary.
""" """
if type(default_object) in type_overload: if type(default_object) in type_overload:
return f" DEFAULT {_convert_sql_format(default_object, type_overload)}" return f" DEFAULT {_convert_sql_format(default_object)}"
elif type(default_object) is type(MISSING):
return "" return ""
else:
mutable_def_params.append(
bytes(dumps(default_object, protocol=HIGHEST_PROTOCOL))
)
return " DEFAULT ?"
# noinspection PyDefaultArgument # noinspection PyDefaultArgument
@@ -129,15 +134,18 @@ async def _create_table(
] ]
fields.sort(key=lambda field: field.name) # Since dictionaries *may* be unsorted. fields.sort(key=lambda field: field.name) # Since dictionaries *may* be unsorted.
def_params = list()
sql_fields = ", ".join( sql_fields = ", ".join(
f"{field.name} {type_converter(field.type, type_overload)}" f"{field.name} {type_converter(field.type, type_overload)}"
f"{_get_default(field.default, type_overload)}" f"{_get_default(field.default, type_overload, def_params)}"
for field in fields for field in fields
) )
sql_fields = "obj_id INTEGER PRIMARY KEY AUTOINCREMENT, " + sql_fields sql_fields = "obj_id INTEGER PRIMARY KEY AUTOINCREMENT, " + sql_fields
await cursor.execute( await cursor.execute(
f"CREATE TABLE IF NOT EXISTS {class_.__name__.lower()} ({sql_fields});" f"CREATE TABLE IF NOT EXISTS {class_.__name__.lower()} ({sql_fields});",
def_params if def_params else None,
) )

View File

@@ -25,4 +25,4 @@ Dataclass fields hinted with this type signals
Unique = Union[Tuple[T], T] Unique = Union[Tuple[T], T]
__all__ = ['Unique', 'ConstraintFailedError'] __all__ = ["Unique", "ConstraintFailedError"]

View File

@@ -1,10 +1,8 @@
from typing import Any, List, Tuple, Type, TypeVar, cast
import aiosqlite import aiosqlite
from typing import Any, List, Tuple, TypeVar, Type, cast
from .commons import _convert_sql_format, _get_table_cols, _tweaked_load_value
import base64
from .commons import _get_table_cols, _tweaked_dump_value, _tweaked_load_value
T = TypeVar("T") T = TypeVar("T")
@@ -71,9 +69,7 @@ async def fetch_equals(
} }
for key in kwargs.keys(): for key in kwargs.keys():
if field_types[key] not in types_table.keys(): if field_types[key] not in types_table.keys():
kwargs[key] = _tweaked_load_value( kwargs[key] = _tweaked_load_value(kwargs[key])
kwargs[key] if isinstance(kwargs[key], bytes) else base64.decodebytes(kwargs[key].encode("utf-8"))
)
obj = class_(**kwargs) obj = class_(**kwargs)
setattr(obj, "obj_id", obj_id) setattr(obj, "obj_id", obj_id)
@@ -119,9 +115,7 @@ def _convert_record_to_object(
elif is_tweaked: elif is_tweaked:
if field_types[key] not in types_table.keys(): if field_types[key] not in types_table.keys():
kwargs[key] = _tweaked_load_value( kwargs[key] = _tweaked_load_value(kwargs[key])
kwargs[key] if isinstance(kwargs[key], bytes) else base64.decodebytes(kwargs[key].encode("utf-8"))
)
obj_id = record[0] obj_id = record[0]
obj = class_(**kwargs) obj = class_(**kwargs)
@@ -130,7 +124,11 @@ def _convert_record_to_object(
async def fetch_if( async def fetch_if(
class_: Type[T], condition: str, page: int = 0, element_count: int = 10 class_: Type[T],
condition: str,
page: int = 0,
element_count: int = 10,
parameter_values: tuple = None,
) -> T: ) -> T:
""" """
Fetch all class_ type variables from the bound db, Fetch all class_ type variables from the bound db,
@@ -140,6 +138,7 @@ async def fetch_if(
:param condition: Condition to check for. :param condition: Condition to check for.
:param page: Which page to retrieve, default all. (0 means closed). :param page: Which page to retrieve, default all. (0 means closed).
:param element_count: Element count in each page. :param element_count: Element count in each page.
:param parameter_values: If placeholders are used, they will be replaced with these values
:return: A tuple of records that fit the given condition :return: A tuple of records that fit the given condition
of given type class_. of given type class_.
""" """
@@ -149,7 +148,8 @@ async def fetch_if(
await cur.execute( await cur.execute(
_insert_pagination( _insert_pagination(
f"SELECT * FROM {table_name} WHERE {condition}", page, element_count f"SELECT * FROM {table_name} WHERE {condition}", page, element_count
) ),
parameter_values,
) )
# noinspection PyTypeChecker # noinspection PyTypeChecker
records: list = await cur.fetchall() records: list = await cur.fetchall()
@@ -175,7 +175,11 @@ async def fetch_where(
:return: A tuple of the records. :return: A tuple of the records.
""" """
return await fetch_if( return await fetch_if(
class_, f"{field} = {_convert_sql_format(value, getattr(class_, 'types_table'))}", page, element_count class_,
f"{field} = ?",
page,
element_count,
parameter_values=(_tweaked_dump_value(class_, value),),
) )
@@ -231,8 +235,18 @@ async def fetch_all(
# noinspection PyTypeChecker # noinspection PyTypeChecker
return cast( return cast(
tuple[T], tuple[T],
tuple(_convert_record_to_object(class_, record, field_names) for record in records), tuple(
_convert_record_to_object(class_, record, field_names) for record in records
),
) )
__all__ = ["is_fetchable", "fetch_equals", "fetch_from", "fetch_if", "fetch_where", "fetch_range", "fetch_all"] __all__ = [
"is_fetchable",
"fetch_equals",
"fetch_from",
"fetch_if",
"fetch_where",
"fetch_range",
"fetch_all",
]

View File

@@ -3,12 +3,13 @@ This module includes functions to insert multiple records
to a bound database at one time, with one time open and closing to a bound database at one time, with one time open and closing
of the database file. of the database file.
""" """
import aiosqlite from dataclasses import asdict, fields
from dataclasses import asdict
from typing import List, Tuple, TypeVar, Union from typing import List, Tuple, TypeVar, Union
from warnings import warn from warnings import warn
from .commons import _convert_sql_format import aiosqlite
from .commons import _tweaked_dump
from .constraints import ConstraintFailedError from .constraints import ConstraintFailedError
T = TypeVar("T") T = TypeVar("T")
@@ -42,7 +43,9 @@ def _check_homogeneity(objects: Union[List[T], Tuple[T]]) -> None:
raise HeterogeneousCollectionError("Tuple or List is not homogeneous.") raise HeterogeneousCollectionError("Tuple or List is not homogeneous.")
async def _toggle_memory_protection(cur: aiosqlite.Cursor, protect_memory: bool) -> None: async def _toggle_memory_protection(
cur: aiosqlite.Cursor, protect_memory: bool
) -> None:
""" """
Given a cursor to a sqlite3 connection, if memory protection is false, Given a cursor to a sqlite3 connection, if memory protection is false,
toggle memory protections off. toggle memory protections off.
@@ -74,35 +77,53 @@ async def _mass_insert(
:return: None :return: None
""" """
_check_homogeneity(objects) _check_homogeneity(objects)
sql_queries = [] is_tweaked = getattr(objects[0], "tweaked")
first_index: int = 0 first_index: int = 0
table_name = objects[0].__class__.__name__.lower() table_name = objects[0].__class__.__name__.lower()
for i, obj in enumerate(objects): for i, obj in enumerate(objects):
kv_pairs = asdict(obj).items()
setattr(obj, "obj_id", first_index + i + 1) setattr(obj, "obj_id", first_index + i + 1)
sql_queries.append(
f"INSERT INTO {table_name}("
+ f"{', '.join(item[0] for item in kv_pairs)})"
+ f" VALUES ({', '.join(_convert_sql_format(item[1], getattr(obj, 'types_table')) for item in kv_pairs)});"
)
async with aiosqlite.connect(db_name) as con: async with aiosqlite.connect(db_name) as con:
cur: aiosqlite.Cursor = await con.cursor() cur: aiosqlite.Cursor = await con.cursor()
try: try:
await _toggle_memory_protection(cur, protect_memory) await _toggle_memory_protection(cur, protect_memory)
await cur.execute(f"SELECT obj_id FROM {table_name} ORDER BY obj_id DESC LIMIT 1") await cur.execute(
f"SELECT obj_id FROM {table_name} ORDER BY obj_id DESC LIMIT 1"
)
index_tuple = await cur.fetchone() index_tuple = await cur.fetchone()
if index_tuple: if index_tuple:
_ = index_tuple[0] _ = index_tuple[0]
await cur.executescript(
"BEGIN TRANSACTION;\n" + "\n".join(sql_queries) + "\nEND TRANSACTION;" await cur.execute("BEGIN TRANSACTION;")
)
for i, obj in enumerate(objects):
if is_tweaked:
kv_pairs = [item for item in fields(obj)]
kv_pairs.sort(key=lambda item: item.name)
column_names = ", ".join(item.name for item in kv_pairs)
vals = tuple(_tweaked_dump(obj, item.name) for item in kv_pairs)
else:
kv_pairs = [item for item in asdict(obj).items()]
kv_pairs.sort(key=lambda item: item[0])
column_names = ", ".join(column[0] for column in kv_pairs)
vals = tuple(column[1] for column in kv_pairs)
setattr(obj, "obj_id", first_index + i + 1)
placeholders = ", ".join("?" for _ in kv_pairs)
sql_statement = f"INSERT INTO {table_name} ({column_names}) VALUES ({placeholders});"
await cur.execute(sql_statement, vals)
await cur.execute("END TRANSACTION;")
except aiosqlite.IntegrityError: except aiosqlite.IntegrityError:
raise ConstraintFailedError raise ConstraintFailedError
await con.commit() await con.commit()
async def create_many(objects: Union[List[T], Tuple[T]], protect_memory: bool = True) -> None: async def create_many(
objects: Union[List[T], Tuple[T]], protect_memory: bool = True
) -> None:
""" """
Insert many records corresponding to objects Insert many records corresponding to objects
in a tuple or a list. in a tuple or a list.
@@ -144,4 +165,4 @@ async def copy_many(
raise ValueError("Collection is empty.") raise ValueError("Collection is empty.")
__all__ = ['copy_many', 'create_many', 'HeterogeneousCollectionError'] __all__ = ["copy_many", "create_many", "HeterogeneousCollectionError"]