Safer defaults and mass actions, base64 support from previous commit dropped

This commit is contained in:
hhh
2024-03-17 14:56:23 +02:00
parent 6dfc3cebbe
commit b0523e141e
4 changed files with 88 additions and 45 deletions

View File

@@ -1,9 +1,8 @@
from dataclasses import Field
from dataclasses import MISSING, Field
from pickle import HIGHEST_PROTOCOL, dumps, loads
from typing import Any, Dict, List, Optional
import aiosqlite
import base64
from .constraints import Unique
@@ -47,7 +46,7 @@ def _tweaked_convert_type(
return type_overload.get(type_, "BLOB")
def _convert_sql_format(value: Any, type_overload: Dict[Optional[type], str]) -> str:
def _convert_sql_format(value: Any) -> str:
"""
Given a Python value, convert to string representation
of the equivalent SQL datatype.
@@ -62,10 +61,8 @@ def _convert_sql_format(value: Any, type_overload: Dict[Optional[type], str]) ->
return '"' + str(value).replace("b'", "")[:-1] + '"'
elif isinstance(value, bool):
return "TRUE" if value else "FALSE"
elif type(value) in type_overload:
return str(value)
else:
return '"' + base64.encodebytes(dumps(value, protocol=HIGHEST_PROTOCOL)).decode() + '"'
return str(value)
async def _get_table_cols(cur: aiosqlite.Cursor, table_name: str) -> List[str]:
@@ -81,7 +78,9 @@ async def _get_table_cols(cur: aiosqlite.Cursor, table_name: str) -> List[str]:
def _get_default(
default_object: object, type_overload: Dict[Optional[type], str]
default_object: object,
type_overload: Dict[Optional[type], str],
mutable_def_params: list,
) -> str:
"""
Check if the field's default object is filled,
@@ -93,8 +92,14 @@ def _get_default(
empty string if no string is necessary.
"""
if type(default_object) in type_overload:
return f" DEFAULT {_convert_sql_format(default_object, type_overload)}"
return ""
return f" DEFAULT {_convert_sql_format(default_object)}"
elif type(default_object) is type(MISSING):
return ""
else:
mutable_def_params.append(
bytes(dumps(default_object, protocol=HIGHEST_PROTOCOL))
)
return " DEFAULT ?"
# noinspection PyDefaultArgument
@@ -129,15 +134,18 @@ async def _create_table(
]
fields.sort(key=lambda field: field.name) # Since dictionaries *may* be unsorted.
def_params = list()
sql_fields = ", ".join(
f"{field.name} {type_converter(field.type, type_overload)}"
f"{_get_default(field.default, type_overload)}"
f"{_get_default(field.default, type_overload, def_params)}"
for field in fields
)
sql_fields = "obj_id INTEGER PRIMARY KEY AUTOINCREMENT, " + sql_fields
await cursor.execute(
f"CREATE TABLE IF NOT EXISTS {class_.__name__.lower()} ({sql_fields});"
f"CREATE TABLE IF NOT EXISTS {class_.__name__.lower()} ({sql_fields});",
def_params if def_params else None,
)