Safer defaults and mass actions, base64 support from previous commit dropped
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
from dataclasses import Field
|
||||
from dataclasses import MISSING, Field
|
||||
from pickle import HIGHEST_PROTOCOL, dumps, loads
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiosqlite
|
||||
import base64
|
||||
|
||||
from .constraints import Unique
|
||||
|
||||
@@ -47,7 +46,7 @@ def _tweaked_convert_type(
|
||||
return type_overload.get(type_, "BLOB")
|
||||
|
||||
|
||||
def _convert_sql_format(value: Any, type_overload: Dict[Optional[type], str]) -> str:
|
||||
def _convert_sql_format(value: Any) -> str:
|
||||
"""
|
||||
Given a Python value, convert to string representation
|
||||
of the equivalent SQL datatype.
|
||||
@@ -62,10 +61,8 @@ def _convert_sql_format(value: Any, type_overload: Dict[Optional[type], str]) ->
|
||||
return '"' + str(value).replace("b'", "")[:-1] + '"'
|
||||
elif isinstance(value, bool):
|
||||
return "TRUE" if value else "FALSE"
|
||||
elif type(value) in type_overload:
|
||||
return str(value)
|
||||
else:
|
||||
return '"' + base64.encodebytes(dumps(value, protocol=HIGHEST_PROTOCOL)).decode() + '"'
|
||||
return str(value)
|
||||
|
||||
|
||||
async def _get_table_cols(cur: aiosqlite.Cursor, table_name: str) -> List[str]:
|
||||
@@ -81,7 +78,9 @@ async def _get_table_cols(cur: aiosqlite.Cursor, table_name: str) -> List[str]:
|
||||
|
||||
|
||||
def _get_default(
|
||||
default_object: object, type_overload: Dict[Optional[type], str]
|
||||
default_object: object,
|
||||
type_overload: Dict[Optional[type], str],
|
||||
mutable_def_params: list,
|
||||
) -> str:
|
||||
"""
|
||||
Check if the field's default object is filled,
|
||||
@@ -93,8 +92,14 @@ def _get_default(
|
||||
empty string if no string is necessary.
|
||||
"""
|
||||
if type(default_object) in type_overload:
|
||||
return f" DEFAULT {_convert_sql_format(default_object, type_overload)}"
|
||||
return ""
|
||||
return f" DEFAULT {_convert_sql_format(default_object)}"
|
||||
elif type(default_object) is type(MISSING):
|
||||
return ""
|
||||
else:
|
||||
mutable_def_params.append(
|
||||
bytes(dumps(default_object, protocol=HIGHEST_PROTOCOL))
|
||||
)
|
||||
return " DEFAULT ?"
|
||||
|
||||
|
||||
# noinspection PyDefaultArgument
|
||||
@@ -129,15 +134,18 @@ async def _create_table(
|
||||
]
|
||||
fields.sort(key=lambda field: field.name) # Since dictionaries *may* be unsorted.
|
||||
|
||||
def_params = list()
|
||||
|
||||
sql_fields = ", ".join(
|
||||
f"{field.name} {type_converter(field.type, type_overload)}"
|
||||
f"{_get_default(field.default, type_overload)}"
|
||||
f"{_get_default(field.default, type_overload, def_params)}"
|
||||
for field in fields
|
||||
)
|
||||
|
||||
sql_fields = "obj_id INTEGER PRIMARY KEY AUTOINCREMENT, " + sql_fields
|
||||
await cursor.execute(
|
||||
f"CREATE TABLE IF NOT EXISTS {class_.__name__.lower()} ({sql_fields});"
|
||||
f"CREATE TABLE IF NOT EXISTS {class_.__name__.lower()} ({sql_fields});",
|
||||
def_params if def_params else None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user