Migrate to aiosqlite, add some undocumented features, will be documented later. Some problems with mass_actions, don't use them, because loading algorythm may change
This commit is contained in:
@@ -1,2 +1,10 @@
|
||||
__all__ = ['commons', 'datalite_decorator', 'fetch', 'migrations', 'datalite', 'constraints', 'mass_actions']
|
||||
__all__ = [
|
||||
"commons",
|
||||
"datalite_decorator",
|
||||
"fetch",
|
||||
"migrations",
|
||||
"datalite",
|
||||
"constraints",
|
||||
"mass_actions",
|
||||
]
|
||||
from .datalite_decorator import datalite
|
||||
|
||||
@@ -1,15 +1,28 @@
|
||||
from dataclasses import Field
|
||||
from typing import Any, Optional, Dict, List
|
||||
from pickle import HIGHEST_PROTOCOL, dumps, loads
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiosqlite
|
||||
import base64
|
||||
|
||||
from .constraints import Unique
|
||||
import sqlite3 as sql
|
||||
|
||||
type_table: Dict[Optional[type], str] = {
|
||||
None: "NULL",
|
||||
int: "INTEGER",
|
||||
float: "REAL",
|
||||
str: "TEXT",
|
||||
bytes: "BLOB",
|
||||
bool: "INTEGER",
|
||||
}
|
||||
type_table.update(
|
||||
{Unique[key]: f"{value} NOT NULL UNIQUE" for key, value in type_table.items()}
|
||||
)
|
||||
|
||||
|
||||
type_table: Dict[Optional[type], str] = {None: "NULL", int: "INTEGER", float: "REAL",
|
||||
str: "TEXT", bytes: "BLOB", bool: "INTEGER"}
|
||||
type_table.update({Unique[key]: f"{value} NOT NULL UNIQUE" for key, value in type_table.items()})
|
||||
|
||||
|
||||
def _convert_type(type_: Optional[type], type_overload: Dict[Optional[type], str]) -> str:
|
||||
def _convert_type(
|
||||
type_: Optional[type], type_overload: Dict[Optional[type], str]
|
||||
) -> str:
|
||||
"""
|
||||
Given a Python type, return the str name of its
|
||||
SQLlite equivalent.
|
||||
@@ -22,19 +35,24 @@ def _convert_type(type_: Optional[type], type_overload: Dict[Optional[type], str
|
||||
try:
|
||||
return type_overload[type_]
|
||||
except KeyError:
|
||||
raise TypeError("Requested type not in the default or overloaded type table.")
|
||||
raise TypeError(
|
||||
"Requested type not in the default or overloaded type table. Use @datalite(tweaked=True) to "
|
||||
"encode custom types"
|
||||
)
|
||||
|
||||
|
||||
def _convert_sql_format(value: Any) -> str:
|
||||
def _tweaked_convert_type(
|
||||
type_: Optional[type], type_overload: Dict[Optional[type], str]
|
||||
) -> str:
|
||||
return type_overload.get(type_, "BLOB")
|
||||
|
||||
|
||||
def _convert_sql_format(value: Any, type_overload: Dict[Optional[type], str]) -> str:
|
||||
"""
|
||||
Given a Python value, convert to string representation
|
||||
of the equivalent SQL datatype.
|
||||
:param value: A value, ie: a literal, a variable etc.
|
||||
:return: The string representation of the SQL equivalent.
|
||||
>>> _convert_sql_format(1)
|
||||
"1"
|
||||
>>> _convert_sql_format("John Smith")
|
||||
'"John Smith"'
|
||||
"""
|
||||
if value is None:
|
||||
return "NULL"
|
||||
@@ -44,11 +62,13 @@ def _convert_sql_format(value: Any) -> str:
|
||||
return '"' + str(value).replace("b'", "")[:-1] + '"'
|
||||
elif isinstance(value, bool):
|
||||
return "TRUE" if value else "FALSE"
|
||||
else:
|
||||
elif type(value) in type_overload:
|
||||
return str(value)
|
||||
else:
|
||||
return '"' + base64.encodebytes(dumps(value, protocol=HIGHEST_PROTOCOL)).decode() + '"'
|
||||
|
||||
|
||||
def _get_table_cols(cur: sql.Cursor, table_name: str) -> List[str]:
|
||||
async def _get_table_cols(cur: aiosqlite.Cursor, table_name: str) -> List[str]:
|
||||
"""
|
||||
Get the column data of a table.
|
||||
|
||||
@@ -56,11 +76,13 @@ def _get_table_cols(cur: sql.Cursor, table_name: str) -> List[str]:
|
||||
:param table_name: Name of the table.
|
||||
:return: the information about columns.
|
||||
"""
|
||||
cur.execute(f"PRAGMA table_info({table_name});")
|
||||
return [row_info[1] for row_info in cur.fetchall()][1:]
|
||||
await cur.execute(f"PRAGMA table_info({table_name});")
|
||||
return [row_info[1] for row_info in await cur.fetchall()][1:]
|
||||
|
||||
|
||||
def _get_default(default_object: object, type_overload: Dict[Optional[type], str]) -> str:
|
||||
def _get_default(
|
||||
default_object: object, type_overload: Dict[Optional[type], str]
|
||||
) -> str:
|
||||
"""
|
||||
Check if the field's default object is filled,
|
||||
if filled return the string to be put in the,
|
||||
@@ -71,11 +93,28 @@ def _get_default(default_object: object, type_overload: Dict[Optional[type], str
|
||||
empty string if no string is necessary.
|
||||
"""
|
||||
if type(default_object) in type_overload:
|
||||
return f' DEFAULT {_convert_sql_format(default_object)}'
|
||||
return f" DEFAULT {_convert_sql_format(default_object, type_overload)}"
|
||||
return ""
|
||||
|
||||
|
||||
def _create_table(class_: type, cursor: sql.Cursor, type_overload: Dict[Optional[type], str] = type_table) -> None:
|
||||
# noinspection PyDefaultArgument
|
||||
async def _tweaked_create_table(
|
||||
class_: type,
|
||||
cursor: aiosqlite.Cursor,
|
||||
type_overload: Dict[Optional[type], str] = type_table,
|
||||
) -> None:
|
||||
await _create_table(
|
||||
class_, cursor, type_overload, type_converter=_tweaked_convert_type
|
||||
)
|
||||
|
||||
|
||||
# noinspection PyDefaultArgument
|
||||
async def _create_table(
|
||||
class_: type,
|
||||
cursor: aiosqlite.Cursor,
|
||||
type_overload: Dict[Optional[type], str] = type_table,
|
||||
type_converter=_convert_type,
|
||||
) -> None:
|
||||
"""
|
||||
Create the table for a specific dataclass given
|
||||
:param class_: A dataclass.
|
||||
@@ -84,10 +123,35 @@ def _create_table(class_: type, cursor: sql.Cursor, type_overload: Dict[Optional
|
||||
with a custom table, this is that custom table.
|
||||
:return: None.
|
||||
"""
|
||||
fields: List[Field] = [class_.__dataclass_fields__[key] for
|
||||
key in class_.__dataclass_fields__.keys()]
|
||||
# noinspection PyUnresolvedReferences
|
||||
fields: List[Field] = [
|
||||
class_.__dataclass_fields__[key] for key in class_.__dataclass_fields__.keys()
|
||||
]
|
||||
fields.sort(key=lambda field: field.name) # Since dictionaries *may* be unsorted.
|
||||
sql_fields = ', '.join(f"{field.name} {_convert_type(field.type, type_overload)}"
|
||||
f"{_get_default(field.default, type_overload)}" for field in fields)
|
||||
|
||||
sql_fields = ", ".join(
|
||||
f"{field.name} {type_converter(field.type, type_overload)}"
|
||||
f"{_get_default(field.default, type_overload)}"
|
||||
for field in fields
|
||||
)
|
||||
|
||||
sql_fields = "obj_id INTEGER PRIMARY KEY AUTOINCREMENT, " + sql_fields
|
||||
cursor.execute(f"CREATE TABLE IF NOT EXISTS {class_.__name__.lower()} ({sql_fields});")
|
||||
await cursor.execute(
|
||||
f"CREATE TABLE IF NOT EXISTS {class_.__name__.lower()} ({sql_fields});"
|
||||
)
|
||||
|
||||
|
||||
def _tweaked_dump_value(self, value):
|
||||
if type(value) in self.types_table:
|
||||
return value
|
||||
else:
|
||||
return bytes(dumps(value, protocol=HIGHEST_PROTOCOL))
|
||||
|
||||
|
||||
def _tweaked_dump(self, name):
|
||||
value = getattr(self, name)
|
||||
return _tweaked_dump_value(self, value)
|
||||
|
||||
|
||||
def _tweaked_load_value(data):
|
||||
return loads(bytes(data))
|
||||
|
||||
@@ -4,15 +4,16 @@ datalite.constraints module introduces constraint
|
||||
that can be used to signal datalite decorator
|
||||
constraints in the database.
|
||||
"""
|
||||
from typing import TypeVar, Union, Tuple
|
||||
from typing import Tuple, TypeVar, Union
|
||||
|
||||
T = TypeVar('T')
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ConstraintFailedError(Exception):
|
||||
"""
|
||||
This exception is raised when a Constraint fails.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -24,3 +25,4 @@ Dataclass fields hinted with this type signals
|
||||
Unique = Union[Tuple[T], T]
|
||||
|
||||
|
||||
__all__ = ['Unique', 'ConstraintFailedError']
|
||||
|
||||
@@ -1,93 +1,160 @@
|
||||
"""
|
||||
Defines the Datalite decorator that can be used to convert a dataclass to
|
||||
a class bound to a sqlite3 database.
|
||||
a class bound to an sqlite3 database.
|
||||
"""
|
||||
from dataclasses import asdict, fields
|
||||
from sqlite3.dbapi2 import IntegrityError
|
||||
from typing import Dict, Optional, Callable
|
||||
from dataclasses import asdict
|
||||
import sqlite3 as sql
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from .commons import _create_table, _tweaked_create_table, _tweaked_dump, type_table
|
||||
from .constraints import ConstraintFailedError
|
||||
from .commons import _convert_sql_format, _convert_type, _create_table, type_table
|
||||
|
||||
|
||||
def _create_entry(self) -> None:
|
||||
async def _create_entry(self) -> None:
|
||||
"""
|
||||
Given an object, create the entry for the object. As a side-effect,
|
||||
Given an object, create the entry for the object. As a side effect,
|
||||
this will set the object_id attribute of the object to the unique
|
||||
id of the entry.
|
||||
:param self: Instance of the object.
|
||||
:return: None.
|
||||
"""
|
||||
with sql.connect(getattr(self, "db_path")) as con:
|
||||
cur: sql.Cursor = con.cursor()
|
||||
async with aiosqlite.connect(getattr(self, "db_path")) as con:
|
||||
cur: aiosqlite.Cursor = await con.cursor()
|
||||
table_name: str = self.__class__.__name__.lower()
|
||||
kv_pairs = [item for item in asdict(self).items()]
|
||||
kv_pairs.sort(key=lambda item: item[0]) # Sort by the name of the fields.
|
||||
try:
|
||||
cur.execute(f"INSERT INTO {table_name}("
|
||||
await cur.execute(
|
||||
f"INSERT INTO {table_name}("
|
||||
f"{', '.join(item[0] for item in kv_pairs)})"
|
||||
f" VALUES ({', '.join('?' for item in kv_pairs)})",
|
||||
[item[1] for item in kv_pairs])
|
||||
f" VALUES ({', '.join('?' for _ in kv_pairs)})",
|
||||
[item[1] for item in kv_pairs],
|
||||
)
|
||||
self.__setattr__("obj_id", cur.lastrowid)
|
||||
con.commit()
|
||||
await con.commit()
|
||||
except IntegrityError:
|
||||
raise ConstraintFailedError("A constraint has failed.")
|
||||
|
||||
|
||||
def _update_entry(self) -> None:
|
||||
async def _tweaked_create_entry(self) -> None:
|
||||
async with aiosqlite.connect(getattr(self, "db_path")) as con:
|
||||
cur: aiosqlite.Cursor = await con.cursor()
|
||||
table_name: str = self.__class__.__name__.lower()
|
||||
kv_pairs = [item for item in fields(self)]
|
||||
kv_pairs.sort(key=lambda item: item.name) # Sort by the name of the fields.
|
||||
try:
|
||||
await cur.execute(
|
||||
f"INSERT INTO {table_name}("
|
||||
f"{', '.join(item.name for item in kv_pairs)})"
|
||||
f" VALUES ({', '.join('?' for _ in kv_pairs)})",
|
||||
[_tweaked_dump(self, item.name) for item in kv_pairs],
|
||||
)
|
||||
self.__setattr__("obj_id", cur.lastrowid)
|
||||
await con.commit()
|
||||
except IntegrityError:
|
||||
raise ConstraintFailedError("A constraint has failed.")
|
||||
|
||||
|
||||
async def _update_entry(self) -> None:
|
||||
"""
|
||||
Given an object, update the objects entry in the bound database.
|
||||
:param self: The object.
|
||||
:return: None.
|
||||
"""
|
||||
with sql.connect(getattr(self, "db_path")) as con:
|
||||
cur: sql.Cursor = con.cursor()
|
||||
async with aiosqlite.connect(getattr(self, "db_path")) as con:
|
||||
cur: aiosqlite.Cursor = await con.cursor()
|
||||
table_name: str = self.__class__.__name__.lower()
|
||||
kv_pairs = [item for item in asdict(self).items()]
|
||||
kv_pairs.sort(key=lambda item: item[0])
|
||||
query = f"UPDATE {table_name} " + \
|
||||
f"SET {', '.join(item[0] + ' = ?' for item in kv_pairs)} " + \
|
||||
query = (
|
||||
f"UPDATE {table_name} "
|
||||
f"SET {', '.join(item[0] + ' = ?' for item in kv_pairs)} "
|
||||
f"WHERE obj_id = {getattr(self, 'obj_id')};"
|
||||
cur.execute(query, [item[1] for item in kv_pairs])
|
||||
con.commit()
|
||||
)
|
||||
await cur.execute(query, [item[1] for item in kv_pairs])
|
||||
await con.commit()
|
||||
|
||||
|
||||
def remove_from(class_: type, obj_id: int):
|
||||
with sql.connect(getattr(class_, "db_path")) as con:
|
||||
cur: sql.Cursor = con.cursor()
|
||||
cur.execute(f"DELETE FROM {class_.__name__.lower()} WHERE obj_id = ?", (obj_id, ))
|
||||
con.commit()
|
||||
async def _tweaked_update_entry(self) -> None:
|
||||
async with aiosqlite.connect(getattr(self, "db_path")) as con:
|
||||
cur: aiosqlite.Cursor = await con.cursor()
|
||||
table_name: str = self.__class__.__name__.lower()
|
||||
kv_pairs = [item for item in fields(self)]
|
||||
kv_pairs.sort(key=lambda item: item.name)
|
||||
query = (
|
||||
f"UPDATE {table_name} "
|
||||
f"SET {', '.join(item.name + ' = ?' for item in kv_pairs)} "
|
||||
f"WHERE obj_id = {getattr(self, 'obj_id')};"
|
||||
)
|
||||
await cur.execute(query, [_tweaked_dump(self, item.name) for item in kv_pairs])
|
||||
await con.commit()
|
||||
|
||||
|
||||
def _remove_entry(self) -> None:
|
||||
async def remove_from(class_: type, obj_id: int):
|
||||
async with aiosqlite.connect(getattr(class_, "db_path")) as con:
|
||||
cur: aiosqlite.Cursor = await con.cursor()
|
||||
await cur.execute(
|
||||
f"DELETE FROM {class_.__name__.lower()} WHERE obj_id = ?", (obj_id,)
|
||||
)
|
||||
await con.commit()
|
||||
|
||||
|
||||
async def _remove_entry(self) -> None:
|
||||
"""
|
||||
Remove the object's record in the underlying database.
|
||||
:param self: self instance.
|
||||
:return: None.
|
||||
"""
|
||||
remove_from(self.__class__, getattr(self, 'obj_id'))
|
||||
await remove_from(self.__class__, getattr(self, "obj_id"))
|
||||
|
||||
|
||||
def datalite(db_path: str, type_overload: Optional[Dict[Optional[type], str]] = None) -> Callable:
|
||||
def _markup_table(markup_function):
|
||||
async def inner(self=None, **kwargs):
|
||||
if not kwargs:
|
||||
async with aiosqlite.connect(getattr(self, "db_path")) as con:
|
||||
cur: aiosqlite.Cursor = await con.cursor()
|
||||
await markup_function(self.__class__, cur, self.types_table)
|
||||
else:
|
||||
await markup_function(**kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def datalite(
|
||||
db_path: str,
|
||||
type_overload: Optional[Dict[Optional[type], str]] = None,
|
||||
tweaked: bool = True,
|
||||
) -> Callable:
|
||||
"""Bind a dataclass to a sqlite3 database. This adds new methods to the class, such as
|
||||
`create_entry()`, `remove_entry()` and `update_entry()`.
|
||||
|
||||
:param db_path: Path of the database to be binded.
|
||||
:param db_path: Path of the database to be bound.
|
||||
:param type_overload: Type overload dictionary.
|
||||
:param tweaked: Whether to use pickle type tweaks
|
||||
:return: The new dataclass.
|
||||
"""
|
||||
def decorator(dataclass_: type, *args_i, **kwargs_i):
|
||||
|
||||
def decorator(dataclass_: type, *_, **__):
|
||||
types_table = type_table.copy()
|
||||
if type_overload is not None:
|
||||
types_table.update(type_overload)
|
||||
with sql.connect(db_path) as con:
|
||||
cur: sql.Cursor = con.cursor()
|
||||
_create_table(dataclass_, cur, types_table)
|
||||
setattr(dataclass_, 'db_path', db_path) # We add the path of the database to class itself.
|
||||
setattr(dataclass_, 'types_table', types_table) # We add the type table for migration.
|
||||
|
||||
setattr(dataclass_, "db_path", db_path)
|
||||
setattr(dataclass_, "types_table", types_table)
|
||||
setattr(dataclass_, "tweaked", tweaked)
|
||||
|
||||
if tweaked:
|
||||
dataclass_.markup_table = _markup_table(_tweaked_create_table)
|
||||
dataclass_.create_entry = _tweaked_create_entry
|
||||
dataclass_.update_entry = _tweaked_update_entry
|
||||
else:
|
||||
dataclass_.markup_table = _markup_table(_create_table)
|
||||
dataclass_.create_entry = _create_entry
|
||||
dataclass_.remove_entry = _remove_entry
|
||||
dataclass_.update_entry = _update_entry
|
||||
dataclass_.remove_entry = _remove_entry
|
||||
|
||||
return dataclass_
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
import sqlite3 as sql
|
||||
from typing import List, Tuple, Any
|
||||
from .commons import _convert_sql_format, _get_table_cols
|
||||
import aiosqlite
|
||||
from typing import Any, List, Tuple, TypeVar, Type, cast
|
||||
|
||||
from .commons import _convert_sql_format, _get_table_cols, _tweaked_load_value
|
||||
|
||||
import base64
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _insert_pagination(query: str, page: int, element_count: int) -> str:
|
||||
@@ -17,7 +23,7 @@ def _insert_pagination(query: str, page: int, element_count: int) -> str:
|
||||
return query + ";"
|
||||
|
||||
|
||||
def is_fetchable(class_: type, obj_id: int) -> bool:
|
||||
async def is_fetchable(class_: type, obj_id: int) -> bool:
|
||||
"""
|
||||
Check if a record is fetchable given its obj_id and
|
||||
class_ type.
|
||||
@@ -26,16 +32,22 @@ def is_fetchable(class_: type, obj_id: int) -> bool:
|
||||
:param obj_id: Unique obj_id of the object.
|
||||
:return: If the object is fetchable.
|
||||
"""
|
||||
with sql.connect(getattr(class_, 'db_path')) as con:
|
||||
cur: sql.Cursor = con.cursor()
|
||||
async with aiosqlite.connect(getattr(class_, "db_path")) as con:
|
||||
cur: aiosqlite.Cursor = await con.cursor()
|
||||
try:
|
||||
cur.execute(f"SELECT 1 FROM {class_.__name__.lower()} WHERE obj_id = ?;", (obj_id, ))
|
||||
except sql.OperationalError:
|
||||
await cur.execute(
|
||||
f"SELECT 1 FROM {class_.__name__.lower()} WHERE obj_id = ?;", (obj_id,)
|
||||
)
|
||||
except aiosqlite.OperationalError:
|
||||
raise KeyError(f"Table {class_.__name__.lower()} does not exist.")
|
||||
return bool(cur.fetchall())
|
||||
return bool(await cur.fetchall())
|
||||
|
||||
|
||||
def fetch_equals(class_: type, field: str, value: Any, ) -> Any:
|
||||
async def fetch_equals(
|
||||
class_: Type[T],
|
||||
field: str,
|
||||
value: Any,
|
||||
) -> T:
|
||||
"""
|
||||
Fetch a class_ type variable from its bound db.
|
||||
|
||||
@@ -45,18 +57,30 @@ def fetch_equals(class_: type, field: str, value: Any, ) -> Any:
|
||||
:return: The object whose data is taken from the database.
|
||||
"""
|
||||
table_name = class_.__name__.lower()
|
||||
with sql.connect(getattr(class_, 'db_path')) as con:
|
||||
cur: sql.Cursor = con.cursor()
|
||||
cur.execute(f"SELECT * FROM {table_name} WHERE {field} = ?;", (value, ))
|
||||
obj_id, *field_values = list(cur.fetchone())
|
||||
field_names: List[str] = _get_table_cols(cur, class_.__name__.lower())
|
||||
async with aiosqlite.connect(getattr(class_, "db_path")) as con:
|
||||
cur: aiosqlite.Cursor = await con.cursor()
|
||||
await cur.execute(f"SELECT * FROM {table_name} WHERE {field} = ?;", (value,))
|
||||
obj_id, *field_values = list(await cur.fetchone())
|
||||
field_names: List[str] = await _get_table_cols(cur, class_.__name__.lower())
|
||||
|
||||
kwargs = dict(zip(field_names, field_values))
|
||||
if getattr(class_, "tweaked"):
|
||||
types_table = getattr(class_, "types_table")
|
||||
field_types = {
|
||||
key: value.type for key, value in class_.__dataclass_fields__.items()
|
||||
}
|
||||
for key in kwargs.keys():
|
||||
if field_types[key] not in types_table.keys():
|
||||
kwargs[key] = _tweaked_load_value(
|
||||
kwargs[key] if isinstance(kwargs[key], bytes) else base64.decodebytes(kwargs[key].encode("utf-8"))
|
||||
)
|
||||
|
||||
obj = class_(**kwargs)
|
||||
setattr(obj, "obj_id", obj_id)
|
||||
return obj
|
||||
|
||||
|
||||
def fetch_from(class_: type, obj_id: int) -> Any:
|
||||
async def fetch_from(class_: Type[T], obj_id: int) -> T:
|
||||
"""
|
||||
Fetch a class_ type variable from its bound dv.
|
||||
|
||||
@@ -64,13 +88,17 @@ def fetch_from(class_: type, obj_id: int) -> Any:
|
||||
:param obj_id: Unique object id of the object.
|
||||
:return: The fetched object.
|
||||
"""
|
||||
if not is_fetchable(class_, obj_id):
|
||||
raise KeyError(f"An object with {obj_id} of type {class_.__name__} does not exist, or"
|
||||
f"otherwise is unreachable.")
|
||||
return fetch_equals(class_, 'obj_id', obj_id)
|
||||
if not await is_fetchable(class_, obj_id):
|
||||
raise KeyError(
|
||||
f"An object with {obj_id} of type {class_.__name__} does not exist, or"
|
||||
f"otherwise is unreachable."
|
||||
)
|
||||
return await fetch_equals(class_, "obj_id", obj_id)
|
||||
|
||||
|
||||
def _convert_record_to_object(class_: type, record: Tuple[Any], field_names: List[str]) -> Any:
|
||||
def _convert_record_to_object(
|
||||
class_: Type[T], record: Tuple[Any], field_names: List[str]
|
||||
) -> T:
|
||||
"""
|
||||
Convert a given record fetched from an SQL instance to a Python Object of given class_.
|
||||
|
||||
@@ -80,17 +108,30 @@ def _convert_record_to_object(class_: type, record: Tuple[Any], field_names: Lis
|
||||
:return: the created object.
|
||||
"""
|
||||
kwargs = dict(zip(field_names, record[1:]))
|
||||
field_types = {key: value.type for key, value in class_.__dataclass_fields__.items()}
|
||||
for key in kwargs:
|
||||
field_types = {
|
||||
key: value.type for key, value in class_.__dataclass_fields__.items()
|
||||
}
|
||||
is_tweaked = getattr(class_, "tweaked")
|
||||
types_table = getattr(class_, "types_table")
|
||||
for key in kwargs.keys():
|
||||
if field_types[key] == bytes:
|
||||
kwargs[key] = bytes(kwargs[key], encoding='utf-8')
|
||||
kwargs[key] = bytes(kwargs[key], encoding="utf-8")
|
||||
|
||||
elif is_tweaked:
|
||||
if field_types[key] not in types_table.keys():
|
||||
kwargs[key] = _tweaked_load_value(
|
||||
kwargs[key] if isinstance(kwargs[key], bytes) else base64.decodebytes(kwargs[key].encode("utf-8"))
|
||||
)
|
||||
|
||||
obj_id = record[0]
|
||||
obj = class_(**kwargs)
|
||||
setattr(obj, "obj_id", obj_id)
|
||||
return obj
|
||||
|
||||
|
||||
def fetch_if(class_: type, condition: str, page: int = 0, element_count: int = 10) -> tuple:
|
||||
async def fetch_if(
|
||||
class_: Type[T], condition: str, page: int = 0, element_count: int = 10
|
||||
) -> T:
|
||||
"""
|
||||
Fetch all class_ type variables from the bound db,
|
||||
provided they fit the given condition
|
||||
@@ -103,18 +144,27 @@ def fetch_if(class_: type, condition: str, page: int = 0, element_count: int = 1
|
||||
of given type class_.
|
||||
"""
|
||||
table_name = class_.__name__.lower()
|
||||
with sql.connect(getattr(class_, 'db_path')) as con:
|
||||
cur: sql.Cursor = con.cursor()
|
||||
cur.execute(_insert_pagination(f"SELECT * FROM {table_name} WHERE {condition}", page, element_count))
|
||||
records: list = cur.fetchall()
|
||||
field_names: List[str] = _get_table_cols(cur, table_name)
|
||||
return tuple(_convert_record_to_object(class_, record, field_names) for record in records)
|
||||
async with aiosqlite.connect(getattr(class_, "db_path")) as con:
|
||||
cur: aiosqlite.Cursor = await con.cursor()
|
||||
await cur.execute(
|
||||
_insert_pagination(
|
||||
f"SELECT * FROM {table_name} WHERE {condition}", page, element_count
|
||||
)
|
||||
)
|
||||
# noinspection PyTypeChecker
|
||||
records: list = await cur.fetchall()
|
||||
field_names: List[str] = await _get_table_cols(cur, table_name)
|
||||
return tuple(
|
||||
_convert_record_to_object(class_, record, field_names) for record in records
|
||||
)
|
||||
|
||||
|
||||
def fetch_where(class_: type, field: str, value: Any, page: int = 0, element_count: int = 10) -> tuple:
|
||||
async def fetch_where(
|
||||
class_: Type[T], field: str, value: Any, page: int = 0, element_count: int = 10
|
||||
) -> tuple[T]:
|
||||
"""
|
||||
Fetch all class_ type variables from the bound db,
|
||||
provided that the field of the records fit the
|
||||
Fetch all class_ type variables from the bound db
|
||||
if the field of the records fit the
|
||||
given value.
|
||||
|
||||
:param class_: Class of the records.
|
||||
@@ -124,10 +174,12 @@ def fetch_where(class_: type, field: str, value: Any, page: int = 0, element_cou
|
||||
:param element_count: Element count in each page.
|
||||
:return: A tuple of the records.
|
||||
"""
|
||||
return fetch_if(class_, f"{field} = {_convert_sql_format(value)}", page, element_count)
|
||||
return await fetch_if(
|
||||
class_, f"{field} = {_convert_sql_format(value, getattr(class_, 'types_table'))}", page, element_count
|
||||
)
|
||||
|
||||
|
||||
def fetch_range(class_: type, range_: range) -> tuple:
|
||||
async def fetch_range(class_: Type[T], range_: range) -> tuple[T]:
|
||||
"""
|
||||
Fetch the records in a given range of object ids.
|
||||
|
||||
@@ -136,12 +188,23 @@ def fetch_range(class_: type, range_: range) -> tuple:
|
||||
:return: A tuple of class_ type objects whose values
|
||||
come from the class_' bound database.
|
||||
"""
|
||||
return tuple(fetch_from(class_, obj_id) for obj_id in range_ if is_fetchable(class_, obj_id))
|
||||
return cast(
|
||||
tuple[T],
|
||||
tuple(
|
||||
[
|
||||
(await fetch_from(class_, obj_id))
|
||||
for obj_id in range_
|
||||
if (await is_fetchable(class_, obj_id))
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def fetch_all(class_: type, page: int = 0, element_count: int = 10) -> tuple:
|
||||
async def fetch_all(
|
||||
class_: Type[T], page: int = 0, element_count: int = 10
|
||||
) -> tuple[T]:
|
||||
"""
|
||||
Fetchall the records in the bound database.
|
||||
Fetch all the records in the bound database.
|
||||
|
||||
:param class_: Class of the records.
|
||||
:param page: Which page to retrieve, default all. (0 means closed).
|
||||
@@ -150,15 +213,26 @@ def fetch_all(class_: type, page: int = 0, element_count: int = 10) -> tuple:
|
||||
the bound database as a tuple.
|
||||
"""
|
||||
try:
|
||||
db_path = getattr(class_, 'db_path')
|
||||
db_path = getattr(class_, "db_path")
|
||||
except AttributeError:
|
||||
raise TypeError("Given class is not decorated with datalite.")
|
||||
with sql.connect(db_path) as con:
|
||||
cur: sql.Cursor = con.cursor()
|
||||
async with aiosqlite.connect(db_path) as con:
|
||||
cur: aiosqlite.Cursor = await con.cursor()
|
||||
try:
|
||||
cur.execute(_insert_pagination(f"SELECT * FROM {class_.__name__.lower()}", page, element_count))
|
||||
except sql.OperationalError:
|
||||
await cur.execute(
|
||||
_insert_pagination(
|
||||
f"SELECT * FROM {class_.__name__.lower()}", page, element_count
|
||||
)
|
||||
)
|
||||
except aiosqlite.OperationalError:
|
||||
raise TypeError(f"No record of type {class_.__name__.lower()}")
|
||||
records = cur.fetchall()
|
||||
field_names: List[str] = _get_table_cols(cur, class_.__name__.lower())
|
||||
return tuple(_convert_record_to_object(class_, record, field_names) for record in records)
|
||||
records = await cur.fetchall()
|
||||
field_names: List[str] = await _get_table_cols(cur, class_.__name__.lower())
|
||||
# noinspection PyTypeChecker
|
||||
return cast(
|
||||
tuple[T],
|
||||
tuple(_convert_record_to_object(class_, record, field_names) for record in records),
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["is_fetchable", "fetch_equals", "fetch_from", "fetch_if", "fetch_where", "fetch_range", "fetch_all"]
|
||||
|
||||
@@ -3,14 +3,15 @@ This module includes functions to insert multiple records
|
||||
to a bound database at one time, with one time open and closing
|
||||
of the database file.
|
||||
"""
|
||||
from typing import TypeVar, Union, List, Tuple
|
||||
import aiosqlite
|
||||
from dataclasses import asdict
|
||||
from typing import List, Tuple, TypeVar, Union
|
||||
from warnings import warn
|
||||
from .constraints import ConstraintFailedError
|
||||
from .commons import _convert_sql_format, _create_table
|
||||
import sqlite3 as sql
|
||||
|
||||
T = TypeVar('T')
|
||||
from .commons import _convert_sql_format
|
||||
from .constraints import ConstraintFailedError
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class HeterogeneousCollectionError(Exception):
|
||||
@@ -19,45 +20,56 @@ class HeterogeneousCollectionError(Exception):
|
||||
ie: If a List or Tuple has elements of multiple
|
||||
types.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def _check_homogeneity(objects: Union[List[T], Tuple[T]]) -> None:
|
||||
"""
|
||||
Check if all of the members a Tuple or a List
|
||||
Check if all the members a Tuple or a List
|
||||
is of the same type.
|
||||
|
||||
:param objects: Tuple or list to check.
|
||||
:return: If all of the members of the same type.
|
||||
:return: If all the members of the same type.
|
||||
"""
|
||||
class_ = objects[0].__class__
|
||||
if not all([isinstance(obj, class_) or isinstance(objects[0], obj.__class__) for obj in objects]):
|
||||
if not all(
|
||||
[
|
||||
isinstance(obj, class_) or isinstance(objects[0], obj.__class__)
|
||||
for obj in objects
|
||||
]
|
||||
):
|
||||
raise HeterogeneousCollectionError("Tuple or List is not homogeneous.")
|
||||
|
||||
|
||||
def _toggle_memory_protection(cur: sql.Cursor, protect_memory: bool) -> None:
|
||||
async def _toggle_memory_protection(cur: aiosqlite.Cursor, protect_memory: bool) -> None:
|
||||
"""
|
||||
Given a cursor to an sqlite3 connection, if memory protection is false,
|
||||
Given a cursor to a sqlite3 connection, if memory protection is false,
|
||||
toggle memory protections off.
|
||||
|
||||
:param cur: Cursor to an open SQLite3 connection.
|
||||
:param protect_memory: Whether or not should memory be protected.
|
||||
:param protect_memory: Whether should memory be protected.
|
||||
:return: Memory protections off.
|
||||
"""
|
||||
if not protect_memory:
|
||||
warn("Memory protections are turned off, "
|
||||
"if operations are interrupted, file may get corrupt.", RuntimeWarning)
|
||||
cur.execute("PRAGMA synchronous = OFF")
|
||||
cur.execute("PRAGMA journal_mode = MEMORY")
|
||||
warn(
|
||||
"Memory protections are turned off, "
|
||||
"if operations are interrupted, file may get corrupt.",
|
||||
RuntimeWarning,
|
||||
)
|
||||
await cur.execute("PRAGMA synchronous = OFF")
|
||||
await cur.execute("PRAGMA journal_mode = MEMORY")
|
||||
|
||||
|
||||
def _mass_insert(objects: Union[List[T], Tuple[T]], db_name: str, protect_memory: bool = True) -> None:
|
||||
async def _mass_insert(
|
||||
objects: Union[List[T], Tuple[T]], db_name: str, protect_memory: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
Insert multiple records into an SQLite3 database.
|
||||
|
||||
:param objects: Objects to insert.
|
||||
:param db_name: Name of the database to insert.
|
||||
:param protect_memory: Whether or not memory
|
||||
:param protect_memory: Whether memory
|
||||
protections are on or off.
|
||||
:return: None
|
||||
"""
|
||||
@@ -69,24 +81,28 @@ def _mass_insert(objects: Union[List[T], Tuple[T]], db_name: str, protect_memory
|
||||
for i, obj in enumerate(objects):
|
||||
kv_pairs = asdict(obj).items()
|
||||
setattr(obj, "obj_id", first_index + i + 1)
|
||||
sql_queries.append(f"INSERT INTO {table_name}(" +
|
||||
f"{', '.join(item[0] for item in kv_pairs)})" +
|
||||
f" VALUES ({', '.join(_convert_sql_format(item[1]) for item in kv_pairs)});")
|
||||
with sql.connect(db_name) as con:
|
||||
cur: sql.Cursor = con.cursor()
|
||||
sql_queries.append(
|
||||
f"INSERT INTO {table_name}("
|
||||
+ f"{', '.join(item[0] for item in kv_pairs)})"
|
||||
+ f" VALUES ({', '.join(_convert_sql_format(item[1], getattr(obj, 'types_table')) for item in kv_pairs)});"
|
||||
)
|
||||
async with aiosqlite.connect(db_name) as con:
|
||||
cur: aiosqlite.Cursor = await con.cursor()
|
||||
try:
|
||||
_toggle_memory_protection(cur, protect_memory)
|
||||
cur.execute(f"SELECT obj_id FROM {table_name} ORDER BY obj_id DESC LIMIT 1")
|
||||
index_tuple = cur.fetchone()
|
||||
await _toggle_memory_protection(cur, protect_memory)
|
||||
await cur.execute(f"SELECT obj_id FROM {table_name} ORDER BY obj_id DESC LIMIT 1")
|
||||
index_tuple = await cur.fetchone()
|
||||
if index_tuple:
|
||||
first_index = index_tuple[0]
|
||||
cur.executescript("BEGIN TRANSACTION;\n" + '\n'.join(sql_queries) + '\nEND TRANSACTION;')
|
||||
except sql.IntegrityError:
|
||||
_ = index_tuple[0]
|
||||
await cur.executescript(
|
||||
"BEGIN TRANSACTION;\n" + "\n".join(sql_queries) + "\nEND TRANSACTION;"
|
||||
)
|
||||
except aiosqlite.IntegrityError:
|
||||
raise ConstraintFailedError
|
||||
con.commit()
|
||||
await con.commit()
|
||||
|
||||
|
||||
def create_many(objects: Union[List[T], Tuple[T]], protect_memory: bool = True) -> None:
|
||||
async def create_many(objects: Union[List[T], Tuple[T]], protect_memory: bool = True) -> None:
|
||||
"""
|
||||
Insert many records corresponding to objects
|
||||
in a tuple or a list.
|
||||
@@ -98,29 +114,34 @@ def create_many(objects: Union[List[T], Tuple[T]], protect_memory: bool = True)
|
||||
:return: None.
|
||||
"""
|
||||
if objects:
|
||||
_mass_insert(objects, getattr(objects[0], "db_path"), protect_memory)
|
||||
await _mass_insert(objects, getattr(objects[0], "db_path"), protect_memory)
|
||||
else:
|
||||
raise ValueError("Collection is empty.")
|
||||
|
||||
|
||||
def copy_many(objects: Union[List[T], Tuple[T]], db_name: str, protect_memory: bool = True) -> None:
|
||||
async def copy_many(
|
||||
objects: Union[List[T], Tuple[T]], db_name: str, protect_memory: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
Copy many records to another database, from
|
||||
their original database to new database, do
|
||||
their original database to a new database, do
|
||||
not delete old records.
|
||||
|
||||
:param objects: Objects to copy.
|
||||
:param db_name: Name of the new database.
|
||||
:param protect_memory: Wheter to protect memory during operation,
|
||||
:param protect_memory: Whether to protect memory during operation,
|
||||
Setting this to False will quicken the operation, but if the
|
||||
operation is cut short, database file will corrupt.
|
||||
operation is cut short, the database file will corrupt.
|
||||
:return: None
|
||||
"""
|
||||
if objects:
|
||||
with sql.connect(db_name) as con:
|
||||
cur = con.cursor()
|
||||
_create_table(objects[0].__class__, cur)
|
||||
con.commit()
|
||||
_mass_insert(objects, db_name, protect_memory)
|
||||
async with aiosqlite.connect(db_name) as con:
|
||||
cur = await con.cursor()
|
||||
await objects[0].markup_table(class_=objects[0].__class__, cursor=cur)
|
||||
await con.commit()
|
||||
await _mass_insert(objects, db_name, protect_memory)
|
||||
else:
|
||||
raise ValueError("Collection is empty.")
|
||||
|
||||
|
||||
__all__ = ['copy_many', 'create_many', 'HeterogeneousCollectionError']
|
||||
|
||||
@@ -2,53 +2,59 @@
|
||||
Migrations module deals with migrating data when the object
|
||||
definitions change. This functions deal with Schema Migrations.
|
||||
"""
|
||||
import shutil
|
||||
import time
|
||||
from dataclasses import Field
|
||||
from os.path import exists
|
||||
from typing import Dict, Tuple, List
|
||||
import sqlite3 as sql
|
||||
from typing import Any, Dict, List, Tuple, cast
|
||||
|
||||
from .commons import _create_table, _get_table_cols
|
||||
import aiosqlite
|
||||
|
||||
from .commons import _get_table_cols, _tweaked_dump_value
|
||||
|
||||
|
||||
def _get_db_table(class_: type) -> Tuple[str, str]:
|
||||
async def _get_db_table(class_: type) -> Tuple[str, str]:
|
||||
"""
|
||||
Check if the class is a datalite class, the database exists
|
||||
Check if the class is a datalite class, the database exists,
|
||||
and the table exists. Return database and table names.
|
||||
|
||||
:param class_: A datalite class.
|
||||
:return: A tuple of database and table names.
|
||||
"""
|
||||
database_name: str = getattr(class_, 'db_path', None)
|
||||
database_name: str = getattr(class_, "db_path", None)
|
||||
if not database_name:
|
||||
raise TypeError(f"{class_.__name__} is not a datalite class.")
|
||||
table_name: str = class_.__name__.lower()
|
||||
if not exists(database_name):
|
||||
raise FileNotFoundError(f"{database_name} does not exist")
|
||||
with sql.connect(database_name) as con:
|
||||
cur: sql.Cursor = con.cursor()
|
||||
cur.execute("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?;", (table_name, ))
|
||||
count: int = int(cur.fetchone()[0])
|
||||
async with aiosqlite.connect(database_name) as con:
|
||||
cur: aiosqlite.Cursor = await con.cursor()
|
||||
await cur.execute(
|
||||
"SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?;",
|
||||
(table_name,),
|
||||
)
|
||||
count: int = int((await cur.fetchone())[0])
|
||||
if not count:
|
||||
raise FileExistsError(f"Table, {table_name}, already exists.")
|
||||
return database_name, table_name
|
||||
|
||||
|
||||
def _get_table_column_names(database_name: str, table_name: str) -> Tuple[str]:
|
||||
async def _get_table_column_names(database_name: str, table_name: str) -> Tuple[str]:
|
||||
"""
|
||||
Get the column names of table.
|
||||
Get the column names of the table.
|
||||
|
||||
:param database_name: Name of the database the table
|
||||
:param database_name: The name of the database the table
|
||||
resides in.
|
||||
:param table_name: Name of the table.
|
||||
:return: A tuple holding the column names of the table.
|
||||
"""
|
||||
with sql.connect(database_name) as con:
|
||||
cur: sql.Cursor = con.cursor()
|
||||
cols: List[str] = _get_table_cols(cur, table_name)
|
||||
return tuple(cols)
|
||||
async with aiosqlite.connect(database_name) as con:
|
||||
cur: aiosqlite.Cursor = await con.cursor()
|
||||
cols: List[str] = await _get_table_cols(cur, table_name)
|
||||
return cast(Tuple[str], tuple(cols))
|
||||
|
||||
|
||||
def _copy_records(database_name: str, table_name: str):
|
||||
async def _copy_records(database_name: str, table_name: str):
|
||||
"""
|
||||
Copy all records from a table.
|
||||
|
||||
@@ -56,17 +62,17 @@ def _copy_records(database_name: str, table_name: str):
|
||||
:param table_name: Name of the table.
|
||||
:return: A generator holding dataclass asdict representations.
|
||||
"""
|
||||
with sql.connect(database_name) as con:
|
||||
cur: sql.Cursor = con.cursor()
|
||||
cur.execute(f'SELECT * FROM {table_name};')
|
||||
values = cur.fetchall()
|
||||
keys = _get_table_cols(cur, table_name)
|
||||
keys.insert(0, 'obj_id')
|
||||
async with aiosqlite.connect(database_name) as con:
|
||||
cur: aiosqlite.Cursor = await con.cursor()
|
||||
await cur.execute(f"SELECT * FROM {table_name};")
|
||||
values = await cur.fetchall()
|
||||
keys = await _get_table_cols(cur, table_name)
|
||||
keys.insert(0, "obj_id")
|
||||
records = (dict(zip(keys, value)) for value in values)
|
||||
return records
|
||||
|
||||
|
||||
def _drop_table(database_name: str, table_name: str) -> None:
|
||||
async def _drop_table(database_name: str, table_name: str) -> None:
|
||||
"""
|
||||
Drop a table.
|
||||
|
||||
@@ -74,14 +80,15 @@ def _drop_table(database_name: str, table_name: str) -> None:
|
||||
:param table_name: Name of the table to be dropped.
|
||||
:return: None.
|
||||
"""
|
||||
with sql.connect(database_name) as con:
|
||||
cur: sql.Cursor = con.cursor()
|
||||
cur.execute(f'DROP TABLE {table_name};')
|
||||
con.commit()
|
||||
async with aiosqlite.connect(database_name) as con:
|
||||
cur: aiosqlite.Cursor = await con.cursor()
|
||||
await cur.execute(f"DROP TABLE {table_name};")
|
||||
await con.commit()
|
||||
|
||||
|
||||
def _modify_records(data, col_to_del: Tuple[str], col_to_add: Tuple[str],
|
||||
flow: Dict[str, str]) -> Tuple[Dict[str, str]]:
|
||||
def _modify_records(
|
||||
data, col_to_del: Tuple[str], col_to_add: Tuple[str], flow: Dict[str, str]
|
||||
) -> Tuple[Dict[str, str]]:
|
||||
"""
|
||||
Modify the asdict records in accordance
|
||||
with schema migration rules provided.
|
||||
@@ -89,9 +96,9 @@ def _modify_records(data, col_to_del: Tuple[str], col_to_add: Tuple[str],
|
||||
:param data: Data kept as asdict in tuple.
|
||||
:param col_to_del: Column names to delete.
|
||||
:param col_to_add: Column names to add.
|
||||
:param flow: A dictionary that explain
|
||||
:param flow: A dictionary that explains
|
||||
if the data from a deleted column
|
||||
will be transferred to a column
|
||||
is transferred to a column
|
||||
to be added.
|
||||
:return: The modified data records.
|
||||
"""
|
||||
@@ -109,15 +116,22 @@ def _modify_records(data, col_to_del: Tuple[str], col_to_add: Tuple[str],
|
||||
if key_to_add not in record_mod:
|
||||
record_mod[key_to_add] = None
|
||||
records.append(record_mod)
|
||||
return records
|
||||
return cast(Tuple[Dict[str, str]], records)
|
||||
|
||||
|
||||
def _migrate_records(class_: type, database_name: str, data,
|
||||
col_to_del: Tuple[str], col_to_add: Tuple[str], flow: Dict[str, str]) -> None:
|
||||
async def _migrate_records(
|
||||
class_: type,
|
||||
database_name: str,
|
||||
data,
|
||||
col_to_del: Tuple[str],
|
||||
col_to_add: Tuple[str],
|
||||
flow: Dict[str, str],
|
||||
safe_migration_defaults: Dict[str, Any] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Migrate the records into the modified table.
|
||||
|
||||
:param class_: Class of the entries.
|
||||
:param class_: Class of entries.
|
||||
:param database_name: Name of the database.
|
||||
:param data: Data, asdict tuple.
|
||||
:param col_to_del: Columns to be deleted.
|
||||
@@ -126,40 +140,94 @@ def _migrate_records(class_: type, database_name: str, data,
|
||||
column data will be transferred.
|
||||
:return: None.
|
||||
"""
|
||||
with sql.connect(database_name) as con:
|
||||
cur: sql.Cursor = con.cursor()
|
||||
_create_table(class_, cur, getattr(class_, 'types_table'))
|
||||
con.commit()
|
||||
if safe_migration_defaults is None:
|
||||
safe_migration_defaults = {}
|
||||
|
||||
async with aiosqlite.connect(database_name) as con:
|
||||
cur: aiosqlite.Cursor = await con.cursor()
|
||||
# noinspection PyUnresolvedReferences
|
||||
await class_.markup_table(
|
||||
class_=class_, cursor=cur, type_overload=getattr(class_, "types_table")
|
||||
)
|
||||
await con.commit()
|
||||
new_records = _modify_records(data, col_to_del, col_to_add, flow)
|
||||
for record in new_records:
|
||||
del record['obj_id']
|
||||
del record["obj_id"]
|
||||
keys_to_delete = [key for key in record if record[key] is None]
|
||||
for key in keys_to_delete:
|
||||
del record[key]
|
||||
class_(**record).create_entry()
|
||||
await class_(
|
||||
**{
|
||||
**record,
|
||||
**{
|
||||
k: _tweaked_dump_value(class_, v)
|
||||
for k, v in safe_migration_defaults.items()
|
||||
if k not in record
|
||||
},
|
||||
}
|
||||
).create_entry()
|
||||
|
||||
|
||||
def basic_migrate(class_: type, column_transfer: dict = None) -> None:
|
||||
async def migrate(
|
||||
class_: type,
|
||||
column_transfer: dict = None,
|
||||
do_backup: bool = True,
|
||||
safe_migration_defaults: Dict[str, Any] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Given a class, compare its previous table,
|
||||
delete the fields that no longer exist,
|
||||
create new columns for new fields. If the
|
||||
column_flow parameter is given, migrate elements
|
||||
from previous column to the new ones. It should be
|
||||
noted that, the obj_ids do not persist.
|
||||
noted that the obj_ids do not persist.
|
||||
|
||||
:param class_: Datalite class to migrate.
|
||||
:param column_transfer: A dictionary showing which
|
||||
columns will be copied to new ones.
|
||||
:param do_backup: Whether to copy a whole database before dropping table
|
||||
:param safe_migration_defaults: Key-value that will be written to old records in the database during
|
||||
migration so as not to break the schema
|
||||
:return: None.
|
||||
"""
|
||||
database_name, table_name = _get_db_table(class_)
|
||||
table_column_names: Tuple[str] = _get_table_column_names(database_name, table_name)
|
||||
values = class_.__dataclass_fields__.values()
|
||||
data_fields: Tuple[Field] = tuple(field for field in values)
|
||||
data_field_names: Tuple[str] = tuple(field.name for field in data_fields)
|
||||
columns_to_be_deleted: Tuple[str] = tuple(column for column in table_column_names if column not in data_field_names)
|
||||
columns_to_be_added: Tuple[str] = tuple(column for column in data_field_names if column not in table_column_names)
|
||||
records = _copy_records(database_name, table_name)
|
||||
_drop_table(database_name, table_name)
|
||||
_migrate_records(class_, database_name, records, columns_to_be_deleted, columns_to_be_added, column_transfer)
|
||||
database_name, table_name = await _get_db_table(class_)
|
||||
table_column_names: Tuple[str] = await _get_table_column_names(
|
||||
database_name, table_name
|
||||
)
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
values: List[Field] = class_.__dataclass_fields__.values()
|
||||
|
||||
data_fields: Tuple[Field] = cast(Tuple[Field], tuple(field for field in values))
|
||||
data_field_names: Tuple[str] = cast(
|
||||
Tuple[str], tuple(field.name for field in data_fields)
|
||||
)
|
||||
columns_to_be_deleted: Tuple[str] = cast(
|
||||
Tuple[str],
|
||||
tuple(
|
||||
column for column in table_column_names if column not in data_field_names
|
||||
),
|
||||
)
|
||||
columns_to_be_added: Tuple[str] = cast(
|
||||
Tuple[str],
|
||||
tuple(
|
||||
column for column in data_field_names if column not in table_column_names
|
||||
),
|
||||
)
|
||||
|
||||
records = await _copy_records(database_name, table_name)
|
||||
if do_backup:
|
||||
shutil.copy(database_name, f"{database_name}-{time.time()}")
|
||||
await _drop_table(database_name, table_name)
|
||||
await _migrate_records(
|
||||
class_,
|
||||
database_name,
|
||||
records,
|
||||
columns_to_be_deleted,
|
||||
columns_to_be_added,
|
||||
column_transfer,
|
||||
safe_migration_defaults,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["migrate"]
|
||||
|
||||
34
datalite/typed.py
Normal file
34
datalite/typed.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
def _datalite_hinted_direct_use():
|
||||
raise ValueError(
|
||||
"Don't use DataliteHinted directly. Inherited classes should also be wrapped in "
|
||||
"datalite and dataclass decorators"
|
||||
)
|
||||
|
||||
|
||||
class DataliteHinted:
|
||||
db_path: str
|
||||
types_table: Dict[Optional[type], str]
|
||||
tweaked: bool
|
||||
obj_id: int
|
||||
|
||||
# noinspection PyMethodMayBeStatic
|
||||
async def markup_table(self):
|
||||
_datalite_hinted_direct_use()
|
||||
|
||||
# noinspection PyMethodMayBeStatic
|
||||
async def create_entry(self):
|
||||
_datalite_hinted_direct_use()
|
||||
|
||||
# noinspection PyMethodMayBeStatic
|
||||
async def update_entry(self):
|
||||
_datalite_hinted_direct_use()
|
||||
|
||||
# noinspection PyMethodMayBeStatic
|
||||
async def remove_entry(self):
|
||||
_datalite_hinted_direct_use()
|
||||
|
||||
|
||||
__all__ = ["DataliteHinted"]
|
||||
23
docs/conf.py
23
docs/conf.py
@@ -12,18 +12,19 @@
|
||||
#
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(0, os.path.abspath('.'))
|
||||
sys.path.insert(0, os.path.abspath('../.'))
|
||||
|
||||
sys.path.insert(0, os.path.abspath("."))
|
||||
sys.path.insert(0, os.path.abspath("../."))
|
||||
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = 'Datalite'
|
||||
copyright = '2020, Ege Ozkan'
|
||||
author = 'Ege Ozkan'
|
||||
project = "Datalite"
|
||||
copyright = "2020, Ege Ozkan"
|
||||
author = "Ege Ozkan"
|
||||
|
||||
# The full version, including alpha/beta/rc tags
|
||||
release = 'v0.7.1'
|
||||
release = "v0.7.1"
|
||||
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
@@ -31,15 +32,15 @@ release = 'v0.7.1'
|
||||
# Add any Sphinx extension module names here, as strings. They can be
|
||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||
# ones.
|
||||
extensions = ['sphinx.ext.autodoc']
|
||||
extensions = ["sphinx.ext.autodoc"]
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ['_templates']
|
||||
templates_path = ["_templates"]
|
||||
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This pattern also affects html_static_path and html_extra_path.
|
||||
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
||||
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
||||
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
@@ -47,9 +48,9 @@ exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||
# a list of builtin themes.
|
||||
#
|
||||
html_theme = 'sphinx_rtd_theme'
|
||||
html_theme = "sphinx_rtd_theme"
|
||||
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
html_static_path = ['_static']
|
||||
html_static_path = ["_static"]
|
||||
|
||||
Reference in New Issue
Block a user