diff --git a/datalite/__init__.py b/datalite/__init__.py index b01f482..2834220 100644 --- a/datalite/__init__.py +++ b/datalite/__init__.py @@ -1,2 +1,2 @@ -__all__ = ['commons', 'datalite_decorator', 'fetch', 'migrations', 'datalite', 'constraints'] +__all__ = ['commons', 'datalite_decorator', 'fetch', 'migrations', 'datalite', 'constraints', 'mass_actions'] from .datalite_decorator import datalite diff --git a/datalite/mass_actions.py b/datalite/mass_actions.py new file mode 100644 index 0000000..d3f1d6e --- /dev/null +++ b/datalite/mass_actions.py @@ -0,0 +1,62 @@ +""" +This module includes functions to insert multiple records + to a bound database at one time, with one time open and closing + of the database file. +""" +from typing import TypeVar, Union, List, Tuple +from dataclasses import asdict +from .constraints import ConstraintFailedError +from .commons import _convert_sql_format +import sqlite3 as sql + +T = TypeVar('T') + + +class MisformedCollectionError(Exception): + pass + + +def is_homogeneous(objects: Union[List[T], Tuple[T]]) -> bool: + """ + Check if all of the members a Tuple or a List + is of the same type. + + :param objects: Tuple or list to check. + :return: If all of the members of the same type. + """ + class_ = objects[0].__class__ + return all([isinstance(obj, class_) for obj in objects]) + + +def create_many_entries(objects: Union[List[T], Tuple[T]]) -> None: + """ + Insert many records corresponding to objects + in a tuple or a list. + + :param objects: A tuple or a list of objects decorated + with datalite. + :return: None. + """ + if not objects or not is_homogeneous(objects): + raise MisformedCollectionError("Tuple or List is empty or homogeneous.") + sql_queries = [] + first_index: int = 0 + table_name = objects[0].__class__.__name__.lower() + for obj in objects: + kv_pairs = asdict(obj).items() + sql_queries.append(f"INSERT INTO {table_name}(" + + f"{', '.join(item[0] for item in kv_pairs)})" + + f" VALUES ({', '.join(_convert_sql_format(item[1]) for item in kv_pairs)});") + with sql.connect(getattr(objects[0], "db_path")) as con: + cur: sql.Cursor = con.cursor() + try: + cur.execute(f"SELECT obj_id FROM {table_name} ORDER BY obj_id DESC LIMIT 1") + index_tuple = cur.fetchone() + if index_tuple: + first_index = index_tuple[0] + cur.executescript('\n'.join(sql_queries)) + except sql.IntegrityError: + raise ConstraintFailedError + con.commit() + for i, obj in enumerate(objects): + setattr(obj, "obj_id", first_index + i) diff --git a/docs/datalite.rst b/docs/datalite.rst index 05d1711..9a9af74 100644 --- a/docs/datalite.rst +++ b/docs/datalite.rst @@ -22,6 +22,14 @@ datalite.fetch module :undoc-members: :show-inheritance: +datalite.mass_actions module +------------------------------ + +.. automodule:: datalite.mass_actions + :members: + :undoc-members: + :show-inheritence: + datalite.migrations module ---------------------------- diff --git a/test/main_tests.py b/test/main_tests.py index 34b4633..5ab28e2 100644 --- a/test/main_tests.py +++ b/test/main_tests.py @@ -2,11 +2,10 @@ import unittest from datalite import datalite from datalite.constraints import Unique, ConstraintFailedError from datalite.fetch import fetch_if, fetch_all, fetch_range, fetch_from, fetch_equals, fetch_where +from datalite.mass_actions import create_many_entries from sqlite3 import connect from dataclasses import dataclass, asdict from math import floor -from os import remove - from datalite.migrations import basic_migrate, _drop_table @@ -51,6 +50,12 @@ class ConstraintedClass: unique_str: Unique[str] +@datalite(db_path='test.db') +@dataclass +class MassCommit: + str_: str + + def getValFromDB(obj_id = 1): with connect('test.db') as db: cur = db.cursor() @@ -190,5 +195,16 @@ class DatabaseConstraints(unittest.TestCase): self.obj.remove_entry() +class DatabaseMassInsert(unittest.TestCase): + def setUp(self) -> None: + self.objs = [MassCommit('cat') for _ in range(30)] + + def testMassCreate(self): + create_many_entries(self.objs) + + def tearDown(self) -> None: + [obj.remove_entry() for obj in self.objs] + + if __name__ == '__main__': unittest.main()