Migrate to aiosqlite, add some undocumented features, will be documented later. Some problems with mass_actions, don't use them, because loading algorythm may change
This commit is contained in:
@@ -1,15 +1,28 @@
|
||||
from dataclasses import Field
|
||||
from typing import Any, Optional, Dict, List
|
||||
from pickle import HIGHEST_PROTOCOL, dumps, loads
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiosqlite
|
||||
import base64
|
||||
|
||||
from .constraints import Unique
|
||||
import sqlite3 as sql
|
||||
|
||||
type_table: Dict[Optional[type], str] = {
|
||||
None: "NULL",
|
||||
int: "INTEGER",
|
||||
float: "REAL",
|
||||
str: "TEXT",
|
||||
bytes: "BLOB",
|
||||
bool: "INTEGER",
|
||||
}
|
||||
type_table.update(
|
||||
{Unique[key]: f"{value} NOT NULL UNIQUE" for key, value in type_table.items()}
|
||||
)
|
||||
|
||||
|
||||
type_table: Dict[Optional[type], str] = {None: "NULL", int: "INTEGER", float: "REAL",
|
||||
str: "TEXT", bytes: "BLOB", bool: "INTEGER"}
|
||||
type_table.update({Unique[key]: f"{value} NOT NULL UNIQUE" for key, value in type_table.items()})
|
||||
|
||||
|
||||
def _convert_type(type_: Optional[type], type_overload: Dict[Optional[type], str]) -> str:
|
||||
def _convert_type(
|
||||
type_: Optional[type], type_overload: Dict[Optional[type], str]
|
||||
) -> str:
|
||||
"""
|
||||
Given a Python type, return the str name of its
|
||||
SQLlite equivalent.
|
||||
@@ -22,19 +35,24 @@ def _convert_type(type_: Optional[type], type_overload: Dict[Optional[type], str
|
||||
try:
|
||||
return type_overload[type_]
|
||||
except KeyError:
|
||||
raise TypeError("Requested type not in the default or overloaded type table.")
|
||||
raise TypeError(
|
||||
"Requested type not in the default or overloaded type table. Use @datalite(tweaked=True) to "
|
||||
"encode custom types"
|
||||
)
|
||||
|
||||
|
||||
def _convert_sql_format(value: Any) -> str:
|
||||
def _tweaked_convert_type(
|
||||
type_: Optional[type], type_overload: Dict[Optional[type], str]
|
||||
) -> str:
|
||||
return type_overload.get(type_, "BLOB")
|
||||
|
||||
|
||||
def _convert_sql_format(value: Any, type_overload: Dict[Optional[type], str]) -> str:
|
||||
"""
|
||||
Given a Python value, convert to string representation
|
||||
of the equivalent SQL datatype.
|
||||
:param value: A value, ie: a literal, a variable etc.
|
||||
:return: The string representation of the SQL equivalent.
|
||||
>>> _convert_sql_format(1)
|
||||
"1"
|
||||
>>> _convert_sql_format("John Smith")
|
||||
'"John Smith"'
|
||||
"""
|
||||
if value is None:
|
||||
return "NULL"
|
||||
@@ -44,11 +62,13 @@ def _convert_sql_format(value: Any) -> str:
|
||||
return '"' + str(value).replace("b'", "")[:-1] + '"'
|
||||
elif isinstance(value, bool):
|
||||
return "TRUE" if value else "FALSE"
|
||||
else:
|
||||
elif type(value) in type_overload:
|
||||
return str(value)
|
||||
else:
|
||||
return '"' + base64.encodebytes(dumps(value, protocol=HIGHEST_PROTOCOL)).decode() + '"'
|
||||
|
||||
|
||||
def _get_table_cols(cur: sql.Cursor, table_name: str) -> List[str]:
|
||||
async def _get_table_cols(cur: aiosqlite.Cursor, table_name: str) -> List[str]:
|
||||
"""
|
||||
Get the column data of a table.
|
||||
|
||||
@@ -56,11 +76,13 @@ def _get_table_cols(cur: sql.Cursor, table_name: str) -> List[str]:
|
||||
:param table_name: Name of the table.
|
||||
:return: the information about columns.
|
||||
"""
|
||||
cur.execute(f"PRAGMA table_info({table_name});")
|
||||
return [row_info[1] for row_info in cur.fetchall()][1:]
|
||||
await cur.execute(f"PRAGMA table_info({table_name});")
|
||||
return [row_info[1] for row_info in await cur.fetchall()][1:]
|
||||
|
||||
|
||||
def _get_default(default_object: object, type_overload: Dict[Optional[type], str]) -> str:
|
||||
def _get_default(
|
||||
default_object: object, type_overload: Dict[Optional[type], str]
|
||||
) -> str:
|
||||
"""
|
||||
Check if the field's default object is filled,
|
||||
if filled return the string to be put in the,
|
||||
@@ -71,11 +93,28 @@ def _get_default(default_object: object, type_overload: Dict[Optional[type], str
|
||||
empty string if no string is necessary.
|
||||
"""
|
||||
if type(default_object) in type_overload:
|
||||
return f' DEFAULT {_convert_sql_format(default_object)}'
|
||||
return f" DEFAULT {_convert_sql_format(default_object, type_overload)}"
|
||||
return ""
|
||||
|
||||
|
||||
def _create_table(class_: type, cursor: sql.Cursor, type_overload: Dict[Optional[type], str] = type_table) -> None:
|
||||
# noinspection PyDefaultArgument
|
||||
async def _tweaked_create_table(
|
||||
class_: type,
|
||||
cursor: aiosqlite.Cursor,
|
||||
type_overload: Dict[Optional[type], str] = type_table,
|
||||
) -> None:
|
||||
await _create_table(
|
||||
class_, cursor, type_overload, type_converter=_tweaked_convert_type
|
||||
)
|
||||
|
||||
|
||||
# noinspection PyDefaultArgument
|
||||
async def _create_table(
|
||||
class_: type,
|
||||
cursor: aiosqlite.Cursor,
|
||||
type_overload: Dict[Optional[type], str] = type_table,
|
||||
type_converter=_convert_type,
|
||||
) -> None:
|
||||
"""
|
||||
Create the table for a specific dataclass given
|
||||
:param class_: A dataclass.
|
||||
@@ -84,10 +123,35 @@ def _create_table(class_: type, cursor: sql.Cursor, type_overload: Dict[Optional
|
||||
with a custom table, this is that custom table.
|
||||
:return: None.
|
||||
"""
|
||||
fields: List[Field] = [class_.__dataclass_fields__[key] for
|
||||
key in class_.__dataclass_fields__.keys()]
|
||||
# noinspection PyUnresolvedReferences
|
||||
fields: List[Field] = [
|
||||
class_.__dataclass_fields__[key] for key in class_.__dataclass_fields__.keys()
|
||||
]
|
||||
fields.sort(key=lambda field: field.name) # Since dictionaries *may* be unsorted.
|
||||
sql_fields = ', '.join(f"{field.name} {_convert_type(field.type, type_overload)}"
|
||||
f"{_get_default(field.default, type_overload)}" for field in fields)
|
||||
|
||||
sql_fields = ", ".join(
|
||||
f"{field.name} {type_converter(field.type, type_overload)}"
|
||||
f"{_get_default(field.default, type_overload)}"
|
||||
for field in fields
|
||||
)
|
||||
|
||||
sql_fields = "obj_id INTEGER PRIMARY KEY AUTOINCREMENT, " + sql_fields
|
||||
cursor.execute(f"CREATE TABLE IF NOT EXISTS {class_.__name__.lower()} ({sql_fields});")
|
||||
await cursor.execute(
|
||||
f"CREATE TABLE IF NOT EXISTS {class_.__name__.lower()} ({sql_fields});"
|
||||
)
|
||||
|
||||
|
||||
def _tweaked_dump_value(self, value):
|
||||
if type(value) in self.types_table:
|
||||
return value
|
||||
else:
|
||||
return bytes(dumps(value, protocol=HIGHEST_PROTOCOL))
|
||||
|
||||
|
||||
def _tweaked_dump(self, name):
|
||||
value = getattr(self, name)
|
||||
return _tweaked_dump_value(self, value)
|
||||
|
||||
|
||||
def _tweaked_load_value(data):
|
||||
return loads(bytes(data))
|
||||
|
||||
Reference in New Issue
Block a user