diff --git a/datalite/mass_actions.py b/datalite/mass_actions.py index 3da743e..08f4c06 100644 --- a/datalite/mass_actions.py +++ b/datalite/mass_actions.py @@ -65,8 +65,10 @@ def _mass_insert(objects: Union[List[T], Tuple[T]], db_name: str, protect_memory sql_queries = [] first_index: int = 0 table_name = objects[0].__class__.__name__.lower() - for obj in objects: + + for i, obj in enumerate(objects): kv_pairs = asdict(obj).items() + setattr(obj, "obj_id", first_index + i + 1) sql_queries.append(f"INSERT INTO {table_name}(" + f"{', '.join(item[0] for item in kv_pairs)})" + f" VALUES ({', '.join(_convert_sql_format(item[1]) for item in kv_pairs)});") @@ -81,8 +83,6 @@ def _mass_insert(objects: Union[List[T], Tuple[T]], db_name: str, protect_memory cur.executescript("BEGIN TRANSACTION;\n" + '\n'.join(sql_queries) + '\nEND TRANSACTION;') except sql.IntegrityError: raise ConstraintFailedError - for i, obj in enumerate(objects): - setattr(obj, "obj_id", first_index + i + 1) def create_many(objects: Union[List[T], Tuple[T]], protect_memory: bool = True) -> None: diff --git a/test/main_tests.py b/test/main_tests.py index d86d138..578c729 100644 --- a/test/main_tests.py +++ b/test/main_tests.py @@ -199,16 +199,23 @@ class DatabaseMassInsert(unittest.TestCase): def setUp(self) -> None: self.objs = [MassCommit(f'cat + {i}') for i in range(30)] + def testMassCreate(self): + with connect('other.db') as con: + cur = con.cursor() + cur.execute(f'CREATE TABLE IF NOT EXISTS MASSCOMMIT (obj_id, str_)') + + start_tup = fetch_all(MassCommit) create_many(self.objs, protect_memory=False) _objs = fetch_all(MassCommit) - self.assertEqual(_objs, tuple(self.objs)) + self.assertEqual(_objs, start_tup + tuple(self.objs)) def testMassCopy(self): - copy_many(self.objs, 'other.db', False) setattr(MassCommit, 'db_path', 'other.db') + start_tup = fetch_all(MassCommit) + copy_many(self.objs, 'other.db', False) tup = fetch_all(MassCommit) - self.assertEqual(tup, tuple(self.objs)) + self.assertEqual(tup, start_tup + tuple(self.objs)) def tearDown(self) -> None: [obj.remove_entry() for obj in self.objs]