Safer defaults and mass actions, base64 support from previous commit dropped
This commit is contained in:
@@ -1,9 +1,8 @@
|
|||||||
from dataclasses import Field
|
from dataclasses import MISSING, Field
|
||||||
from pickle import HIGHEST_PROTOCOL, dumps, loads
|
from pickle import HIGHEST_PROTOCOL, dumps, loads
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
import base64
|
|
||||||
|
|
||||||
from .constraints import Unique
|
from .constraints import Unique
|
||||||
|
|
||||||
@@ -47,7 +46,7 @@ def _tweaked_convert_type(
|
|||||||
return type_overload.get(type_, "BLOB")
|
return type_overload.get(type_, "BLOB")
|
||||||
|
|
||||||
|
|
||||||
def _convert_sql_format(value: Any, type_overload: Dict[Optional[type], str]) -> str:
|
def _convert_sql_format(value: Any) -> str:
|
||||||
"""
|
"""
|
||||||
Given a Python value, convert to string representation
|
Given a Python value, convert to string representation
|
||||||
of the equivalent SQL datatype.
|
of the equivalent SQL datatype.
|
||||||
@@ -62,10 +61,8 @@ def _convert_sql_format(value: Any, type_overload: Dict[Optional[type], str]) ->
|
|||||||
return '"' + str(value).replace("b'", "")[:-1] + '"'
|
return '"' + str(value).replace("b'", "")[:-1] + '"'
|
||||||
elif isinstance(value, bool):
|
elif isinstance(value, bool):
|
||||||
return "TRUE" if value else "FALSE"
|
return "TRUE" if value else "FALSE"
|
||||||
elif type(value) in type_overload:
|
|
||||||
return str(value)
|
|
||||||
else:
|
else:
|
||||||
return '"' + base64.encodebytes(dumps(value, protocol=HIGHEST_PROTOCOL)).decode() + '"'
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
async def _get_table_cols(cur: aiosqlite.Cursor, table_name: str) -> List[str]:
|
async def _get_table_cols(cur: aiosqlite.Cursor, table_name: str) -> List[str]:
|
||||||
@@ -81,7 +78,9 @@ async def _get_table_cols(cur: aiosqlite.Cursor, table_name: str) -> List[str]:
|
|||||||
|
|
||||||
|
|
||||||
def _get_default(
|
def _get_default(
|
||||||
default_object: object, type_overload: Dict[Optional[type], str]
|
default_object: object,
|
||||||
|
type_overload: Dict[Optional[type], str],
|
||||||
|
mutable_def_params: list,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Check if the field's default object is filled,
|
Check if the field's default object is filled,
|
||||||
@@ -93,8 +92,14 @@ def _get_default(
|
|||||||
empty string if no string is necessary.
|
empty string if no string is necessary.
|
||||||
"""
|
"""
|
||||||
if type(default_object) in type_overload:
|
if type(default_object) in type_overload:
|
||||||
return f" DEFAULT {_convert_sql_format(default_object, type_overload)}"
|
return f" DEFAULT {_convert_sql_format(default_object)}"
|
||||||
return ""
|
elif type(default_object) is type(MISSING):
|
||||||
|
return ""
|
||||||
|
else:
|
||||||
|
mutable_def_params.append(
|
||||||
|
bytes(dumps(default_object, protocol=HIGHEST_PROTOCOL))
|
||||||
|
)
|
||||||
|
return " DEFAULT ?"
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyDefaultArgument
|
# noinspection PyDefaultArgument
|
||||||
@@ -129,15 +134,18 @@ async def _create_table(
|
|||||||
]
|
]
|
||||||
fields.sort(key=lambda field: field.name) # Since dictionaries *may* be unsorted.
|
fields.sort(key=lambda field: field.name) # Since dictionaries *may* be unsorted.
|
||||||
|
|
||||||
|
def_params = list()
|
||||||
|
|
||||||
sql_fields = ", ".join(
|
sql_fields = ", ".join(
|
||||||
f"{field.name} {type_converter(field.type, type_overload)}"
|
f"{field.name} {type_converter(field.type, type_overload)}"
|
||||||
f"{_get_default(field.default, type_overload)}"
|
f"{_get_default(field.default, type_overload, def_params)}"
|
||||||
for field in fields
|
for field in fields
|
||||||
)
|
)
|
||||||
|
|
||||||
sql_fields = "obj_id INTEGER PRIMARY KEY AUTOINCREMENT, " + sql_fields
|
sql_fields = "obj_id INTEGER PRIMARY KEY AUTOINCREMENT, " + sql_fields
|
||||||
await cursor.execute(
|
await cursor.execute(
|
||||||
f"CREATE TABLE IF NOT EXISTS {class_.__name__.lower()} ({sql_fields});"
|
f"CREATE TABLE IF NOT EXISTS {class_.__name__.lower()} ({sql_fields});",
|
||||||
|
def_params if def_params else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -25,4 +25,4 @@ Dataclass fields hinted with this type signals
|
|||||||
Unique = Union[Tuple[T], T]
|
Unique = Union[Tuple[T], T]
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['Unique', 'ConstraintFailedError']
|
__all__ = ["Unique", "ConstraintFailedError"]
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
|
from typing import Any, List, Tuple, Type, TypeVar, cast
|
||||||
|
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
from typing import Any, List, Tuple, TypeVar, Type, cast
|
|
||||||
|
|
||||||
from .commons import _convert_sql_format, _get_table_cols, _tweaked_load_value
|
|
||||||
|
|
||||||
import base64
|
|
||||||
|
|
||||||
|
from .commons import _get_table_cols, _tweaked_dump_value, _tweaked_load_value
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
@@ -71,9 +69,7 @@ async def fetch_equals(
|
|||||||
}
|
}
|
||||||
for key in kwargs.keys():
|
for key in kwargs.keys():
|
||||||
if field_types[key] not in types_table.keys():
|
if field_types[key] not in types_table.keys():
|
||||||
kwargs[key] = _tweaked_load_value(
|
kwargs[key] = _tweaked_load_value(kwargs[key])
|
||||||
kwargs[key] if isinstance(kwargs[key], bytes) else base64.decodebytes(kwargs[key].encode("utf-8"))
|
|
||||||
)
|
|
||||||
|
|
||||||
obj = class_(**kwargs)
|
obj = class_(**kwargs)
|
||||||
setattr(obj, "obj_id", obj_id)
|
setattr(obj, "obj_id", obj_id)
|
||||||
@@ -119,9 +115,7 @@ def _convert_record_to_object(
|
|||||||
|
|
||||||
elif is_tweaked:
|
elif is_tweaked:
|
||||||
if field_types[key] not in types_table.keys():
|
if field_types[key] not in types_table.keys():
|
||||||
kwargs[key] = _tweaked_load_value(
|
kwargs[key] = _tweaked_load_value(kwargs[key])
|
||||||
kwargs[key] if isinstance(kwargs[key], bytes) else base64.decodebytes(kwargs[key].encode("utf-8"))
|
|
||||||
)
|
|
||||||
|
|
||||||
obj_id = record[0]
|
obj_id = record[0]
|
||||||
obj = class_(**kwargs)
|
obj = class_(**kwargs)
|
||||||
@@ -130,7 +124,11 @@ def _convert_record_to_object(
|
|||||||
|
|
||||||
|
|
||||||
async def fetch_if(
|
async def fetch_if(
|
||||||
class_: Type[T], condition: str, page: int = 0, element_count: int = 10
|
class_: Type[T],
|
||||||
|
condition: str,
|
||||||
|
page: int = 0,
|
||||||
|
element_count: int = 10,
|
||||||
|
parameter_values: tuple = None,
|
||||||
) -> T:
|
) -> T:
|
||||||
"""
|
"""
|
||||||
Fetch all class_ type variables from the bound db,
|
Fetch all class_ type variables from the bound db,
|
||||||
@@ -140,6 +138,7 @@ async def fetch_if(
|
|||||||
:param condition: Condition to check for.
|
:param condition: Condition to check for.
|
||||||
:param page: Which page to retrieve, default all. (0 means closed).
|
:param page: Which page to retrieve, default all. (0 means closed).
|
||||||
:param element_count: Element count in each page.
|
:param element_count: Element count in each page.
|
||||||
|
:param parameter_values: If placeholders are used, they will be replaced with these values
|
||||||
:return: A tuple of records that fit the given condition
|
:return: A tuple of records that fit the given condition
|
||||||
of given type class_.
|
of given type class_.
|
||||||
"""
|
"""
|
||||||
@@ -149,7 +148,8 @@ async def fetch_if(
|
|||||||
await cur.execute(
|
await cur.execute(
|
||||||
_insert_pagination(
|
_insert_pagination(
|
||||||
f"SELECT * FROM {table_name} WHERE {condition}", page, element_count
|
f"SELECT * FROM {table_name} WHERE {condition}", page, element_count
|
||||||
)
|
),
|
||||||
|
parameter_values,
|
||||||
)
|
)
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
records: list = await cur.fetchall()
|
records: list = await cur.fetchall()
|
||||||
@@ -175,7 +175,11 @@ async def fetch_where(
|
|||||||
:return: A tuple of the records.
|
:return: A tuple of the records.
|
||||||
"""
|
"""
|
||||||
return await fetch_if(
|
return await fetch_if(
|
||||||
class_, f"{field} = {_convert_sql_format(value, getattr(class_, 'types_table'))}", page, element_count
|
class_,
|
||||||
|
f"{field} = ?",
|
||||||
|
page,
|
||||||
|
element_count,
|
||||||
|
parameter_values=(_tweaked_dump_value(class_, value),),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -231,8 +235,18 @@ async def fetch_all(
|
|||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
return cast(
|
return cast(
|
||||||
tuple[T],
|
tuple[T],
|
||||||
tuple(_convert_record_to_object(class_, record, field_names) for record in records),
|
tuple(
|
||||||
|
_convert_record_to_object(class_, record, field_names) for record in records
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["is_fetchable", "fetch_equals", "fetch_from", "fetch_if", "fetch_where", "fetch_range", "fetch_all"]
|
__all__ = [
|
||||||
|
"is_fetchable",
|
||||||
|
"fetch_equals",
|
||||||
|
"fetch_from",
|
||||||
|
"fetch_if",
|
||||||
|
"fetch_where",
|
||||||
|
"fetch_range",
|
||||||
|
"fetch_all",
|
||||||
|
]
|
||||||
|
|||||||
@@ -3,12 +3,13 @@ This module includes functions to insert multiple records
|
|||||||
to a bound database at one time, with one time open and closing
|
to a bound database at one time, with one time open and closing
|
||||||
of the database file.
|
of the database file.
|
||||||
"""
|
"""
|
||||||
import aiosqlite
|
from dataclasses import asdict, fields
|
||||||
from dataclasses import asdict
|
|
||||||
from typing import List, Tuple, TypeVar, Union
|
from typing import List, Tuple, TypeVar, Union
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
|
||||||
from .commons import _convert_sql_format
|
import aiosqlite
|
||||||
|
|
||||||
|
from .commons import _tweaked_dump
|
||||||
from .constraints import ConstraintFailedError
|
from .constraints import ConstraintFailedError
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
@@ -42,7 +43,9 @@ def _check_homogeneity(objects: Union[List[T], Tuple[T]]) -> None:
|
|||||||
raise HeterogeneousCollectionError("Tuple or List is not homogeneous.")
|
raise HeterogeneousCollectionError("Tuple or List is not homogeneous.")
|
||||||
|
|
||||||
|
|
||||||
async def _toggle_memory_protection(cur: aiosqlite.Cursor, protect_memory: bool) -> None:
|
async def _toggle_memory_protection(
|
||||||
|
cur: aiosqlite.Cursor, protect_memory: bool
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Given a cursor to a sqlite3 connection, if memory protection is false,
|
Given a cursor to a sqlite3 connection, if memory protection is false,
|
||||||
toggle memory protections off.
|
toggle memory protections off.
|
||||||
@@ -74,35 +77,53 @@ async def _mass_insert(
|
|||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
_check_homogeneity(objects)
|
_check_homogeneity(objects)
|
||||||
sql_queries = []
|
is_tweaked = getattr(objects[0], "tweaked")
|
||||||
first_index: int = 0
|
first_index: int = 0
|
||||||
table_name = objects[0].__class__.__name__.lower()
|
table_name = objects[0].__class__.__name__.lower()
|
||||||
|
|
||||||
for i, obj in enumerate(objects):
|
for i, obj in enumerate(objects):
|
||||||
kv_pairs = asdict(obj).items()
|
|
||||||
setattr(obj, "obj_id", first_index + i + 1)
|
setattr(obj, "obj_id", first_index + i + 1)
|
||||||
sql_queries.append(
|
|
||||||
f"INSERT INTO {table_name}("
|
|
||||||
+ f"{', '.join(item[0] for item in kv_pairs)})"
|
|
||||||
+ f" VALUES ({', '.join(_convert_sql_format(item[1], getattr(obj, 'types_table')) for item in kv_pairs)});"
|
|
||||||
)
|
|
||||||
async with aiosqlite.connect(db_name) as con:
|
async with aiosqlite.connect(db_name) as con:
|
||||||
cur: aiosqlite.Cursor = await con.cursor()
|
cur: aiosqlite.Cursor = await con.cursor()
|
||||||
try:
|
try:
|
||||||
await _toggle_memory_protection(cur, protect_memory)
|
await _toggle_memory_protection(cur, protect_memory)
|
||||||
await cur.execute(f"SELECT obj_id FROM {table_name} ORDER BY obj_id DESC LIMIT 1")
|
await cur.execute(
|
||||||
|
f"SELECT obj_id FROM {table_name} ORDER BY obj_id DESC LIMIT 1"
|
||||||
|
)
|
||||||
index_tuple = await cur.fetchone()
|
index_tuple = await cur.fetchone()
|
||||||
if index_tuple:
|
if index_tuple:
|
||||||
_ = index_tuple[0]
|
_ = index_tuple[0]
|
||||||
await cur.executescript(
|
|
||||||
"BEGIN TRANSACTION;\n" + "\n".join(sql_queries) + "\nEND TRANSACTION;"
|
await cur.execute("BEGIN TRANSACTION;")
|
||||||
)
|
|
||||||
|
for i, obj in enumerate(objects):
|
||||||
|
if is_tweaked:
|
||||||
|
kv_pairs = [item for item in fields(obj)]
|
||||||
|
kv_pairs.sort(key=lambda item: item.name)
|
||||||
|
column_names = ", ".join(item.name for item in kv_pairs)
|
||||||
|
vals = tuple(_tweaked_dump(obj, item.name) for item in kv_pairs)
|
||||||
|
else:
|
||||||
|
kv_pairs = [item for item in asdict(obj).items()]
|
||||||
|
kv_pairs.sort(key=lambda item: item[0])
|
||||||
|
column_names = ", ".join(column[0] for column in kv_pairs)
|
||||||
|
vals = tuple(column[1] for column in kv_pairs)
|
||||||
|
setattr(obj, "obj_id", first_index + i + 1)
|
||||||
|
|
||||||
|
placeholders = ", ".join("?" for _ in kv_pairs)
|
||||||
|
sql_statement = f"INSERT INTO {table_name} ({column_names}) VALUES ({placeholders});"
|
||||||
|
|
||||||
|
await cur.execute(sql_statement, vals)
|
||||||
|
|
||||||
|
await cur.execute("END TRANSACTION;")
|
||||||
|
|
||||||
except aiosqlite.IntegrityError:
|
except aiosqlite.IntegrityError:
|
||||||
raise ConstraintFailedError
|
raise ConstraintFailedError
|
||||||
await con.commit()
|
await con.commit()
|
||||||
|
|
||||||
|
|
||||||
async def create_many(objects: Union[List[T], Tuple[T]], protect_memory: bool = True) -> None:
|
async def create_many(
|
||||||
|
objects: Union[List[T], Tuple[T]], protect_memory: bool = True
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Insert many records corresponding to objects
|
Insert many records corresponding to objects
|
||||||
in a tuple or a list.
|
in a tuple or a list.
|
||||||
@@ -144,4 +165,4 @@ async def copy_many(
|
|||||||
raise ValueError("Collection is empty.")
|
raise ValueError("Collection is empty.")
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['copy_many', 'create_many', 'HeterogeneousCollectionError']
|
__all__ = ["copy_many", "create_many", "HeterogeneousCollectionError"]
|
||||||
|
|||||||
Reference in New Issue
Block a user