Safer defaults and mass actions, base64 support from previous commit dropped

This commit is contained in:
hhh
2024-03-17 14:56:23 +02:00
parent 6dfc3cebbe
commit b0523e141e
4 changed files with 88 additions and 45 deletions

View File

@@ -1,10 +1,8 @@
from typing import Any, List, Tuple, Type, TypeVar, cast
import aiosqlite
from typing import Any, List, Tuple, TypeVar, Type, cast
from .commons import _convert_sql_format, _get_table_cols, _tweaked_load_value
import base64
from .commons import _get_table_cols, _tweaked_dump_value, _tweaked_load_value
T = TypeVar("T")
@@ -71,9 +69,7 @@ async def fetch_equals(
}
for key in kwargs.keys():
if field_types[key] not in types_table.keys():
kwargs[key] = _tweaked_load_value(
kwargs[key] if isinstance(kwargs[key], bytes) else base64.decodebytes(kwargs[key].encode("utf-8"))
)
kwargs[key] = _tweaked_load_value(kwargs[key])
obj = class_(**kwargs)
setattr(obj, "obj_id", obj_id)
@@ -119,9 +115,7 @@ def _convert_record_to_object(
elif is_tweaked:
if field_types[key] not in types_table.keys():
kwargs[key] = _tweaked_load_value(
kwargs[key] if isinstance(kwargs[key], bytes) else base64.decodebytes(kwargs[key].encode("utf-8"))
)
kwargs[key] = _tweaked_load_value(kwargs[key])
obj_id = record[0]
obj = class_(**kwargs)
@@ -130,7 +124,11 @@ def _convert_record_to_object(
async def fetch_if(
class_: Type[T], condition: str, page: int = 0, element_count: int = 10
class_: Type[T],
condition: str,
page: int = 0,
element_count: int = 10,
parameter_values: tuple = None,
) -> T:
"""
Fetch all class_ type variables from the bound db,
@@ -140,6 +138,7 @@ async def fetch_if(
:param condition: Condition to check for.
:param page: Which page to retrieve, default all. (0 means closed).
:param element_count: Element count in each page.
:param parameter_values: If placeholders are used, they will be replaced with these values
:return: A tuple of records that fit the given condition
of given type class_.
"""
@@ -149,7 +148,8 @@ async def fetch_if(
await cur.execute(
_insert_pagination(
f"SELECT * FROM {table_name} WHERE {condition}", page, element_count
)
),
parameter_values,
)
# noinspection PyTypeChecker
records: list = await cur.fetchall()
@@ -175,7 +175,11 @@ async def fetch_where(
:return: A tuple of the records.
"""
return await fetch_if(
class_, f"{field} = {_convert_sql_format(value, getattr(class_, 'types_table'))}", page, element_count
class_,
f"{field} = ?",
page,
element_count,
parameter_values=(_tweaked_dump_value(class_, value),),
)
@@ -231,8 +235,18 @@ async def fetch_all(
# noinspection PyTypeChecker
return cast(
tuple[T],
tuple(_convert_record_to_object(class_, record, field_names) for record in records),
tuple(
_convert_record_to_object(class_, record, field_names) for record in records
),
)
__all__ = ["is_fetchable", "fetch_equals", "fetch_from", "fetch_if", "fetch_where", "fetch_range", "fetch_all"]
__all__ = [
"is_fetchable",
"fetch_equals",
"fetch_from",
"fetch_if",
"fetch_where",
"fetch_range",
"fetch_all",
]