From 32a23a82dd9afeedcd9172caa6044943727afc76 Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 21 Nov 2023 17:21:01 +0100 Subject: [PATCH] Don't update equivalent entities in import_data Fixes #314 --- spinedb_api/db_mapping_base.py | 11 +++++++++-- tests/test_import_functions.py | 21 +++++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index f522bbd0..e8d8a21c 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -574,7 +574,8 @@ class MappedItemBase(dict): """A dictionary that represents a db item.""" fields = {} - """A dictionary mapping keys to a tuple of (type, value description)""" + """A dictionary mapping keys to a another dict mapping "type" to a Python type, + "value" to a description of the value for the key, and "optional" to a bool.""" _defaults = {} """A dictionary mapping keys to their default values""" _unique_keys = () @@ -597,6 +598,7 @@ class MappedItemBase(dict): source key. """ _private_fields = set() + """A set with fields that should be ignored in validations.""" def __init__(self, db_map, item_type, **kwargs): """ @@ -741,7 +743,12 @@ def _something_to_update(self, other): def _convert(x): return tuple(x) if isinstance(x, list) else x - return not all(_convert(self.get(key)) == _convert(value) for key, value in other.items()) + return not all( + _convert(self.get(key)) == _convert(value) + for key, value in other.items() + if value is not None + or self.fields.get(key, {}).get("optional", False) # Ignore mandatory fields that are None + ) def first_invalid_key(self): """Goes through the ``_references`` class attribute and returns the key of the first reference diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py index 98da973d..b6ebfd2c 100644 --- a/tests/test_import_functions.py +++ b/tests/test_import_functions.py @@ -340,6 +340,27 @@ def test_import_existing_relationship_class_parameter(self): db_map.close() +class TestImportEntity(unittest.TestCase): + def test_import_multi_d_entity_twice(self): + db_map = DatabaseMapping("sqlite://", create=True) + import_data( + db_map, + entity_classes=( + ("object_class1",), + ("object_class2",), + ("relationship_class", ("object_class1", "object_class2")), + ), + entities=( + ("object_class1", "object1"), + ("object_class2", "object2"), + ("relationship_class", ("object1", "object2")), + ), + ) + count, errors = import_data(db_map, entities=(("relationship_class", ("object1", "object2")),)) + self.assertEqual(count, 0) + self.assertEqual(errors, []) + + class TestImportRelationship(unittest.TestCase): @staticmethod def populate(db_map):