Migrate to aiosqlite, add some undocumented features, will be documented later. Some problems with mass_actions, don't use them, because loading algorythm may change

This commit is contained in:
hhh
2024-03-17 00:25:56 +02:00
parent ac7e9055a5
commit 6dfc3cebbe
9 changed files with 560 additions and 221 deletions

View File

@@ -1,2 +1,10 @@
__all__ = ['commons', 'datalite_decorator', 'fetch', 'migrations', 'datalite', 'constraints', 'mass_actions']
__all__ = [
"commons",
"datalite_decorator",
"fetch",
"migrations",
"datalite",
"constraints",
"mass_actions",
]
from .datalite_decorator import datalite

View File

@@ -1,15 +1,28 @@
from dataclasses import Field
from typing import Any, Optional, Dict, List
from pickle import HIGHEST_PROTOCOL, dumps, loads
from typing import Any, Dict, List, Optional
import aiosqlite
import base64
from .constraints import Unique
import sqlite3 as sql
type_table: Dict[Optional[type], str] = {
None: "NULL",
int: "INTEGER",
float: "REAL",
str: "TEXT",
bytes: "BLOB",
bool: "INTEGER",
}
type_table.update(
{Unique[key]: f"{value} NOT NULL UNIQUE" for key, value in type_table.items()}
)
type_table: Dict[Optional[type], str] = {None: "NULL", int: "INTEGER", float: "REAL",
str: "TEXT", bytes: "BLOB", bool: "INTEGER"}
type_table.update({Unique[key]: f"{value} NOT NULL UNIQUE" for key, value in type_table.items()})
def _convert_type(type_: Optional[type], type_overload: Dict[Optional[type], str]) -> str:
def _convert_type(
type_: Optional[type], type_overload: Dict[Optional[type], str]
) -> str:
"""
Given a Python type, return the str name of its
SQLlite equivalent.
@@ -22,19 +35,24 @@ def _convert_type(type_: Optional[type], type_overload: Dict[Optional[type], str
try:
return type_overload[type_]
except KeyError:
raise TypeError("Requested type not in the default or overloaded type table.")
raise TypeError(
"Requested type not in the default or overloaded type table. Use @datalite(tweaked=True) to "
"encode custom types"
)
def _convert_sql_format(value: Any) -> str:
def _tweaked_convert_type(
type_: Optional[type], type_overload: Dict[Optional[type], str]
) -> str:
return type_overload.get(type_, "BLOB")
def _convert_sql_format(value: Any, type_overload: Dict[Optional[type], str]) -> str:
"""
Given a Python value, convert to string representation
of the equivalent SQL datatype.
:param value: A value, ie: a literal, a variable etc.
:return: The string representation of the SQL equivalent.
>>> _convert_sql_format(1)
"1"
>>> _convert_sql_format("John Smith")
'"John Smith"'
"""
if value is None:
return "NULL"
@@ -44,11 +62,13 @@ def _convert_sql_format(value: Any) -> str:
return '"' + str(value).replace("b'", "")[:-1] + '"'
elif isinstance(value, bool):
return "TRUE" if value else "FALSE"
else:
elif type(value) in type_overload:
return str(value)
else:
return '"' + base64.encodebytes(dumps(value, protocol=HIGHEST_PROTOCOL)).decode() + '"'
def _get_table_cols(cur: sql.Cursor, table_name: str) -> List[str]:
async def _get_table_cols(cur: aiosqlite.Cursor, table_name: str) -> List[str]:
"""
Get the column data of a table.
@@ -56,11 +76,13 @@ def _get_table_cols(cur: sql.Cursor, table_name: str) -> List[str]:
:param table_name: Name of the table.
:return: the information about columns.
"""
cur.execute(f"PRAGMA table_info({table_name});")
return [row_info[1] for row_info in cur.fetchall()][1:]
await cur.execute(f"PRAGMA table_info({table_name});")
return [row_info[1] for row_info in await cur.fetchall()][1:]
def _get_default(default_object: object, type_overload: Dict[Optional[type], str]) -> str:
def _get_default(
default_object: object, type_overload: Dict[Optional[type], str]
) -> str:
"""
Check if the field's default object is filled,
if filled return the string to be put in the,
@@ -71,11 +93,28 @@ def _get_default(default_object: object, type_overload: Dict[Optional[type], str
empty string if no string is necessary.
"""
if type(default_object) in type_overload:
return f' DEFAULT {_convert_sql_format(default_object)}'
return f" DEFAULT {_convert_sql_format(default_object, type_overload)}"
return ""
def _create_table(class_: type, cursor: sql.Cursor, type_overload: Dict[Optional[type], str] = type_table) -> None:
# noinspection PyDefaultArgument
async def _tweaked_create_table(
class_: type,
cursor: aiosqlite.Cursor,
type_overload: Dict[Optional[type], str] = type_table,
) -> None:
await _create_table(
class_, cursor, type_overload, type_converter=_tweaked_convert_type
)
# noinspection PyDefaultArgument
async def _create_table(
class_: type,
cursor: aiosqlite.Cursor,
type_overload: Dict[Optional[type], str] = type_table,
type_converter=_convert_type,
) -> None:
"""
Create the table for a specific dataclass given
:param class_: A dataclass.
@@ -84,10 +123,35 @@ def _create_table(class_: type, cursor: sql.Cursor, type_overload: Dict[Optional
with a custom table, this is that custom table.
:return: None.
"""
fields: List[Field] = [class_.__dataclass_fields__[key] for
key in class_.__dataclass_fields__.keys()]
# noinspection PyUnresolvedReferences
fields: List[Field] = [
class_.__dataclass_fields__[key] for key in class_.__dataclass_fields__.keys()
]
fields.sort(key=lambda field: field.name) # Since dictionaries *may* be unsorted.
sql_fields = ', '.join(f"{field.name} {_convert_type(field.type, type_overload)}"
f"{_get_default(field.default, type_overload)}" for field in fields)
sql_fields = ", ".join(
f"{field.name} {type_converter(field.type, type_overload)}"
f"{_get_default(field.default, type_overload)}"
for field in fields
)
sql_fields = "obj_id INTEGER PRIMARY KEY AUTOINCREMENT, " + sql_fields
cursor.execute(f"CREATE TABLE IF NOT EXISTS {class_.__name__.lower()} ({sql_fields});")
await cursor.execute(
f"CREATE TABLE IF NOT EXISTS {class_.__name__.lower()} ({sql_fields});"
)
def _tweaked_dump_value(self, value):
if type(value) in self.types_table:
return value
else:
return bytes(dumps(value, protocol=HIGHEST_PROTOCOL))
def _tweaked_dump(self, name):
value = getattr(self, name)
return _tweaked_dump_value(self, value)
def _tweaked_load_value(data):
return loads(bytes(data))

View File

@@ -4,15 +4,16 @@ datalite.constraints module introduces constraint
that can be used to signal datalite decorator
constraints in the database.
"""
from typing import TypeVar, Union, Tuple
from typing import Tuple, TypeVar, Union
T = TypeVar('T')
T = TypeVar("T")
class ConstraintFailedError(Exception):
"""
This exception is raised when a Constraint fails.
"""
pass
@@ -24,3 +25,4 @@ Dataclass fields hinted with this type signals
Unique = Union[Tuple[T], T]
__all__ = ['Unique', 'ConstraintFailedError']

View File

@@ -1,93 +1,160 @@
"""
Defines the Datalite decorator that can be used to convert a dataclass to
a class bound to a sqlite3 database.
a class bound to an sqlite3 database.
"""
from dataclasses import asdict, fields
from sqlite3.dbapi2 import IntegrityError
from typing import Dict, Optional, Callable
from dataclasses import asdict
import sqlite3 as sql
from typing import Callable, Dict, Optional
import aiosqlite
from .commons import _create_table, _tweaked_create_table, _tweaked_dump, type_table
from .constraints import ConstraintFailedError
from .commons import _convert_sql_format, _convert_type, _create_table, type_table
def _create_entry(self) -> None:
async def _create_entry(self) -> None:
"""
Given an object, create the entry for the object. As a side-effect,
Given an object, create the entry for the object. As a side effect,
this will set the object_id attribute of the object to the unique
id of the entry.
:param self: Instance of the object.
:return: None.
"""
with sql.connect(getattr(self, "db_path")) as con:
cur: sql.Cursor = con.cursor()
async with aiosqlite.connect(getattr(self, "db_path")) as con:
cur: aiosqlite.Cursor = await con.cursor()
table_name: str = self.__class__.__name__.lower()
kv_pairs = [item for item in asdict(self).items()]
kv_pairs.sort(key=lambda item: item[0]) # Sort by the name of the fields.
try:
cur.execute(f"INSERT INTO {table_name}("
await cur.execute(
f"INSERT INTO {table_name}("
f"{', '.join(item[0] for item in kv_pairs)})"
f" VALUES ({', '.join('?' for item in kv_pairs)})",
[item[1] for item in kv_pairs])
f" VALUES ({', '.join('?' for _ in kv_pairs)})",
[item[1] for item in kv_pairs],
)
self.__setattr__("obj_id", cur.lastrowid)
con.commit()
await con.commit()
except IntegrityError:
raise ConstraintFailedError("A constraint has failed.")
def _update_entry(self) -> None:
async def _tweaked_create_entry(self) -> None:
async with aiosqlite.connect(getattr(self, "db_path")) as con:
cur: aiosqlite.Cursor = await con.cursor()
table_name: str = self.__class__.__name__.lower()
kv_pairs = [item for item in fields(self)]
kv_pairs.sort(key=lambda item: item.name) # Sort by the name of the fields.
try:
await cur.execute(
f"INSERT INTO {table_name}("
f"{', '.join(item.name for item in kv_pairs)})"
f" VALUES ({', '.join('?' for _ in kv_pairs)})",
[_tweaked_dump(self, item.name) for item in kv_pairs],
)
self.__setattr__("obj_id", cur.lastrowid)
await con.commit()
except IntegrityError:
raise ConstraintFailedError("A constraint has failed.")
async def _update_entry(self) -> None:
"""
Given an object, update the objects entry in the bound database.
:param self: The object.
:return: None.
"""
with sql.connect(getattr(self, "db_path")) as con:
cur: sql.Cursor = con.cursor()
async with aiosqlite.connect(getattr(self, "db_path")) as con:
cur: aiosqlite.Cursor = await con.cursor()
table_name: str = self.__class__.__name__.lower()
kv_pairs = [item for item in asdict(self).items()]
kv_pairs.sort(key=lambda item: item[0])
query = f"UPDATE {table_name} " + \
f"SET {', '.join(item[0] + ' = ?' for item in kv_pairs)} " + \
query = (
f"UPDATE {table_name} "
f"SET {', '.join(item[0] + ' = ?' for item in kv_pairs)} "
f"WHERE obj_id = {getattr(self, 'obj_id')};"
cur.execute(query, [item[1] for item in kv_pairs])
con.commit()
)
await cur.execute(query, [item[1] for item in kv_pairs])
await con.commit()
def remove_from(class_: type, obj_id: int):
with sql.connect(getattr(class_, "db_path")) as con:
cur: sql.Cursor = con.cursor()
cur.execute(f"DELETE FROM {class_.__name__.lower()} WHERE obj_id = ?", (obj_id, ))
con.commit()
async def _tweaked_update_entry(self) -> None:
async with aiosqlite.connect(getattr(self, "db_path")) as con:
cur: aiosqlite.Cursor = await con.cursor()
table_name: str = self.__class__.__name__.lower()
kv_pairs = [item for item in fields(self)]
kv_pairs.sort(key=lambda item: item.name)
query = (
f"UPDATE {table_name} "
f"SET {', '.join(item.name + ' = ?' for item in kv_pairs)} "
f"WHERE obj_id = {getattr(self, 'obj_id')};"
)
await cur.execute(query, [_tweaked_dump(self, item.name) for item in kv_pairs])
await con.commit()
def _remove_entry(self) -> None:
async def remove_from(class_: type, obj_id: int):
async with aiosqlite.connect(getattr(class_, "db_path")) as con:
cur: aiosqlite.Cursor = await con.cursor()
await cur.execute(
f"DELETE FROM {class_.__name__.lower()} WHERE obj_id = ?", (obj_id,)
)
await con.commit()
async def _remove_entry(self) -> None:
"""
Remove the object's record in the underlying database.
:param self: self instance.
:return: None.
"""
remove_from(self.__class__, getattr(self, 'obj_id'))
await remove_from(self.__class__, getattr(self, "obj_id"))
def datalite(db_path: str, type_overload: Optional[Dict[Optional[type], str]] = None) -> Callable:
def _markup_table(markup_function):
async def inner(self=None, **kwargs):
if not kwargs:
async with aiosqlite.connect(getattr(self, "db_path")) as con:
cur: aiosqlite.Cursor = await con.cursor()
await markup_function(self.__class__, cur, self.types_table)
else:
await markup_function(**kwargs)
return inner
def datalite(
db_path: str,
type_overload: Optional[Dict[Optional[type], str]] = None,
tweaked: bool = True,
) -> Callable:
"""Bind a dataclass to a sqlite3 database. This adds new methods to the class, such as
`create_entry()`, `remove_entry()` and `update_entry()`.
:param db_path: Path of the database to be binded.
:param db_path: Path of the database to be bound.
:param type_overload: Type overload dictionary.
:param tweaked: Whether to use pickle type tweaks
:return: The new dataclass.
"""
def decorator(dataclass_: type, *args_i, **kwargs_i):
def decorator(dataclass_: type, *_, **__):
types_table = type_table.copy()
if type_overload is not None:
types_table.update(type_overload)
with sql.connect(db_path) as con:
cur: sql.Cursor = con.cursor()
_create_table(dataclass_, cur, types_table)
setattr(dataclass_, 'db_path', db_path) # We add the path of the database to class itself.
setattr(dataclass_, 'types_table', types_table) # We add the type table for migration.
setattr(dataclass_, "db_path", db_path)
setattr(dataclass_, "types_table", types_table)
setattr(dataclass_, "tweaked", tweaked)
if tweaked:
dataclass_.markup_table = _markup_table(_tweaked_create_table)
dataclass_.create_entry = _tweaked_create_entry
dataclass_.update_entry = _tweaked_update_entry
else:
dataclass_.markup_table = _markup_table(_create_table)
dataclass_.create_entry = _create_entry
dataclass_.remove_entry = _remove_entry
dataclass_.update_entry = _update_entry
dataclass_.remove_entry = _remove_entry
return dataclass_
return decorator

View File

@@ -1,6 +1,12 @@
import sqlite3 as sql
from typing import List, Tuple, Any
from .commons import _convert_sql_format, _get_table_cols
import aiosqlite
from typing import Any, List, Tuple, TypeVar, Type, cast
from .commons import _convert_sql_format, _get_table_cols, _tweaked_load_value
import base64
T = TypeVar("T")
def _insert_pagination(query: str, page: int, element_count: int) -> str:
@@ -17,7 +23,7 @@ def _insert_pagination(query: str, page: int, element_count: int) -> str:
return query + ";"
def is_fetchable(class_: type, obj_id: int) -> bool:
async def is_fetchable(class_: type, obj_id: int) -> bool:
"""
Check if a record is fetchable given its obj_id and
class_ type.
@@ -26,16 +32,22 @@ def is_fetchable(class_: type, obj_id: int) -> bool:
:param obj_id: Unique obj_id of the object.
:return: If the object is fetchable.
"""
with sql.connect(getattr(class_, 'db_path')) as con:
cur: sql.Cursor = con.cursor()
async with aiosqlite.connect(getattr(class_, "db_path")) as con:
cur: aiosqlite.Cursor = await con.cursor()
try:
cur.execute(f"SELECT 1 FROM {class_.__name__.lower()} WHERE obj_id = ?;", (obj_id, ))
except sql.OperationalError:
await cur.execute(
f"SELECT 1 FROM {class_.__name__.lower()} WHERE obj_id = ?;", (obj_id,)
)
except aiosqlite.OperationalError:
raise KeyError(f"Table {class_.__name__.lower()} does not exist.")
return bool(cur.fetchall())
return bool(await cur.fetchall())
def fetch_equals(class_: type, field: str, value: Any, ) -> Any:
async def fetch_equals(
class_: Type[T],
field: str,
value: Any,
) -> T:
"""
Fetch a class_ type variable from its bound db.
@@ -45,18 +57,30 @@ def fetch_equals(class_: type, field: str, value: Any, ) -> Any:
:return: The object whose data is taken from the database.
"""
table_name = class_.__name__.lower()
with sql.connect(getattr(class_, 'db_path')) as con:
cur: sql.Cursor = con.cursor()
cur.execute(f"SELECT * FROM {table_name} WHERE {field} = ?;", (value, ))
obj_id, *field_values = list(cur.fetchone())
field_names: List[str] = _get_table_cols(cur, class_.__name__.lower())
async with aiosqlite.connect(getattr(class_, "db_path")) as con:
cur: aiosqlite.Cursor = await con.cursor()
await cur.execute(f"SELECT * FROM {table_name} WHERE {field} = ?;", (value,))
obj_id, *field_values = list(await cur.fetchone())
field_names: List[str] = await _get_table_cols(cur, class_.__name__.lower())
kwargs = dict(zip(field_names, field_values))
if getattr(class_, "tweaked"):
types_table = getattr(class_, "types_table")
field_types = {
key: value.type for key, value in class_.__dataclass_fields__.items()
}
for key in kwargs.keys():
if field_types[key] not in types_table.keys():
kwargs[key] = _tweaked_load_value(
kwargs[key] if isinstance(kwargs[key], bytes) else base64.decodebytes(kwargs[key].encode("utf-8"))
)
obj = class_(**kwargs)
setattr(obj, "obj_id", obj_id)
return obj
def fetch_from(class_: type, obj_id: int) -> Any:
async def fetch_from(class_: Type[T], obj_id: int) -> T:
"""
Fetch a class_ type variable from its bound dv.
@@ -64,13 +88,17 @@ def fetch_from(class_: type, obj_id: int) -> Any:
:param obj_id: Unique object id of the object.
:return: The fetched object.
"""
if not is_fetchable(class_, obj_id):
raise KeyError(f"An object with {obj_id} of type {class_.__name__} does not exist, or"
f"otherwise is unreachable.")
return fetch_equals(class_, 'obj_id', obj_id)
if not await is_fetchable(class_, obj_id):
raise KeyError(
f"An object with {obj_id} of type {class_.__name__} does not exist, or"
f"otherwise is unreachable."
)
return await fetch_equals(class_, "obj_id", obj_id)
def _convert_record_to_object(class_: type, record: Tuple[Any], field_names: List[str]) -> Any:
def _convert_record_to_object(
class_: Type[T], record: Tuple[Any], field_names: List[str]
) -> T:
"""
Convert a given record fetched from an SQL instance to a Python Object of given class_.
@@ -80,17 +108,30 @@ def _convert_record_to_object(class_: type, record: Tuple[Any], field_names: Lis
:return: the created object.
"""
kwargs = dict(zip(field_names, record[1:]))
field_types = {key: value.type for key, value in class_.__dataclass_fields__.items()}
for key in kwargs:
field_types = {
key: value.type for key, value in class_.__dataclass_fields__.items()
}
is_tweaked = getattr(class_, "tweaked")
types_table = getattr(class_, "types_table")
for key in kwargs.keys():
if field_types[key] == bytes:
kwargs[key] = bytes(kwargs[key], encoding='utf-8')
kwargs[key] = bytes(kwargs[key], encoding="utf-8")
elif is_tweaked:
if field_types[key] not in types_table.keys():
kwargs[key] = _tweaked_load_value(
kwargs[key] if isinstance(kwargs[key], bytes) else base64.decodebytes(kwargs[key].encode("utf-8"))
)
obj_id = record[0]
obj = class_(**kwargs)
setattr(obj, "obj_id", obj_id)
return obj
def fetch_if(class_: type, condition: str, page: int = 0, element_count: int = 10) -> tuple:
async def fetch_if(
class_: Type[T], condition: str, page: int = 0, element_count: int = 10
) -> T:
"""
Fetch all class_ type variables from the bound db,
provided they fit the given condition
@@ -103,18 +144,27 @@ def fetch_if(class_: type, condition: str, page: int = 0, element_count: int = 1
of given type class_.
"""
table_name = class_.__name__.lower()
with sql.connect(getattr(class_, 'db_path')) as con:
cur: sql.Cursor = con.cursor()
cur.execute(_insert_pagination(f"SELECT * FROM {table_name} WHERE {condition}", page, element_count))
records: list = cur.fetchall()
field_names: List[str] = _get_table_cols(cur, table_name)
return tuple(_convert_record_to_object(class_, record, field_names) for record in records)
async with aiosqlite.connect(getattr(class_, "db_path")) as con:
cur: aiosqlite.Cursor = await con.cursor()
await cur.execute(
_insert_pagination(
f"SELECT * FROM {table_name} WHERE {condition}", page, element_count
)
)
# noinspection PyTypeChecker
records: list = await cur.fetchall()
field_names: List[str] = await _get_table_cols(cur, table_name)
return tuple(
_convert_record_to_object(class_, record, field_names) for record in records
)
def fetch_where(class_: type, field: str, value: Any, page: int = 0, element_count: int = 10) -> tuple:
async def fetch_where(
class_: Type[T], field: str, value: Any, page: int = 0, element_count: int = 10
) -> tuple[T]:
"""
Fetch all class_ type variables from the bound db,
provided that the field of the records fit the
Fetch all class_ type variables from the bound db
if the field of the records fit the
given value.
:param class_: Class of the records.
@@ -124,10 +174,12 @@ def fetch_where(class_: type, field: str, value: Any, page: int = 0, element_cou
:param element_count: Element count in each page.
:return: A tuple of the records.
"""
return fetch_if(class_, f"{field} = {_convert_sql_format(value)}", page, element_count)
return await fetch_if(
class_, f"{field} = {_convert_sql_format(value, getattr(class_, 'types_table'))}", page, element_count
)
def fetch_range(class_: type, range_: range) -> tuple:
async def fetch_range(class_: Type[T], range_: range) -> tuple[T]:
"""
Fetch the records in a given range of object ids.
@@ -136,10 +188,21 @@ def fetch_range(class_: type, range_: range) -> tuple:
:return: A tuple of class_ type objects whose values
come from the class_' bound database.
"""
return tuple(fetch_from(class_, obj_id) for obj_id in range_ if is_fetchable(class_, obj_id))
return cast(
tuple[T],
tuple(
[
(await fetch_from(class_, obj_id))
for obj_id in range_
if (await is_fetchable(class_, obj_id))
]
),
)
def fetch_all(class_: type, page: int = 0, element_count: int = 10) -> tuple:
async def fetch_all(
class_: Type[T], page: int = 0, element_count: int = 10
) -> tuple[T]:
"""
Fetch all the records in the bound database.
@@ -150,15 +213,26 @@ def fetch_all(class_: type, page: int = 0, element_count: int = 10) -> tuple:
the bound database as a tuple.
"""
try:
db_path = getattr(class_, 'db_path')
db_path = getattr(class_, "db_path")
except AttributeError:
raise TypeError("Given class is not decorated with datalite.")
with sql.connect(db_path) as con:
cur: sql.Cursor = con.cursor()
async with aiosqlite.connect(db_path) as con:
cur: aiosqlite.Cursor = await con.cursor()
try:
cur.execute(_insert_pagination(f"SELECT * FROM {class_.__name__.lower()}", page, element_count))
except sql.OperationalError:
await cur.execute(
_insert_pagination(
f"SELECT * FROM {class_.__name__.lower()}", page, element_count
)
)
except aiosqlite.OperationalError:
raise TypeError(f"No record of type {class_.__name__.lower()}")
records = cur.fetchall()
field_names: List[str] = _get_table_cols(cur, class_.__name__.lower())
return tuple(_convert_record_to_object(class_, record, field_names) for record in records)
records = await cur.fetchall()
field_names: List[str] = await _get_table_cols(cur, class_.__name__.lower())
# noinspection PyTypeChecker
return cast(
tuple[T],
tuple(_convert_record_to_object(class_, record, field_names) for record in records),
)
__all__ = ["is_fetchable", "fetch_equals", "fetch_from", "fetch_if", "fetch_where", "fetch_range", "fetch_all"]

View File

@@ -3,14 +3,15 @@ This module includes functions to insert multiple records
to a bound database at one time, with one time open and closing
of the database file.
"""
from typing import TypeVar, Union, List, Tuple
import aiosqlite
from dataclasses import asdict
from typing import List, Tuple, TypeVar, Union
from warnings import warn
from .constraints import ConstraintFailedError
from .commons import _convert_sql_format, _create_table
import sqlite3 as sql
T = TypeVar('T')
from .commons import _convert_sql_format
from .constraints import ConstraintFailedError
T = TypeVar("T")
class HeterogeneousCollectionError(Exception):
@@ -19,45 +20,56 @@ class HeterogeneousCollectionError(Exception):
ie: If a List or Tuple has elements of multiple
types.
"""
pass
def _check_homogeneity(objects: Union[List[T], Tuple[T]]) -> None:
"""
Check if all of the members a Tuple or a List
Check if all the members a Tuple or a List
is of the same type.
:param objects: Tuple or list to check.
:return: If all of the members of the same type.
:return: If all the members of the same type.
"""
class_ = objects[0].__class__
if not all([isinstance(obj, class_) or isinstance(objects[0], obj.__class__) for obj in objects]):
if not all(
[
isinstance(obj, class_) or isinstance(objects[0], obj.__class__)
for obj in objects
]
):
raise HeterogeneousCollectionError("Tuple or List is not homogeneous.")
def _toggle_memory_protection(cur: sql.Cursor, protect_memory: bool) -> None:
async def _toggle_memory_protection(cur: aiosqlite.Cursor, protect_memory: bool) -> None:
"""
Given a cursor to an sqlite3 connection, if memory protection is false,
Given a cursor to a sqlite3 connection, if memory protection is false,
toggle memory protections off.
:param cur: Cursor to an open SQLite3 connection.
:param protect_memory: Whether or not should memory be protected.
:param protect_memory: Whether should memory be protected.
:return: Memory protections off.
"""
if not protect_memory:
warn("Memory protections are turned off, "
"if operations are interrupted, file may get corrupt.", RuntimeWarning)
cur.execute("PRAGMA synchronous = OFF")
cur.execute("PRAGMA journal_mode = MEMORY")
warn(
"Memory protections are turned off, "
"if operations are interrupted, file may get corrupt.",
RuntimeWarning,
)
await cur.execute("PRAGMA synchronous = OFF")
await cur.execute("PRAGMA journal_mode = MEMORY")
def _mass_insert(objects: Union[List[T], Tuple[T]], db_name: str, protect_memory: bool = True) -> None:
async def _mass_insert(
objects: Union[List[T], Tuple[T]], db_name: str, protect_memory: bool = True
) -> None:
"""
Insert multiple records into an SQLite3 database.
:param objects: Objects to insert.
:param db_name: Name of the database to insert.
:param protect_memory: Whether or not memory
:param protect_memory: Whether memory
protections are on or off.
:return: None
"""
@@ -69,24 +81,28 @@ def _mass_insert(objects: Union[List[T], Tuple[T]], db_name: str, protect_memory
for i, obj in enumerate(objects):
kv_pairs = asdict(obj).items()
setattr(obj, "obj_id", first_index + i + 1)
sql_queries.append(f"INSERT INTO {table_name}(" +
f"{', '.join(item[0] for item in kv_pairs)})" +
f" VALUES ({', '.join(_convert_sql_format(item[1]) for item in kv_pairs)});")
with sql.connect(db_name) as con:
cur: sql.Cursor = con.cursor()
sql_queries.append(
f"INSERT INTO {table_name}("
+ f"{', '.join(item[0] for item in kv_pairs)})"
+ f" VALUES ({', '.join(_convert_sql_format(item[1], getattr(obj, 'types_table')) for item in kv_pairs)});"
)
async with aiosqlite.connect(db_name) as con:
cur: aiosqlite.Cursor = await con.cursor()
try:
_toggle_memory_protection(cur, protect_memory)
cur.execute(f"SELECT obj_id FROM {table_name} ORDER BY obj_id DESC LIMIT 1")
index_tuple = cur.fetchone()
await _toggle_memory_protection(cur, protect_memory)
await cur.execute(f"SELECT obj_id FROM {table_name} ORDER BY obj_id DESC LIMIT 1")
index_tuple = await cur.fetchone()
if index_tuple:
first_index = index_tuple[0]
cur.executescript("BEGIN TRANSACTION;\n" + '\n'.join(sql_queries) + '\nEND TRANSACTION;')
except sql.IntegrityError:
_ = index_tuple[0]
await cur.executescript(
"BEGIN TRANSACTION;\n" + "\n".join(sql_queries) + "\nEND TRANSACTION;"
)
except aiosqlite.IntegrityError:
raise ConstraintFailedError
con.commit()
await con.commit()
def create_many(objects: Union[List[T], Tuple[T]], protect_memory: bool = True) -> None:
async def create_many(objects: Union[List[T], Tuple[T]], protect_memory: bool = True) -> None:
"""
Insert many records corresponding to objects
in a tuple or a list.
@@ -98,29 +114,34 @@ def create_many(objects: Union[List[T], Tuple[T]], protect_memory: bool = True)
:return: None.
"""
if objects:
_mass_insert(objects, getattr(objects[0], "db_path"), protect_memory)
await _mass_insert(objects, getattr(objects[0], "db_path"), protect_memory)
else:
raise ValueError("Collection is empty.")
def copy_many(objects: Union[List[T], Tuple[T]], db_name: str, protect_memory: bool = True) -> None:
async def copy_many(
objects: Union[List[T], Tuple[T]], db_name: str, protect_memory: bool = True
) -> None:
"""
Copy many records to another database, from
their original database to new database, do
their original database to a new database, do
not delete old records.
:param objects: Objects to copy.
:param db_name: Name of the new database.
:param protect_memory: Wheter to protect memory during operation,
:param protect_memory: Whether to protect memory during operation,
Setting this to False will quicken the operation, but if the
operation is cut short, database file will corrupt.
operation is cut short, the database file will corrupt.
:return: None
"""
if objects:
with sql.connect(db_name) as con:
cur = con.cursor()
_create_table(objects[0].__class__, cur)
con.commit()
_mass_insert(objects, db_name, protect_memory)
async with aiosqlite.connect(db_name) as con:
cur = await con.cursor()
await objects[0].markup_table(class_=objects[0].__class__, cursor=cur)
await con.commit()
await _mass_insert(objects, db_name, protect_memory)
else:
raise ValueError("Collection is empty.")
__all__ = ['copy_many', 'create_many', 'HeterogeneousCollectionError']

View File

@@ -2,53 +2,59 @@
Migrations module deals with migrating data when the object
definitions change. This functions deal with Schema Migrations.
"""
import shutil
import time
from dataclasses import Field
from os.path import exists
from typing import Dict, Tuple, List
import sqlite3 as sql
from typing import Any, Dict, List, Tuple, cast
from .commons import _create_table, _get_table_cols
import aiosqlite
from .commons import _get_table_cols, _tweaked_dump_value
def _get_db_table(class_: type) -> Tuple[str, str]:
async def _get_db_table(class_: type) -> Tuple[str, str]:
"""
Check if the class is a datalite class, the database exists
Check if the class is a datalite class, the database exists,
and the table exists. Return database and table names.
:param class_: A datalite class.
:return: A tuple of database and table names.
"""
database_name: str = getattr(class_, 'db_path', None)
database_name: str = getattr(class_, "db_path", None)
if not database_name:
raise TypeError(f"{class_.__name__} is not a datalite class.")
table_name: str = class_.__name__.lower()
if not exists(database_name):
raise FileNotFoundError(f"{database_name} does not exist")
with sql.connect(database_name) as con:
cur: sql.Cursor = con.cursor()
cur.execute("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?;", (table_name, ))
count: int = int(cur.fetchone()[0])
async with aiosqlite.connect(database_name) as con:
cur: aiosqlite.Cursor = await con.cursor()
await cur.execute(
"SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?;",
(table_name,),
)
count: int = int((await cur.fetchone())[0])
if not count:
raise FileExistsError(f"Table, {table_name}, already exists.")
return database_name, table_name
def _get_table_column_names(database_name: str, table_name: str) -> Tuple[str]:
async def _get_table_column_names(database_name: str, table_name: str) -> Tuple[str]:
"""
Get the column names of table.
Get the column names of the table.
:param database_name: Name of the database the table
:param database_name: The name of the database the table
resides in.
:param table_name: Name of the table.
:return: A tuple holding the column names of the table.
"""
with sql.connect(database_name) as con:
cur: sql.Cursor = con.cursor()
cols: List[str] = _get_table_cols(cur, table_name)
return tuple(cols)
async with aiosqlite.connect(database_name) as con:
cur: aiosqlite.Cursor = await con.cursor()
cols: List[str] = await _get_table_cols(cur, table_name)
return cast(Tuple[str], tuple(cols))
def _copy_records(database_name: str, table_name: str):
async def _copy_records(database_name: str, table_name: str):
"""
Copy all records from a table.
@@ -56,17 +62,17 @@ def _copy_records(database_name: str, table_name: str):
:param table_name: Name of the table.
:return: A generator holding dataclass asdict representations.
"""
with sql.connect(database_name) as con:
cur: sql.Cursor = con.cursor()
cur.execute(f'SELECT * FROM {table_name};')
values = cur.fetchall()
keys = _get_table_cols(cur, table_name)
keys.insert(0, 'obj_id')
async with aiosqlite.connect(database_name) as con:
cur: aiosqlite.Cursor = await con.cursor()
await cur.execute(f"SELECT * FROM {table_name};")
values = await cur.fetchall()
keys = await _get_table_cols(cur, table_name)
keys.insert(0, "obj_id")
records = (dict(zip(keys, value)) for value in values)
return records
def _drop_table(database_name: str, table_name: str) -> None:
async def _drop_table(database_name: str, table_name: str) -> None:
"""
Drop a table.
@@ -74,14 +80,15 @@ def _drop_table(database_name: str, table_name: str) -> None:
:param table_name: Name of the table to be dropped.
:return: None.
"""
with sql.connect(database_name) as con:
cur: sql.Cursor = con.cursor()
cur.execute(f'DROP TABLE {table_name};')
con.commit()
async with aiosqlite.connect(database_name) as con:
cur: aiosqlite.Cursor = await con.cursor()
await cur.execute(f"DROP TABLE {table_name};")
await con.commit()
def _modify_records(data, col_to_del: Tuple[str], col_to_add: Tuple[str],
flow: Dict[str, str]) -> Tuple[Dict[str, str]]:
def _modify_records(
data, col_to_del: Tuple[str], col_to_add: Tuple[str], flow: Dict[str, str]
) -> Tuple[Dict[str, str]]:
"""
Modify the asdict records in accordance
with schema migration rules provided.
@@ -89,9 +96,9 @@ def _modify_records(data, col_to_del: Tuple[str], col_to_add: Tuple[str],
:param data: Data kept as asdict in tuple.
:param col_to_del: Column names to delete.
:param col_to_add: Column names to add.
:param flow: A dictionary that explain
:param flow: A dictionary that explains
if the data from a deleted column
will be transferred to a column
is transferred to a column
to be added.
:return: The modified data records.
"""
@@ -109,15 +116,22 @@ def _modify_records(data, col_to_del: Tuple[str], col_to_add: Tuple[str],
if key_to_add not in record_mod:
record_mod[key_to_add] = None
records.append(record_mod)
return records
return cast(Tuple[Dict[str, str]], records)
def _migrate_records(class_: type, database_name: str, data,
col_to_del: Tuple[str], col_to_add: Tuple[str], flow: Dict[str, str]) -> None:
async def _migrate_records(
class_: type,
database_name: str,
data,
col_to_del: Tuple[str],
col_to_add: Tuple[str],
flow: Dict[str, str],
safe_migration_defaults: Dict[str, Any] = None,
) -> None:
"""
Migrate the records into the modified table.
:param class_: Class of the entries.
:param class_: Class of entries.
:param database_name: Name of the database.
:param data: Data, asdict tuple.
:param col_to_del: Columns to be deleted.
@@ -126,40 +140,94 @@ def _migrate_records(class_: type, database_name: str, data,
column data will be transferred.
:return: None.
"""
with sql.connect(database_name) as con:
cur: sql.Cursor = con.cursor()
_create_table(class_, cur, getattr(class_, 'types_table'))
con.commit()
if safe_migration_defaults is None:
safe_migration_defaults = {}
async with aiosqlite.connect(database_name) as con:
cur: aiosqlite.Cursor = await con.cursor()
# noinspection PyUnresolvedReferences
await class_.markup_table(
class_=class_, cursor=cur, type_overload=getattr(class_, "types_table")
)
await con.commit()
new_records = _modify_records(data, col_to_del, col_to_add, flow)
for record in new_records:
del record['obj_id']
del record["obj_id"]
keys_to_delete = [key for key in record if record[key] is None]
for key in keys_to_delete:
del record[key]
class_(**record).create_entry()
await class_(
**{
**record,
**{
k: _tweaked_dump_value(class_, v)
for k, v in safe_migration_defaults.items()
if k not in record
},
}
).create_entry()
def basic_migrate(class_: type, column_transfer: dict = None) -> None:
async def migrate(
class_: type,
column_transfer: dict = None,
do_backup: bool = True,
safe_migration_defaults: Dict[str, Any] = None,
) -> None:
"""
Given a class, compare its previous table,
delete the fields that no longer exist,
create new columns for new fields. If the
column_flow parameter is given, migrate elements
from previous column to the new ones. It should be
noted that, the obj_ids do not persist.
noted that the obj_ids do not persist.
:param class_: Datalite class to migrate.
:param column_transfer: A dictionary showing which
columns will be copied to new ones.
:param do_backup: Whether to copy a whole database before dropping table
:param safe_migration_defaults: Key-value that will be written to old records in the database during
migration so as not to break the schema
:return: None.
"""
database_name, table_name = _get_db_table(class_)
table_column_names: Tuple[str] = _get_table_column_names(database_name, table_name)
values = class_.__dataclass_fields__.values()
data_fields: Tuple[Field] = tuple(field for field in values)
data_field_names: Tuple[str] = tuple(field.name for field in data_fields)
columns_to_be_deleted: Tuple[str] = tuple(column for column in table_column_names if column not in data_field_names)
columns_to_be_added: Tuple[str] = tuple(column for column in data_field_names if column not in table_column_names)
records = _copy_records(database_name, table_name)
_drop_table(database_name, table_name)
_migrate_records(class_, database_name, records, columns_to_be_deleted, columns_to_be_added, column_transfer)
database_name, table_name = await _get_db_table(class_)
table_column_names: Tuple[str] = await _get_table_column_names(
database_name, table_name
)
# noinspection PyUnresolvedReferences
values: List[Field] = class_.__dataclass_fields__.values()
data_fields: Tuple[Field] = cast(Tuple[Field], tuple(field for field in values))
data_field_names: Tuple[str] = cast(
Tuple[str], tuple(field.name for field in data_fields)
)
columns_to_be_deleted: Tuple[str] = cast(
Tuple[str],
tuple(
column for column in table_column_names if column not in data_field_names
),
)
columns_to_be_added: Tuple[str] = cast(
Tuple[str],
tuple(
column for column in data_field_names if column not in table_column_names
),
)
records = await _copy_records(database_name, table_name)
if do_backup:
shutil.copy(database_name, f"{database_name}-{time.time()}")
await _drop_table(database_name, table_name)
await _migrate_records(
class_,
database_name,
records,
columns_to_be_deleted,
columns_to_be_added,
column_transfer,
safe_migration_defaults,
)
__all__ = ["migrate"]

34
datalite/typed.py Normal file
View File

@@ -0,0 +1,34 @@
from typing import Dict, Optional
def _datalite_hinted_direct_use():
raise ValueError(
"Don't use DataliteHinted directly. Inherited classes should also be wrapped in "
"datalite and dataclass decorators"
)
class DataliteHinted:
db_path: str
types_table: Dict[Optional[type], str]
tweaked: bool
obj_id: int
# noinspection PyMethodMayBeStatic
async def markup_table(self):
_datalite_hinted_direct_use()
# noinspection PyMethodMayBeStatic
async def create_entry(self):
_datalite_hinted_direct_use()
# noinspection PyMethodMayBeStatic
async def update_entry(self):
_datalite_hinted_direct_use()
# noinspection PyMethodMayBeStatic
async def remove_entry(self):
_datalite_hinted_direct_use()
__all__ = ["DataliteHinted"]

View File

@@ -12,18 +12,19 @@
#
import os
import sys
sys.path.insert(0, os.path.abspath('.'))
sys.path.insert(0, os.path.abspath('../.'))
sys.path.insert(0, os.path.abspath("."))
sys.path.insert(0, os.path.abspath("../."))
# -- Project information -----------------------------------------------------
project = 'Datalite'
copyright = '2020, Ege Ozkan'
author = 'Ege Ozkan'
project = "Datalite"
copyright = "2020, Ege Ozkan"
author = "Ege Ozkan"
# The full version, including alpha/beta/rc tags
release = 'v0.7.1'
release = "v0.7.1"
# -- General configuration ---------------------------------------------------
@@ -31,15 +32,15 @@ release = 'v0.7.1'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = ['sphinx.ext.autodoc']
extensions = ["sphinx.ext.autodoc"]
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
templates_path = ["_templates"]
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
# -- Options for HTML output -------------------------------------------------
@@ -47,9 +48,9 @@ exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'sphinx_rtd_theme'
html_theme = "sphinx_rtd_theme"
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
html_static_path = ["_static"]