Migrate to aiosqlite, add some undocumented features, will be documented later. Some problems with mass_actions, don't use them, because loading algorythm may change
This commit is contained in:
@@ -1,2 +1,10 @@
|
|||||||
__all__ = ['commons', 'datalite_decorator', 'fetch', 'migrations', 'datalite', 'constraints', 'mass_actions']
|
__all__ = [
|
||||||
|
"commons",
|
||||||
|
"datalite_decorator",
|
||||||
|
"fetch",
|
||||||
|
"migrations",
|
||||||
|
"datalite",
|
||||||
|
"constraints",
|
||||||
|
"mass_actions",
|
||||||
|
]
|
||||||
from .datalite_decorator import datalite
|
from .datalite_decorator import datalite
|
||||||
|
|||||||
@@ -1,15 +1,28 @@
|
|||||||
from dataclasses import Field
|
from dataclasses import Field
|
||||||
from typing import Any, Optional, Dict, List
|
from pickle import HIGHEST_PROTOCOL, dumps, loads
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
import base64
|
||||||
|
|
||||||
from .constraints import Unique
|
from .constraints import Unique
|
||||||
import sqlite3 as sql
|
|
||||||
|
type_table: Dict[Optional[type], str] = {
|
||||||
|
None: "NULL",
|
||||||
|
int: "INTEGER",
|
||||||
|
float: "REAL",
|
||||||
|
str: "TEXT",
|
||||||
|
bytes: "BLOB",
|
||||||
|
bool: "INTEGER",
|
||||||
|
}
|
||||||
|
type_table.update(
|
||||||
|
{Unique[key]: f"{value} NOT NULL UNIQUE" for key, value in type_table.items()}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
type_table: Dict[Optional[type], str] = {None: "NULL", int: "INTEGER", float: "REAL",
|
def _convert_type(
|
||||||
str: "TEXT", bytes: "BLOB", bool: "INTEGER"}
|
type_: Optional[type], type_overload: Dict[Optional[type], str]
|
||||||
type_table.update({Unique[key]: f"{value} NOT NULL UNIQUE" for key, value in type_table.items()})
|
) -> str:
|
||||||
|
|
||||||
|
|
||||||
def _convert_type(type_: Optional[type], type_overload: Dict[Optional[type], str]) -> str:
|
|
||||||
"""
|
"""
|
||||||
Given a Python type, return the str name of its
|
Given a Python type, return the str name of its
|
||||||
SQLlite equivalent.
|
SQLlite equivalent.
|
||||||
@@ -22,19 +35,24 @@ def _convert_type(type_: Optional[type], type_overload: Dict[Optional[type], str
|
|||||||
try:
|
try:
|
||||||
return type_overload[type_]
|
return type_overload[type_]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise TypeError("Requested type not in the default or overloaded type table.")
|
raise TypeError(
|
||||||
|
"Requested type not in the default or overloaded type table. Use @datalite(tweaked=True) to "
|
||||||
|
"encode custom types"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _convert_sql_format(value: Any) -> str:
|
def _tweaked_convert_type(
|
||||||
|
type_: Optional[type], type_overload: Dict[Optional[type], str]
|
||||||
|
) -> str:
|
||||||
|
return type_overload.get(type_, "BLOB")
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_sql_format(value: Any, type_overload: Dict[Optional[type], str]) -> str:
|
||||||
"""
|
"""
|
||||||
Given a Python value, convert to string representation
|
Given a Python value, convert to string representation
|
||||||
of the equivalent SQL datatype.
|
of the equivalent SQL datatype.
|
||||||
:param value: A value, ie: a literal, a variable etc.
|
:param value: A value, ie: a literal, a variable etc.
|
||||||
:return: The string representation of the SQL equivalent.
|
:return: The string representation of the SQL equivalent.
|
||||||
>>> _convert_sql_format(1)
|
|
||||||
"1"
|
|
||||||
>>> _convert_sql_format("John Smith")
|
|
||||||
'"John Smith"'
|
|
||||||
"""
|
"""
|
||||||
if value is None:
|
if value is None:
|
||||||
return "NULL"
|
return "NULL"
|
||||||
@@ -44,11 +62,13 @@ def _convert_sql_format(value: Any) -> str:
|
|||||||
return '"' + str(value).replace("b'", "")[:-1] + '"'
|
return '"' + str(value).replace("b'", "")[:-1] + '"'
|
||||||
elif isinstance(value, bool):
|
elif isinstance(value, bool):
|
||||||
return "TRUE" if value else "FALSE"
|
return "TRUE" if value else "FALSE"
|
||||||
else:
|
elif type(value) in type_overload:
|
||||||
return str(value)
|
return str(value)
|
||||||
|
else:
|
||||||
|
return '"' + base64.encodebytes(dumps(value, protocol=HIGHEST_PROTOCOL)).decode() + '"'
|
||||||
|
|
||||||
|
|
||||||
def _get_table_cols(cur: sql.Cursor, table_name: str) -> List[str]:
|
async def _get_table_cols(cur: aiosqlite.Cursor, table_name: str) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Get the column data of a table.
|
Get the column data of a table.
|
||||||
|
|
||||||
@@ -56,11 +76,13 @@ def _get_table_cols(cur: sql.Cursor, table_name: str) -> List[str]:
|
|||||||
:param table_name: Name of the table.
|
:param table_name: Name of the table.
|
||||||
:return: the information about columns.
|
:return: the information about columns.
|
||||||
"""
|
"""
|
||||||
cur.execute(f"PRAGMA table_info({table_name});")
|
await cur.execute(f"PRAGMA table_info({table_name});")
|
||||||
return [row_info[1] for row_info in cur.fetchall()][1:]
|
return [row_info[1] for row_info in await cur.fetchall()][1:]
|
||||||
|
|
||||||
|
|
||||||
def _get_default(default_object: object, type_overload: Dict[Optional[type], str]) -> str:
|
def _get_default(
|
||||||
|
default_object: object, type_overload: Dict[Optional[type], str]
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Check if the field's default object is filled,
|
Check if the field's default object is filled,
|
||||||
if filled return the string to be put in the,
|
if filled return the string to be put in the,
|
||||||
@@ -71,11 +93,28 @@ def _get_default(default_object: object, type_overload: Dict[Optional[type], str
|
|||||||
empty string if no string is necessary.
|
empty string if no string is necessary.
|
||||||
"""
|
"""
|
||||||
if type(default_object) in type_overload:
|
if type(default_object) in type_overload:
|
||||||
return f' DEFAULT {_convert_sql_format(default_object)}'
|
return f" DEFAULT {_convert_sql_format(default_object, type_overload)}"
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def _create_table(class_: type, cursor: sql.Cursor, type_overload: Dict[Optional[type], str] = type_table) -> None:
|
# noinspection PyDefaultArgument
|
||||||
|
async def _tweaked_create_table(
|
||||||
|
class_: type,
|
||||||
|
cursor: aiosqlite.Cursor,
|
||||||
|
type_overload: Dict[Optional[type], str] = type_table,
|
||||||
|
) -> None:
|
||||||
|
await _create_table(
|
||||||
|
class_, cursor, type_overload, type_converter=_tweaked_convert_type
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyDefaultArgument
|
||||||
|
async def _create_table(
|
||||||
|
class_: type,
|
||||||
|
cursor: aiosqlite.Cursor,
|
||||||
|
type_overload: Dict[Optional[type], str] = type_table,
|
||||||
|
type_converter=_convert_type,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Create the table for a specific dataclass given
|
Create the table for a specific dataclass given
|
||||||
:param class_: A dataclass.
|
:param class_: A dataclass.
|
||||||
@@ -84,10 +123,35 @@ def _create_table(class_: type, cursor: sql.Cursor, type_overload: Dict[Optional
|
|||||||
with a custom table, this is that custom table.
|
with a custom table, this is that custom table.
|
||||||
:return: None.
|
:return: None.
|
||||||
"""
|
"""
|
||||||
fields: List[Field] = [class_.__dataclass_fields__[key] for
|
# noinspection PyUnresolvedReferences
|
||||||
key in class_.__dataclass_fields__.keys()]
|
fields: List[Field] = [
|
||||||
|
class_.__dataclass_fields__[key] for key in class_.__dataclass_fields__.keys()
|
||||||
|
]
|
||||||
fields.sort(key=lambda field: field.name) # Since dictionaries *may* be unsorted.
|
fields.sort(key=lambda field: field.name) # Since dictionaries *may* be unsorted.
|
||||||
sql_fields = ', '.join(f"{field.name} {_convert_type(field.type, type_overload)}"
|
|
||||||
f"{_get_default(field.default, type_overload)}" for field in fields)
|
sql_fields = ", ".join(
|
||||||
|
f"{field.name} {type_converter(field.type, type_overload)}"
|
||||||
|
f"{_get_default(field.default, type_overload)}"
|
||||||
|
for field in fields
|
||||||
|
)
|
||||||
|
|
||||||
sql_fields = "obj_id INTEGER PRIMARY KEY AUTOINCREMENT, " + sql_fields
|
sql_fields = "obj_id INTEGER PRIMARY KEY AUTOINCREMENT, " + sql_fields
|
||||||
cursor.execute(f"CREATE TABLE IF NOT EXISTS {class_.__name__.lower()} ({sql_fields});")
|
await cursor.execute(
|
||||||
|
f"CREATE TABLE IF NOT EXISTS {class_.__name__.lower()} ({sql_fields});"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _tweaked_dump_value(self, value):
|
||||||
|
if type(value) in self.types_table:
|
||||||
|
return value
|
||||||
|
else:
|
||||||
|
return bytes(dumps(value, protocol=HIGHEST_PROTOCOL))
|
||||||
|
|
||||||
|
|
||||||
|
def _tweaked_dump(self, name):
|
||||||
|
value = getattr(self, name)
|
||||||
|
return _tweaked_dump_value(self, value)
|
||||||
|
|
||||||
|
|
||||||
|
def _tweaked_load_value(data):
|
||||||
|
return loads(bytes(data))
|
||||||
|
|||||||
@@ -4,15 +4,16 @@ datalite.constraints module introduces constraint
|
|||||||
that can be used to signal datalite decorator
|
that can be used to signal datalite decorator
|
||||||
constraints in the database.
|
constraints in the database.
|
||||||
"""
|
"""
|
||||||
from typing import TypeVar, Union, Tuple
|
from typing import Tuple, TypeVar, Union
|
||||||
|
|
||||||
T = TypeVar('T')
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class ConstraintFailedError(Exception):
|
class ConstraintFailedError(Exception):
|
||||||
"""
|
"""
|
||||||
This exception is raised when a Constraint fails.
|
This exception is raised when a Constraint fails.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -24,3 +25,4 @@ Dataclass fields hinted with this type signals
|
|||||||
Unique = Union[Tuple[T], T]
|
Unique = Union[Tuple[T], T]
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ['Unique', 'ConstraintFailedError']
|
||||||
|
|||||||
@@ -1,93 +1,160 @@
|
|||||||
"""
|
"""
|
||||||
Defines the Datalite decorator that can be used to convert a dataclass to
|
Defines the Datalite decorator that can be used to convert a dataclass to
|
||||||
a class bound to a sqlite3 database.
|
a class bound to an sqlite3 database.
|
||||||
"""
|
"""
|
||||||
|
from dataclasses import asdict, fields
|
||||||
from sqlite3.dbapi2 import IntegrityError
|
from sqlite3.dbapi2 import IntegrityError
|
||||||
from typing import Dict, Optional, Callable
|
from typing import Callable, Dict, Optional
|
||||||
from dataclasses import asdict
|
|
||||||
import sqlite3 as sql
|
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
|
||||||
|
from .commons import _create_table, _tweaked_create_table, _tweaked_dump, type_table
|
||||||
from .constraints import ConstraintFailedError
|
from .constraints import ConstraintFailedError
|
||||||
from .commons import _convert_sql_format, _convert_type, _create_table, type_table
|
|
||||||
|
|
||||||
|
|
||||||
def _create_entry(self) -> None:
|
async def _create_entry(self) -> None:
|
||||||
"""
|
"""
|
||||||
Given an object, create the entry for the object. As a side-effect,
|
Given an object, create the entry for the object. As a side effect,
|
||||||
this will set the object_id attribute of the object to the unique
|
this will set the object_id attribute of the object to the unique
|
||||||
id of the entry.
|
id of the entry.
|
||||||
:param self: Instance of the object.
|
:param self: Instance of the object.
|
||||||
:return: None.
|
:return: None.
|
||||||
"""
|
"""
|
||||||
with sql.connect(getattr(self, "db_path")) as con:
|
async with aiosqlite.connect(getattr(self, "db_path")) as con:
|
||||||
cur: sql.Cursor = con.cursor()
|
cur: aiosqlite.Cursor = await con.cursor()
|
||||||
table_name: str = self.__class__.__name__.lower()
|
table_name: str = self.__class__.__name__.lower()
|
||||||
kv_pairs = [item for item in asdict(self).items()]
|
kv_pairs = [item for item in asdict(self).items()]
|
||||||
kv_pairs.sort(key=lambda item: item[0]) # Sort by the name of the fields.
|
kv_pairs.sort(key=lambda item: item[0]) # Sort by the name of the fields.
|
||||||
try:
|
try:
|
||||||
cur.execute(f"INSERT INTO {table_name}("
|
await cur.execute(
|
||||||
|
f"INSERT INTO {table_name}("
|
||||||
f"{', '.join(item[0] for item in kv_pairs)})"
|
f"{', '.join(item[0] for item in kv_pairs)})"
|
||||||
f" VALUES ({', '.join('?' for item in kv_pairs)})",
|
f" VALUES ({', '.join('?' for _ in kv_pairs)})",
|
||||||
[item[1] for item in kv_pairs])
|
[item[1] for item in kv_pairs],
|
||||||
|
)
|
||||||
self.__setattr__("obj_id", cur.lastrowid)
|
self.__setattr__("obj_id", cur.lastrowid)
|
||||||
con.commit()
|
await con.commit()
|
||||||
except IntegrityError:
|
except IntegrityError:
|
||||||
raise ConstraintFailedError("A constraint has failed.")
|
raise ConstraintFailedError("A constraint has failed.")
|
||||||
|
|
||||||
|
|
||||||
def _update_entry(self) -> None:
|
async def _tweaked_create_entry(self) -> None:
|
||||||
|
async with aiosqlite.connect(getattr(self, "db_path")) as con:
|
||||||
|
cur: aiosqlite.Cursor = await con.cursor()
|
||||||
|
table_name: str = self.__class__.__name__.lower()
|
||||||
|
kv_pairs = [item for item in fields(self)]
|
||||||
|
kv_pairs.sort(key=lambda item: item.name) # Sort by the name of the fields.
|
||||||
|
try:
|
||||||
|
await cur.execute(
|
||||||
|
f"INSERT INTO {table_name}("
|
||||||
|
f"{', '.join(item.name for item in kv_pairs)})"
|
||||||
|
f" VALUES ({', '.join('?' for _ in kv_pairs)})",
|
||||||
|
[_tweaked_dump(self, item.name) for item in kv_pairs],
|
||||||
|
)
|
||||||
|
self.__setattr__("obj_id", cur.lastrowid)
|
||||||
|
await con.commit()
|
||||||
|
except IntegrityError:
|
||||||
|
raise ConstraintFailedError("A constraint has failed.")
|
||||||
|
|
||||||
|
|
||||||
|
async def _update_entry(self) -> None:
|
||||||
"""
|
"""
|
||||||
Given an object, update the objects entry in the bound database.
|
Given an object, update the objects entry in the bound database.
|
||||||
:param self: The object.
|
:param self: The object.
|
||||||
:return: None.
|
:return: None.
|
||||||
"""
|
"""
|
||||||
with sql.connect(getattr(self, "db_path")) as con:
|
async with aiosqlite.connect(getattr(self, "db_path")) as con:
|
||||||
cur: sql.Cursor = con.cursor()
|
cur: aiosqlite.Cursor = await con.cursor()
|
||||||
table_name: str = self.__class__.__name__.lower()
|
table_name: str = self.__class__.__name__.lower()
|
||||||
kv_pairs = [item for item in asdict(self).items()]
|
kv_pairs = [item for item in asdict(self).items()]
|
||||||
kv_pairs.sort(key=lambda item: item[0])
|
kv_pairs.sort(key=lambda item: item[0])
|
||||||
query = f"UPDATE {table_name} " + \
|
query = (
|
||||||
f"SET {', '.join(item[0] + ' = ?' for item in kv_pairs)} " + \
|
f"UPDATE {table_name} "
|
||||||
|
f"SET {', '.join(item[0] + ' = ?' for item in kv_pairs)} "
|
||||||
f"WHERE obj_id = {getattr(self, 'obj_id')};"
|
f"WHERE obj_id = {getattr(self, 'obj_id')};"
|
||||||
cur.execute(query, [item[1] for item in kv_pairs])
|
)
|
||||||
con.commit()
|
await cur.execute(query, [item[1] for item in kv_pairs])
|
||||||
|
await con.commit()
|
||||||
|
|
||||||
|
|
||||||
def remove_from(class_: type, obj_id: int):
|
async def _tweaked_update_entry(self) -> None:
|
||||||
with sql.connect(getattr(class_, "db_path")) as con:
|
async with aiosqlite.connect(getattr(self, "db_path")) as con:
|
||||||
cur: sql.Cursor = con.cursor()
|
cur: aiosqlite.Cursor = await con.cursor()
|
||||||
cur.execute(f"DELETE FROM {class_.__name__.lower()} WHERE obj_id = ?", (obj_id, ))
|
table_name: str = self.__class__.__name__.lower()
|
||||||
con.commit()
|
kv_pairs = [item for item in fields(self)]
|
||||||
|
kv_pairs.sort(key=lambda item: item.name)
|
||||||
|
query = (
|
||||||
|
f"UPDATE {table_name} "
|
||||||
|
f"SET {', '.join(item.name + ' = ?' for item in kv_pairs)} "
|
||||||
|
f"WHERE obj_id = {getattr(self, 'obj_id')};"
|
||||||
|
)
|
||||||
|
await cur.execute(query, [_tweaked_dump(self, item.name) for item in kv_pairs])
|
||||||
|
await con.commit()
|
||||||
|
|
||||||
|
|
||||||
def _remove_entry(self) -> None:
|
async def remove_from(class_: type, obj_id: int):
|
||||||
|
async with aiosqlite.connect(getattr(class_, "db_path")) as con:
|
||||||
|
cur: aiosqlite.Cursor = await con.cursor()
|
||||||
|
await cur.execute(
|
||||||
|
f"DELETE FROM {class_.__name__.lower()} WHERE obj_id = ?", (obj_id,)
|
||||||
|
)
|
||||||
|
await con.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def _remove_entry(self) -> None:
|
||||||
"""
|
"""
|
||||||
Remove the object's record in the underlying database.
|
Remove the object's record in the underlying database.
|
||||||
:param self: self instance.
|
:param self: self instance.
|
||||||
:return: None.
|
:return: None.
|
||||||
"""
|
"""
|
||||||
remove_from(self.__class__, getattr(self, 'obj_id'))
|
await remove_from(self.__class__, getattr(self, "obj_id"))
|
||||||
|
|
||||||
|
|
||||||
def datalite(db_path: str, type_overload: Optional[Dict[Optional[type], str]] = None) -> Callable:
|
def _markup_table(markup_function):
|
||||||
|
async def inner(self=None, **kwargs):
|
||||||
|
if not kwargs:
|
||||||
|
async with aiosqlite.connect(getattr(self, "db_path")) as con:
|
||||||
|
cur: aiosqlite.Cursor = await con.cursor()
|
||||||
|
await markup_function(self.__class__, cur, self.types_table)
|
||||||
|
else:
|
||||||
|
await markup_function(**kwargs)
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
|
def datalite(
|
||||||
|
db_path: str,
|
||||||
|
type_overload: Optional[Dict[Optional[type], str]] = None,
|
||||||
|
tweaked: bool = True,
|
||||||
|
) -> Callable:
|
||||||
"""Bind a dataclass to a sqlite3 database. This adds new methods to the class, such as
|
"""Bind a dataclass to a sqlite3 database. This adds new methods to the class, such as
|
||||||
`create_entry()`, `remove_entry()` and `update_entry()`.
|
`create_entry()`, `remove_entry()` and `update_entry()`.
|
||||||
|
|
||||||
:param db_path: Path of the database to be binded.
|
:param db_path: Path of the database to be bound.
|
||||||
:param type_overload: Type overload dictionary.
|
:param type_overload: Type overload dictionary.
|
||||||
|
:param tweaked: Whether to use pickle type tweaks
|
||||||
:return: The new dataclass.
|
:return: The new dataclass.
|
||||||
"""
|
"""
|
||||||
def decorator(dataclass_: type, *args_i, **kwargs_i):
|
|
||||||
|
def decorator(dataclass_: type, *_, **__):
|
||||||
types_table = type_table.copy()
|
types_table = type_table.copy()
|
||||||
if type_overload is not None:
|
if type_overload is not None:
|
||||||
types_table.update(type_overload)
|
types_table.update(type_overload)
|
||||||
with sql.connect(db_path) as con:
|
|
||||||
cur: sql.Cursor = con.cursor()
|
setattr(dataclass_, "db_path", db_path)
|
||||||
_create_table(dataclass_, cur, types_table)
|
setattr(dataclass_, "types_table", types_table)
|
||||||
setattr(dataclass_, 'db_path', db_path) # We add the path of the database to class itself.
|
setattr(dataclass_, "tweaked", tweaked)
|
||||||
setattr(dataclass_, 'types_table', types_table) # We add the type table for migration.
|
|
||||||
|
if tweaked:
|
||||||
|
dataclass_.markup_table = _markup_table(_tweaked_create_table)
|
||||||
|
dataclass_.create_entry = _tweaked_create_entry
|
||||||
|
dataclass_.update_entry = _tweaked_update_entry
|
||||||
|
else:
|
||||||
|
dataclass_.markup_table = _markup_table(_create_table)
|
||||||
dataclass_.create_entry = _create_entry
|
dataclass_.create_entry = _create_entry
|
||||||
dataclass_.remove_entry = _remove_entry
|
|
||||||
dataclass_.update_entry = _update_entry
|
dataclass_.update_entry = _update_entry
|
||||||
|
dataclass_.remove_entry = _remove_entry
|
||||||
|
|
||||||
return dataclass_
|
return dataclass_
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|||||||
@@ -1,6 +1,12 @@
|
|||||||
import sqlite3 as sql
|
import aiosqlite
|
||||||
from typing import List, Tuple, Any
|
from typing import Any, List, Tuple, TypeVar, Type, cast
|
||||||
from .commons import _convert_sql_format, _get_table_cols
|
|
||||||
|
from .commons import _convert_sql_format, _get_table_cols, _tweaked_load_value
|
||||||
|
|
||||||
|
import base64
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def _insert_pagination(query: str, page: int, element_count: int) -> str:
|
def _insert_pagination(query: str, page: int, element_count: int) -> str:
|
||||||
@@ -17,7 +23,7 @@ def _insert_pagination(query: str, page: int, element_count: int) -> str:
|
|||||||
return query + ";"
|
return query + ";"
|
||||||
|
|
||||||
|
|
||||||
def is_fetchable(class_: type, obj_id: int) -> bool:
|
async def is_fetchable(class_: type, obj_id: int) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if a record is fetchable given its obj_id and
|
Check if a record is fetchable given its obj_id and
|
||||||
class_ type.
|
class_ type.
|
||||||
@@ -26,16 +32,22 @@ def is_fetchable(class_: type, obj_id: int) -> bool:
|
|||||||
:param obj_id: Unique obj_id of the object.
|
:param obj_id: Unique obj_id of the object.
|
||||||
:return: If the object is fetchable.
|
:return: If the object is fetchable.
|
||||||
"""
|
"""
|
||||||
with sql.connect(getattr(class_, 'db_path')) as con:
|
async with aiosqlite.connect(getattr(class_, "db_path")) as con:
|
||||||
cur: sql.Cursor = con.cursor()
|
cur: aiosqlite.Cursor = await con.cursor()
|
||||||
try:
|
try:
|
||||||
cur.execute(f"SELECT 1 FROM {class_.__name__.lower()} WHERE obj_id = ?;", (obj_id, ))
|
await cur.execute(
|
||||||
except sql.OperationalError:
|
f"SELECT 1 FROM {class_.__name__.lower()} WHERE obj_id = ?;", (obj_id,)
|
||||||
|
)
|
||||||
|
except aiosqlite.OperationalError:
|
||||||
raise KeyError(f"Table {class_.__name__.lower()} does not exist.")
|
raise KeyError(f"Table {class_.__name__.lower()} does not exist.")
|
||||||
return bool(cur.fetchall())
|
return bool(await cur.fetchall())
|
||||||
|
|
||||||
|
|
||||||
def fetch_equals(class_: type, field: str, value: Any, ) -> Any:
|
async def fetch_equals(
|
||||||
|
class_: Type[T],
|
||||||
|
field: str,
|
||||||
|
value: Any,
|
||||||
|
) -> T:
|
||||||
"""
|
"""
|
||||||
Fetch a class_ type variable from its bound db.
|
Fetch a class_ type variable from its bound db.
|
||||||
|
|
||||||
@@ -45,18 +57,30 @@ def fetch_equals(class_: type, field: str, value: Any, ) -> Any:
|
|||||||
:return: The object whose data is taken from the database.
|
:return: The object whose data is taken from the database.
|
||||||
"""
|
"""
|
||||||
table_name = class_.__name__.lower()
|
table_name = class_.__name__.lower()
|
||||||
with sql.connect(getattr(class_, 'db_path')) as con:
|
async with aiosqlite.connect(getattr(class_, "db_path")) as con:
|
||||||
cur: sql.Cursor = con.cursor()
|
cur: aiosqlite.Cursor = await con.cursor()
|
||||||
cur.execute(f"SELECT * FROM {table_name} WHERE {field} = ?;", (value, ))
|
await cur.execute(f"SELECT * FROM {table_name} WHERE {field} = ?;", (value,))
|
||||||
obj_id, *field_values = list(cur.fetchone())
|
obj_id, *field_values = list(await cur.fetchone())
|
||||||
field_names: List[str] = _get_table_cols(cur, class_.__name__.lower())
|
field_names: List[str] = await _get_table_cols(cur, class_.__name__.lower())
|
||||||
|
|
||||||
kwargs = dict(zip(field_names, field_values))
|
kwargs = dict(zip(field_names, field_values))
|
||||||
|
if getattr(class_, "tweaked"):
|
||||||
|
types_table = getattr(class_, "types_table")
|
||||||
|
field_types = {
|
||||||
|
key: value.type for key, value in class_.__dataclass_fields__.items()
|
||||||
|
}
|
||||||
|
for key in kwargs.keys():
|
||||||
|
if field_types[key] not in types_table.keys():
|
||||||
|
kwargs[key] = _tweaked_load_value(
|
||||||
|
kwargs[key] if isinstance(kwargs[key], bytes) else base64.decodebytes(kwargs[key].encode("utf-8"))
|
||||||
|
)
|
||||||
|
|
||||||
obj = class_(**kwargs)
|
obj = class_(**kwargs)
|
||||||
setattr(obj, "obj_id", obj_id)
|
setattr(obj, "obj_id", obj_id)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
def fetch_from(class_: type, obj_id: int) -> Any:
|
async def fetch_from(class_: Type[T], obj_id: int) -> T:
|
||||||
"""
|
"""
|
||||||
Fetch a class_ type variable from its bound dv.
|
Fetch a class_ type variable from its bound dv.
|
||||||
|
|
||||||
@@ -64,13 +88,17 @@ def fetch_from(class_: type, obj_id: int) -> Any:
|
|||||||
:param obj_id: Unique object id of the object.
|
:param obj_id: Unique object id of the object.
|
||||||
:return: The fetched object.
|
:return: The fetched object.
|
||||||
"""
|
"""
|
||||||
if not is_fetchable(class_, obj_id):
|
if not await is_fetchable(class_, obj_id):
|
||||||
raise KeyError(f"An object with {obj_id} of type {class_.__name__} does not exist, or"
|
raise KeyError(
|
||||||
f"otherwise is unreachable.")
|
f"An object with {obj_id} of type {class_.__name__} does not exist, or"
|
||||||
return fetch_equals(class_, 'obj_id', obj_id)
|
f"otherwise is unreachable."
|
||||||
|
)
|
||||||
|
return await fetch_equals(class_, "obj_id", obj_id)
|
||||||
|
|
||||||
|
|
||||||
def _convert_record_to_object(class_: type, record: Tuple[Any], field_names: List[str]) -> Any:
|
def _convert_record_to_object(
|
||||||
|
class_: Type[T], record: Tuple[Any], field_names: List[str]
|
||||||
|
) -> T:
|
||||||
"""
|
"""
|
||||||
Convert a given record fetched from an SQL instance to a Python Object of given class_.
|
Convert a given record fetched from an SQL instance to a Python Object of given class_.
|
||||||
|
|
||||||
@@ -80,17 +108,30 @@ def _convert_record_to_object(class_: type, record: Tuple[Any], field_names: Lis
|
|||||||
:return: the created object.
|
:return: the created object.
|
||||||
"""
|
"""
|
||||||
kwargs = dict(zip(field_names, record[1:]))
|
kwargs = dict(zip(field_names, record[1:]))
|
||||||
field_types = {key: value.type for key, value in class_.__dataclass_fields__.items()}
|
field_types = {
|
||||||
for key in kwargs:
|
key: value.type for key, value in class_.__dataclass_fields__.items()
|
||||||
|
}
|
||||||
|
is_tweaked = getattr(class_, "tweaked")
|
||||||
|
types_table = getattr(class_, "types_table")
|
||||||
|
for key in kwargs.keys():
|
||||||
if field_types[key] == bytes:
|
if field_types[key] == bytes:
|
||||||
kwargs[key] = bytes(kwargs[key], encoding='utf-8')
|
kwargs[key] = bytes(kwargs[key], encoding="utf-8")
|
||||||
|
|
||||||
|
elif is_tweaked:
|
||||||
|
if field_types[key] not in types_table.keys():
|
||||||
|
kwargs[key] = _tweaked_load_value(
|
||||||
|
kwargs[key] if isinstance(kwargs[key], bytes) else base64.decodebytes(kwargs[key].encode("utf-8"))
|
||||||
|
)
|
||||||
|
|
||||||
obj_id = record[0]
|
obj_id = record[0]
|
||||||
obj = class_(**kwargs)
|
obj = class_(**kwargs)
|
||||||
setattr(obj, "obj_id", obj_id)
|
setattr(obj, "obj_id", obj_id)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
def fetch_if(class_: type, condition: str, page: int = 0, element_count: int = 10) -> tuple:
|
async def fetch_if(
|
||||||
|
class_: Type[T], condition: str, page: int = 0, element_count: int = 10
|
||||||
|
) -> T:
|
||||||
"""
|
"""
|
||||||
Fetch all class_ type variables from the bound db,
|
Fetch all class_ type variables from the bound db,
|
||||||
provided they fit the given condition
|
provided they fit the given condition
|
||||||
@@ -103,18 +144,27 @@ def fetch_if(class_: type, condition: str, page: int = 0, element_count: int = 1
|
|||||||
of given type class_.
|
of given type class_.
|
||||||
"""
|
"""
|
||||||
table_name = class_.__name__.lower()
|
table_name = class_.__name__.lower()
|
||||||
with sql.connect(getattr(class_, 'db_path')) as con:
|
async with aiosqlite.connect(getattr(class_, "db_path")) as con:
|
||||||
cur: sql.Cursor = con.cursor()
|
cur: aiosqlite.Cursor = await con.cursor()
|
||||||
cur.execute(_insert_pagination(f"SELECT * FROM {table_name} WHERE {condition}", page, element_count))
|
await cur.execute(
|
||||||
records: list = cur.fetchall()
|
_insert_pagination(
|
||||||
field_names: List[str] = _get_table_cols(cur, table_name)
|
f"SELECT * FROM {table_name} WHERE {condition}", page, element_count
|
||||||
return tuple(_convert_record_to_object(class_, record, field_names) for record in records)
|
)
|
||||||
|
)
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
records: list = await cur.fetchall()
|
||||||
|
field_names: List[str] = await _get_table_cols(cur, table_name)
|
||||||
|
return tuple(
|
||||||
|
_convert_record_to_object(class_, record, field_names) for record in records
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def fetch_where(class_: type, field: str, value: Any, page: int = 0, element_count: int = 10) -> tuple:
|
async def fetch_where(
|
||||||
|
class_: Type[T], field: str, value: Any, page: int = 0, element_count: int = 10
|
||||||
|
) -> tuple[T]:
|
||||||
"""
|
"""
|
||||||
Fetch all class_ type variables from the bound db,
|
Fetch all class_ type variables from the bound db
|
||||||
provided that the field of the records fit the
|
if the field of the records fit the
|
||||||
given value.
|
given value.
|
||||||
|
|
||||||
:param class_: Class of the records.
|
:param class_: Class of the records.
|
||||||
@@ -124,10 +174,12 @@ def fetch_where(class_: type, field: str, value: Any, page: int = 0, element_cou
|
|||||||
:param element_count: Element count in each page.
|
:param element_count: Element count in each page.
|
||||||
:return: A tuple of the records.
|
:return: A tuple of the records.
|
||||||
"""
|
"""
|
||||||
return fetch_if(class_, f"{field} = {_convert_sql_format(value)}", page, element_count)
|
return await fetch_if(
|
||||||
|
class_, f"{field} = {_convert_sql_format(value, getattr(class_, 'types_table'))}", page, element_count
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def fetch_range(class_: type, range_: range) -> tuple:
|
async def fetch_range(class_: Type[T], range_: range) -> tuple[T]:
|
||||||
"""
|
"""
|
||||||
Fetch the records in a given range of object ids.
|
Fetch the records in a given range of object ids.
|
||||||
|
|
||||||
@@ -136,12 +188,23 @@ def fetch_range(class_: type, range_: range) -> tuple:
|
|||||||
:return: A tuple of class_ type objects whose values
|
:return: A tuple of class_ type objects whose values
|
||||||
come from the class_' bound database.
|
come from the class_' bound database.
|
||||||
"""
|
"""
|
||||||
return tuple(fetch_from(class_, obj_id) for obj_id in range_ if is_fetchable(class_, obj_id))
|
return cast(
|
||||||
|
tuple[T],
|
||||||
|
tuple(
|
||||||
|
[
|
||||||
|
(await fetch_from(class_, obj_id))
|
||||||
|
for obj_id in range_
|
||||||
|
if (await is_fetchable(class_, obj_id))
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def fetch_all(class_: type, page: int = 0, element_count: int = 10) -> tuple:
|
async def fetch_all(
|
||||||
|
class_: Type[T], page: int = 0, element_count: int = 10
|
||||||
|
) -> tuple[T]:
|
||||||
"""
|
"""
|
||||||
Fetchall the records in the bound database.
|
Fetch all the records in the bound database.
|
||||||
|
|
||||||
:param class_: Class of the records.
|
:param class_: Class of the records.
|
||||||
:param page: Which page to retrieve, default all. (0 means closed).
|
:param page: Which page to retrieve, default all. (0 means closed).
|
||||||
@@ -150,15 +213,26 @@ def fetch_all(class_: type, page: int = 0, element_count: int = 10) -> tuple:
|
|||||||
the bound database as a tuple.
|
the bound database as a tuple.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
db_path = getattr(class_, 'db_path')
|
db_path = getattr(class_, "db_path")
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise TypeError("Given class is not decorated with datalite.")
|
raise TypeError("Given class is not decorated with datalite.")
|
||||||
with sql.connect(db_path) as con:
|
async with aiosqlite.connect(db_path) as con:
|
||||||
cur: sql.Cursor = con.cursor()
|
cur: aiosqlite.Cursor = await con.cursor()
|
||||||
try:
|
try:
|
||||||
cur.execute(_insert_pagination(f"SELECT * FROM {class_.__name__.lower()}", page, element_count))
|
await cur.execute(
|
||||||
except sql.OperationalError:
|
_insert_pagination(
|
||||||
|
f"SELECT * FROM {class_.__name__.lower()}", page, element_count
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except aiosqlite.OperationalError:
|
||||||
raise TypeError(f"No record of type {class_.__name__.lower()}")
|
raise TypeError(f"No record of type {class_.__name__.lower()}")
|
||||||
records = cur.fetchall()
|
records = await cur.fetchall()
|
||||||
field_names: List[str] = _get_table_cols(cur, class_.__name__.lower())
|
field_names: List[str] = await _get_table_cols(cur, class_.__name__.lower())
|
||||||
return tuple(_convert_record_to_object(class_, record, field_names) for record in records)
|
# noinspection PyTypeChecker
|
||||||
|
return cast(
|
||||||
|
tuple[T],
|
||||||
|
tuple(_convert_record_to_object(class_, record, field_names) for record in records),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["is_fetchable", "fetch_equals", "fetch_from", "fetch_if", "fetch_where", "fetch_range", "fetch_all"]
|
||||||
|
|||||||
@@ -3,14 +3,15 @@ This module includes functions to insert multiple records
|
|||||||
to a bound database at one time, with one time open and closing
|
to a bound database at one time, with one time open and closing
|
||||||
of the database file.
|
of the database file.
|
||||||
"""
|
"""
|
||||||
from typing import TypeVar, Union, List, Tuple
|
import aiosqlite
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
|
from typing import List, Tuple, TypeVar, Union
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
from .constraints import ConstraintFailedError
|
|
||||||
from .commons import _convert_sql_format, _create_table
|
|
||||||
import sqlite3 as sql
|
|
||||||
|
|
||||||
T = TypeVar('T')
|
from .commons import _convert_sql_format
|
||||||
|
from .constraints import ConstraintFailedError
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousCollectionError(Exception):
|
class HeterogeneousCollectionError(Exception):
|
||||||
@@ -19,45 +20,56 @@ class HeterogeneousCollectionError(Exception):
|
|||||||
ie: If a List or Tuple has elements of multiple
|
ie: If a List or Tuple has elements of multiple
|
||||||
types.
|
types.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _check_homogeneity(objects: Union[List[T], Tuple[T]]) -> None:
|
def _check_homogeneity(objects: Union[List[T], Tuple[T]]) -> None:
|
||||||
"""
|
"""
|
||||||
Check if all of the members a Tuple or a List
|
Check if all the members a Tuple or a List
|
||||||
is of the same type.
|
is of the same type.
|
||||||
|
|
||||||
:param objects: Tuple or list to check.
|
:param objects: Tuple or list to check.
|
||||||
:return: If all of the members of the same type.
|
:return: If all the members of the same type.
|
||||||
"""
|
"""
|
||||||
class_ = objects[0].__class__
|
class_ = objects[0].__class__
|
||||||
if not all([isinstance(obj, class_) or isinstance(objects[0], obj.__class__) for obj in objects]):
|
if not all(
|
||||||
|
[
|
||||||
|
isinstance(obj, class_) or isinstance(objects[0], obj.__class__)
|
||||||
|
for obj in objects
|
||||||
|
]
|
||||||
|
):
|
||||||
raise HeterogeneousCollectionError("Tuple or List is not homogeneous.")
|
raise HeterogeneousCollectionError("Tuple or List is not homogeneous.")
|
||||||
|
|
||||||
|
|
||||||
def _toggle_memory_protection(cur: sql.Cursor, protect_memory: bool) -> None:
|
async def _toggle_memory_protection(cur: aiosqlite.Cursor, protect_memory: bool) -> None:
|
||||||
"""
|
"""
|
||||||
Given a cursor to an sqlite3 connection, if memory protection is false,
|
Given a cursor to a sqlite3 connection, if memory protection is false,
|
||||||
toggle memory protections off.
|
toggle memory protections off.
|
||||||
|
|
||||||
:param cur: Cursor to an open SQLite3 connection.
|
:param cur: Cursor to an open SQLite3 connection.
|
||||||
:param protect_memory: Whether or not should memory be protected.
|
:param protect_memory: Whether should memory be protected.
|
||||||
:return: Memory protections off.
|
:return: Memory protections off.
|
||||||
"""
|
"""
|
||||||
if not protect_memory:
|
if not protect_memory:
|
||||||
warn("Memory protections are turned off, "
|
warn(
|
||||||
"if operations are interrupted, file may get corrupt.", RuntimeWarning)
|
"Memory protections are turned off, "
|
||||||
cur.execute("PRAGMA synchronous = OFF")
|
"if operations are interrupted, file may get corrupt.",
|
||||||
cur.execute("PRAGMA journal_mode = MEMORY")
|
RuntimeWarning,
|
||||||
|
)
|
||||||
|
await cur.execute("PRAGMA synchronous = OFF")
|
||||||
|
await cur.execute("PRAGMA journal_mode = MEMORY")
|
||||||
|
|
||||||
|
|
||||||
def _mass_insert(objects: Union[List[T], Tuple[T]], db_name: str, protect_memory: bool = True) -> None:
|
async def _mass_insert(
|
||||||
|
objects: Union[List[T], Tuple[T]], db_name: str, protect_memory: bool = True
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Insert multiple records into an SQLite3 database.
|
Insert multiple records into an SQLite3 database.
|
||||||
|
|
||||||
:param objects: Objects to insert.
|
:param objects: Objects to insert.
|
||||||
:param db_name: Name of the database to insert.
|
:param db_name: Name of the database to insert.
|
||||||
:param protect_memory: Whether or not memory
|
:param protect_memory: Whether memory
|
||||||
protections are on or off.
|
protections are on or off.
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
@@ -69,24 +81,28 @@ def _mass_insert(objects: Union[List[T], Tuple[T]], db_name: str, protect_memory
|
|||||||
for i, obj in enumerate(objects):
|
for i, obj in enumerate(objects):
|
||||||
kv_pairs = asdict(obj).items()
|
kv_pairs = asdict(obj).items()
|
||||||
setattr(obj, "obj_id", first_index + i + 1)
|
setattr(obj, "obj_id", first_index + i + 1)
|
||||||
sql_queries.append(f"INSERT INTO {table_name}(" +
|
sql_queries.append(
|
||||||
f"{', '.join(item[0] for item in kv_pairs)})" +
|
f"INSERT INTO {table_name}("
|
||||||
f" VALUES ({', '.join(_convert_sql_format(item[1]) for item in kv_pairs)});")
|
+ f"{', '.join(item[0] for item in kv_pairs)})"
|
||||||
with sql.connect(db_name) as con:
|
+ f" VALUES ({', '.join(_convert_sql_format(item[1], getattr(obj, 'types_table')) for item in kv_pairs)});"
|
||||||
cur: sql.Cursor = con.cursor()
|
)
|
||||||
|
async with aiosqlite.connect(db_name) as con:
|
||||||
|
cur: aiosqlite.Cursor = await con.cursor()
|
||||||
try:
|
try:
|
||||||
_toggle_memory_protection(cur, protect_memory)
|
await _toggle_memory_protection(cur, protect_memory)
|
||||||
cur.execute(f"SELECT obj_id FROM {table_name} ORDER BY obj_id DESC LIMIT 1")
|
await cur.execute(f"SELECT obj_id FROM {table_name} ORDER BY obj_id DESC LIMIT 1")
|
||||||
index_tuple = cur.fetchone()
|
index_tuple = await cur.fetchone()
|
||||||
if index_tuple:
|
if index_tuple:
|
||||||
first_index = index_tuple[0]
|
_ = index_tuple[0]
|
||||||
cur.executescript("BEGIN TRANSACTION;\n" + '\n'.join(sql_queries) + '\nEND TRANSACTION;')
|
await cur.executescript(
|
||||||
except sql.IntegrityError:
|
"BEGIN TRANSACTION;\n" + "\n".join(sql_queries) + "\nEND TRANSACTION;"
|
||||||
|
)
|
||||||
|
except aiosqlite.IntegrityError:
|
||||||
raise ConstraintFailedError
|
raise ConstraintFailedError
|
||||||
con.commit()
|
await con.commit()
|
||||||
|
|
||||||
|
|
||||||
def create_many(objects: Union[List[T], Tuple[T]], protect_memory: bool = True) -> None:
|
async def create_many(objects: Union[List[T], Tuple[T]], protect_memory: bool = True) -> None:
|
||||||
"""
|
"""
|
||||||
Insert many records corresponding to objects
|
Insert many records corresponding to objects
|
||||||
in a tuple or a list.
|
in a tuple or a list.
|
||||||
@@ -98,29 +114,34 @@ def create_many(objects: Union[List[T], Tuple[T]], protect_memory: bool = True)
|
|||||||
:return: None.
|
:return: None.
|
||||||
"""
|
"""
|
||||||
if objects:
|
if objects:
|
||||||
_mass_insert(objects, getattr(objects[0], "db_path"), protect_memory)
|
await _mass_insert(objects, getattr(objects[0], "db_path"), protect_memory)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Collection is empty.")
|
raise ValueError("Collection is empty.")
|
||||||
|
|
||||||
|
|
||||||
def copy_many(objects: Union[List[T], Tuple[T]], db_name: str, protect_memory: bool = True) -> None:
|
async def copy_many(
|
||||||
|
objects: Union[List[T], Tuple[T]], db_name: str, protect_memory: bool = True
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Copy many records to another database, from
|
Copy many records to another database, from
|
||||||
their original database to new database, do
|
their original database to a new database, do
|
||||||
not delete old records.
|
not delete old records.
|
||||||
|
|
||||||
:param objects: Objects to copy.
|
:param objects: Objects to copy.
|
||||||
:param db_name: Name of the new database.
|
:param db_name: Name of the new database.
|
||||||
:param protect_memory: Wheter to protect memory during operation,
|
:param protect_memory: Whether to protect memory during operation,
|
||||||
Setting this to False will quicken the operation, but if the
|
Setting this to False will quicken the operation, but if the
|
||||||
operation is cut short, database file will corrupt.
|
operation is cut short, the database file will corrupt.
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
if objects:
|
if objects:
|
||||||
with sql.connect(db_name) as con:
|
async with aiosqlite.connect(db_name) as con:
|
||||||
cur = con.cursor()
|
cur = await con.cursor()
|
||||||
_create_table(objects[0].__class__, cur)
|
await objects[0].markup_table(class_=objects[0].__class__, cursor=cur)
|
||||||
con.commit()
|
await con.commit()
|
||||||
_mass_insert(objects, db_name, protect_memory)
|
await _mass_insert(objects, db_name, protect_memory)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Collection is empty.")
|
raise ValueError("Collection is empty.")
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ['copy_many', 'create_many', 'HeterogeneousCollectionError']
|
||||||
|
|||||||
@@ -2,53 +2,59 @@
|
|||||||
Migrations module deals with migrating data when the object
|
Migrations module deals with migrating data when the object
|
||||||
definitions change. This functions deal with Schema Migrations.
|
definitions change. This functions deal with Schema Migrations.
|
||||||
"""
|
"""
|
||||||
|
import shutil
|
||||||
|
import time
|
||||||
from dataclasses import Field
|
from dataclasses import Field
|
||||||
from os.path import exists
|
from os.path import exists
|
||||||
from typing import Dict, Tuple, List
|
from typing import Any, Dict, List, Tuple, cast
|
||||||
import sqlite3 as sql
|
|
||||||
|
|
||||||
from .commons import _create_table, _get_table_cols
|
import aiosqlite
|
||||||
|
|
||||||
|
from .commons import _get_table_cols, _tweaked_dump_value
|
||||||
|
|
||||||
|
|
||||||
def _get_db_table(class_: type) -> Tuple[str, str]:
|
async def _get_db_table(class_: type) -> Tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Check if the class is a datalite class, the database exists
|
Check if the class is a datalite class, the database exists,
|
||||||
and the table exists. Return database and table names.
|
and the table exists. Return database and table names.
|
||||||
|
|
||||||
:param class_: A datalite class.
|
:param class_: A datalite class.
|
||||||
:return: A tuple of database and table names.
|
:return: A tuple of database and table names.
|
||||||
"""
|
"""
|
||||||
database_name: str = getattr(class_, 'db_path', None)
|
database_name: str = getattr(class_, "db_path", None)
|
||||||
if not database_name:
|
if not database_name:
|
||||||
raise TypeError(f"{class_.__name__} is not a datalite class.")
|
raise TypeError(f"{class_.__name__} is not a datalite class.")
|
||||||
table_name: str = class_.__name__.lower()
|
table_name: str = class_.__name__.lower()
|
||||||
if not exists(database_name):
|
if not exists(database_name):
|
||||||
raise FileNotFoundError(f"{database_name} does not exist")
|
raise FileNotFoundError(f"{database_name} does not exist")
|
||||||
with sql.connect(database_name) as con:
|
async with aiosqlite.connect(database_name) as con:
|
||||||
cur: sql.Cursor = con.cursor()
|
cur: aiosqlite.Cursor = await con.cursor()
|
||||||
cur.execute("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?;", (table_name, ))
|
await cur.execute(
|
||||||
count: int = int(cur.fetchone()[0])
|
"SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?;",
|
||||||
|
(table_name,),
|
||||||
|
)
|
||||||
|
count: int = int((await cur.fetchone())[0])
|
||||||
if not count:
|
if not count:
|
||||||
raise FileExistsError(f"Table, {table_name}, already exists.")
|
raise FileExistsError(f"Table, {table_name}, already exists.")
|
||||||
return database_name, table_name
|
return database_name, table_name
|
||||||
|
|
||||||
|
|
||||||
def _get_table_column_names(database_name: str, table_name: str) -> Tuple[str]:
|
async def _get_table_column_names(database_name: str, table_name: str) -> Tuple[str]:
|
||||||
"""
|
"""
|
||||||
Get the column names of table.
|
Get the column names of the table.
|
||||||
|
|
||||||
:param database_name: Name of the database the table
|
:param database_name: The name of the database the table
|
||||||
resides in.
|
resides in.
|
||||||
:param table_name: Name of the table.
|
:param table_name: Name of the table.
|
||||||
:return: A tuple holding the column names of the table.
|
:return: A tuple holding the column names of the table.
|
||||||
"""
|
"""
|
||||||
with sql.connect(database_name) as con:
|
async with aiosqlite.connect(database_name) as con:
|
||||||
cur: sql.Cursor = con.cursor()
|
cur: aiosqlite.Cursor = await con.cursor()
|
||||||
cols: List[str] = _get_table_cols(cur, table_name)
|
cols: List[str] = await _get_table_cols(cur, table_name)
|
||||||
return tuple(cols)
|
return cast(Tuple[str], tuple(cols))
|
||||||
|
|
||||||
|
|
||||||
def _copy_records(database_name: str, table_name: str):
|
async def _copy_records(database_name: str, table_name: str):
|
||||||
"""
|
"""
|
||||||
Copy all records from a table.
|
Copy all records from a table.
|
||||||
|
|
||||||
@@ -56,17 +62,17 @@ def _copy_records(database_name: str, table_name: str):
|
|||||||
:param table_name: Name of the table.
|
:param table_name: Name of the table.
|
||||||
:return: A generator holding dataclass asdict representations.
|
:return: A generator holding dataclass asdict representations.
|
||||||
"""
|
"""
|
||||||
with sql.connect(database_name) as con:
|
async with aiosqlite.connect(database_name) as con:
|
||||||
cur: sql.Cursor = con.cursor()
|
cur: aiosqlite.Cursor = await con.cursor()
|
||||||
cur.execute(f'SELECT * FROM {table_name};')
|
await cur.execute(f"SELECT * FROM {table_name};")
|
||||||
values = cur.fetchall()
|
values = await cur.fetchall()
|
||||||
keys = _get_table_cols(cur, table_name)
|
keys = await _get_table_cols(cur, table_name)
|
||||||
keys.insert(0, 'obj_id')
|
keys.insert(0, "obj_id")
|
||||||
records = (dict(zip(keys, value)) for value in values)
|
records = (dict(zip(keys, value)) for value in values)
|
||||||
return records
|
return records
|
||||||
|
|
||||||
|
|
||||||
def _drop_table(database_name: str, table_name: str) -> None:
|
async def _drop_table(database_name: str, table_name: str) -> None:
|
||||||
"""
|
"""
|
||||||
Drop a table.
|
Drop a table.
|
||||||
|
|
||||||
@@ -74,14 +80,15 @@ def _drop_table(database_name: str, table_name: str) -> None:
|
|||||||
:param table_name: Name of the table to be dropped.
|
:param table_name: Name of the table to be dropped.
|
||||||
:return: None.
|
:return: None.
|
||||||
"""
|
"""
|
||||||
with sql.connect(database_name) as con:
|
async with aiosqlite.connect(database_name) as con:
|
||||||
cur: sql.Cursor = con.cursor()
|
cur: aiosqlite.Cursor = await con.cursor()
|
||||||
cur.execute(f'DROP TABLE {table_name};')
|
await cur.execute(f"DROP TABLE {table_name};")
|
||||||
con.commit()
|
await con.commit()
|
||||||
|
|
||||||
|
|
||||||
def _modify_records(data, col_to_del: Tuple[str], col_to_add: Tuple[str],
|
def _modify_records(
|
||||||
flow: Dict[str, str]) -> Tuple[Dict[str, str]]:
|
data, col_to_del: Tuple[str], col_to_add: Tuple[str], flow: Dict[str, str]
|
||||||
|
) -> Tuple[Dict[str, str]]:
|
||||||
"""
|
"""
|
||||||
Modify the asdict records in accordance
|
Modify the asdict records in accordance
|
||||||
with schema migration rules provided.
|
with schema migration rules provided.
|
||||||
@@ -89,9 +96,9 @@ def _modify_records(data, col_to_del: Tuple[str], col_to_add: Tuple[str],
|
|||||||
:param data: Data kept as asdict in tuple.
|
:param data: Data kept as asdict in tuple.
|
||||||
:param col_to_del: Column names to delete.
|
:param col_to_del: Column names to delete.
|
||||||
:param col_to_add: Column names to add.
|
:param col_to_add: Column names to add.
|
||||||
:param flow: A dictionary that explain
|
:param flow: A dictionary that explains
|
||||||
if the data from a deleted column
|
if the data from a deleted column
|
||||||
will be transferred to a column
|
is transferred to a column
|
||||||
to be added.
|
to be added.
|
||||||
:return: The modified data records.
|
:return: The modified data records.
|
||||||
"""
|
"""
|
||||||
@@ -109,15 +116,22 @@ def _modify_records(data, col_to_del: Tuple[str], col_to_add: Tuple[str],
|
|||||||
if key_to_add not in record_mod:
|
if key_to_add not in record_mod:
|
||||||
record_mod[key_to_add] = None
|
record_mod[key_to_add] = None
|
||||||
records.append(record_mod)
|
records.append(record_mod)
|
||||||
return records
|
return cast(Tuple[Dict[str, str]], records)
|
||||||
|
|
||||||
|
|
||||||
def _migrate_records(class_: type, database_name: str, data,
|
async def _migrate_records(
|
||||||
col_to_del: Tuple[str], col_to_add: Tuple[str], flow: Dict[str, str]) -> None:
|
class_: type,
|
||||||
|
database_name: str,
|
||||||
|
data,
|
||||||
|
col_to_del: Tuple[str],
|
||||||
|
col_to_add: Tuple[str],
|
||||||
|
flow: Dict[str, str],
|
||||||
|
safe_migration_defaults: Dict[str, Any] = None,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Migrate the records into the modified table.
|
Migrate the records into the modified table.
|
||||||
|
|
||||||
:param class_: Class of the entries.
|
:param class_: Class of entries.
|
||||||
:param database_name: Name of the database.
|
:param database_name: Name of the database.
|
||||||
:param data: Data, asdict tuple.
|
:param data: Data, asdict tuple.
|
||||||
:param col_to_del: Columns to be deleted.
|
:param col_to_del: Columns to be deleted.
|
||||||
@@ -126,40 +140,94 @@ def _migrate_records(class_: type, database_name: str, data,
|
|||||||
column data will be transferred.
|
column data will be transferred.
|
||||||
:return: None.
|
:return: None.
|
||||||
"""
|
"""
|
||||||
with sql.connect(database_name) as con:
|
if safe_migration_defaults is None:
|
||||||
cur: sql.Cursor = con.cursor()
|
safe_migration_defaults = {}
|
||||||
_create_table(class_, cur, getattr(class_, 'types_table'))
|
|
||||||
con.commit()
|
async with aiosqlite.connect(database_name) as con:
|
||||||
|
cur: aiosqlite.Cursor = await con.cursor()
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
await class_.markup_table(
|
||||||
|
class_=class_, cursor=cur, type_overload=getattr(class_, "types_table")
|
||||||
|
)
|
||||||
|
await con.commit()
|
||||||
new_records = _modify_records(data, col_to_del, col_to_add, flow)
|
new_records = _modify_records(data, col_to_del, col_to_add, flow)
|
||||||
for record in new_records:
|
for record in new_records:
|
||||||
del record['obj_id']
|
del record["obj_id"]
|
||||||
keys_to_delete = [key for key in record if record[key] is None]
|
keys_to_delete = [key for key in record if record[key] is None]
|
||||||
for key in keys_to_delete:
|
for key in keys_to_delete:
|
||||||
del record[key]
|
del record[key]
|
||||||
class_(**record).create_entry()
|
await class_(
|
||||||
|
**{
|
||||||
|
**record,
|
||||||
|
**{
|
||||||
|
k: _tweaked_dump_value(class_, v)
|
||||||
|
for k, v in safe_migration_defaults.items()
|
||||||
|
if k not in record
|
||||||
|
},
|
||||||
|
}
|
||||||
|
).create_entry()
|
||||||
|
|
||||||
|
|
||||||
def basic_migrate(class_: type, column_transfer: dict = None) -> None:
|
async def migrate(
|
||||||
|
class_: type,
|
||||||
|
column_transfer: dict = None,
|
||||||
|
do_backup: bool = True,
|
||||||
|
safe_migration_defaults: Dict[str, Any] = None,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Given a class, compare its previous table,
|
Given a class, compare its previous table,
|
||||||
delete the fields that no longer exist,
|
delete the fields that no longer exist,
|
||||||
create new columns for new fields. If the
|
create new columns for new fields. If the
|
||||||
column_flow parameter is given, migrate elements
|
column_flow parameter is given, migrate elements
|
||||||
from previous column to the new ones. It should be
|
from previous column to the new ones. It should be
|
||||||
noted that, the obj_ids do not persist.
|
noted that the obj_ids do not persist.
|
||||||
|
|
||||||
:param class_: Datalite class to migrate.
|
:param class_: Datalite class to migrate.
|
||||||
:param column_transfer: A dictionary showing which
|
:param column_transfer: A dictionary showing which
|
||||||
columns will be copied to new ones.
|
columns will be copied to new ones.
|
||||||
|
:param do_backup: Whether to copy a whole database before dropping table
|
||||||
|
:param safe_migration_defaults: Key-value that will be written to old records in the database during
|
||||||
|
migration so as not to break the schema
|
||||||
:return: None.
|
:return: None.
|
||||||
"""
|
"""
|
||||||
database_name, table_name = _get_db_table(class_)
|
database_name, table_name = await _get_db_table(class_)
|
||||||
table_column_names: Tuple[str] = _get_table_column_names(database_name, table_name)
|
table_column_names: Tuple[str] = await _get_table_column_names(
|
||||||
values = class_.__dataclass_fields__.values()
|
database_name, table_name
|
||||||
data_fields: Tuple[Field] = tuple(field for field in values)
|
)
|
||||||
data_field_names: Tuple[str] = tuple(field.name for field in data_fields)
|
|
||||||
columns_to_be_deleted: Tuple[str] = tuple(column for column in table_column_names if column not in data_field_names)
|
# noinspection PyUnresolvedReferences
|
||||||
columns_to_be_added: Tuple[str] = tuple(column for column in data_field_names if column not in table_column_names)
|
values: List[Field] = class_.__dataclass_fields__.values()
|
||||||
records = _copy_records(database_name, table_name)
|
|
||||||
_drop_table(database_name, table_name)
|
data_fields: Tuple[Field] = cast(Tuple[Field], tuple(field for field in values))
|
||||||
_migrate_records(class_, database_name, records, columns_to_be_deleted, columns_to_be_added, column_transfer)
|
data_field_names: Tuple[str] = cast(
|
||||||
|
Tuple[str], tuple(field.name for field in data_fields)
|
||||||
|
)
|
||||||
|
columns_to_be_deleted: Tuple[str] = cast(
|
||||||
|
Tuple[str],
|
||||||
|
tuple(
|
||||||
|
column for column in table_column_names if column not in data_field_names
|
||||||
|
),
|
||||||
|
)
|
||||||
|
columns_to_be_added: Tuple[str] = cast(
|
||||||
|
Tuple[str],
|
||||||
|
tuple(
|
||||||
|
column for column in data_field_names if column not in table_column_names
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
records = await _copy_records(database_name, table_name)
|
||||||
|
if do_backup:
|
||||||
|
shutil.copy(database_name, f"{database_name}-{time.time()}")
|
||||||
|
await _drop_table(database_name, table_name)
|
||||||
|
await _migrate_records(
|
||||||
|
class_,
|
||||||
|
database_name,
|
||||||
|
records,
|
||||||
|
columns_to_be_deleted,
|
||||||
|
columns_to_be_added,
|
||||||
|
column_transfer,
|
||||||
|
safe_migration_defaults,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["migrate"]
|
||||||
|
|||||||
34
datalite/typed.py
Normal file
34
datalite/typed.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
def _datalite_hinted_direct_use():
|
||||||
|
raise ValueError(
|
||||||
|
"Don't use DataliteHinted directly. Inherited classes should also be wrapped in "
|
||||||
|
"datalite and dataclass decorators"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DataliteHinted:
|
||||||
|
db_path: str
|
||||||
|
types_table: Dict[Optional[type], str]
|
||||||
|
tweaked: bool
|
||||||
|
obj_id: int
|
||||||
|
|
||||||
|
# noinspection PyMethodMayBeStatic
|
||||||
|
async def markup_table(self):
|
||||||
|
_datalite_hinted_direct_use()
|
||||||
|
|
||||||
|
# noinspection PyMethodMayBeStatic
|
||||||
|
async def create_entry(self):
|
||||||
|
_datalite_hinted_direct_use()
|
||||||
|
|
||||||
|
# noinspection PyMethodMayBeStatic
|
||||||
|
async def update_entry(self):
|
||||||
|
_datalite_hinted_direct_use()
|
||||||
|
|
||||||
|
# noinspection PyMethodMayBeStatic
|
||||||
|
async def remove_entry(self):
|
||||||
|
_datalite_hinted_direct_use()
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["DataliteHinted"]
|
||||||
23
docs/conf.py
23
docs/conf.py
@@ -12,18 +12,19 @@
|
|||||||
#
|
#
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
sys.path.insert(0, os.path.abspath('.'))
|
|
||||||
sys.path.insert(0, os.path.abspath('../.'))
|
sys.path.insert(0, os.path.abspath("."))
|
||||||
|
sys.path.insert(0, os.path.abspath("../."))
|
||||||
|
|
||||||
|
|
||||||
# -- Project information -----------------------------------------------------
|
# -- Project information -----------------------------------------------------
|
||||||
|
|
||||||
project = 'Datalite'
|
project = "Datalite"
|
||||||
copyright = '2020, Ege Ozkan'
|
copyright = "2020, Ege Ozkan"
|
||||||
author = 'Ege Ozkan'
|
author = "Ege Ozkan"
|
||||||
|
|
||||||
# The full version, including alpha/beta/rc tags
|
# The full version, including alpha/beta/rc tags
|
||||||
release = 'v0.7.1'
|
release = "v0.7.1"
|
||||||
|
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
@@ -31,15 +32,15 @@ release = 'v0.7.1'
|
|||||||
# Add any Sphinx extension module names here, as strings. They can be
|
# Add any Sphinx extension module names here, as strings. They can be
|
||||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||||
# ones.
|
# ones.
|
||||||
extensions = ['sphinx.ext.autodoc']
|
extensions = ["sphinx.ext.autodoc"]
|
||||||
|
|
||||||
# Add any paths that contain templates here, relative to this directory.
|
# Add any paths that contain templates here, relative to this directory.
|
||||||
templates_path = ['_templates']
|
templates_path = ["_templates"]
|
||||||
|
|
||||||
# List of patterns, relative to source directory, that match files and
|
# List of patterns, relative to source directory, that match files and
|
||||||
# directories to ignore when looking for source files.
|
# directories to ignore when looking for source files.
|
||||||
# This pattern also affects html_static_path and html_extra_path.
|
# This pattern also affects html_static_path and html_extra_path.
|
||||||
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
||||||
|
|
||||||
|
|
||||||
# -- Options for HTML output -------------------------------------------------
|
# -- Options for HTML output -------------------------------------------------
|
||||||
@@ -47,9 +48,9 @@ exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
|||||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||||
# a list of builtin themes.
|
# a list of builtin themes.
|
||||||
#
|
#
|
||||||
html_theme = 'sphinx_rtd_theme'
|
html_theme = "sphinx_rtd_theme"
|
||||||
|
|
||||||
# Add any paths that contain custom static files (such as style sheets) here,
|
# Add any paths that contain custom static files (such as style sheets) here,
|
||||||
# relative to this directory. They are copied after the builtin static files,
|
# relative to this directory. They are copied after the builtin static files,
|
||||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||||
html_static_path = ['_static']
|
html_static_path = ["_static"]
|
||||||
|
|||||||
Reference in New Issue
Block a user