From a4a685ca796bcc297d472f7a80affeb1f9f6dc4d Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 8 Feb 2023 14:39:48 +0100 Subject: [PATCH 001/317] Introduce the dimension and element concepts --- spinedb_api/check_functions.py | 10 +- spinedb_api/db_mapping.py | 1 - spinedb_api/db_mapping_add_mixin.py | 49 +++---- spinedb_api/db_mapping_base.py | 191 ++++++++++++++------------ spinedb_api/db_mapping_check_mixin.py | 4 +- spinedb_api/diff_db_mapping.py | 1 - spinedb_api/helpers.py | 111 +++------------ 7 files changed, 149 insertions(+), 218 deletions(-) diff --git a/spinedb_api/check_functions.py b/spinedb_api/check_functions.py index 48c37064..5d90e7f0 100644 --- a/spinedb_api/check_functions.py +++ b/spinedb_api/check_functions.py @@ -88,7 +88,7 @@ def check_scenario_alternative(item, ids_by_alt_id, ids_by_rank, scenario_names, ) -def check_object_class(item, current_items, object_class_type): +def check_object_class(item, current_items): """Check whether the insertion of an object class item results in the violation of an integrity constraint. @@ -107,15 +107,11 @@ def check_object_class(item, current_items, object_class_type): ) if not name: raise SpineIntegrityError("Object class name is an empty string and therefore not valid") - if "type_id" in item and item["type_id"] != object_class_type: - raise SpineIntegrityError( - f"Object class '{name}' does not have a type_id of an object class.", id=current_items[name] - ) if name in current_items: raise SpineIntegrityError(f"There can't be more than one object class called '{name}'.", id=current_items[name]) -def check_object(item, current_items, object_class_ids, object_entity_type): +def check_object(item, current_items, object_class_ids): """Check whether the insertion of an object item results in the violation of an integrity constraint. @@ -135,8 +131,6 @@ def check_object(item, current_items, object_class_ids, object_entity_type): ) if not name: raise SpineIntegrityError("Object name is an empty string and therefore not valid") - if "type_id" in item and item["type_id"] != object_entity_type: - raise SpineIntegrityError(f"Object '{name}' does not have entity type of and object", id=current_items[name]) try: class_id = item["class_id"] except KeyError: diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 77e7d896..f4c5218e 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -44,7 +44,6 @@ class DatabaseMapping( def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._init_type_attributes() if self._filter_configs is not None: stack = load_filters(self._filter_configs) apply_filter_stack(self, stack) diff --git a/spinedb_api/db_mapping_add_mixin.py b/spinedb_api/db_mapping_add_mixin.py index 1dbe72ac..9ac6b624 100644 --- a/spinedb_api/db_mapping_add_mixin.py +++ b/spinedb_api/db_mapping_add_mixin.py @@ -106,6 +106,12 @@ def _do_reserve_ids(self, connection, tablename, count): next_id = getattr(next_id_row, fieldname) stmt = self._next_id.update() if next_id is None: + tablename = { + "object_class": "entity_class", + "relationship_class": "entity_class", + "object": "entity", + "relationship": "entity", + }.get(tablename, tablename) table = self._metadata.tables[tablename] id_col = self.table_ids.get(tablename, "id") select_max_id = select([func.max(getattr(table.c, id_col))]) @@ -200,7 +206,6 @@ def _get_table_for_insert(self, tablename): def _do_add_items(self, tablename, *items_to_add): if not self.committing: return - items_to_add = tuple(self._items_with_type_id(tablename, *items_to_add)) try: for tablename_, items_to_add_ in self._items_to_add_per_table(tablename, items_to_add): table = self._get_table_for_insert(tablename_) @@ -223,45 +228,35 @@ def _items_to_add_per_table(self, tablename, items_to_add): tuple: database table name, items to add """ if tablename == "object_class": - oc_items_to_add = list() - append_oc_items_to_add = oc_items_to_add.append - for item in items_to_add: - append_oc_items_to_add({"entity_class_id": item["id"], "type_id": self.object_class_type}) yield ("entity_class", items_to_add) - yield ("object_class", oc_items_to_add) elif tablename == "object": - o_items_to_add = list() - append_o_items_to_add = o_items_to_add.append - for item in items_to_add: - append_o_items_to_add({"entity_id": item["id"], "type_id": item["type_id"]}) yield ("entity", items_to_add) - yield ("object", o_items_to_add) elif tablename == "relationship_class": - rc_items_to_add = list() - rec_items_to_add = list() + ecd_items_to_add = [] for item in items_to_add: - rc_items_to_add.append({"entity_class_id": item["id"], "type_id": self.relationship_class_type}) - rec_items_to_add += get_relationship_entity_class_items(item, self.object_class_type) + ecd_items_to_add += [ + {"entity_class_id": item["id"], "position": position, "dimension_id": dimension_id} + for position, dimension_id in enumerate(item["object_class_id_list"]) + ] yield ("entity_class", items_to_add) - yield ("relationship_class", rc_items_to_add) - yield ("relationship_entity_class", rec_items_to_add) + yield ("entity_class_dimension", ecd_items_to_add) elif tablename == "relationship": - re_items_to_add = list() - r_items_to_add = list() + ee_items_to_add = [] for item in items_to_add: - r_items_to_add.append( + ee_items_to_add += [ { "entity_id": item["id"], "entity_class_id": item["class_id"], - "type_id": self.relationship_entity_type, + "position": position, + "element_id": object_id, + "dimension_id": object_class_id, } - ) - re_items_to_add += get_relationship_entity_items( - item, self.relationship_entity_type, self.object_entity_type - ) + for position, (object_id, object_class_id) in enumerate( + zip(item["object_id_list"], item["object_class_id_list"]) + ) + ] yield ("entity", items_to_add) - yield ("relationship", r_items_to_add) - yield ("relationship_entity", re_items_to_add) + yield ("entity_element", ee_items_to_add) elif tablename == "parameter_definition": for item in items_to_add: item["entity_class_id"] = ( diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index f3e1b13e..19cade85 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -164,6 +164,8 @@ def __init__( self._entity_metadata_sq = None self._clean_parameter_value_sq = None # Special convenience subqueries that join two or more tables + self._ext_entity_class_sq = None + self._ext_entity_sq = None self._ext_parameter_value_list_sq = None self._wide_parameter_value_list_sq = None self._ord_list_value_sq = None @@ -267,24 +269,6 @@ def __init__( self.descendant_tablenames = { tablename: set(self._descendant_tablenames(tablename)) for tablename in self.cache_sqs } - self.object_class_type = None - self.relationship_class_type = None - self.object_entity_type = None - self.relationship_entity_type = None - - def _init_type_attributes(self): - self.object_class_type = ( - self.query(self.entity_class_type_sq).filter(self.entity_class_type_sq.c.name == "object").first().id - ) - self.relationship_class_type = ( - self.query(self.entity_class_type_sq).filter(self.entity_class_type_sq.c.name == "relationship").first().id - ) - self.object_entity_type = ( - self.query(self.entity_type_sq).filter(self.entity_type_sq.c.name == "object").first().id - ) - self.relationship_entity_type = ( - self.query(self.entity_type_sq).filter(self.entity_type_sq.c.name == "relationship").first().id - ) def __enter__(self): return self @@ -560,64 +544,102 @@ def scenario_alternative_sq(self): return self._scenario_alternative_sq @property - def entity_class_type_sq(self): + def entity_class_sq(self): """A subquery of the form: .. code-block:: sql - SELECT * FROM class_type + SELECT * FROM class Returns: sqlalchemy.sql.expression.Alias """ - if self._entity_class_type_sq is None: - self._entity_class_type_sq = self._subquery("entity_class_type") - return self._entity_class_type_sq + if self._entity_class_sq is None: + self._entity_class_sq = self._make_entity_class_sq() + return self._entity_class_sq @property - def entity_type_sq(self): + def entity_sq(self): """A subquery of the form: .. code-block:: sql - SELECT * FROM class_type + SELECT * FROM entity Returns: sqlalchemy.sql.expression.Alias """ - if self._entity_type_sq is None: - self._entity_type_sq = self._subquery("entity_type") - return self._entity_type_sq + if self._entity_sq is None: + self._entity_sq = self._make_entity_sq() + return self._entity_sq @property - def entity_class_sq(self): + def ext_entity_class_sq(self): """A subquery of the form: .. code-block:: sql - SELECT * FROM class + SELECT + ec.*, + count(ecd.dimension_id) AS dimension_count + group_concat(ecd.dimension_id) AS dimension_id_list + FROM + entity_class AS ec + entity_class_dimension AS ecd + WHERE + ec.id == ecd.entity_class_id Returns: sqlalchemy.sql.expression.Alias """ - if self._entity_class_sq is None: - self._entity_class_sq = self._make_entity_class_sq() - return self._entity_class_sq + if self._ext_entity_class_sq is None: + entity_class_dimension_sq = self._subquery("entity_class_dimension") + self._ext_entity_class_sq = ( + self.query( + self.entity_class_sq, + group_concat(entity_class_dimension_sq.c.dimension_id, entity_class_dimension_sq.c.position).label( + "dimension_id_list" + ), + ) + .outerjoin( + entity_class_dimension_sq, self.entity_class_sq.c.id == entity_class_dimension_sq.c.entity_class_id + ) + .group_by(entity_class_dimension_sq.c.entity_class_id) + .subquery() + ) + return self._ext_entity_class_sq @property - def entity_sq(self): + def ext_entity_sq(self): """A subquery of the form: .. code-block:: sql - SELECT * FROM entity + SELECT + e.*, + count(ee.element_id) AS element_count + group_concat(ee.element_id) AS element_id_list + FROM + entity AS e + entity_element AS ee + WHERE + e.id == ee.entity_id Returns: sqlalchemy.sql.expression.Alias """ - if self._entity_sq is None: - self._entity_sq = self._make_entity_sq() - return self._entity_sq + if self._ext_entity_sq is None: + entity_element_sq = self._subquery("entity_element") + self._ext_entity_sq = ( + self.query( + self.entity_sq, + group_concat(entity_element_sq.c.element_id, entity_element_sq.c.position).label("element_id_list"), + ) + .outerjoin(entity_element_sq, self.entity_sq.c.id == entity_element_sq.c.entity_id) + .group_by(entity_element_sq.c.entity_id) + .subquery() + ) + return self._ext_entity_sq @property def object_class_sq(self): @@ -631,18 +653,16 @@ def object_class_sq(self): sqlalchemy.sql.expression.Alias """ if self._object_class_sq is None: - object_class_sq = self._subquery("object_class") self._object_class_sq = ( self.query( - self.entity_class_sq.c.id.label("id"), - self.entity_class_sq.c.name.label("name"), - self.entity_class_sq.c.description.label("description"), - self.entity_class_sq.c.display_order.label("display_order"), - self.entity_class_sq.c.display_icon.label("display_icon"), - self.entity_class_sq.c.hidden.label("hidden"), - self.entity_class_sq.c.commit_id.label("commit_id"), + self.ext_entity_class_sq.c.id.label("id"), + self.ext_entity_class_sq.c.name.label("name"), + self.ext_entity_class_sq.c.description.label("description"), + self.ext_entity_class_sq.c.display_order.label("display_order"), + self.ext_entity_class_sq.c.display_icon.label("display_icon"), + self.ext_entity_class_sq.c.hidden.label("hidden"), ) - .filter(self.entity_class_sq.c.id == object_class_sq.c.entity_class_id) + .filter(self.ext_entity_class_sq.c.dimension_id_list == None) .subquery() ) return self._object_class_sq @@ -659,16 +679,15 @@ def object_sq(self): sqlalchemy.sql.expression.Alias """ if self._object_sq is None: - object_sq = self._subquery("object") self._object_sq = ( self.query( - self.entity_sq.c.id.label("id"), - self.entity_sq.c.class_id.label("class_id"), - self.entity_sq.c.name.label("name"), - self.entity_sq.c.description.label("description"), - self.entity_sq.c.commit_id.label("commit_id"), + self.ext_entity_sq.c.id.label("id"), + self.ext_entity_sq.c.class_id.label("class_id"), + self.ext_entity_sq.c.name.label("name"), + self.ext_entity_sq.c.description.label("description"), + self.ext_entity_sq.c.commit_id.label("commit_id"), ) - .filter(self.entity_sq.c.id == object_sq.c.entity_id) + .filter(self.ext_entity_sq.c.element_id_list == None) .subquery() ) return self._object_sq @@ -685,19 +704,18 @@ def relationship_class_sq(self): sqlalchemy.sql.expression.Alias """ if self._relationship_class_sq is None: - rel_ent_cls_sq = self._subquery("relationship_entity_class") + ent_cls_dim_sq = self._subquery("entity_class_dimension") self._relationship_class_sq = ( self.query( - rel_ent_cls_sq.c.entity_class_id.label("id"), - rel_ent_cls_sq.c.dimension.label("dimension"), - rel_ent_cls_sq.c.member_class_id.label("object_class_id"), - self.entity_class_sq.c.name.label("name"), - self.entity_class_sq.c.description.label("description"), - self.entity_class_sq.c.display_icon.label("display_icon"), - self.entity_class_sq.c.hidden.label("hidden"), - self.entity_class_sq.c.commit_id.label("commit_id"), + ent_cls_dim_sq.c.entity_class_id.label("id"), + ent_cls_dim_sq.c.position.label("dimension"), # NOTE: nothing to do with the `dimension` concept + ent_cls_dim_sq.c.dimension_id.label("object_class_id"), + self.ext_entity_class_sq.c.name.label("name"), + self.ext_entity_class_sq.c.description.label("description"), + self.ext_entity_class_sq.c.display_icon.label("display_icon"), + self.ext_entity_class_sq.c.hidden.label("hidden"), ) - .filter(self.entity_class_sq.c.id == rel_ent_cls_sq.c.entity_class_id) + .filter(self.ext_entity_class_sq.c.id == ent_cls_dim_sq.c.entity_class_id) .subquery() ) return self._relationship_class_sq @@ -714,17 +732,17 @@ def relationship_sq(self): sqlalchemy.sql.expression.Alias """ if self._relationship_sq is None: - rel_ent_sq = self._subquery("relationship_entity") + ent_el_sq = self._subquery("entity_element") self._relationship_sq = ( self.query( - rel_ent_sq.c.entity_id.label("id"), - rel_ent_sq.c.dimension.label("dimension"), - rel_ent_sq.c.member_id.label("object_id"), - rel_ent_sq.c.entity_class_id.label("class_id"), - self.entity_sq.c.name.label("name"), - self.entity_sq.c.commit_id.label("commit_id"), + ent_el_sq.c.entity_id.label("id"), + ent_el_sq.c.position.label("dimension"), # NOTE: nothing to do with the `dimension` concept + ent_el_sq.c.element_id.label("object_id"), + ent_el_sq.c.entity_class_id.label("class_id"), + self.ext_entity_sq.c.name.label("name"), + self.ext_entity_sq.c.commit_id.label("commit_id"), ) - .filter(self.entity_sq.c.id == rel_ent_sq.c.entity_id) + .filter(self.ext_entity_sq.c.id == ent_el_sq.c.entity_id) .subquery() ) return self._relationship_sq @@ -1099,7 +1117,6 @@ def ext_relationship_class_sq(self): self.relationship_class_sq.c.display_icon.label("display_icon"), self.object_class_sq.c.id.label("object_class_id"), self.object_class_sq.c.name.label("object_class_name"), - self.relationship_class_sq.c.commit_id.label("commit_id"), ) .filter(self.relationship_class_sq.c.object_class_id == self.object_class_sq.c.id) .order_by(self.relationship_class_sq.c.id, self.relationship_class_sq.c.dimension) @@ -1140,7 +1157,6 @@ def wide_relationship_class_sq(self): self.ext_relationship_class_sq.c.name, self.ext_relationship_class_sq.c.description, self.ext_relationship_class_sq.c.display_icon, - self.ext_relationship_class_sq.c.commit_id, group_concat( self.ext_relationship_class_sq.c.object_class_id, self.ext_relationship_class_sq.c.dimension ).label("object_class_id_list"), @@ -1153,7 +1169,6 @@ def wide_relationship_class_sq(self): self.ext_relationship_class_sq.c.name, self.ext_relationship_class_sq.c.description, self.ext_relationship_class_sq.c.display_icon, - self.ext_relationship_class_sq.c.commit_id, ) .subquery() ) @@ -2058,34 +2073,32 @@ def _items_with_type_id(self, tablename, *items): yield item def _object_class_id(self): - return case([(self.entity_class_sq.c.type_id == self.object_class_type, self.entity_class_sq.c.id)], else_=None) + return case([(self.ext_entity_class_sq.c.dimension_id_list == None, self.ext_entity_class_sq.c.id)], else_=None) def _relationship_class_id(self): - return case( - [(self.entity_class_sq.c.type_id == self.relationship_class_type, self.entity_class_sq.c.id)], else_=None - ) + return case([(self.ext_entity_class_sq.c.dimension_id_list != None, self.ext_entity_class_sq.c.id)], else_=None) def _object_id(self): - return case([(self.entity_sq.c.type_id == self.object_entity_type, self.entity_sq.c.id)], else_=None) + return case([(self.ext_entity_sq.c.element_id_list == None, self.ext_entity_sq.c.id)], else_=None) def _relationship_id(self): - return case([(self.entity_sq.c.type_id == self.relationship_entity_type, self.entity_sq.c.id)], else_=None) + return case([(self.ext_entity_sq.c.element_id_list != None, self.ext_entity_sq.c.id)], else_=None) def _object_class_name(self): return case( - [(self.entity_class_sq.c.type_id == self.object_class_type, self.entity_class_sq.c.name)], else_=None + [(self.ext_entity_class_sq.c.dimension_id_list == None, self.ext_entity_class_sq.c.name)], else_=None ) def _relationship_class_name(self): return case( - [(self.entity_class_sq.c.type_id == self.relationship_class_type, self.entity_class_sq.c.name)], else_=None + [(self.ext_entity_class_sq.c.dimension_id_list != None, self.ext_entity_class_sq.c.name)], else_=None ) def _object_class_id_list(self): return case( [ ( - self.entity_class_sq.c.type_id == self.relationship_class_type, + self.ext_entity_class_sq.c.dimension_id_list != None, self.wide_relationship_class_sq.c.object_class_id_list, ) ], @@ -2096,7 +2109,7 @@ def _object_class_name_list(self): return case( [ ( - self.entity_class_sq.c.type_id == self.relationship_class_type, + self.ext_entity_class_sq.c.dimension_id_list != None, self.wide_relationship_class_sq.c.object_class_name_list, ) ], @@ -2104,18 +2117,16 @@ def _object_class_name_list(self): ) def _object_name(self): - return case([(self.entity_sq.c.type_id == self.object_entity_type, self.entity_sq.c.name)], else_=None) + return case([(self.ext_entity_sq.c.element_id_list == None, self.entity_sq.c.name)], else_=None) def _object_id_list(self): return case( - [(self.entity_sq.c.type_id == self.relationship_entity_type, self.wide_relationship_sq.c.object_id_list)], - else_=None, + [(self.ext_entity_sq.c.element_id_list != None, self.wide_relationship_sq.c.object_id_list)], else_=None ) def _object_name_list(self): return case( - [(self.entity_sq.c.type_id == self.relationship_entity_type, self.wide_relationship_sq.c.object_name_list)], - else_=None, + [(self.ext_entity_sq.c.element_id_list != None, self.wide_relationship_sq.c.object_name_list)], else_=None ) @staticmethod diff --git a/spinedb_api/db_mapping_check_mixin.py b/spinedb_api/db_mapping_check_mixin.py index 001e9433..a67fb5bf 100644 --- a/spinedb_api/db_mapping_check_mixin.py +++ b/spinedb_api/db_mapping_check_mixin.py @@ -348,7 +348,7 @@ def check_object_classes(self, *items, for_update=False, strict=False, cache=Non with self._manage_stocks( "object_class", item, {("name",): object_class_ids}, for_update, cache, intgr_error_log ) as item: - check_object_class(item, object_class_ids, self.object_class_type) + check_object_class(item, object_class_ids) checked_items.append(item) except SpineIntegrityError as e: if strict: @@ -378,7 +378,7 @@ def check_objects(self, *items, for_update=False, strict=False, cache=None): with self._manage_stocks( "object", item, {("class_id", "name"): object_ids}, for_update, cache, intgr_error_log ) as item: - check_object(item, object_ids, object_class_ids, self.object_entity_type) + check_object(item, object_ids, object_class_ids) checked_items.append(item) except SpineIntegrityError as e: if strict: diff --git a/spinedb_api/diff_db_mapping.py b/spinedb_api/diff_db_mapping.py index 8caa1ef9..e6e7bea1 100644 --- a/spinedb_api/diff_db_mapping.py +++ b/spinedb_api/diff_db_mapping.py @@ -52,7 +52,6 @@ class DiffDatabaseMapping( def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._init_type_attributes() if self._filter_configs is not None: stack = load_filters(self._filter_configs) apply_filter_stack(self, stack) diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index 979ff709..1a1efcf3 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -347,130 +347,65 @@ def create_spine_metadata(): UniqueConstraint("scenario_id", "rank"), UniqueConstraint("scenario_id", "alternative_id"), ) - Table( - "entity_class_type", - meta, - Column("id", Integer, primary_key=True), - Column("name", String(255), nullable=False), - Column("commit_id", Integer, ForeignKey("commit.id")), - ) - Table( - "entity_type", - meta, - Column("id", Integer, primary_key=True), - Column("name", String(255), nullable=False), - Column("commit_id", Integer, ForeignKey("commit.id")), - ) Table( "entity_class", meta, Column("id", Integer, primary_key=True), - Column( - "type_id", - Integer, - ForeignKey("entity_class_type.id", onupdate="CASCADE", ondelete="CASCADE"), - nullable=False, - ), Column("name", String(255), nullable=False), Column("description", Text(), server_default=null()), Column("display_order", Integer, server_default="99"), Column("display_icon", BigInteger, server_default=null()), Column("hidden", Integer, server_default="0"), - Column("commit_id", Integer, ForeignKey("commit.id")), - UniqueConstraint("id", "type_id"), - UniqueConstraint("type_id", "name"), - ) - Table( - "object_class", - meta, - Column("entity_class_id", Integer, primary_key=True), - Column("type_id", Integer, nullable=False), - ForeignKeyConstraint( - ("entity_class_id", "type_id"), ("entity_class.id", "entity_class.type_id"), ondelete="CASCADE" - ), - CheckConstraint("`type_id` = 1", name="type_id"), # make sure object class can only have object type - ) - Table( - "relationship_class", - meta, - Column("entity_class_id", Integer, primary_key=True), - Column("type_id", Integer, nullable=False), - ForeignKeyConstraint( - ("entity_class_id", "type_id"), ("entity_class.id", "entity_class.type_id"), ondelete="CASCADE" - ), - CheckConstraint("`type_id` = 2", name="type_id"), ) Table( - "relationship_entity_class", + "entity_class_dimension", meta, Column( "entity_class_id", Integer, - ForeignKey("relationship_class.entity_class_id", onupdate="CASCADE", ondelete="CASCADE"), + ForeignKey("entity_class.id", onupdate="CASCADE", ondelete="CASCADE"), primary_key=True, ), - Column("dimension", Integer, primary_key=True), - Column("member_class_id", Integer, nullable=False), - Column("member_class_type_id", Integer, nullable=False), - UniqueConstraint("entity_class_id", "dimension", "member_class_id", name="uq_relationship_entity_class"), - ForeignKeyConstraint(("member_class_id", "member_class_type_id"), ("entity_class.id", "entity_class.type_id")), - CheckConstraint("`member_class_type_id` != 2", name="member_class_type_id"), + Column( + "dimension_id", + Integer, + ForeignKey("entity_class.id", onupdate="CASCADE", ondelete="CASCADE"), + primary_key=True, + ), + Column("position", Integer, primary_key=True), + UniqueConstraint("entity_class_id", "dimension_id", "position", name="uq_entity_class_dimension"), ) Table( "entity", meta, Column("id", Integer, primary_key=True), - Column("type_id", Integer, ForeignKey("entity_type.id", onupdate="CASCADE", ondelete="CASCADE")), Column("class_id", Integer, ForeignKey("entity_class.id", onupdate="CASCADE", ondelete="CASCADE")), Column("name", String(255), nullable=False), Column("description", Text(), server_default=null()), Column("commit_id", Integer, ForeignKey("commit.id")), - UniqueConstraint("id", "class_id"), - UniqueConstraint("id", "type_id", "class_id"), UniqueConstraint("class_id", "name"), ) Table( - "object", - meta, - Column("entity_id", Integer, primary_key=True), - Column("type_id", Integer, nullable=False), - ForeignKeyConstraint(("entity_id", "type_id"), ("entity.id", "entity.type_id"), ondelete="CASCADE"), - CheckConstraint("`type_id` = 1", name="type_id"), # make sure object can only have object type - ) - Table( - "relationship", - meta, - Column("entity_id", Integer, primary_key=True), - Column("entity_class_id", Integer, nullable=False), - Column("type_id", Integer, nullable=False), - UniqueConstraint("entity_id", "entity_class_id"), - ForeignKeyConstraint(("entity_id", "type_id"), ("entity.id", "entity.type_id"), ondelete="CASCADE"), - CheckConstraint("`type_id` = 2", name="type_id"), - ) - Table( - "relationship_entity", + "entity_element", meta, Column("entity_id", Integer, primary_key=True), Column("entity_class_id", Integer, nullable=False), - Column("dimension", Integer, primary_key=True), - Column("member_id", Integer, nullable=False), - Column("member_class_id", Integer, nullable=False), + Column("element_id", Integer, nullable=False), + Column("dimension_id", Integer, nullable=False), + Column("position", Integer, primary_key=True), ForeignKeyConstraint( - ("member_id", "member_class_id"), ("entity.id", "entity.class_id"), onupdate="CASCADE", ondelete="CASCADE" + ("entity_id", "entity_class_id"), ("entity.id", "entity.class_id"), onupdate="CASCADE", ondelete="CASCADE" ), ForeignKeyConstraint( - ("entity_class_id", "dimension", "member_class_id"), - ( - "relationship_entity_class.entity_class_id", - "relationship_entity_class.dimension", - "relationship_entity_class.member_class_id", - ), - onupdate="CASCADE", - ondelete="CASCADE", + ("element_id", "dimension_id"), ("entity.id", "entity.class_id"), onupdate="CASCADE", ondelete="CASCADE" ), ForeignKeyConstraint( - ("entity_id", "entity_class_id"), - ("relationship.entity_id", "relationship.entity_class_id"), + ("entity_class_id", "dimension_id", "position"), + ( + "entity_class_dimension.entity_class_id", + "entity_class_dimension.dimension_id", + "entity_class_dimension.position", + ), onupdate="CASCADE", ondelete="CASCADE", ), @@ -698,8 +633,6 @@ def create_new_spine_database(db_url): meta.create_all(engine) engine.execute("INSERT INTO `commit` VALUES (1, 'Create the database', CURRENT_TIMESTAMP, 'spinedb_api')") engine.execute("INSERT INTO alternative VALUES (1, 'Base', 'Base alternative', 1)") - engine.execute("INSERT INTO entity_class_type VALUES (1, 'object', 1), (2, 'relationship', 1)") - engine.execute("INSERT INTO entity_type VALUES (1, 'object', 1), (2, 'relationship', 1)") engine.execute("INSERT INTO alembic_version VALUES ('989fccf80441')") except DatabaseError as e: raise SpineDBAPIError("Unable to create Spine database: {}".format(e)) from None From be03b6108fd92ab5595d4d6e8e6b95e40551bfdd Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 8 Feb 2023 16:33:46 +0100 Subject: [PATCH 002/317] Select and insert stuff objects and relationships --- spinedb_api/check_functions.py | 13 ++------ spinedb_api/db_mapping_add_mixin.py | 1 - spinedb_api/db_mapping_base.py | 43 +++++++++------------------ spinedb_api/db_mapping_check_mixin.py | 5 +--- spinedb_api/helpers.py | 29 ------------------ 5 files changed, 17 insertions(+), 74 deletions(-) diff --git a/spinedb_api/check_functions.py b/spinedb_api/check_functions.py index 5d90e7f0..94bfa0dd 100644 --- a/spinedb_api/check_functions.py +++ b/spinedb_api/check_functions.py @@ -143,7 +143,7 @@ def check_object(item, current_items, object_class_ids): ) -def check_wide_relationship_class(wide_item, current_items, object_class_ids, relationship_class_type): +def check_wide_relationship_class(wide_item, current_items, object_class_ids): """Check whether the insertion of a relationship class item results in the violation of an integrity constraint. @@ -177,17 +177,13 @@ def check_wide_relationship_class(wide_item, current_items, object_class_ids, re raise SpineIntegrityError( f"At least one of the object class ids of the relationship class '{name}' is not in the database." ) - if "type_id" in wide_item and wide_item["type_id"] != relationship_class_type: - raise SpineIntegrityError(f"Relationship class '{name}' must have correct type_id .", id=current_items[name]) if name in current_items: raise SpineIntegrityError( f"There can't be more than one relationship class with the name '{name}'.", id=current_items[name] ) -def check_wide_relationship( - wide_item, current_items_by_name, current_items_by_obj_lst, relationship_classes, objects, relationship_entity_type -): +def check_wide_relationship(wide_item, current_items_by_name, current_items_by_obj_lst, relationship_classes, objects): """Check whether the insertion of a relationship item results in the violation of an integrity constraint. @@ -219,11 +215,6 @@ def check_wide_relationship( f"Python KeyError: There is no dictionary key for the relationship class id of relationship '{name}'. " "Probably a bug, please report" ) - if "type_id" in wide_item and wide_item["type_id"] != relationship_entity_type: - raise SpineIntegrityError( - f"Relationship '{name}' does not have entity type of a relationship.", - id=current_items_by_name[class_id, name], - ) if (class_id, name) in current_items_by_name: raise SpineIntegrityError( f"There's already a relationship called '{name}' in the same class.", diff --git a/spinedb_api/db_mapping_add_mixin.py b/spinedb_api/db_mapping_add_mixin.py index 9ac6b624..4c1a7651 100644 --- a/spinedb_api/db_mapping_add_mixin.py +++ b/spinedb_api/db_mapping_add_mixin.py @@ -20,7 +20,6 @@ from sqlalchemy import func, Table, Column, Integer, String, null, select from sqlalchemy.exc import DBAPIError from .exception import SpineDBAPIError -from .helpers import get_relationship_entity_class_items, get_relationship_entity_items class DatabaseMappingAddMixin: diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 19cade85..17bd65f5 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -132,11 +132,6 @@ def __init__( self._tablenames = [t.name for t in self._metadata.sorted_tables] self.session = Session(self.connection, **self._session_kwargs) self.cache = DBCache(self._advance_cache_query) - # class and entity type id - self._object_class_type = None - self._relationship_class_type = None - self._object_entity_type = None - self._relationship_entity_type = None # Subqueries that select everything from each table self._commit_sq = None self._alternative_sq = None @@ -604,7 +599,7 @@ def ext_entity_class_sq(self): .outerjoin( entity_class_dimension_sq, self.entity_class_sq.c.id == entity_class_dimension_sq.c.entity_class_id ) - .group_by(entity_class_dimension_sq.c.entity_class_id) + .group_by(self.entity_class_sq.c.id) .subquery() ) return self._ext_entity_class_sq @@ -636,7 +631,7 @@ def ext_entity_sq(self): group_concat(entity_element_sq.c.element_id, entity_element_sq.c.position).label("element_id_list"), ) .outerjoin(entity_element_sq, self.entity_sq.c.id == entity_element_sq.c.entity_id) - .group_by(entity_element_sq.c.entity_id) + .group_by(self.entity_sq.c.id) .subquery() ) return self._ext_entity_sq @@ -1296,13 +1291,13 @@ def ext_entity_group_sq(self): self.entity_group_sq.c.entity_class_id.label("class_id"), self.entity_group_sq.c.entity_id.label("group_id"), self.entity_group_sq.c.member_id.label("member_id"), - self.entity_class_sq.c.name.label("class_name"), + self.ext_entity_class_sq.c.name.label("class_name"), group_entity.c.name.label("group_name"), member_entity.c.name.label("member_name"), label("object_class_id", self._object_class_id()), label("relationship_class_id", self._relationship_class_id()), ) - .filter(self.entity_group_sq.c.entity_class_id == self.entity_class_sq.c.id) + .filter(self.entity_group_sq.c.entity_class_id == self.ext_entity_class_sq.c.id) .join(group_entity, self.entity_group_sq.c.entity_id == group_entity.c.id) .join(member_entity, self.entity_group_sq.c.member_id == member_entity.c.id) .subquery() @@ -1322,7 +1317,7 @@ def entity_parameter_definition_sq(self): self.parameter_definition_sq.c.entity_class_id, self.parameter_definition_sq.c.object_class_id, self.parameter_definition_sq.c.relationship_class_id, - self.entity_class_sq.c.name.label("entity_class_name"), + self.ext_entity_class_sq.c.name.label("entity_class_name"), label("object_class_name", self._object_class_name()), label("relationship_class_name", self._relationship_class_name()), label("object_class_id_list", self._object_class_id_list()), @@ -1336,13 +1331,17 @@ def entity_parameter_definition_sq(self): self.parameter_definition_sq.c.description, self.parameter_definition_sq.c.commit_id, ) - .join(self.entity_class_sq, self.entity_class_sq.c.id == self.parameter_definition_sq.c.entity_class_id) + .join( + self.ext_entity_class_sq, + self.ext_entity_class_sq.c.id == self.parameter_definition_sq.c.entity_class_id, + ) .outerjoin( self.parameter_value_list_sq, self.parameter_value_list_sq.c.id == self.parameter_definition_sq.c.parameter_value_list_id, ) .outerjoin( - self.wide_relationship_class_sq, self.wide_relationship_class_sq.c.id == self.entity_class_sq.c.id + self.wide_relationship_class_sq, + self.wide_relationship_class_sq.c.id == self.ext_entity_class_sq.c.id, ) .subquery() ) @@ -1798,7 +1797,7 @@ def _make_parameter_definition_sq(self): par_def_sq.c.commit_id.label("commit_id"), par_def_sq.c.parameter_value_list_id.label("parameter_value_list_id"), ) - .join(self.entity_class_sq, self.entity_class_sq.c.id == par_def_sq.c.entity_class_id) + .join(self.ext_entity_class_sq, self.ext_entity_class_sq.c.id == par_def_sq.c.entity_class_id) .outerjoin(self.list_value_sq, self.list_value_sq.c.id == list_value_id) .subquery() ) @@ -1830,8 +1829,8 @@ def _make_parameter_value_sq(self): par_val_sq.c.commit_id.label("commit_id"), par_val_sq.c.alternative_id, ) - .join(self.entity_sq, self.entity_sq.c.id == par_val_sq.c.entity_id) - .join(self.entity_class_sq, self.entity_class_sq.c.id == par_val_sq.c.entity_class_id) + .join(self.ext_entity_sq, self.ext_entity_sq.c.id == par_val_sq.c.entity_id) + .join(self.ext_entity_class_sq, self.ext_entity_class_sq.c.id == par_val_sq.c.entity_class_id) .outerjoin(self.list_value_sq, self.list_value_sq.c.id == list_value_id) .subquery() ) @@ -2058,20 +2057,6 @@ def _do_advance_cache_query(self, tablename): for x in self.query(getattr(self, self.cache_sqs[tablename])).yield_per(1000).enable_eagerloads(False): table_cache.add_item(x._asdict()) - def _items_with_type_id(self, tablename, *items): - type_id = { - "object_class": self.object_class_type, - "relationship_class": self.relationship_class_type, - "object": self.object_entity_type, - "relationship": self.relationship_entity_type, - }.get(tablename) - if type_id is None: - yield from items - return - for item in items: - item["type_id"] = type_id - yield item - def _object_class_id(self): return case([(self.ext_entity_class_sq.c.dimension_id_list == None, self.ext_entity_class_sq.c.id)], else_=None) diff --git a/spinedb_api/db_mapping_check_mixin.py b/spinedb_api/db_mapping_check_mixin.py index a67fb5bf..569fd9e8 100644 --- a/spinedb_api/db_mapping_check_mixin.py +++ b/spinedb_api/db_mapping_check_mixin.py @@ -414,9 +414,7 @@ def check_wide_relationship_classes(self, *wide_items, for_update=False, strict= cache, intgr_error_log, ) as wide_item: - check_wide_relationship_class( - wide_item, relationship_class_ids, object_class_ids, self.relationship_class_type - ) + check_wide_relationship_class(wide_item, relationship_class_ids, object_class_ids) checked_wide_items.append(wide_item) except SpineIntegrityError as e: if strict: @@ -468,7 +466,6 @@ def check_wide_relationships(self, *wide_items, for_update=False, strict=False, relationship_ids_by_obj_lst, relationship_classes, objects, - self.relationship_entity_type, ) checked_wide_items.append(wide_item) except SpineIntegrityError as e: diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index 1a1efcf3..e11718ad 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -821,35 +821,6 @@ def forward_sweep(root, fn): break -def get_relationship_entity_class_items(item, object_class_type): - return [ - { - "entity_class_id": item["id"], - "dimension": dimension, - "member_class_id": object_class_id, - "member_class_type_id": object_class_type, - } - for dimension, object_class_id in enumerate(item["object_class_id_list"]) - ] - - -def get_relationship_entity_items(item, relationship_entity_type, object_entity_type): - return [ - { - "entity_id": item["id"], - "type_id": relationship_entity_type, - "entity_class_id": item["class_id"], - "dimension": dimension, - "member_id": object_id, - "member_class_type_id": object_entity_type, - "member_class_id": object_class_id, - } - for dimension, (object_id, object_class_id) in enumerate( - zip(item["object_id_list"], item["object_class_id_list"]) - ) - ] - - def labelled_columns(table): return [c.label(c.name) for c in table.columns] From cbaf3b2f0bacaa11476188946a6e98a2ddedfcf6 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 8 Feb 2023 16:54:28 +0100 Subject: [PATCH 003/317] Also update and remove objects and relationships via entities --- spinedb_api/db_mapping_add_mixin.py | 11 +++-------- spinedb_api/db_mapping_base.py | 16 ++++++++++++---- spinedb_api/db_mapping_remove_mixin.py | 9 +++++---- spinedb_api/db_mapping_update_mixin.py | 25 +++++++++---------------- 4 files changed, 29 insertions(+), 32 deletions(-) diff --git a/spinedb_api/db_mapping_add_mixin.py b/spinedb_api/db_mapping_add_mixin.py index 4c1a7651..1d38fb6a 100644 --- a/spinedb_api/db_mapping_add_mixin.py +++ b/spinedb_api/db_mapping_add_mixin.py @@ -105,14 +105,9 @@ def _do_reserve_ids(self, connection, tablename, count): next_id = getattr(next_id_row, fieldname) stmt = self._next_id.update() if next_id is None: - tablename = { - "object_class": "entity_class", - "relationship_class": "entity_class", - "object": "entity", - "relationship": "entity", - }.get(tablename, tablename) - table = self._metadata.tables[tablename] - id_col = self.table_ids.get(tablename, "id") + real_tablename = self._real_tablename(tablename) + table = self._metadata.tables[real_tablename] + id_col = self.table_ids.get(real_tablename, "id") select_max_id = select([func.max(getattr(table.c, id_col))]) max_id = connection.execute(select_max_id).scalar() next_id = max_id + 1 if max_id else 1 diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 17bd65f5..336db3e6 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -193,16 +193,16 @@ def __init__( self._table_to_sq_attr = {} # Table primary ids map: self.table_ids = { - "relationship_entity_class": "entity_class_id", "object_class": "entity_class_id", "relationship_class": "entity_class_id", + "entity_class_dimension": "entity_class_id", "object": "entity_id", "relationship": "entity_id", - "relationship_entity": "entity_id", + "entity_element": "entity_id", } self.composite_pks = { - "relationship_entity": ("entity_id", "dimension"), - "relationship_entity_class": ("entity_class_id", "dimension"), + "entity_element": ("entity_id", "position"), + "entity_class_dimension": ("entity_class_id", "position"), } # Subqueries used to populate cache self.cache_sqs = { @@ -315,6 +315,14 @@ def sorted_tablenames(self): tablenames.append(tablename) return sorted_tablenames + def _real_tablename(self, tablename): + return { + "object_class": "entity_class", + "relationship_class": "entity_class", + "object": "entity", + "relationship": "entity", + }.get(tablename, tablename) + def commit_id(self): return self._commit_id diff --git a/spinedb_api/db_mapping_remove_mixin.py b/spinedb_api/db_mapping_remove_mixin.py index adca2883..abf61aad 100644 --- a/spinedb_api/db_mapping_remove_mixin.py +++ b/spinedb_api/db_mapping_remove_mixin.py @@ -46,8 +46,9 @@ def remove_items(self, **kwargs): for tablename, ids in kwargs.items(): if not ids: continue - table_id = self.table_ids.get(tablename, "id") - table = self._metadata.tables[tablename] + real_tablename = self._real_tablename(tablename) + table_id = self.table_ids.get(real_tablename, "id") + table = self._metadata.tables[real_tablename] delete = table.delete().where(self.in_(getattr(table.c, table_id), ids)) try: self.connection.execute(delete) @@ -163,7 +164,7 @@ def _relationship_class_cascading_ids(self, ids, cache): """Returns relationship class cascading ids.""" cascading_ids = { "relationship_class": set(ids), - "relationship_entity_class": set(ids), + "entity_class_dimension": set(ids), "entity_class": set(ids), } relationships = [x for x in cache.get("relationship", {}).values() if x.class_id in ids] @@ -176,7 +177,7 @@ def _relationship_class_cascading_ids(self, ids, cache): def _relationship_cascading_ids(self, ids, cache): """Returns relationship cascading ids.""" - cascading_ids = {"relationship": set(ids), "entity": set(ids), "relationship_entity": set(ids)} + cascading_ids = {"relationship": set(ids), "entity": set(ids), "entity_element": set(ids)} parameter_values = [x for x in cache.get("parameter_value", {}).values() if x.entity_id in ids] groups = [x for x in cache.get("entity_group", {}).values() if {x.group_id, x.member_id}.intersection(ids)] entity_metadata_ids = {x.id for x in cache.get("entity_metadata", {}).values() if x.entity_id in ids} diff --git a/spinedb_api/db_mapping_update_mixin.py b/spinedb_api/db_mapping_update_mixin.py index d73e3780..9a8db577 100644 --- a/spinedb_api/db_mapping_update_mixin.py +++ b/spinedb_api/db_mapping_update_mixin.py @@ -33,13 +33,7 @@ def _update_items(self, tablename, *items): # Special cases if tablename == "relationship": return self._update_wide_relationships(*items) - real_tablename = { - "object_class": "entity_class", - "relationship_class": "entity_class", - "object": "entity", - "relationship": "entity", - }.get(tablename, tablename) - items = self._items_with_type_id(tablename, *items) + real_tablename = self._real_tablename(tablename) return self._do_update_items(real_tablename, *items) def _do_update_items(self, tablename, *items): @@ -125,9 +119,8 @@ def update_wide_relationships(self, *items, **kwargs): return self.update_items("relationship", *items, **kwargs) def _update_wide_relationships(self, *items): - items = self._items_with_type_id("relationship", *items) entity_items = [] - relationship_entity_items = [] + entity_element_items = [] for item in items: entity_id = item["id"] class_id = item["class_id"] @@ -137,21 +130,21 @@ def _update_wide_relationships(self, *items): "name": item["name"], "description": item.get("description"), } + entity_items.append(ent_item) object_class_id_list = item["object_class_id_list"] object_id_list = item["object_id_list"] - entity_items.append(ent_item) - for dimension, (member_class_id, member_id) in enumerate(zip(object_class_id_list, object_id_list)): + for position, (dimension_id, element_id) in enumerate(zip(object_class_id_list, object_id_list)): rel_ent_item = { "id": None, # Need to have an "id" field to make _update_items() happy. "entity_class_id": class_id, "entity_id": entity_id, - "dimension": dimension, - "member_class_id": member_class_id, - "member_id": member_id, + "position": position, + "dimension_id": dimension_id, + "element_id": element_id, } - relationship_entity_items.append(rel_ent_item) + entity_element_items.append(rel_ent_item) entity_ids = self._update_items("entity", *entity_items) - self._update_items("relationship_entity", *relationship_entity_items) + self._update_items("entity_element", *entity_element_items) return entity_ids def update_parameter_definitions(self, *items, **kwargs): From 285771eb45c713286b8a99adb9ed4debc5f41a06 Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 9 Feb 2023 07:10:35 +0100 Subject: [PATCH 004/317] Add migration script to drop object and relationship --- ...c61_drop_object_and_relationship_tables.py | 124 ++++++++++++++++++ spinedb_api/helpers.py | 10 +- 2 files changed, 129 insertions(+), 5 deletions(-) create mode 100644 spinedb_api/alembic/versions/6b7c994c1c61_drop_object_and_relationship_tables.py diff --git a/spinedb_api/alembic/versions/6b7c994c1c61_drop_object_and_relationship_tables.py b/spinedb_api/alembic/versions/6b7c994c1c61_drop_object_and_relationship_tables.py new file mode 100644 index 00000000..703c0547 --- /dev/null +++ b/spinedb_api/alembic/versions/6b7c994c1c61_drop_object_and_relationship_tables.py @@ -0,0 +1,124 @@ +"""drop_object_and_relationship_tables + +Revision ID: 6b7c994c1c61 +Revises: 989fccf80441 +Create Date: 2023-02-09 06:48:46.585108 + +""" +from alembic import op +import sqlalchemy as sa +from spinedb_api.helpers import naming_convention + + +# revision identifiers, used by Alembic. +revision = '6b7c994c1c61' +down_revision = '989fccf80441' +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + 'entity_class_dimension', + sa.Column('entity_class_id', sa.Integer(), nullable=False), + sa.Column('dimension_id', sa.Integer(), nullable=False), + sa.Column('position', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ['dimension_id'], + ['entity_class.id'], + name=op.f('fk_entity_class_dimension_dimension_id_entity_class'), + onupdate='CASCADE', + ondelete='CASCADE', + ), + sa.ForeignKeyConstraint( + ['entity_class_id'], + ['entity_class.id'], + name=op.f('fk_entity_class_dimension_entity_class_id_entity_class'), + onupdate='CASCADE', + ondelete='CASCADE', + ), + sa.PrimaryKeyConstraint('entity_class_id', 'dimension_id', 'position', name=op.f('pk_entity_class_dimension')), + sa.UniqueConstraint('entity_class_id', 'dimension_id', 'position', name='uq_entity_class_dimension'), + ) + op.create_table( + 'entity_element', + sa.Column('entity_id', sa.Integer(), nullable=False), + sa.Column('entity_class_id', sa.Integer(), nullable=False), + sa.Column('element_id', sa.Integer(), nullable=False), + sa.Column('dimension_id', sa.Integer(), nullable=False), + sa.Column('position', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ['element_id', 'dimension_id'], + ['entity.id', 'entity.class_id'], + name=op.f('fk_entity_element_element_id_entity'), + onupdate='CASCADE', + ondelete='CASCADE', + ), + sa.ForeignKeyConstraint( + ['entity_class_id', 'dimension_id', 'position'], + [ + 'entity_class_dimension.entity_class_id', + 'entity_class_dimension.dimension_id', + 'entity_class_dimension.position', + ], + name=op.f('fk_entity_element_entity_class_id_entity_class_dimension'), + onupdate='CASCADE', + ondelete='CASCADE', + ), + sa.ForeignKeyConstraint( + ['entity_id', 'entity_class_id'], + ['entity.id', 'entity.class_id'], + name=op.f('fk_entity_element_entity_id_entity'), + onupdate='CASCADE', + ondelete='CASCADE', + ), + sa.PrimaryKeyConstraint('entity_id', 'position', name=op.f('pk_entity_element')), + ) + _persist_data() + with op.batch_alter_table("entity", naming_convention=naming_convention) as batch_op: + batch_op.drop_constraint('uq_entity_idclass_id', type_='unique') + batch_op.drop_constraint('uq_entity_idtype_idclass_id', type_='unique') + batch_op.drop_constraint('fk_entity_type_id_entity_type', type_='foreignkey') + batch_op.drop_column('type_id') + with op.batch_alter_table("entity_class", naming_convention=naming_convention) as batch_op: + batch_op.drop_constraint('uq_entity_class_idtype_id', type_='unique') + batch_op.drop_constraint('uq_entity_class_type_idname', type_='unique') + batch_op.drop_constraint('fk_entity_class_type_id_entity_class_type', type_='foreignkey') + batch_op.drop_constraint('fk_entity_class_commit_id_commit', type_='foreignkey') + batch_op.drop_column('commit_id') + batch_op.drop_column('type_id') + op.drop_table('object_class') + op.drop_table('entity_class_type') + # op.drop_table('next_id') + op.drop_table('object') + op.drop_table('relationship_entity_class') + op.drop_table('relationship') + op.drop_table('entity_type') + op.drop_table('relationship_class') + op.drop_table('relationship_entity') + + +def _persist_data(): + conn = op.get_bind() + meta = sa.MetaData(conn) + meta.reflect() + ecd_items = [ + {"entity_class_id": x["entity_class_id"], "dimension_id": x["member_class_id"], "position": x["dimension"]} + for x in conn.execute("SELECT * FROM relationship_entity_class") + ] + ee_items = [ + { + "entity_id": x["entity_id"], + "entity_class_id": x["entity_class_id"], + "element_id": x["member_id"], + "dimension_id": x["member_class_id"], + "position": x["dimension"], + } + for x in conn.execute("SELECT * FROM relationship_entity") + ] + op.bulk_insert(meta.tables["entity_class_dimension"], ecd_items) + op.bulk_insert(meta.tables["entity_element"], ee_items) + + +def downgrade(): + pass diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index e11718ad..070fccea 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -622,7 +622,7 @@ def create_new_spine_database(db_url): try: engine = create_engine(db_url) except DatabaseError as e: - raise SpineDBAPIError("Could not connect to '{}': {}".format(db_url, e.orig.args)) from None + raise SpineDBAPIError(f"Could not connect to '{db_url}': {e.orig.args}") from None # Drop existing tables. This is a Spine db now... meta = MetaData(engine) meta.reflect() @@ -633,9 +633,9 @@ def create_new_spine_database(db_url): meta.create_all(engine) engine.execute("INSERT INTO `commit` VALUES (1, 'Create the database', CURRENT_TIMESTAMP, 'spinedb_api')") engine.execute("INSERT INTO alternative VALUES (1, 'Base', 'Base alternative', 1)") - engine.execute("INSERT INTO alembic_version VALUES ('989fccf80441')") + engine.execute("INSERT INTO alembic_version VALUES ('6b7c994c1c61')") except DatabaseError as e: - raise SpineDBAPIError("Unable to create Spine database: {}".format(e)) from None + raise SpineDBAPIError(f"Unable to create Spine database: {e}") from None return engine @@ -646,7 +646,7 @@ def _create_first_spine_database(db_url): try: engine = create_engine(db_url) except DatabaseError as e: - raise SpineDBAPIError("Could not connect to '{}': {}".format(db_url, e.orig.args)) from None + raise SpineDBAPIError(f"Could not connect to '{db_url}': {e.orig.args}") from None # Drop existing tables. This is a Spine db now... meta = MetaData(engine) meta.reflect() @@ -788,7 +788,7 @@ def _create_first_spine_database(db_url): try: meta.create_all(engine) except DatabaseError as e: - raise SpineDBAPIError("Unable to create Spine database: {}".format(e.orig.args)) + raise SpineDBAPIError(f"Unable to create Spine database: {e.orig.args}") return engine From 7571279438531eb3026833192380ac3a291a35dc Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 9 Feb 2023 14:45:37 +0100 Subject: [PATCH 005/317] Update DBCache to work with entity rather than object and relationship --- spinedb_api/db_cache.py | 137 ++++++++++++--------------------- spinedb_api/db_mapping_base.py | 42 ++++++---- 2 files changed, 76 insertions(+), 103 deletions(-) diff --git a/spinedb_api/db_cache.py b/spinedb_api/db_cache.py index 197c30c7..ccec6eae 100644 --- a/spinedb_api/db_cache.py +++ b/spinedb_api/db_cache.py @@ -56,10 +56,8 @@ def make_item(self, item_type, item): CacheItem """ factory = { - "object_class": ObjectClassItem, - "object": ObjectItem, - "relationship_class": RelationshipClassItem, - "relationship": RelationshipItem, + "entity_class": EntityClassItem, + "entity": EntityItem, "parameter_definition": ParameterDefinitionItem, "parameter_value": ParameterValueItem, "entity_group": EntityGroupItem, @@ -273,98 +271,67 @@ def __getitem__(self, key): return super().__getitem__(key) -class ObjectClassItem(DisplayIconMixin, DescriptionMixin, CacheItem): - pass - - -class ObjectItem(DescriptionMixin, CacheItem): - def __getitem__(self, key): - if key == "class_name": - return self._get_ref("object_class", self["class_id"], key).get("name") - return super().__getitem__(key) - - def _reference_keys(self): - return super()._reference_keys() + ("class_name",) - - -class ObjectClassIdListMixin: +class EntityClassItem(DisplayIconMixin, DescriptionMixin, CacheItem): def __init__(self, *args, **kwargs): - object_class_id_list = kwargs["object_class_id_list"] - if isinstance(object_class_id_list, str): - object_class_id_list = (int(id_) for id_ in object_class_id_list.split(",")) - kwargs["object_class_id_list"] = tuple(object_class_id_list) + dimension_id_list = kwargs["dimension_id_list"] + if dimension_id_list is None: + dimension_id_list = () + if isinstance(dimension_id_list, str): + dimension_id_list = (int(id_) for id_ in dimension_id_list.split(",")) + kwargs["dimension_id_list"] = tuple(dimension_id_list) super().__init__(*args, **kwargs) def __getitem__(self, key): - if key == "object_class_name_list": - return tuple(self._get_ref("object_class", id_, key).get("name") for id_ in self["object_class_id_list"]) + if key == "dimension_name_list": + return tuple(self._get_ref("entity_class", id_, key).get("name") for id_ in self["dimension_id_list"]) return super().__getitem__(key) def _reference_keys(self): - return super()._reference_keys() + ("object_class_name_list",) - - -class RelationshipClassItem(DisplayIconMixin, ObjectClassIdListMixin, DescriptionMixin, CacheItem): - pass + return super()._reference_keys() + ("dimension_name_list",) -class RelationshipItem(ObjectClassIdListMixin, CacheItem): - def __init__(self, db_cache, *args, **kwargs): - if "object_class_id_list" not in kwargs: - kwargs["object_class_id_list"] = db_cache.get_item("relationship_class", kwargs["class_id"]).get( - "object_class_id_list", () - ) - object_id_list = kwargs.get("object_id_list", ()) - if isinstance(object_id_list, str): - object_id_list = (int(id_) for id_ in object_id_list.split(",")) - kwargs["object_id_list"] = tuple(object_id_list) - super().__init__(db_cache, *args, **kwargs) +class EntityItem(DescriptionMixin, CacheItem): + def __init__(self, *args, **kwargs): + element_id_list = kwargs["element_id_list"] + if element_id_list is None: + element_id_list = () + if isinstance(element_id_list, str): + element_id_list = (int(id_) for id_ in element_id_list.split(",")) + kwargs["element_id_list"] = tuple(element_id_list) + super().__init__(*args, **kwargs) def __getitem__(self, key): if key == "class_name": - return self._get_ref("relationship_class", self["class_id"], key).get("name") - if key == "object_name_list": - return tuple(self._get_ref("object", id_, key).get("name") for id_ in self["object_id_list"]) + return self._get_ref("entity_class", self["class_id"], key).get("name") + if key == "dimension_id_list": + return self._get_ref("entity_class", self["class_id"], key).get("dimension_id_list") + if key == "dimension_name_list": + return self._get_ref("entity_class", self["class_id"], key).get("dimension_name_list") + if key == "element_name_list": + return tuple(self._get_ref("entity", id_, key).get("name") for id_ in self["element_id_list"]) return super().__getitem__(key) def _reference_keys(self): - return super()._reference_keys() + ("class_name", "object_name_list") + return super()._reference_keys() + ( + "class_name", + "dimension_id_list", + "dimension_name_list", + "element_name_list", + ) class ParameterMixin: - def __init__(self, *args, **kwargs): - if "entity_class_id" not in kwargs: - kwargs["entity_class_id"] = kwargs.get("object_class_id") or kwargs.get("relationship_class_id") - super().__init__(*args, **kwargs) - def __getitem__(self, key): - if key in ("object_class_id", "relationship_class_id"): - return dict.get(self, key) - if key == "object_class_name": - if self["object_class_id"] is None: - return None - return self._get_ref("object_class", self["object_class_id"], key).get("name") - if key == "relationship_class_name": - if self["relationship_class_id"] is None: - return None - return self._get_ref("relationship_class", self["relationship_class_id"], key).get("name") - if key in ("object_class_id_list", "object_class_name_list"): - if self["relationship_class_id"] is None: - return None - return self._get_ref("relationship_class", self["relationship_class_id"], key).get(key) + if key in ("dimension_id_list", "dimension_name_list"): + return self._get_ref("entity_class", self["entity_class_id"], key)[key] if key == "entity_class_name": - return self["relationship_class_name"] if self["object_class_id"] is None else self["object_class_name"] + return self._get_ref("entity_class", self["entity_class_id"], key)["name"] if key == "parameter_value_list_id": return dict.get(self, key) return super().__getitem__(key) def _reference_keys(self): - keys = super()._reference_keys() - if self["object_class_id"]: - keys += ("object_class_name",) - elif self["relationship_class_id"]: - keys += ("relationship_class_name", "object_class_id_list", "object_class_name_list") - return keys + return super()._reference_keys() + ("entity_class_name", "dimension_id_list", "dimension_name_list") class ParameterDefinitionItem(DescriptionMixin, ParameterMixin, CacheItem): @@ -391,27 +358,19 @@ def __getitem__(self, key): class ParameterValueItem(ParameterMixin, CacheItem): def __init__(self, *args, **kwargs): - if "entity_id" not in kwargs: - kwargs["entity_id"] = kwargs.get("object_id") or kwargs.get("relationship_id") if kwargs.get("list_value_id") is None: kwargs["list_value_id"] = int(kwargs["value"]) if kwargs.get("type") == "list_value_ref" else None super().__init__(*args, **kwargs) def __getitem__(self, key): - if key in ("object_id", "relationship_id"): - return dict.get(self, key) if key == "parameter_id": return super().__getitem__("parameter_definition_id") if key == "parameter_name": return self._get_ref("parameter_definition", self["parameter_definition_id"], key).get("name") - if key == "object_name": - if self["object_id"] is None: - return None - return self._get_ref("object", self["object_id"], key).get("name") - if key in ("object_id_list", "object_name_list"): - if self["relationship_id"] is None: - return None - return self._get_ref("relationship", self["relationship_id"], key).get(key) + if key == "entity_name": + return self._get_ref("entity", self["entity_id"], key)["name"] + if key in ("element_id_list", "element_name_list"): + return self._get_ref("entity", self["entity_id"], key)[key] if key == "alternative_name": return self._get_ref("alternative", self["alternative_id"], key).get("name") if key in ("value", "type") and self["list_value_id"] is not None: @@ -419,11 +378,13 @@ def __getitem__(self, key): return super().__getitem__(key) def _reference_keys(self): - keys = super()._reference_keys() + ("parameter_name", "alternative_name") - if self["object_id"]: - keys += ("object_name",) - elif self["relationship_id"]: - keys += ("object_id_list", "object_name_list") + keys = super()._reference_keys() + ( + "parameter_name", + "alternative_name", + "entity_name", + "element_id_list", + "element_name_list", + ) return keys diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 336db3e6..78753527 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -139,8 +139,8 @@ def __init__( self._scenario_alternative_sq = None self._entity_class_sq = None self._entity_sq = None - self._entity_class_type_sq = None - self._entity_type_sq = None + self._entity_class_dimension_sq = None + self._entity_element_sq = None self._object_class_sq = None self._object_sq = None self._relationship_class_sq = None @@ -206,7 +206,10 @@ def __init__( } # Subqueries used to populate cache self.cache_sqs = { - "entity": "entity_sq", + "entity_class": "ext_entity_class_sq", + "entity_class_dimension": "entity_class_dimension_sq", + "entity": "ext_entity_sq", + "entity_element": "entity_element_sq", "feature": "feature_sq", "tool": "tool_sq", "tool_feature": "tool_feature_sq", @@ -216,10 +219,6 @@ def __init__( "alternative": "alternative_sq", "scenario": "scenario_sq", "scenario_alternative": "scenario_alternative_sq", - "object_class": "object_class_sq", - "object": "object_sq", - "relationship_class": "wide_relationship_class_sq", - "relationship": "wide_relationship_sq", "entity_group": "entity_group_sq", "parameter_definition": "parameter_definition_sq", "parameter_value": "clean_parameter_value_sq", @@ -561,6 +560,18 @@ def entity_class_sq(self): self._entity_class_sq = self._make_entity_class_sq() return self._entity_class_sq + @property + def entity_class_dimension_sq(self): + if self._entity_class_dimension_sq is None: + self._entity_class_dimension_sq = self._subquery("entity_class_dimension") + return self._entity_class_dimension_sq + + @property + def entity_element_sq(self): + if self._entity_element_sq is None: + self._entity_element_sq = self._subquery("entity_element") + return self._entity_element_sq + @property def entity_sq(self): """A subquery of the form: @@ -596,16 +607,16 @@ def ext_entity_class_sq(self): sqlalchemy.sql.expression.Alias """ if self._ext_entity_class_sq is None: - entity_class_dimension_sq = self._subquery("entity_class_dimension") self._ext_entity_class_sq = ( self.query( self.entity_class_sq, - group_concat(entity_class_dimension_sq.c.dimension_id, entity_class_dimension_sq.c.position).label( - "dimension_id_list" - ), + group_concat( + self.entity_class_dimension_sq.c.dimension_id, self.entity_class_dimension_sq.c.position + ).label("dimension_id_list"), ) .outerjoin( - entity_class_dimension_sq, self.entity_class_sq.c.id == entity_class_dimension_sq.c.entity_class_id + self.entity_class_dimension_sq, + self.entity_class_sq.c.id == self.entity_class_dimension_sq.c.entity_class_id, ) .group_by(self.entity_class_sq.c.id) .subquery() @@ -632,13 +643,14 @@ def ext_entity_sq(self): sqlalchemy.sql.expression.Alias """ if self._ext_entity_sq is None: - entity_element_sq = self._subquery("entity_element") self._ext_entity_sq = ( self.query( self.entity_sq, - group_concat(entity_element_sq.c.element_id, entity_element_sq.c.position).label("element_id_list"), + group_concat(self.entity_element_sq.c.element_id, self.entity_element_sq.c.position).label( + "element_id_list" + ), ) - .outerjoin(entity_element_sq, self.entity_sq.c.id == entity_element_sq.c.entity_id) + .outerjoin(self.entity_element_sq, self.entity_sq.c.id == self.entity_element_sq.c.entity_id) .group_by(self.entity_sq.c.id) .subquery() ) From 043bc13dfda9464a699926681d6ebcb66ff1bd4f Mon Sep 17 00:00:00 2001 From: Manuel Date: Mon, 13 Feb 2023 08:07:51 +0100 Subject: [PATCH 006/317] Fix entity queries --- spinedb_api/db_mapping_base.py | 72 ++++++++++++++++++++++++++++------ 1 file changed, 60 insertions(+), 12 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 78753527..214dccd9 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -607,18 +607,42 @@ def ext_entity_class_sq(self): sqlalchemy.sql.expression.Alias """ if self._ext_entity_class_sq is None: - self._ext_entity_class_sq = ( + ecd_sq = ( self.query( - self.entity_class_sq, - group_concat( - self.entity_class_dimension_sq.c.dimension_id, self.entity_class_dimension_sq.c.position - ).label("dimension_id_list"), + self.entity_class_sq.c.id, + self.entity_class_sq.c.name, + self.entity_class_sq.c.description, + self.entity_class_sq.c.display_order, + self.entity_class_sq.c.display_icon, + self.entity_class_sq.c.hidden, + self.entity_class_dimension_sq.c.dimension_id, + self.entity_class_dimension_sq.c.position, ) .outerjoin( self.entity_class_dimension_sq, self.entity_class_sq.c.id == self.entity_class_dimension_sq.c.entity_class_id, ) - .group_by(self.entity_class_sq.c.id) + .order_by(self.entity_class_sq.c.id, self.entity_class_dimension_sq.c.position) + .subquery() + ) + self._ext_entity_class_sq = ( + self.query( + ecd_sq.c.id, + ecd_sq.c.name, + ecd_sq.c.description, + ecd_sq.c.display_order, + ecd_sq.c.display_icon, + ecd_sq.c.hidden, + group_concat(ecd_sq.c.dimension_id, ecd_sq.c.position).label("dimension_id_list"), + ) + .group_by( + ecd_sq.c.id, + ecd_sq.c.name, + ecd_sq.c.description, + ecd_sq.c.display_order, + ecd_sq.c.display_icon, + ecd_sq.c.hidden, + ) .subquery() ) return self._ext_entity_class_sq @@ -643,15 +667,39 @@ def ext_entity_sq(self): sqlalchemy.sql.expression.Alias """ if self._ext_entity_sq is None: + ee_sq = ( + self.query( + self.entity_sq.c.id, + self.entity_sq.c.class_id, + self.entity_sq.c.name, + self.entity_sq.c.description, + self.entity_sq.c.commit_id, + self.entity_element_sq.c.element_id, + self.entity_element_sq.c.position, + ) + .outerjoin( + self.entity_element_sq, + self.entity_sq.c.id == self.entity_element_sq.c.entity_id, + ) + .order_by(self.entity_sq.c.id, self.entity_element_sq.c.position) + .subquery() + ) self._ext_entity_sq = ( self.query( - self.entity_sq, - group_concat(self.entity_element_sq.c.element_id, self.entity_element_sq.c.position).label( - "element_id_list" - ), + ee_sq.c.id, + ee_sq.c.class_id, + ee_sq.c.name, + ee_sq.c.description, + ee_sq.c.commit_id, + group_concat(ee_sq.c.element_id, ee_sq.c.position).label("element_id_list"), + ) + .group_by( + ee_sq.c.id, + ee_sq.c.class_id, + ee_sq.c.name, + ee_sq.c.description, + ee_sq.c.commit_id, ) - .outerjoin(self.entity_element_sq, self.entity_sq.c.id == self.entity_element_sq.c.entity_id) - .group_by(self.entity_sq.c.id) .subquery() ) return self._ext_entity_sq From 8b8e021e9b9cda1ad7248071aef626cb372ff914 Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 14 Feb 2023 07:52:00 +0100 Subject: [PATCH 007/317] Initial support to add and update entities --- spinedb_api/check_functions.py | 83 +++++++++++++ spinedb_api/db_cache.py | 3 + spinedb_api/db_mapping_add_mixin.py | 44 ++++++- spinedb_api/db_mapping_base.py | 6 +- spinedb_api/db_mapping_check_mixin.py | 163 ++++++++++++++++++------- spinedb_api/db_mapping_update_mixin.py | 44 ++++++- 6 files changed, 285 insertions(+), 58 deletions(-) diff --git a/spinedb_api/check_functions.py b/spinedb_api/check_functions.py index 94bfa0dd..56b586a0 100644 --- a/spinedb_api/check_functions.py +++ b/spinedb_api/check_functions.py @@ -88,6 +88,89 @@ def check_scenario_alternative(item, ids_by_alt_id, ids_by_rank, scenario_names, ) +def check_entity_class(item, current_items): + """Check whether the insertion of an entity class item results in the violation of an integrity constraint. + + Args: + wide_item (dict): An entity class item to be checked. + current_items (dict): A dictionary mapping names to ids of entity classes already in the database. + + Raises: + SpineIntegrityError: if the insertion of the item violates an integrity constraint. + """ + try: + name = item["name"] + except KeyError: + raise SpineIntegrityError("The name for the entity class is missing.") + if not name: + raise SpineIntegrityError("Entity class name is an empty string, which is not valid") + try: + dimension_id_list = item["dimension_id_list"] + except KeyError: + item["dimension_id_list"] = dimension_id_list = () + if not all(id_ in current_items.values() for id_ in dimension_id_list): + raise SpineIntegrityError(f"One or more dimension ids for the entity class '{name}' are not in the database.") + if name in current_items: + raise SpineIntegrityError( + f"There can't be more than one entity class with the name '{name}'.", id=current_items[name] + ) + + +def check_entity(item, current_items_by_name, current_items_by_el_id_lst, entity_classes, entities): + """Check whether the insertion of an entity item results in the violation of an integrity constraint. + + Args: + wide_item (dict): An entity item to be checked. + current_items_by_name (dict): A dictionary mapping tuples (class_id, name) to ids of + entities already in the database. + current_items_by_el_id_lst (dict): A dictionary mapping tuples (class_id, element_name_list) to ids + of entities already in the database. + entity_classes (dict): A dictionary of entity class items in the database keyed by id. + entities (dict): A dictionary of entity items in the database keyed by id. + + Raises: + SpineIntegrityError: if the insertion of the item violates an integrity constraint. + """ + + try: + name = item["name"] + except KeyError: + raise SpineIntegrityError("The name for the entity is missing.") + if not name: + raise SpineIntegrityError("Entity name is an empty string, which is not valid") + try: + class_id = item["class_id"] + except KeyError: + raise SpineIntegrityError(f"The entity class id for entity '{name}' is missing.") + if (class_id, name) in current_items_by_name: + raise SpineIntegrityError( + f"There's already an entity called '{name}' in the same class.", + id=current_items_by_name[class_id, name], + ) + dimension_id_list = entity_classes[class_id]["dimension_id_list"] + if not dimension_id_list: + return + try: + element_id_list = tuple(item["element_id_list"]) + except KeyError: + item["element_id_list"] = element_id_list = () + try: + given_dimension_id_list = tuple(entities[id_]["class_id"] for id_ in element_id_list) + except KeyError: + raise SpineIntegrityError(f"Some of the elements in entity '{name}' are not in the database.") + if given_dimension_id_list != dimension_id_list: + element_name_list = [entities[id_]["name"] for id_ in element_id_list] + entity_class_name = entity_classes[class_id]["name"] + raise SpineIntegrityError(f"Incorrect elements '{element_name_list}' for entity class '{entity_class_name}'.") + if (class_id, element_id_list) in current_items_by_el_id_lst: + element_name_list = [entities[id]["name"] for id in element_id_list] + entity_class_name = entity_classes[class_id]["name"] + raise SpineIntegrityError( + f"There's already an entity with elements {element_name_list} in class {entity_class_name}.", + id=current_items_by_el_id_lst[class_id, element_id_list], + ) + + def check_object_class(item, current_items): """Check whether the insertion of an object class item results in the violation of an integrity constraint. diff --git a/spinedb_api/db_cache.py b/spinedb_api/db_cache.py index ccec6eae..de76316c 100644 --- a/spinedb_api/db_cache.py +++ b/spinedb_api/db_cache.py @@ -183,6 +183,9 @@ def get(self, key, default=None): def copy(self): return type(self)(self._db_cache, self._item_type, **self) + def updated(self, other): + return type(self)(self._db_cache, self._item_type, **{**self, **other}) + def is_valid(self): if self._valid is not None: return self._valid diff --git a/spinedb_api/db_mapping_add_mixin.py b/spinedb_api/db_mapping_add_mixin.py index 1d38fb6a..5703ab1f 100644 --- a/spinedb_api/db_mapping_add_mixin.py +++ b/spinedb_api/db_mapping_add_mixin.py @@ -76,9 +76,11 @@ def _reserve_ids(self, tablename, count): def _do_reserve_ids(self, connection, tablename, count): fieldname = { + "entity_class": "entity_class_id", "object_class": "entity_class_id", - "object": "entity_id", "relationship_class": "entity_class_id", + "entity": "entity_id", + "object": "entity_id", "relationship": "entity_id", "entity_group": "entity_group_id", "parameter_definition": "parameter_definition_id", @@ -221,7 +223,33 @@ def _items_to_add_per_table(self, tablename, items_to_add): Yields: tuple: database table name, items to add """ - if tablename == "object_class": + if tablename == "entity_class": + ecd_items_to_add = [] + for item in items_to_add: + ecd_items_to_add += [ + {"entity_class_id": item["id"], "position": position, "dimension_id": dimension_id} + for position, dimension_id in enumerate(item["dimension_id_list"]) + ] + yield ("entity_class", items_to_add) + yield ("entity_class_dimension", ecd_items_to_add) + elif tablename == "entity": + ee_items_to_add = [] + for item in items_to_add: + ee_items_to_add += [ + { + "entity_id": item["id"], + "entity_class_id": item["class_id"], + "position": position, + "element_id": element_id, + "dimension_id": dimension_id, + } + for position, (element_id, dimension_id) in enumerate( + zip(item["element_id_list"], item["dimension_id_list"]) + ) + ] + yield ("entity", items_to_add) + yield ("entity_element", ee_items_to_add) + elif tablename == "object_class": yield ("entity_class", items_to_add) elif tablename == "object": yield ("entity", items_to_add) @@ -242,10 +270,10 @@ def _items_to_add_per_table(self, tablename, items_to_add): "entity_id": item["id"], "entity_class_id": item["class_id"], "position": position, - "element_id": object_id, - "dimension_id": object_class_id, + "element_id": element_id, + "dimension_id": dimension_id, } - for position, (object_id, object_class_id) in enumerate( + for position, (element_id, dimension_id) in enumerate( zip(item["object_id_list"], item["object_class_id_list"]) ) ] @@ -273,6 +301,12 @@ def add_object_classes(self, *items, **kwargs): def add_objects(self, *items, **kwargs): return self.add_items("object", *items, **kwargs) + def add_entity_classes(self, *items, **kwargs): + return self.add_items("entity_class", *items, **kwargs) + + def add_entities(self, *items, **kwargs): + return self.add_items("entity", *items, **kwargs) + def add_wide_relationship_classes(self, *items, **kwargs): return self.add_items("relationship_class", *items, **kwargs) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 214dccd9..aaef64d1 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -60,13 +60,11 @@ class DatabaseMappingBase: _session_kwargs = {} ITEM_TYPES = ( - "object_class", - "relationship_class", + "entity_class", "parameter_value_list", "list_value", "parameter_definition", - "object", - "relationship", + "entity", "entity_group", "parameter_value", "alternative", diff --git a/spinedb_api/db_mapping_check_mixin.py b/spinedb_api/db_mapping_check_mixin.py index 569fd9e8..fec8b978 100644 --- a/spinedb_api/db_mapping_check_mixin.py +++ b/spinedb_api/db_mapping_check_mixin.py @@ -23,6 +23,8 @@ check_alternative, check_scenario, check_scenario_alternative, + check_entity_class, + check_entity, check_object_class, check_object, check_wide_relationship_class, @@ -54,6 +56,8 @@ def check_items(self, tablename, *items, for_update=False, strict=False, cache=N "alternative": self.check_alternatives, "scenario": self.check_scenarios, "scenario_alternative": self.check_scenario_alternatives, + "entity": self.check_entities, + "entity_class": self.check_entity_classes, "object": self.check_objects, "object_class": self.check_object_classes, "relationship_class": self.check_wide_relationship_classes, @@ -326,6 +330,86 @@ def check_scenario_alternatives(self, *items, for_update=False, strict=False, ca intgr_error_log.append(e) return checked_items, intgr_error_log + def check_entity_classes(self, *items, for_update=False, strict=False, cache=None): + """Check whether entity classes passed as argument respect integrity constraints. + + Args: + items (Iterable): One or more Python :class:`dict` objects representing the items to be checked. + strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` + if one of the items violates an integrity constraint. + + Returns + list: items that passed the check. + list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. + """ + if cache is None: + cache = self.make_cache({"entity_class"}, include_ancestors=True) + intgr_error_log = [] + checked_items = list() + entity_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values()} + for item in items: + try: + with self._manage_stocks( + "entity_class", item, {("name",): entity_class_ids}, for_update, cache, intgr_error_log + ) as item: + check_entity_class(item, entity_class_ids) + checked_items.append(item) + except SpineIntegrityError as e: + if strict: + raise e + intgr_error_log.append(e) + return checked_items, intgr_error_log + + def check_entities(self, *items, for_update=False, strict=False, cache=None): + """Check whether entities passed as argument respect integrity constraints. + + Args: + items (Iterable): One or more Python :class:`dict` objects representing the items to be checked. + strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` + if one of the items violates an integrity constraint. + + Returns + list: items that passed the check. + list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. + """ + if cache is None: + cache = self.make_cache({"entity"}, include_ancestors=True) + intgr_error_log = [] + checked_items = list() + entity_ids_by_name = {(x.class_id, x.name): x.id for x in cache.get("entity", {}).values()} + entity_ids_by_el_id_lst = {(x.class_id, x.element_id_list): x.id for x in cache.get("entity", {}).values()} + entity_classes = { + x.id: {"dimension_id_list": x.dimension_id_list, "name": x.name} + for x in cache.get("entity_class", {}).values() + } + entities = {x.id: {"class_id": x.class_id, "name": x.name} for x in cache.get("entity", {}).values()} + for item in items: + try: + with self._manage_stocks( + "entity", + item, + { + ("class_id", "name"): entity_ids_by_name, + ("class_id", "element_id_list"): entity_ids_by_el_id_lst, + }, + for_update, + cache, + intgr_error_log, + ) as item: + check_entity( + item, + entity_ids_by_name, + entity_ids_by_el_id_lst, + entity_classes, + entities, + ) + checked_items.append(item) + except SpineIntegrityError as e: + if strict: + raise e + intgr_error_log.append(e) + return checked_items, intgr_error_log + def check_object_classes(self, *items, for_update=False, strict=False, cache=None): """Check whether object classes passed as argument respect integrity constraints. @@ -339,14 +423,14 @@ def check_object_classes(self, *items, for_update=False, strict=False, cache=Non list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. """ if cache is None: - cache = self.make_cache({"object_class"}, include_ancestors=True) + cache = self.make_cache({"entity_class"}, include_ancestors=True) intgr_error_log = [] checked_items = list() - object_class_ids = {x.name: x.id for x in cache.get("object_class", {}).values()} + object_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values() if not x.dimension_id_list} for item in items: try: with self._manage_stocks( - "object_class", item, {("name",): object_class_ids}, for_update, cache, intgr_error_log + "entity_class", item, {("name",): object_class_ids}, for_update, cache, intgr_error_log ) as item: check_object_class(item, object_class_ids) checked_items.append(item) @@ -368,15 +452,15 @@ def check_objects(self, *items, for_update=False, strict=False, cache=None): list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. """ if cache is None: - cache = self.make_cache({"object"}, include_ancestors=True) + cache = self.make_cache({"entity"}, include_ancestors=True) intgr_error_log = [] checked_items = list() - object_ids = {(x.class_id, x.name): x.id for x in cache.get("object", {}).values()} - object_class_ids = [x.id for x in cache.get("object_class", {}).values()] + object_ids = {(x.class_id, x.name): x.id for x in cache.get("entity", {}).values() if not x.element_id_list} + object_class_ids = [x.id for x in cache.get("entity_class", {}).values() if not x.dimension_id_list] for item in items: try: with self._manage_stocks( - "object", item, {("class_id", "name"): object_ids}, for_update, cache, intgr_error_log + "entity", item, {("class_id", "name"): object_ids}, for_update, cache, intgr_error_log ) as item: check_object(item, object_ids, object_class_ids) checked_items.append(item) @@ -399,21 +483,22 @@ def check_wide_relationship_classes(self, *wide_items, for_update=False, strict= list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. """ if cache is None: - cache = self.make_cache({"relationship_class"}, include_ancestors=True) + cache = self.make_cache({"entity_class"}, include_ancestors=True) intgr_error_log = [] checked_wide_items = list() - relationship_class_ids = {x.name: x.id for x in cache.get("relationship_class", {}).values()} - object_class_ids = [x.id for x in cache.get("object_class", {}).values()] + relationship_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values() if x.dimension_id_list} + object_class_ids = [x.id for x in cache.get("entity_class", {}).values() if not x.dimension_id_list] for wide_item in wide_items: try: with self._manage_stocks( - "relationship_class", + "entity_class", wide_item, {("name",): relationship_class_ids}, for_update, cache, intgr_error_log, ) as wide_item: + wide_item["object_class_id_list"] = wide_item.pop("dimension_id_list", ()) check_wide_relationship_class(wide_item, relationship_class_ids, object_class_ids) checked_wide_items.append(wide_item) except SpineIntegrityError as e: @@ -435,31 +520,36 @@ def check_wide_relationships(self, *wide_items, for_update=False, strict=False, list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. """ if cache is None: - cache = self.make_cache({"relationship"}, include_ancestors=True) + cache = self.make_cache({"entity"}, include_ancestors=True) intgr_error_log = [] checked_wide_items = list() - relationship_ids_by_name = {(x.class_id, x.name): x.id for x in cache.get("relationship", {}).values()} + relationship_ids_by_name = { + (x.class_id, x.name): x.id for x in cache.get("entity", {}).values() if x.element_id_list + } relationship_ids_by_obj_lst = { - (x.class_id, x.object_id_list): x.id for x in cache.get("relationship", {}).values() + (x.class_id, x.element_id_list): x.id for x in cache.get("entity", {}).values() if x.element_id_list } relationship_classes = { - x.id: {"object_class_id_list": x.object_class_id_list, "name": x.name} - for x in cache.get("relationship_class", {}).values() + x.id: {"object_class_id_list": x.dimension_id_list, "name": x.name} + for x in cache.get("entity_class", {}).values() + if x.dimension_id_list } objects = {x.id: {"class_id": x.class_id, "name": x.name} for x in cache.get("object", {}).values()} for wide_item in wide_items: try: with self._manage_stocks( - "relationship", + "entity", wide_item, { ("class_id", "name"): relationship_ids_by_name, - ("class_id", "object_id_list"): relationship_ids_by_obj_lst, + ("class_id", "element_id_list"): relationship_ids_by_obj_lst, }, for_update, cache, intgr_error_log, ) as wide_item: + wide_item["object_class_id_list"] = wide_item.pop("dimension_id_list", ()) + wide_item["object_id_list"] = wide_item.pop("element_id_list", ()) check_wide_relationship( wide_item, relationship_ids_by_name, @@ -492,7 +582,7 @@ def check_entity_groups(self, *items, for_update=False, strict=False, cache=None checked_items = list() current_ids = {(x.group_id, x.member_id): x.id for x in cache.get("entity_group", {}).values()} entities = {} - for entity in chain(cache.get("object", {}).values(), cache.get("relationship", {}).values()): + for entity in cache.get("entity", {}).values(): entities.setdefault(entity.class_id, dict())[entity.id] = entity._asdict() for item in items: try: @@ -529,29 +619,12 @@ def check_parameter_definitions(self, *items, for_update=False, strict=False, ca parameter_definition_ids = { (x.entity_class_id, x.parameter_name): x.id for x in cache.get("parameter_definition", {}).values() } - object_class_ids = {x.id for x in cache.get("object_class", {}).values()} - relationship_class_ids = {x.id for x in cache.get("relationship_class", {}).values()} - entity_class_ids = object_class_ids | relationship_class_ids + entity_class_ids = {x.id for x in cache.get("entity_class", {}).values()} parameter_value_lists = {x.id: x.value_id_list for x in cache.get("parameter_value_list", {}).values()} list_values = {x.id: from_database(x.value, x.type) for x in cache.get("list_value", {}).values()} for item in items: - object_class_id = item.get("object_class_id") - relationship_class_id = item.get("relationship_class_id") - if object_class_id and relationship_class_id: - e = SpineIntegrityError("Can't associate a parameter to both an object and a relationship class.") - if strict: - raise e - intgr_error_log.append(e) - continue - if object_class_id: - class_ids = object_class_ids - elif relationship_class_id: - class_ids = relationship_class_ids - else: - class_ids = entity_class_ids - entity_class_id = object_class_id or relationship_class_id - if entity_class_id is not None: - item["entity_class_id"] = entity_class_id + if "entity_class_id" not in item: + item["entity_class_id"] = item.get("object_class_id") or item.get("relationship_class_id") try: if ( for_update @@ -570,7 +643,7 @@ def check_parameter_definitions(self, *items, for_update=False, strict=False, ca intgr_error_log, ) as full_item: check_parameter_definition( - full_item, parameter_definition_ids, class_ids, parameter_value_lists, list_values + full_item, parameter_definition_ids, entity_class_ids, parameter_value_lists, list_values ) checked_items.append(full_item) except SpineIntegrityError as e: @@ -606,17 +679,13 @@ def check_parameter_values(self, *items, for_update=False, strict=False, cache=N } for x in cache.get("parameter_definition", {}).values() } - entities = { - x.id: {"class_id": x.class_id, "name": x.name} - for x in chain(cache.get("object", {}).values(), cache.get("relationship", {}).values()) - } + entities = {x.id: {"class_id": x.class_id, "name": x.name} for x in cache.get("entity", {}).values()} parameter_value_lists = {x.id: x.value_id_list for x in cache.get("parameter_value_list", {}).values()} list_values = {x.id: from_database(x.value, x.type) for x in cache.get("list_value", {}).values()} alternatives = set(a.id for a in cache.get("alternative", {}).values()) for item in items: - entity_id = item.get("object_id") or item.get("relationship_id") - if entity_id is not None: - item["entity_id"] = entity_id + if "entity_id" not in item: + item["entity_id"] = item.get("object_id") or item.get("relationship_id") try: with self._manage_stocks( "parameter_value", diff --git a/spinedb_api/db_mapping_update_mixin.py b/spinedb_api/db_mapping_update_mixin.py index 9a8db577..927a9b88 100644 --- a/spinedb_api/db_mapping_update_mixin.py +++ b/spinedb_api/db_mapping_update_mixin.py @@ -31,6 +31,8 @@ def _update_items(self, tablename, *items): if not items: return set() # Special cases + if tablename == "entity": + return self._update_entities(*items) if tablename == "relationship": return self._update_wide_relationships(*items) real_tablename = self._real_tablename(tablename) @@ -97,6 +99,44 @@ def update_scenario_alternatives(self, *items, **kwargs): def _update_scenario_alternatives(self, *items): return self._update_items("scenario_alternative", *items) + def update_entity_classes(self, *items, **kwargs): + return self.update_items("entity_class", *items, **kwargs) + + def _update_entity_classes(self, *items): + return self._update_items("entity_class", *items) + + def update_entities(self, *items, **kwargs): + return self.update_items("entity", *items, **kwargs) + + def _update_entities(self, *items): + entity_items = [] + entity_element_items = [] + for item in items: + entity_id = item["id"] + class_id = item["class_id"] + ent_item = { + "id": entity_id, + "class_id": class_id, + "name": item["name"], + "description": item.get("description"), + } + entity_items.append(ent_item) + dimension_id_list = item["dimension_id_list"] + element_id_list = item["element_id_list"] + for position, (dimension_id, element_id) in enumerate(zip(dimension_id_list, element_id_list)): + rel_ent_item = { + "id": None, # Need to have an "id" field to make _update_items() happy. + "entity_class_id": class_id, + "entity_id": entity_id, + "position": position, + "dimension_id": dimension_id, + "element_id": element_id, + } + entity_element_items.append(rel_ent_item) + entity_ids = self._do_update_items("entity", *entity_items) + self._do_update_items("entity_element", *entity_element_items) + return entity_ids + def update_object_classes(self, *items, **kwargs): return self.update_items("object_class", *items, **kwargs) @@ -143,8 +183,8 @@ def _update_wide_relationships(self, *items): "element_id": element_id, } entity_element_items.append(rel_ent_item) - entity_ids = self._update_items("entity", *entity_items) - self._update_items("entity_element", *entity_element_items) + entity_ids = self._do_update_items("entity", *entity_items) + self._do_update_items("entity_element", *entity_element_items) return entity_ids def update_parameter_definitions(self, *items, **kwargs): From 5c814e55c5f78d4f94e2183cdd06891d41a3cd01 Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 14 Feb 2023 15:39:31 +0100 Subject: [PATCH 008/317] Update export/import functions to work with entity --- spinedb_api/db_cache.py | 24 +- spinedb_api/db_mapping_add_mixin.py | 6 + spinedb_api/db_mapping_base.py | 24 +- spinedb_api/export_functions.py | 132 +++++-- spinedb_api/import_functions.py | 544 +++++++++++++++++++++++++--- 5 files changed, 626 insertions(+), 104 deletions(-) diff --git a/spinedb_api/db_cache.py b/spinedb_api/db_cache.py index de76316c..0d808308 100644 --- a/spinedb_api/db_cache.py +++ b/spinedb_api/db_cache.py @@ -381,14 +381,13 @@ def __getitem__(self, key): return super().__getitem__(key) def _reference_keys(self): - keys = super()._reference_keys() + ( + return super()._reference_keys() + ( "parameter_name", "alternative_name", "entity_name", "element_id_list", "element_name_list", ) - return keys class EntityGroupItem(CacheItem): @@ -398,26 +397,17 @@ def __getitem__(self, key): if key == "group_id": return self["entity_id"] if key == "class_name": - return ( - self._get_ref("object_class", self["entity_class_id"], key) - or self._get_ref("relationship_class", self["entity_class_id"], key) - ).get("name") + return self._get_ref("entity_class", self["entity_class_id"], key)["name"] if key == "group_name": - return ( - self._get_ref("object", self["entity_id"], key) or self._get_ref("relationship", self["entity_id"], key) - ).get("name") + return self._get_ref("entity", self["entity_id"], key)["name"] if key == "member_name": - return ( - self._get_ref("object", self["member_id"], key) or self._get_ref("relationship", self["member_id"], key) - ).get("name") - if key == "object_class_id": - return self._get_ref("object_class", self["entity_class_id"], key).get("id") - if key == "relationship_class_id": - return self._get_ref("relationship_class", self["entity_class_id"], key).get("id") + return self._get_ref("entity", self["member_id"], key)["name"] + if key == "dimension_id_list": + return self._get_ref("entity_class", self["entity_class_id"], key)["dimension_id_list"] return super().__getitem__(key) def _reference_keys(self): - return super()._reference_keys() + ("class_name", "group_name", "member_name") + return super()._reference_keys() + ("class_name", "group_name", "member_name", "dimension_id_list") class ScenarioItem(CacheItem): diff --git a/spinedb_api/db_mapping_add_mixin.py b/spinedb_api/db_mapping_add_mixin.py index 5703ab1f..4557b812 100644 --- a/spinedb_api/db_mapping_add_mixin.py +++ b/spinedb_api/db_mapping_add_mixin.py @@ -419,6 +419,12 @@ def add_ext_parameter_value_metadata( "parameter_value_metadata", *items, check=check, strict=strict, return_items=return_items, cache=cache ) + def _add_entity_classes(self, *items): + return self._add_items("entity_class", *items) + + def _add_entities(self, *items): + return self._add_items("entities", *items) + def _add_object_classes(self, *items): return self._add_items("object_class", *items) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index aaef64d1..4db3960b 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -205,9 +205,7 @@ def __init__( # Subqueries used to populate cache self.cache_sqs = { "entity_class": "ext_entity_class_sq", - "entity_class_dimension": "entity_class_dimension_sq", "entity": "ext_entity_sq", - "entity_element": "entity_element_sq", "feature": "feature_sq", "tool": "tool_sq", "tool_feature": "tool_feature_sq", @@ -230,30 +228,24 @@ def __init__( "tool_feature": ("tool", "feature"), "tool_feature_method": ("tool_feature", "parameter_value_list", "list_value"), "scenario_alternative": ("scenario", "alternative"), - "relationship_class": ("object_class",), - "object": ("object_class",), - "entity_group": ("object_class", "relationship_class", "object", "relationship"), - "relationship": ("relationship_class", "object"), - "parameter_definition": ("object_class", "relationship_class", "parameter_value_list", "list_value"), + "entity": ("entity_class",), + "entity_group": ("entity_class", "entity"), + "parameter_definition": ("entity_class", "parameter_value_list", "list_value"), "parameter_value": ( "alternative", - "object_class", - "relationship_class", - "object", - "relationship", + "entity_class", + "entity", "parameter_definition", "parameter_value_list", "list_value", ), - "entity_metadata": ("metadata", "object", "object_class", "relationship", "relationship_class"), + "entity_metadata": ("metadata", "entity_class", "entity"), "parameter_value_metadata": ( "metadata", "parameter_value", "parameter_definition", - "object", - "object_class", - "relationship", - "relationship_class", + "entity_class", + "entity", "alternative", ), "list_value": ("parameter_value_list",), diff --git a/spinedb_api/export_functions.py b/spinedb_api/export_functions.py index 469dd4e4..65bc5008 100644 --- a/spinedb_api/export_functions.py +++ b/spinedb_api/export_functions.py @@ -24,6 +24,11 @@ def export_data( db_map, + entity_class_ids=Asterisk, + entity_ids=Asterisk, + parameter_definition_ids=Asterisk, + parameter_value_ids=Asterisk, + entity_group_ids=Asterisk, object_class_ids=Asterisk, relationship_class_ids=Asterisk, parameter_value_list_ids=Asterisk, @@ -70,6 +75,15 @@ def export_data( dict: exported data """ data = { + "entity_classes": export_entity_classes(db_map, entity_class_ids, make_cache=make_cache), + "entities": export_entities(db_map, entity_ids, make_cache=make_cache), + "entity_groups": export_entity_groups(db_map, entity_group_ids, make_cache=make_cache), + "parameter_definitions": export_parameter_definitions( + db_map, parameter_definition_ids, make_cache=make_cache, parse_value=parse_value + ), + "parameter_values": export_parameter_values( + db_map, parameter_value_ids, make_cache=make_cache, parse_value=parse_value + ), "object_classes": export_object_classes(db_map, object_class_ids, make_cache=make_cache), "relationship_classes": export_relationship_classes(db_map, relationship_class_ids, make_cache=make_cache), "parameter_value_lists": export_parameter_value_lists( @@ -145,74 +159,134 @@ def __call__(self, item): yield KeyedTuple([item.name, val.value, val.type], fields) -def export_object_classes(db_map, ids=Asterisk, make_cache=None): - return sorted((x.name, x.description, x.display_icon) for x in _get_items(db_map, "object_class", ids, make_cache)) +def export_parameter_value_lists(db_map, ids=Asterisk, make_cache=None, parse_value=from_database): + return sorted( + ((x.name, parse_value(x.value, x.type)) for x in _get_items(db_map, "parameter_value_list", ids, make_cache)), + key=itemgetter(0), + ) -def export_objects(db_map, ids=Asterisk, make_cache=None): - return sorted((x.class_name, x.name, x.description) for x in _get_items(db_map, "object", ids, make_cache)) +def export_entity_classes(db_map, ids=Asterisk, make_cache=None): + return sorted( + (x.name, x.description, x.display_icon, x.dimension_name_list) + for x in _get_items(db_map, "entity_class", ids, make_cache) + ) -def export_relationship_classes(db_map, ids=Asterisk, make_cache=None): +def export_entities(db_map, ids=Asterisk, make_cache=None): return sorted( - (x.name, x.object_class_name_list, x.description, x.display_icon) - for x in _get_items(db_map, "relationship_class", ids, make_cache) + (x.class_name, x.element_name_list or x.name, x.description) + for x in _get_items(db_map, "entity", ids, make_cache) ) -def export_parameter_value_lists(db_map, ids=Asterisk, make_cache=None, parse_value=from_database): +def export_entity_groups(db_map, ids=Asterisk, make_cache=None): return sorted( - ((x.name, parse_value(x.value, x.type)) for x in _get_items(db_map, "parameter_value_list", ids, make_cache)), - key=itemgetter(0), + (x.class_name, x.group_name, x.member_name) for x in _get_items(db_map, "entity_group", ids, make_cache) ) -def export_object_parameters(db_map, ids=Asterisk, make_cache=None, parse_value=from_database): +def export_parameter_definitions(db_map, ids=Asterisk, make_cache=None, parse_value=from_database): return sorted( ( - x.object_class_name, + x.entity_class_name, x.parameter_name, parse_value(x.default_value, x.default_type), x.value_list_name, x.description, ) for x in _get_items(db_map, "parameter_definition", ids, make_cache) - if x.object_class_id ) -def export_relationship_parameters(db_map, ids=Asterisk, make_cache=None, parse_value=from_database): +def export_parameter_values(db_map, ids=Asterisk, make_cache=None, parse_value=from_database): return sorted( ( - x.relationship_class_name, - x.parameter_name, - parse_value(x.default_value, x.default_type), - x.value_list_name, - x.description, - ) - for x in _get_items(db_map, "parameter_definition", ids, make_cache) - if x.relationship_class_id + ( + x.entity_class_name, + x.element_name_list or x.name, + x.parameter_name, + parse_value(x.value, x.type), + x.alternative_name, + ) + for x in _get_items(db_map, "parameter_value", ids, make_cache) + ), + key=lambda x: x[:3] + (x[-1],), + ) + + +def export_object_classes(db_map, ids=Asterisk, make_cache=None): + return sorted( + (x.name, x.description, x.display_icon) + for x in _get_items(db_map, "entity_class", ids, make_cache) + if not x.dimension_id_list + ) + + +def export_relationship_classes(db_map, ids=Asterisk, make_cache=None): + return sorted( + (x.name, x.dimension_name_list, x.description, x.display_icon) + for x in _get_items(db_map, "entity_class", ids, make_cache) + if x.dimension_id_list + ) + + +def export_objects(db_map, ids=Asterisk, make_cache=None): + return sorted( + (x.class_name, x.name, x.description) + for x in _get_items(db_map, "entity", ids, make_cache) + if not x.element_id_list ) def export_relationships(db_map, ids=Asterisk, make_cache=None): - return sorted((x.class_name, x.object_name_list) for x in _get_items(db_map, "relationship", ids, make_cache)) + return sorted( + (x.class_name, x.element_name_list) for x in _get_items(db_map, "entity", ids, make_cache) if x.element_id_list + ) def export_object_groups(db_map, ids=Asterisk, make_cache=None): return sorted( (x.class_name, x.group_name, x.member_name) for x in _get_items(db_map, "entity_group", ids, make_cache) - if x.object_class_id + if not x.dimension_id_list + ) + + +def export_object_parameters(db_map, ids=Asterisk, make_cache=None, parse_value=from_database): + return sorted( + ( + x.entity_class_name, + x.parameter_name, + parse_value(x.default_value, x.default_type), + x.value_list_name, + x.description, + ) + for x in _get_items(db_map, "parameter_definition", ids, make_cache) + if not x.dimension_id_list + ) + + +def export_relationship_parameters(db_map, ids=Asterisk, make_cache=None, parse_value=from_database): + return sorted( + ( + x.entity_class_name, + x.parameter_name, + parse_value(x.default_value, x.default_type), + x.value_list_name, + x.description, + ) + for x in _get_items(db_map, "parameter_definition", ids, make_cache) + if x.dimension_id_list ) def export_object_parameter_values(db_map, ids=Asterisk, make_cache=None, parse_value=from_database): return sorted( ( - (x.object_class_name, x.object_name, x.parameter_name, parse_value(x.value, x.type), x.alternative_name) + (x.entity_class_name, x.entity_name, x.parameter_name, parse_value(x.value, x.type), x.alternative_name) for x in _get_items(db_map, "parameter_value", ids, make_cache) - if x.object_id + if not x.element_id_list ), key=lambda x: x[:3] + (x[-1],), ) @@ -222,14 +296,14 @@ def export_relationship_parameter_values(db_map, ids=Asterisk, make_cache=None, return sorted( ( ( - x.relationship_class_name, - x.object_name_list, + x.entity_class_name, + x.element_name_list, x.parameter_name, parse_value(x.value, x.type), x.alternative_name, ) for x in _get_items(db_map, "parameter_value", ids, make_cache) - if x.relationship_id + if x.element_id_list ), key=lambda x: x[:3] + (x[-1],), ) diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 5342b37e..e495e88c 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -19,6 +19,8 @@ import uuid from .exception import SpineIntegrityError, SpineDBAPIError from .check_functions import ( + check_entity_class, + check_entity, check_tool, check_feature, check_tool_feature, @@ -121,6 +123,7 @@ def import_data(db_map, make_cache=None, unparse_value=to_database, on_conflict= "alternative": db_map._add_alternatives, "scenario": db_map._add_scenarios, "scenario_alternative": db_map._add_scenario_alternatives, + "entity_class": db_map._add_entity_classes, "object_class": db_map._add_object_classes, "relationship_class": db_map._add_wide_relationship_classes, "parameter_value_list": db_map._add_parameter_value_lists, @@ -130,6 +133,7 @@ def import_data(db_map, make_cache=None, unparse_value=to_database, on_conflict= "tool": db_map._add_tools, "tool_feature": db_map._add_tool_features, "tool_feature_method": db_map._add_tool_feature_methods, + "entity": db_map._add_entities, "object": db_map._add_objects, "relationship": db_map._add_wide_relationships, "entity_group": db_map._add_entity_groups, @@ -142,6 +146,7 @@ def import_data(db_map, make_cache=None, unparse_value=to_database, on_conflict= "alternative": db_map._update_alternatives, "scenario": db_map._update_scenarios, "scenario_alternative": db_map._update_scenario_alternatives, + "entity_class": db_map._update_entity_classes, "object_class": db_map._update_object_classes, "relationship_class": db_map._update_wide_relationship_classes, "parameter_value_list": db_map._update_parameter_value_lists, @@ -150,6 +155,7 @@ def import_data(db_map, make_cache=None, unparse_value=to_database, on_conflict= "feature": db_map._update_features, "tool": db_map._update_tools, "tool_feature": db_map._update_tool_features, + "entity": db_map._update_entities, "object": db_map._update_objects, "parameter_value": db_map._update_parameter_values, } @@ -180,6 +186,11 @@ def get_data_for_import( make_cache=None, unparse_value=to_database, on_conflict="merge", + entity_classes=(), + entities=(), + parameter_definitions=(), + parameter_values=(), + entity_groups=(), object_classes=(), relationship_classes=(), parameter_value_lists=(), @@ -247,6 +258,8 @@ def get_data_for_import( alternatives = (item[1] for item in scenario_alternatives) yield ("alternative", _get_alternatives_for_import(alternatives, make_cache)) yield ("scenario_alternative", _get_scenario_alternatives_for_import(scenario_alternatives, make_cache)) + if entity_classes: + yield ("entity_class", _get_entity_classes_for_import(db_map, entity_classes, make_cache)) if object_classes: yield ("object_class", _get_object_classes_for_import(db_map, object_classes, make_cache)) if relationship_classes: @@ -254,6 +267,11 @@ def get_data_for_import( if parameter_value_lists: yield ("parameter_value_list", _get_parameter_value_lists_for_import(db_map, parameter_value_lists, make_cache)) yield ("list_value", _get_list_values_for_import(db_map, parameter_value_lists, make_cache, unparse_value)) + if parameter_definitions: + yield ( + "parameter_definition", + _get_parameter_definitions_for_import(db_map, parameter_definitions, make_cache, unparse_value), + ) if object_parameters: yield ( "parameter_definition", @@ -275,12 +293,21 @@ def get_data_for_import( "tool_feature_method", _get_tool_feature_methods_for_import(db_map, tool_feature_methods, make_cache, unparse_value), ) + if entities: + yield ("entity", _get_entities_for_import(db_map, entities, make_cache)) if objects: yield ("object", _get_objects_for_import(db_map, objects, make_cache)) if relationships: yield ("relationship", _get_relationships_for_import(db_map, relationships, make_cache)) + if entity_groups: + yield ("entity_group", _get_entity_groups_for_import(db_map, entity_groups, make_cache)) if object_groups: yield ("entity_group", _get_object_groups_for_import(db_map, object_groups, make_cache)) + if parameter_values: + yield ( + "parameter_value", + _get_parameter_values_for_import(db_map, parameter_values, make_cache, unparse_value, on_conflict), + ) if object_parameter_values: yield ( "parameter_value", @@ -315,6 +342,427 @@ def get_data_for_import( ) +def import_entity_classes(db_map, data, make_cache=None): + """Imports entity classes. + + Example:: + + data = [ + 'new_class', + ('another_class', 'description', 123456), + ('multidimensional_class', 'description', 654321, ("new_class", "another_class")) + ] + import_entity_classes(db_map, data) + + Args: + db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into + data (Iterable): list/set/iterable of string entity class names, + and optionally description, integer display icon reference, and lists/tuples with dimension names, + + Returns: + tuple of int and list: Number of successfully inserted object classes, list of errors + """ + return import_data(db_map, entity_classes=data, make_cache=make_cache) + + +def _get_entity_classes_for_import(db_map, data, make_cache): + cache = make_cache({"entity_class"}, include_ancestors=True) + entity_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values()} + checked = set() + error_log = [] + to_add = [] + to_update = [] + for name, *optionals in data: + if name in checked: + continue + ec_id = entity_class_ids.pop(name, None) + item = ( + cache["entity_class"][ec_id]._asdict() + if ec_id is not None + else {"name": name, "description": None, "display_icon": None} + ) + item.update(dict(zip(("description", "display_icon", "dimension_name_list"), optionals))) + item["dimension_id_list"] = tuple(entity_class_ids.get(x, None) for x in item.get("dimension_name_list", ())) + try: + check_entity_class(item, entity_class_ids) + except SpineIntegrityError as e: + error_log.append( + ImportErrorLogItem(f"Could not import entity class '{name}': {e.msg}", db_type="entity_class") + ) + continue + finally: + if ec_id is not None: + entity_class_ids[name] = ec_id + checked.add(name) + if ec_id is not None: + item["id"] = ec_id + to_update.append(item) + else: + to_add.append(item) + return to_add, to_update, error_log + + +def import_entities(db_map, data, make_cache=None): + """Imports entities. + + Example:: + + data = [ + ('class_name1', 'entity_name1'), + ('class_name2', 'entity_name2'), + ('class_name3', ('entity_name1', 'entity_name2')) + ] + import_entities(db_map, data) + + Args: + db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into + data (List[List/Tuple]): list/set/iterable of lists/tuples with entity class name + and entity name or list/tuple of element names + + Returns: + (Int, List) Number of successful inserted entities, list of errors + """ + return import_data(db_map, entities=data, make_cache=make_cache) + + +def _make_unique_entity_name(class_id, class_name, ent_name_or_el_names, class_id_name_tuples): + if isinstance(ent_name_or_el_names, str): + return ent_name_or_el_names + base_name = class_name + "_" + "__".join([en if en is not None else "None" for en in ent_name_or_el_names]) + name = base_name + while (class_id, name) in class_id_name_tuples: + name = base_name + uuid.uuid4().hex + return name + + +def _get_entities_for_import(db_map, data, make_cache): + cache = make_cache({"entity"}, include_ancestors=True) + entities = {x.name: x for x in cache.get("entity", {}).values()} + entity_ids_per_name = {(x.class_id, x.name): x.id for x in entities.values()} + entity_ids_per_el_id_lst = {(x.class_id, x.element_id_list): x.id for x in entities.values()} + entity_classes = { + x.id: {"dimension_id_list": x.dimension_id_list, "name": x.name} for x in cache.get("entity_class", {}).values() + } + entity_ids = {(x["name"], x["class_id"]): id_ for id_, x in entities.items()} + entity_class_ids = {x["name"]: id_ for id_, x in entity_classes.items()} + dimension_id_lists = {id_: x["dimension_id_list"] for id_, x in entity_classes.items()} + error_log = [] + to_add = [] + to_update = [] + checked = set() + for class_name, ent_name_or_el_names, *optionals in data: + ec_id = entity_class_ids.get(class_name, None) + dim_ids = dimension_id_lists.get(ec_id, ()) + el_ids = tuple(entity_ids.get((name, dim_id), None) for name, dim_id in zip(ent_name_or_el_names, dim_ids)) + e_key = el_ids or ent_name_or_el_names + if (ec_id, e_key) in checked: + continue + e_id = entity_ids_per_el_id_lst.pop((ec_id, el_ids), None) + if e_id is not None: + e_name = cache["entity"][e_id].name + entity_ids_per_name.pop((e_id, e_name)) + else: + e_name = _make_unique_entity_name(ec_id, class_name, ent_name_or_el_names, entity_ids_per_name) + item = ( + cache["entity"][e_id]._asdict() + if e_id is not None + else { + "name": e_name, + "class_id": ec_id, + "element_id_list": el_ids, + } + ) + item.update(dict(zip(("description",), optionals))) + try: + check_entity(item, entity_ids_per_name, entity_ids_per_el_id_lst, entity_classes, entities) + except SpineIntegrityError as e: + msg = f"Could not import entity {tuple(ent_name_or_el_names)} into '{class_name}': {e.msg}" + error_log.append(ImportErrorLogItem(msg=msg, db_type="relationship")) + continue + finally: + if e_id is not None: + entity_ids_per_el_id_lst[ec_id, el_ids] = e_id + entity_ids_per_name[ec_id, e_name] = e_id + checked.add((ec_id, e_key)) + if e_id is not None: + item["id"] = e_id + to_update.append(item) + else: + to_add.append(item) + return to_add, to_update, error_log + + +def import_entity_groups(db_map, data, make_cache=None): + """Imports list of entity groups by name with associated class name into given database mapping: + Ignores duplicate and existing (group, member) tuples. + + Example:: + + data = [ + ('class_name', 'group_name', 'member_name'), + ('class_name', 'group_name', 'another_member_name') + ] + import_entity_groups(db_map, data) + + Args: + db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into + data (List[List/Tuple]): list/set/iterable of lists/tuples with entity class name, group name, + and member name + + Returns: + (Int, List) Number of successful inserted entity groups, list of errors + """ + return import_data(db_map, entity_groups=data, make_cache=make_cache) + + +def _get_entity_groups_for_import(db_map, data, make_cache): + cache = make_cache({"entity_group"}, include_ancestors=True) + entity_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values()} + entity_ids = {(x.class_id, x.name): x.id for x in cache.get("entity", {}).values()} + entities = {} + for ent in cache.get("entity", {}).values(): + entities.setdefault(ent.class_id, {})[ent.id] = ent._asdict() + entity_group_ids = {(x.group_id, x.member_id): x.id for x in cache.get("entity_group", {}).values()} + error_log = [] + to_add = [] + seen = set() + for class_name, group_name, member_name in data: + ec_id = entity_class_ids.get(class_name) + g_id = entity_ids.get((ec_id, group_name)) + m_id = entity_ids.get((ec_id, member_name)) + if (g_id, m_id) in seen | entity_group_ids.keys(): + continue + item = {"entity_class_id": ec_id, "entity_id": g_id, "member_id": m_id} + try: + check_entity_group(item, entity_group_ids, entities) + to_add.append(item) + seen.add((g_id, m_id)) + except SpineIntegrityError as e: + error_log.append( + ImportErrorLogItem( + msg=f"Could not import entity '{member_name}' into group '{group_name}': {e.msg}", + db_type="entity group", + ) + ) + return to_add, [], error_log + + +def import_parameter_definitions(db_map, data, make_cache=None, unparse_value=to_database): + """Imports list of parameter definitions: + + Example:: + + data = [ + ('entity_class_1', 'new_parameter'), + ('entity_class_2', 'other_parameter', 'default_value', 'value_list_name', 'description') + ] + import_parameter_definitions(db_map, data) + + Args: + db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into + data (List[List/Tuple]): list/set/iterable of lists/tuples with entity class name, parameter name, + and optionally default value, value list name, and description + + Returns: + (Int, List) Number of successful inserted parameter definitions, list of errors + """ + return import_data(db_map, parameter_definitions=data, make_cache=make_cache, unparse_value=unparse_value) + + +def _get_parameter_definitions_for_import(db_map, data, make_cache, unparse_value): + cache = make_cache({"parameter_definition"}, include_ancestors=True) + parameter_definition_ids = { + (x.entity_class_id, x.parameter_name): x.id for x in cache.get("parameter_definition", {}).values() + } + entity_class_names = {x.id: x.name for x in cache.get("entity_class", {}).values()} + entity_class_ids = {ec_name: id_ for id_, ec_name in entity_class_names.items()} + parameter_value_lists = {} + parameter_value_list_ids = {} + for x in cache.get("parameter_value_list", {}).values(): + parameter_value_lists[x.id] = x.value_id_list + parameter_value_list_ids[x.name] = x.id + list_values = {x.id: from_database(x.value, x.type) for x in cache.get("list_value", {}).values()} + error_log = [] + to_add = [] + to_update = [] + checked = set() + functions = [unparse_value, lambda x: (parameter_value_list_ids.get(x),), lambda x: (x,)] + for class_name, parameter_name, *optionals in data: + ec_id = entity_class_ids.get(class_name, None) + checked_key = (ec_id, parameter_name) + if checked_key in checked: + continue + p_id = parameter_definition_ids.pop((ec_id, parameter_name), None) + item = ( + cache["parameter_definition"][p_id]._asdict() + if p_id is not None + else { + "name": parameter_name, + "entity_class_id": ec_id, + "default_value": None, + "default_type": None, + "parameter_value_list_id": None, + "description": None, + } + ) + optionals = [y for f, x in zip(functions, optionals) for y in f(x)] + item.update(dict(zip(("default_value", "default_type", "parameter_value_list_id", "description"), optionals))) + try: + check_parameter_definition( + item, parameter_definition_ids, entity_class_names.keys(), parameter_value_lists, list_values + ) + except SpineIntegrityError as e: + # Relationship class doesn't exists + error_log.append( + ImportErrorLogItem( + msg=f"Could not import parameter definition '{parameter_name}' with class '{class_name}': {e.msg}", + db_type="parameter definition", + ) + ) + continue + finally: + if p_id is not None: + parameter_definition_ids[ec_id, parameter_name] = p_id + checked.add(checked_key) + if p_id is not None: + item["id"] = p_id + to_update.append(item) + else: + to_add.append(item) + return to_add, to_update, error_log + + +def import_parameter_values(db_map, data, make_cache=None, unparse_value=to_database, on_conflict="merge"): + """Imports parameter values: + + Example:: + + data = [ + ['example_class2', 'example_entity', 'parameter', 5.5, 'alternative'], + ['example_class1', ('example_entity', 'other_entity'), 'parameter', 2.718] + ] + import_parameter_values(db_map, data) + + Args: + db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into + data (List[List/Tuple]): list/set/iterable of lists/tuples with + entity class name, entity name or list of element names, parameter name, (deserialized) parameter value, + optional name of an alternative + + Returns: + (Int, List) Number of successful inserted parameter values, list of errors + """ + return import_data( + db_map, parameter_values=data, make_cache=make_cache, unparse_value=unparse_value, on_conflict=on_conflict + ) + + +def _get_parameter_values_for_import(db_map, data, make_cache, unparse_value, on_conflict): + cache = make_cache({"parameter_value"}, include_ancestors=True) + dimension_id_lists = {x.id: x.dimension_id_list for x in cache.get("entity_class", {}).values()} + parameter_value_ids = { + (x.entity_id, x.parameter_id, x.alternative_id): x.id for x in cache.get("parameter_value", {}).values() + } + parameters = { + x.id: { + "name": x.parameter_name, + "entity_class_id": x.entity_class_id, + "parameter_value_list_id": x.value_list_id, + } + for x in cache.get("parameter_definition", {}).values() + } + entities = { + x.id: {"class_id": x.class_id, "name": x.name, "element_id_list": x.element_id_list} + for x in cache.get("entity", {}).values() + } + parameter_value_lists = {x.id: x.value_id_list for x in cache.get("parameter_value_list", {}).values()} + list_values = {x.id: from_database(x.value, x.type) for x in cache.get("list_value", {}).values()} + parameter_ids = {(p["entity_class_id"], p["name"]): p_id for p_id, p in parameters.items()} + entity_ids = {(x["class_id"], x["element_id_list"]): e_id for e_id, x in entities.items()} + entity_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values()} + alternatives = {a.name: a.id for a in cache.get("alternative", {}).values()} + alternative_ids = set(alternatives.values()) + error_log = [] + to_add = [] + to_update = [] + checked = set() + for class_name, ent_name_or_el_names, parameter_name, value, *optionals in data: + ec_id = entity_class_ids.get(class_name, None) + dim_ids = dimension_id_lists.get(ec_id, ()) + el_ids = tuple(entity_ids.get((dim_id, name)) for dim_id, name in zip(dim_ids, ent_name_or_el_names)) + ent_key = el_ids or ent_name_or_el_names + e_id = entity_ids.get((ec_id, ent_key), None) + p_id = parameter_ids.get((ec_id, parameter_name), None) + if optionals: + alternative_name = optionals[0] + alt_id = alternatives.get(alternative_name) + if not alt_id: + error_log.append( + ImportErrorLogItem( + msg=( + f"Could not import parameter value for '{ent_name_or_el_names}', class '{class_name}', " + f"parameter '{parameter_name}': alternative {alternative_name} does not exist." + ), + db_type="parameter value", + ) + ) + continue + else: + alt_id, alternative_name = db_map.get_import_alternative(cache=cache) + alternative_ids.add(alt_id) + checked_key = (e_id, p_id, alt_id) + if checked_key in checked: + msg = ( + f"Could not import parameter value for '{ent_name_or_el_names}', class '{class_name}', " + f"parameter '{parameter_name}', alternative {alternative_name}: " + "Duplicate parameter value, only first value will be considered." + ) + error_log.append(ImportErrorLogItem(msg=msg, db_type="parameter_value")) + continue + pv_id = parameter_value_ids.pop((e_id, p_id, alt_id), None) + value, type_ = unparse_value(value) + if pv_id is not None: + current_pv = cache["parameter_value"][pv_id] + value, type_ = fix_conflict((value, type_), (current_pv.value, current_pv.type), on_conflict) + item = { + "parameter_definition_id": p_id, + "entity_class_id": ec_id, + "entity_id": e_id, + "value": value, + "type": type_, + "alternative_id": alt_id, + } + try: + check_parameter_value( + item, + parameter_value_ids, + parameters, + entities, + parameter_value_lists, + list_values, + alternative_ids, + ) + except SpineIntegrityError as e: + error_log.append( + ImportErrorLogItem( + msg=f"Could not import parameter value for '{ent_name_or_el_names}', class '{class_name}', " + f"parameter '{parameter_name}', alternative {alternative_name}: {e.msg}", + db_type="parameter_value", + ) + ) + continue + finally: + if pv_id is not None: + parameter_value_ids[r_id, p_id, alt_id] = pv_id + checked.add(checked_key) + if pv_id is not None: + item["id"] = pv_id + to_update.append(item) + else: + to_add.append(item) + return to_add, to_update, error_log + + def import_features(db_map, data, make_cache=None): """ Imports features. @@ -815,8 +1263,10 @@ def import_object_classes(db_map, data, make_cache=None): def _get_object_classes_for_import(db_map, data, make_cache): - cache = make_cache({"object_class"}, include_ancestors=True) - object_class_ids = {oc.name: oc.id for oc in cache.get("object_class", {}).values()} + cache = make_cache({"entity_class"}, include_ancestors=True) + object_class_ids = { + oc.name: oc.id for oc in cache.get("_get_object_classes_for_import", {}).values() if not oc.dimension_id_list + } checked = set() to_add = [] to_update = [] @@ -829,14 +1279,13 @@ def _get_object_classes_for_import(db_map, data, make_cache): continue oc_id = object_class_ids.pop(name, None) item = ( - cache["object_class"][oc_id]._asdict() + cache["entity_class"][oc_id]._asdict() if oc_id is not None else {"name": name, "description": None, "display_icon": None} ) - item["type_id"] = db_map.object_class_type item.update(dict(zip(("description", "display_icon"), optionals))) try: - check_object_class(item, object_class_ids, db_map.object_class_type) + check_object_class(item, object_class_ids) except SpineIntegrityError as e: error_log.append( ImportErrorLogItem(msg=f"Could not import object class '{name}': {e.msg}", db_type="object class") @@ -877,9 +1326,9 @@ def import_relationship_classes(db_map, data, make_cache=None): def _get_relationship_classes_for_import(db_map, data, make_cache): - cache = make_cache({"relationship_class"}, include_ancestors=True) - object_class_ids = {oc.name: oc.id for oc in cache.get("object_class", {}).values()} - relationship_class_ids = {x.name: x.id for x in cache.get("relationship_class", {}).values()} + cache = make_cache({"entity_class"}, include_ancestors=True) + object_class_ids = {oc.name: oc.id for oc in cache.get("entity_class", {}).values() if not oc.dimension_id_list} + relationship_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values() if x.dimension_id_list} checked = set() error_log = [] to_add = [] @@ -889,7 +1338,7 @@ def _get_relationship_classes_for_import(db_map, data, make_cache): continue rc_id = relationship_class_ids.pop(name, None) item = ( - cache["relationship_class"][rc_id]._asdict() + cache["entity_class"][rc_id]._asdict() if rc_id is not None else { "name": name, @@ -898,12 +1347,9 @@ def _get_relationship_classes_for_import(db_map, data, make_cache): "display_icon": None, } ) - item["type_id"] = db_map.relationship_class_type item.update(dict(zip(("description", "display_icon"), optionals))) try: - check_wide_relationship_class( - item, relationship_class_ids, set(object_class_ids.values()), db_map.relationship_class_type - ) + check_wide_relationship_class(item, relationship_class_ids, set(object_class_ids.values())) except SpineIntegrityError as e: error_log.append( ImportErrorLogItem( @@ -947,9 +1393,9 @@ def import_objects(db_map, data, make_cache=None): def _get_objects_for_import(db_map, data, make_cache): - cache = make_cache({"object"}, include_ancestors=True) - object_class_ids = {oc.name: oc.id for oc in cache.get("object_class", {}).values()} - object_ids = {(o.class_id, o.name): o.id for o in cache.get("object", {}).values()} + cache = make_cache({"entity"}, include_ancestors=True) + object_class_ids = {oc.name: oc.id for oc in cache.get("entity_class", {}).values() if not oc.dimension_id_list} + object_ids = {(o.class_id, o.name): o.id for o in cache.get("entity", {}).values() if not o.element_id_list} checked = set() error_log = [] to_add = [] @@ -960,14 +1406,13 @@ def _get_objects_for_import(db_map, data, make_cache): continue o_id = object_ids.pop((oc_id, name), None) item = ( - cache["object"][o_id]._asdict() + cache["entity"][o_id]._asdict() if o_id is not None else {"name": name, "class_id": oc_id, "description": None} ) - item["type_id"] = db_map.object_entity_type item.update(dict(zip(("description",), optionals))) try: - check_object(item, object_ids, set(object_class_ids.values()), db_map.object_entity_type) + check_object(item, object_ids, set(object_class_ids.values())) except SpineIntegrityError as e: error_log.append( ImportErrorLogItem( @@ -1012,11 +1457,12 @@ def import_object_groups(db_map, data, make_cache=None): def _get_object_groups_for_import(db_map, data, make_cache): cache = make_cache({"entity_group"}, include_ancestors=True) - object_class_ids = {oc.name: oc.id for oc in cache.get("object_class", {}).values()} - object_ids = {(o.class_id, o.name): o.id for o in cache.get("object", {}).values()} + object_class_ids = {oc.name: oc.id for oc in cache.get("entity_class", {}).values() if not oc.dimension_id_list} + object_ids = {(o.class_id, o.name): o.id for o in cache.get("entity", {}).values() if not o.element_id_list} objects = {} - for obj in cache.get("object", {}).values(): - objects.setdefault(obj.class_id, dict())[obj.id] = obj._asdict() + for obj in cache.get("entity", {}).values(): + if not obj.element_id_list: + objects.setdefault(obj.class_id, dict())[obj.id] = obj._asdict() entity_group_ids = {(x.group_id, x.member_id): x.id for x in cache.get("entity_group", {}).values()} error_log = [] to_add = [] @@ -1070,15 +1516,20 @@ def _make_unique_relationship_name(class_id, class_name, object_names, class_id_ def _get_relationships_for_import(db_map, data, make_cache): - cache = make_cache({"relationship"}, include_ancestors=True) - relationships = {x.name: x for x in cache.get("relationship", {}).values()} + cache = make_cache({"entity"}, include_ancestors=True) + relationships = {x.name: x for x in cache.get("entity", {}).values() if x.element_id_list} relationship_ids_per_name = {(x.class_id, x.name): x.id for x in relationships.values()} - relationship_ids_per_obj_lst = {(x.class_id, x.object_id_list): x.id for x in relationships.values()} + relationship_ids_per_obj_lst = {(x.class_id, x.element_id_list): x.id for x in relationships.values()} relationship_classes = { - x.id: {"object_class_id_list": x.object_class_id_list, "name": x.name} - for x in cache.get("relationship_class", {}).values() + x.id: {"object_class_id_list": x.dimension_id_list, "name": x.name} + for x in cache.get("entity_class", {}).values() + if x.dimension_id_list + } + objects = { + x.id: {"class_id": x.class_id, "name": x.name} + for x in cache.get("entity", {}).values() + if not x.element_id_list } - objects = {x.id: {"class_id": x.class_id, "name": x.name} for x in cache.get("object", {}).values()} object_ids = {(o["name"], o["class_id"]): o_id for o_id, o in objects.items()} relationship_class_ids = {rc["name"]: rc_id for rc_id, rc in relationship_classes.items()} object_class_id_lists = {rc_id: rc["object_class_id_list"] for rc_id, rc in relationship_classes.items()} @@ -1094,17 +1545,16 @@ def _get_relationships_for_import(db_map, data, make_cache): continue r_id = relationship_ids_per_obj_lst.pop((rc_id, o_ids), None) if r_id is not None: - r_name = cache["relationship"][r_id].name + r_name = cache["entity"][r_id].name relationship_ids_per_name.pop((rc_id, r_name)) item = ( - cache["relationship"][r_id]._asdict() + cache["entity"][r_id]._asdict() if r_id is not None else { "name": _make_unique_relationship_name(rc_id, class_name, object_names, relationship_ids_per_name), "class_id": rc_id, "object_id_list": list(o_ids), "object_class_id_list": oc_ids, - "type_id": db_map.relationship_entity_type, } ) item.update(dict(zip(("description",), optionals))) @@ -1115,7 +1565,6 @@ def _get_relationships_for_import(db_map, data, make_cache): relationship_ids_per_obj_lst, relationship_classes, objects, - db_map.relationship_entity_type, ) except SpineIntegrityError as e: msg = f"Could not import relationship with objects {tuple(object_names)} into '{class_name}': {e.msg}" @@ -1161,7 +1610,7 @@ def _get_object_parameters_for_import(db_map, data, make_cache, unparse_value): parameter_ids = { (x.entity_class_id, x.parameter_name): x.id for x in cache.get("parameter_definition", {}).values() } - object_class_names = {x.id: x.name for x in cache.get("object_class", {}).values()} + object_class_names = {x.id: x.name for x in cache.get("entity_class", {}).values() if not x.dimension_id_list} object_class_ids = {oc_name: oc_id for oc_id, oc_name in object_class_names.items()} parameter_value_lists = {} parameter_value_list_ids = {} @@ -1246,7 +1695,7 @@ def _get_relationship_parameters_for_import(db_map, data, make_cache, unparse_va parameter_ids = { (x.entity_class_id, x.parameter_name): x.id for x in cache.get("parameter_definition", {}).values() } - relationship_class_names = {x.id: x.name for x in cache.get("relationship_class", {}).values()} + relationship_class_names = {x.id: x.name for x in cache.get("entity_class", {}).values() if x.dimension_id_list} relationship_class_ids = {rc_name: rc_id for rc_id, rc_name in relationship_class_names.items()} parameter_value_lists = {} parameter_value_list_ids = {} @@ -1335,7 +1784,7 @@ def import_object_parameter_values(db_map, data, make_cache=None, unparse_value= def _get_object_parameter_values_for_import(db_map, data, make_cache, unparse_value, on_conflict): cache = make_cache({"parameter_value"}, include_ancestors=True) - object_class_ids = {x.name: x.id for x in cache.get("object_class", {}).values()} + object_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values() if not x.dimension_id_list} parameter_value_ids = { (x.entity_id, x.parameter_id, x.alternative_id): x.id for x in cache.get("parameter_value", {}).values() } @@ -1347,7 +1796,11 @@ def _get_object_parameter_values_for_import(db_map, data, make_cache, unparse_va } for x in cache.get("parameter_definition", {}).values() } - objects = {x.id: {"class_id": x.class_id, "name": x.name} for x in cache.get("object", {}).values()} + objects = { + x.id: {"class_id": x.class_id, "name": x.name} + for x in cache.get("entity", {}).values() + if not x.element_id_list + } parameter_value_lists = {x.id: x.value_id_list for x in cache.get("parameter_value_list", {}).values()} list_values = {x.id: from_database(x.value, x.type) for x in cache.get("list_value", {}).values()} object_ids = {(o["name"], o["class_id"]): o_id for o_id, o in objects.items()} @@ -1460,7 +1913,9 @@ def import_relationship_parameter_values(db_map, data, make_cache=None, unparse_ def _get_relationship_parameter_values_for_import(db_map, data, make_cache, unparse_value, on_conflict): cache = make_cache({"parameter_value"}, include_ancestors=True) - object_class_id_lists = {x.id: x.object_class_id_list for x in cache.get("relationship_class", {}).values()} + object_class_id_lists = { + x.id: x.dimension_id_list for x in cache.get("entity_class", {}).values() if x.dimension_id_list + } parameter_value_ids = { (x.entity_id, x.parameter_id, x.alternative_id): x.id for x in cache.get("parameter_value", {}).values() } @@ -1473,15 +1928,16 @@ def _get_relationship_parameter_values_for_import(db_map, data, make_cache, unpa for x in cache.get("parameter_definition", {}).values() } relationships = { - x.id: {"class_id": x.class_id, "name": x.name, "object_id_list": x.object_id_list} - for x in cache.get("relationship", {}).values() + x.id: {"class_id": x.class_id, "name": x.name, "object_id_list": x.element_id_list} + for x in cache.get("entity", {}).values() + if x.element_id_list } parameter_value_lists = {x.id: x.value_id_list for x in cache.get("parameter_value_list", {}).values()} list_values = {x.id: from_database(x.value, x.type) for x in cache.get("list_value", {}).values()} parameter_ids = {(p["entity_class_id"], p["name"]): p_id for p_id, p in parameters.items()} relationship_ids = {(r["class_id"], tuple(r["object_id_list"])): r_id for r_id, r in relationships.items()} - object_ids = {(o.name, o.class_id): o.id for o in cache.get("object", {}).values()} - relationship_class_ids = {oc.name: oc.id for oc in cache.get("relationship_class", {}).values()} + object_ids = {(o.name, o.class_id): o.id for o in cache.get("entity", {}).values() if not o.element_id_list} + relationship_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values() if x.dimension_id_list} alternatives = {a.name: a.id for a in cache.get("alternative", {}).values()} alternative_ids = set(alternatives.values()) error_log = [] @@ -1497,6 +1953,10 @@ def _get_relationship_parameter_values_for_import(db_map, data, make_cache, unpa o_ids = tuple(None for _ in object_names) r_id = relationship_ids.get((rc_id, o_ids), None) p_id = parameter_ids.get((rc_id, parameter_name), None) + if p_id is None: + print(class_name, object_names, parameter_name, value) + for x in parameter_ids.items(): + print(x) if optionals: alternative_name = optionals[0] alt_id = alternatives.get(alternative_name) From 9a30d04895c68deec649abcff7911966895be858 Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 14 Feb 2023 18:57:44 +0100 Subject: [PATCH 009/317] Fix import entities --- spinedb_api/import_functions.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index e495e88c..38464d71 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -366,6 +366,8 @@ def import_entity_classes(db_map, data, make_cache=None): def _get_entity_classes_for_import(db_map, data, make_cache): + # FIXME: We need to find a way to set the ids for newly added single dimensional entities + # so that they can be used in this same function for adding multi dimensional ones cache = make_cache({"entity_class"}, include_ancestors=True) entity_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values()} checked = set() @@ -437,9 +439,11 @@ def _make_unique_entity_name(class_id, class_name, ent_name_or_el_names, class_i def _get_entities_for_import(db_map, data, make_cache): cache = make_cache({"entity"}, include_ancestors=True) - entities = {x.name: x for x in cache.get("entity", {}).values()} - entity_ids_per_name = {(x.class_id, x.name): x.id for x in entities.values()} - entity_ids_per_el_id_lst = {(x.class_id, x.element_id_list): x.id for x in entities.values()} + entities = {x.id: x for x in cache.get("entity", {}).values()} + entity_ids_per_name = {(x.class_id, x.name): x.id for x in cache.get("entity", {}).values()} + entity_ids_per_el_id_lst = { + (x.class_id, x.element_id_list): x.id for x in cache.get("entity", {}).values() if x.element_id_list + } entity_classes = { x.id: {"dimension_id_list": x.dimension_id_list, "name": x.name} for x in cache.get("entity_class", {}).values() } From f6faf5c3a006089980b6e8251ba72badfd93b0cc Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 15 Feb 2023 10:53:41 +0100 Subject: [PATCH 010/317] Introduce entity byname --- spinedb_api/db_cache.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/spinedb_api/db_cache.py b/spinedb_api/db_cache.py index 0d808308..ba960f52 100644 --- a/spinedb_api/db_cache.py +++ b/spinedb_api/db_cache.py @@ -312,6 +312,8 @@ def __getitem__(self, key): return self._get_ref("entity_class", self["class_id"], key).get("dimension_name_list") if key == "element_name_list": return tuple(self._get_ref("entity", id_, key).get("name") for id_ in self["element_id_list"]) + if key == "byname": + return self["element_name_list"] or (self["name"],) return super().__getitem__(key) def _reference_keys(self): @@ -372,6 +374,8 @@ def __getitem__(self, key): return self._get_ref("parameter_definition", self["parameter_definition_id"], key).get("name") if key == "entity_name": return self._get_ref("entity", self["entity_id"], key)["name"] + if key == "entity_byname": + return self._get_ref("entity", self["entity_id"], key)["byname"] if key in ("element_id_list", "element_name_list"): return self._get_ref("entity", self["entity_id"], key)[key] if key == "alternative_name": From 504baf60d09dffff574179329ad4d34f792866b3 Mon Sep 17 00:00:00 2001 From: Manuel Date: Mon, 13 Mar 2023 18:30:23 +0100 Subject: [PATCH 011/317] Fix some tests --- spinedb_api/check_functions.py | 1 - spinedb_api/db_cache.py | 4 +- spinedb_api/db_mapping_base.py | 3 + spinedb_api/db_mapping_check_mixin.py | 15 +- spinedb_api/import_functions.py | 48 +++--- tests/export_mapping/test_export_mapping.py | 142 ++++++++-------- tests/filters/test_alternative_filter.py | 6 +- tests/filters/test_renamer.py | 6 +- tests/filters/test_scenario_filter.py | 6 +- tests/filters/test_tool_filter.py | 4 +- tests/filters/test_tools.py | 6 +- tests/filters/test_value_transformer.py | 4 +- tests/spine_io/exporters/test_csv_writer.py | 10 +- tests/spine_io/exporters/test_excel_writer.py | 14 +- tests/test_DiffDatabaseMapping.py | 158 ++++++++---------- tests/test_export_functions.py | 4 +- tests/test_import_functions.py | 5 +- tests/test_migration.py | 4 +- 18 files changed, 222 insertions(+), 218 deletions(-) diff --git a/spinedb_api/check_functions.py b/spinedb_api/check_functions.py index 56b586a0..bdebbf94 100644 --- a/spinedb_api/check_functions.py +++ b/spinedb_api/check_functions.py @@ -282,7 +282,6 @@ def check_wide_relationship(wide_item, current_items_by_name, current_items_by_o Raises: SpineIntegrityError: if the insertion of the item violates an integrity constraint. """ - try: name = wide_item["name"] except KeyError: diff --git a/spinedb_api/db_cache.py b/spinedb_api/db_cache.py index ba960f52..b6818b13 100644 --- a/spinedb_api/db_cache.py +++ b/spinedb_api/db_cache.py @@ -276,7 +276,7 @@ def __getitem__(self, key): class EntityClassItem(DisplayIconMixin, DescriptionMixin, CacheItem): def __init__(self, *args, **kwargs): - dimension_id_list = kwargs["dimension_id_list"] + dimension_id_list = kwargs.get("dimension_id_list") if dimension_id_list is None: dimension_id_list = () if isinstance(dimension_id_list, str): @@ -295,7 +295,7 @@ def _reference_keys(self): class EntityItem(DescriptionMixin, CacheItem): def __init__(self, *args, **kwargs): - element_id_list = kwargs["element_id_list"] + element_id_list = kwargs.get("element_id_list") if element_id_list is None: element_id_list = () if isinstance(element_id_list, str): diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 5c7a63cd..d0187acd 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -315,6 +315,9 @@ def _real_tablename(self, tablename): "relationship": "entity", }.get(tablename, tablename) + def get_table(self, tablename): + return self._metadata.tables[tablename] + def commit_id(self): return self._commit_id diff --git a/spinedb_api/db_mapping_check_mixin.py b/spinedb_api/db_mapping_check_mixin.py index fec8b978..eb6de908 100644 --- a/spinedb_api/db_mapping_check_mixin.py +++ b/spinedb_api/db_mapping_check_mixin.py @@ -498,7 +498,8 @@ def check_wide_relationship_classes(self, *wide_items, for_update=False, strict= cache, intgr_error_log, ) as wide_item: - wide_item["object_class_id_list"] = wide_item.pop("dimension_id_list", ()) + if "object_class_id_list" not in wide_item: + wide_item["object_class_id_list"] = wide_item.pop("dimension_id_list", ()) check_wide_relationship_class(wide_item, relationship_class_ids, object_class_ids) checked_wide_items.append(wide_item) except SpineIntegrityError as e: @@ -534,7 +535,11 @@ def check_wide_relationships(self, *wide_items, for_update=False, strict=False, for x in cache.get("entity_class", {}).values() if x.dimension_id_list } - objects = {x.id: {"class_id": x.class_id, "name": x.name} for x in cache.get("object", {}).values()} + objects = { + x.id: {"class_id": x.class_id, "name": x.name} + for x in cache.get("entity", {}).values() + if not x.element_id_list + } for wide_item in wide_items: try: with self._manage_stocks( @@ -548,8 +553,10 @@ def check_wide_relationships(self, *wide_items, for_update=False, strict=False, cache, intgr_error_log, ) as wide_item: - wide_item["object_class_id_list"] = wide_item.pop("dimension_id_list", ()) - wide_item["object_id_list"] = wide_item.pop("element_id_list", ()) + if "object_class_id_list" not in wide_item: + wide_item["object_class_id_list"] = wide_item.pop("dimension_id_list", ()) + if "object_id_list" not in wide_item: + wide_item["object_id_list"] = wide_item.pop("element_id_list", ()) check_wide_relationship( wide_item, relationship_ids_by_name, diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 38464d71..95d43f50 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -1346,11 +1346,12 @@ def _get_relationship_classes_for_import(db_map, data, make_cache): if rc_id is not None else { "name": name, - "object_class_id_list": [object_class_ids.get(oc, None) for oc in oc_names], + "dimension_id_list": [object_class_ids.get(oc, None) for oc in oc_names], "description": None, "display_icon": None, } ) + item["object_class_id_list"] = item.pop("dimension_id_list") item.update(dict(zip(("description", "display_icon"), optionals))) try: check_wide_relationship_class(item, relationship_class_ids, set(object_class_ids.values())) @@ -1557,10 +1558,12 @@ def _get_relationships_for_import(db_map, data, make_cache): else { "name": _make_unique_relationship_name(rc_id, class_name, object_names, relationship_ids_per_name), "class_id": rc_id, - "object_id_list": list(o_ids), - "object_class_id_list": oc_ids, + "element_id_list": list(o_ids), + "dimension_id_list": oc_ids, } ) + item["object_id_list"] = item.pop("element_id_list") + item["object_class_id_list"] = item.pop("dimension_id_list", ()) item.update(dict(zip(("description",), optionals))) try: check_wide_relationship( @@ -1957,10 +1960,6 @@ def _get_relationship_parameter_values_for_import(db_map, data, make_cache, unpa o_ids = tuple(None for _ in object_names) r_id = relationship_ids.get((rc_id, o_ids), None) p_id = parameter_ids.get((rc_id, parameter_name), None) - if p_id is None: - print(class_name, object_names, parameter_name, value) - for x in parameter_ids.items(): - print(x) if optionals: alternative_name = optionals[0] alt_id = alternatives.get(alternative_name) @@ -2162,6 +2161,9 @@ def _get_metadata_for_import(db_map, data, make_cache): return to_add, [], [] +# TODO: import_entity_metadata, import_parameter_value_metadata + + def import_object_metadata(db_map, data, make_cache=None): """Imports object metadata. Ignores duplicates. @@ -2183,9 +2185,9 @@ def import_object_metadata(db_map, data, make_cache=None): def _get_object_metadata_for_import(db_map, data, make_cache): cache = make_cache({"object", "entity_metadata"}, include_ancestors=True) - object_class_ids = {x.name: x.id for x in cache.get("object_class", {}).values()} + object_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values() if not x.dimension_id_list} metadata_ids = {(x.name, x.value): x.id for x in cache.get("metadata", {}).values()} - object_ids = {(x.name, x.class_id): x.id for x in cache.get("object", {}).values()} + object_ids = {(x.name, x.class_id): x.id for x in cache.get("entity", {}).values() if not x.element_id_list} seen = {(x.entity_id, x.metadata_id) for x in cache.get("entity_metadata", {}).values()} error_log = [] to_add = [] @@ -2243,11 +2245,15 @@ def import_relationship_metadata(db_map, data, make_cache=None): def _get_relationship_metadata_for_import(db_map, data, make_cache): cache = make_cache({"relationship", "entity_metadata"}, include_ancestors=True) - relationship_class_ids = {oc.name: oc.id for oc in cache.get("relationship_class", {}).values()} - object_class_id_lists = {x.id: x.object_class_id_list for x in cache.get("relationship_class", {}).values()} + relationship_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values() if x.dimension_id_list} + object_class_id_lists = { + x.id: x.dimension_id_list for x in cache.get("entity_class", {}).values() if x.dimension_id_list + } metadata_ids = {(x.name, x.value): x.id for x in cache.get("metadata", {}).values()} - object_ids = {(x.name, x.class_id): x.id for x in cache.get("object", {}).values()} - relationship_ids = {(x.class_id, x.object_id_list): x.id for x in cache.get("relationship", {}).values()} + object_ids = {(x.name, x.class_id): x.id for x in cache.get("entity", {}).values() if not x.element_id_list} + relationship_ids = { + (x.class_id, x.element_id_list): x.id for x in cache.get("entity", {}).values() if x.element_id_list + } seen = {(x.entity_id, x.metadata_id) for x in cache.get("entity_metadata", {}).values()} error_log = [] to_add = [] @@ -2309,8 +2315,8 @@ def import_object_parameter_value_metadata(db_map, data, make_cache=None): def _get_object_parameter_value_metadata_for_import(db_map, data, make_cache): cache = make_cache({"parameter_value", "parameter_value_metadata"}, include_ancestors=True) - object_class_ids = {x.name: x.id for x in cache.get("object_class", {}).values()} - object_ids = {(x.name, x.class_id): x.id for x in cache.get("object", {}).values()} + object_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values() if not x.dimension_id_list} + object_ids = {(x.name, x.class_id): x.id for x in cache.get("entity", {}).values() if not x.element_id_list} parameter_ids = { (x.parameter_name, x.entity_class_id): x.id for x in cache.get("parameter_definition", {}).values() } @@ -2384,10 +2390,14 @@ def import_relationship_parameter_value_metadata(db_map, data, make_cache=None): def _get_relationship_parameter_value_metadata_for_import(db_map, data, make_cache): cache = make_cache({"parameter_value", "parameter_value_metadata"}, include_ancestors=True) - relationship_class_ids = {oc.name: oc.id for oc in cache.get("relationship_class", {}).values()} - object_class_id_lists = {x.id: x.object_class_id_list for x in cache.get("relationship_class", {}).values()} - object_ids = {(x.name, x.class_id): x.id for x in cache.get("object", {}).values()} - relationship_ids = {(x.object_id_list, x.class_id): x.id for x in cache.get("relationship", {}).values()} + relationship_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values() if x.dimension_id_list} + object_class_id_lists = { + x.id: x.dimension_id_list for x in cache.get("entity_class", {}).values() if x.dimension_id_list + } + object_ids = {(x.name, x.class_id): x.id for x in cache.get("entity", {}).values() if not x.element_id_list} + relationship_ids = { + (x.object_id_list, x.class_id): x.id for x in cache.get("entity", {}).values() if x.element_id_list + } parameter_ids = { (x.parameter_name, x.entity_class_id): x.id for x in cache.get("parameter_definition", {}).values() } diff --git a/tests/export_mapping/test_export_mapping.py b/tests/export_mapping/test_export_mapping.py index 2a123202..b6d7d0c4 100644 --- a/tests/export_mapping/test_export_mapping.py +++ b/tests/export_mapping/test_export_mapping.py @@ -18,7 +18,7 @@ import unittest from spinedb_api import ( DatabaseMapping, - DiffDatabaseMapping, + DatabaseMapping, import_alternatives, import_features, import_object_classes, @@ -95,7 +95,7 @@ def test_export_empty_table(self): db_map.connection.close() def test_export_single_object_class(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("object_class",)) db_map.commit_session("Add test data.") object_class_mapping = ObjectClassMapping(0) @@ -103,7 +103,7 @@ def test_export_single_object_class(self): db_map.connection.close() def test_export_objects(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc1", "oc2", "oc3")) import_objects( db_map, (("oc1", "o11"), ("oc1", "o12"), ("oc2", "o21"), ("oc3", "o31"), ("oc3", "o32"), ("oc3", "o33")) @@ -118,7 +118,7 @@ def test_export_objects(self): db_map.connection.close() def test_hidden_tail(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc1",)) import_objects(db_map, (("oc1", "o11"), ("oc1", "o12"))) db_map.commit_session("Add test data.") @@ -128,7 +128,7 @@ def test_hidden_tail(self): db_map.connection.close() def test_pivot_without_values(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc1",)) import_objects(db_map, (("oc1", "o11"), ("oc1", "o12"))) db_map.commit_session("Add test data.") @@ -138,7 +138,7 @@ def test_pivot_without_values(self): db_map.connection.close() def test_hidden_tail_pivoted(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) import_object_parameters(db_map, (("oc", "p1"), ("oc", "p2"))) import_objects(db_map, (("oc", "o1"), ("oc", "o2"))) @@ -168,7 +168,7 @@ def test_hidden_leaf_item_in_pivot_table_not_valid(self): self.assertEqual(object_class_mapping.check_validity(), ["Cannot be pivoted."]) def test_object_groups(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) import_objects(db_map, (("oc", "o1"), ("oc", "o2"), ("oc", "o3"), ("oc", "g1"), ("oc", "g2"))) import_object_groups(db_map, (("oc", "g1", "o1"), ("oc", "g1", "o2"), ("oc", "g2", "o3"))) @@ -179,7 +179,7 @@ def test_object_groups(self): db_map.connection.close() def test_object_groups_with_objects(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) import_objects(db_map, (("oc", "o1"), ("oc", "o2"), ("oc", "o3"), ("oc", "g1"), ("oc", "g2"))) import_object_groups(db_map, (("oc", "g1", "o1"), ("oc", "g1", "o2"), ("oc", "g2", "o3"))) @@ -190,7 +190,7 @@ def test_object_groups_with_objects(self): db_map.connection.close() def test_object_groups_with_parameter_values(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) import_objects(db_map, (("oc", "o1"), ("oc", "o2"), ("oc", "o3"), ("oc", "g1"), ("oc", "g2"))) import_object_groups(db_map, (("oc", "g1", "o1"), ("oc", "g1", "o2"), ("oc", "g2", "o3"))) @@ -215,7 +215,7 @@ def test_object_groups_with_parameter_values(self): db_map.connection.close() def test_export_parameter_definitions(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc1", "oc2")) import_object_parameters(db_map, (("oc1", "p11"), ("oc1", "p12"), ("oc2", "p21"))) import_objects(db_map, (("oc1", "o11"), ("oc1", "o12"), ("oc2", "o21"))) @@ -235,7 +235,7 @@ def test_export_parameter_definitions(self): db_map.connection.close() def test_export_single_parameter_value_when_there_are_multiple_objects(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc1", "oc2")) import_object_parameters(db_map, (("oc1", "p11"), ("oc1", "p12"), ("oc2", "p21"))) import_objects(db_map, (("oc1", "o11"), ("oc1", "o12"), ("oc2", "o21"))) @@ -254,7 +254,7 @@ def test_export_single_parameter_value_when_there_are_multiple_objects(self): db_map.connection.close() def test_export_single_parameter_value_pivoted_by_object_name(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc1",)) import_object_parameters(db_map, (("oc1", "p11"), ("oc1", "p12"))) import_objects(db_map, (("oc1", "o11"), ("oc1", "o12"))) @@ -282,7 +282,7 @@ def test_export_single_parameter_value_pivoted_by_object_name(self): db_map.connection.close() def test_minimum_pivot_index_need_not_be_minus_one(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_alternatives(db_map, ("alt",)) import_object_classes(db_map, ("oc",)) import_object_parameters(db_map, (("oc", "p"),)) @@ -305,7 +305,7 @@ def test_minimum_pivot_index_need_not_be_minus_one(self): db_map.connection.close() def test_pivot_row_order(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc1",)) import_object_parameters(db_map, (("oc1", "p11"), ("oc1", "p12"))) import_objects(db_map, (("oc1", "o11"), ("oc1", "o12"))) @@ -349,7 +349,7 @@ def test_pivot_row_order(self): db_map.connection.close() def test_export_parameter_indexes(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) import_object_parameters(db_map, (("oc", "p1"), ("oc", "p2"))) import_objects(db_map, (("oc", "o1"), ("oc", "o2"))) @@ -386,7 +386,7 @@ def test_export_parameter_indexes(self): db_map.connection.close() def test_export_nested_parameter_indexes(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) import_object_parameters(db_map, (("oc", "p"),)) import_objects(db_map, (("oc", "o1"), ("oc", "o2"))) @@ -421,7 +421,7 @@ def test_export_nested_parameter_indexes(self): db_map.connection.close() def test_export_nested_map_values_only(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) import_object_parameters(db_map, (("oc", "p"),)) import_objects(db_map, (("oc", "o1"), ("oc", "o2"))) @@ -451,7 +451,7 @@ def test_export_nested_map_values_only(self): db_map.connection.close() def test_full_pivot_table(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) import_object_parameters(db_map, (("oc", "p"),)) import_objects(db_map, (("oc", "o1"), ("oc", "o2"))) @@ -487,7 +487,7 @@ def test_full_pivot_table(self): db_map.connection.close() def test_full_pivot_table_with_hidden_columns(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) import_object_parameters(db_map, (("oc", "p"),)) import_objects(db_map, (("oc", "o1"), ("oc", "o2"))) @@ -505,7 +505,7 @@ def test_full_pivot_table_with_hidden_columns(self): db_map.connection.close() def test_objects_as_pivot_header_for_indexed_values_with_alternatives(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_alternatives(db_map, ("alt",)) import_object_classes(db_map, ("oc",)) import_object_parameters(db_map, (("oc", "p"),)) @@ -532,7 +532,7 @@ def test_objects_as_pivot_header_for_indexed_values_with_alternatives(self): db_map.connection.close() def test_objects_and_indexes_as_pivot_header(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) import_object_parameters(db_map, (("oc", "p"),)) import_objects(db_map, (("oc", "o1"), ("oc", "o2"))) @@ -550,7 +550,7 @@ def test_objects_and_indexes_as_pivot_header(self): db_map.connection.close() def test_objects_and_indexes_as_pivot_header_with_multiple_alternatives_and_parameters(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_alternatives(db_map, ("alt",)) import_object_classes(db_map, ("oc",)) import_object_parameters(db_map, (("oc", "p1"),)) @@ -582,7 +582,7 @@ def test_objects_and_indexes_as_pivot_header_with_multiple_alternatives_and_para db_map.connection.close() def test_empty_column_while_pivoted_handled_gracefully(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_alternatives(db_map, ("alt",)) import_object_classes(db_map, ("oc",)) import_object_parameters(db_map, (("oc", "p"),)) @@ -599,7 +599,7 @@ def test_empty_column_while_pivoted_handled_gracefully(self): db_map.connection.close() def test_object_classes_as_header_row_and_objects_in_columns(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc1", "oc2", "oc3")) import_objects( db_map, (("oc1", "o11"), ("oc1", "o12"), ("oc2", "o21"), ("oc3", "o31"), ("oc3", "o32"), ("oc3", "o33")) @@ -614,7 +614,7 @@ def test_object_classes_as_header_row_and_objects_in_columns(self): db_map.connection.close() def test_object_classes_as_table_names(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc1", "oc2", "oc3")) import_objects( db_map, (("oc1", "o11"), ("oc1", "o12"), ("oc2", "o21"), ("oc3", "o31"), ("oc3", "o32"), ("oc3", "o33")) @@ -629,7 +629,7 @@ def test_object_classes_as_table_names(self): db_map.connection.close() def test_object_class_and_parameter_definition_as_table_name(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc1", "oc2", "oc3")) import_object_parameters(db_map, (("oc1", "p11"), ("oc2", "p21"), ("oc2", "p22"))) import_objects(db_map, (("oc1", "o11"), ("oc1", "o12"), ("oc2", "o21"), ("oc3", "o31"))) @@ -648,7 +648,7 @@ def test_object_class_and_parameter_definition_as_table_name(self): db_map.connection.close() def test_object_relationship_name_as_table_name(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc1", "oc2")) import_objects(db_map, (("oc1", "o1"), ("oc1", "o2"), ("oc2", "O"))) import_relationship_classes(db_map, (("rc", ("oc1", "oc2")),)) @@ -664,7 +664,7 @@ def test_object_relationship_name_as_table_name(self): db_map.connection.close() def test_parameter_definitions_with_value_lists(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) import_parameter_value_lists(db_map, (("vl1", -1.0), ("vl2", -2.0))) import_object_parameters(db_map, (("oc", "p1", None, "vl1"), ("oc", "p2"))) @@ -681,7 +681,7 @@ def test_parameter_definitions_with_value_lists(self): db_map.connection.close() def test_parameter_definitions_and_values_and_value_lists(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) import_parameter_value_lists(db_map, (("vl", -1.0),)) import_object_parameters(db_map, (("oc", "p1", None, "vl"), ("oc", "p2"))) @@ -704,7 +704,7 @@ def test_parameter_definitions_and_values_and_value_lists(self): db_map.connection.close() def test_parameter_definitions_and_values_and_ignorable_value_lists(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) import_parameter_value_lists(db_map, (("vl", -1.0),)) import_object_parameters(db_map, (("oc", "p1", None, "vl"), ("oc", "p2"))) @@ -729,7 +729,7 @@ def test_parameter_definitions_and_values_and_ignorable_value_lists(self): db_map.connection.close() def test_parameter_value_lists(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_parameter_value_lists(db_map, (("vl1", -1.0), ("vl2", -2.0))) db_map.commit_session("Add test data.") value_list_mapping = ParameterValueListMapping(0) @@ -740,7 +740,7 @@ def test_parameter_value_lists(self): db_map.connection.close() def test_parameter_value_list_values(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_parameter_value_lists(db_map, (("vl1", -1.0), ("vl2", -2.0))) db_map.commit_session("Add test data.") value_list_mapping = ParameterValueListMapping(Position.table_name) @@ -753,7 +753,7 @@ def test_parameter_value_list_values(self): db_map.connection.close() def test_no_item_declared_as_title_gives_full_table(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc1", "oc2", "oc3")) import_object_parameters(db_map, (("oc1", "p11"), ("oc2", "p21"), ("oc2", "p22"))) import_objects(db_map, (("oc1", "o11"), ("oc1", "o12"), ("oc2", "o21"), ("oc3", "o31"))) @@ -770,7 +770,7 @@ def test_no_item_declared_as_title_gives_full_table(self): db_map.connection.close() def test_missing_values_for_alternatives(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) import_object_parameters(db_map, (("oc", "p1"), ("oc", "p2"))) import_alternatives(db_map, ("alt1", "alt2")) @@ -804,7 +804,7 @@ def test_missing_values_for_alternatives(self): db_map.connection.close() def test_export_relationship_classes(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc1", "oc2", "oc3")) import_relationship_classes( db_map, (("rc1", ("oc1",)), ("rc2", ("oc3", "oc2")), ("rc3", ("oc2", "oc3", "oc1"))) @@ -815,7 +815,7 @@ def test_export_relationship_classes(self): db_map.connection.close() def test_export_relationships(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc1", "oc2")) import_objects(db_map, (("oc1", "o11"), ("oc1", "o12"), ("oc2", "o21"))) import_relationship_classes(db_map, (("rc1", ("oc1",)), ("rc2", ("oc2", "oc1")))) @@ -829,7 +829,7 @@ def test_export_relationships(self): db_map.connection.close() def test_relationships_with_different_dimensions(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc1", "oc2")) import_objects(db_map, (("oc1", "o11"), ("oc1", "o12"), ("oc2", "o21"), ("oc2", "o22"))) import_relationship_classes(db_map, (("rc1D", ("oc1",)), ("rc2D", ("oc1", "oc2")))) @@ -865,7 +865,7 @@ def test_relationships_with_different_dimensions(self): db_map.connection.close() def test_default_parameter_values(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc1", "oc2", "oc3")) import_object_parameters(db_map, (("oc1", "p11", 3.14), ("oc2", "p21", 14.3), ("oc2", "p22", -1.0))) db_map.commit_session("Add test data.") @@ -879,7 +879,7 @@ def test_default_parameter_values(self): db_map.connection.close() def test_indexed_default_parameter_values(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc1", "oc2", "oc3")) import_object_parameters( db_map, @@ -909,7 +909,7 @@ def test_indexed_default_parameter_values(self): db_map.connection.close() def test_replace_parameter_indexes_by_external_data(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) import_object_parameters(db_map, (("oc", "p1"),)) import_objects(db_map, (("oc", "o1"), ("oc", "o2"))) @@ -939,7 +939,7 @@ def test_replace_parameter_indexes_by_external_data(self): db_map.connection.close() def test_constant_mapping_as_title(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc1", "oc2", "oc3")) db_map.commit_session("Add test data.") constant_mapping = FixedValueMapping(Position.table_name, "title_text") @@ -952,7 +952,7 @@ def test_constant_mapping_as_title(self): db_map.connection.close() def test_scenario_mapping(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_scenarios(db_map, ("s1", "s2")) db_map.commit_session("Add test data.") scenario_mapping = ScenarioMapping(0) @@ -963,7 +963,7 @@ def test_scenario_mapping(self): db_map.connection.close() def test_scenario_active_flag_mapping(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_scenarios(db_map, (("s1", True), ("s2", False))) db_map.commit_session("Add test data.") scenario_mapping = ScenarioMapping(0) @@ -976,7 +976,7 @@ def test_scenario_active_flag_mapping(self): db_map.connection.close() def test_scenario_alternative_mapping(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_alternatives(db_map, ("a1", "a2", "a3")) import_scenarios(db_map, ("s1", "s2", "empty")) import_scenario_alternatives(db_map, (("s1", "a2"), ("s1", "a1", "a2"), ("s2", "a2"), ("s2", "a3", "a2"))) @@ -991,7 +991,7 @@ def test_scenario_alternative_mapping(self): db_map.connection.close() def test_tool_mapping(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_tools(db_map, ("tool1", "tool2")) db_map.commit_session("Add test data.") tool_mapping = ToolMapping(0) @@ -1002,7 +1002,7 @@ def test_tool_mapping(self): db_map.connection.close() def test_feature_mapping(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc1", "oc2")) import_parameter_value_lists(db_map, (("features", "feat1"), ("features", "feat2"))) import_object_parameters( @@ -1025,7 +1025,7 @@ def test_feature_mapping(self): db_map.connection.close() def test_tool_feature_mapping(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc1", "oc2")) import_parameter_value_lists(db_map, (("features", "feat1"), ("features", "feat2"))) import_object_parameters( @@ -1058,7 +1058,7 @@ def test_tool_feature_mapping(self): db_map.connection.close() def test_tool_feature_method_mapping(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc1", "oc2")) import_parameter_value_lists(db_map, (("features", "feat1"), ("features", "feat2"))) import_object_parameters( @@ -1103,7 +1103,7 @@ def test_tool_feature_method_mapping(self): db_map.connection.close() def test_header(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) import_objects(db_map, (("oc", "o1"),)) db_map.commit_session("Add test data.") @@ -1113,21 +1113,21 @@ def test_header(self): db_map.connection.close() def test_header_without_data_still_creates_header(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) root = unflatten([ObjectClassMapping(0, header="class"), ObjectMapping(1, header="object")]) expected = [["class", "object"]] self.assertEqual(list(rows(root, db_map)), expected) db_map.connection.close() def test_header_in_half_pivot_table_without_data_still_creates_header(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) root = unflatten([ObjectClassMapping(-1, header="class"), ObjectMapping(9, header="object")]) expected = [["class"]] self.assertEqual(list(rows(root, db_map)), expected) db_map.connection.close() def test_header_in_pivot_table_without_data_still_creates_header(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) root = unflatten( [ ObjectClassMapping(-1, header="class"), @@ -1142,21 +1142,21 @@ def test_header_in_pivot_table_without_data_still_creates_header(self): db_map.connection.close() def test_disabled_empty_data_header(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) root = unflatten([ObjectClassMapping(0, header="class"), ObjectMapping(1, header="object")]) expected = [] self.assertEqual(list(rows(root, db_map, empty_data_header=False)), expected) db_map.connection.close() def test_disabled_empty_data_header_in_pivot_table(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) root = unflatten([ObjectClassMapping(-1, header="class"), ObjectMapping(0)]) expected = [] self.assertEqual(list(rows(root, db_map, empty_data_header=False)), expected) db_map.connection.close() def test_header_position(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) import_objects(db_map, (("oc", "o1"),)) db_map.commit_session("Add test data.") @@ -1166,7 +1166,7 @@ def test_header_position(self): db_map.connection.close() def test_header_position_with_relationships(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc1", "oc2")) import_objects(db_map, (("oc1", "o11"), ("oc2", "o21"))) import_relationship_classes(db_map, (("rc", ("oc1", "oc2")),)) @@ -1187,7 +1187,7 @@ def test_header_position_with_relationships(self): db_map.connection.close() def test_header_position_with_relationships_but_no_data(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc1", "oc2")) import_relationship_classes(db_map, (("rc", ("oc1", "oc2")),)) db_map.commit_session("Add test data.") @@ -1206,7 +1206,7 @@ def test_header_position_with_relationships_but_no_data(self): db_map.connection.close() def test_header_and_pivot(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_alternatives(db_map, ("alt",)) import_object_classes(db_map, ("oc",)) import_object_parameters(db_map, (("oc", "p1"),)) @@ -1247,7 +1247,7 @@ def test_header_and_pivot(self): db_map.connection.close() def test_pivot_without_left_hand_side_has_padding_column_for_headers(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_alternatives(db_map, ("alt",)) import_object_classes(db_map, ("oc",)) import_object_parameters(db_map, (("oc", "p1"),)) @@ -1367,7 +1367,7 @@ def test_serialization(self): self.assertEqual(m.highlight_dimension, highlight_dimension) def test_setting_ignorable_flag(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) db_map.commit_session("Add test data.") object_mapping = ObjectMapping(1) @@ -1379,7 +1379,7 @@ def test_setting_ignorable_flag(self): db_map.connection.close() def test_unsetting_ignorable_flag(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) import_objects(db_map, (("oc", "o1"),)) db_map.commit_session("Add test data.") @@ -1394,7 +1394,7 @@ def test_unsetting_ignorable_flag(self): db_map.connection.close() def test_filter(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) import_objects(db_map, (("oc", "o1"), ("oc", "o2"))) db_map.commit_session("Add test data.") @@ -1406,7 +1406,7 @@ def test_filter(self): db_map.connection.close() def test_hidden_tail_filter(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc1", "oc2")) import_objects(db_map, (("oc1", "o1"), ("oc2", "o2"))) db_map.commit_session("Add test data.") @@ -1418,7 +1418,7 @@ def test_hidden_tail_filter(self): db_map.connection.close() def test_index_names(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) import_object_parameters(db_map, (("oc", "p"),)) import_objects(db_map, (("oc", "o"),)) @@ -1430,7 +1430,7 @@ def test_index_names(self): db_map.connection.close() def test_default_value_index_names_with_nested_map(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) import_object_parameters( db_map, (("oc", "p", Map(["A"], [Map(["b"], [2.3], index_name="idx2")], index_name="idx1")),) @@ -1444,7 +1444,7 @@ def test_default_value_index_names_with_nested_map(self): db_map.connection.close() def test_multiple_index_names_with_empty_database(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) mapping = relationship_parameter_export( 0, 4, Position.hidden, 1, [2], [3], 5, Position.hidden, 8, [Position.header, Position.header], [6, 7] ) @@ -1453,7 +1453,7 @@ def test_multiple_index_names_with_empty_database(self): db_map.connection.close() def test_parameter_default_value_type(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc1", "oc2", "oc3")) import_object_parameters(db_map, (("oc1", "p11", 3.14), ("oc2", "p21", 14.3), ("oc2", "p22", -1.0))) db_map.commit_session("Add test data.") @@ -1467,7 +1467,7 @@ def test_parameter_default_value_type(self): db_map.connection.close() def test_map_with_more_dimensions_than_index_mappings(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) import_object_parameters(db_map, (("oc", "p"),)) import_objects(db_map, (("oc", "o"),)) @@ -1481,7 +1481,7 @@ def test_map_with_more_dimensions_than_index_mappings(self): db_map.connection.close() def test_default_map_value_with_more_dimensions_than_index_mappings(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) import_object_parameters(db_map, (("oc", "p", Map(["A"], [Map(["b"], [2.3])])),)) db_map.commit_session("Add test data.") @@ -1491,7 +1491,7 @@ def test_default_map_value_with_more_dimensions_than_index_mappings(self): db_map.connection.close() def test_map_with_single_value_mapping(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) import_object_parameters(db_map, (("oc", "p"),)) import_objects(db_map, (("oc", "o"),)) @@ -1503,7 +1503,7 @@ def test_map_with_single_value_mapping(self): db_map.connection.close() def test_default_map_value_with_single_value_mapping(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) import_object_parameters(db_map, (("oc", "p", Map(["A"], [2.3])),)) db_map.commit_session("Add test data.") diff --git a/tests/filters/test_alternative_filter.py b/tests/filters/test_alternative_filter.py index 0a09a8f5..80cc5b62 100644 --- a/tests/filters/test_alternative_filter.py +++ b/tests/filters/test_alternative_filter.py @@ -23,7 +23,7 @@ apply_alternative_filter_to_parameter_value_sq, create_new_spine_database, DatabaseMapping, - DiffDatabaseMapping, + DatabaseMapping, import_alternatives, import_object_classes, import_object_parameter_values, @@ -50,9 +50,9 @@ def setUpClass(cls): def setUp(self): create_new_spine_database(self._db_url) - self._out_map = DiffDatabaseMapping(self._db_url) + self._out_map = DatabaseMapping(self._db_url) self._db_map = DatabaseMapping(self._db_url) - self._diff_db_map = DiffDatabaseMapping(self._db_url) + self._diff_db_map = DatabaseMapping(self._db_url) def tearDown(self): self._out_map.connection.close() diff --git a/tests/filters/test_renamer.py b/tests/filters/test_renamer.py index 9637787b..d799fbfb 100644 --- a/tests/filters/test_renamer.py +++ b/tests/filters/test_renamer.py @@ -24,7 +24,7 @@ apply_renaming_to_entity_class_sq, create_new_spine_database, DatabaseMapping, - DiffDatabaseMapping, + DatabaseMapping, import_object_classes, import_object_parameters, import_relationship_classes, @@ -49,7 +49,7 @@ def setUpClass(cls): def setUp(self): create_new_spine_database(self._db_url) - self._out_map = DiffDatabaseMapping(self._db_url) + self._out_map = DatabaseMapping(self._db_url) self._db_map = DatabaseMapping(self._db_url) def tearDown(self): @@ -154,7 +154,7 @@ def setUpClass(cls): def setUp(self): create_new_spine_database(self._db_url) - self._out_map = DiffDatabaseMapping(self._db_url) + self._out_map = DatabaseMapping(self._db_url) self._db_map = DatabaseMapping(self._db_url) def tearDown(self): diff --git a/tests/filters/test_scenario_filter.py b/tests/filters/test_scenario_filter.py index 6925d4f6..2cfc409b 100644 --- a/tests/filters/test_scenario_filter.py +++ b/tests/filters/test_scenario_filter.py @@ -23,7 +23,7 @@ apply_scenario_filter_to_subqueries, create_new_spine_database, DatabaseMapping, - DiffDatabaseMapping, + DatabaseMapping, import_alternatives, import_object_classes, import_object_parameter_values, @@ -56,9 +56,9 @@ def setUpClass(cls): def setUp(self): create_new_spine_database(self._db_url) - self._out_map = DiffDatabaseMapping(self._db_url) + self._out_map = DatabaseMapping(self._db_url) self._db_map = DatabaseMapping(self._db_url) - self._diff_db_map = DiffDatabaseMapping(self._db_url) + self._diff_db_map = DatabaseMapping(self._db_url) def tearDown(self): self._out_map.connection.close() diff --git a/tests/filters/test_tool_filter.py b/tests/filters/test_tool_filter.py index e4f7fa63..213a3d7e 100644 --- a/tests/filters/test_tool_filter.py +++ b/tests/filters/test_tool_filter.py @@ -22,7 +22,7 @@ from spinedb_api import ( apply_tool_filter_to_entity_sq, create_new_spine_database, - DiffDatabaseMapping, + DatabaseMapping, import_object_classes, import_relationship_classes, import_object_parameter_values, @@ -57,7 +57,7 @@ def setUpClass(cls): def setUp(self): create_new_spine_database(self._db_url) - self._db_map = DiffDatabaseMapping(self._db_url) + self._db_map = DatabaseMapping(self._db_url) def tearDown(self): self._db_map.connection.close() diff --git a/tests/filters/test_tools.py b/tests/filters/test_tools.py index ead77e46..b1b7f53b 100644 --- a/tests/filters/test_tools.py +++ b/tests/filters/test_tools.py @@ -22,7 +22,7 @@ append_filter_config, clear_filter_configs, DatabaseMapping, - DiffDatabaseMapping, + DatabaseMapping, export_object_classes, import_object_classes, pop_filter_configs, @@ -93,7 +93,7 @@ class TestApplyFilterStack(unittest.TestCase): def setUpClass(cls): cls._dir = TemporaryDirectory() cls._db_url.database = os.path.join(cls._dir.name, ".json") - db_map = DiffDatabaseMapping(cls._db_url, create=True) + db_map = DatabaseMapping(cls._db_url, create=True) import_object_classes(db_map, ("object_class",)) db_map.commit_session("Add test data.") db_map.connection.close() @@ -140,7 +140,7 @@ class TestFilteredDatabaseMap(unittest.TestCase): def setUpClass(cls): cls._dir = TemporaryDirectory() cls._db_url.database = os.path.join(cls._dir.name, "TestFilteredDatabaseMap.json") - db_map = DiffDatabaseMapping(cls._db_url, create=True) + db_map = DatabaseMapping(cls._db_url, create=True) import_object_classes(db_map, ("object_class",)) db_map.commit_session("Add test data.") db_map.connection.close() diff --git a/tests/filters/test_value_transformer.py b/tests/filters/test_value_transformer.py index fe612e8b..dc4f5047 100644 --- a/tests/filters/test_value_transformer.py +++ b/tests/filters/test_value_transformer.py @@ -22,7 +22,7 @@ from spinedb_api import ( DatabaseMapping, - DiffDatabaseMapping, + DatabaseMapping, import_object_classes, import_object_parameter_values, import_object_parameters, @@ -110,7 +110,7 @@ def tearDownClass(cls): def setUp(self): create_new_spine_database(self._db_url) - self._out_map = DiffDatabaseMapping(self._db_url) + self._out_map = DatabaseMapping(self._db_url) def tearDown(self): self._out_map.connection.close() diff --git a/tests/spine_io/exporters/test_csv_writer.py b/tests/spine_io/exporters/test_csv_writer.py index feb50e5b..989299b8 100644 --- a/tests/spine_io/exporters/test_csv_writer.py +++ b/tests/spine_io/exporters/test_csv_writer.py @@ -17,7 +17,7 @@ from pathlib import Path from tempfile import TemporaryDirectory import unittest -from spinedb_api import DiffDatabaseMapping, import_object_classes, import_objects +from spinedb_api import DatabaseMapping, import_object_classes, import_objects from spinedb_api.mapping import Position from spinedb_api.export_mapping import object_export from spinedb_api.spine_io.exporters.writer import write @@ -32,7 +32,7 @@ def tearDown(self): self._temp_dir.cleanup() def test_write_empty_database(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) root_mapping = object_export(0, 1) out_path = Path(self._temp_dir.name, "out.csv") writer = CsvWriter(out_path.parent, out_path.name) @@ -43,7 +43,7 @@ def test_write_empty_database(self): db_map.connection.close() def test_write_single_object_class_and_object(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) import_objects(db_map, (("oc", "o1"),)) db_map.commit_session("Add test data.") @@ -57,7 +57,7 @@ def test_write_single_object_class_and_object(self): db_map.connection.close() def test_tables_are_written_to_separate_files(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc1", "oc2")) import_objects(db_map, (("oc1", "o1"), ("oc2", "o2"))) db_map.commit_session("Add test data.") @@ -81,7 +81,7 @@ def test_tables_are_written_to_separate_files(self): db_map.connection.close() def test_append_to_table(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) import_objects(db_map, (("oc", "o1"),)) db_map.commit_session("Add test data.") diff --git a/tests/spine_io/exporters/test_excel_writer.py b/tests/spine_io/exporters/test_excel_writer.py index ff3eb3d4..357f84b7 100644 --- a/tests/spine_io/exporters/test_excel_writer.py +++ b/tests/spine_io/exporters/test_excel_writer.py @@ -18,7 +18,7 @@ from tempfile import TemporaryDirectory import unittest from openpyxl import load_workbook -from spinedb_api import DiffDatabaseMapping, import_object_classes, import_objects +from spinedb_api import DatabaseMapping, import_object_classes, import_objects from spinedb_api.mapping import Position from spinedb_api.export_mapping import object_export from spinedb_api.spine_io.exporters.writer import write @@ -33,7 +33,7 @@ def tearDown(self): self._temp_dir.cleanup() def test_write_empty_database(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) root_mapping = object_export(0, 1) path = os.path.join(self._temp_dir.name, "test.xlsx") writer = ExcelWriter(path) @@ -46,7 +46,7 @@ def test_write_empty_database(self): db_map.connection.close() def test_write_single_object_class_and_object(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) import_objects(db_map, (("oc", "o1"),)) db_map.commit_session("Add test data.") @@ -62,7 +62,7 @@ def test_write_single_object_class_and_object(self): db_map.connection.close() def test_write_to_existing_sheet(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("Sheet1",)) import_objects(db_map, (("Sheet1", "o1"), ("Sheet1", "o2"))) db_map.commit_session("Add test data.") @@ -78,7 +78,7 @@ def test_write_to_existing_sheet(self): db_map.connection.close() def test_write_to_named_sheets(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc1", ("oc2"))) import_objects(db_map, (("oc1", "o11"), ("oc1", "o12"), ("oc2", "o21"))) db_map.commit_session("Add test data.") @@ -96,7 +96,7 @@ def test_write_to_named_sheets(self): db_map.connection.close() def test_append_to_anonymous_table(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) import_objects(db_map, (("oc", "o1"),)) db_map.commit_session("Add test data.") @@ -113,7 +113,7 @@ def test_append_to_anonymous_table(self): db_map.connection.close() def test_append_to_named_table(self): - db_map = DiffDatabaseMapping("sqlite://", create=True) + db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) import_objects(db_map, (("oc", "o1"),)) db_map.commit_session("Add test data.") diff --git a/tests/test_DiffDatabaseMapping.py b/tests/test_DiffDatabaseMapping.py index 71dac987..3122a532 100644 --- a/tests/test_DiffDatabaseMapping.py +++ b/tests/test_DiffDatabaseMapping.py @@ -10,7 +10,7 @@ ###################################################################################################################### """ -Unit tests for DiffDatabaseMapping class. +Unit tests for DatabaseMapping class. :author: P. Vennström (VTT) :date: 29.11.2018 @@ -22,7 +22,7 @@ from unittest import mock from sqlalchemy.engine.url import make_url, URL from sqlalchemy.util import KeyedTuple -from spinedb_api.diff_db_mapping import DiffDatabaseMapping +from spinedb_api.db_mapping import DatabaseMapping from spinedb_api.exception import SpineIntegrityError from spinedb_api.db_cache import DBCache from spinedb_api import import_functions, SpineDBAPIError @@ -42,17 +42,17 @@ def query_wrapper(*args, orig_query=db_map.query, **kwargs): def create_diff_db_map(): - return DiffDatabaseMapping(IN_MEMORY_DB_URL, username="UnitTest", create=True) + return DatabaseMapping(IN_MEMORY_DB_URL, username="UnitTest", create=True) -class TestDiffDatabaseMappingConstruction(unittest.TestCase): +class TestDatabaseMappingConstruction(unittest.TestCase): def test_construction_with_filters(self): db_url = IN_MEMORY_DB_URL + "?spinedbfilter=fltr1&spinedbfilter=fltr2" with mock.patch("spinedb_api.diff_db_mapping.apply_filter_stack") as mock_apply: with mock.patch( "spinedb_api.diff_db_mapping.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] ) as mock_load: - db_map = DiffDatabaseMapping(db_url, create=True) + db_map = DatabaseMapping(db_url, create=True) db_map.connection.close() mock_load.assert_called_once_with(["fltr1", "fltr2"]) mock_apply.assert_called_once_with(db_map, [{"fltr1": "config1", "fltr2": "config2"}]) @@ -64,7 +64,7 @@ def test_construction_with_sqlalchemy_url_and_filters(self): with mock.patch( "spinedb_api.diff_db_mapping.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] ) as mock_load: - db_map = DiffDatabaseMapping(sa_url, create=True) + db_map = DatabaseMapping(sa_url, create=True) db_map.connection.close() mock_load.assert_called_once_with(["fltr1", "fltr2"]) mock_apply.assert_called_once_with(db_map, [{"fltr1": "config1", "fltr2": "config2"}]) @@ -73,19 +73,19 @@ def test_shorthand_filter_query_works(self): with TemporaryDirectory() as temp_dir: url = URL("sqlite") url.database = os.path.join(temp_dir, "test_shorthand_filter_query_works.json") - out_db = DiffDatabaseMapping(url, create=True) + out_db = DatabaseMapping(url, create=True) out_db.add_tools({"name": "object_activity_control", "id": 1}) out_db.commit_session("Add tool.") out_db.connection.close() try: - db_map = DiffDatabaseMapping(url) + db_map = DatabaseMapping(url) except: - self.fail("DiffDatabaseMapping.__init__() should not raise.") + self.fail("DatabaseMapping.__init__() should not raise.") else: db_map.connection.close() -class TestDiffDatabaseMappingRemove(unittest.TestCase): +class TestDatabaseMappingRemove(unittest.TestCase): def setUp(self): self._db_map = create_diff_db_map() @@ -409,7 +409,7 @@ def test_cascade_remove_parameter_value_removes_its_metadata(self): self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 0) -class TestDiffDatabaseMappingAdd(unittest.TestCase): +class TestDatabaseMappingAdd(unittest.TestCase): def setUp(self): self._db_map = create_diff_db_map() @@ -428,10 +428,7 @@ def test_add_and_retrieve_many_objects(self): def test_add_object_classes(self): """Test that adding object classes works.""" self._db_map.add_object_classes({"name": "fish"}, {"name": "dog"}) - diff_table = self._db_map._diff_table("entity_class") - object_classes = ( - self._db_map.query(diff_table).filter(diff_table.c.type_id == self._db_map.object_class_type).all() - ) + object_classes = self._db_map.query(self._db_map.object_class_sq).all() self.assertEqual(len(object_classes), 2) self.assertEqual(object_classes[0].name, "fish") self.assertEqual(object_classes[1].name, "dog") @@ -444,10 +441,7 @@ def test_add_object_class_with_invalid_name(self): def test_add_object_classes_with_same_name(self): """Test that adding two object classes with the same name only adds one of them.""" self._db_map.add_object_classes({"name": "fish"}, {"name": "fish"}) - diff_table = self._db_map._diff_table("entity_class") - object_classes = ( - self._db_map.query(diff_table).filter(diff_table.c.type_id == self._db_map.object_class_type).all() - ) + object_classes = self._db_map.query(self._db_map.object_class_sq).all() self.assertEqual(len(object_classes), 1) self.assertEqual(object_classes[0].name, "fish") @@ -461,8 +455,7 @@ def test_add_objects(self): """Test that adding objects works.""" self._db_map.add_object_classes({"name": "fish"}) self._db_map.add_objects({"name": "nemo", "class_id": 1}, {"name": "dory", "class_id": 1}) - diff_table = self._db_map._diff_table("entity") - objects = self._db_map.query(diff_table).filter(diff_table.c.type_id == self._db_map.object_entity_type).all() + objects = self._db_map.query(self._db_map.object_sq).all() self.assertEqual(len(objects), 2) self.assertEqual(objects[0].name, "nemo") self.assertEqual(objects[0].class_id, 1) @@ -479,8 +472,7 @@ def test_add_objects_with_same_name(self): """Test that adding two objects with the same name only adds one of them.""" self._db_map.add_object_classes({"name": "fish"}) self._db_map.add_objects({"name": "nemo", "class_id": 1}, {"name": "nemo", "class_id": 1}) - diff_table = self._db_map._diff_table("entity") - objects = self._db_map.query(diff_table).filter(diff_table.c.type_id == self._db_map.object_entity_type).all() + objects = self._db_map.query(self._db_map.object_sq).all() self.assertEqual(len(objects), 1) self.assertEqual(objects[0].name, "nemo") self.assertEqual(objects[0].class_id, 1) @@ -504,19 +496,16 @@ def test_add_relationship_classes(self): self._db_map.add_wide_relationship_classes( {"name": "rc1", "object_class_id_list": [1, 2]}, {"name": "rc2", "object_class_id_list": [2, 1]} ) - diff_table = self._db_map._diff_table("relationship_entity_class") - rel_ent_clss = self._db_map.query(diff_table).all() - diff_table = self._db_map._diff_table("entity_class") - rel_clss = ( - self._db_map.query(diff_table).filter(diff_table.c.type_id == self._db_map.relationship_class_type).all() - ) - self.assertEqual(len(rel_ent_clss), 4) + diff_table = self._db_map.get_table("entity_class_dimension") + ent_cls_dims = self._db_map.query(diff_table).all() + rel_clss = self._db_map.query(self._db_map.wide_relationship_class_sq).all() + self.assertEqual(len(ent_cls_dims), 4) self.assertEqual(rel_clss[0].name, "rc1") - self.assertEqual(rel_ent_clss[0].member_class_id, 1) - self.assertEqual(rel_ent_clss[1].member_class_id, 2) + self.assertEqual(ent_cls_dims[0].dimension_id, 1) + self.assertEqual(ent_cls_dims[1].dimension_id, 2) self.assertEqual(rel_clss[1].name, "rc2") - self.assertEqual(rel_ent_clss[2].member_class_id, 2) - self.assertEqual(rel_ent_clss[3].member_class_id, 1) + self.assertEqual(ent_cls_dims[2].dimension_id, 2) + self.assertEqual(ent_cls_dims[3].dimension_id, 1) def test_add_relationship_classes_with_invalid_name(self): """Test that adding object classes with empty name raises error""" @@ -530,25 +519,22 @@ def test_add_relationship_classes_with_same_name(self): self._db_map.add_wide_relationship_classes( {"name": "rc1", "object_class_id_list": [1, 2]}, {"name": "rc1", "object_class_id_list": [1, 2]} ) - diff_table = self._db_map._diff_table("relationship_entity_class") - relationship_members = self._db_map.query(diff_table).all() - diff_table = self._db_map._diff_table("entity_class") - relationships = ( - self._db_map.query(diff_table).filter(diff_table.c.type_id == self._db_map.relationship_class_type).all() - ) - self.assertEqual(len(relationship_members), 2) - self.assertEqual(len(relationships), 1) - self.assertEqual(relationships[0].name, "rc1") - self.assertEqual(relationship_members[0].member_class_id, 1) - self.assertEqual(relationship_members[1].member_class_id, 2) + diff_table = self._db_map.get_table("entity_class_dimension") + ecs_dims = self._db_map.query(diff_table).all() + relationship_classes = self._db_map.query(self._db_map.wide_relationship_class_sq).all() + self.assertEqual(len(ecs_dims), 2) + self.assertEqual(len(relationship_classes), 1) + self.assertEqual(relationship_classes[0].name, "rc1") + self.assertEqual(ecs_dims[0].dimension_id, 1) + self.assertEqual(ecs_dims[1].dimension_id, 2) def test_add_relationship_class_with_same_name_as_existing_one(self): """Test that adding a relationship class with an already taken name raises an integrity error.""" query_wrapper = create_query_wrapper(self._db_map) - with mock.patch.object(DiffDatabaseMapping, "query") as mock_query, mock.patch.object( - DiffDatabaseMapping, "object_class_sq" + with mock.patch.object(DatabaseMapping, "query") as mock_query, mock.patch.object( + DatabaseMapping, "object_class_sq" ) as mock_object_class_sq, mock.patch.object( - DiffDatabaseMapping, "wide_relationship_class_sq" + DatabaseMapping, "wide_relationship_class_sq" ) as mock_wide_rel_cls_sq: mock_query.side_effect = query_wrapper mock_object_class_sq.return_value = [ @@ -566,9 +552,9 @@ def test_add_relationship_class_with_same_name_as_existing_one(self): def test_add_relationship_class_with_invalid_object_class(self): """Test that adding a relationship class with a non existing object class raises an integrity error.""" query_wrapper = create_query_wrapper(self._db_map) - with mock.patch.object(DiffDatabaseMapping, "query") as mock_query, mock.patch.object( - DiffDatabaseMapping, "object_class_sq" - ) as mock_object_class_sq, mock.patch.object(DiffDatabaseMapping, "wide_relationship_class_sq"): + with mock.patch.object(DatabaseMapping, "query") as mock_query, mock.patch.object( + DatabaseMapping, "object_class_sq" + ) as mock_object_class_sq, mock.patch.object(DatabaseMapping, "wide_relationship_class_sq"): mock_query.side_effect = query_wrapper mock_object_class_sq.return_value = [KeyedTuple([1, "fish"], labels=["id", "name"])] with self.assertRaises(SpineIntegrityError): @@ -582,19 +568,19 @@ def test_add_relationships(self): self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "object_class_id_list": [1, 2]}) self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 2}) self._db_map.add_wide_relationships({"name": "nemo__pluto", "class_id": 3, "object_id_list": [1, 2]}) - diff_table = self._db_map._diff_table("relationship_entity") - rel_ents = self._db_map.query(diff_table).all() - diff_table = self._db_map._diff_table("entity") - relationships = ( - self._db_map.query(diff_table).filter(diff_table.c.type_id == self._db_map.relationship_entity_type).all() - ) - self.assertEqual(len(rel_ents), 2) + # FIXME + self._db_map.commit_session("Ok") + diff_table = self._db_map.get_table("entity_element") + ent_els = self._db_map.query(diff_table).all() + diff_table = self._db_map.get_table("entity") + relationships = self._db_map.query(diff_table).filter(diff_table.c.id.in_({x.entity_id for x in ent_els})).all() + self.assertEqual(len(ent_els), 2) self.assertEqual(len(relationships), 1) self.assertEqual(relationships[0].name, "nemo__pluto") - self.assertEqual(rel_ents[0].entity_class_id, 3) - self.assertEqual(rel_ents[0].member_id, 1) - self.assertEqual(rel_ents[1].entity_class_id, 3) - self.assertEqual(rel_ents[1].member_id, 2) + self.assertEqual(ent_els[0].entity_class_id, 3) + self.assertEqual(ent_els[0].member_id, 1) + self.assertEqual(ent_els[1].entity_class_id, 3) + self.assertEqual(ent_els[1].member_id, 2) def test_add_relationship_with_invalid_name(self): """Test that adding object classes with empty name raises error""" @@ -607,14 +593,14 @@ def test_add_relationship_with_invalid_name(self): def test_add_identical_relationships(self): """Test that adding two relationships with the same class and same objects only adds the first one.""" self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) - self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "object_class_id_list": [1, 2]}) + self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "dimension_id_list": [1, 2]}) self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 2}) self._db_map.add_wide_relationships( - {"name": "nemo__pluto", "class_id": 3, "object_id_list": [1, 2]}, - {"name": "nemo__pluto_duplicate", "class_id": 3, "object_id_list": [1, 2]}, + {"name": "nemo__pluto", "class_id": 3, "element_id_list": [1, 2]}, + {"name": "nemo__pluto_duplicate", "class_id": 3, "element_id_list": [1, 2]}, ) - diff_table = self._db_map._diff_table("relationship") - relationships = self._db_map.query(diff_table).all() + self._db_map.commit_session("Ok") + relationships = self._db_map.query(self._db_map.wide_relationship_sq).all() self.assertEqual(len(relationships), 1) def test_add_relationship_identical_to_existing_one(self): @@ -622,12 +608,12 @@ def test_add_relationship_identical_to_existing_one(self): raises an integrity error. """ query_wrapper = create_query_wrapper(self._db_map) - with mock.patch.object(DiffDatabaseMapping, "query") as mock_query, mock.patch.object( - DiffDatabaseMapping, "object_sq" + with mock.patch.object(DatabaseMapping, "query") as mock_query, mock.patch.object( + DatabaseMapping, "object_sq" ) as mock_object_sq, mock.patch.object( - DiffDatabaseMapping, "wide_relationship_class_sq" + DatabaseMapping, "wide_relationship_class_sq" ) as mock_wide_rel_cls_sq, mock.patch.object( - DiffDatabaseMapping, "wide_relationship_sq" + DatabaseMapping, "wide_relationship_sq" ) as mock_wide_rel_sq: mock_query.side_effect = query_wrapper mock_object_sq.return_value = [ @@ -648,12 +634,12 @@ def test_add_relationship_identical_to_existing_one(self): def test_add_relationship_with_invalid_class(self): """Test that adding a relationship with an invalid class raises an integrity error.""" query_wrapper = create_query_wrapper(self._db_map) - with mock.patch.object(DiffDatabaseMapping, "query") as mock_query, mock.patch.object( - DiffDatabaseMapping, "object_sq" + with mock.patch.object(DatabaseMapping, "query") as mock_query, mock.patch.object( + DatabaseMapping, "object_sq" ) as mock_object_sq, mock.patch.object( - DiffDatabaseMapping, "wide_relationship_class_sq" + DatabaseMapping, "wide_relationship_class_sq" ) as mock_wide_rel_cls_sq, mock.patch.object( - DiffDatabaseMapping, "wide_relationship_sq" + DatabaseMapping, "wide_relationship_sq" ): mock_query.side_effect = query_wrapper mock_object_sq.return_value = [ @@ -671,12 +657,12 @@ def test_add_relationship_with_invalid_class(self): def test_add_relationship_with_invalid_object(self): """Test that adding a relationship with an invalid object raises an integrity error.""" query_wrapper = create_query_wrapper(self._db_map) - with mock.patch.object(DiffDatabaseMapping, "query") as mock_query, mock.patch.object( - DiffDatabaseMapping, "object_sq" + with mock.patch.object(DatabaseMapping, "query") as mock_query, mock.patch.object( + DatabaseMapping, "object_sq" ) as mock_object_sq, mock.patch.object( - DiffDatabaseMapping, "wide_relationship_class_sq" + DatabaseMapping, "wide_relationship_class_sq" ) as mock_wide_rel_cls_sq, mock.patch.object( - DiffDatabaseMapping, "wide_relationship_sq" + DatabaseMapping, "wide_relationship_sq" ): mock_query.side_effect = query_wrapper mock_object_sq.return_value = [ @@ -696,7 +682,7 @@ def test_add_entity_groups(self): self._db_map.add_object_classes({"name": "oc1", "id": 1}) self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 1}) self._db_map.add_entity_groups({"entity_id": 1, "entity_class_id": 1, "member_id": 2}) - diff_table = self._db_map._diff_table("entity_group") + diff_table = self._db_map.get_table("entity_group") entity_groups = self._db_map.query(diff_table).all() self.assertEqual(len(entity_groups), 1) self.assertEqual(entity_groups[0].entity_id, 1) @@ -740,7 +726,7 @@ def test_add_parameter_definitions(self): {"name": "color", "object_class_id": 1, "description": "test1"}, {"name": "relative_speed", "relationship_class_id": 3, "description": "test2"}, ) - diff_table = self._db_map._diff_table("parameter_definition") + diff_table = self._db_map.get_table("parameter_definition") parameter_definitions = self._db_map.query(diff_table).all() self.assertEqual(len(parameter_definitions), 2) self.assertEqual(parameter_definitions[0].name, "color") @@ -763,7 +749,7 @@ def test_add_parameter_definitions_with_same_name(self): self._db_map.add_parameter_definitions( {"name": "color", "object_class_id": 1}, {"name": "color", "relationship_class_id": 3} ) - diff_table = self._db_map._diff_table("parameter_definition") + diff_table = self._db_map.get_table("parameter_definition") parameter_definitions = self._db_map.query(diff_table).all() self.assertEqual(len(parameter_definitions), 2) self.assertEqual(parameter_definitions[0].name, "color") @@ -837,7 +823,7 @@ def test_add_parameter_values(self): "alternative_id": 1, }, ) - diff_table = self._db_map._diff_table("parameter_value") + diff_table = self._db_map.get_table("parameter_value") parameter_values = self._db_map.query(diff_table).all() self.assertEqual(len(parameter_values), 2) self.assertEqual(parameter_values[0].parameter_definition_id, 1) @@ -893,7 +879,7 @@ def test_add_same_parameter_value_twice(self): "alternative_id": 1, }, ) - diff_table = self._db_map._diff_table("parameter_value") + diff_table = self._db_map.get_table("parameter_value") parameter_values = self._db_map.query(diff_table).all() self.assertEqual(len(parameter_values), 1) self.assertEqual(parameter_values[0].parameter_definition_id, 1) @@ -1198,7 +1184,7 @@ def test_add_ext_parameter_value_metadata_reuses_existing_metadata(self): ) -class TestDiffDatabaseMappingUpdate(unittest.TestCase): +class TestDatabaseMappingUpdate(unittest.TestCase): def setUp(self): self._db_map = create_diff_db_map() @@ -1296,7 +1282,7 @@ def test_update_relationships(self): self.assertEqual(rels[4]["object_id_list"], "1,3") -class TestDiffDatabaseMappingCommit(unittest.TestCase): +class TestDatabaseMappingCommit(unittest.TestCase): def setUp(self): self._db_map = create_diff_db_map() diff --git a/tests/test_export_functions.py b/tests/test_export_functions.py index 28d82f73..7d3d9ce0 100644 --- a/tests/test_export_functions.py +++ b/tests/test_export_functions.py @@ -18,7 +18,7 @@ import unittest from spinedb_api import ( - DiffDatabaseMapping, + DatabaseMapping, export_alternatives, export_data, export_scenarios, @@ -49,7 +49,7 @@ class TestExportFunctions(unittest.TestCase): def setUp(self): db_url = "sqlite://" - self._db_map = DiffDatabaseMapping(db_url, username="UnitTest", create=True) + self._db_map = DatabaseMapping(db_url, username="UnitTest", create=True) def tearDown(self): self._db_map.connection.close() diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py index d6df9e02..83393d73 100644 --- a/tests/test_import_functions.py +++ b/tests/test_import_functions.py @@ -19,7 +19,6 @@ import unittest from spinedb_api.spine_db_server import _unparse_value -from spinedb_api.diff_db_mapping import DiffDatabaseMapping from spinedb_api.db_mapping import DatabaseMapping from spinedb_api.import_functions import ( import_alternatives, @@ -79,13 +78,13 @@ def _assert_same_elements(test, obs_vals, exp_vals): def create_diff_db_map(): db_url = "sqlite://" - return DiffDatabaseMapping(db_url, username="UnitTest", create=True) + return DatabaseMapping(db_url, username="UnitTest", create=True) class TestIntegrationImportData(unittest.TestCase): def test_import_data_integration(self): database_url = "sqlite://" - db_map = DiffDatabaseMapping(database_url, username="IntegrationTest", create=True) + db_map = DatabaseMapping(database_url, username="IntegrationTest", create=True) object_c = ["example_class", "other_class"] # 2 items objects = [["example_class", "example_object"], ["other_class", "other_object"]] # 2 items diff --git a/tests/test_migration.py b/tests/test_migration.py index bdfb27b2..509e6c89 100644 --- a/tests/test_migration.py +++ b/tests/test_migration.py @@ -21,7 +21,7 @@ from sqlalchemy import inspect from sqlalchemy.engine.url import URL from spinedb_api.helpers import create_new_spine_database, _create_first_spine_database, is_head_engine, schema_dict -from spinedb_api import DiffDatabaseMapping +from spinedb_api import DatabaseMapping class TestMigration(unittest.TestCase): @@ -94,7 +94,7 @@ def test_upgrade_content(self): engine.execute("INSERT INTO parameter_value (parameter_id, relationship_id, value) VALUES (3, 1, '100')") engine.execute("INSERT INTO parameter_value (parameter_id, relationship_id, value) VALUES (3, 2, '-1')") # Upgrade the db and check that our stuff is still there - db_map = DiffDatabaseMapping(db_url, upgrade=True) + db_map = DatabaseMapping(db_url, upgrade=True) object_classes = {x.id: x.name for x in db_map.object_class_list()} objects = {x.id: (object_classes[x.class_id], x.name) for x in db_map.object_list()} rel_clss = {x.id: (x.name, x.object_class_name_list) for x in db_map.wide_relationship_class_list()} From 5070d7c54215e4dd5547084ec1d8f71a71a26fe5 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 29 Mar 2023 16:26:13 +0200 Subject: [PATCH 012/317] Fix lots of tests and missing bits --- spinedb_api/db_cache.py | 3 ++ spinedb_api/db_mapping_base.py | 18 ++++++---- spinedb_api/db_mapping_check_mixin.py | 39 ++++++++++++++------- spinedb_api/filters/renamer.py | 2 -- spinedb_api/filters/tool_filter.py | 4 +-- spinedb_api/filters/value_transformer.py | 12 +++---- spinedb_api/import_functions.py | 2 +- tests/export_mapping/test_export_mapping.py | 4 +-- tests/filters/test_alternative_filter.py | 5 ++- tests/filters/test_renamer.py | 4 +-- tests/filters/test_scenario_filter.py | 5 ++- tests/filters/test_value_transformer.py | 2 -- tests/test_DatabaseMapping.py | 38 +++----------------- tests/test_DiffDatabaseMapping.py | 26 ++++++-------- tests/test_import_functions.py | 8 ++--- 15 files changed, 75 insertions(+), 97 deletions(-) diff --git a/spinedb_api/db_cache.py b/spinedb_api/db_cache.py index b6818b13..ba328024 100644 --- a/spinedb_api/db_cache.py +++ b/spinedb_api/db_cache.py @@ -17,6 +17,9 @@ from operator import itemgetter +# TODO: Implement CacheItem.pop() to do lookup? + + class DBCache(dict): def __init__(self, advance_query, *args, **kwargs): """ diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index d0187acd..83a9fbcc 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -1567,13 +1567,13 @@ def entity_parameter_value_sq(self): self.parameter_definition_sq.c.entity_class_id, self.parameter_definition_sq.c.object_class_id, self.parameter_definition_sq.c.relationship_class_id, - self.entity_class_sq.c.name.label("entity_class_name"), + self.ext_entity_class_sq.c.name.label("entity_class_name"), label("object_class_name", self._object_class_name()), label("relationship_class_name", self._relationship_class_name()), label("object_class_id_list", self._object_class_id_list()), label("object_class_name_list", self._object_class_name_list()), self.parameter_value_sq.c.entity_id, - self.entity_sq.c.name.label("entity_name"), + self.ext_entity_sq.c.name.label("entity_name"), self.parameter_value_sq.c.object_id, self.parameter_value_sq.c.relationship_id, label("object_name", self._object_name()), @@ -1592,13 +1592,17 @@ def entity_parameter_value_sq(self): self.parameter_definition_sq, self.parameter_definition_sq.c.id == self.parameter_value_sq.c.parameter_definition_id, ) - .join(self.entity_sq, self.parameter_value_sq.c.entity_id == self.entity_sq.c.id) - .join(self.entity_class_sq, self.parameter_definition_sq.c.entity_class_id == self.entity_class_sq.c.id) + .join(self.ext_entity_sq, self.parameter_value_sq.c.entity_id == self.ext_entity_sq.c.id) + .join( + self.ext_entity_class_sq, + self.parameter_definition_sq.c.entity_class_id == self.ext_entity_class_sq.c.id, + ) .join(self.alternative_sq, self.parameter_value_sq.c.alternative_id == self.alternative_sq.c.id) .outerjoin( - self.wide_relationship_class_sq, self.wide_relationship_class_sq.c.id == self.entity_class_sq.c.id + self.wide_relationship_class_sq, + self.wide_relationship_class_sq.c.id == self.ext_entity_class_sq.c.id, ) - .outerjoin(self.wide_relationship_sq, self.wide_relationship_sq.c.id == self.entity_sq.c.id) + .outerjoin(self.wide_relationship_sq, self.wide_relationship_sq.c.id == self.ext_entity_sq.c.id) # object_id_list might be None when objects have been filtered out .filter( or_( @@ -2170,7 +2174,7 @@ def _object_class_name_list(self): ) def _object_name(self): - return case([(self.ext_entity_sq.c.element_id_list == None, self.entity_sq.c.name)], else_=None) + return case([(self.ext_entity_sq.c.element_id_list == None, self.ext_entity_sq.c.name)], else_=None) def _object_id_list(self): return case( diff --git a/spinedb_api/db_mapping_check_mixin.py b/spinedb_api/db_mapping_check_mixin.py index eb6de908..6e2943c7 100644 --- a/spinedb_api/db_mapping_check_mixin.py +++ b/spinedb_api/db_mapping_check_mixin.py @@ -17,7 +17,6 @@ # TODO: Review docstrings, they are almost good from contextlib import contextmanager -from itertools import chain from .exception import SpineIntegrityError from .check_functions import ( check_alternative, @@ -499,7 +498,8 @@ def check_wide_relationship_classes(self, *wide_items, for_update=False, strict= intgr_error_log, ) as wide_item: if "object_class_id_list" not in wide_item: - wide_item["object_class_id_list"] = wide_item.pop("dimension_id_list", ()) + # Use CacheItem.get rather than pop since the former implements the lookup + wide_item["object_class_id_list"] = wide_item.get("dimension_id_list", ()) check_wide_relationship_class(wide_item, relationship_class_ids, object_class_ids) checked_wide_items.append(wide_item) except SpineIntegrityError as e: @@ -547,16 +547,18 @@ def check_wide_relationships(self, *wide_items, for_update=False, strict=False, wide_item, { ("class_id", "name"): relationship_ids_by_name, - ("class_id", "element_id_list"): relationship_ids_by_obj_lst, + ("class_id", "object_id_list"): relationship_ids_by_obj_lst, }, for_update, cache, intgr_error_log, ) as wide_item: if "object_class_id_list" not in wide_item: - wide_item["object_class_id_list"] = wide_item.pop("dimension_id_list", ()) + # NOTE: Use CacheItem.get rather than pop since the former implements the lookup + wide_item["object_class_id_list"] = wide_item.get("dimension_id_list", ()) if "object_id_list" not in wide_item: - wide_item["object_id_list"] = wide_item.pop("element_id_list", ()) + # NOTE: Use CacheItem.get rather than pop since the former implements the lookup + wide_item["object_id_list"] = wide_item.get("element_id_list", ()) check_wide_relationship( wide_item, relationship_ids_by_name, @@ -630,8 +632,17 @@ def check_parameter_definitions(self, *items, for_update=False, strict=False, ca parameter_value_lists = {x.id: x.value_id_list for x in cache.get("parameter_value_list", {}).values()} list_values = {x.id: from_database(x.value, x.type) for x in cache.get("list_value", {}).values()} for item in items: - if "entity_class_id" not in item: - item["entity_class_id"] = item.get("object_class_id") or item.get("relationship_class_id") + object_class_id = item.get("object_class_id") + relationship_class_id = item.get("relationship_class_id") + if object_class_id and relationship_class_id: + e = SpineIntegrityError("Can't associate a parameter to both an object and a relationship class.") + if strict: + raise e + intgr_error_log.append(e) + continue + entity_class_id = object_class_id or relationship_class_id + if "entity_class_id" not in item and entity_class_id is not None: + item["entity_class_id"] = entity_class_id try: if ( for_update @@ -691,8 +702,9 @@ def check_parameter_values(self, *items, for_update=False, strict=False, cache=N list_values = {x.id: from_database(x.value, x.type) for x in cache.get("list_value", {}).values()} alternatives = set(a.id for a in cache.get("alternative", {}).values()) for item in items: - if "entity_id" not in item: - item["entity_id"] = item.get("object_id") or item.get("relationship_id") + entity_id = item.get("object_id") or item.get("relationship_id") + if "entity_id" not in item and entity_id is not None: + item["entity_id"] = entity_id try: with self._manage_stocks( "parameter_value", @@ -846,8 +858,7 @@ def check_entity_metadata(self, *items, for_update=False, strict=False, cache=No cache = self.make_cache({"entity_metadata"}, include_ancestors=True) intgr_error_log = [] checked_items = list() - entities = {x.id for x in cache.get("object", {}).values()} - entities |= {x.id for x in cache.get("relationship", {}).values()} + entities = {x.id for x in cache.get("entity", {}).values()} metadata = {x.id for x in cache.get("metadata", {}).values()} for item in items: try: @@ -952,13 +963,17 @@ def _get_key(item, pk): def _fix_immutable_fields(item_type, current_item, item): immutable_fields = { - "object": ("class_id",), + "entity_class": ("dimension_id_list",), "relationship_class": ("object_class_id_list",), + "object": ("class_id",), "relationship": ("class_id",), + "entity": ("class_id",), "parameter_definition": ("entity_class_id", "object_class_id", "relationship_class_id"), "parameter_value": ("entity_class_id", "object_class_id", "relationship_class_id"), }.get(item_type, ()) fixed = [] + # FIXME: we need to be able to identify object_class_id_list as dimension_id_list + # for relationship class items for field in immutable_fields: if current_item.get(field) is None: continue diff --git a/spinedb_api/filters/renamer.py b/spinedb_api/filters/renamer.py index 135a0d65..ac0e2ac8 100644 --- a/spinedb_api/filters/renamer.py +++ b/spinedb_api/filters/renamer.py @@ -212,13 +212,11 @@ def _make_renaming_entity_class_sq(db_map, state): new_class_name = case(cases, else_=subquery.c.name) # if not in the name map, just keep the original name entity_class_sq = db_map.query( subquery.c.id, - subquery.c.type_id, new_class_name.label("name"), subquery.c.description, subquery.c.display_order, subquery.c.display_icon, subquery.c.hidden, - subquery.c.commit_id, ).subquery() return entity_class_sq diff --git a/spinedb_api/filters/tool_filter.py b/spinedb_api/filters/tool_filter.py index 0ee6e11f..61b0b720 100644 --- a/spinedb_api/filters/tool_filter.py +++ b/spinedb_api/filters/tool_filter.py @@ -17,7 +17,7 @@ """ from functools import partial from uuid import uuid4 -from sqlalchemy import and_, or_, case, func, Table, Column, ForeignKey +from sqlalchemy import and_, or_, case, func, Column, ForeignKey from ..exception import SpineDBAPIError @@ -164,12 +164,10 @@ def active_entity_id_sq(db_map, tool_id): Alias: subquery """ tool_feature_method_sq = _make_ext_tool_feature_method_sq(db_map, tool_id) - method_filter = _make_method_filter( tool_feature_method_sq, db_map.parameter_value_sq, db_map.parameter_definition_sq ) required_filter = _make_required_filter(tool_feature_method_sq, db_map.parameter_value_sq) - return ( db_map.query(db_map.entity_sq.c.id) .outerjoin( diff --git a/spinedb_api/filters/value_transformer.py b/spinedb_api/filters/value_transformer.py index 144b1753..6c5a80e1 100644 --- a/spinedb_api/filters/value_transformer.py +++ b/spinedb_api/filters/value_transformer.py @@ -191,17 +191,13 @@ def _make_parameter_value_transforming_sq(db_map, state): new_value = case([(temp_sq.c.transformed_value != None, temp_sq.c.transformed_value)], else_=subquery.c.value) new_type = case([(temp_sq.c.transformed_type != None, temp_sq.c.transformed_type)], else_=subquery.c.type) object_class_case = case( - [(db_map.entity_class_sq.c.type_id == db_map.object_class_type, subquery.c.entity_class_id)], else_=None + [(db_map.ext_entity_class_sq.c.dimension_id_list == None, subquery.c.entity_class_id)], else_=None ) rel_class_case = case( - [(db_map.entity_class_sq.c.type_id == db_map.relationship_class_type, subquery.c.entity_class_id)], else_=None - ) - object_entity_case = case( - [(db_map.entity_sq.c.type_id == db_map.object_entity_type, subquery.c.entity_id)], else_=None - ) - rel_entity_case = case( - [(db_map.entity_sq.c.type_id == db_map.relationship_entity_type, subquery.c.entity_id)], else_=None + [(db_map.ext_entity_class_sq.c.dimension_id_list != None, subquery.c.entity_class_id)], else_=None ) + object_entity_case = case([(db_map.ext_entity_sq.c.element_id_list == None, subquery.c.entity_id)], else_=None) + rel_entity_case = case([(db_map.ext_entity_sq.c.element_id_list != None, subquery.c.entity_id)], else_=None) parameter_value_sq = ( db_map.query( subquery.c.id.label("id"), diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 95d43f50..e0113d55 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -2396,7 +2396,7 @@ def _get_relationship_parameter_value_metadata_for_import(db_map, data, make_cac } object_ids = {(x.name, x.class_id): x.id for x in cache.get("entity", {}).values() if not x.element_id_list} relationship_ids = { - (x.object_id_list, x.class_id): x.id for x in cache.get("entity", {}).values() if x.element_id_list + (x.element_id_list, x.class_id): x.id for x in cache.get("entity", {}).values() if x.element_id_list } parameter_ids = { (x.parameter_name, x.entity_class_id): x.id for x in cache.get("parameter_definition", {}).values() diff --git a/tests/export_mapping/test_export_mapping.py b/tests/export_mapping/test_export_mapping.py index b6d7d0c4..ce9cb6bc 100644 --- a/tests/export_mapping/test_export_mapping.py +++ b/tests/export_mapping/test_export_mapping.py @@ -375,10 +375,10 @@ def test_export_parameter_indexes(self): expected = [ ["oc", "o1", "p1", "a"], ["oc", "o1", "p1", "b"], - ["oc", "o1", "p2", "c"], - ["oc", "o1", "p2", "d"], ["oc", "o2", "p1", "e"], ["oc", "o2", "p1", "f"], + ["oc", "o1", "p2", "c"], + ["oc", "o1", "p2", "d"], ["oc", "o2", "p2", "g"], ["oc", "o2", "p2", "h"], ] diff --git a/tests/filters/test_alternative_filter.py b/tests/filters/test_alternative_filter.py index 80cc5b62..01b8b118 100644 --- a/tests/filters/test_alternative_filter.py +++ b/tests/filters/test_alternative_filter.py @@ -23,7 +23,6 @@ apply_alternative_filter_to_parameter_value_sq, create_new_spine_database, DatabaseMapping, - DatabaseMapping, import_alternatives, import_object_classes, import_object_parameter_values, @@ -142,8 +141,8 @@ def test_multiple_alternatives(self): alternative_filter_from_dict(self._db_map, config) parameters = self._db_map.query(self._db_map.parameter_value_sq).all() self.assertEqual(len(parameters), 2) - self.assertEqual(parameters[0].value, b"23.0") - self.assertEqual(parameters[1].value, b"101.1") + self.assertEqual(parameters[0].value, b"101.1") + self.assertEqual(parameters[1].value, b"23.0") def _add_value_in_alternative(self, value, alternative): import_alternatives(self._db_map, [alternative]) diff --git a/tests/filters/test_renamer.py b/tests/filters/test_renamer.py index d799fbfb..57733db8 100644 --- a/tests/filters/test_renamer.py +++ b/tests/filters/test_renamer.py @@ -69,7 +69,7 @@ def test_renaming_singe_entity_class(self): self.assertEqual(len(classes), 1) class_row = classes[0] keys = tuple(class_row.keys()) - expected_keys = ("id", "type_id", "name", "description", "display_order", "display_icon", "hidden", "commit_id") + expected_keys = ("id", "name", "description", "display_order", "display_icon", "hidden") self.assertEqual(len(keys), len(expected_keys)) for expected_key in expected_keys: self.assertIn(expected_key, keys) @@ -126,7 +126,7 @@ def test_entity_class_renamer_from_dict(self): self.assertEqual(len(classes), 1) class_row = classes[0] keys = tuple(class_row.keys()) - expected_keys = ("id", "type_id", "name", "description", "display_order", "display_icon", "hidden", "commit_id") + expected_keys = ("id", "name", "description", "display_order", "display_icon", "hidden") self.assertEqual(len(keys), len(expected_keys)) for expected_key in expected_keys: self.assertIn(expected_key, keys) diff --git a/tests/filters/test_scenario_filter.py b/tests/filters/test_scenario_filter.py index 2cfc409b..ded05eee 100644 --- a/tests/filters/test_scenario_filter.py +++ b/tests/filters/test_scenario_filter.py @@ -23,7 +23,6 @@ apply_scenario_filter_to_subqueries, create_new_spine_database, DatabaseMapping, - DatabaseMapping, import_alternatives, import_object_classes, import_object_parameter_values, @@ -107,7 +106,7 @@ def test_scenario_filter_uncommitted_data(self): self.assertEqual(len(parameters), 1) self.assertEqual(parameters[0].value, b"23.0") alternatives = [a._asdict() for a in self._out_map.query(self._out_map.alternative_sq)] - self.assertEqual(alternatives, [{"name": "alternative", "description": None, "id": 2, "commit_id": None}]) + self.assertEqual(alternatives, [{"name": "alternative", "description": None, "id": 2, "commit_id": 2}]) scenarios = [s._asdict() for s in self._out_map.query(self._out_map.wide_scenario_sq).all()] self.assertEqual( scenarios, @@ -119,7 +118,7 @@ def test_scenario_filter_uncommitted_data(self): "alternative_name_list": "alternative", "alternative_id_list": "2", "id": 1, - "commit_id": None, + "commit_id": 2, } ], ) diff --git a/tests/filters/test_value_transformer.py b/tests/filters/test_value_transformer.py index dc4f5047..69cc30d1 100644 --- a/tests/filters/test_value_transformer.py +++ b/tests/filters/test_value_transformer.py @@ -19,9 +19,7 @@ import unittest from tempfile import TemporaryDirectory from sqlalchemy.engine.url import URL - from spinedb_api import ( - DatabaseMapping, DatabaseMapping, import_object_classes, import_object_parameter_values, diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 5744874c..314c2987 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -18,14 +18,7 @@ import unittest from unittest.mock import patch from sqlalchemy.engine.url import URL -from spinedb_api import ( - DatabaseMapping, - to_database, - import_functions, - from_database, - SpineDBAPIError, - SpineIntegrityError, -) +from spinedb_api import DatabaseMapping, to_database, import_functions, from_database, SpineDBAPIError IN_MEMORY_DB_URL = "sqlite://" @@ -64,26 +57,14 @@ def test_construction_with_sqlalchemy_url_and_filters(self): mock_load.assert_called_once_with(["fltr1", "fltr2"]) mock_apply.assert_called_once_with(db_map, [{"fltr1": "config1", "fltr2": "config2"}]) - def test_entity_class_type_sq(self): - columns = ["id", "name", "commit_id"] - self.assertEqual(len(self._db_map.entity_class_type_sq.c), len(columns)) - for column_name in columns: - self.assertTrue(hasattr(self._db_map.entity_class_type_sq.c, column_name)) - - def test_entity_type_sq(self): - columns = ["id", "name", "commit_id"] - self.assertEqual(len(self._db_map.entity_type_sq.c), len(columns)) - for column_name in columns: - self.assertTrue(hasattr(self._db_map.entity_type_sq.c, column_name)) - def test_entity_sq(self): - columns = ["id", "type_id", "class_id", "name", "description", "commit_id"] + columns = ["id", "class_id", "name", "description", "commit_id"] self.assertEqual(len(self._db_map.entity_sq.c), len(columns)) for column_name in columns: self.assertTrue(hasattr(self._db_map.entity_sq.c, column_name)) def test_object_class_sq(self): - columns = ["id", "name", "description", "display_order", "display_icon", "hidden", "commit_id"] + columns = ["id", "name", "description", "display_order", "display_icon", "hidden"] self.assertEqual(len(self._db_map.object_class_sq.c), len(columns)) for column_name in columns: self.assertTrue(hasattr(self._db_map.object_class_sq.c, column_name)) @@ -95,7 +76,7 @@ def test_object_sq(self): self.assertTrue(hasattr(self._db_map.object_sq.c, column_name)) def test_relationship_class_sq(self): - columns = ["id", "dimension", "object_class_id", "name", "description", "display_icon", "hidden", "commit_id"] + columns = ["id", "dimension", "object_class_id", "name", "description", "display_icon", "hidden"] self.assertEqual(len(self._db_map.relationship_class_sq.c), len(columns)) for column_name in columns: self.assertTrue(hasattr(self._db_map.relationship_class_sq.c, column_name)) @@ -171,22 +152,13 @@ def test_ext_relationship_class_sq(self): "dimension", "object_class_id", "object_class_name", - "commit_id", ] self.assertEqual(len(self._db_map.ext_relationship_class_sq.c), len(columns)) for column_name in columns: self.assertTrue(hasattr(self._db_map.ext_relationship_class_sq.c, column_name)) def test_wide_relationship_class_sq(self): - columns = [ - "id", - "name", - "description", - "display_icon", - "commit_id", - "object_class_id_list", - "object_class_name_list", - ] + columns = ["id", "name", "description", "display_icon", "object_class_id_list", "object_class_name_list"] self.assertEqual(len(self._db_map.wide_relationship_class_sq.c), len(columns)) for column_name in columns: self.assertTrue(hasattr(self._db_map.wide_relationship_class_sq.c, column_name)) diff --git a/tests/test_DiffDatabaseMapping.py b/tests/test_DiffDatabaseMapping.py index 3122a532..fd72dcfe 100644 --- a/tests/test_DiffDatabaseMapping.py +++ b/tests/test_DiffDatabaseMapping.py @@ -568,19 +568,15 @@ def test_add_relationships(self): self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "object_class_id_list": [1, 2]}) self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 2}) self._db_map.add_wide_relationships({"name": "nemo__pluto", "class_id": 3, "object_id_list": [1, 2]}) - # FIXME - self._db_map.commit_session("Ok") - diff_table = self._db_map.get_table("entity_element") - ent_els = self._db_map.query(diff_table).all() - diff_table = self._db_map.get_table("entity") - relationships = self._db_map.query(diff_table).filter(diff_table.c.id.in_({x.entity_id for x in ent_els})).all() + ent_els = self._db_map.query(self._db_map.get_table("entity_element")).all() + relationships = self._db_map.query(self._db_map.wide_relationship_sq).all() self.assertEqual(len(ent_els), 2) self.assertEqual(len(relationships), 1) self.assertEqual(relationships[0].name, "nemo__pluto") self.assertEqual(ent_els[0].entity_class_id, 3) - self.assertEqual(ent_els[0].member_id, 1) + self.assertEqual(ent_els[0].element_id, 1) self.assertEqual(ent_els[1].entity_class_id, 3) - self.assertEqual(ent_els[1].member_id, 2) + self.assertEqual(ent_els[1].element_id, 2) def test_add_relationship_with_invalid_name(self): """Test that adding object classes with empty name raises error""" @@ -593,11 +589,11 @@ def test_add_relationship_with_invalid_name(self): def test_add_identical_relationships(self): """Test that adding two relationships with the same class and same objects only adds the first one.""" self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) - self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "dimension_id_list": [1, 2]}) + self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "object_class_id_list": [1, 2]}) self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 2}) - self._db_map.add_wide_relationships( - {"name": "nemo__pluto", "class_id": 3, "element_id_list": [1, 2]}, - {"name": "nemo__pluto_duplicate", "class_id": 3, "element_id_list": [1, 2]}, + x = self._db_map.add_wide_relationships( + {"name": "nemo__pluto", "class_id": 3, "object_id_list": [1, 2]}, + {"name": "nemo__pluto_duplicate", "class_id": 3, "object_id_list": [1, 2]}, ) self._db_map.commit_session("Ok") relationships = self._db_map.query(self._db_map.wide_relationship_sq).all() @@ -916,7 +912,7 @@ def test_add_alternative(self): alternatives[0]._asdict(), {"id": 1, "name": "Base", "description": "Base alternative", "commit_id": 1} ) self.assertEqual( - alternatives[1]._asdict(), {"id": 2, "name": "my_alternative", "description": None, "commit_id": None} + alternatives[1]._asdict(), {"id": 2, "name": "my_alternative", "description": None, "commit_id": 2} ) def test_add_scenario(self): @@ -927,7 +923,7 @@ def test_add_scenario(self): self.assertEqual(len(scenarios), 1) self.assertEqual( scenarios[0]._asdict(), - {"id": 1, "name": "my_scenario", "description": None, "active": False, "commit_id": None}, + {"id": 1, "name": "my_scenario", "description": None, "active": False, "commit_id": 2}, ) def test_add_scenario_alternative(self): @@ -940,7 +936,7 @@ def test_add_scenario_alternative(self): self.assertEqual(len(scenario_alternatives), 1) self.assertEqual( scenario_alternatives[0]._asdict(), - {"id": 1, "scenario_id": 1, "alternative_id": 1, "rank": 0, "commit_id": None}, + {"id": 1, "scenario_id": 1, "alternative_id": 1, "rank": 0, "commit_id": 3}, ) def test_add_metadata(self): diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py index 83393d73..c54e863b 100644 --- a/tests/test_import_functions.py +++ b/tests/test_import_functions.py @@ -1473,7 +1473,7 @@ def test_import_object_parameter_value_metadata(self): "metadata_value": "John", "parameter_name": "param", "parameter_value_id": 1, - "commit_id": None, + "commit_id": 2, }, ) self.assertEqual( @@ -1487,7 +1487,7 @@ def test_import_object_parameter_value_metadata(self): "metadata_value": "17", "parameter_name": "param", "parameter_value_id": 1, - "commit_id": None, + "commit_id": 2, }, ) @@ -1516,7 +1516,7 @@ def test_import_relationship_parameter_value_metadata(self): "metadata_value": "John", "parameter_name": "param", "parameter_value_id": 1, - "commit_id": None, + "commit_id": 2, }, ) self.assertEqual( @@ -1530,7 +1530,7 @@ def test_import_relationship_parameter_value_metadata(self): "metadata_value": "17", "parameter_name": "param", "parameter_value_id": 1, - "commit_id": None, + "commit_id": 2, }, ) From 1bc8aff863d301670878fdaa1baacf7f9a9c6e07 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 29 Mar 2023 17:11:41 +0200 Subject: [PATCH 013/317] Fix more tests --- spinedb_api/db_mapping_check_mixin.py | 8 +++++++- tests/test_DiffDatabaseMapping.py | 3 +-- tests/test_export_functions.py | 2 +- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/spinedb_api/db_mapping_check_mixin.py b/spinedb_api/db_mapping_check_mixin.py index 6e2943c7..783e45b3 100644 --- a/spinedb_api/db_mapping_check_mixin.py +++ b/spinedb_api/db_mapping_check_mixin.py @@ -488,6 +488,9 @@ def check_wide_relationship_classes(self, *wide_items, for_update=False, strict= relationship_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values() if x.dimension_id_list} object_class_ids = [x.id for x in cache.get("entity_class", {}).values() if not x.dimension_id_list] for wide_item in wide_items: + object_class_id_list = wide_item.get("object_class_id_list") + if "dimension_id_list" not in wide_item and object_class_id_list is not None: + wide_item["dimension_id_list"] = object_class_id_list try: with self._manage_stocks( "entity_class", @@ -541,13 +544,16 @@ def check_wide_relationships(self, *wide_items, for_update=False, strict=False, if not x.element_id_list } for wide_item in wide_items: + object_id_list = wide_item.get("object_id_list") + if "element_id_list" not in wide_item and object_id_list is not None: + wide_item["element_id_list"] = object_id_list try: with self._manage_stocks( "entity", wide_item, { ("class_id", "name"): relationship_ids_by_name, - ("class_id", "object_id_list"): relationship_ids_by_obj_lst, + ("class_id", "element_id_list"): relationship_ids_by_obj_lst, }, for_update, cache, diff --git a/tests/test_DiffDatabaseMapping.py b/tests/test_DiffDatabaseMapping.py index fd72dcfe..b1a840a9 100644 --- a/tests/test_DiffDatabaseMapping.py +++ b/tests/test_DiffDatabaseMapping.py @@ -591,11 +591,10 @@ def test_add_identical_relationships(self): self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "object_class_id_list": [1, 2]}) self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 2}) - x = self._db_map.add_wide_relationships( + self._db_map.add_wide_relationships( {"name": "nemo__pluto", "class_id": 3, "object_id_list": [1, 2]}, {"name": "nemo__pluto_duplicate", "class_id": 3, "object_id_list": [1, 2]}, ) - self._db_map.commit_session("Ok") relationships = self._db_map.query(self._db_map.wide_relationship_sq).all() self.assertEqual(len(relationships), 1) diff --git a/tests/test_export_functions.py b/tests/test_export_functions.py index 7d3d9ce0..1ffefb4d 100644 --- a/tests/test_export_functions.py +++ b/tests/test_export_functions.py @@ -132,7 +132,7 @@ def test_export_data(self): import_scenarios(self._db_map, ["scenario"]) import_scenario_alternatives(self._db_map, [("scenario", "alternative")]) exported = export_data(self._db_map) - self.assertEqual(len(exported), 12) + self.assertEqual(len(exported), 16) self.assertIn("object_classes", exported) self.assertEqual(exported["object_classes"], [("object_class", None, None)]) self.assertIn("object_parameters", exported) From 3e414552c592217dd2eac594bbeba883151b03bc Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 29 Mar 2023 17:41:30 +0200 Subject: [PATCH 014/317] Fix remove_items and descendant_tablenames to use entity --- spinedb_api/db_mapping_base.py | 6 +- spinedb_api/db_mapping_remove_mixin.py | 93 +++++++++++--------------- 2 files changed, 40 insertions(+), 59 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 83a9fbcc..edbf4122 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -276,10 +276,8 @@ def _descendant_tablenames(self, tablename): child_tablenames = { "alternative": ("parameter_value", "scenario_alternative"), "scenario": ("scenario_alternative",), - "object_class": ("object", "relationship_class", "parameter_definition"), - "object": ("relationship", "parameter_value", "entity_group", "entity_metadata"), - "relationship_class": ("relationship", "parameter_definition"), - "relationship": ("parameter_value", "entity_group", "entity_metadata"), + "entity_class": ("entity", "parameter_definition"), + "entity": ("parameter_value", "entity_group", "entity_metadata"), "parameter_definition": ("parameter_value", "feature"), "parameter_value_list": ("feature",), "parameter_value": ("parameter_value_metadata", "entity_metadata"), diff --git a/spinedb_api/db_mapping_remove_mixin.py b/spinedb_api/db_mapping_remove_mixin.py index 292de74f..6f051578 100644 --- a/spinedb_api/db_mapping_remove_mixin.py +++ b/spinedb_api/db_mapping_remove_mixin.py @@ -71,6 +71,15 @@ def cascading_ids(self, cache=None, **kwargs): Returns: cascading_ids (dict): cascading ids keyed by table name """ + for new_tablename, old_tablenames in ( + ("entity_class", {"object_class", "relationship_class"}), + ("entity", {"object", "relationship"}), + ): + for old_tablename in old_tablenames: + ids = kwargs.pop(old_tablename, None) + if ids is not None: + # FIXME: Add deprecation warning + kwargs.setdefault(new_tablename, set()).update(ids) if cache is None: cache = self.make_cache( set(kwargs), @@ -80,10 +89,8 @@ def cascading_ids(self, cache=None, **kwargs): else None, ) ids = {} - self._merge(ids, self._object_class_cascading_ids(kwargs.get("object_class", set()), cache)) - self._merge(ids, self._object_cascading_ids(kwargs.get("object", set()), cache)) - self._merge(ids, self._relationship_class_cascading_ids(kwargs.get("relationship_class", set()), cache)) - self._merge(ids, self._relationship_cascading_ids(kwargs.get("relationship", set()), cache)) + self._merge(ids, self._entity_class_cascading_ids(kwargs.get("entity_class", set()), cache)) + self._merge(ids, self._entity_cascading_ids(kwargs.get("entity", set()), cache)) self._merge(ids, self._entity_group_cascading_ids(kwargs.get("entity_group", set()), cache)) self._merge(ids, self._parameter_definition_cascading_ids(kwargs.get("parameter_definition", set()), cache)) self._merge(ids, self._parameter_value_cascading_ids(kwargs.get("parameter_value", set()), cache)) @@ -102,16 +109,18 @@ def cascading_ids(self, cache=None, **kwargs): ids, self._parameter_value_metadata_cascading_ids(kwargs.get("parameter_value_metadata", set()), cache) ) sorted_ids = {} - tablenames = list(ids) - while tablenames: - tablename = tablenames.pop(0) - ancestors = self.ancestor_tablenames.get(tablename) - if ancestors is None or all(x in sorted_ids for x in ancestors): - sorted_ids[tablename] = ids.pop(tablename) - else: - tablenames.append(tablename) + while ids: + tablename = next(iter(ids)) + self._move(tablename, ids, sorted_ids) return sorted_ids + def _move(self, tablename, unsorted, sorted_): + for ancestor in self.ancestor_tablenames.get(tablename, ()): + self._move(ancestor, unsorted, sorted_) + to_move = unsorted.pop(tablename, None) + if to_move: + sorted_[tablename] = to_move + @staticmethod def _merge(left, right): for tablename, ids in right.items(): @@ -138,61 +147,35 @@ def _scenario_cascading_ids(self, ids, cache): ) return cascading_ids - def _object_class_cascading_ids(self, ids, cache): - """Returns object class cascading ids.""" - cascading_ids = {"entity_class": set(ids), "object_class": set(ids)} - objects = [x for x in dict.values(cache.get("object", {})) if x.class_id in ids] - relationship_classes = ( - x for x in dict.values(cache.get("relationship_class", {})) if set(x.object_class_id_list).intersection(ids) - ) - paramerer_definitions = [ - x for x in dict.values(cache.get("parameter_definition", {})) if x.entity_class_id in ids - ] - self._merge(cascading_ids, self._object_cascading_ids({x.id for x in objects}, cache)) - self._merge(cascading_ids, self._relationship_class_cascading_ids({x.id for x in relationship_classes}, cache)) - self._merge( - cascading_ids, self._parameter_definition_cascading_ids({x.id for x in paramerer_definitions}, cache) - ) - return cascading_ids - - def _object_cascading_ids(self, ids, cache): - """Returns object cascading ids.""" - cascading_ids = {"entity": set(ids), "object": set(ids)} - relationships = ( - x for x in dict.values(cache.get("relationship", {})) if set(x.object_id_list).intersection(ids) + def _entity_class_cascading_ids(self, ids, cache): + """Returns entity class cascading ids.""" + if not ids: + return {} + cascading_ids = {"entity_class": set(ids), "entity_class_dimension": set(ids)} + entities = [x for x in dict.values(cache.get("entity", {})) if x.class_id in ids] + entity_classes = ( + x for x in dict.values(cache.get("entity_class", {})) if set(x.dimension_id_list).intersection(ids) ) - parameter_values = [x for x in dict.values(cache.get("parameter_value", {})) if x.entity_id in ids] - groups = [x for x in dict.values(cache.get("entity_group", {})) if {x.group_id, x.member_id}.intersection(ids)] - entity_metadata_ids = {x.id for x in dict.values(cache.get("entity_metadata", {})) if x.entity_id in ids} - self._merge(cascading_ids, self._relationship_cascading_ids({x.id for x in relationships}, cache)) - self._merge(cascading_ids, self._parameter_value_cascading_ids({x.id for x in parameter_values}, cache)) - self._merge(cascading_ids, self._entity_group_cascading_ids({x.id for x in groups}, cache)) - self._merge(cascading_ids, self._entity_metadata_cascading_ids(entity_metadata_ids, cache)) - return cascading_ids - - def _relationship_class_cascading_ids(self, ids, cache): - """Returns relationship class cascading ids.""" - cascading_ids = { - "relationship_class": set(ids), - "entity_class_dimension": set(ids), - "entity_class": set(ids), - } - relationships = [x for x in dict.values(cache.get("relationship", {})) if x.class_id in ids] paramerer_definitions = [ x for x in dict.values(cache.get("parameter_definition", {})) if x.entity_class_id in ids ] - self._merge(cascading_ids, self._relationship_cascading_ids({x.id for x in relationships}, cache)) + self._merge(cascading_ids, self._entity_cascading_ids({x.id for x in entities}, cache)) + self._merge(cascading_ids, self._entity_class_cascading_ids({x.id for x in entity_classes}, cache)) self._merge( cascading_ids, self._parameter_definition_cascading_ids({x.id for x in paramerer_definitions}, cache) ) return cascading_ids - def _relationship_cascading_ids(self, ids, cache): - """Returns relationship cascading ids.""" - cascading_ids = {"relationship": set(ids), "entity": set(ids), "entity_element": set(ids)} + def _entity_cascading_ids(self, ids, cache): + """Returns entity cascading ids.""" + if not ids: + return {} + cascading_ids = {"entity": set(ids), "entity_element": set(ids)} + entities = (x for x in dict.values(cache.get("entity", {})) if set(x.element_id_list).intersection(ids)) parameter_values = [x for x in dict.values(cache.get("parameter_value", {})) if x.entity_id in ids] groups = [x for x in dict.values(cache.get("entity_group", {})) if {x.group_id, x.member_id}.intersection(ids)] entity_metadata_ids = {x.id for x in dict.values(cache.get("entity_metadata", {})) if x.entity_id in ids} + self._merge(cascading_ids, self._entity_cascading_ids({x.id for x in entities}, cache)) self._merge(cascading_ids, self._parameter_value_cascading_ids({x.id for x in parameter_values}, cache)) self._merge(cascading_ids, self._entity_group_cascading_ids({x.id for x in groups}, cache)) self._merge(cascading_ids, self._entity_metadata_cascading_ids(entity_metadata_ids, cache)) From c6c9b533e63023bd3ef915c8e3ffa969450ea02f Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 29 Mar 2023 17:45:27 +0200 Subject: [PATCH 015/317] Fix a couple more tests --- spinedb_api/db_mapping_check_mixin.py | 2 -- tests/test_DatabaseMapping.py | 2 +- tests/test_DiffDatabaseMapping.py | 8 ++++---- tests/test_helpers.py | 2 +- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/spinedb_api/db_mapping_check_mixin.py b/spinedb_api/db_mapping_check_mixin.py index 783e45b3..a1dc8ef6 100644 --- a/spinedb_api/db_mapping_check_mixin.py +++ b/spinedb_api/db_mapping_check_mixin.py @@ -978,8 +978,6 @@ def _fix_immutable_fields(item_type, current_item, item): "parameter_value": ("entity_class_id", "object_class_id", "relationship_class_id"), }.get(item_type, ()) fixed = [] - # FIXME: we need to be able to identify object_class_id_list as dimension_id_list - # for relationship class items for field in immutable_fields: if current_item.get(field) is None: continue diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 314c2987..5b2c31ee 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -577,7 +577,7 @@ def test_update_wide_relationship_class_does_not_update_member_class_id(self): updated_ids, errors = self._db_map.update_wide_relationship_classes( {"id": 3, "name": "renamed", "object_class_id_list": [2]} ) - self.assertEqual([str(err) for err in errors], ["Can't update fixed fields 'object_class_id_list'"]) + self.assertEqual([str(err) for err in errors], ["Can't update fixed fields 'dimension_id_list'"]) self.assertEqual(updated_ids, {3}) self._db_map.commit_session("Update data.") classes = self._db_map.query(self._db_map.wide_relationship_class_sq).all() diff --git a/tests/test_DiffDatabaseMapping.py b/tests/test_DiffDatabaseMapping.py index b1a840a9..dae7b1c3 100644 --- a/tests/test_DiffDatabaseMapping.py +++ b/tests/test_DiffDatabaseMapping.py @@ -48,9 +48,9 @@ def create_diff_db_map(): class TestDatabaseMappingConstruction(unittest.TestCase): def test_construction_with_filters(self): db_url = IN_MEMORY_DB_URL + "?spinedbfilter=fltr1&spinedbfilter=fltr2" - with mock.patch("spinedb_api.diff_db_mapping.apply_filter_stack") as mock_apply: + with mock.patch("spinedb_api.db_mapping.apply_filter_stack") as mock_apply: with mock.patch( - "spinedb_api.diff_db_mapping.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] + "spinedb_api.db_mapping.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] ) as mock_load: db_map = DatabaseMapping(db_url, create=True) db_map.connection.close() @@ -60,9 +60,9 @@ def test_construction_with_filters(self): def test_construction_with_sqlalchemy_url_and_filters(self): db_url = IN_MEMORY_DB_URL + "/?spinedbfilter=fltr1&spinedbfilter=fltr2" sa_url = make_url(db_url) - with mock.patch("spinedb_api.diff_db_mapping.apply_filter_stack") as mock_apply: + with mock.patch("spinedb_api.db_mapping.apply_filter_stack") as mock_apply: with mock.patch( - "spinedb_api.diff_db_mapping.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] + "spinedb_api.db_mapping.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] ) as mock_load: db_map = DatabaseMapping(sa_url, create=True) db_map.connection.close() diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 980c77af..3f4bb4ae 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -38,7 +38,7 @@ def test_different_schema(self): """Test that importing object class works""" engine1 = create_new_spine_database('sqlite://') engine2 = create_new_spine_database('sqlite://') - engine2.execute("drop table entity_type") + engine2.execute("drop table entity") self.assertFalse(compare_schemas(engine1, engine2)) From e0532cc4c1bb292b57d07e80c45154390d3a9574 Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 30 Mar 2023 16:28:53 +0200 Subject: [PATCH 016/317] Fix details in import entities etc --- spinedb_api/__init__.py | 4 ++++ spinedb_api/check_functions.py | 1 - spinedb_api/db_mapping_add_mixin.py | 2 +- spinedb_api/db_mapping_update_mixin.py | 2 ++ spinedb_api/import_functions.py | 7 ++++--- 5 files changed, 11 insertions(+), 5 deletions(-) diff --git a/spinedb_api/__init__.py b/spinedb_api/__init__.py index 077ac380..f54b09ea 100644 --- a/spinedb_api/__init__.py +++ b/spinedb_api/__init__.py @@ -48,6 +48,10 @@ from .import_functions import ( import_alternatives, import_data, + import_entity_classes, + import_entities, + import_parameter_definitions, + import_parameter_values, import_object_classes, import_objects, import_object_parameters, diff --git a/spinedb_api/check_functions.py b/spinedb_api/check_functions.py index bdebbf94..e1d10db6 100644 --- a/spinedb_api/check_functions.py +++ b/spinedb_api/check_functions.py @@ -131,7 +131,6 @@ def check_entity(item, current_items_by_name, current_items_by_el_id_lst, entity Raises: SpineIntegrityError: if the insertion of the item violates an integrity constraint. """ - try: name = item["name"] except KeyError: diff --git a/spinedb_api/db_mapping_add_mixin.py b/spinedb_api/db_mapping_add_mixin.py index 4557b812..18b647f8 100644 --- a/spinedb_api/db_mapping_add_mixin.py +++ b/spinedb_api/db_mapping_add_mixin.py @@ -423,7 +423,7 @@ def _add_entity_classes(self, *items): return self._add_items("entity_class", *items) def _add_entities(self, *items): - return self._add_items("entities", *items) + return self._add_items("entity", *items) def _add_object_classes(self, *items): return self._add_items("object_class", *items) diff --git a/spinedb_api/db_mapping_update_mixin.py b/spinedb_api/db_mapping_update_mixin.py index 927a9b88..3d1d8a88 100644 --- a/spinedb_api/db_mapping_update_mixin.py +++ b/spinedb_api/db_mapping_update_mixin.py @@ -39,6 +39,8 @@ def _update_items(self, tablename, *items): return self._do_update_items(real_tablename, *items) def _do_update_items(self, tablename, *items): + if not items: + return set() if self.committing: self._add_commit_id(*items) table = self._metadata.tables[tablename] diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index e0113d55..0daffc32 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -383,7 +383,7 @@ def _get_entity_classes_for_import(db_map, data, make_cache): if ec_id is not None else {"name": name, "description": None, "display_icon": None} ) - item.update(dict(zip(("description", "display_icon", "dimension_name_list"), optionals))) + item.update(dict(zip(("dimension_name_list", "description", "display_icon"), optionals))) item["dimension_id_list"] = tuple(entity_class_ids.get(x, None) for x in item.get("dimension_name_list", ())) try: check_entity_class(item, entity_class_ids) @@ -474,6 +474,7 @@ def _get_entities_for_import(db_map, data, make_cache): "name": e_name, "class_id": ec_id, "element_id_list": el_ids, + "dimension_id_list": dim_ids, } ) item.update(dict(zip(("description",), optionals))) @@ -682,7 +683,7 @@ def _get_parameter_values_for_import(db_map, data, make_cache, unparse_value, on parameter_value_lists = {x.id: x.value_id_list for x in cache.get("parameter_value_list", {}).values()} list_values = {x.id: from_database(x.value, x.type) for x in cache.get("list_value", {}).values()} parameter_ids = {(p["entity_class_id"], p["name"]): p_id for p_id, p in parameters.items()} - entity_ids = {(x["class_id"], x["element_id_list"]): e_id for e_id, x in entities.items()} + entity_ids = {(x["class_id"], x["element_id_list"] or x["name"]): e_id for e_id, x in entities.items()} entity_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values()} alternatives = {a.name: a.id for a in cache.get("alternative", {}).values()} alternative_ids = set(alternatives.values()) @@ -757,7 +758,7 @@ def _get_parameter_values_for_import(db_map, data, make_cache, unparse_value, on continue finally: if pv_id is not None: - parameter_value_ids[r_id, p_id, alt_id] = pv_id + parameter_value_ids[e_id, p_id, alt_id] = pv_id checked.add(checked_key) if pv_id is not None: item["id"] = pv_id From 7131aeaa60d8272a2d52b74cd326a2a00352d420 Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 31 Mar 2023 14:32:34 +0200 Subject: [PATCH 017/317] Adapt ImportMapping to the entity structure --- spinedb_api/import_functions.py | 8 +- spinedb_api/import_mapping/generator.py | 29 +- spinedb_api/import_mapping/import_mapping.py | 347 ++++++--------- .../import_mapping/import_mapping_compat.py | 45 +- tests/import_mapping/test_generator.py | 90 ++-- tests/import_mapping/test_import_mapping.py | 416 +++++++++--------- tests/spine_io/test_excel_integration.py | 18 +- 7 files changed, 447 insertions(+), 506 deletions(-) diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 0daffc32..3bed39e6 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -457,8 +457,12 @@ def _get_entities_for_import(db_map, data, make_cache): for class_name, ent_name_or_el_names, *optionals in data: ec_id = entity_class_ids.get(class_name, None) dim_ids = dimension_id_lists.get(ec_id, ()) - el_ids = tuple(entity_ids.get((name, dim_id), None) for name, dim_id in zip(ent_name_or_el_names, dim_ids)) - e_key = el_ids or ent_name_or_el_names + if isinstance(ent_name_or_el_names, str): + el_ids = () + e_key = ent_name_or_el_names + else: + el_ids = tuple(entity_ids.get((name, dim_id), None) for name, dim_id in zip(ent_name_or_el_names, dim_ids)) + e_key = el_ids if (ec_id, e_key) in checked: continue e_id = entity_ids_per_el_id_lst.pop((ec_id, el_ids), None) diff --git a/spinedb_api/import_mapping/generator.py b/spinedb_api/import_mapping/generator.py index eee4cf13..502bbac9 100644 --- a/spinedb_api/import_mapping/generator.py +++ b/spinedb_api/import_mapping/generator.py @@ -159,7 +159,7 @@ def get_mapped_data( full_row = non_pivoted_row + unpivoted_row full_row.append(row[column_pos]) mapping.import_row(full_row, read_state, mapped_data) - _make_relationship_classes(mapped_data) + _make_entity_classes(mapped_data) _make_parameter_values(mapped_data, unparse_value) return mapped_data, errors @@ -264,22 +264,22 @@ def _unpivot_rows(rows, data_header, pivoted, non_pivoted, pivoted_from_header, return unpivoted_rows, pivoted_pos, non_pivoted_pos, unpivoted_column_pos -def _make_relationship_classes(mapped_data): - rows = mapped_data.get("relationship_classes") +def _make_entity_classes(mapped_data): + rows = mapped_data.get("entity_classes") if rows is None: return - full_rows = [] - for class_name, object_classes in rows.items(): - full_rows.append((class_name, object_classes)) - mapped_data["relationship_classes"] = full_rows + full_rows = set() + for class_name, dimension_names in rows.items(): + row = (class_name, tuple(dimension_names)) if dimension_names else (class_name,) + full_rows.add(row) + mapped_data["entity_classes"] = full_rows def _make_parameter_values(mapped_data, unparse_value): value_pos = 3 - for key in ("object_parameter_values", "relationship_parameter_values"): - rows = mapped_data.get(key) - if rows is None: - continue + key = "parameter_values" + rows = mapped_data.get(key) + if rows is not None: valued_rows = [] for row in rows: raw_value = _make_value(row, value_pos) @@ -291,10 +291,9 @@ def _make_parameter_values(mapped_data, unparse_value): valued_rows.append(row) mapped_data[key] = valued_rows value_pos = 0 - for key in ("object_parameters", "relationship_parameters"): - rows = mapped_data.get(key) - if rows is None: - continue + key = "parameter_definitions" + rows = mapped_data.get(key) + if rows is not None: full_rows = [] for entity_definition, extras in rows.items(): if extras: diff --git a/spinedb_api/import_mapping/import_mapping.py b/spinedb_api/import_mapping/import_mapping.py index f792e266..176c49cf 100644 --- a/spinedb_api/import_mapping/import_mapping.py +++ b/spinedb_api/import_mapping/import_mapping.py @@ -23,10 +23,9 @@ @unique class ImportKey(Enum): - CLASS_NAME = auto() - RELATIONSHIP_DIMENSION_COUNT = auto() - OBJECT_CLASS_NAME = auto() - OBJECT_NAME = auto() + DIMENSION_COUNT = auto() + ENTITY_CLASS_NAME = auto() + ENTITY_NAME = auto() GROUP_NAME = auto() MEMBER_NAME = auto() PARAMETER_NAME = auto() @@ -36,9 +35,8 @@ class ImportKey(Enum): PARAMETER_DEFAULT_VALUE_INDEXES = auto() PARAMETER_VALUES = auto() PARAMETER_VALUE_INDEXES = auto() - RELATIONSHIP_CLASS_NAME = auto() - OBJECT_CLASS_NAMES = auto() - OBJECT_NAMES = auto() + DIMENSION_NAMES = auto() + ELEMENT_NAMES = auto() ALTERNATIVE_NAME = auto() SCENARIO_NAME = auto() SCENARIO_ALTERNATIVE = auto() @@ -50,18 +48,16 @@ class ImportKey(Enum): def __str__(self): name = { - self.CLASS_NAME.value: "Class names", - self.OBJECT_CLASS_NAME.value: "Object class names", - self.OBJECT_NAME.value: "Object names", + self.ENTITY_CLASS_NAME.value: "Entity class names", + self.ENTITY_NAME.value: "Entity names", self.GROUP_NAME.value: "Group names", self.MEMBER_NAME.value: "Member names", self.PARAMETER_NAME.value: "Parameter names", self.PARAMETER_DEFINITION.value: "Parameter names", self.PARAMETER_DEFAULT_VALUE_INDEXES.value: "Parameter indexes", self.PARAMETER_VALUE_INDEXES.value: "Parameter indexes", - self.RELATIONSHIP_CLASS_NAME.value: "Relationship class names", - self.OBJECT_CLASS_NAMES.value: "Object class names", - self.OBJECT_NAMES.value: "Object names", + self.DIMENSION_NAMES.value: "Dimension names", + self.ELEMENT_NAMES.value: "Element names", self.PARAMETER_VALUE_LIST_NAME.value: "Parameter value lists", self.SCENARIO_NAME.value: "Scenario names", self.SCENARIO_ALTERNATIVE.value: "Alternative names", @@ -262,7 +258,7 @@ def import_row(self, source_row, state, mapped_data, errors=None): if self.child is not None: self.child.import_row(source_row, state, mapped_data, errors=errors) - def _data(self, source_row): # pylint: disable=arguments-differ + def _data(self, source_row): # pylint: disable=arguments-renamed if source_row is None: return None return source_row[self.position] @@ -313,21 +309,21 @@ def reconstruct(cls, position, value, skip_columns, read_start_row, filter_re, m return mapping -class ImportObjectsMixin: - def __init__(self, position, value=None, skip_columns=None, read_start_row=0, filter_re="", import_objects=False): +class ImportEntitiesMixin: + def __init__(self, position, value=None, skip_columns=None, read_start_row=0, filter_re="", import_entities=False): super().__init__(position, value, skip_columns, read_start_row, filter_re) - self.import_objects = import_objects + self.import_entities = import_entities def to_dict(self): d = super().to_dict() - if self.import_objects: - d["import_objects"] = True + if self.import_entities: + d["import_entities"] = True return d @classmethod def reconstruct(cls, position, value, skip_columns, read_start_row, filter_re, mapping_dict): - import_objects = mapping_dict.get("import_objects", False) - mapping = cls(position, value, skip_columns, read_start_row, filter_re, import_objects) + import_entities = mapping_dict.get("import_entities", False) + mapping = cls(position, value, skip_columns, read_start_row, filter_re, import_entities) return mapping @@ -357,158 +353,123 @@ def reconstruct(cls, position, value, skip_columns, read_start_row, filter_re, m return mapping -class ObjectClassMapping(ImportMapping): - """Maps object classes. +class EntityClassMapping(ImportMapping): + """Maps entity classes. Can be used as the topmost mapping. """ - MAP_TYPE = "ObjectClass" + MAP_TYPE = "EntityClass" def _import_row(self, source_data, state, mapped_data): - object_class_name = state[ImportKey.OBJECT_CLASS_NAME] = str(source_data) - object_classes = mapped_data.setdefault("object_classes", set()) - object_classes.add(object_class_name) + dim_count = len([m for m in self.flatten() if isinstance(m, DimensionMapping)]) + state[ImportKey.DIMENSION_COUNT] = dim_count + entity_class_name = state[ImportKey.ENTITY_CLASS_NAME] = str(source_data) + dimension_names = state[ImportKey.DIMENSION_NAMES] = [] + entity_classes = mapped_data.setdefault("entity_classes", {}) + entity_classes[entity_class_name] = dimension_names + if dim_count: + raise KeyError(ImportKey.DIMENSION_NAMES) -class ObjectMapping(ImportMapping): - """Maps objects. +class EntityMapping(ImportMapping): + """Maps entities. - Cannot be used as the topmost mapping; one of the parents must be :class:`ObjectClassMapping`. + Cannot be used as the topmost mapping; one of the parents must be :class:`EntityClassMapping`. """ - MAP_TYPE = "Object" + MAP_TYPE = "Entity" + + def import_row(self, source_row, state, mapped_data, errors=None): + state[ImportKey.ELEMENT_NAMES] = () + super().import_row(source_row, state, mapped_data, errors=errors) def _import_row(self, source_data, state, mapped_data): - object_class_name = state[ImportKey.OBJECT_CLASS_NAME] - object_name = state[ImportKey.OBJECT_NAME] = str(source_data) - if isinstance(self.child, ObjectGroupMapping): + if state[ImportKey.DIMENSION_COUNT]: + return + entity_class_name = state[ImportKey.ENTITY_CLASS_NAME] + entity_name = state[ImportKey.ENTITY_NAME] = str(source_data) + if isinstance(self.child, EntityGroupMapping): raise KeyError(ImportKey.MEMBER_NAME) - mapped_data.setdefault("objects", set()).add((object_class_name, object_name)) + mapped_data.setdefault("entities", set()).add((entity_class_name, entity_name)) -class ObjectMetadataMapping(ImportMapping): - """Maps object metadata. +class EntityMetadataMapping(ImportMapping): + """Maps entity metadata. - Cannot be used as the topmost mapping; must have :class:`ObjectClassMapping` and :class:`ObjectMapping` as parents. + Cannot be used as the topmost mapping; must have :class:`EntityClassMapping` and :class:`EntityMapping` as parents. """ - MAP_TYPE = "ObjectMetadata" + MAP_TYPE = "EntityMetadata" def _import_row(self, source_data, state, mapped_data): pass -class ObjectGroupMapping(ImportObjectsMixin, ImportMapping): - """Maps object groups. +class EntityGroupMapping(ImportEntitiesMixin, ImportMapping): + """Maps entity groups. - Cannot be used as the topmost mapping; must have :class:`ObjectClassMapping` and :class:`ObjectMapping` as parents. + Cannot be used as the topmost mapping; must have :class:`EntityClassMapping` and :class:`EntityMapping` as parents. """ - MAP_TYPE = "ObjectGroup" + MAP_TYPE = "EntityGroup" def _import_row(self, source_data, state, mapped_data): - object_class_name = state[ImportKey.OBJECT_CLASS_NAME] - group_name = state.get(ImportKey.OBJECT_NAME) + entity_class_name = state[ImportKey.ENTITY_CLASS_NAME] + group_name = state.get(ImportKey.ENTITY_NAME) if group_name is None: raise KeyError(ImportKey.GROUP_NAME) member_name = str(source_data) - mapped_data.setdefault("object_groups", set()).add((object_class_name, group_name, member_name)) - if self.import_objects: - objects = (object_class_name, group_name), (object_class_name, member_name) - mapped_data.setdefault("objects", set()).update(objects) + mapped_data.setdefault("entity_groups", set()).add((entity_class_name, group_name, member_name)) + if self.import_entities: + entities = (entity_class_name, group_name), (entity_class_name, member_name) + mapped_data.setdefault("entities", set()).update(entities) raise KeyFix(ImportKey.MEMBER_NAME) -class RelationshipClassMapping(ImportMapping): - """Maps relationships classes. +class DimensionMapping(ImportMapping): + """Maps dimensions. - Can be used as the topmost mapping. + Cannot be used as the topmost mapping; one of the parents must be :class:`EntityClassMapping`. """ - MAP_TYPE = "RelationshipClass" + MAP_TYPE = "Dimension" def _import_row(self, source_data, state, mapped_data): - dim_count = len([m for m in self.flatten() if isinstance(m, RelationshipClassObjectClassMapping)]) - state[ImportKey.RELATIONSHIP_DIMENSION_COUNT] = dim_count - relationship_class_name = state[ImportKey.RELATIONSHIP_CLASS_NAME] = str(source_data) - object_class_names = state[ImportKey.OBJECT_CLASS_NAMES] = [] - relationship_classes = mapped_data.setdefault("relationship_classes", dict()) - relationship_classes[relationship_class_name] = object_class_names - raise KeyError(ImportKey.OBJECT_CLASS_NAMES) + _ = state[ImportKey.ENTITY_CLASS_NAME] + dimension_name = str(source_data) + state[ImportKey.DIMENSION_NAMES].append(dimension_name) + dimension_names = state[ImportKey.DIMENSION_NAMES] + if len(dimension_names) == state[ImportKey.DIMENSION_COUNT]: + raise KeyFix(ImportKey.DIMENSION_NAMES) -class RelationshipClassObjectClassMapping(ImportMapping): - """Maps relationship class object classes. +class ElementMapping(ImportEntitiesMixin, ImportMapping): + """Maps elements. - Cannot be used as the topmost mapping; one of the parents must be :class:`RelationshipClassMapping`. - """ - - MAP_TYPE = "RelationshipClassObjectClass" - - def _import_row(self, source_data, state, mapped_data): - _ = state[ImportKey.RELATIONSHIP_CLASS_NAME] - object_class_names = state[ImportKey.OBJECT_CLASS_NAMES] - object_class_name = str(source_data) - object_class_names.append(object_class_name) - if len(object_class_names) == state[ImportKey.RELATIONSHIP_DIMENSION_COUNT]: - raise KeyFix(ImportKey.OBJECT_CLASS_NAMES) - - -class RelationshipMapping(ImportMapping): - """Maps relationships. - - Cannot be used as the topmost mapping; one of the parents must be :class:`RelationshipClassMapping`. - """ - - MAP_TYPE = "Relationship" - - def _import_row(self, source_data, state, mapped_data): - # Don't access state[ImportKey.RELATIONSHIP_CLASS_NAME], we don't want to catch errors here - # because this one's invisible. - state[ImportKey.OBJECT_NAMES] = [] - - -class RelationshipObjectMapping(ImportObjectsMixin, ImportMapping): - """Maps relationship's objects. - - Cannot be used as the topmost mapping; must have :class:`RelationshipClassMapping` and :class:`RelationshipMapping` + Cannot be used as the topmost mapping; must have :class:`EntityClassMapping` and :class:`EntityMapping` as parents. """ - MAP_TYPE = "RelationshipObject" - - def _import_row(self, source_data, state, mapped_data): - relationship_class_name = state[ImportKey.RELATIONSHIP_CLASS_NAME] - object_class_names = state[ImportKey.OBJECT_CLASS_NAMES] - if len(object_class_names) != state[ImportKey.RELATIONSHIP_DIMENSION_COUNT]: - raise KeyError(ImportKey.OBJECT_CLASS_NAMES) - object_names = state[ImportKey.OBJECT_NAMES] - object_name = str(source_data) - object_names.append(object_name) - if self.import_objects: - k = len(object_names) - 1 - object_class_name = object_class_names[k] - mapped_data.setdefault("object_classes", set()).add(object_class_name) - mapped_data.setdefault("objects", set()).add((object_class_name, object_name)) - if len(object_names) == state[ImportKey.RELATIONSHIP_DIMENSION_COUNT]: - relationships = mapped_data.setdefault("relationships", set()) - relationships.add((relationship_class_name, tuple(object_names))) - raise KeyFix(ImportKey.OBJECT_NAMES) - raise KeyError(ImportKey.OBJECT_NAMES) - - -class RelationshipMetadataMapping(ImportMapping): - """Maps relationship metadata. - - Cannot be used as the topmost mapping; must have :class:`RelationshipClassMapping`, a :class:`RelationshipMapping` - and one or more :class:`RelationshipObjectMapping` as parents. - """ - - MAP_TYPE = "RelationshipMetadata" + MAP_TYPE = "Element" def _import_row(self, source_data, state, mapped_data): - pass + entity_class_name = state[ImportKey.ENTITY_CLASS_NAME] + dimension_names = state[ImportKey.DIMENSION_NAMES] + if len(dimension_names) != state[ImportKey.DIMENSION_COUNT]: + raise KeyError(ImportKey.DIMENSION_NAMES) + element_name = str(source_data) + element_names = state[ImportKey.ELEMENT_NAMES] = state[ImportKey.ELEMENT_NAMES] + (element_name,) + if self.import_entities: + k = len(element_names) - 1 + dimension_name = dimension_names[k] + mapped_data.setdefault("entity_classes", {}).update({dimension_name: ()}) + mapped_data.setdefault("entities", set()).add((dimension_name, element_name)) + if len(element_names) == state[ImportKey.DIMENSION_COUNT]: + entities = mapped_data.setdefault("entities", set()) + entities.add((entity_class_name, tuple(element_names))) + raise KeyFix(ImportKey.ELEMENT_NAMES) + raise KeyError(ImportKey.ELEMENT_NAMES) class ParameterDefinitionMapping(ImportMapping): @@ -520,23 +481,13 @@ class ParameterDefinitionMapping(ImportMapping): MAP_TYPE = "ParameterDefinition" def _import_row(self, source_data, state, mapped_data): - object_class_name = state.get(ImportKey.OBJECT_CLASS_NAME) - if object_class_name is not None: - class_name = object_class_name - map_key = "object_parameters" - else: - relationship_class_name = state.get(ImportKey.RELATIONSHIP_CLASS_NAME) - if relationship_class_name is not None: - class_name = relationship_class_name - map_key = "relationship_parameters" - else: - raise KeyError(ImportKey.CLASS_NAME) + entity_class_name = state.get(ImportKey.ENTITY_CLASS_NAME) parameter_name = state[ImportKey.PARAMETER_NAME] = str(source_data) definition_extras = state[ImportKey.PARAMETER_DEFINITION_EXTRAS] = [] - parameter_definition_key = state[ImportKey.PARAMETER_DEFINITION] = class_name, parameter_name + parameter_definition_key = state[ImportKey.PARAMETER_DEFINITION] = entity_class_name, parameter_name default_values = state.get(ImportKey.PARAMETER_DEFAULT_VALUES) if default_values is None or parameter_definition_key not in default_values: - mapped_data.setdefault(map_key, dict())[parameter_definition_key] = definition_extras + mapped_data.setdefault("parameter_definitions", dict())[parameter_definition_key] = definition_extras class ParameterDefaultValueMapping(ImportMapping): @@ -675,65 +626,35 @@ def _import_row(self, source_data, state, mapped_data): value = source_data if value == "": return - object_class_name = state.get(ImportKey.OBJECT_CLASS_NAME) - relationship_class_name = state.get(ImportKey.RELATIONSHIP_CLASS_NAME) - if object_class_name is not None: - class_name = object_class_name - entity_name = state[ImportKey.OBJECT_NAME] - map_key = "object_parameter_values" - elif relationship_class_name is not None: - object_names = state[ImportKey.OBJECT_NAMES] - if len(object_names) != state[ImportKey.RELATIONSHIP_DIMENSION_COUNT]: - raise KeyError(ImportKey.OBJECT_NAMES) - class_name = relationship_class_name - entity_name = object_names - map_key = "relationship_parameter_values" - else: - raise KeyError(ImportKey.CLASS_NAME) - parameter_name = state[ImportKey.PARAMETER_NAME] - parameter_value = [class_name, entity_name, parameter_name, value] - alternative_name = state.get(ImportKey.ALTERNATIVE_NAME) + entity_class_name, entity_byname, parameter_name, alternative_name = _parameter_value_key(state) + parameter_value = [entity_class_name, entity_byname, parameter_name, value] if alternative_name is not None: parameter_value.append(alternative_name) - mapped_data.setdefault(map_key, list()).append(parameter_value) + mapped_data.setdefault("parameter_values", []).append(parameter_value) class ParameterValueTypeMapping(IndexedValueMixin, ImportMapping): MAP_TYPE = "ParameterValueType" def _import_row(self, source_data, state, mapped_data): - parameter_name = state.get(ImportKey.PARAMETER_NAME) - if parameter_name is None: + if ImportKey.PARAMETER_NAME not in state: # Don't catch errors here, this one's invisible return - object_class_name = state.get(ImportKey.OBJECT_CLASS_NAME) + key = _parameter_value_key(state) values = state.setdefault(ImportKey.PARAMETER_VALUES, {}) - if object_class_name is not None: - class_name = object_class_name - entity_name = state[ImportKey.OBJECT_NAME] - map_key = "object_parameter_values" - else: - relationship_class_name = state.get(ImportKey.RELATIONSHIP_CLASS_NAME) - if relationship_class_name is not None: - class_name = relationship_class_name - entity_name = tuple(state[ImportKey.OBJECT_NAMES]) - map_key = "relationship_parameter_values" - else: - raise KeyError(ImportKey.CLASS_NAME) - alternative_name = state.get(ImportKey.ALTERNATIVE_NAME) - key = (class_name, entity_name, parameter_name, alternative_name) if key in values: return + entity_class_name, entity_byname, parameter_name, alternative_name = key value_type = str(source_data) value = values[key] = {"type": value_type} # See import_mapping.generator._parameter_value_from_dict() if self.compress and value_type == "map": value["compress"] = self.compress if self.options and value_type == "time_series": value["options"] = self.options - parameter_value = [class_name, entity_name, parameter_name, value] + parameter_value = [entity_class_name, entity_byname, parameter_name, value] if alternative_name is not None: parameter_value.append(alternative_name) - mapped_data.setdefault(map_key, list()).append(parameter_value) + mapped_data.setdefault("parameter_values", []).append(parameter_value) class ParameterValueMetadataMapping(ImportMapping): @@ -833,7 +754,7 @@ def _import_row(self, source_data, state, mapped_data): if list_value == "": return value_list_name = state[ImportKey.PARAMETER_VALUE_LIST_NAME] - mapped_data.setdefault("parameter_value_lists", list()).append([value_list_name, list_value]) + mapped_data.setdefault("parameter_value_lists", []).append([value_list_name, list_value]) class AlternativeMapping(ImportMapping): @@ -889,7 +810,7 @@ def _import_row(self, source_data, state, mapped_data): return scenario = state[ImportKey.SCENARIO_NAME] scen_alt = state[ImportKey.SCENARIO_ALTERNATIVE] = [scenario, alternative] - mapped_data.setdefault("scenario_alternatives", list()).append(scen_alt) + mapped_data.setdefault("scenario_alternatives", []).append(scen_alt) class ScenarioBeforeAlternativeMapping(ImportMapping): @@ -961,7 +882,7 @@ def _import_row(self, source_data, state, mapped_data): entity_class = str(source_data) tool_feature = [tool, entity_class] state[ImportKey.TOOL_FEATURE] = tool_feature - mapped_data.setdefault("tool_features", list()).append(tool_feature) + mapped_data.setdefault("tool_features", []).append(tool_feature) class ToolFeatureParameterDefinitionMapping(ImportMapping): @@ -1005,7 +926,7 @@ def _import_row(self, source_data, state, mapped_data): entity_class = str(source_data) tool_feature_method = [tool_name, entity_class] state[ImportKey.TOOL_FEATURE_METHOD] = tool_feature_method - mapped_data.setdefault("tool_feature_methods", list()).append(tool_feature_method) + mapped_data.setdefault("tool_feature_methods", []).append(tool_feature_method) class ToolFeatureMethodParameterDefinitionMapping(ImportMapping): @@ -1051,15 +972,12 @@ def from_dict(serialized): mappings = { klass.MAP_TYPE: klass for klass in ( - ObjectClassMapping, - ObjectMapping, - ObjectMetadataMapping, - ObjectGroupMapping, - RelationshipClassMapping, - RelationshipClassObjectClassMapping, - RelationshipMapping, - RelationshipObjectMapping, - RelationshipMetadataMapping, + EntityClassMapping, + EntityMapping, + EntityMetadataMapping, + EntityGroupMapping, + DimensionMapping, + ElementMapping, ParameterDefinitionMapping, ParameterDefaultValueMapping, ParameterDefaultValueTypeMapping, @@ -1089,9 +1007,20 @@ def from_dict(serialized): ToolFeatureMethodMethodMapping, ) } - # Legacy - mappings["ParameterIndex"] = ParameterValueIndexMapping - flattened = list() + legacy_mappings = { + "ParameterIndex": ParameterValueIndexMapping, + "ObjectClass": EntityClassMapping, + "Object": EntityMapping, + "ObjectMetadata": EntityMetadataMapping, + "ObjectGroup": EntityGroupMapping, + "RelationshipClass": EntityClassMapping, + "RelationshipClassObjectClass": DimensionMapping, + "Relationship": EntityMapping, + "RelationshipObject": ElementMapping, + "RelationshipMetadata": EntityMetadataMapping, + } + mappings.update(legacy_mappings) + flattened = [] for mapping_dict in serialized: position = mapping_dict["position"] value = mapping_dict.get("value") @@ -1100,6 +1029,9 @@ def from_dict(serialized): filter_re = mapping_dict.get("filter_re", "") if isinstance(position, str): position = Position(position) + if "import_objects" in mapping_dict: + # Legacy + mapping_dict["import_entities"] = mapping_dict.pop("import_objects") flattened.append( mappings[mapping_dict["map_type"]].reconstruct( position, value, skip_columns, read_start_row, filter_re, mapping_dict @@ -1118,24 +1050,19 @@ def _parameter_value_key(state): state (dict): import state Returns: - tuple of str: class name, entity name and parameter name + tuple of str: class name, entity byname, parameter name, and alternative name """ - object_class_name = state.get(ImportKey.OBJECT_CLASS_NAME) - if object_class_name is not None: - class_name = object_class_name - entity_name = state[ImportKey.OBJECT_NAME] + entity_class_name = state.get(ImportKey.ENTITY_CLASS_NAME) + if state.get(ImportKey.DIMENSION_COUNT): + element_names = state[ImportKey.ELEMENT_NAMES] + if len(element_names) != state[ImportKey.DIMENSION_COUNT]: + raise KeyError(ImportKey.ELEMENT_NAMES) + entity_byname = element_names else: - relationship_class_name = state.get(ImportKey.RELATIONSHIP_CLASS_NAME) - if relationship_class_name is None: - raise KeyError(ImportKey.CLASS_NAME) - object_names = state[ImportKey.OBJECT_NAMES] - if len(object_names) != state[ImportKey.RELATIONSHIP_DIMENSION_COUNT]: - raise KeyError(ImportKey.OBJECT_NAMES) - class_name = relationship_class_name - entity_name = tuple(object_names) + entity_byname = state[ImportKey.ENTITY_NAME] parameter_name = state[ImportKey.PARAMETER_NAME] alternative_name = state.get(ImportKey.ALTERNATIVE_NAME) - return class_name, entity_name, parameter_name, alternative_name + return entity_class_name, entity_byname, parameter_name, alternative_name def _default_value_key(state): @@ -1147,10 +1074,4 @@ def _default_value_key(state): Returns: tuple of str: class name and parameter name """ - class_name = state.get(ImportKey.OBJECT_CLASS_NAME) - if class_name is None: - class_name = state.get(ImportKey.RELATIONSHIP_CLASS_NAME) - if class_name is None: - raise KeyError(ImportKey.CLASS_NAME) - parameter_name = state[ImportKey.PARAMETER_NAME] - return class_name, parameter_name + return state[ImportKey.ENTITY_CLASS_NAME], state[ImportKey.PARAMETER_NAME] diff --git a/spinedb_api/import_mapping/import_mapping_compat.py b/spinedb_api/import_mapping/import_mapping_compat.py index 603f0e65..8308c26e 100644 --- a/spinedb_api/import_mapping/import_mapping_compat.py +++ b/spinedb_api/import_mapping/import_mapping_compat.py @@ -17,14 +17,11 @@ """ from .import_mapping import ( Position, - ObjectClassMapping, - RelationshipClassMapping, - RelationshipClassObjectClassMapping, - ObjectMapping, - ObjectMetadataMapping, - RelationshipMapping, - RelationshipObjectMapping, - RelationshipMetadataMapping, + EntityClassMapping, + DimensionMapping, + EntityMapping, + EntityMetadataMapping, + ElementMapping, ParameterDefinitionMapping, ParameterDefaultValueMapping, ParameterDefaultValueTypeMapping, @@ -49,7 +46,7 @@ ToolFeatureMethodEntityClassMapping, ToolFeatureMethodParameterDefinitionMapping, ToolFeatureMethodMethodMapping, - ObjectGroupMapping, + EntityGroupMapping, ParameterValueListMapping, ParameterValueListValueMapping, from_dict as mapping_from_dict, @@ -205,9 +202,9 @@ def _object_class_mapping_from_dict(map_dict): parameters = map_dict.get("parameters") skip_columns = map_dict.get("skip_columns", []) read_start_row = map_dict.get("read_start_row", 0) - root_mapping = ObjectClassMapping(*_pos_and_val(name), skip_columns=skip_columns, read_start_row=read_start_row) - object_mapping = root_mapping.child = ObjectMapping(*_pos_and_val(objects)) - object_metadata_mapping = object_mapping.child = ObjectMetadataMapping(*_pos_and_val(object_metadata)) + root_mapping = EntityClassMapping(*_pos_and_val(name), skip_columns=skip_columns, read_start_row=read_start_row) + object_mapping = root_mapping.child = EntityMapping(*_pos_and_val(objects)) + object_metadata_mapping = object_mapping.child = EntityMetadataMapping(*_pos_and_val(object_metadata)) object_metadata_mapping.child = parameter_mapping_from_dict(parameters) return root_mapping @@ -216,12 +213,12 @@ def _object_group_mapping_from_dict(map_dict): name = map_dict.get("name") groups = map_dict.get("groups") members = map_dict.get("members") - import_objects = map_dict.get("import_objects", False) + import_entities = map_dict.get("import_objects", False) skip_columns = map_dict.get("skip_columns", []) read_start_row = map_dict.get("read_start_row", 0) - root_mapping = ObjectClassMapping(*_pos_and_val(name), skip_columns=skip_columns, read_start_row=read_start_row) - object_mapping = root_mapping.child = ObjectMapping(*_pos_and_val(groups)) - object_mapping.child = ObjectGroupMapping(*_pos_and_val(members), import_objects=import_objects) + root_mapping = EntityClassMapping(*_pos_and_val(name), skip_columns=skip_columns, read_start_row=read_start_row) + object_mapping = root_mapping.child = EntityMapping(*_pos_and_val(groups)) + object_mapping.child = EntityGroupMapping(*_pos_and_val(members), import_entities=import_entities) return root_mapping @@ -235,26 +232,22 @@ def _relationship_class_mapping_from_dict(map_dict): object_classes = [None] relationship_metadata = map_dict.get("relationship_metadata") parameters = map_dict.get("parameters") - import_objects = map_dict.get("import_objects", False) + import_entities = map_dict.get("import_objects", False) skip_columns = map_dict.get("skip_columns", []) read_start_row = map_dict.get("read_start_row", 0) - root_mapping = RelationshipClassMapping( - *_pos_and_val(name), skip_columns=skip_columns, read_start_row=read_start_row - ) + root_mapping = EntityClassMapping(*_pos_and_val(name), skip_columns=skip_columns, read_start_row=read_start_row) parent_mapping = root_mapping for klass in object_classes: - class_mapping = RelationshipClassObjectClassMapping(*_pos_and_val(klass)) + class_mapping = DimensionMapping(*_pos_and_val(klass)) parent_mapping.child = class_mapping parent_mapping = class_mapping - relationship_mapping = parent_mapping.child = RelationshipMapping(Position.hidden, value="relationship") + relationship_mapping = parent_mapping.child = EntityMapping(Position.hidden) parent_mapping = relationship_mapping for obj in objects: - object_mapping = RelationshipObjectMapping(*_pos_and_val(obj), import_objects=import_objects) + object_mapping = ElementMapping(*_pos_and_val(obj), import_entities=import_entities) parent_mapping.child = object_mapping parent_mapping = object_mapping - relationship_metadata_mapping = parent_mapping.child = RelationshipMetadataMapping( - *_pos_and_val(relationship_metadata) - ) + relationship_metadata_mapping = parent_mapping.child = EntityMetadataMapping(*_pos_and_val(relationship_metadata)) relationship_metadata_mapping.child = parameter_mapping_from_dict(parameters) return root_mapping diff --git a/tests/import_mapping/test_generator.py b/tests/import_mapping/test_generator.py index 3d10101a..44d0b65f 100644 --- a/tests/import_mapping/test_generator.py +++ b/tests/import_mapping/test_generator.py @@ -72,10 +72,10 @@ def test_returns_appropriate_error_if_last_row_is_empty(self): mapped_data, { 'alternatives': {'Base'}, - 'object_classes': {'Object'}, - 'object_parameter_values': [['Object', 'data', 'Parameter', Map(["T1", "T2"], [5.0, 99.0]), 'Base']], - 'object_parameters': [('Object', 'Parameter')], - 'objects': {('Object', 'data')}, + 'entity_classes': {('Object',)}, + 'parameter_values': [['Object', 'data', 'Parameter', Map(["T1", "T2"], [5.0, 99.0]), 'Base']], + 'parameter_definitions': [('Object', 'Parameter')], + 'entities': {('Object', 'data')}, }, ) @@ -104,10 +104,10 @@ def test_convert_functions_get_expanded_over_last_defined_column_in_pivoted_data mapped_data, { 'alternatives': {'Base'}, - 'object_classes': {'Object'}, - 'object_parameter_values': [['Object', 'data', 'Parameter', Map(["T1", "T2"], [5.0, 99.0]), 'Base']], - 'object_parameters': [('Object', 'Parameter')], - 'objects': {('Object', 'data')}, + 'entity_classes': {('Object',)}, + 'parameter_values': [['Object', 'data', 'Parameter', Map(["T1", "T2"], [5.0, 99.0]), 'Base']], + 'parameter_definitions': [('Object', 'Parameter')], + 'entities': {('Object', 'data')}, }, ) @@ -135,10 +135,10 @@ def test_read_start_row_skips_rows_in_pivoted_data(self): self.assertEqual( mapped_data, { - 'object_classes': {'klass'}, - 'object_parameter_values': [['klass', 'kloss', 'Parameter_2', Map(["T1", "T2"], [2.3, 23.0])]], - 'object_parameters': [('klass', 'Parameter_2')], - 'objects': {('klass', 'kloss')}, + 'entity_classes': {('klass',)}, + 'parameter_values': [['klass', 'kloss', 'Parameter_2', Map(["T1", "T2"], [2.3, 23.0])]], + 'parameter_definitions': [('klass', 'Parameter_2')], + 'entities': {('klass', 'kloss')}, }, ) @@ -189,10 +189,10 @@ def test_map_without_values_is_ignored_and_not_interpreted_as_null(self): mapped_data, { "alternatives": {"base"}, - "object_classes": {"o"}, - "object_parameters": [("o", "parameter_name")], - "object_parameter_values": [], - "objects": {("o", "o1")}, + "entity_classes": {("o",)}, + "parameter_definitions": [("o", "parameter_name")], + "parameter_values": [], + "entities": {("o", "o1")}, }, ) @@ -225,12 +225,18 @@ def test_import_object_works_with_multiple_relationship_object_imports(self): mapped_data, { "alternatives": {"base"}, - "object_classes": {"o", "q"}, - "objects": {("o", "o1"), ("o", "o2"), ("q", "q1"), ("q", "q2")}, - "relationship_classes": [("o_to_q", ["o", "q"])], - "relationships": {("o_to_q", ("o1", "q1")), ("o_to_q", ("o1", "q2")), ("o_to_q", ("o2", "q2"))}, - "relationship_parameters": [("o_to_q", "param")], - "relationship_parameter_values": [ + "entity_classes": {("o",), ("q",), ("o_to_q", ("o", "q"))}, + "entities": { + ("o", "o1"), + ("o", "o2"), + ("q", "q1"), + ("q", "q2"), + ("o_to_q", ("o1", "q1")), + ("o_to_q", ("o1", "q2")), + ("o_to_q", ("o2", "q2")), + }, + "parameter_definitions": [("o_to_q", "param")], + "parameter_values": [ ["o_to_q", ("o1", "q1"), "param", Map(["t1", "t2"], [11, 22], index_name="time"), "base"], ["o_to_q", ("o2", "q2"), "param", Map(["t1", "t2"], [33, 44], index_name="time"), "base"], ["o_to_q", ("o1", "q2"), "param", Map(["t1", "t2"], [55, 66], index_name="time"), "base"], @@ -263,10 +269,10 @@ def test_default_convert_function_in_column_convert_functions(self): self.assertEqual( mapped_data, { - "object_classes": {"klass"}, - "object_parameter_values": [["klass", "kloss", "Parameter_2", Map(["T1", "T2"], [2.3, 23.0])]], - "object_parameters": [("klass", "Parameter_2")], - "objects": {("klass", "kloss")}, + "entity_classes": {("klass",)}, + "parameter_values": [["klass", "kloss", "Parameter_2", Map(["T1", "T2"], [2.3, 23.0])]], + "parameter_definitions": [("klass", "Parameter_2")], + "entities": {("klass", "kloss")}, }, ) @@ -291,10 +297,10 @@ def test_identity_function_is_used_as_convert_function_when_no_convert_functions self.assertEqual( mapped_data, { - "object_classes": {"klass"}, - "object_parameter_values": [["klass", "kloss", "Parameter_2", Map(["T1", "T2"], ["2.3", "23.0"])]], - "object_parameters": [("klass", "Parameter_2")], - "objects": {("klass", "kloss")}, + "entity_classes": {("klass",)}, + "parameter_values": [["klass", "kloss", "Parameter_2", Map(["T1", "T2"], ["2.3", "23.0"])]], + "parameter_definitions": [("klass", "Parameter_2")], + "entities": {("klass", "kloss")}, }, ) @@ -321,10 +327,10 @@ def test_last_convert_function_gets_used_as_default_convert_function_when_no_def self.assertEqual( mapped_data, { - "object_classes": {"klass"}, - "object_parameter_values": [["klass", "kloss", "Parameter_2", Map(["T1", "T2"], [2.3, 23.0])]], - "object_parameters": [("klass", "Parameter_2")], - "objects": {("klass", "kloss")}, + "entity_classes": {("klass",)}, + "parameter_values": [["klass", "kloss", "Parameter_2", Map(["T1", "T2"], [2.3, 23.0])]], + "parameter_definitions": [("klass", "Parameter_2")], + "entities": {("klass", "kloss")}, }, ) @@ -354,13 +360,13 @@ def test_array_parameters_get_imported_correctly_when_objects_are_in_header(self mapped_data, { "alternatives": {"Base"}, - "object_classes": {"class"}, - "object_parameter_values": [ + "entity_classes": {("class",)}, + "parameter_values": [ ["class", "object_1", "param", Array([-1.1, 1.1]), "Base"], ["class", "object_2", "param", Array([2.3, -2.3]), "Base"], ], - "object_parameters": [("class", "param")], - "objects": {("class", "object_1"), ("class", "object_2")}, + "parameter_definitions": [("class", "param")], + "entities": {("class", "object_1"), ("class", "object_2")}, }, ) @@ -390,13 +396,13 @@ def test_arrays_get_imported_correctly_when_objects_are_in_header_and_alternativ mapped_data, { "alternatives": {"Base"}, - "object_classes": {"Gadget"}, - "object_parameter_values": [ + "entity_classes": {("Gadget",)}, + "parameter_values": [ ["Gadget", "object_1", "data", Array([-1.1, 1.1]), "Base"], ["Gadget", "object_2", "data", Array([2.3, -2.3]), "Base"], ], - "object_parameters": [("Gadget", "data")], - "objects": {("Gadget", "object_1"), ("Gadget", "object_2")}, + "parameter_definitions": [("Gadget", "data")], + "entities": {("Gadget", "object_1"), ("Gadget", "object_2")}, }, ) diff --git a/tests/import_mapping/test_import_mapping.py b/tests/import_mapping/test_import_mapping.py index 5d4334df..49c8e60f 100644 --- a/tests/import_mapping/test_import_mapping.py +++ b/tests/import_mapping/test_import_mapping.py @@ -21,10 +21,10 @@ from spinedb_api.mapping import Position, to_dict as mapping_to_dict, unflatten from spinedb_api.import_mapping.import_mapping import ( ImportMapping, + EntityClassMapping, + EntityMapping, check_validity, ParameterDefinitionMapping, - ObjectClassMapping, - ObjectMapping, IndexNameMapping, ParameterValueIndexMapping, ExpandedParameterValueMapping, @@ -58,7 +58,11 @@ def test_convert_functions_float(self): param_def_mapping.value = "param" param_def_mapping.flatten()[-1].position = 1 mapped_data, _ = get_mapped_data(data, [mapping], column_convert_fns=column_convert_fns) - expected = {'object_classes': {'a'}, 'objects': {('a', 'obj')}, 'object_parameters': [('a', 'param', 1.2)]} + expected = { + 'entity_classes': {('a',)}, + 'entities': {('a', 'obj')}, + 'parameter_definitions': [('a', 'param', 1.2)], + } self.assertEqual(mapped_data, expected) def test_convert_functions_str(self): @@ -74,9 +78,9 @@ def test_convert_functions_str(self): param_def_mapping.flatten()[-1].position = 1 mapped_data, _ = get_mapped_data(data, [mapping], column_convert_fns=column_convert_fns) expected = { - 'object_classes': {'a'}, - 'objects': {('a', 'obj')}, - 'object_parameters': [('a', 'param', '1111.2222')], + 'entity_classes': {('a',)}, + 'entities': {('a', 'obj')}, + 'parameter_definitions': [('a', 'param', '1111.2222')], } self.assertEqual(mapped_data, expected) @@ -92,7 +96,11 @@ def test_convert_functions_bool(self): param_def_mapping.value = "param" param_def_mapping.flatten()[-1].position = 1 mapped_data, _ = get_mapped_data(data, [mapping], column_convert_fns=column_convert_fns) - expected = {'object_classes': {'a'}, 'objects': {('a', 'obj')}, 'object_parameters': [('a', 'param', False)]} + expected = { + 'entity_classes': {('a',)}, + 'entities': {('a', 'obj')}, + 'parameter_definitions': [('a', 'param', False)], + } self.assertEqual(mapped_data, expected) def test_convert_functions_with_error(self): @@ -179,27 +187,21 @@ def test_object_class_mapping(self): mapping = import_mapping_from_dict({"map_type": "ObjectClass"}) d = mapping_to_dict(mapping) types = [m["map_type"] for m in d] - expected = ['ObjectClass', 'Object', 'ObjectMetadata'] + expected = ['EntityClass', 'Entity', 'EntityMetadata'] self.assertEqual(types, expected) def test_relationship_class_mapping(self): mapping = import_mapping_from_dict({"map_type": "RelationshipClass"}) d = mapping_to_dict(mapping) types = [m["map_type"] for m in d] - expected = [ - 'RelationshipClass', - 'RelationshipClassObjectClass', - 'Relationship', - 'RelationshipObject', - 'RelationshipMetadata', - ] + expected = ['EntityClass', 'Dimension', 'Entity', 'Element', 'EntityMetadata'] self.assertEqual(types, expected) def test_object_group_mapping(self): mapping = import_mapping_from_dict({"map_type": "ObjectGroup"}) d = mapping_to_dict(mapping) types = [m["map_type"] for m in d] - expected = ['ObjectClass', 'Object', 'ObjectGroup'] + expected = ['EntityClass', 'Entity', 'EntityGroup'] self.assertEqual(types, expected) def test_alternative_mapping(self): @@ -268,9 +270,9 @@ def test_ObjectClass_to_dict_from_dict(self): mapping = import_mapping_from_dict(mapping) out = mapping_to_dict(mapping) expected = [ - {'map_type': 'ObjectClass', 'position': 0}, - {'map_type': 'Object', 'position': 1}, - {'map_type': 'ObjectMetadata', 'position': 'hidden'}, + {'map_type': 'EntityClass', 'position': 0}, + {'map_type': 'Entity', 'position': 1}, + {'map_type': 'EntityMetadata', 'position': 'hidden'}, {'map_type': 'ParameterDefinition', 'position': 2}, {'map_type': 'Alternative', 'position': 'hidden'}, {'map_type': 'ParameterValueMetadata', 'position': 'hidden'}, @@ -283,9 +285,9 @@ def test_ObjectClass_object_from_dict_to_dict(self): mapping = import_mapping_from_dict(mapping) out = mapping_to_dict(mapping) expected = [ - {'map_type': 'ObjectClass', 'position': 0}, - {'map_type': 'Object', 'position': 1}, - {'map_type': 'ObjectMetadata', 'position': 'hidden'}, + {'map_type': 'EntityClass', 'position': 0}, + {'map_type': 'Entity', 'position': 1}, + {'map_type': 'EntityMetadata', 'position': 'hidden'}, ] self.assertEqual(out, expected) @@ -294,9 +296,9 @@ def test_ObjectClass_object_from_dict_to_dict2(self): mapping = import_mapping_from_dict(mapping) out = mapping_to_dict(mapping) expected = [ - {'map_type': 'ObjectClass', 'position': 'hidden', 'value': 'cls'}, - {'map_type': 'Object', 'position': 'hidden', 'value': 'obj'}, - {'map_type': 'ObjectMetadata', 'position': 'hidden'}, + {'map_type': 'EntityClass', 'position': 'hidden', 'value': 'cls'}, + {'map_type': 'Entity', 'position': 'hidden', 'value': 'obj'}, + {'map_type': 'EntityMetadata', 'position': 'hidden'}, ] self.assertEqual(out, expected) @@ -311,13 +313,13 @@ def test_RelationshipClassMapping_from_dict_to_dict(self): mapping = import_mapping_from_dict(mapping) out = mapping_to_dict(mapping) expected = [ - {'map_type': 'RelationshipClass', 'position': 'hidden', 'value': 'unit__node'}, - {'map_type': 'RelationshipClassObjectClass', 'position': 0}, - {'map_type': 'RelationshipClassObjectClass', 'position': 1}, - {'map_type': 'Relationship', 'position': 'hidden', 'value': 'relationship'}, - {'map_type': 'RelationshipObject', 'position': 0}, - {'map_type': 'RelationshipObject', 'position': 1}, - {'map_type': 'RelationshipMetadata', 'position': 'hidden'}, + {'map_type': 'EntityClass', 'position': 'hidden', 'value': 'unit__node'}, + {'map_type': 'Dimension', 'position': 0}, + {'map_type': 'Dimension', 'position': 1}, + {'map_type': 'Entity', 'position': 'hidden'}, + {'map_type': 'Element', 'position': 0}, + {'map_type': 'Element', 'position': 1}, + {'map_type': 'EntityMetadata', 'position': 'hidden'}, {'map_type': 'ParameterDefinition', 'position': 'hidden', 'value': 'pname'}, {'map_type': 'Alternative', 'position': 'hidden'}, {'map_type': 'ParameterValueMetadata', 'position': 'hidden'}, @@ -335,13 +337,13 @@ def test_RelationshipClassMapping_from_dict_to_dict2(self): mapping = import_mapping_from_dict(mapping) out = mapping_to_dict(mapping) expected = [ - {'map_type': 'RelationshipClass', 'position': 'hidden', 'value': 'unit__node'}, - {'map_type': 'RelationshipClassObjectClass', 'position': 'hidden', 'value': 'cls'}, - {'map_type': 'RelationshipClassObjectClass', 'position': 0}, - {'map_type': 'Relationship', 'position': 'hidden', 'value': 'relationship'}, - {'map_type': 'RelationshipObject', 'position': 'hidden', 'value': 'obj'}, - {'map_type': 'RelationshipObject', 'position': 0}, - {'map_type': 'RelationshipMetadata', 'position': 'hidden'}, + {'map_type': 'EntityClass', 'position': 'hidden', 'value': 'unit__node'}, + {'map_type': 'Dimension', 'position': 'hidden', 'value': 'cls'}, + {'map_type': 'Dimension', 'position': 0}, + {'map_type': 'Entity', 'position': 'hidden'}, + {'map_type': 'Element', 'position': 'hidden', 'value': 'obj'}, + {'map_type': 'Element', 'position': 0}, + {'map_type': 'EntityMetadata', 'position': 'hidden'}, ] self.assertEqual(out, expected) @@ -360,11 +362,11 @@ def test_RelationshipClassMapping_from_dict_to_dict3(self): mapping = import_mapping_from_dict(mapping) out = mapping_to_dict(mapping) expected = [ - {'map_type': 'RelationshipClass', 'position': 'hidden', 'value': 'unit__node'}, - {'map_type': 'RelationshipClassObjectClass', 'position': 'hidden'}, - {'map_type': 'Relationship', 'position': 'hidden', 'value': 'relationship'}, - {'map_type': 'RelationshipObject', 'position': 'hidden'}, - {'map_type': 'RelationshipMetadata', 'position': 'hidden'}, + {'map_type': 'EntityClass', 'position': 'hidden', 'value': 'unit__node'}, + {'map_type': 'Dimension', 'position': 'hidden'}, + {'map_type': 'Entity', 'position': 'hidden'}, + {'map_type': 'Element', 'position': 'hidden'}, + {'map_type': 'EntityMetadata', 'position': 'hidden'}, {'map_type': 'ParameterDefinition', 'position': 'hidden', 'value': 'pname'}, {'map_type': 'Alternative', 'position': 'hidden'}, {'map_type': 'ParameterValueMetadata', 'position': 'hidden'}, @@ -392,9 +394,9 @@ def test_ObjectGroupMapping_to_dict_from_dict(self): mapping = import_mapping_from_dict(mapping) out = mapping_to_dict(mapping) expected = [ - {'map_type': 'ObjectClass', 'position': 0}, - {'map_type': 'Object', 'position': 1}, - {'map_type': 'ObjectGroup', 'position': 2}, + {'map_type': 'EntityClass', 'position': 0}, + {'map_type': 'Entity', 'position': 1}, + {'map_type': 'EntityGroup', 'position': 2}, ] self.assertEqual(out, expected) @@ -790,7 +792,7 @@ def test_read_iterator_with_row_with_all_Nones(self): [None, None, None, None], ["oc2", "obj2", "parameter_name2", 2], ] - expected = {"object_classes": {"oc2"}} + expected = {"entity_classes": {("oc2",)}} data = iter(input_data) data_header = next(data) @@ -803,7 +805,7 @@ def test_read_iterator_with_row_with_all_Nones(self): def test_read_iterator_with_None(self): input_data = [["object_class", "object", "parameter", "value"], None, ["oc2", "obj2", "parameter_name2", 2]] - expected = {"object_classes": {"oc2"}} + expected = {"entity_classes": {("oc2",)}} data = iter(input_data) data_header = next(data) @@ -821,10 +823,10 @@ def test_read_flat_file(self): ["oc2", "obj2", "parameter_name2", 2], ] expected = { - "object_classes": {"oc1", "oc2"}, - "objects": {("oc1", "obj1"), ("oc2", "obj2")}, - "object_parameters": [("oc1", "parameter_name1"), ("oc2", "parameter_name2")], - "object_parameter_values": [["oc1", "obj1", "parameter_name1", 1], ["oc2", "obj2", "parameter_name2", 2]], + "entity_classes": {("oc1",), ("oc2",)}, + "entities": {("oc1", "obj1"), ("oc2", "obj2")}, + "parameter_definitions": [("oc1", "parameter_name1"), ("oc2", "parameter_name2")], + "parameter_values": [["oc1", "obj1", "parameter_name1", 1], ["oc2", "obj2", "parameter_name2", 2]], } data = iter(input_data) @@ -848,10 +850,10 @@ def test_read_flat_file_array(self): ["oc1", "obj1", "parameter_name1", 2], ] expected = { - "object_classes": {"oc1"}, - "objects": {("oc1", "obj1")}, - "object_parameters": [("oc1", "parameter_name1")], - "object_parameter_values": [["oc1", "obj1", "parameter_name1", Array([1, 2])]], + "entity_classes": {("oc1",)}, + "entities": {("oc1", "obj1")}, + "parameter_definitions": [("oc1", "parameter_name1")], + "parameter_values": [["oc1", "obj1", "parameter_name1", Array([1, 2])]], } data = iter(input_data) @@ -875,10 +877,10 @@ def test_read_flat_file_array_with_ed(self): ["oc1", "obj1", "parameter_name1", 2, 1], ] expected = { - "object_classes": {"oc1"}, - "objects": {("oc1", "obj1")}, - "object_parameters": [("oc1", "parameter_name1")], - "object_parameter_values": [["oc1", "obj1", "parameter_name1", Array([1, 2])]], + "entity_classes": {("oc1",)}, + "entities": {("oc1", "obj1")}, + "parameter_definitions": [("oc1", "parameter_name1")], + "parameter_values": [["oc1", "obj1", "parameter_name1", Array([1, 2])]], } data = iter(input_data) @@ -903,7 +905,7 @@ def test_read_flat_file_array_with_ed(self): def test_read_flat_file_with_column_name_reference(self): input_data = [["object", "parameter", "value"], ["obj1", "parameter_name1", 1], ["obj2", "parameter_name2", 2]] - expected = {"object_classes": {"object"}, "objects": {("object", "obj1"), ("object", "obj2")}} + expected = {"entity_classes": {("object",)}, "entities": {("object", "obj1"), ("object", "obj2")}} data = iter(input_data) data_header = next(data) @@ -916,7 +918,10 @@ def test_read_flat_file_with_column_name_reference(self): def test_read_object_class_from_header_using_string_as_integral_index(self): input_data = [["object_class"], ["obj1"], ["obj2"]] - expected = {"object_classes": {"object_class"}, "objects": {("object_class", "obj1"), ("object_class", "obj2")}} + expected = { + "entity_classes": {("object_class",)}, + "entities": {("object_class", "obj1"), ("object_class", "obj2")}, + } data = iter(input_data) data_header = next(data) @@ -929,7 +934,10 @@ def test_read_object_class_from_header_using_string_as_integral_index(self): def test_read_object_class_from_header_using_string_as_column_header_name(self): input_data = [["object_class"], ["obj1"], ["obj2"]] - expected = {"object_classes": {"object_class"}, "objects": {("object_class", "obj1"), ("object_class", "obj2")}} + expected = { + "entity_classes": {("object_class",)}, + "entities": {("object_class", "obj1"), ("object_class", "obj2")}, + } data = iter(input_data) data_header = next(data) @@ -946,7 +954,7 @@ def test_read_object_class_from_header_using_string_as_column_header_name(self): def test_read_with_list_of_mappings(self): input_data = [["object", "parameter", "value"], ["obj1", "parameter_name1", 1], ["obj2", "parameter_name2", 2]] - expected = {"object_classes": {"object"}, "objects": {("object", "obj1"), ("object", "obj2")}} + expected = {"entity_classes": {("object",)}, "entities": {("object", "obj1"), ("object", "obj2")}} data = iter(input_data) data_header = next(data) @@ -960,10 +968,10 @@ def test_read_with_list_of_mappings(self): def test_read_pivoted_parameters_from_header(self): input_data = [["object", "parameter_name1", "parameter_name2"], ["obj1", 0, 1], ["obj2", 2, 3]] expected = { - "object_classes": {"object"}, - "objects": {("object", "obj1"), ("object", "obj2")}, - "object_parameters": [("object", "parameter_name1"), ("object", "parameter_name2")], - "object_parameter_values": [ + "entity_classes": {("object",)}, + "entities": {("object", "obj1"), ("object", "obj2")}, + "parameter_definitions": [("object", "parameter_name1"), ("object", "parameter_name2")], + "parameter_values": [ ["object", "obj1", "parameter_name1", 0], ["object", "obj1", "parameter_name2", 1], ["object", "obj2", "parameter_name1", 2], @@ -1006,10 +1014,10 @@ def test_read_empty_pivot(self): def test_read_pivoted_parameters_from_data(self): input_data = [["object", "parameter_name1", "parameter_name2"], ["obj1", 0, 1], ["obj2", 2, 3]] expected = { - "object_classes": {"object"}, - "objects": {("object", "obj1"), ("object", "obj2")}, - "object_parameters": [("object", "parameter_name1"), ("object", "parameter_name2")], - "object_parameter_values": [ + "entity_classes": {("object",)}, + "entities": {("object", "obj1"), ("object", "obj2")}, + "parameter_definitions": [("object", "parameter_name1"), ("object", "parameter_name2")], + "parameter_values": [ ["object", "obj1", "parameter_name1", 0], ["object", "obj1", "parameter_name2", 1], ["object", "obj2", "parameter_name1", 2], @@ -1040,11 +1048,11 @@ def test_pivoted_value_has_actual_position(self): ["obj2", "T2", 22.0], ] expected = { - "object_classes": {"timeline"}, - "objects": {("timeline", "obj1"), ("timeline", "obj2")}, - "object_parameters": [("timeline", "value")], + "entity_classes": {("timeline",)}, + "entities": {("timeline", "obj1"), ("timeline", "obj2")}, + "parameter_definitions": [("timeline", "value")], "alternatives": {"Base"}, - "object_parameter_values": [ + "parameter_values": [ ["timeline", "obj1", "value", Map(["T1", "T2"], [11.0, 12.0], index_name="timestep"), "Base"], ["timeline", "obj2", "value", Map(["T1", "T2"], [21.0, 22.0], index_name="timestep"), "Base"], ], @@ -1070,11 +1078,11 @@ def test_import_objects_from_pivoted_data_when_they_lack_parameter_values(self): """Pivoted mapping works even when last mapping has valid position in columns.""" input_data = [["object", "is_skilled", "has_powers"], ["obj1", "yes", "no"], ["obj2", None, None]] expected = { - "object_classes": {"node"}, - "objects": {("node", "obj1"), ("node", "obj2")}, - "object_parameters": [("node", "is_skilled"), ("node", "has_powers")], + "entity_classes": {("node",)}, + "entities": {("node", "obj1"), ("node", "obj2")}, + "parameter_definitions": [("node", "is_skilled"), ("node", "has_powers")], "alternatives": {"Base"}, - "object_parameter_values": [ + "parameter_values": [ ["node", "obj1", "is_skilled", "yes", "Base"], ["node", "obj1", "has_powers", "no", "Base"], ], @@ -1101,11 +1109,11 @@ def test_import_objects_from_pivoted_data_when_they_lack_map_type_parameter_valu ["obj1", "today", None, "yes"], ] expected = { - "object_classes": {"node"}, - "objects": {("node", "obj1")}, - "object_parameters": [("node", "is_skilled"), ("node", "has_powers")], + "entity_classes": {("node",)}, + "entities": {("node", "obj1")}, + "parameter_definitions": [("node", "is_skilled"), ("node", "has_powers")], "alternatives": {"Base"}, - "object_parameter_values": [ + "parameter_values": [ ["node", "obj1", "has_powers", Map(["yesterday", "today"], ["no", "yes"], index_name="period"), "Base"] ], } @@ -1130,10 +1138,10 @@ def test_read_flat_file_with_extra_value_dimensions(self): input_data = [["object", "time", "parameter_name1"], ["obj1", "2018-01-01", 1], ["obj1", "2018-01-02", 2]] expected = { - "object_classes": {"object"}, - "objects": {("object", "obj1")}, - "object_parameters": [("object", "parameter_name1")], - "object_parameter_values": [ + "entity_classes": {("object",)}, + "entities": {("object", "obj1")}, + "parameter_definitions": [("object", "parameter_name1")], + "parameter_values": [ [ "object", "obj1", @@ -1167,9 +1175,9 @@ def test_read_flat_file_with_parameter_definition(self): input_data = [["object", "time", "parameter_name1"], ["obj1", "2018-01-01", 1], ["obj1", "2018-01-02", 2]] expected = { - "object_classes": {"object"}, - "objects": {("object", "obj1")}, - "object_parameters": [("object", "parameter_name1")], + "entity_classes": {("object",)}, + "entities": {("object", "obj1")}, + "parameter_definitions": [("object", "parameter_name1")], } data = iter(input_data) @@ -1194,8 +1202,8 @@ def test_read_flat_file_with_parameter_definition(self): def test_read_1dim_relationships(self): input_data = [["unit", "node"], ["u1", "n1"], ["u1", "n2"]] expected = { - "relationship_classes": [("node_group", ["node"])], - "relationships": {("node_group", ("n1",)), ("node_group", ("n2",))}, + "entity_classes": {("node_group", ("node",))}, + "entities": {("node_group", ("n1",)), ("node_group", ("n2",))}, } data = iter(input_data) @@ -1215,8 +1223,8 @@ def test_read_1dim_relationships(self): def test_read_relationships(self): input_data = [["unit", "node"], ["u1", "n1"], ["u1", "n2"]] expected = { - "relationship_classes": [("unit__node", ["unit", "node"])], - "relationships": {("unit__node", ("u1", "n1")), ("unit__node", ("u1", "n2"))}, + "entity_classes": {("unit__node", ("unit", "node"))}, + "entities": {("unit__node", ("u1", "n1")), ("unit__node", ("u1", "n2"))}, } data = iter(input_data) @@ -1239,12 +1247,12 @@ def test_read_relationships(self): def test_read_relationships_with_parameters(self): input_data = [["unit", "node", "rel_parameter"], ["u1", "n1", 0], ["u1", "n2", 1]] expected = { - "relationship_classes": [("unit__node", ["unit", "node"])], - "relationships": {("unit__node", ("u1", "n1")), ("unit__node", ("u1", "n2"))}, - "relationship_parameters": [("unit__node", "rel_parameter")], - "relationship_parameter_values": [ - ["unit__node", ["u1", "n1"], "rel_parameter", 0], - ["unit__node", ["u1", "n2"], "rel_parameter", 1], + "entity_classes": {("unit__node", ("unit", "node"))}, + "entities": {("unit__node", ("u1", "n1")), ("unit__node", ("u1", "n2"))}, + "parameter_definitions": [("unit__node", "rel_parameter")], + "parameter_values": [ + ["unit__node", ("u1", "n1"), "rel_parameter", 0], + ["unit__node", ("u1", "n2"), "rel_parameter", 1], ], } @@ -1269,14 +1277,19 @@ def test_read_relationships_with_parameters(self): def test_read_relationships_with_parameters2(self): input_data = [["nuts2", "Capacity", "Fueltype"], ["BE23", 268.0, "Bioenergy"], ["DE11", 14.0, "Bioenergy"]] expected = { - "object_classes": {"nuts2", "fueltype"}, - "objects": {("nuts2", "BE23"), ("fueltype", "Bioenergy"), ("nuts2", "DE11"), ("fueltype", "Bioenergy")}, - "relationship_classes": [("nuts2__fueltype", ["nuts2", "fueltype"])], - "relationships": {("nuts2__fueltype", ("BE23", "Bioenergy")), ("nuts2__fueltype", ("DE11", "Bioenergy"))}, - "relationship_parameters": [("nuts2__fueltype", "capacity")], - "relationship_parameter_values": [ - ["nuts2__fueltype", ["BE23", "Bioenergy"], "capacity", 268.0], - ["nuts2__fueltype", ["DE11", "Bioenergy"], "capacity", 14.0], + "entity_classes": {("nuts2",), ("fueltype",), ("nuts2__fueltype", ("nuts2", "fueltype"))}, + "entities": { + ("nuts2", "BE23"), + ("fueltype", "Bioenergy"), + ("nuts2", "DE11"), + ("fueltype", "Bioenergy"), + ("nuts2__fueltype", ("BE23", "Bioenergy")), + ("nuts2__fueltype", ("DE11", "Bioenergy")), + }, + "parameter_definitions": [("nuts2__fueltype", "capacity")], + "parameter_values": [ + ["nuts2__fueltype", ("BE23", "Bioenergy"), "capacity", 268.0], + ["nuts2__fueltype", ("DE11", "Bioenergy"), "capacity", 14.0], ], } @@ -1309,10 +1322,10 @@ def test_read_relationships_with_parameters2(self): def test_read_parameter_header_with_only_one_parameter(self): input_data = [["object", "parameter_name1"], ["obj1", 0], ["obj2", 2]] expected = { - "object_classes": {"object"}, - "objects": {("object", "obj1"), ("object", "obj2")}, - "object_parameters": [("object", "parameter_name1")], - "object_parameter_values": [ + "entity_classes": {("object",)}, + "entities": {("object", "obj1"), ("object", "obj2")}, + "parameter_definitions": [("object", "parameter_name1")], + "parameter_values": [ ["object", "obj1", "parameter_name1", 0], ["object", "obj2", "parameter_name1", 2], ], @@ -1335,10 +1348,10 @@ def test_read_parameter_header_with_only_one_parameter(self): def test_read_pivoted_parameters_from_data_with_skipped_column(self): input_data = [["object", "parameter_name1", "parameter_name2"], ["obj1", 0, 1], ["obj2", 2, 3]] expected = { - "object_classes": {"object"}, - "objects": {("object", "obj1"), ("object", "obj2")}, - "object_parameters": [("object", "parameter_name1")], - "object_parameter_values": [ + "entity_classes": {("object",)}, + "entities": {("object", "obj1"), ("object", "obj2")}, + "parameter_definitions": [("object", "parameter_name1")], + "parameter_values": [ ["object", "obj1", "parameter_name1", 0], ["object", "obj2", "parameter_name1", 2], ], @@ -1361,10 +1374,15 @@ def test_read_pivoted_parameters_from_data_with_skipped_column(self): def test_read_relationships_and_import_objects(self): input_data = [["unit", "node"], ["u1", "n1"], ["u2", "n2"]] expected = { - "relationship_classes": [("unit__node", ["unit", "node"])], - "relationships": {("unit__node", ("u1", "n1")), ("unit__node", ("u2", "n2"))}, - "object_classes": {"unit", "node"}, - "objects": {("unit", "u1"), ("node", "n1"), ("unit", "u2"), ("node", "n2")}, + "entity_classes": {("unit",), ("node",), ("unit__node", ("unit", "node"))}, + "entities": { + ("unit", "u1"), + ("node", "n1"), + ("unit", "u2"), + ("node", "n2"), + ("unit__node", ("u1", "n1")), + ("unit__node", ("u2", "n2")), + }, } data = iter(input_data) @@ -1389,10 +1407,10 @@ def test_read_relationships_parameter_values_with_extra_dimensions(self): input_data = [["", "a", "b"], ["", "c", "d"], ["", "e", "f"], ["a", 2, 3], ["b", 4, 5]] expected = { - "relationship_classes": [("unit__node", ["unit", "node"])], - "relationship_parameters": [("unit__node", "e"), ("unit__node", "f")], - "relationships": {("unit__node", ("a", "c")), ("unit__node", ("b", "d"))}, - "relationship_parameter_values": [ + "entity_classes": {("unit__node", ("unit", "node"))}, + "parameter_definitions": [("unit__node", "e"), ("unit__node", "f")], + "entities": {("unit__node", ("a", "c")), ("unit__node", ("b", "d"))}, + "parameter_values": [ ["unit__node", ("a", "c"), "e", Map(["a", "b"], [2, 4])], ["unit__node", ("b", "d"), "f", Map(["a", "b"], [3, 5])], ], @@ -1426,10 +1444,10 @@ def test_read_data_with_read_start_row(self): ["oc2", "obj2", "parameter_name2", 2], ] expected = { - "object_classes": {"oc1", "oc2"}, - "objects": {("oc1", "obj1"), ("oc2", "obj2")}, - "object_parameters": [("oc1", "parameter_name1"), ("oc2", "parameter_name2")], - "object_parameter_values": [["oc1", "obj1", "parameter_name1", 1], ["oc2", "obj2", "parameter_name2", 2]], + "entity_classes": {("oc1",), ("oc2",)}, + "entities": {("oc1", "obj1"), ("oc2", "obj2")}, + "parameter_definitions": [("oc1", "parameter_name1"), ("oc2", "parameter_name2")], + "parameter_values": [["oc1", "obj1", "parameter_name1", 1], ["oc2", "obj2", "parameter_name2", 2]], } data = iter(input_data) @@ -1455,10 +1473,10 @@ def test_read_data_with_two_mappings_with_different_read_start_row(self): ["oc1_obj2", "oc2_obj2", 2, 4], ] expected = { - "object_classes": {"oc1", "oc2"}, - "objects": {("oc1", "oc1_obj1"), ("oc1", "oc1_obj2"), ("oc2", "oc2_obj2")}, - "object_parameters": [("oc1", "parameter_class1"), ("oc2", "parameter_class2")], - "object_parameter_values": [ + "entity_classes": {("oc1",), ("oc2",)}, + "entities": {("oc1", "oc1_obj1"), ("oc1", "oc1_obj2"), ("oc2", "oc2_obj2")}, + "parameter_definitions": [("oc1", "parameter_class1"), ("oc2", "parameter_class2")], + "parameter_values": [ ["oc1", "oc1_obj1", "parameter_class1", 1], ["oc1", "oc1_obj2", "parameter_class1", 2], ["oc2", "oc2_obj2", "parameter_class2", 4], @@ -1498,8 +1516,8 @@ def test_read_object_class_with_table_name_as_class_name(self): } out, errors = get_mapped_data(data, [mapping], data_header, "class name") expected = { - "object_classes": {"class name"}, - "objects": {("class name", "object 1"), ("class name", "object 2")}, + "entity_classes": {("class name",)}, + "entities": {("class name", "object 1"), ("class name", "object 2")}, } self.assertFalse(errors) self.assertEqual(out, expected) @@ -1523,10 +1541,10 @@ def test_read_flat_map_from_columns(self): out, errors = get_mapped_data(data, [mapping], data_header) expected_map = Map(["key1", "key2"], [-2, -1]) expected = { - "object_classes": {"object_class"}, - "objects": {("object_class", "object")}, - "object_parameter_values": [["object_class", "object", "parameter", expected_map]], - "object_parameters": [("object_class", "parameter")], + "entity_classes": {("object_class",)}, + "entities": {("object_class", "object")}, + "parameter_values": [["object_class", "object", "parameter", expected_map]], + "parameter_definitions": [("object_class", "parameter")], } self.assertFalse(errors) self.assertEqual(out, expected) @@ -1550,10 +1568,10 @@ def test_read_nested_map_from_columns(self): out, errors = get_mapped_data(data, [mapping], data_header) expected_map = Map(["key11", "key21"], [Map(["key12"], [-2]), Map(["key22"], [-1])]) expected = { - "object_classes": {"object_class"}, - "objects": {("object_class", "object")}, - "object_parameter_values": [["object_class", "object", "parameter", expected_map]], - "object_parameters": [("object_class", "parameter")], + "entity_classes": {("object_class",)}, + "entities": {("object_class", "object")}, + "parameter_values": [["object_class", "object", "parameter", expected_map]], + "parameter_definitions": [("object_class", "parameter")], } self.assertFalse(errors) self.assertEqual(out, expected) @@ -1594,10 +1612,10 @@ def test_read_uneven_nested_map_from_columns(self): ], ) expected = { - "object_classes": {"object_class"}, - "objects": {("object_class", "object")}, - "object_parameter_values": [["object_class", "object", "parameter", expected_map]], - "object_parameters": [("object_class", "parameter")], + "entity_classes": {("object_class",)}, + "entities": {("object_class", "object")}, + "parameter_values": [["object_class", "object", "parameter", expected_map]], + "parameter_definitions": [("object_class", "parameter")], } self.assertFalse(errors) self.assertEqual(out, expected) @@ -1636,10 +1654,10 @@ def test_read_nested_map_with_compression(self): ], ) expected = { - "object_classes": {"object_class"}, - "objects": {("object_class", "object")}, - "object_parameter_values": [["object_class", "object", "parameter", expected_map]], - "object_parameters": [("object_class", "parameter")], + "entity_classes": {("object_class",)}, + "entities": {("object_class", "object")}, + "parameter_values": [["object_class", "object", "parameter", expected_map]], + "parameter_definitions": [("object_class", "parameter")], } self.assertFalse(errors) self.assertEqual(out, expected) @@ -1799,8 +1817,8 @@ def test_read_object_group_without_parameters(self): mapping = {"map_type": "ObjectGroup", "name": 0, "groups": 1, "members": 2} out, errors = get_mapped_data(data, [mapping], data_header) expected = dict() - expected["object_classes"] = {"class_A"} - expected["object_groups"] = { + expected["entity_classes"] = {("class_A",)} + expected["entity_groups"] = { ("class_A", "group1", "object1"), ("class_A", "group1", "object2"), ("class_A", "group2", "object3"), @@ -1820,13 +1838,13 @@ def test_read_object_group_and_import_objects(self): mapping = {"map_type": "ObjectGroup", "name": 0, "groups": 1, "members": 2, "import_objects": True} out, errors = get_mapped_data(data, [mapping], data_header) expected = dict() - expected["object_groups"] = { + expected["entity_groups"] = { ("class_A", "group1", "object1"), ("class_A", "group1", "object2"), ("class_A", "group2", "object3"), } - expected["object_classes"] = {"class_A"} - expected["objects"] = { + expected["entity_classes"] = {("class_A",)} + expected["entities"] = { ("class_A", "group1"), ("class_A", "object1"), ("class_A", "group1"), @@ -1858,8 +1876,8 @@ def test_read_parameter_definition_with_default_values_and_value_lists(self): } out, errors = get_mapped_data(data, [mapping], data_header) expected = dict() - expected["object_classes"] = {"class_A", "class_A", "class_B"} - expected["object_parameters"] = [ + expected["entity_classes"] = {("class_A",), ("class_A",), ("class_B",)} + expected["parameter_definitions"] = [ ("class_A", "param1", 23.0, "listA"), ("class_A", "param2", 42.0, "listB"), ("class_B", "param3", 5.0, "listA"), @@ -1882,8 +1900,8 @@ def test_map_as_default_parameter_value(self): out, errors = get_mapped_data(data, [mapping]) expected_map = Map(["key1", "key2", "key3"], [-2.3, 5.5, 3.2]) expected = { - "object_classes": {"object_class"}, - "object_parameters": [("object_class", "parameter", expected_map)], + "entity_classes": {("object_class",)}, + "parameter_definitions": [("object_class", "parameter", expected_map)], } self.assertFalse(errors) self.assertEqual(out, expected) @@ -1904,8 +1922,8 @@ def test_read_parameter_definition_with_nested_map_as_default_value(self): out, errors = get_mapped_data(data, [mapping], data_header) expected_map = Map(["key11", "key21"], [Map(["key12"], [-2]), Map(["key22"], [-1])]) expected = { - "object_classes": {"object_class"}, - "object_parameters": [("object_class", "parameter", expected_map)], + "entity_classes": {("object_class",)}, + "parameter_definitions": [("object_class", "parameter", expected_map)], } self.assertFalse(errors) self.assertEqual(out, expected) @@ -1916,9 +1934,9 @@ def test_read_map_index_names_from_columns(self): data_header = next(data) mapping_root = unflatten( [ - ObjectClassMapping(Position.hidden, value="object_class"), + EntityClassMapping(Position.hidden, value="object_class"), ParameterDefinitionMapping(Position.hidden, value="parameter"), - ObjectMapping(Position.hidden, value="object"), + EntityMapping(Position.hidden, value="object"), ParameterValueTypeMapping(Position.hidden, value="map"), IndexNameMapping(Position.header, value=0), ParameterValueIndexMapping(0), @@ -1934,10 +1952,10 @@ def test_read_map_index_names_from_columns(self): index_name="Index 1", ) expected = { - "object_classes": {"object_class"}, - "objects": {("object_class", "object")}, - "object_parameter_values": [["object_class", "object", "parameter", expected_map]], - "object_parameters": [("object_class", "parameter")], + "entity_classes": {("object_class",)}, + "entities": {("object_class", "object")}, + "parameter_values": [["object_class", "object", "parameter", expected_map]], + "parameter_definitions": [("object_class", "parameter")], } self.assertFalse(errors) self.assertEqual(out, expected) @@ -1948,9 +1966,9 @@ def test_missing_map_index_name(self): data_header = next(data) mapping_root = unflatten( [ - ObjectClassMapping(Position.hidden, value="object_class"), + EntityClassMapping(Position.hidden, value="object_class"), ParameterDefinitionMapping(Position.hidden, value="parameter"), - ObjectMapping(Position.hidden, value="object"), + EntityMapping(Position.hidden, value="object"), ParameterValueTypeMapping(Position.hidden, value="map"), IndexNameMapping(Position.hidden, value=None), ParameterValueIndexMapping(0), @@ -1966,10 +1984,10 @@ def test_missing_map_index_name(self): index_name="", ) expected = { - "object_classes": {"object_class"}, - "objects": {("object_class", "object")}, - "object_parameter_values": [["object_class", "object", "parameter", expected_map]], - "object_parameters": [("object_class", "parameter")], + "entity_classes": {("object_class",)}, + "entities": {("object_class", "object")}, + "parameter_values": [["object_class", "object", "parameter", expected_map]], + "parameter_definitions": [("object_class", "parameter")], } self.assertFalse(errors) self.assertEqual(out, expected) @@ -1980,7 +1998,7 @@ def test_read_default_value_index_names_from_columns(self): data_header = next(data) mapping_root = unflatten( [ - ObjectClassMapping(Position.hidden, value="object_class"), + EntityClassMapping(Position.hidden, value="object_class"), ParameterDefinitionMapping(Position.hidden, value="parameter"), ParameterDefaultValueTypeMapping(Position.hidden, value="map"), DefaultValueIndexNameMapping(Position.header, value=0), @@ -1997,8 +2015,8 @@ def test_read_default_value_index_names_from_columns(self): index_name="Index 1", ) expected = { - "object_classes": {"object_class"}, - "object_parameters": [("object_class", "parameter", expected_map)], + "entity_classes": {("object_class",)}, + "parameter_definitions": [("object_class", "parameter", expected_map)], } self.assertFalse(errors) self.assertEqual(out, expected) @@ -2006,27 +2024,27 @@ def test_read_default_value_index_names_from_columns(self): def test_filter_regular_expression_in_root_mapping(self): input_data = [["A", "p"], ["A", "q"], ["B", "r"]] data = iter(input_data) - mapping_root = unflatten([ObjectClassMapping(0, filter_re="B"), ObjectMapping(1)]) + mapping_root = unflatten([EntityClassMapping(0, filter_re="B"), EntityMapping(1)]) out, errors = get_mapped_data(data, [mapping_root]) - expected = {"object_classes": {"B"}, "objects": {("B", "r")}} + expected = {"entity_classes": {("B",)}, "entities": {("B", "r")}} self.assertFalse(errors) self.assertEqual(out, expected) def test_filter_regular_expression_in_child_mapping(self): input_data = [["A", "p"], ["A", "q"], ["B", "r"]] data = iter(input_data) - mapping_root = unflatten([ObjectClassMapping(0), ObjectMapping(1, filter_re="q|r")]) + mapping_root = unflatten([EntityClassMapping(0), EntityMapping(1, filter_re="q|r")]) out, errors = get_mapped_data(data, [mapping_root]) - expected = {"object_classes": {"A", "B"}, "objects": {("A", "q"), ("B", "r")}} + expected = {"entity_classes": {("A",), ("B",)}, "entities": {("A", "q"), ("B", "r")}} self.assertFalse(errors) self.assertEqual(out, expected) def test_filter_regular_expression_in_child_mapping_filters_parent_mappings_too(self): input_data = [["A", "p"], ["A", "q"], ["B", "r"]] data = iter(input_data) - mapping_root = unflatten([ObjectClassMapping(0), ObjectMapping(1, filter_re="q")]) + mapping_root = unflatten([EntityClassMapping(0), EntityMapping(1, filter_re="q")]) out, errors = get_mapped_data(data, [mapping_root]) - expected = {"object_classes": {"A"}, "objects": {("A", "q")}} + expected = {"entity_classes": {("A",)}, "entities": {("A", "q")}} self.assertFalse(errors) self.assertEqual(out, expected) @@ -2035,8 +2053,8 @@ def test_arrays_get_imported_to_correct_alternatives(self): data = iter(input_data) mapping_root = unflatten( [ - ObjectClassMapping(Position.hidden, value="class"), - ObjectMapping(1), + EntityClassMapping(Position.hidden, value="class"), + EntityMapping(1), ParameterDefinitionMapping(Position.hidden, value="parameter"), AlternativeMapping(0), ParameterValueTypeMapping(Position.hidden, value="array"), @@ -2045,11 +2063,11 @@ def test_arrays_get_imported_to_correct_alternatives(self): ) out, errors = get_mapped_data(data, [mapping_root]) expected = { - "object_classes": {"class"}, - "objects": {("class", "y")}, - "object_parameters": [("class", "parameter")], + "entity_classes": {("class",)}, + "entities": {("class", "y")}, + "parameter_definitions": [("class", "parameter")], "alternatives": {"Base", "alternative"}, - "object_parameter_values": [ + "parameter_values": [ ["class", "y", "parameter", Array(["p1"]), "Base"], ["class", "y", "parameter", Array(["p1"]), "alternative"], ], @@ -2060,33 +2078,33 @@ def test_arrays_get_imported_to_correct_alternatives(self): class TestHasFilter(unittest.TestCase): def test_mapping_without_filter_doesnt_have_filter(self): - mapping = ObjectClassMapping(0) + mapping = EntityClassMapping(0) self.assertFalse(mapping.has_filter()) def test_hidden_mapping_without_value_doesnt_have_filter(self): - mapping = ObjectClassMapping(Position.hidden, filter_re="a") + mapping = EntityClassMapping(Position.hidden, filter_re="a") self.assertFalse(mapping.has_filter()) def test_hidden_mapping_with_value_has_filter(self): - mapping = ObjectClassMapping(0, value="a", filter_re="b") + mapping = EntityClassMapping(0, value="a", filter_re="b") self.assertTrue(mapping.has_filter()) def test_mapping_without_value_has_filter(self): - mapping = ObjectClassMapping(Position.hidden, value="a", filter_re="b") + mapping = EntityClassMapping(Position.hidden, value="a", filter_re="b") self.assertTrue(mapping.has_filter()) def test_mapping_with_value_but_without_filter_doesnt_have_filter(self): - mapping = ObjectClassMapping(0, value="a") + mapping = EntityClassMapping(0, value="a") self.assertFalse(mapping.has_filter()) def test_child_mapping_with_filter_has_filter(self): - mapping = ObjectClassMapping(0) - mapping.child = ObjectMapping(1, filter_re="a") + mapping = EntityClassMapping(0) + mapping.child = EntityMapping(1, filter_re="a") self.assertTrue(mapping.has_filter()) def test_child_mapping_without_filter_doesnt_have_filter(self): - mapping = ObjectClassMapping(0) - mapping.child = ObjectMapping(1) + mapping = EntityClassMapping(0) + mapping.child = EntityMapping(1) self.assertFalse(mapping.has_filter()) diff --git a/tests/spine_io/test_excel_integration.py b/tests/spine_io/test_excel_integration.py index 3f91c381..693ffa5b 100644 --- a/tests/spine_io/test_excel_integration.py +++ b/tests/spine_io/test_excel_integration.py @@ -104,10 +104,10 @@ def test_map(self): def _check_parameter_value(self, val): input_data = { - "object_classes": ["dog"], - "objects": [("dog", "pluto")], - "object_parameters": [("dog", "bone")], - "object_parameter_values": [("dog", "pluto", "bone", val)], + "entity_classes": {("dog",)}, + "entities": {("dog", "pluto")}, + "parameter_definitions": [("dog", "bone")], + "parameter_values": [("dog", "pluto", "bone", val)], } db_map = DatabaseMapping("sqlite://", create=True) import_data(db_map, **input_data) @@ -118,11 +118,11 @@ def _check_parameter_value(self, val): output_data, errors = get_mapped_data_from_xlsx(path) db_map.connection.close() self.assertEqual([], errors) - input_obj_param_vals = input_data.pop("object_parameter_values") - output_obj_param_vals = output_data.pop("object_parameter_values") - self.assertEqual(1, len(output_obj_param_vals)) - input_obj_param_val = input_obj_param_vals[0] - output_obj_param_val = output_obj_param_vals[0] + input_param_vals = input_data.pop("parameter_values") + output_param_vals = output_data.pop("parameter_values") + self.assertEqual(1, len(output_param_vals)) + input_obj_param_val = input_param_vals[0] + output_obj_param_val = output_param_vals[0] for input_, output in zip(input_obj_param_val[:3], output_obj_param_val[:3]): self.assertEqual(input_, output) input_val = input_obj_param_val[3] From daaee5c7bc6e176b97b882b9cc3b1b3121ac2e3b Mon Sep 17 00:00:00 2001 From: Manuel Date: Sat, 1 Apr 2023 11:03:32 +0200 Subject: [PATCH 018/317] Adapt part of the ExportMapping to entity design --- spinedb_api/db_mapping_base.py | 44 ++- spinedb_api/export_mapping/__init__.py | 15 +- spinedb_api/export_mapping/export_mapping.py | 324 +++++++---------- spinedb_api/export_mapping/pivot.py | 6 +- spinedb_api/export_mapping/settings.py | 312 ++++++---------- spinedb_api/spine_io/exporters/excel.py | 28 +- tests/export_mapping/test_export_mapping.py | 341 +++++++++--------- tests/export_mapping/test_settings.py | 114 +++--- tests/spine_io/exporters/test_csv_writer.py | 12 +- tests/spine_io/exporters/test_excel_writer.py | 18 +- tests/spine_io/exporters/test_gdx_writer.py | 48 ++- tests/spine_io/exporters/test_sql_writer.py | 28 +- tests/spine_io/exporters/test_writer.py | 6 +- 13 files changed, 589 insertions(+), 707 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index edbf4122..ec5f7aaf 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -605,6 +605,16 @@ def ext_entity_class_sq(self): sqlalchemy.sql.expression.Alias """ if self._ext_entity_class_sq is None: + entity_class_dimension_sq = ( + self.query( + self.entity_class_dimension_sq.c.entity_class_id, + self.entity_class_dimension_sq.c.dimension_id, + self.entity_class_dimension_sq.c.position, + self.entity_class_sq.c.name.label("dimension_name"), + ) + .filter(self.entity_class_dimension_sq.c.dimension_id == self.entity_class_sq.c.id) + .subquery() + ) ecd_sq = ( self.query( self.entity_class_sq.c.id, @@ -613,14 +623,15 @@ def ext_entity_class_sq(self): self.entity_class_sq.c.display_order, self.entity_class_sq.c.display_icon, self.entity_class_sq.c.hidden, - self.entity_class_dimension_sq.c.dimension_id, - self.entity_class_dimension_sq.c.position, + entity_class_dimension_sq.c.dimension_id, + entity_class_dimension_sq.c.dimension_name, + entity_class_dimension_sq.c.position, ) .outerjoin( - self.entity_class_dimension_sq, - self.entity_class_sq.c.id == self.entity_class_dimension_sq.c.entity_class_id, + entity_class_dimension_sq, + self.entity_class_sq.c.id == entity_class_dimension_sq.c.entity_class_id, ) - .order_by(self.entity_class_sq.c.id, self.entity_class_dimension_sq.c.position) + .order_by(self.entity_class_sq.c.id, entity_class_dimension_sq.c.position) .subquery() ) self._ext_entity_class_sq = ( @@ -632,6 +643,7 @@ def ext_entity_class_sq(self): ecd_sq.c.display_icon, ecd_sq.c.hidden, group_concat(ecd_sq.c.dimension_id, ecd_sq.c.position).label("dimension_id_list"), + group_concat(ecd_sq.c.dimension_name, ecd_sq.c.position).label("dimension_name_list"), ) .group_by( ecd_sq.c.id, @@ -665,6 +677,16 @@ def ext_entity_sq(self): sqlalchemy.sql.expression.Alias """ if self._ext_entity_sq is None: + entity_element_sq = ( + self.query( + self.entity_element_sq.c.entity_id, + self.entity_element_sq.c.element_id, + self.entity_element_sq.c.position, + self.entity_sq.c.name.label("element_name"), + ) + .filter(self.entity_element_sq.c.element_id == self.entity_sq.c.id) + .subquery() + ) ee_sq = ( self.query( self.entity_sq.c.id, @@ -672,14 +694,15 @@ def ext_entity_sq(self): self.entity_sq.c.name, self.entity_sq.c.description, self.entity_sq.c.commit_id, - self.entity_element_sq.c.element_id, - self.entity_element_sq.c.position, + entity_element_sq.c.element_id, + entity_element_sq.c.element_name, + entity_element_sq.c.position, ) .outerjoin( - self.entity_element_sq, - self.entity_sq.c.id == self.entity_element_sq.c.entity_id, + entity_element_sq, + self.entity_sq.c.id == entity_element_sq.c.entity_id, ) - .order_by(self.entity_sq.c.id, self.entity_element_sq.c.position) + .order_by(self.entity_sq.c.id, entity_element_sq.c.position) .subquery() ) self._ext_entity_sq = ( @@ -690,6 +713,7 @@ def ext_entity_sq(self): ee_sq.c.description, ee_sq.c.commit_id, group_concat(ee_sq.c.element_id, ee_sq.c.position).label("element_id_list"), + group_concat(ee_sq.c.element_name, ee_sq.c.position).label("element_name_list"), ) .group_by( ee_sq.c.id, diff --git a/spinedb_api/export_mapping/__init__.py b/spinedb_api/export_mapping/__init__.py index f59ba6ae..d82d28d4 100644 --- a/spinedb_api/export_mapping/__init__.py +++ b/spinedb_api/export_mapping/__init__.py @@ -19,16 +19,13 @@ from .settings import ( alternative_export, feature_export, - object_export, - object_group_export, - object_parameter_default_value_export, - object_parameter_export, + entity_export, + entity_group_export, + entity_class_parameter_default_value_export, + entity_parameter_export, parameter_value_list_export, - relationship_export, - relationship_object_parameter_default_value_export, - relationship_object_parameter_export, - relationship_parameter_default_value_export, - relationship_parameter_export, + entity_class_dimension_parameter_default_value_export, + entity_element_parameter_export, scenario_alternative_export, scenario_export, tool_export, diff --git a/spinedb_api/export_mapping/export_mapping.py b/spinedb_api/export_mapping/export_mapping.py index 5c2f1d0f..474b9277 100644 --- a/spinedb_api/export_mapping/export_mapping.py +++ b/spinedb_api/export_mapping/export_mapping.py @@ -627,71 +627,98 @@ def id_field(): return None -class ObjectClassMapping(ExportMapping): - """Maps object classes. +class EntityClassMapping(ExportMapping): + """Maps entity classes. Can be used as the topmost mapping. """ - MAP_TYPE = "ObjectClass" + MAP_TYPE = "EntityClass" def add_query_columns(self, db_map, query): return query.add_columns( - db_map.object_class_sq.c.id.label("object_class_id"), - db_map.object_class_sq.c.name.label("object_class_name"), + db_map.ext_entity_class_sq.c.id.label("entity_class_id"), + db_map.ext_entity_class_sq.c.name.label("entity_class_name"), + db_map.ext_entity_class_sq.c.dimension_id_list.label("dimension_id_list"), + db_map.ext_entity_class_sq.c.dimension_name_list.label("dimension_name_list"), ) @staticmethod def name_field(): - return "object_class_name" + return "entity_class_name" @staticmethod def id_field(): # Use the class name here, for the sake of the standard excel export - return "object_class_name" + return "entity_class_name" + + def query_parents(self, what): + if what != "dimension": + return super().query_parents(what) + return -1 + + def _title_state(self, db_row): + state = super()._title_state(db_row) + state["dimension_id_list"] = getattr(db_row, "dimension_id_list") + return state -class ObjectMapping(ExportMapping): - """Maps objects. +class EntityMapping(ExportMapping): + """Maps entities. - Cannot be used as the topmost mapping; one of the parents must be :class:`ObjectClassMapping`. + Cannot be used as the topmost mapping; one of the parents must be :class:`EntityClassMapping`. """ - MAP_TYPE = "Object" + MAP_TYPE = "Entity" def add_query_columns(self, db_map, query): - return query.add_columns(db_map.object_sq.c.id.label("object_id"), db_map.object_sq.c.name.label("object_name")) + return query.add_columns( + db_map.ext_entity_sq.c.id.label("entity_id"), + db_map.ext_entity_sq.c.name.label("entity_name"), + db_map.ext_entity_sq.c.element_id_list, + db_map.ext_entity_sq.c.element_name_list, + ) def filter_query(self, db_map, query): - return query.outerjoin(db_map.object_sq, db_map.object_sq.c.class_id == db_map.object_class_sq.c.id) + return query.outerjoin(db_map.ext_entity_sq, db_map.ext_entity_sq.c.class_id == db_map.ext_entity_class_sq.c.id) @staticmethod def name_field(): - return "object_name" + return "entity_name" @staticmethod def id_field(): - return "object_id" + return "entity_id" + + def query_parents(self, what): + if what != "dimension": + return super().query_parents(what) + return -1 + + def _title_state(self, db_row): + state = super()._title_state(db_row) + state["element_id_list"] = getattr(db_row, "element_id_list") + return state @staticmethod def is_buddy(parent): - return isinstance(parent, ObjectClassMapping) + return isinstance(parent, EntityClassMapping) -class ObjectGroupMapping(ExportMapping): - """Maps object groups. +class EntityGroupMapping(ExportMapping): + """Maps entity groups. - Cannot be used as the topmost mapping; one of the parents must be :class:`ObjectClassMapping`. + Cannot be used as the topmost mapping; one of the parents must be :class:`EntityClassMapping`. """ - MAP_TYPE = "ObjectGroup" + MAP_TYPE = "EntityGroup" def add_query_columns(self, db_map, query): return query.add_columns(db_map.ext_entity_group_sq.c.group_id, db_map.ext_entity_group_sq.c.group_name) def filter_query(self, db_map, query): return query.outerjoin( - db_map.ext_entity_group_sq, db_map.ext_entity_group_sq.c.class_id == db_map.object_class_sq.c.id + db_map.ext_entity_group_sq, db_map.ext_entity_group_sq.c.class_id == db_map.ext_entity_class_sq.c.id ).distinct() @staticmethod @@ -704,81 +731,47 @@ def id_field(): @staticmethod def is_buddy(parent): - return isinstance(parent, ObjectClassMapping) + return isinstance(parent, EntityClassMapping) -class ObjectGroupObjectMapping(ExportMapping): - """Maps objects in object groups. +class EntityGroupEntityMapping(ExportMapping): + """Maps entities in objectentity groups. - Cannot be used as the topmost mapping; one of the parents must be :class:`ObjectGroupMapping`. + Cannot be used as the topmost mapping; one of the parents must be :class:`EntityGroupMapping`. """ - MAP_TYPE = "ObjectGroupObject" + MAP_TYPE = "EntityGroupEntity" def add_query_columns(self, db_map, query): - return query.add_columns(db_map.object_sq.c.id.label("object_id"), db_map.object_sq.c.name.label("object_name")) + return query.add_columns( + db_map.ext_entity_sq.c.id.label("entity_id"), db_map.ext_entity_sq.c.name.label("entity_name") + ) def filter_query(self, db_map, query): - return query.filter(db_map.ext_entity_group_sq.c.member_id == db_map.object_sq.c.id) + return query.filter(db_map.ext_entity_group_sq.c.member_id == db_map.ext_entity_sq.c.id) @staticmethod def name_field(): - return "object_name" + return "entity_name" @staticmethod def id_field(): - return "object_id" + return "entity_id" @staticmethod def is_buddy(parent): - return isinstance(parent, ObjectGroupMapping) - - -class RelationshipClassMapping(ExportMapping): - """Maps relationships classes. - - Can be used as the topmost mapping. - """ - - MAP_TYPE = "RelationshipClass" - - def add_query_columns(self, db_map, query): - return query.add_columns( - db_map.wide_relationship_class_sq.c.id.label("relationship_class_id"), - db_map.wide_relationship_class_sq.c.name.label("relationship_class_name"), - db_map.wide_relationship_class_sq.c.object_class_id_list, - db_map.wide_relationship_class_sq.c.object_class_name_list, - ) - - @staticmethod - def name_field(): - return "relationship_class_name" - - @staticmethod - def id_field(): - # Use the class name here, for the sake of the standard excel export - return "relationship_class_name" - - def query_parents(self, what): - if what != "dimension": - return super().query_parents(what) - return -1 + return isinstance(parent, EntityGroupMapping) - def _title_state(self, db_row): - state = super()._title_state(db_row) - state["object_class_id_list"] = getattr(db_row, "object_class_id_list") - return state +class DimensionHighlightingMapping(EntityClassMapping): + """Maps entity classes. -class RelationshipClassObjectHighlightingMapping(RelationshipClassMapping): - """Maps relationships classes. - - Adds object class dimension chosen by highlight_dimension to the query. + Adds dimension chosen by highlight_dimension to the query. Can be used as the topmost mapping. """ - MAP_TYPE = "RelationshipClassObjectHighlightingMapping" + MAP_TYPE = "DimensionHighlighting" def __init__(self, position, value=None, header="", filter_re="", highlight_dimension=0): super().__init__(position, value, header, filter_re) @@ -794,21 +787,21 @@ def highlight_dimension(self, dimension): def add_query_columns(self, db_map, query): query = super().add_query_columns(db_map, query) - return query.add_columns(db_map.object_class_sq.c.id.label("object_class_id")) + return query.add_columns(db_map.entity_class_sq.c.id.label("dimension_id")) def filter_query(self, db_map, query): - highlighted_object_class_qry = db_map.query(db_map.relationship_class_sq).filter( - db_map.relationship_class_sq.c.dimension == self._highlight_dimension + highlighted_dimension_qry = db_map.query(db_map.entity_class_dimension_sq).filter_by( + position=self._highlight_dimension ) conditions = ( - and_(db_map.wide_relationship_class_sq.c.id == x.id, db_map.object_class_sq.c.id == x.object_class_id) - for x in highlighted_object_class_qry + and_(db_map.ext_entity_class_sq.c.id == x.entity_class_id, db_map.entity_class_sq.c.id == x.dimension_id) + for x in highlighted_dimension_qry ) return query.filter(or_(*conditions)) @staticmethod def id_field(): - return "relationship_class_id" + return "entity_class_id" def query_parents(self, what): if what != "highlight_dimension": @@ -828,25 +821,28 @@ def reconstruct(cls, position, value, header, filter_re, ignorable, mapping_dict return mapping -class RelationshipClassObjectClassMapping(ExportMapping): - """Maps relationship class object classes. +class DimensionMapping(ExportMapping): + """Maps dimensions. - Cannot be used as the topmost mapping; one of the parents must be :class:`RelationshipClassMapping`. + Cannot be used as the topmost mapping; one of the parents must be :class:`EntityClassMapping`. """ - MAP_TYPE = "RelationshipClassObjectClass" + MAP_TYPE = "Dimension" _cached_dimension = None @staticmethod def name_field(): - return "object_class_name_list" + return "dimension_name_list" @staticmethod def id_field(): - return "object_class_id_list" + return "dimension_id_list" def _data(self, db_row): - data = super()._data(db_row).split(",") + dimension_name_list = super()._data(db_row) + if dimension_name_list is None: + return None + data = dimension_name_list.split(",") if self._cached_dimension is None: self._cached_dimension = self.query_parents("dimension") try: @@ -861,104 +857,62 @@ def query_parents(self, what): @staticmethod def is_buddy(parent): - return isinstance(parent, RelationshipClassMapping) - - -class RelationshipMapping(ExportMapping): - """Maps relationships. - - Cannot be used as the topmost mapping; one of the parents must be :class:`RelationshipClassMapping`. - """ - - MAP_TYPE = "Relationship" - - def add_query_columns(self, db_map, query): - return query.add_columns( - db_map.wide_relationship_sq.c.id.label("relationship_id"), - db_map.wide_relationship_sq.c.name.label("relationship_name"), - db_map.wide_relationship_sq.c.object_id_list, - db_map.wide_relationship_sq.c.object_name_list, - ) - - def filter_query(self, db_map, query): - return query.outerjoin( - db_map.wide_relationship_sq, - db_map.wide_relationship_sq.c.class_id == db_map.wide_relationship_class_sq.c.id, - ) - - @staticmethod - def name_field(): - return "relationship_name" - - @staticmethod - def id_field(): - return "relationship_id" - - def query_parents(self, what): - if what != "dimension": - return super().query_parents(what) - return -1 - - def _title_state(self, db_row): - state = super()._title_state(db_row) - state["object_id_list"] = getattr(db_row, "object_id_list") - return state - - @staticmethod - def is_buddy(parent): - return isinstance(parent, RelationshipClassMapping) + return isinstance(parent, EntityClassMapping) -class RelationshipObjectHighlightingMapping(RelationshipMapping): - """Maps relationships. +class ElementHighlightingMapping(EntityMapping): + """Maps entities. - Adds object dimension chosen by highlight_dimension in relationship class mapping to the query. + Adds elements chosen by highlight_dimension in dimension highlighting mapping to the query. Cannot be used as the topmost mapping; - one of the parents must be :class:`RelationshipClassObjectHighlightingMapping`. + one of the parents must be :class:`DimensionHighlightingMapping`. """ - MAP_TYPE = "RelationshipObjectHighlightingMapping" + MAP_TYPE = "ElementHighlighting" def add_query_columns(self, db_map, query): query = super().add_query_columns(db_map, query) - return query.add_columns(db_map.object_sq.c.id.label("object_id")) + return query.add_columns(db_map.entity_sq.c.id.label("element_id")) def filter_query(self, db_map, query): - highlighted_object_qry = db_map.query(db_map.relationship_sq).filter( - db_map.relationship_sq.c.dimension == self.query_parents("highlight_dimension") + highlighted_element_qry = db_map.query(db_map.entity_element_sq).filter_by( + position=self.query_parents("highlight_dimension") ) conditions = ( - and_(db_map.wide_relationship_sq.c.id == x.id, db_map.object_sq.c.id == x.object_id) - for x in highlighted_object_qry + and_(db_map.ext_entity_sq.c.id == x.entity_id, db_map.entity_sq.c.id == x.element_id) + for x in highlighted_element_qry ) return query.filter(or_(*conditions)) @staticmethod def is_buddy(parent): - return isinstance(parent, RelationshipClassObjectHighlightingMapping) + return isinstance(parent, DimensionHighlightingMapping) -class RelationshipObjectMapping(ExportMapping): - """Maps relationship's objects. +class ElementMapping(ExportMapping): + """Maps elements. - Cannot be used as the topmost mapping; must have :class:`RelationshipClassMapping` and :class:`RelationshipMapping` + Cannot be used as the topmost mapping; must have :class:`EntityClassMapping` and :class:`EntityMapping` as parents. """ - MAP_TYPE = "RelationshipObject" + MAP_TYPE = "Element" _cached_dimension = None @staticmethod def name_field(): - return "object_name_list" + return "element_name_list" @staticmethod def id_field(): - return "object_id_list" + return "element_id_list" def _data(self, db_row): - data = super()._data(db_row).split(",") + element_name_list = super()._data(db_row) + if element_name_list is None: + return None + data = element_name_list.split(",") if self._cached_dimension is None: self._cached_dimension = self.query_parents("dimension") try: @@ -973,7 +927,7 @@ def query_parents(self, what): @staticmethod def is_buddy(parent): - return isinstance(parent, RelationshipClassObjectClassMapping) + return isinstance(parent, DimensionMapping) class ParameterDefinitionMapping(ExportMapping): @@ -992,17 +946,12 @@ def add_query_columns(self, db_map, query): def filter_query(self, db_map, query): column_names = {c["name"] for c in query.column_descriptions} - if "object_class_id" in column_names: - return query.outerjoin( - db_map.parameter_definition_sq, - db_map.parameter_definition_sq.c.object_class_id == db_map.object_class_sq.c.id, - ) - if "relationship_class_id" in column_names: - return query.outerjoin( - db_map.parameter_definition_sq, - db_map.parameter_definition_sq.c.relationship_class_id == db_map.wide_relationship_class_sq.c.id, - ) - raise RuntimeError("Logic error: this code should be unreachable.") + # "dimension_id" in column_names means a DimensionHighlightingMapping is acting + entity_class_sq = db_map.entity_class_sq if "dimension_id" in column_names else db_map.ext_entity_class_sq + return query.outerjoin( + db_map.parameter_definition_sq, + db_map.parameter_definition_sq.c.entity_class_id == entity_class_sq.c.id, + ) @staticmethod def name_field(): @@ -1175,21 +1124,13 @@ def filter_query(self, db_map, query): if not self._selects_value: return query column_names = {c["name"] for c in query.column_descriptions} - if "object_id" in column_names: - return query.filter( - and_( - db_map.parameter_value_sq.c.object_id == db_map.object_sq.c.id, - db_map.parameter_value_sq.c.parameter_definition_id == db_map.parameter_definition_sq.c.id, - ) - ) - if "relationship_id" in column_names: - return query.filter( - and_( - db_map.parameter_value_sq.c.relationship_id == db_map.wide_relationship_sq.c.id, - db_map.parameter_value_sq.c.parameter_definition_id == db_map.parameter_definition_sq.c.id, - ) + entity_sq = db_map.entity_sq if "element_id" in column_names else db_map.ext_entity_sq + return query.filter( + and_( + db_map.parameter_value_sq.c.entity_id == entity_sq.c.id, + db_map.parameter_value_sq.c.parameter_definition_id == db_map.parameter_definition_sq.c.id, ) - raise RuntimeError("Logic error: this code should be unreachable.") + ) @staticmethod def name_field(): @@ -1204,7 +1145,7 @@ def _data(self, db_row): @staticmethod def is_buddy(parent): - return isinstance(parent, (ParameterDefinitionMapping, ObjectMapping, RelationshipMapping, AlternativeMapping)) + return isinstance(parent, (ParameterDefinitionMapping, EntityMapping, AlternativeMapping)) class ParameterValueTypeMapping(ParameterValueMapping): @@ -1857,16 +1798,20 @@ def from_dict(serialized): AlternativeDescriptionMapping, AlternativeMapping, DefaultValueIndexNameMapping, + DimensionMapping, + DimensionHighlightingMapping, + ElementHighlightingMapping, + ElementMapping, ExpandedParameterDefaultValueMapping, ExpandedParameterValueMapping, FeatureEntityClassMapping, FeatureParameterDefinitionMapping, FixedValueMapping, IndexNameMapping, - ObjectClassMapping, - ObjectGroupMapping, - ObjectGroupObjectMapping, - ObjectMapping, + EntityClassMapping, + EntityGroupMapping, + EntityGroupEntityMapping, + EntityMapping, ParameterDefaultValueIndexMapping, ParameterDefaultValueMapping, ParameterDefaultValueTypeMapping, @@ -1876,12 +1821,6 @@ def from_dict(serialized): ParameterValueListValueMapping, ParameterValueMapping, ParameterValueTypeMapping, - RelationshipClassMapping, - RelationshipClassObjectClassMapping, - RelationshipClassObjectHighlightingMapping, - RelationshipMapping, - RelationshipObjectHighlightingMapping, - RelationshipObjectMapping, ScenarioActiveFlagMapping, ScenarioAlternativeMapping, ScenarioBeforeAlternativeMapping, @@ -1895,8 +1834,21 @@ def from_dict(serialized): ToolFeatureMethodParameterDefinitionMapping, ) } + legacy_mappings = { + "ParameterIndex": ParameterValueIndexMapping, + "ObjectClass": EntityClassMapping, + "ObjectGroup": EntityGroupMapping, + "ObjectGroupObject": EntityGroupEntityMapping, + "Object": EntityMapping, + "RelationshipClass": EntityClassMapping, + "RelationshipClassObjectClass": DimensionMapping, + "Relationship": EntityMapping, + "RelationshipObject": ElementMapping, + "RelationshipClassObjectHighlightingMapping": DimensionHighlightingMapping, + "RelationshipObjectHighlightingMapping": ElementHighlightingMapping, + } + mappings.update(legacy_mappings) # Legacy - mappings["ParameterIndex"] = ParameterValueIndexMapping flattened = list() for mapping_dict in serialized: position = mapping_dict["position"] diff --git a/spinedb_api/export_mapping/pivot.py b/spinedb_api/export_mapping/pivot.py index af505cf2..a8b59ec3 100644 --- a/spinedb_api/export_mapping/pivot.py +++ b/spinedb_api/export_mapping/pivot.py @@ -16,7 +16,7 @@ """ from copy import deepcopy -from .export_mapping import RelationshipMapping +from .export_mapping import EntityMapping from ..mapping import is_regular, is_pivoted, Position, unflatten, value_index from .group_functions import from_str as group_function_from_str, NoGroup @@ -236,5 +236,5 @@ def make_regular(root_mapping): def _is_unhiddable(mapping): - """Returns True if mapping uhiddable for pivoting purposes.""" - return not isinstance(mapping, RelationshipMapping) + """Returns True if mapping unhiddable for pivoting purposes.""" + return not isinstance(mapping, EntityMapping) # FIXME: Maybe also check that dimension_count > 0 ?? diff --git a/spinedb_api/export_mapping/settings.py b/spinedb_api/export_mapping/settings.py index 7c0170fe..36aa68e2 100644 --- a/spinedb_api/export_mapping/settings.py +++ b/spinedb_api/export_mapping/settings.py @@ -19,14 +19,18 @@ from .export_mapping import ( AlternativeMapping, AlternativeDescriptionMapping, + DimensionMapping, + DimensionHighlightingMapping, + ElementHighlightingMapping, + ElementMapping, ExpandedParameterDefaultValueMapping, ExpandedParameterValueMapping, FeatureEntityClassMapping, FeatureParameterDefinitionMapping, - ObjectGroupMapping, - ObjectGroupObjectMapping, - ObjectMapping, - ObjectClassMapping, + EntityGroupMapping, + EntityGroupEntityMapping, + EntityMapping, + EntityClassMapping, ParameterDefaultValueMapping, ParameterDefaultValueIndexMapping, ParameterDefinitionMapping, @@ -36,12 +40,6 @@ ParameterValueMapping, ParameterValueTypeMapping, Position, - RelationshipClassMapping, - RelationshipClassObjectClassMapping, - RelationshipClassObjectHighlightingMapping, - RelationshipMapping, - RelationshipObjectHighlightingMapping, - RelationshipObjectMapping, ScenarioActiveFlagMapping, ScenarioAlternativeMapping, ScenarioBeforeAlternativeMapping, @@ -61,152 +59,60 @@ from ..mapping import unflatten -def object_export(class_position=Position.hidden, object_position=Position.hidden): - """ - Sets up export mappings for exporting objects without parameters. - - Args: - class_position (int or Position): position of object classes - object_position (int or Position): position of objects - - Returns: - ExportMapping: root mapping - """ - class_ = ObjectClassMapping(class_position) - object_ = ObjectMapping(object_position) - class_.child = object_ - return class_ - - -def object_parameter_default_value_export( - class_position=Position.hidden, - definition_position=Position.hidden, - value_type_position=Position.hidden, - value_position=Position.hidden, - index_name_positions=None, - index_positions=None, -): - """ - Sets up export mappings for exporting objects classes and default parameter values. - - Args: - class_position (int or Position): position of object classes - definition_position (int or Position): position of parameter names - value_type_position (int or Position): position of parameter value types - value_position (int or Position): position of parameter values - index_name_positions (list of int, optional): positions of index names - index_positions (list of int, optional): positions of parameter indexes - - Returns: - ExportMapping: root mapping - """ - class_ = ObjectClassMapping(class_position) - definition = ParameterDefinitionMapping(definition_position) - _generate_default_value_mappings( - definition, value_type_position, value_position, index_name_positions, index_positions - ) - class_.child = definition - return class_ - - -def object_parameter_export( - class_position=Position.hidden, - definition_position=Position.hidden, - value_list_position=Position.hidden, - object_position=Position.hidden, - alternative_position=Position.hidden, - value_type_position=Position.hidden, - value_position=Position.hidden, - index_name_positions=None, - index_positions=None, -): - """ - Sets up export mappings for exporting objects and object parameters. - - Args: - class_position (int or Position): position of object classes in a table - definition_position (int or Position): position of parameter names in a table - value_list_position (int or Position): position of parameter value lists - object_position (int or Position): position of objects in a table - alternative_position (int or Position): position of alternatives in a table - value_type_position (int or Position): position of parameter value types in a table - value_position (int or Position): position of parameter values in a table - index_name_positions (list of int, optional): positions of index names - index_positions (list of int, optional): positions of parameter indexes in a table - - Returns: - ExportMapping: root mapping - """ - class_ = ObjectClassMapping(class_position) - definition = ParameterDefinitionMapping(definition_position) - value_list = ParameterValueListMapping(value_list_position) - value_list.set_ignorable(True) - object_ = ObjectMapping(object_position) - _generate_parameter_value_mappings( - object_, alternative_position, value_type_position, value_position, index_name_positions, index_positions - ) - value_list.child = object_ - definition.child = value_list - class_.child = definition - return class_ - - -def object_group_export( - class_position=Position.hidden, group_position=Position.hidden, object_position=Position.hidden +def entity_group_export( + class_position=Position.hidden, group_position=Position.hidden, entity_position=Position.hidden ): """ - Sets up export mappings for exporting object groups. + Sets up export mappings for exporting entity groups. Args: - class_position (int or Position): position of object classes + class_position (int or Position): position of entity classes group_position (int or Position): position of groups - object_position (int or Position): position of objects + entity_position (int or Position): position of entities Returns: ExportMapping: root mapping """ - class_ = ObjectClassMapping(class_position) - group = ObjectGroupMapping(group_position) - object_ = ObjectGroupObjectMapping(object_position) - group.child = object_ + class_ = EntityClassMapping(class_position) + group = EntityGroupMapping(group_position) + entity = EntityGroupEntityMapping(entity_position) + group.child = entity class_.child = group return class_ -def relationship_export( - relationship_class_position=Position.hidden, - relationship_position=Position.hidden, - object_class_positions=None, - object_positions=None, +def entity_export( + entity_class_position=Position.hidden, + entity_position=Position.hidden, + dimension_positions=None, + element_positions=None, ): """ - Sets up export items for exporting relationships without parameters. + Sets up export items for exporting entities without parameters. Args: - relationship_class_position (int or Position): position of relationship classes in a table - relationship_position (int or Position): position of relationships in a table - object_class_positions (Iterable, optional): positions of object classes in a table - object_positions (Iterable, optional): positions of object in a table + entity_class_position (int or Position): position of entity classes in a table + entity_position (int or Position): position of entities in a table + dimension_positions (Iterable, optional): positions of dimension in a table + element_positions (Iterable, optional): positions of element in a table Returns: ExportMapping: root mapping """ - if object_class_positions is None: - object_class_positions = list() - if object_positions is None: - object_positions = list() - relationship_class = RelationshipClassMapping(relationship_class_position) - object_or_relationship_class = _generate_dimensions( - relationship_class, RelationshipClassObjectClassMapping, object_class_positions - ) - relationship = RelationshipMapping(relationship_position) - object_or_relationship_class.child = relationship - _generate_dimensions(relationship, RelationshipObjectMapping, object_positions) - return relationship_class + if dimension_positions is None: + dimension_positions = list() + if element_positions is None: + element_positions = list() + entity_class = EntityClassMapping(entity_class_position) + dimension = _generate_dimensions(entity_class, DimensionMapping, dimension_positions) + entity = EntityMapping(entity_position) + dimension.child = entity + _generate_dimensions(entity, ElementMapping, element_positions) + return entity_class -def relationship_parameter_default_value_export( - relationship_class_position=Position.hidden, +def entity_class_parameter_default_value_export( + entity_class_position=Position.hidden, definition_position=Position.hidden, value_type_position=Position.hidden, value_position=Position.hidden, @@ -214,10 +120,10 @@ def relationship_parameter_default_value_export( index_positions=None, ): """ - Sets up export mappings for exporting relationship classes and default parameter values. + Sets up export mappings for exporting entity classes and default parameter values. Args: - relationship_class_position (int or Position): position of relationship classes + entity_class_position (int or Position): position of relationship classes definition_position (int or Position): position of parameter definitions value_type_position (int or Position): position of parameter value types value_position (int or Position): position of parameter values @@ -227,22 +133,22 @@ def relationship_parameter_default_value_export( Returns: ExportMapping: root mapping """ - relationship_class = RelationshipClassMapping(relationship_class_position) + entity_class = EntityClassMapping(entity_class_position) definition = ParameterDefinitionMapping(definition_position) _generate_default_value_mappings( definition, value_type_position, value_position, index_name_positions, index_positions ) - relationship_class.child = definition - return relationship_class + entity_class.child = definition + return entity_class -def relationship_parameter_export( - relationship_class_position=Position.hidden, +def entity_parameter_export( + entity_class_position=Position.hidden, definition_position=Position.hidden, value_list_position=Position.hidden, - relationship_position=Position.hidden, - object_class_positions=None, - object_positions=None, + entity_position=Position.hidden, + dimension_positions=None, + element_positions=None, alternative_position=Position.hidden, value_type_position=Position.hidden, value_position=Position.hidden, @@ -250,15 +156,15 @@ def relationship_parameter_export( index_positions=None, ): """ - Sets up export mappings for exporting relationships and relationship parameters. + Sets up export mappings for exporting entities and parameter values. Args: - relationship_class_position (int or Position): position of relationship classes + entity_class_position (int or Position): position of entity classes definition_position (int or Position): position of parameter definitions value_list_position (int or Position): position of parameter value lists - relationship_position (int or Position): position of relationships - object_class_positions (list of int, optional): positions of object classes - object_positions (list of int, optional): positions of objects + entity_position (int or Position): position of entities + dimension_positions (list of int, optional): positions of dimensions + element_positions (list of int, optional): positions of elements alternative_position (int or Position): positions of alternatives value_type_position (int or Position): position of parameter value types value_position (int or Position): position of parameter values @@ -268,37 +174,35 @@ def relationship_parameter_export( Returns: ExportMapping: root mapping """ - if object_class_positions is None: - object_class_positions = list() - if object_positions is None: - object_positions = list() - relationship_class = RelationshipClassMapping(relationship_class_position) - object_or_relationship_class = _generate_dimensions( - relationship_class, RelationshipClassObjectClassMapping, object_class_positions - ) + if dimension_positions is None: + dimension_positions = list() + if element_positions is None: + element_positions = list() + entity_class = EntityClassMapping(entity_class_position) + dimension = _generate_dimensions(entity_class, DimensionMapping, dimension_positions) value_list = ParameterValueListMapping(value_list_position) value_list.set_ignorable(True) definition = ParameterDefinitionMapping(definition_position) - object_or_relationship_class.child = definition - relationship = RelationshipMapping(relationship_position) + dimension.child = definition + relationship = EntityMapping(entity_position) definition.child = value_list value_list.child = relationship - object_or_relationship = _generate_dimensions(relationship, RelationshipObjectMapping, object_positions) + element = _generate_dimensions(relationship, ElementMapping, element_positions) _generate_parameter_value_mappings( - object_or_relationship, + element, alternative_position, value_type_position, value_position, index_name_positions, index_positions, ) - return relationship_class + return entity_class -def relationship_object_parameter_default_value_export( - relationship_class_position=Position.hidden, +def entity_class_dimension_parameter_default_value_export( + entity_class_position=Position.hidden, definition_position=Position.hidden, - object_class_positions=None, + dimension_positions=None, value_type_position=Position.hidden, value_position=Position.hidden, index_name_positions=None, @@ -306,43 +210,41 @@ def relationship_object_parameter_default_value_export( highlight_dimension=0, ): """ - Sets up export mappings for exporting relationship classes but with default object parameter values. + Sets up export mappings for exporting entity classes but with default dimension parameter values. Args: - relationship_class_position (int or Position): position of relationship classes + entity_class_position (int or Position): position of entity classes definition_position (int or Position): position of parameter definitions - object_class_positions (list of int, optional): positions of object classes + dimension_positions (list of int, optional): positions of dimensions value_type_position (int or Position): position of parameter value types value_position (int or Position): position of parameter values index_name_positions (list of int, optional): positions of index names index_positions (list of int, optional): positions of parameter indexes - highlight_dimension (int): selected object class' relationship dimension + highlight_dimension (int): selected entity class dimension Returns: ExportMapping: root mapping """ root_mapping = unflatten( [ - RelationshipClassObjectHighlightingMapping( - relationship_class_position, highlight_dimension=highlight_dimension - ), + DimensionHighlightingMapping(entity_class_position, highlight_dimension=highlight_dimension), ParameterDefinitionMapping(definition_position), ] ) - _generate_dimensions(root_mapping.tail_mapping(), RelationshipClassObjectClassMapping, object_class_positions) + _generate_dimensions(root_mapping.tail_mapping(), DimensionMapping, dimension_positions) _generate_default_value_mappings( root_mapping.tail_mapping(), value_type_position, value_position, index_name_positions, index_positions ) return root_mapping -def relationship_object_parameter_export( - relationship_class_position=Position.hidden, +def entity_element_parameter_export( + entity_class_position=Position.hidden, definition_position=Position.hidden, value_list_position=Position.hidden, - relationship_position=Position.hidden, - object_class_positions=None, - object_positions=None, + entity_position=Position.hidden, + dimension_positions=None, + element_positions=None, alternative_position=Position.hidden, value_type_position=Position.hidden, value_position=Position.hidden, @@ -351,70 +253,64 @@ def relationship_object_parameter_export( highlight_dimension=0, ): """ - Sets up export mappings for exporting relationships and relationship parameters. + Sets up export mappings for exporting entities and element parameter values. Args: - relationship_class_position (int or Position): position of relationship classes + entity_class_position (int or Position): position of entity classes definition_position (int or Position): position of parameter definitions value_list_position (int or Position): position of parameter value lists - relationship_position (int or Position): position of relationships - object_class_positions (list of int, optional): positions of object classes - object_positions (list of int, optional): positions of objects + entity_position (int or Position): position of relationships + dimension_positions (list of int, optional): positions of object classes + element_positions (list of int, optional): positions of objects alternative_position (int or Position): positions of alternatives value_type_position (int or Position): position of parameter value types value_position (int or Position): position of parameter values index_name_positions (list of int, optional): positions of index names index_positions (list of int, optional): positions of parameter indexes - highlight_dimension (int): selected object class' relationship dimension + highlight_dimension (int): selected object class' entity dimension Returns: ExportMapping: root mapping """ - if object_class_positions is None: - object_class_positions = list() - if object_positions is None: - object_positions = list() - relationship_class = RelationshipClassObjectHighlightingMapping( - relationship_class_position, highlight_dimension=highlight_dimension - ) - object_or_relationship_class = _generate_dimensions( - relationship_class, RelationshipClassObjectClassMapping, object_class_positions - ) + if dimension_positions is None: + dimension_positions = list() + if element_positions is None: + element_positions = list() + entity_class = DimensionHighlightingMapping(entity_class_position, highlight_dimension=highlight_dimension) + dimension = _generate_dimensions(entity_class, DimensionMapping, dimension_positions) value_list = ParameterValueListMapping(value_list_position) value_list.set_ignorable(True) definition = ParameterDefinitionMapping(definition_position) - object_or_relationship_class.child = definition - relationship = RelationshipObjectHighlightingMapping(relationship_position) + dimension.child = definition + entity = ElementHighlightingMapping(entity_position) definition.child = value_list - value_list.child = relationship - object_or_relationship = _generate_dimensions(relationship, RelationshipObjectMapping, object_positions) + value_list.child = entity + element = _generate_dimensions(entity, ElementMapping, element_positions) _generate_parameter_value_mappings( - object_or_relationship, + element, alternative_position, value_type_position, value_position, index_name_positions, index_positions, ) - return relationship_class + return entity_class -def set_relationship_dimensions(relationship_mapping, dimensions): +def set_entity_dimensions(entity_mapping, dimensions): """ - Modifies given relationship mapping's dimensions (number of object classes and objects). + Modifies given entity mapping's dimensions. Args: - relationship_mapping (ExportMapping): a relationship mapping + entity_mapping (ExportMapping): an entity mapping dimensions (int): number of dimensions """ - mapping_list = relationship_mapping.flatten() + mapping_list = entity_mapping.flatten() mapping_list = _change_amount_of_consecutive_mappings( - mapping_list, RelationshipClassMapping, RelationshipClassObjectClassMapping, dimensions + mapping_list, EntityClassMapping, DimensionMapping, dimensions ) - if any(isinstance(m, RelationshipMapping) for m in mapping_list): - mapping_list = _change_amount_of_consecutive_mappings( - mapping_list, RelationshipMapping, RelationshipObjectMapping, dimensions - ) + if any(isinstance(m, EntityMapping) for m in mapping_list): + mapping_list = _change_amount_of_consecutive_mappings(mapping_list, EntityMapping, ElementMapping, dimensions) unflatten(mapping_list) diff --git a/spinedb_api/spine_io/exporters/excel.py b/spinedb_api/spine_io/exporters/excel.py index 5351cd51..69f69532 100644 --- a/spinedb_api/spine_io/exporters/excel.py +++ b/spinedb_api/spine_io/exporters/excel.py @@ -20,9 +20,9 @@ Position, AlternativeMapping, AlternativeDescriptionMapping, - ObjectClassMapping, - ObjectGroupMapping, - ObjectMapping, + EntityClassMapping, + EntityGroupMapping, + EntityMapping, FixedValueMapping, ScenarioMapping, ScenarioAlternativeMapping, @@ -33,9 +33,7 @@ ParameterValueTypeMapping, ParameterValueMapping, ExpandedParameterValueMapping, - RelationshipClassMapping, - RelationshipMapping, - RelationshipObjectMapping, + ElementMapping, ) from ...parameter_value import from_database_to_dimension_count from .excel_writer import ExcelWriter @@ -58,7 +56,7 @@ def start_table(self, table_name, title_key): def _make_preamble(table_name, title_key): if table_name in ("alternative", "scenario", "scenario_alternative"): return {"sheet_type": table_name} - class_name = title_key.get("object_class_name") or title_key.get("relationship_class_name") + class_name = title_key["entity_class_name"] if table_name.endswith(",group"): return {"sheet_type": "object_group", "class_name": class_name} object_class_id_list = title_key.get("object_class_id_list") @@ -127,10 +125,10 @@ def _make_scenario_alternative_mapping(): def _make_object_group_mappings(db_map): for obj_grp in db_map.query(db_map.ext_entity_group_sq).group_by(db_map.ext_entity_group_sq.c.class_name): - root_mapping = ObjectClassMapping(Position.table_name, filter_re=obj_grp.class_name) + root_mapping = EntityClassMapping(Position.table_name, filter_re=obj_grp.class_name) group_mapping = root_mapping.child = FixedValueMapping(Position.table_name, value="group") - object_mapping = group_mapping.child = ObjectMapping(1, header="member") - object_mapping.child = ObjectGroupMapping(0, header="group") + object_mapping = group_mapping.child = EntityMapping(1, header="member") + object_mapping.child = EntityGroupMapping(0, header="group") yield root_mapping @@ -156,9 +154,9 @@ def _make_indexed_parameter_value_mapping(alt_pos=-2, filter_re="array|time_patt def _make_object_mapping(object_class_name, pivoted=False): - root_mapping = ObjectClassMapping(Position.table_name, filter_re=f"^{object_class_name}$") + root_mapping = EntityClassMapping(Position.table_name, filter_re=f"^{object_class_name}$") pos = 0 if not pivoted else -1 - root_mapping.child = ObjectMapping(pos, header=object_class_name) + root_mapping.child = EntityMapping(pos, header=object_class_name) return root_mapping @@ -185,13 +183,13 @@ def _make_object_map_parameter_value_mapping(object_class_name, dim_count): def _make_relationship_mapping(relationship_class_name, object_class_name_list, pivoted=False): - root_mapping = RelationshipClassMapping(Position.table_name, filter_re=f"^{relationship_class_name}$") - relationship_mapping = root_mapping.child = RelationshipMapping(Position.hidden) + root_mapping = EntityClassMapping(Position.table_name, filter_re=f"^{relationship_class_name}$") + relationship_mapping = root_mapping.child = EntityMapping(Position.hidden) parent_mapping = relationship_mapping for d, class_name in enumerate(object_class_name_list): if pivoted: d = -(d + 1) - object_mapping = parent_mapping.child = RelationshipObjectMapping(d, header=class_name) + object_mapping = parent_mapping.child = ElementMapping(d, header=class_name) parent_mapping = object_mapping return root_mapping diff --git a/tests/export_mapping/test_export_mapping.py b/tests/export_mapping/test_export_mapping.py index ce9cb6bc..55820372 100644 --- a/tests/export_mapping/test_export_mapping.py +++ b/tests/export_mapping/test_export_mapping.py @@ -17,7 +17,6 @@ import unittest from spinedb_api import ( - DatabaseMapping, DatabaseMapping, import_alternatives, import_features, @@ -34,17 +33,15 @@ import_tool_feature_methods, import_tools, Map, - TimeSeriesFixedResolution, ) from spinedb_api.import_functions import import_object_groups from spinedb_api.mapping import Position, to_dict from spinedb_api.export_mapping import ( rows, titles, - object_parameter_default_value_export, - object_parameter_export, - relationship_export, - relationship_parameter_export, + entity_class_parameter_default_value_export, + entity_parameter_export, + entity_export, ) from spinedb_api.export_mapping.export_mapping import ( AlternativeMapping, @@ -55,10 +52,10 @@ FeatureEntityClassMapping, FeatureParameterDefinitionMapping, from_dict, - ObjectGroupMapping, - ObjectGroupObjectMapping, - ObjectMapping, - ObjectClassMapping, + EntityGroupMapping, + EntityGroupEntityMapping, + EntityMapping, + EntityClassMapping, ParameterDefaultValueMapping, ParameterDefaultValueIndexMapping, ParameterDefinitionMapping, @@ -67,12 +64,10 @@ ParameterValueListValueMapping, ParameterValueMapping, ParameterValueTypeMapping, - RelationshipClassMapping, - RelationshipClassObjectClassMapping, - RelationshipClassObjectHighlightingMapping, - RelationshipMapping, - RelationshipObjectHighlightingMapping, - RelationshipObjectMapping, + DimensionMapping, + DimensionHighlightingMapping, + ElementHighlightingMapping, + ElementMapping, ScenarioActiveFlagMapping, ScenarioAlternativeMapping, ScenarioMapping, @@ -90,7 +85,7 @@ class TestExportMapping(unittest.TestCase): def test_export_empty_table(self): db_map = DatabaseMapping("sqlite://", create=True) - object_class_mapping = ObjectClassMapping(0) + object_class_mapping = EntityClassMapping(0) self.assertEqual(list(rows(object_class_mapping, db_map)), []) db_map.connection.close() @@ -98,7 +93,7 @@ def test_export_single_object_class(self): db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("object_class",)) db_map.commit_session("Add test data.") - object_class_mapping = ObjectClassMapping(0) + object_class_mapping = EntityClassMapping(0) self.assertEqual(list(rows(object_class_mapping, db_map)), [["object_class"]]) db_map.connection.close() @@ -109,8 +104,8 @@ def test_export_objects(self): db_map, (("oc1", "o11"), ("oc1", "o12"), ("oc2", "o21"), ("oc3", "o31"), ("oc3", "o32"), ("oc3", "o33")) ) db_map.commit_session("Add test data.") - object_class_mapping = ObjectClassMapping(0) - object_class_mapping.child = ObjectMapping(1) + object_class_mapping = EntityClassMapping(0) + object_class_mapping.child = EntityMapping(1) self.assertEqual( list(rows(object_class_mapping, db_map)), [["oc1", "o11"], ["oc1", "o12"], ["oc2", "o21"], ["oc3", "o31"], ["oc3", "o32"], ["oc3", "o33"]], @@ -122,8 +117,8 @@ def test_hidden_tail(self): import_object_classes(db_map, ("oc1",)) import_objects(db_map, (("oc1", "o11"), ("oc1", "o12"))) db_map.commit_session("Add test data.") - object_class_mapping = ObjectClassMapping(0) - object_class_mapping.child = ObjectMapping(Position.hidden) + object_class_mapping = EntityClassMapping(0) + object_class_mapping.child = EntityMapping(Position.hidden) self.assertEqual(list(rows(object_class_mapping, db_map)), [["oc1"], ["oc1"]]) db_map.connection.close() @@ -132,8 +127,8 @@ def test_pivot_without_values(self): import_object_classes(db_map, ("oc1",)) import_objects(db_map, (("oc1", "o11"), ("oc1", "o12"))) db_map.commit_session("Add test data.") - object_class_mapping = ObjectClassMapping(-1) - object_class_mapping.child = ObjectMapping(Position.hidden) + object_class_mapping = EntityClassMapping(-1) + object_class_mapping.child = EntityMapping(Position.hidden) self.assertEqual(list(rows(object_class_mapping, db_map)), []) db_map.connection.close() @@ -146,9 +141,9 @@ def test_hidden_tail_pivoted(self): db_map.commit_session("Add test data.") root_mapping = unflatten( [ - ObjectClassMapping(0), + EntityClassMapping(0), ParameterDefinitionMapping(-1), - ObjectMapping(1), + EntityMapping(1), AlternativeMapping(2), ParameterValueMapping(Position.hidden), ] @@ -158,13 +153,13 @@ def test_hidden_tail_pivoted(self): db_map.connection.close() def test_hidden_leaf_item_in_regular_table_valid(self): - object_class_mapping = ObjectClassMapping(0) - object_class_mapping.child = ObjectMapping(Position.hidden) + object_class_mapping = EntityClassMapping(0) + object_class_mapping.child = EntityMapping(Position.hidden) self.assertEqual(object_class_mapping.check_validity(), []) def test_hidden_leaf_item_in_pivot_table_not_valid(self): - object_class_mapping = ObjectClassMapping(-1) - object_class_mapping.child = ObjectMapping(Position.hidden) + object_class_mapping = EntityClassMapping(-1) + object_class_mapping.child = EntityMapping(Position.hidden) self.assertEqual(object_class_mapping.check_validity(), ["Cannot be pivoted."]) def test_object_groups(self): @@ -173,7 +168,7 @@ def test_object_groups(self): import_objects(db_map, (("oc", "o1"), ("oc", "o2"), ("oc", "o3"), ("oc", "g1"), ("oc", "g2"))) import_object_groups(db_map, (("oc", "g1", "o1"), ("oc", "g1", "o2"), ("oc", "g2", "o3"))) db_map.commit_session("Add test data.") - flattened = [ObjectClassMapping(0), ObjectGroupMapping(1)] + flattened = [EntityClassMapping(0), EntityGroupMapping(1)] mapping = unflatten(flattened) self.assertEqual(list(rows(mapping, db_map)), [["oc", "g1"], ["oc", "g2"]]) db_map.connection.close() @@ -184,7 +179,7 @@ def test_object_groups_with_objects(self): import_objects(db_map, (("oc", "o1"), ("oc", "o2"), ("oc", "o3"), ("oc", "g1"), ("oc", "g2"))) import_object_groups(db_map, (("oc", "g1", "o1"), ("oc", "g1", "o2"), ("oc", "g2", "o3"))) db_map.commit_session("Add test data.") - flattened = [ObjectClassMapping(0), ObjectGroupMapping(1), ObjectGroupObjectMapping(2)] + flattened = [EntityClassMapping(0), EntityGroupMapping(1), EntityGroupEntityMapping(2)] mapping = unflatten(flattened) self.assertEqual(list(rows(mapping, db_map)), [["oc", "g1", "o1"], ["oc", "g1", "o2"], ["oc", "g2", "o3"]]) db_map.connection.close() @@ -200,9 +195,9 @@ def test_object_groups_with_parameter_values(self): ) db_map.commit_session("Add test data.") flattened = [ - ObjectClassMapping(0), - ObjectGroupMapping(1), - ObjectGroupObjectMapping(2), + EntityClassMapping(0), + EntityGroupMapping(1), + EntityGroupEntityMapping(2), ParameterDefinitionMapping(Position.hidden), AlternativeMapping(Position.hidden), ParameterValueMapping(3), @@ -220,9 +215,9 @@ def test_export_parameter_definitions(self): import_object_parameters(db_map, (("oc1", "p11"), ("oc1", "p12"), ("oc2", "p21"))) import_objects(db_map, (("oc1", "o11"), ("oc1", "o12"), ("oc2", "o21"))) db_map.commit_session("Add test data.") - object_class_mapping = ObjectClassMapping(0) + object_class_mapping = EntityClassMapping(0) parameter_definition_mapping = ParameterDefinitionMapping(1) - parameter_definition_mapping.child = ObjectMapping(2) + parameter_definition_mapping.child = EntityMapping(2) object_class_mapping.child = parameter_definition_mapping expected = [ ["oc1", "p11", "o11"], @@ -241,11 +236,11 @@ def test_export_single_parameter_value_when_there_are_multiple_objects(self): import_objects(db_map, (("oc1", "o11"), ("oc1", "o12"), ("oc2", "o21"))) import_object_parameter_values(db_map, (("oc1", "o11", "p12", -11.0),)) db_map.commit_session("Add test data.") - object_class_mapping = ObjectClassMapping(0) + object_class_mapping = EntityClassMapping(0) parameter_definition_mapping = ParameterDefinitionMapping(1) alternative_mapping = AlternativeMapping(Position.hidden) parameter_definition_mapping.child = alternative_mapping - object_mapping = ObjectMapping(2) + object_mapping = EntityMapping(2) alternative_mapping.child = object_mapping value_mapping = ParameterValueMapping(3) object_mapping.child = value_mapping @@ -268,11 +263,11 @@ def test_export_single_parameter_value_pivoted_by_object_name(self): ), ) db_map.commit_session("Add test data.") - object_class_mapping = ObjectClassMapping(0) + object_class_mapping = EntityClassMapping(0) parameter_definition_mapping = ParameterDefinitionMapping(1) alternative_mapping = AlternativeMapping(Position.hidden) parameter_definition_mapping.child = alternative_mapping - object_mapping = ObjectMapping(-1) + object_mapping = EntityMapping(-1) alternative_mapping.child = object_mapping value_mapping = ParameterValueMapping(-2) object_mapping.child = value_mapping @@ -295,7 +290,9 @@ def test_minimum_pivot_index_need_not_be_minus_one(self): ), ) db_map.commit_session("Add test data.") - mapping = object_parameter_export(1, 2, Position.hidden, 0, -2, Position.hidden, 4, [Position.hidden], [3]) + mapping = entity_parameter_export( + 1, 2, Position.hidden, 0, None, None, -2, Position.hidden, 4, [Position.hidden], [3] + ) expected = [ [None, None, None, None, "Base", "alt"], ["o", "oc", "p", "A", -1.1, -5.5], @@ -319,10 +316,10 @@ def test_pivot_row_order(self): ), ) db_map.commit_session("Add test data.") - object_class_mapping = ObjectClassMapping(0) + object_class_mapping = EntityClassMapping(0) parameter_definition_mapping = ParameterDefinitionMapping(-1) alternative_mapping = AlternativeMapping(Position.hidden) - object_mapping = ObjectMapping(-2) + object_mapping = EntityMapping(-2) value_mapping = ParameterValueMapping(3) mappings = [ object_class_mapping, @@ -363,11 +360,11 @@ def test_export_parameter_indexes(self): ), ) db_map.commit_session("Add test data.") - object_class_mapping = ObjectClassMapping(0) + object_class_mapping = EntityClassMapping(0) parameter_definition_mapping = ParameterDefinitionMapping(2) alternative_mapping = AlternativeMapping(Position.hidden) parameter_definition_mapping.child = alternative_mapping - object_mapping = ObjectMapping(1) + object_mapping = EntityMapping(1) alternative_mapping.child = object_mapping index_mapping = ParameterValueIndexMapping(3) object_mapping.child = index_mapping @@ -398,11 +395,11 @@ def test_export_nested_parameter_indexes(self): ), ) db_map.commit_session("Add test data.") - object_class_mapping = ObjectClassMapping(0) + object_class_mapping = EntityClassMapping(0) parameter_definition_mapping = ParameterDefinitionMapping(2) alternative_mapping = AlternativeMapping(Position.hidden) parameter_definition_mapping.child = alternative_mapping - object_mapping = ObjectMapping(1) + object_mapping = EntityMapping(1) alternative_mapping.child = object_mapping index_mapping_1 = ParameterValueIndexMapping(3) index_mapping_2 = ParameterValueIndexMapping(4) @@ -433,9 +430,9 @@ def test_export_nested_map_values_only(self): ), ) db_map.commit_session("Add test data.") - object_class_mapping = ObjectClassMapping(Position.hidden) + object_class_mapping = EntityClassMapping(Position.hidden) parameter_definition_mapping = ParameterDefinitionMapping(Position.hidden) - object_mapping = ObjectMapping(Position.hidden) + object_mapping = EntityMapping(Position.hidden) parameter_definition_mapping.child = object_mapping alternative_mapping = AlternativeMapping(Position.hidden) object_mapping.child = alternative_mapping @@ -463,11 +460,11 @@ def test_full_pivot_table(self): ), ) db_map.commit_session("Add test data.") - object_class_mapping = ObjectClassMapping(0) + object_class_mapping = EntityClassMapping(0) parameter_definition_mapping = ParameterDefinitionMapping(1) alternative_mapping = AlternativeMapping(Position.hidden) parameter_definition_mapping.child = alternative_mapping - object_mapping = ObjectMapping(-1) + object_mapping = EntityMapping(-1) alternative_mapping.child = object_mapping index_mapping_1 = ParameterValueIndexMapping(2) index_mapping_2 = ParameterValueIndexMapping(3) @@ -495,7 +492,9 @@ def test_full_pivot_table_with_hidden_columns(self): db_map, (("oc", "o1", "p", Map(["A", "B"], [-1.1, -2.2])), ("oc", "o2", "p", Map(["A", "B"], [-5.5, -6.6]))) ) db_map.commit_session("Add test data.") - mapping = object_parameter_export(0, 2, Position.hidden, -1, 3, Position.hidden, 5, [Position.hidden], [4]) + mapping = entity_parameter_export( + 0, 2, Position.hidden, -1, None, None, 3, Position.hidden, 5, [Position.hidden], [4] + ) expected = [ [None, None, None, None, None, "o1", "o2"], ["oc", None, "p", "Base", "A", -1.1, -5.5], @@ -520,7 +519,9 @@ def test_objects_as_pivot_header_for_indexed_values_with_alternatives(self): ), ) db_map.commit_session("Add test data.") - mapping = object_parameter_export(0, 2, Position.hidden, -1, 3, Position.hidden, 5, [Position.hidden], [4]) + mapping = entity_parameter_export( + 0, 2, Position.hidden, -1, None, None, 3, Position.hidden, 5, [Position.hidden], [4] + ) expected = [ [None, None, None, None, None, "o1", "o2"], ["oc", None, "p", "Base", "A", -1.1, -5.5], @@ -540,7 +541,9 @@ def test_objects_and_indexes_as_pivot_header(self): db_map, (("oc", "o1", "p", Map(["A", "B"], [-1.1, -2.2])), ("oc", "o2", "p", Map(["A", "B"], [-3.3, -4.4]))) ) db_map.commit_session("Add test data.") - mapping = object_parameter_export(0, 2, Position.hidden, -1, 3, Position.hidden, 4, [Position.hidden], [-2]) + mapping = entity_parameter_export( + 0, 2, Position.hidden, -1, None, None, 3, Position.hidden, 4, [Position.hidden], [-2] + ) expected = [ [None, None, None, None, "o1", "o1", "o2", "o2"], [None, None, None, None, "A", "B", "A", "B"], @@ -570,7 +573,9 @@ def test_objects_and_indexes_as_pivot_header_with_multiple_alternatives_and_para ), ) db_map.commit_session("Add test data.") - mapping = object_parameter_export(0, 1, Position.hidden, -1, -2, Position.hidden, 2, [Position.hidden], [-3]) + mapping = entity_parameter_export( + 0, 1, Position.hidden, -1, None, None, -2, Position.hidden, 2, [Position.hidden], [-3] + ) expected = [ [None, None, "o1", "o1", "o1", "o1", "o2", "o2", "o2", "o2"], [None, None, "Base", "Base", "alt", "alt", "Base", "Base", "alt", "alt"], @@ -588,10 +593,10 @@ def test_empty_column_while_pivoted_handled_gracefully(self): import_object_parameters(db_map, (("oc", "p"),)) import_objects(db_map, (("oc", "o"),)) db_map.commit_session("Add test data.") - mapping = ObjectClassMapping(0) + mapping = EntityClassMapping(0) definition = ParameterDefinitionMapping(1) value_list = ParameterValueListMapping(2) - object_ = ObjectMapping(-1) + object_ = EntityMapping(-1) value_list.child = object_ definition.child = value_list mapping.child = definition @@ -605,8 +610,8 @@ def test_object_classes_as_header_row_and_objects_in_columns(self): db_map, (("oc1", "o11"), ("oc1", "o12"), ("oc2", "o21"), ("oc3", "o31"), ("oc3", "o32"), ("oc3", "o33")) ) db_map.commit_session("Add test data.") - object_class_mapping = ObjectClassMapping(-1) - object_class_mapping.child = ObjectMapping(0) + object_class_mapping = EntityClassMapping(-1) + object_class_mapping.child = EntityMapping(0) self.assertEqual( list(rows(object_class_mapping, db_map)), [["oc1", "oc2", "oc3"], ["o11", "o21", "o31"], ["o12", None, "o32"], [None, None, "o33"]], @@ -620,8 +625,8 @@ def test_object_classes_as_table_names(self): db_map, (("oc1", "o11"), ("oc1", "o12"), ("oc2", "o21"), ("oc3", "o31"), ("oc3", "o32"), ("oc3", "o33")) ) db_map.commit_session("Add test data.") - object_class_mapping = ObjectClassMapping(Position.table_name) - object_class_mapping.child = ObjectMapping(0) + object_class_mapping = EntityClassMapping(Position.table_name) + object_class_mapping.child = EntityMapping(0) tables = dict() for title, title_key in titles(object_class_mapping, db_map): tables[title] = list(rows(object_class_mapping, db_map, title_key)) @@ -634,9 +639,9 @@ def test_object_class_and_parameter_definition_as_table_name(self): import_object_parameters(db_map, (("oc1", "p11"), ("oc2", "p21"), ("oc2", "p22"))) import_objects(db_map, (("oc1", "o11"), ("oc1", "o12"), ("oc2", "o21"), ("oc3", "o31"))) db_map.commit_session("Add test data.") - object_class_mapping = ObjectClassMapping(Position.table_name) + object_class_mapping = EntityClassMapping(Position.table_name) definition_mapping = ParameterDefinitionMapping(Position.table_name) - object_mapping = ObjectMapping(0) + object_mapping = EntityMapping(0) object_class_mapping.child = definition_mapping definition_mapping.child = object_mapping tables = dict() @@ -654,7 +659,7 @@ def test_object_relationship_name_as_table_name(self): import_relationship_classes(db_map, (("rc", ("oc1", "oc2")),)) import_relationships(db_map, (("rc", ("o1", "O")), ("rc", ("o2", "O")))) db_map.commit_session("Add test data.") - mappings = relationship_export(0, Position.table_name, [1, 2], [Position.table_name, 3]) + mappings = entity_export(0, Position.table_name, [1, 2], [Position.table_name, 3]) tables = dict() for title, title_key in titles(mappings, db_map): tables[title] = list(rows(mappings, db_map, title_key)) @@ -669,7 +674,7 @@ def test_parameter_definitions_with_value_lists(self): import_parameter_value_lists(db_map, (("vl1", -1.0), ("vl2", -2.0))) import_object_parameters(db_map, (("oc", "p1", None, "vl1"), ("oc", "p2"))) db_map.commit_session("Add test data.") - class_mapping = ObjectClassMapping(0) + class_mapping = EntityClassMapping(0) definition_mapping = ParameterDefinitionMapping(1) value_list_mapping = ParameterValueListMapping(2) definition_mapping.child = value_list_mapping @@ -689,11 +694,11 @@ def test_parameter_definitions_and_values_and_value_lists(self): import_object_parameter_values(db_map, (("oc", "o", "p1", -1.0), ("oc", "o", "p2", 5.0))) db_map.commit_session("Add test data.") flattened = [ - ObjectClassMapping(0), + EntityClassMapping(0), ParameterDefinitionMapping(1), AlternativeMapping(Position.hidden), ParameterValueListMapping(2), - ObjectMapping(3), + EntityMapping(3), ParameterValueMapping(4), ] mapping = unflatten(flattened) @@ -714,11 +719,11 @@ def test_parameter_definitions_and_values_and_ignorable_value_lists(self): value_list_mapping = ParameterValueListMapping(2) value_list_mapping.set_ignorable(True) flattened = [ - ObjectClassMapping(0), + EntityClassMapping(0), ParameterDefinitionMapping(1), AlternativeMapping(Position.hidden), value_list_mapping, - ObjectMapping(3), + EntityMapping(3), ParameterValueMapping(4), ] mapping = unflatten(flattened) @@ -758,9 +763,9 @@ def test_no_item_declared_as_title_gives_full_table(self): import_object_parameters(db_map, (("oc1", "p11"), ("oc2", "p21"), ("oc2", "p22"))) import_objects(db_map, (("oc1", "o11"), ("oc1", "o12"), ("oc2", "o21"), ("oc3", "o31"))) db_map.commit_session("Add test data.") - object_class_mapping = ObjectClassMapping(Position.hidden) + object_class_mapping = EntityClassMapping(Position.hidden) definition_mapping = ParameterDefinitionMapping(Position.hidden) - object_mapping = ObjectMapping(0) + object_mapping = EntityMapping(0) object_class_mapping.child = definition_mapping definition_mapping.child = object_mapping tables = dict() @@ -785,9 +790,9 @@ def test_missing_values_for_alternatives(self): ), ) db_map.commit_session("Add test data.") - object_class_mapping = ObjectClassMapping(0) + object_class_mapping = EntityClassMapping(0) definition_mapping = ParameterDefinitionMapping(2) - object_mapping = ObjectMapping(1) + object_mapping = EntityMapping(1) alternative_mapping = AlternativeMapping(3) value_mapping = ParameterValueMapping(4) object_class_mapping.child = definition_mapping @@ -810,7 +815,7 @@ def test_export_relationship_classes(self): db_map, (("rc1", ("oc1",)), ("rc2", ("oc3", "oc2")), ("rc3", ("oc2", "oc3", "oc1"))) ) db_map.commit_session("Add test data.") - relationship_class_mapping = RelationshipClassMapping(0) + relationship_class_mapping = EntityClassMapping(0) self.assertEqual(list(rows(relationship_class_mapping, db_map)), [["rc1"], ["rc2"], ["rc3"]]) db_map.connection.close() @@ -821,8 +826,8 @@ def test_export_relationships(self): import_relationship_classes(db_map, (("rc1", ("oc1",)), ("rc2", ("oc2", "oc1")))) import_relationships(db_map, (("rc1", ("o11",)), ("rc2", ("o21", "o11")), ("rc2", ("o21", "o12")))) db_map.commit_session("Add test data.") - relationship_class_mapping = RelationshipClassMapping(0) - relationship_mapping = RelationshipMapping(1) + relationship_class_mapping = EntityClassMapping(0) + relationship_mapping = EntityMapping(1) relationship_class_mapping.child = relationship_mapping expected = [["rc1", "rc1_o11"], ["rc2", "rc2_o21__o11"], ["rc2", "rc2_o21__o12"]] self.assertEqual(list(rows(relationship_class_mapping, db_map)), expected) @@ -839,12 +844,12 @@ def test_relationships_with_different_dimensions(self): (("rc2D", ("o11", "o21")), ("rc2D", ("o11", "o22")), ("rc2D", ("o12", "o21")), ("rc2D", ("o12", "o22"))), ) db_map.commit_session("Add test data.") - relationship_class_mapping = RelationshipClassMapping(0) - object_class_mapping1 = RelationshipClassObjectClassMapping(1) - object_class_mapping2 = RelationshipClassObjectClassMapping(2) - relationship_mapping = RelationshipMapping(Position.hidden) - object_mapping1 = RelationshipObjectMapping(3) - object_mapping2 = RelationshipObjectMapping(4) + relationship_class_mapping = EntityClassMapping(0) + object_class_mapping1 = DimensionMapping(1) + object_class_mapping2 = DimensionMapping(2) + relationship_mapping = EntityMapping(Position.hidden) + object_mapping1 = ElementMapping(3) + object_mapping2 = ElementMapping(4) object_mapping1.child = object_mapping2 relationship_mapping.child = object_mapping1 object_class_mapping2.child = relationship_mapping @@ -869,7 +874,7 @@ def test_default_parameter_values(self): import_object_classes(db_map, ("oc1", "oc2", "oc3")) import_object_parameters(db_map, (("oc1", "p11", 3.14), ("oc2", "p21", 14.3), ("oc2", "p22", -1.0))) db_map.commit_session("Add test data.") - object_class_mapping = ObjectClassMapping(0) + object_class_mapping = EntityClassMapping(0) definition_mapping = ParameterDefinitionMapping(1) default_value_mapping = ParameterDefaultValueMapping(2) definition_mapping.child = default_value_mapping @@ -890,7 +895,7 @@ def test_indexed_default_parameter_values(self): ), ) db_map.commit_session("Add test data.") - object_class_mapping = ObjectClassMapping(0) + object_class_mapping = EntityClassMapping(0) definition_mapping = ParameterDefinitionMapping(1) index_mapping = ParameterDefaultValueIndexMapping(2) value_mapping = ExpandedParameterDefaultValueMapping(3) @@ -917,11 +922,11 @@ def test_replace_parameter_indexes_by_external_data(self): db_map, (("oc", "o1", "p1", Map(["a", "b"], [5.0, -5.0])), ("oc", "o2", "p1", Map(["a", "b"], [2.0, -2.0]))) ) db_map.commit_session("Add test data.") - object_class_mapping = ObjectClassMapping(0) + object_class_mapping = EntityClassMapping(0) parameter_definition_mapping = ParameterDefinitionMapping(2) alternative_mapping = AlternativeMapping(Position.hidden) parameter_definition_mapping.child = alternative_mapping - object_mapping = ObjectMapping(1) + object_mapping = EntityMapping(1) alternative_mapping.child = object_mapping index_mapping = ParameterValueIndexMapping(3) value_mapping = ExpandedParameterValueMapping(4) @@ -943,7 +948,7 @@ def test_constant_mapping_as_title(self): import_object_classes(db_map, ("oc1", "oc2", "oc3")) db_map.commit_session("Add test data.") constant_mapping = FixedValueMapping(Position.table_name, "title_text") - object_class_mapping = ObjectClassMapping(0) + object_class_mapping = EntityClassMapping(0) constant_mapping.child = object_class_mapping tables = dict() for title, title_key in titles(constant_mapping, db_map): @@ -1107,21 +1112,21 @@ def test_header(self): import_object_classes(db_map, ("oc",)) import_objects(db_map, (("oc", "o1"),)) db_map.commit_session("Add test data.") - root = unflatten([ObjectClassMapping(0, header="class"), ObjectMapping(1, header="entity")]) + root = unflatten([EntityClassMapping(0, header="class"), EntityMapping(1, header="entity")]) expected = [["class", "entity"], ["oc", "o1"]] self.assertEqual(list(rows(root, db_map)), expected) db_map.connection.close() def test_header_without_data_still_creates_header(self): db_map = DatabaseMapping("sqlite://", create=True) - root = unflatten([ObjectClassMapping(0, header="class"), ObjectMapping(1, header="object")]) + root = unflatten([EntityClassMapping(0, header="class"), EntityMapping(1, header="object")]) expected = [["class", "object"]] self.assertEqual(list(rows(root, db_map)), expected) db_map.connection.close() def test_header_in_half_pivot_table_without_data_still_creates_header(self): db_map = DatabaseMapping("sqlite://", create=True) - root = unflatten([ObjectClassMapping(-1, header="class"), ObjectMapping(9, header="object")]) + root = unflatten([EntityClassMapping(-1, header="class"), EntityMapping(9, header="object")]) expected = [["class"]] self.assertEqual(list(rows(root, db_map)), expected) db_map.connection.close() @@ -1130,9 +1135,9 @@ def test_header_in_pivot_table_without_data_still_creates_header(self): db_map = DatabaseMapping("sqlite://", create=True) root = unflatten( [ - ObjectClassMapping(-1, header="class"), + EntityClassMapping(-1, header="class"), ParameterDefinitionMapping(0, header="parameter"), - ObjectMapping(-2, header="object"), + EntityMapping(-2, header="object"), AlternativeMapping(1, header="alternative"), ParameterValueMapping(0), ] @@ -1143,14 +1148,14 @@ def test_header_in_pivot_table_without_data_still_creates_header(self): def test_disabled_empty_data_header(self): db_map = DatabaseMapping("sqlite://", create=True) - root = unflatten([ObjectClassMapping(0, header="class"), ObjectMapping(1, header="object")]) + root = unflatten([EntityClassMapping(0, header="class"), EntityMapping(1, header="object")]) expected = [] self.assertEqual(list(rows(root, db_map, empty_data_header=False)), expected) db_map.connection.close() def test_disabled_empty_data_header_in_pivot_table(self): db_map = DatabaseMapping("sqlite://", create=True) - root = unflatten([ObjectClassMapping(-1, header="class"), ObjectMapping(0)]) + root = unflatten([EntityClassMapping(-1, header="class"), EntityMapping(0)]) expected = [] self.assertEqual(list(rows(root, db_map, empty_data_header=False)), expected) db_map.connection.close() @@ -1160,7 +1165,7 @@ def test_header_position(self): import_object_classes(db_map, ("oc",)) import_objects(db_map, (("oc", "o1"),)) db_map.commit_session("Add test data.") - root = unflatten([ObjectClassMapping(Position.header), ObjectMapping(0)]) + root = unflatten([EntityClassMapping(Position.header), EntityMapping(0)]) expected = [["oc"], ["o1"]] self.assertEqual(list(rows(root, db_map)), expected) db_map.connection.close() @@ -1174,15 +1179,18 @@ def test_header_position_with_relationships(self): db_map.commit_session("Add test data.") root = unflatten( [ - RelationshipClassMapping(0), - RelationshipClassObjectClassMapping(Position.header), - RelationshipClassObjectClassMapping(Position.header), - RelationshipMapping(1), - RelationshipObjectMapping(2), - RelationshipObjectMapping(3), + EntityClassMapping(0), + DimensionMapping(Position.header), + DimensionMapping(Position.header), + EntityMapping(1), + ElementMapping(2), + ElementMapping(3), ] ) expected = [["", "", "oc1", "oc2"], ["rc", "rc_o11__o21", "o11", "o21"]] + import pprint + + pprint.pprint(list(rows(root, db_map))) self.assertEqual(list(rows(root, db_map)), expected) db_map.connection.close() @@ -1193,12 +1201,12 @@ def test_header_position_with_relationships_but_no_data(self): db_map.commit_session("Add test data.") root = unflatten( [ - RelationshipClassMapping(0), - RelationshipClassObjectClassMapping(Position.header), - RelationshipClassObjectClassMapping(Position.header), - RelationshipMapping(1), - RelationshipObjectMapping(2), - RelationshipObjectMapping(3), + EntityClassMapping(0), + DimensionMapping(Position.header), + DimensionMapping(Position.header), + EntityMapping(1), + ElementMapping(2), + ElementMapping(3), ] ) expected = [["", "", "oc1", "oc2"]] @@ -1228,9 +1236,9 @@ def test_header_and_pivot(self): db_map.commit_session("Add test data.") mapping = unflatten( [ - ObjectClassMapping(0, header="class"), + EntityClassMapping(0, header="class"), ParameterDefinitionMapping(1, header="parameter"), - ObjectMapping(-1, header="object"), + EntityMapping(-1, header="object"), AlternativeMapping(-2, header="alternative"), ParameterValueIndexMapping(-3, header=""), ExpandedParameterValueMapping(2, header="value"), @@ -1269,9 +1277,9 @@ def test_pivot_without_left_hand_side_has_padding_column_for_headers(self): db_map.commit_session("Add test data.") mapping = unflatten( [ - ObjectClassMapping(Position.header), + EntityClassMapping(Position.header), ParameterDefinitionMapping(Position.hidden, header="parameter"), - ObjectMapping(-1), + EntityMapping(-1), AlternativeMapping(-2, header="alternative"), ParameterValueIndexMapping(-3, header="index"), ExpandedParameterValueMapping(2, header="value"), @@ -1288,42 +1296,42 @@ def test_pivot_without_left_hand_side_has_padding_column_for_headers(self): db_map.connection.close() def test_count_mappings(self): - object_class_mapping = ObjectClassMapping(2) + object_class_mapping = EntityClassMapping(2) parameter_definition_mapping = ParameterDefinitionMapping(0) - object_mapping = ObjectMapping(1) + object_mapping = EntityMapping(1) parameter_definition_mapping.child = object_mapping object_class_mapping.child = parameter_definition_mapping self.assertEqual(object_class_mapping.count_mappings(), 3) def test_flatten(self): - object_class_mapping = ObjectClassMapping(2) + object_class_mapping = EntityClassMapping(2) parameter_definition_mapping = ParameterDefinitionMapping(0) - object_mapping = ObjectMapping(1) + object_mapping = EntityMapping(1) parameter_definition_mapping.child = object_mapping object_class_mapping.child = parameter_definition_mapping mappings = object_class_mapping.flatten() self.assertEqual(mappings, [object_class_mapping, parameter_definition_mapping, object_mapping]) def test_unflatten_sets_last_mappings_child_to_none(self): - object_class_mapping = ObjectClassMapping(2) - object_mapping = ObjectMapping(1) + object_class_mapping = EntityClassMapping(2) + object_mapping = EntityMapping(1) object_class_mapping.child = object_mapping mapping_list = object_class_mapping.flatten() root = unflatten(mapping_list[:1]) self.assertIsNone(root.child) def test_has_titles(self): - object_class_mapping = ObjectClassMapping(0) + object_class_mapping = EntityClassMapping(0) parameter_definition_mapping = ParameterDefinitionMapping(Position.table_name) - object_mapping = ObjectMapping(1) + object_mapping = EntityMapping(1) parameter_definition_mapping.child = object_mapping object_class_mapping.child = parameter_definition_mapping self.assertTrue(object_class_mapping.has_titles()) def test_drop_non_positioned_tail(self): - object_class_mapping = ObjectClassMapping(0) + object_class_mapping = EntityClassMapping(0) parameter_definition_mapping = ParameterDefinitionMapping(Position.hidden) - object_mapping = ObjectMapping(1) + object_mapping = EntityMapping(1) alternative_mapping = AlternativeMapping(Position.hidden) value_mapping = ParameterValueMapping(Position.hidden) alternative_mapping.child = value_mapping @@ -1337,15 +1345,15 @@ def test_drop_non_positioned_tail(self): def test_serialization(self): highlight_dimension = 5 mappings = [ - ObjectClassMapping(0), - RelationshipClassMapping(Position.table_name), - RelationshipClassObjectHighlightingMapping(Position.header, highlight_dimension=highlight_dimension), - RelationshipClassObjectClassMapping(2), + EntityClassMapping(0), + EntityClassMapping(Position.table_name), + DimensionHighlightingMapping(Position.header, highlight_dimension=highlight_dimension), + DimensionMapping(2), ParameterDefinitionMapping(1), - ObjectMapping(-1), - RelationshipMapping(Position.hidden), - RelationshipObjectHighlightingMapping(9), - RelationshipObjectMapping(-1), + EntityMapping(-1), + EntityMapping(Position.hidden), + ElementHighlightingMapping(9), + ElementMapping(-1), AlternativeMapping(3), ParameterValueMapping(4), ParameterValueIndexMapping(5), @@ -1363,15 +1371,15 @@ def test_serialization(self): for m in deserialized: if isinstance(m, FixedValueMapping): self.assertEqual(m.value, "gaga") - elif isinstance(m, RelationshipClassObjectHighlightingMapping): + elif isinstance(m, DimensionHighlightingMapping): self.assertEqual(m.highlight_dimension, highlight_dimension) def test_setting_ignorable_flag(self): db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) db_map.commit_session("Add test data.") - object_mapping = ObjectMapping(1) - root_mapping = unflatten([ObjectClassMapping(0), object_mapping]) + object_mapping = EntityMapping(1) + root_mapping = unflatten([EntityClassMapping(0), object_mapping]) object_mapping.set_ignorable(True) self.assertTrue(object_mapping.is_ignorable()) expected = [["oc", None]] @@ -1383,8 +1391,8 @@ def test_unsetting_ignorable_flag(self): import_object_classes(db_map, ("oc",)) import_objects(db_map, (("oc", "o1"),)) db_map.commit_session("Add test data.") - object_mapping = ObjectMapping(1) - root_mapping = unflatten([ObjectClassMapping(0), object_mapping]) + object_mapping = EntityMapping(1) + root_mapping = unflatten([EntityClassMapping(0), object_mapping]) object_mapping.set_ignorable(True) self.assertTrue(object_mapping.is_ignorable()) object_mapping.set_ignorable(False) @@ -1398,9 +1406,9 @@ def test_filter(self): import_object_classes(db_map, ("oc",)) import_objects(db_map, (("oc", "o1"), ("oc", "o2"))) db_map.commit_session("Add test data.") - object_mapping = ObjectMapping(1) + object_mapping = EntityMapping(1) object_mapping.filter_re = "o1" - root_mapping = unflatten([ObjectClassMapping(0), object_mapping]) + root_mapping = unflatten([EntityClassMapping(0), object_mapping]) expected = [["oc", "o1"]] self.assertEqual(list(rows(root_mapping, db_map)), expected) db_map.connection.close() @@ -1410,9 +1418,9 @@ def test_hidden_tail_filter(self): import_object_classes(db_map, ("oc1", "oc2")) import_objects(db_map, (("oc1", "o1"), ("oc2", "o2"))) db_map.commit_session("Add test data.") - object_mapping = ObjectMapping(Position.hidden) + object_mapping = EntityMapping(Position.hidden) object_mapping.filter_re = "o1" - root_mapping = unflatten([ObjectClassMapping(0), object_mapping]) + root_mapping = unflatten([EntityClassMapping(0), object_mapping]) expected = [["oc1"]] self.assertEqual(list(rows(root_mapping, db_map)), expected) db_map.connection.close() @@ -1424,7 +1432,9 @@ def test_index_names(self): import_objects(db_map, (("oc", "o"),)) import_object_parameter_values(db_map, (("oc", "o", "p", Map(["a"], [5.0], index_name="index")),)) db_map.commit_session("Add test data.") - mapping = object_parameter_export(0, 2, Position.hidden, 1, 3, Position.hidden, 5, [Position.header], [4]) + mapping = entity_parameter_export( + 0, 2, Position.hidden, 1, None, None, 3, Position.hidden, 5, [Position.header], [4] + ) expected = [["", "", "", "", "index", ""], ["oc", "o", "p", "Base", "a", 5.0]] self.assertEqual(list(rows(mapping, db_map)), expected) db_map.connection.close() @@ -1436,7 +1446,7 @@ def test_default_value_index_names_with_nested_map(self): db_map, (("oc", "p", Map(["A"], [Map(["b"], [2.3], index_name="idx2")], index_name="idx1")),) ) db_map.commit_session("Add test data.") - mapping = object_parameter_default_value_export( + mapping = entity_class_parameter_default_value_export( 0, 1, Position.hidden, 4, [Position.header, Position.header], [2, 3] ) expected = [["", "", "idx1", "idx2", ""], ["oc", "p", "A", "b", 2.3]] @@ -1445,7 +1455,7 @@ def test_default_value_index_names_with_nested_map(self): def test_multiple_index_names_with_empty_database(self): db_map = DatabaseMapping("sqlite://", create=True) - mapping = relationship_parameter_export( + mapping = entity_parameter_export( 0, 4, Position.hidden, 1, [2], [3], 5, Position.hidden, 8, [Position.header, Position.header], [6, 7] ) expected = [9 * [""]] @@ -1457,7 +1467,7 @@ def test_parameter_default_value_type(self): import_object_classes(db_map, ("oc1", "oc2", "oc3")) import_object_parameters(db_map, (("oc1", "p11", 3.14), ("oc2", "p21", 14.3), ("oc2", "p22", -1.0))) db_map.commit_session("Add test data.") - root_mapping = object_parameter_default_value_export(0, 1, 2, 3, None, None) + root_mapping = entity_class_parameter_default_value_export(0, 1, 2, 3, None, None) expected = [ ["oc1", "p11", "single_value", 3.14], ["oc2", "p21", "single_value", 14.3], @@ -1473,8 +1483,8 @@ def test_map_with_more_dimensions_than_index_mappings(self): import_objects(db_map, (("oc", "o"),)) import_object_parameter_values(db_map, (("oc", "o", "p", Map(["A"], [Map(["b"], [2.3])])),)) db_map.commit_session("Add test data.") - mapping = object_parameter_export( - 0, 1, Position.hidden, 2, Position.hidden, Position.hidden, 4, [Position.hidden], [3] + mapping = entity_parameter_export( + 0, 1, Position.hidden, 2, None, None, Position.hidden, Position.hidden, 4, [Position.hidden], [3] ) expected = [["oc", "p", "o", "A", "map"]] self.assertEqual(list(rows(mapping, db_map)), expected) @@ -1485,7 +1495,7 @@ def test_default_map_value_with_more_dimensions_than_index_mappings(self): import_object_classes(db_map, ("oc",)) import_object_parameters(db_map, (("oc", "p", Map(["A"], [Map(["b"], [2.3])])),)) db_map.commit_session("Add test data.") - mapping = object_parameter_default_value_export(0, 1, Position.hidden, 3, [Position.hidden], [2]) + mapping = entity_class_parameter_default_value_export(0, 1, Position.hidden, 3, [Position.hidden], [2]) expected = [["oc", "p", "A", "map"]] self.assertEqual(list(rows(mapping, db_map)), expected) db_map.connection.close() @@ -1497,7 +1507,9 @@ def test_map_with_single_value_mapping(self): import_objects(db_map, (("oc", "o"),)) import_object_parameter_values(db_map, (("oc", "o", "p", Map(["A"], [2.3])),)) db_map.commit_session("Add test data.") - mapping = object_parameter_export(0, 1, Position.hidden, 2, Position.hidden, Position.hidden, 3, None, None) + mapping = entity_parameter_export( + 0, 1, Position.hidden, 2, None, None, Position.hidden, Position.hidden, 3, None, None + ) expected = [["oc", "p", "o", "map"]] self.assertEqual(list(rows(mapping, db_map)), expected) db_map.connection.close() @@ -1507,7 +1519,7 @@ def test_default_map_value_with_single_value_mapping(self): import_object_classes(db_map, ("oc",)) import_object_parameters(db_map, (("oc", "p", Map(["A"], [2.3])),)) db_map.commit_session("Add test data.") - mapping = object_parameter_default_value_export(0, 1, Position.hidden, 2, None, None) + mapping = entity_class_parameter_default_value_export(0, 1, Position.hidden, 2, None, None) expected = [["oc", "p", "map"]] self.assertEqual(list(rows(mapping, db_map)), expected) db_map.connection.close() @@ -1517,7 +1529,7 @@ def test_table_gets_exported_even_without_parameter_values(self): import_object_classes(db_map, ("oc",)) import_object_parameters(db_map, (("oc", "p"),)) db_map.commit_session("Add test data.") - mapping = object_parameter_export(Position.header, Position.table_name, object_position=0, value_position=1) + mapping = entity_parameter_export(Position.header, Position.table_name, entity_position=0, value_position=1) tables = dict() for title, title_key in titles(mapping, db_map): tables[title] = list(rows(mapping, db_map, title_key)) @@ -1533,8 +1545,8 @@ def test_relationship_class_object_classes_parameters(self): db_map.commit_session("Add test data") root_mapping = unflatten( [ - RelationshipClassObjectHighlightingMapping(0), - RelationshipClassObjectClassMapping(1), + DimensionHighlightingMapping(0), + DimensionMapping(1), ParameterDefinitionMapping(2), ] ) @@ -1549,12 +1561,7 @@ def test_relationship_class_object_classes_parameters_multiple_dimensions(self): import_relationship_classes(db_map, (("rc", ("oc1", "oc2")),)) db_map.commit_session("Add test data") root_mapping = unflatten( - [ - RelationshipClassObjectHighlightingMapping(0), - RelationshipClassObjectClassMapping(1), - RelationshipClassObjectClassMapping(3), - ParameterDefinitionMapping(2), - ] + [DimensionHighlightingMapping(0), DimensionMapping(1), DimensionMapping(3), ParameterDefinitionMapping(2)] ) expected = [["rc", "oc1", "p11", "oc2"], ["rc", "oc1", "p12", "oc2"]] self.assertEqual(list(rows(root_mapping, db_map)), expected) @@ -1571,12 +1578,12 @@ def test_highlight_relationship_objects(self): db_map.commit_session("Add test data") root_mapping = unflatten( [ - RelationshipClassObjectHighlightingMapping(0), - RelationshipClassObjectClassMapping(1), - RelationshipClassObjectClassMapping(2), - RelationshipObjectHighlightingMapping(3), - RelationshipObjectMapping(4), - RelationshipObjectMapping(5), + DimensionHighlightingMapping(0), + DimensionMapping(1), + DimensionMapping(2), + ElementHighlightingMapping(3), + ElementMapping(4), + ElementMapping(5), ] ) expected = [ @@ -1597,10 +1604,10 @@ def test_export_object_parameters_while_exporting_relationships(self): db_map.commit_session("Add test data") root_mapping = unflatten( [ - RelationshipClassObjectHighlightingMapping(0), - RelationshipClassObjectClassMapping(1), - RelationshipObjectHighlightingMapping(2), - RelationshipObjectMapping(3), + DimensionHighlightingMapping(0), + DimensionMapping(1), + ElementHighlightingMapping(2), + ElementMapping(3), ParameterDefinitionMapping(4), AlternativeMapping(5), ParameterValueMapping(6), diff --git a/tests/export_mapping/test_settings.py b/tests/export_mapping/test_settings.py index 198760ec..29d09f4f 100644 --- a/tests/export_mapping/test_settings.py +++ b/tests/export_mapping/test_settings.py @@ -31,23 +31,21 @@ ) from spinedb_api.export_mapping import rows from spinedb_api.export_mapping.settings import ( - relationship_export, - set_relationship_dimensions, - object_parameter_export, + entity_export, + set_entity_dimensions, + entity_parameter_export, set_parameter_dimensions, - relationship_parameter_default_value_export, set_parameter_default_value_dimensions, - object_parameter_default_value_export, - relationship_parameter_export, - relationship_object_parameter_default_value_export, - relationship_object_parameter_export, + entity_class_parameter_default_value_export, + entity_class_dimension_parameter_default_value_export, + entity_element_parameter_export, ) from spinedb_api.export_mapping.export_mapping import ( Position, - RelationshipClassMapping, - RelationshipClassObjectClassMapping, - RelationshipMapping, - RelationshipObjectMapping, + EntityClassMapping, + DimensionMapping, + EntityMapping, + ElementMapping, ExpandedParameterValueMapping, ParameterValueIndexMapping, IndexNameMapping, @@ -61,7 +59,7 @@ ) -class TestRelationshipParameterExport(unittest.TestCase): +class TestEntityParameterExport(unittest.TestCase): def test_export_with_parameter_values(self): db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc1", "oc2")) @@ -87,8 +85,8 @@ def test_export_with_parameter_values(self): ), ) db_map.commit_session("Add test data.") - root_mapping = relationship_parameter_export( - object_positions=[-1, -2], value_position=-3, index_name_positions=[Position.hidden], index_positions=[0] + root_mapping = entity_parameter_export( + element_positions=[-1, -2], value_position=-3, index_name_positions=[Position.hidden], index_positions=[0] ) expected = [ [None, "o1", "o1"], @@ -100,7 +98,7 @@ def test_export_with_parameter_values(self): db_map.connection.close() -class TestRelationshipObjectParameterDefaultValueExport(unittest.TestCase): +class TestEntityClassDimensionParameterDefaultValueExport(unittest.TestCase): def setUp(self): self._db_map = DatabaseMapping("sqlite://", create=True) @@ -115,10 +113,10 @@ def test_export_with_two_dimensions(self): import_relationship_classes(self._db_map, (("rc", ("oc1", "oc2")),)) import_relationship_parameters(self._db_map, (("rc", "rc_p", "dummy"),)) self._db_map.commit_session("Add test data.") - root_mapping = relationship_object_parameter_default_value_export( - relationship_class_position=0, + root_mapping = entity_class_dimension_parameter_default_value_export( + entity_class_position=0, definition_position=1, - object_class_positions=[2, 3], + dimension_positions=[2, 3], value_position=4, value_type_position=5, index_name_positions=None, @@ -129,7 +127,7 @@ def test_export_with_two_dimensions(self): self.assertEqual(list(rows(root_mapping, self._db_map)), expected) -class TestRelationshipObjectParameterExport(unittest.TestCase): +class TestEntityElementParameterExport(unittest.TestCase): def setUp(self): self._db_map = DatabaseMapping("sqlite://", create=True) @@ -154,19 +152,19 @@ def test_export_with_two_dimensions(self): import_relationships(self._db_map, (("rc", ("o11", "o21")), ("rc", ("o12", "o21")))) import_relationship_parameter_values(self._db_map, (("rc", ("o11", "o21"), "rc_p", "dummy"),)) self._db_map.commit_session("Add test data.") - root_mapping = relationship_object_parameter_export( - relationship_class_position=0, + root_mapping = entity_element_parameter_export( + entity_class_position=0, definition_position=1, value_list_position=Position.hidden, - relationship_position=2, - object_class_positions=[3, 4], - object_positions=[5, 6], + entity_position=2, + dimension_positions=[3, 4], + element_positions=[5, 6], alternative_position=7, value_type_position=8, value_position=9, highlight_dimension=0, ) - set_relationship_dimensions(root_mapping, 2) + set_entity_dimensions(root_mapping, 2) expected = [ ["rc", "p11", "rc_o11__o21", "oc1", "oc2", "o11", "o21", "Base", "single_value", 2.3], ["rc", "p11", "rc_o12__o21", "oc1", "oc2", "o12", "o21", "Base", "single_value", -2.3], @@ -175,72 +173,72 @@ def test_export_with_two_dimensions(self): self.assertEqual(list(rows(root_mapping, self._db_map)), expected) -class TestSetRelationshipDimensions(unittest.TestCase): +class TestSetEntityDimensions(unittest.TestCase): def test_change_dimensions_from_zero_to_one(self): - mapping = relationship_export(0, 1) + mapping = entity_export(0, 1) self.assertEqual(mapping.count_mappings(), 2) - set_relationship_dimensions(mapping, 1) + set_entity_dimensions(mapping, 1) self.assertEqual(mapping.count_mappings(), 4) flattened = mapping.flatten() classes = [type(mapping) for mapping in flattened] self.assertEqual( classes, [ - RelationshipClassMapping, - RelationshipClassObjectClassMapping, - RelationshipMapping, - RelationshipObjectMapping, + EntityClassMapping, + DimensionMapping, + EntityMapping, + ElementMapping, ], ) positions = [mapping.position for mapping in flattened] self.assertEqual(positions, [0, Position.hidden, 1, Position.hidden]) def test_change_dimension_from_one_to_zero(self): - mapping = relationship_export(0, 1, [2], [3]) + mapping = entity_export(0, 1, [2], [3]) self.assertEqual(mapping.count_mappings(), 4) - set_relationship_dimensions(mapping, 0) + set_entity_dimensions(mapping, 0) self.assertEqual(mapping.count_mappings(), 2) flattened = mapping.flatten() classes = [type(mapping) for mapping in flattened] - self.assertEqual(classes, [RelationshipClassMapping, RelationshipMapping]) + self.assertEqual(classes, [EntityClassMapping, EntityMapping]) positions = [mapping.position for mapping in flattened] self.assertEqual(positions, [0, 1]) def test_increase_dimensions(self): - mapping = relationship_export(0, 1, [2], [3]) + mapping = entity_export(0, 1, [2], [3]) self.assertEqual(mapping.count_mappings(), 4) - set_relationship_dimensions(mapping, 2) + set_entity_dimensions(mapping, 2) self.assertEqual(mapping.count_mappings(), 6) flattened = mapping.flatten() classes = [type(mapping) for mapping in flattened] self.assertEqual( classes, [ - RelationshipClassMapping, - RelationshipClassObjectClassMapping, - RelationshipClassObjectClassMapping, - RelationshipMapping, - RelationshipObjectMapping, - RelationshipObjectMapping, + EntityClassMapping, + DimensionMapping, + DimensionMapping, + EntityMapping, + ElementMapping, + ElementMapping, ], ) positions = [mapping.position for mapping in flattened] self.assertEqual(positions, [0, 2, Position.hidden, 1, 3, Position.hidden]) def test_decrease_dimensions(self): - mapping = relationship_export(0, 1, [2, 3], [4, 5]) + mapping = entity_export(0, 1, [2, 3], [4, 5]) self.assertEqual(mapping.count_mappings(), 6) - set_relationship_dimensions(mapping, 1) + set_entity_dimensions(mapping, 1) self.assertEqual(mapping.count_mappings(), 4) flattened = mapping.flatten() classes = [type(mapping) for mapping in flattened] self.assertEqual( classes, [ - RelationshipClassMapping, - RelationshipClassObjectClassMapping, - RelationshipMapping, - RelationshipObjectMapping, + EntityClassMapping, + DimensionMapping, + EntityMapping, + ElementMapping, ], ) positions = [mapping.position for mapping in flattened] @@ -249,7 +247,7 @@ def test_decrease_dimensions(self): class TestSetParameterDimensions(unittest.TestCase): def test_set_dimensions_from_zero_to_one(self): - root_mapping = object_parameter_export() + root_mapping = entity_parameter_export() set_parameter_dimensions(root_mapping, 1) expected_types = [ ExpandedParameterValueMapping, @@ -261,7 +259,7 @@ def test_set_dimensions_from_zero_to_one(self): self.assertIsInstance(mapping, expected_type) def test_set_default_value_dimensions_from_zero_to_one(self): - root_mapping = relationship_parameter_default_value_export() + root_mapping = entity_class_parameter_default_value_export() set_parameter_default_value_dimensions(root_mapping, 1) expected_types = [ ExpandedParameterDefaultValueMapping, @@ -273,21 +271,21 @@ def test_set_default_value_dimensions_from_zero_to_one(self): self.assertIsInstance(mapping, expected_type) def test_set_dimensions_from_one_to_zero(self): - root_mapping = relationship_parameter_export(index_name_positions=[0], index_positions=[1]) + root_mapping = entity_parameter_export(index_name_positions=[0], index_positions=[1]) set_parameter_dimensions(root_mapping, 0) expected_types = [ParameterValueMapping, ParameterValueTypeMapping] for expected_type, mapping in zip(expected_types, reversed(root_mapping.flatten())): self.assertIsInstance(mapping, expected_type) def test_set_default_value_dimensions_from_one_to_zero(self): - root_mapping = object_parameter_default_value_export(index_name_positions=[0], index_positions=[1]) + root_mapping = entity_class_parameter_default_value_export(index_name_positions=[0], index_positions=[1]) set_parameter_default_value_dimensions(root_mapping, 0) expected_types = [ParameterDefaultValueMapping, ParameterDefaultValueTypeMapping] for expected_type, mapping in zip(expected_types, reversed(root_mapping.flatten())): self.assertIsInstance(mapping, expected_type) def test_set_dimensions_from_one_to_two(self): - root_mapping = relationship_parameter_export(index_name_positions=[0], index_positions=[1]) + root_mapping = entity_parameter_export(index_name_positions=[0], index_positions=[1]) set_parameter_dimensions(root_mapping, 2) expected_types = [ ExpandedParameterValueMapping, @@ -301,7 +299,7 @@ def test_set_dimensions_from_one_to_two(self): self.assertIsInstance(mapping, expected_type) def test_set_default_value_dimensions_from_one_to_two(self): - root_mapping = relationship_parameter_default_value_export(index_name_positions=[0], index_positions=[1]) + root_mapping = entity_class_parameter_default_value_export(index_name_positions=[0], index_positions=[1]) set_parameter_default_value_dimensions(root_mapping, 2) expected_types = [ ExpandedParameterDefaultValueMapping, @@ -315,7 +313,7 @@ def test_set_default_value_dimensions_from_one_to_two(self): self.assertIsInstance(mapping, expected_type) def test_set_dimensions_from_two_to_one(self): - root_mapping = relationship_parameter_export(index_name_positions=[0, 2], index_positions=[1, 3]) + root_mapping = entity_parameter_export(index_name_positions=[0, 2], index_positions=[1, 3]) set_parameter_dimensions(root_mapping, 1) expected_types = [ ExpandedParameterValueMapping, @@ -327,7 +325,7 @@ def test_set_dimensions_from_two_to_one(self): self.assertIsInstance(mapping, expected_type) def test_set_default_value_dimensions_from_two_to_one(self): - root_mapping = relationship_parameter_default_value_export(index_name_positions=[0, 2], index_positions=[1, 3]) + root_mapping = entity_class_parameter_default_value_export(index_name_positions=[0, 2], index_positions=[1, 3]) set_parameter_default_value_dimensions(root_mapping, 1) expected_types = [ ExpandedParameterDefaultValueMapping, diff --git a/tests/spine_io/exporters/test_csv_writer.py b/tests/spine_io/exporters/test_csv_writer.py index 989299b8..01d9df75 100644 --- a/tests/spine_io/exporters/test_csv_writer.py +++ b/tests/spine_io/exporters/test_csv_writer.py @@ -19,7 +19,7 @@ import unittest from spinedb_api import DatabaseMapping, import_object_classes, import_objects from spinedb_api.mapping import Position -from spinedb_api.export_mapping import object_export +from spinedb_api.export_mapping import entity_export from spinedb_api.spine_io.exporters.writer import write from spinedb_api.spine_io.exporters.csv_writer import CsvWriter @@ -33,7 +33,7 @@ def tearDown(self): def test_write_empty_database(self): db_map = DatabaseMapping("sqlite://", create=True) - root_mapping = object_export(0, 1) + root_mapping = entity_export(0, 1) out_path = Path(self._temp_dir.name, "out.csv") writer = CsvWriter(out_path.parent, out_path.name) write(db_map, writer, root_mapping) @@ -47,7 +47,7 @@ def test_write_single_object_class_and_object(self): import_object_classes(db_map, ("oc",)) import_objects(db_map, (("oc", "o1"),)) db_map.commit_session("Add test data.") - root_mapping = object_export(0, 1) + root_mapping = entity_export(0, 1) out_path = Path(self._temp_dir.name, "out.csv") writer = CsvWriter(out_path.parent, out_path.name) write(db_map, writer, root_mapping) @@ -61,7 +61,7 @@ def test_tables_are_written_to_separate_files(self): import_object_classes(db_map, ("oc1", "oc2")) import_objects(db_map, (("oc1", "o1"), ("oc2", "o2"))) db_map.commit_session("Add test data.") - root_mapping = object_export(Position.table_name, 0) + root_mapping = entity_export(Position.table_name, 0) out_path = Path(self._temp_dir.name, "out.csv") writer = CsvWriter(out_path.parent, out_path.name) write(db_map, writer, root_mapping) @@ -85,8 +85,8 @@ def test_append_to_table(self): import_object_classes(db_map, ("oc",)) import_objects(db_map, (("oc", "o1"),)) db_map.commit_session("Add test data.") - root_mapping1 = object_export(0, 1) - root_mapping2 = object_export(0, 1) + root_mapping1 = entity_export(0, 1) + root_mapping2 = entity_export(0, 1) out_path = Path(self._temp_dir.name, "out.csv") writer = CsvWriter(out_path.parent, out_path.name) write(db_map, writer, root_mapping1, root_mapping2) diff --git a/tests/spine_io/exporters/test_excel_writer.py b/tests/spine_io/exporters/test_excel_writer.py index 357f84b7..edb91753 100644 --- a/tests/spine_io/exporters/test_excel_writer.py +++ b/tests/spine_io/exporters/test_excel_writer.py @@ -20,7 +20,7 @@ from openpyxl import load_workbook from spinedb_api import DatabaseMapping, import_object_classes, import_objects from spinedb_api.mapping import Position -from spinedb_api.export_mapping import object_export +from spinedb_api.export_mapping import entity_export from spinedb_api.spine_io.exporters.writer import write from spinedb_api.spine_io.exporters.excel_writer import ExcelWriter @@ -34,7 +34,7 @@ def tearDown(self): def test_write_empty_database(self): db_map = DatabaseMapping("sqlite://", create=True) - root_mapping = object_export(0, 1) + root_mapping = entity_export(0, 1) path = os.path.join(self._temp_dir.name, "test.xlsx") writer = ExcelWriter(path) write(db_map, writer, root_mapping) @@ -50,7 +50,7 @@ def test_write_single_object_class_and_object(self): import_object_classes(db_map, ("oc",)) import_objects(db_map, (("oc", "o1"),)) db_map.commit_session("Add test data.") - root_mapping = object_export(0, 1) + root_mapping = entity_export(0, 1) path = os.path.join(self._temp_dir.name, "test.xlsx") writer = ExcelWriter(path) write(db_map, writer, root_mapping) @@ -66,7 +66,7 @@ def test_write_to_existing_sheet(self): import_object_classes(db_map, ("Sheet1",)) import_objects(db_map, (("Sheet1", "o1"), ("Sheet1", "o2"))) db_map.commit_session("Add test data.") - root_mapping = object_export(Position.table_name, 0) + root_mapping = entity_export(Position.table_name, 0) path = os.path.join(self._temp_dir.name, "test.xlsx") writer = ExcelWriter(path) write(db_map, writer, root_mapping) @@ -82,7 +82,7 @@ def test_write_to_named_sheets(self): import_object_classes(db_map, ("oc1", ("oc2"))) import_objects(db_map, (("oc1", "o11"), ("oc1", "o12"), ("oc2", "o21"))) db_map.commit_session("Add test data.") - root_mapping = object_export(Position.table_name, 1) + root_mapping = entity_export(Position.table_name, 1) path = os.path.join(self._temp_dir.name, "test.xlsx") writer = ExcelWriter(path) write(db_map, writer, root_mapping) @@ -100,8 +100,8 @@ def test_append_to_anonymous_table(self): import_object_classes(db_map, ("oc",)) import_objects(db_map, (("oc", "o1"),)) db_map.commit_session("Add test data.") - root_mapping1 = object_export(0, 1) - root_mapping2 = object_export(0, 1) + root_mapping1 = entity_export(0, 1) + root_mapping2 = entity_export(0, 1) path = os.path.join(self._temp_dir.name, "test.xlsx") writer = ExcelWriter(path) write(db_map, writer, root_mapping1, root_mapping2) @@ -117,8 +117,8 @@ def test_append_to_named_table(self): import_object_classes(db_map, ("oc",)) import_objects(db_map, (("oc", "o1"),)) db_map.commit_session("Add test data.") - root_mapping1 = object_export(Position.table_name, 0) - root_mapping2 = object_export(Position.table_name, 0) + root_mapping1 = entity_export(Position.table_name, 0) + root_mapping2 = entity_export(Position.table_name, 0) path = os.path.join(self._temp_dir.name, "test.xlsx") writer = ExcelWriter(path) write(db_map, writer, root_mapping1, root_mapping2) diff --git a/tests/spine_io/exporters/test_gdx_writer.py b/tests/spine_io/exporters/test_gdx_writer.py index 57ddeba2..6e99a0ae 100644 --- a/tests/spine_io/exporters/test_gdx_writer.py +++ b/tests/spine_io/exporters/test_gdx_writer.py @@ -35,7 +35,7 @@ Map, ) from spinedb_api.mapping import Position, unflatten -from spinedb_api.export_mapping import object_export, object_parameter_export, relationship_export +from spinedb_api.export_mapping import entity_export, entity_parameter_export, entity_export from spinedb_api.export_mapping.export_mapping import FixedValueMapping @@ -45,7 +45,7 @@ class TestGdxWriter(unittest.TestCase): @unittest.skipIf(_gams_dir is None, "No working GAMS installation found.") def test_write_empty_database(self): db_map = DatabaseMapping("sqlite://", create=True) - root_mapping = object_export(class_position=Position.table_name, object_position=0) + root_mapping = entity_export(entity_class_position=Position.table_name, entity_position=0) root_mapping.child.header = "*" with TemporaryDirectory() as temp_dir: file_path = Path(temp_dir, "test_write_empty_database.gdx") @@ -61,7 +61,7 @@ def test_write_single_object_class_and_object(self): import_object_classes(db_map, ("oc",)) import_objects(db_map, (("oc", "o1"),)) db_map.commit_session("Add test data.") - root_mapping = object_export(Position.table_name, 0) + root_mapping = entity_export(Position.table_name, 0) root_mapping.child.header = "*" with TemporaryDirectory() as temp_dir: file_path = Path(temp_dir, "test_write_single_object_class_and_object.gdx") @@ -82,9 +82,7 @@ def test_write_2D_relationship(self): import_relationship_classes(db_map, (("rel", ("oc1", "oc2")),)) import_relationships(db_map, (("rel", ("o1", "o2")),)) db_map.commit_session("Add test data.") - root_mapping = relationship_export( - Position.table_name, Position.hidden, [Position.header, Position.header], [0, 1] - ) + root_mapping = entity_export(Position.table_name, Position.hidden, [Position.header, Position.header], [0, 1]) with TemporaryDirectory() as temp_dir: file_path = Path(temp_dir, "test_write_2D_relationship.gdx") writer = GdxWriter(str(file_path), self._gams_dir) @@ -104,7 +102,9 @@ def test_write_parameters(self): import_objects(db_map, (("oc", "o1"),)) import_object_parameter_values(db_map, (("oc", "o1", "p", 2.3),)) db_map.commit_session("Add test data.") - root_mapping = object_parameter_export(class_position=Position.table_name, object_position=0, value_position=1) + root_mapping = entity_parameter_export( + entity_class_position=Position.table_name, entity_position=0, value_position=1 + ) mappings = root_mapping.flatten() mappings[3].header = "*" with TemporaryDirectory() as temp_dir: @@ -126,7 +126,9 @@ def test_non_numerical_parameter_value_raises_writer_expection(self): import_objects(db_map, (("oc", "o1"),)) import_object_parameter_values(db_map, (("oc", "o1", "p", "text"),)) db_map.commit_session("Add test data.") - root_mapping = object_parameter_export(class_position=Position.table_name, object_position=0, value_position=1) + root_mapping = entity_parameter_export( + entity_class_position=Position.table_name, entity_position=0, value_position=1 + ) mappings = root_mapping.flatten() mappings[3].header = "*" with TemporaryDirectory() as temp_dir: @@ -143,7 +145,9 @@ def test_empty_parameter(self): import_objects(db_map, (("oc", "o1"),)) import_object_parameter_values(db_map, (("oc", "o1", "p", Map([], [], str)),)) db_map.commit_session("Add test data.") - root_mapping = object_parameter_export(class_position=Position.table_name, object_position=0, value_position=1) + root_mapping = entity_parameter_export( + entity_class_position=Position.table_name, entity_position=0, value_position=1 + ) mappings = root_mapping.flatten() mappings[3].header = "*" mappings[-1].filter_re = "single_value" @@ -166,7 +170,7 @@ def test_write_scalars(self): import_objects(db_map, (("oc", "o1"),)) import_object_parameter_values(db_map, (("oc", "o1", "p", 2.3),)) db_map.commit_session("Add test data.") - root_mapping = object_parameter_export(class_position=Position.table_name, value_position=0) + root_mapping = entity_parameter_export(entity_class_position=Position.table_name, value_position=0) with TemporaryDirectory() as temp_dir: file_path = Path(temp_dir, "test_write_scalars.gdx") writer = GdxWriter(str(file_path), self._gams_dir) @@ -183,7 +187,7 @@ def test_two_tables(self): import_object_classes(db_map, ("oc1", "oc2")) import_objects(db_map, (("oc1", "o"), ("oc2", "p"))) db_map.commit_session("Add test data.") - root_mapping = object_export(class_position=Position.table_name, object_position=0) + root_mapping = entity_export(entity_class_position=Position.table_name, entity_position=0) root_mapping.child.header = "*" with TemporaryDirectory() as temp_dir: file_path = Path(temp_dir, "test_two_tables.gdx") @@ -206,12 +210,12 @@ def test_append_to_table(self): import_objects(db_map, (("oc1", "o"), ("oc2", "p"))) db_map.commit_session("Add test data.") root_mapping1 = unflatten( - [FixedValueMapping(Position.table_name, value="set_X")] + object_export(object_position=0).flatten() + [FixedValueMapping(Position.table_name, value="set_X")] + entity_export(entity_position=0).flatten() ) root_mapping1.child.filter_re = "oc1" root_mapping1.child.child.header = "*" root_mapping2 = unflatten( - [FixedValueMapping(Position.table_name, value="set_X")] + object_export(object_position=0).flatten() + [FixedValueMapping(Position.table_name, value="set_X")] + entity_export(entity_position=0).flatten() ) root_mapping2.child.filter_re = "oc2" root_mapping2.child.child.header = "*" @@ -234,8 +238,11 @@ def test_parameter_value_non_convertible_to_float_raises_WriterException(self): import_objects(db_map, (("oc", "o"), ("oc", "p"))) import_object_parameter_values(db_map, (("oc", "o", "param", "text"), ("oc", "p", "param", 2.3))) db_map.commit_session("Add test data.") - root_mapping = object_parameter_export( - class_position=Position.hidden, definition_position=Position.table_name, object_position=0, value_position=1 + root_mapping = entity_parameter_export( + entity_class_position=Position.hidden, + definition_position=Position.table_name, + entity_position=0, + value_position=1, ) root_mapping.child.child.child.header = "*" with TemporaryDirectory() as temp_dir: @@ -252,8 +259,11 @@ def test_non_string_set_element_raises_WriterException(self): import_objects(db_map, (("oc", "o"), ("oc", "p"))) import_object_parameter_values(db_map, (("oc", "o", "param", 2.3), ("oc", "p", "param", "text"))) db_map.commit_session("Add test data.") - root_mapping = object_parameter_export( - class_position=Position.hidden, definition_position=Position.table_name, object_position=0, value_position=1 + root_mapping = entity_parameter_export( + entity_class_position=Position.hidden, + definition_position=Position.table_name, + entity_position=0, + value_position=1, ) root_mapping.child.child.child.header = "*" with TemporaryDirectory() as temp_dir: @@ -280,8 +290,8 @@ def test_special_value_conversions(self): ), ) db_map.commit_session("Add test data.") - root_mapping = object_parameter_export( - class_position=Position.table_name, object_position=0, definition_position=1, value_position=2 + root_mapping = entity_parameter_export( + entity_class_position=Position.table_name, entity_position=0, definition_position=1, value_position=2 ) mappings = root_mapping.flatten() mappings[1].header = mappings[3].header = "*" diff --git a/tests/spine_io/exporters/test_sql_writer.py b/tests/spine_io/exporters/test_sql_writer.py index cbcb4421..8895badb 100644 --- a/tests/spine_io/exporters/test_sql_writer.py +++ b/tests/spine_io/exporters/test_sql_writer.py @@ -29,12 +29,12 @@ import_object_parameter_values, ) from spinedb_api.mapping import Position, unflatten -from spinedb_api.export_mapping import object_export +from spinedb_api.export_mapping import entity_export from spinedb_api.export_mapping.export_mapping import ( AlternativeMapping, FixedValueMapping, - ObjectClassMapping, - ObjectMapping, + EntityClassMapping, + EntityMapping, ParameterDefinitionMapping, ParameterValueMapping, ) @@ -65,8 +65,8 @@ def test_write_header_only(self): root_mapping = unflatten( [ FixedValueMapping(Position.table_name, "table 1"), - ObjectClassMapping(0, header="classes"), - ObjectMapping(1, header="objects"), + EntityClassMapping(0, header="classes"), + EntityMapping(1, header="objects"), ] ) out_path = Path(self._temp_dir.name, "out.sqlite") @@ -97,8 +97,8 @@ def test_write_single_object_class_and_object(self): root_mapping = unflatten( [ FixedValueMapping(Position.table_name, "table 1"), - ObjectClassMapping(0, header="classes"), - ObjectMapping(1, header="objects"), + EntityClassMapping(0, header="classes"), + EntityMapping(1, header="objects"), ] ) out_path = Path(self._temp_dir.name, "out.sqlite") @@ -133,8 +133,8 @@ def test_write_datetime_value(self): root_mapping = unflatten( [ FixedValueMapping(Position.table_name, "table 1"), - ObjectClassMapping(0, header="classes"), - ObjectMapping(1, header="objects"), + EntityClassMapping(0, header="classes"), + EntityMapping(1, header="objects"), ParameterDefinitionMapping(2, header="parameters"), AlternativeMapping(Position.hidden), ParameterValueMapping(3, header="values"), @@ -173,8 +173,8 @@ def test_write_duration_value(self): root_mapping = unflatten( [ FixedValueMapping(Position.table_name, "table 1"), - ObjectClassMapping(0, header="classes"), - ObjectMapping(1, header="objects"), + EntityClassMapping(0, header="classes"), + EntityMapping(1, header="objects"), ParameterDefinitionMapping(2, header="parameters"), AlternativeMapping(Position.hidden), ParameterValueMapping(3, header="values"), @@ -208,10 +208,10 @@ def test_append_to_table(self): import_object_classes(db_map, ("oc",)) import_objects(db_map, (("oc", "o1"), ("oc", "q1"))) db_map.commit_session("Add test data.") - root_mapping1 = object_export(Position.table_name, 0) + root_mapping1 = entity_export(Position.table_name, 0) root_mapping1.child.header = "objects" root_mapping1.child.filter_re = "o1" - root_mapping2 = object_export(Position.table_name, 0) + root_mapping2 = entity_export(Position.table_name, 0) root_mapping2.child.header = "objects" root_mapping2.child.filter_re = "q1" out_path = Path(self._temp_dir.name, "out.sqlite") @@ -254,7 +254,7 @@ def test_appending_to_table_in_existing_database(self): out_connection.execute(object_table.insert(), objects="initial_object") finally: out_connection.close() - root_mapping = object_export(Position.table_name, 0) + root_mapping = entity_export(Position.table_name, 0) root_mapping.child.header = "objects" writer = SqlWriter(str(out_path), overwrite_existing=False) write(db_map, writer, root_mapping) diff --git a/tests/spine_io/exporters/test_writer.py b/tests/spine_io/exporters/test_writer.py index a637c1d4..9f0f1b00 100644 --- a/tests/spine_io/exporters/test_writer.py +++ b/tests/spine_io/exporters/test_writer.py @@ -17,7 +17,7 @@ import unittest from spinedb_api import DatabaseMapping, import_object_classes, import_objects from spinedb_api.spine_io.exporters.writer import Writer, write -from spinedb_api.export_mapping.settings import object_export +from spinedb_api.export_mapping.settings import entity_export class _TableWriter(Writer): @@ -63,7 +63,7 @@ def test_max_rows(self): ) self._db_map.commit_session("Add test data.") writer = _TableWriter() - root_mapping = object_export(0, 1) + root_mapping = entity_export(0, 1) write(self._db_map, writer, root_mapping, max_rows=2) self.assertEqual(writer.tables, {None: [["class1", "obj1"], ["class1", "obj2"]]}) @@ -82,7 +82,7 @@ def test_max_rows_with_filter(self): ) self._db_map.commit_session("Add test data.") writer = _TableWriter() - root_mapping = object_export(0, 1) + root_mapping = entity_export(0, 1) root_mapping.child.filter_re = "obj6" write(self._db_map, writer, root_mapping, max_rows=1) self.assertEqual(writer.tables, {None: [["class2", "obj6"]]}) From 676b88ac3e1932c2c496f33bb1a661c94a2cc16b Mon Sep 17 00:00:00 2001 From: Manuel Date: Sat, 1 Apr 2023 11:46:13 +0200 Subject: [PATCH 019/317] Fix a couple more tests --- spinedb_api/export_mapping/export_mapping.py | 4 ++-- tests/export_mapping/test_export_mapping.py | 20 ++++++++++++++++---- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/spinedb_api/export_mapping/export_mapping.py b/spinedb_api/export_mapping/export_mapping.py index 474b9277..911aee9b 100644 --- a/spinedb_api/export_mapping/export_mapping.py +++ b/spinedb_api/export_mapping/export_mapping.py @@ -289,7 +289,7 @@ def _build_header_query(self, db_map, title_state, buddies): flat_buddies = [b for pair in buddies for b in pair] for _ in range(len(mappings)): m = mappings[-1] - if m.position == Position.header or m.position == Position.table_name or m in flat_buddies: + if m.position in (Position.header, Position.table_name) or m in flat_buddies: break mappings.pop(-1) # Start with empty query @@ -560,7 +560,7 @@ def make_header_recursive(self, query, buddies): if buddy is not None: query.rewind() header[buddy.position] = next( - (x for db_row in query for x in self._get_data_iterator(self._data(db_row))), "" + (x for db_row in query for x in self._get_data_iterator(self._data(db_row)) if x), "" ) else: header[self.position] = self.header diff --git a/tests/export_mapping/test_export_mapping.py b/tests/export_mapping/test_export_mapping.py index 55820372..d7bc4165 100644 --- a/tests/export_mapping/test_export_mapping.py +++ b/tests/export_mapping/test_export_mapping.py @@ -816,7 +816,12 @@ def test_export_relationship_classes(self): ) db_map.commit_session("Add test data.") relationship_class_mapping = EntityClassMapping(0) - self.assertEqual(list(rows(relationship_class_mapping, db_map)), [["rc1"], ["rc2"], ["rc3"]]) + dimension_mapping = relationship_class_mapping.child = DimensionMapping(1) + dimension_mapping.child = DimensionMapping(2) + self.assertEqual( + list(rows(relationship_class_mapping, db_map)), + [["rc1", "oc1", ""], ["rc2", "oc3", "oc2"], ["rc3", "oc2", "oc3"]], + ) db_map.connection.close() def test_export_relationships(self): @@ -827,9 +832,16 @@ def test_export_relationships(self): import_relationships(db_map, (("rc1", ("o11",)), ("rc2", ("o21", "o11")), ("rc2", ("o21", "o12")))) db_map.commit_session("Add test data.") relationship_class_mapping = EntityClassMapping(0) - relationship_mapping = EntityMapping(1) - relationship_class_mapping.child = relationship_mapping - expected = [["rc1", "rc1_o11"], ["rc2", "rc2_o21__o11"], ["rc2", "rc2_o21__o12"]] + dimension1_mapping = relationship_class_mapping.child = DimensionMapping(1) + dimension2_mapping = dimension1_mapping.child = DimensionMapping(2) + relationship_mapping = dimension2_mapping.child = EntityMapping(3) + element1_mapping = relationship_mapping.child = ElementMapping(4) + element1_mapping.child = ElementMapping(5) + expected = [ + ['rc1', 'oc1', '', 'rc1_o11', 'o11', ''], + ['rc2', 'oc2', 'oc1', 'rc2_o21__o11', 'o21', 'o11'], + ['rc2', 'oc2', 'oc1', 'rc2_o21__o12', 'o21', 'o12'], + ] self.assertEqual(list(rows(relationship_class_mapping, db_map)), expected) db_map.connection.close() From b54242f70c1a7618ef0eba62feed66e84493ff2c Mon Sep 17 00:00:00 2001 From: Manuel Date: Sun, 2 Apr 2023 15:48:47 +0200 Subject: [PATCH 020/317] Some renaming and fix export entity classes --- spinedb_api/export_functions.py | 2 +- spinedb_api/export_mapping/__init__.py | 8 ++--- spinedb_api/export_mapping/settings.py | 32 +++++++++--------- tests/export_mapping/test_export_mapping.py | 37 ++++++++++----------- tests/export_mapping/test_settings.py | 30 ++++++++--------- tests/import_mapping/test_import_mapping.py | 2 +- tests/spine_io/exporters/test_gdx_writer.py | 16 ++++----- 7 files changed, 63 insertions(+), 64 deletions(-) diff --git a/spinedb_api/export_functions.py b/spinedb_api/export_functions.py index 65bc5008..262e793a 100644 --- a/spinedb_api/export_functions.py +++ b/spinedb_api/export_functions.py @@ -168,7 +168,7 @@ def export_parameter_value_lists(db_map, ids=Asterisk, make_cache=None, parse_va def export_entity_classes(db_map, ids=Asterisk, make_cache=None): return sorted( - (x.name, x.description, x.display_icon, x.dimension_name_list) + (x.name, x.dimension_name_list, x.description, x.display_icon) for x in _get_items(db_map, "entity_class", ids, make_cache) ) diff --git a/spinedb_api/export_mapping/__init__.py b/spinedb_api/export_mapping/__init__.py index d82d28d4..896cc0b6 100644 --- a/spinedb_api/export_mapping/__init__.py +++ b/spinedb_api/export_mapping/__init__.py @@ -21,11 +21,11 @@ feature_export, entity_export, entity_group_export, - entity_class_parameter_default_value_export, - entity_parameter_export, + entity_parameter_default_value_export, + entity_parameter_value_export, parameter_value_list_export, - entity_class_dimension_parameter_default_value_export, - entity_element_parameter_export, + entity_dimension_parameter_default_value_export, + entity_dimension_parameter_value_export, scenario_alternative_export, scenario_export, tool_export, diff --git a/spinedb_api/export_mapping/settings.py b/spinedb_api/export_mapping/settings.py index 36aa68e2..dd263118 100644 --- a/spinedb_api/export_mapping/settings.py +++ b/spinedb_api/export_mapping/settings.py @@ -60,20 +60,20 @@ def entity_group_export( - class_position=Position.hidden, group_position=Position.hidden, entity_position=Position.hidden + entity_class_position=Position.hidden, group_position=Position.hidden, entity_position=Position.hidden ): """ Sets up export mappings for exporting entity groups. Args: - class_position (int or Position): position of entity classes + entity_class_position (int or Position): position of entity classes group_position (int or Position): position of groups entity_position (int or Position): position of entities Returns: ExportMapping: root mapping """ - class_ = EntityClassMapping(class_position) + class_ = EntityClassMapping(entity_class_position) group = EntityGroupMapping(group_position) entity = EntityGroupEntityMapping(entity_position) group.child = entity @@ -111,7 +111,7 @@ def entity_export( return entity_class -def entity_class_parameter_default_value_export( +def entity_parameter_default_value_export( entity_class_position=Position.hidden, definition_position=Position.hidden, value_type_position=Position.hidden, @@ -142,7 +142,7 @@ def entity_class_parameter_default_value_export( return entity_class -def entity_parameter_export( +def entity_parameter_value_export( entity_class_position=Position.hidden, definition_position=Position.hidden, value_list_position=Position.hidden, @@ -199,7 +199,7 @@ def entity_parameter_export( return entity_class -def entity_class_dimension_parameter_default_value_export( +def entity_dimension_parameter_default_value_export( entity_class_position=Position.hidden, definition_position=Position.hidden, dimension_positions=None, @@ -238,7 +238,7 @@ def entity_class_dimension_parameter_default_value_export( return root_mapping -def entity_element_parameter_export( +def entity_dimension_parameter_value_export( entity_class_position=Position.hidden, definition_position=Position.hidden, value_list_position=Position.hidden, @@ -426,18 +426,18 @@ def set_parameter_default_value_dimensions(mapping, dimensions): ) -def feature_export(class_position=Position.hidden, definition_position=Position.hidden): +def feature_export(entity_class_position=Position.hidden, definition_position=Position.hidden): """ Sets up export mappings for exporting features. Args: - class_position (int or Position): position of entity classes + entity_class_position (int or Position): position of entity classes definition_position (int or Position): position of parameter definitions Returns: ExportMapping: root mapping """ - class_ = FeatureEntityClassMapping(class_position) + class_ = FeatureEntityClassMapping(entity_class_position) definition = FeatureParameterDefinitionMapping(definition_position) class_.child = definition return class_ @@ -458,7 +458,7 @@ def tool_export(tool_position=Position.hidden): def tool_feature_export( tool_position=Position.hidden, - class_position=Position.hidden, + entity_class_position=Position.hidden, definition_position=Position.hidden, required_flag_position=Position.hidden, ): @@ -467,7 +467,7 @@ def tool_feature_export( Args: tool_position (int or Position): position of tools - class_position (int or Position): position of entity classes + entity_class_position (int or Position): position of entity classes definition_position (int or Position): position of parameter definitions required_flag_position (int or Position): position of required flags @@ -475,7 +475,7 @@ def tool_feature_export( ExportMapping: root mapping """ tool = ToolMapping(tool_position) - class_ = ToolFeatureEntityClassMapping(class_position) + class_ = ToolFeatureEntityClassMapping(entity_class_position) definition = ToolFeatureParameterDefinitionMapping(definition_position) required_flag = ToolFeatureRequiredFlagMapping(required_flag_position) definition.child = required_flag @@ -486,7 +486,7 @@ def tool_feature_export( def tool_feature_method_export( tool_position=Position.hidden, - class_position=Position.hidden, + entity_class_position=Position.hidden, definition_position=Position.hidden, method_position=Position.hidden, ): @@ -495,7 +495,7 @@ def tool_feature_method_export( Args: tool_position (int or Position): position of tools - class_position (int or Position): position of entity classes + entity_class_position (int or Position): position of entity classes definition_position (int or Position): position of parameter definitions method_position (int or Position): position of methods @@ -503,7 +503,7 @@ def tool_feature_method_export( ExportMapping: root mapping """ tool = ToolMapping(tool_position) - class_ = ToolFeatureMethodEntityClassMapping(class_position) + class_ = ToolFeatureMethodEntityClassMapping(entity_class_position) definition = ToolFeatureMethodParameterDefinitionMapping(definition_position) method = ToolFeatureMethodMethodMapping(method_position) definition.child = method diff --git a/tests/export_mapping/test_export_mapping.py b/tests/export_mapping/test_export_mapping.py index d7bc4165..e828b5ae 100644 --- a/tests/export_mapping/test_export_mapping.py +++ b/tests/export_mapping/test_export_mapping.py @@ -39,8 +39,8 @@ from spinedb_api.export_mapping import ( rows, titles, - entity_class_parameter_default_value_export, - entity_parameter_export, + entity_parameter_default_value_export, + entity_parameter_value_export, entity_export, ) from spinedb_api.export_mapping.export_mapping import ( @@ -290,7 +290,7 @@ def test_minimum_pivot_index_need_not_be_minus_one(self): ), ) db_map.commit_session("Add test data.") - mapping = entity_parameter_export( + mapping = entity_parameter_value_export( 1, 2, Position.hidden, 0, None, None, -2, Position.hidden, 4, [Position.hidden], [3] ) expected = [ @@ -492,7 +492,7 @@ def test_full_pivot_table_with_hidden_columns(self): db_map, (("oc", "o1", "p", Map(["A", "B"], [-1.1, -2.2])), ("oc", "o2", "p", Map(["A", "B"], [-5.5, -6.6]))) ) db_map.commit_session("Add test data.") - mapping = entity_parameter_export( + mapping = entity_parameter_value_export( 0, 2, Position.hidden, -1, None, None, 3, Position.hidden, 5, [Position.hidden], [4] ) expected = [ @@ -519,7 +519,7 @@ def test_objects_as_pivot_header_for_indexed_values_with_alternatives(self): ), ) db_map.commit_session("Add test data.") - mapping = entity_parameter_export( + mapping = entity_parameter_value_export( 0, 2, Position.hidden, -1, None, None, 3, Position.hidden, 5, [Position.hidden], [4] ) expected = [ @@ -541,7 +541,7 @@ def test_objects_and_indexes_as_pivot_header(self): db_map, (("oc", "o1", "p", Map(["A", "B"], [-1.1, -2.2])), ("oc", "o2", "p", Map(["A", "B"], [-3.3, -4.4]))) ) db_map.commit_session("Add test data.") - mapping = entity_parameter_export( + mapping = entity_parameter_value_export( 0, 2, Position.hidden, -1, None, None, 3, Position.hidden, 4, [Position.hidden], [-2] ) expected = [ @@ -573,7 +573,7 @@ def test_objects_and_indexes_as_pivot_header_with_multiple_alternatives_and_para ), ) db_map.commit_session("Add test data.") - mapping = entity_parameter_export( + mapping = entity_parameter_value_export( 0, 1, Position.hidden, -1, None, None, -2, Position.hidden, 2, [Position.hidden], [-3] ) expected = [ @@ -1200,9 +1200,6 @@ def test_header_position_with_relationships(self): ] ) expected = [["", "", "oc1", "oc2"], ["rc", "rc_o11__o21", "o11", "o21"]] - import pprint - - pprint.pprint(list(rows(root, db_map))) self.assertEqual(list(rows(root, db_map)), expected) db_map.connection.close() @@ -1444,7 +1441,7 @@ def test_index_names(self): import_objects(db_map, (("oc", "o"),)) import_object_parameter_values(db_map, (("oc", "o", "p", Map(["a"], [5.0], index_name="index")),)) db_map.commit_session("Add test data.") - mapping = entity_parameter_export( + mapping = entity_parameter_value_export( 0, 2, Position.hidden, 1, None, None, 3, Position.hidden, 5, [Position.header], [4] ) expected = [["", "", "", "", "index", ""], ["oc", "o", "p", "Base", "a", 5.0]] @@ -1458,7 +1455,7 @@ def test_default_value_index_names_with_nested_map(self): db_map, (("oc", "p", Map(["A"], [Map(["b"], [2.3], index_name="idx2")], index_name="idx1")),) ) db_map.commit_session("Add test data.") - mapping = entity_class_parameter_default_value_export( + mapping = entity_parameter_default_value_export( 0, 1, Position.hidden, 4, [Position.header, Position.header], [2, 3] ) expected = [["", "", "idx1", "idx2", ""], ["oc", "p", "A", "b", 2.3]] @@ -1467,7 +1464,7 @@ def test_default_value_index_names_with_nested_map(self): def test_multiple_index_names_with_empty_database(self): db_map = DatabaseMapping("sqlite://", create=True) - mapping = entity_parameter_export( + mapping = entity_parameter_value_export( 0, 4, Position.hidden, 1, [2], [3], 5, Position.hidden, 8, [Position.header, Position.header], [6, 7] ) expected = [9 * [""]] @@ -1479,7 +1476,7 @@ def test_parameter_default_value_type(self): import_object_classes(db_map, ("oc1", "oc2", "oc3")) import_object_parameters(db_map, (("oc1", "p11", 3.14), ("oc2", "p21", 14.3), ("oc2", "p22", -1.0))) db_map.commit_session("Add test data.") - root_mapping = entity_class_parameter_default_value_export(0, 1, 2, 3, None, None) + root_mapping = entity_parameter_default_value_export(0, 1, 2, 3, None, None) expected = [ ["oc1", "p11", "single_value", 3.14], ["oc2", "p21", "single_value", 14.3], @@ -1495,7 +1492,7 @@ def test_map_with_more_dimensions_than_index_mappings(self): import_objects(db_map, (("oc", "o"),)) import_object_parameter_values(db_map, (("oc", "o", "p", Map(["A"], [Map(["b"], [2.3])])),)) db_map.commit_session("Add test data.") - mapping = entity_parameter_export( + mapping = entity_parameter_value_export( 0, 1, Position.hidden, 2, None, None, Position.hidden, Position.hidden, 4, [Position.hidden], [3] ) expected = [["oc", "p", "o", "A", "map"]] @@ -1507,7 +1504,7 @@ def test_default_map_value_with_more_dimensions_than_index_mappings(self): import_object_classes(db_map, ("oc",)) import_object_parameters(db_map, (("oc", "p", Map(["A"], [Map(["b"], [2.3])])),)) db_map.commit_session("Add test data.") - mapping = entity_class_parameter_default_value_export(0, 1, Position.hidden, 3, [Position.hidden], [2]) + mapping = entity_parameter_default_value_export(0, 1, Position.hidden, 3, [Position.hidden], [2]) expected = [["oc", "p", "A", "map"]] self.assertEqual(list(rows(mapping, db_map)), expected) db_map.connection.close() @@ -1519,7 +1516,7 @@ def test_map_with_single_value_mapping(self): import_objects(db_map, (("oc", "o"),)) import_object_parameter_values(db_map, (("oc", "o", "p", Map(["A"], [2.3])),)) db_map.commit_session("Add test data.") - mapping = entity_parameter_export( + mapping = entity_parameter_value_export( 0, 1, Position.hidden, 2, None, None, Position.hidden, Position.hidden, 3, None, None ) expected = [["oc", "p", "o", "map"]] @@ -1531,7 +1528,7 @@ def test_default_map_value_with_single_value_mapping(self): import_object_classes(db_map, ("oc",)) import_object_parameters(db_map, (("oc", "p", Map(["A"], [2.3])),)) db_map.commit_session("Add test data.") - mapping = entity_class_parameter_default_value_export(0, 1, Position.hidden, 2, None, None) + mapping = entity_parameter_default_value_export(0, 1, Position.hidden, 2, None, None) expected = [["oc", "p", "map"]] self.assertEqual(list(rows(mapping, db_map)), expected) db_map.connection.close() @@ -1541,7 +1538,9 @@ def test_table_gets_exported_even_without_parameter_values(self): import_object_classes(db_map, ("oc",)) import_object_parameters(db_map, (("oc", "p"),)) db_map.commit_session("Add test data.") - mapping = entity_parameter_export(Position.header, Position.table_name, entity_position=0, value_position=1) + mapping = entity_parameter_value_export( + Position.header, Position.table_name, entity_position=0, value_position=1 + ) tables = dict() for title, title_key in titles(mapping, db_map): tables[title] = list(rows(mapping, db_map, title_key)) diff --git a/tests/export_mapping/test_settings.py b/tests/export_mapping/test_settings.py index 29d09f4f..d57c4e12 100644 --- a/tests/export_mapping/test_settings.py +++ b/tests/export_mapping/test_settings.py @@ -33,12 +33,12 @@ from spinedb_api.export_mapping.settings import ( entity_export, set_entity_dimensions, - entity_parameter_export, + entity_parameter_value_export, set_parameter_dimensions, set_parameter_default_value_dimensions, - entity_class_parameter_default_value_export, - entity_class_dimension_parameter_default_value_export, - entity_element_parameter_export, + entity_parameter_default_value_export, + entity_dimension_parameter_default_value_export, + entity_dimension_parameter_value_export, ) from spinedb_api.export_mapping.export_mapping import ( Position, @@ -85,7 +85,7 @@ def test_export_with_parameter_values(self): ), ) db_map.commit_session("Add test data.") - root_mapping = entity_parameter_export( + root_mapping = entity_parameter_value_export( element_positions=[-1, -2], value_position=-3, index_name_positions=[Position.hidden], index_positions=[0] ) expected = [ @@ -113,7 +113,7 @@ def test_export_with_two_dimensions(self): import_relationship_classes(self._db_map, (("rc", ("oc1", "oc2")),)) import_relationship_parameters(self._db_map, (("rc", "rc_p", "dummy"),)) self._db_map.commit_session("Add test data.") - root_mapping = entity_class_dimension_parameter_default_value_export( + root_mapping = entity_dimension_parameter_default_value_export( entity_class_position=0, definition_position=1, dimension_positions=[2, 3], @@ -152,7 +152,7 @@ def test_export_with_two_dimensions(self): import_relationships(self._db_map, (("rc", ("o11", "o21")), ("rc", ("o12", "o21")))) import_relationship_parameter_values(self._db_map, (("rc", ("o11", "o21"), "rc_p", "dummy"),)) self._db_map.commit_session("Add test data.") - root_mapping = entity_element_parameter_export( + root_mapping = entity_dimension_parameter_value_export( entity_class_position=0, definition_position=1, value_list_position=Position.hidden, @@ -247,7 +247,7 @@ def test_decrease_dimensions(self): class TestSetParameterDimensions(unittest.TestCase): def test_set_dimensions_from_zero_to_one(self): - root_mapping = entity_parameter_export() + root_mapping = entity_parameter_value_export() set_parameter_dimensions(root_mapping, 1) expected_types = [ ExpandedParameterValueMapping, @@ -259,7 +259,7 @@ def test_set_dimensions_from_zero_to_one(self): self.assertIsInstance(mapping, expected_type) def test_set_default_value_dimensions_from_zero_to_one(self): - root_mapping = entity_class_parameter_default_value_export() + root_mapping = entity_parameter_default_value_export() set_parameter_default_value_dimensions(root_mapping, 1) expected_types = [ ExpandedParameterDefaultValueMapping, @@ -271,21 +271,21 @@ def test_set_default_value_dimensions_from_zero_to_one(self): self.assertIsInstance(mapping, expected_type) def test_set_dimensions_from_one_to_zero(self): - root_mapping = entity_parameter_export(index_name_positions=[0], index_positions=[1]) + root_mapping = entity_parameter_value_export(index_name_positions=[0], index_positions=[1]) set_parameter_dimensions(root_mapping, 0) expected_types = [ParameterValueMapping, ParameterValueTypeMapping] for expected_type, mapping in zip(expected_types, reversed(root_mapping.flatten())): self.assertIsInstance(mapping, expected_type) def test_set_default_value_dimensions_from_one_to_zero(self): - root_mapping = entity_class_parameter_default_value_export(index_name_positions=[0], index_positions=[1]) + root_mapping = entity_parameter_default_value_export(index_name_positions=[0], index_positions=[1]) set_parameter_default_value_dimensions(root_mapping, 0) expected_types = [ParameterDefaultValueMapping, ParameterDefaultValueTypeMapping] for expected_type, mapping in zip(expected_types, reversed(root_mapping.flatten())): self.assertIsInstance(mapping, expected_type) def test_set_dimensions_from_one_to_two(self): - root_mapping = entity_parameter_export(index_name_positions=[0], index_positions=[1]) + root_mapping = entity_parameter_value_export(index_name_positions=[0], index_positions=[1]) set_parameter_dimensions(root_mapping, 2) expected_types = [ ExpandedParameterValueMapping, @@ -299,7 +299,7 @@ def test_set_dimensions_from_one_to_two(self): self.assertIsInstance(mapping, expected_type) def test_set_default_value_dimensions_from_one_to_two(self): - root_mapping = entity_class_parameter_default_value_export(index_name_positions=[0], index_positions=[1]) + root_mapping = entity_parameter_default_value_export(index_name_positions=[0], index_positions=[1]) set_parameter_default_value_dimensions(root_mapping, 2) expected_types = [ ExpandedParameterDefaultValueMapping, @@ -313,7 +313,7 @@ def test_set_default_value_dimensions_from_one_to_two(self): self.assertIsInstance(mapping, expected_type) def test_set_dimensions_from_two_to_one(self): - root_mapping = entity_parameter_export(index_name_positions=[0, 2], index_positions=[1, 3]) + root_mapping = entity_parameter_value_export(index_name_positions=[0, 2], index_positions=[1, 3]) set_parameter_dimensions(root_mapping, 1) expected_types = [ ExpandedParameterValueMapping, @@ -325,7 +325,7 @@ def test_set_dimensions_from_two_to_one(self): self.assertIsInstance(mapping, expected_type) def test_set_default_value_dimensions_from_two_to_one(self): - root_mapping = entity_class_parameter_default_value_export(index_name_positions=[0, 2], index_positions=[1, 3]) + root_mapping = entity_parameter_default_value_export(index_name_positions=[0, 2], index_positions=[1, 3]) set_parameter_default_value_dimensions(root_mapping, 1) expected_types = [ ExpandedParameterDefaultValueMapping, diff --git a/tests/import_mapping/test_import_mapping.py b/tests/import_mapping/test_import_mapping.py index 49c8e60f..c2830af9 100644 --- a/tests/import_mapping/test_import_mapping.py +++ b/tests/import_mapping/test_import_mapping.py @@ -1506,7 +1506,7 @@ def test_read_data_with_two_mappings_with_different_read_start_row(self): self.assertEqual(out, expected) def test_read_object_class_with_table_name_as_class_name(self): - input_data = [["Object names"], ["object 1"], ["object 2"]] + input_data = [["Entity names"], ["object 1"], ["object 2"]] data = iter(input_data) data_header = next(data) mapping = { diff --git a/tests/spine_io/exporters/test_gdx_writer.py b/tests/spine_io/exporters/test_gdx_writer.py index 6e99a0ae..9be5f6ff 100644 --- a/tests/spine_io/exporters/test_gdx_writer.py +++ b/tests/spine_io/exporters/test_gdx_writer.py @@ -35,7 +35,7 @@ Map, ) from spinedb_api.mapping import Position, unflatten -from spinedb_api.export_mapping import entity_export, entity_parameter_export, entity_export +from spinedb_api.export_mapping import entity_export, entity_parameter_value_export, entity_export from spinedb_api.export_mapping.export_mapping import FixedValueMapping @@ -102,7 +102,7 @@ def test_write_parameters(self): import_objects(db_map, (("oc", "o1"),)) import_object_parameter_values(db_map, (("oc", "o1", "p", 2.3),)) db_map.commit_session("Add test data.") - root_mapping = entity_parameter_export( + root_mapping = entity_parameter_value_export( entity_class_position=Position.table_name, entity_position=0, value_position=1 ) mappings = root_mapping.flatten() @@ -126,7 +126,7 @@ def test_non_numerical_parameter_value_raises_writer_expection(self): import_objects(db_map, (("oc", "o1"),)) import_object_parameter_values(db_map, (("oc", "o1", "p", "text"),)) db_map.commit_session("Add test data.") - root_mapping = entity_parameter_export( + root_mapping = entity_parameter_value_export( entity_class_position=Position.table_name, entity_position=0, value_position=1 ) mappings = root_mapping.flatten() @@ -145,7 +145,7 @@ def test_empty_parameter(self): import_objects(db_map, (("oc", "o1"),)) import_object_parameter_values(db_map, (("oc", "o1", "p", Map([], [], str)),)) db_map.commit_session("Add test data.") - root_mapping = entity_parameter_export( + root_mapping = entity_parameter_value_export( entity_class_position=Position.table_name, entity_position=0, value_position=1 ) mappings = root_mapping.flatten() @@ -170,7 +170,7 @@ def test_write_scalars(self): import_objects(db_map, (("oc", "o1"),)) import_object_parameter_values(db_map, (("oc", "o1", "p", 2.3),)) db_map.commit_session("Add test data.") - root_mapping = entity_parameter_export(entity_class_position=Position.table_name, value_position=0) + root_mapping = entity_parameter_value_export(entity_class_position=Position.table_name, value_position=0) with TemporaryDirectory() as temp_dir: file_path = Path(temp_dir, "test_write_scalars.gdx") writer = GdxWriter(str(file_path), self._gams_dir) @@ -238,7 +238,7 @@ def test_parameter_value_non_convertible_to_float_raises_WriterException(self): import_objects(db_map, (("oc", "o"), ("oc", "p"))) import_object_parameter_values(db_map, (("oc", "o", "param", "text"), ("oc", "p", "param", 2.3))) db_map.commit_session("Add test data.") - root_mapping = entity_parameter_export( + root_mapping = entity_parameter_value_export( entity_class_position=Position.hidden, definition_position=Position.table_name, entity_position=0, @@ -259,7 +259,7 @@ def test_non_string_set_element_raises_WriterException(self): import_objects(db_map, (("oc", "o"), ("oc", "p"))) import_object_parameter_values(db_map, (("oc", "o", "param", 2.3), ("oc", "p", "param", "text"))) db_map.commit_session("Add test data.") - root_mapping = entity_parameter_export( + root_mapping = entity_parameter_value_export( entity_class_position=Position.hidden, definition_position=Position.table_name, entity_position=0, @@ -290,7 +290,7 @@ def test_special_value_conversions(self): ), ) db_map.commit_session("Add test data.") - root_mapping = entity_parameter_export( + root_mapping = entity_parameter_value_export( entity_class_position=Position.table_name, entity_position=0, definition_position=1, value_position=2 ) mappings = root_mapping.flatten() From 8e73a8089dd8ff8cab44852277b710c3afffbcc1 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 5 Apr 2023 12:13:43 +0200 Subject: [PATCH 021/317] Introduce id generator and order entities in import mapping so it works Since now we import objects and relationships together via the "entities" key, we need to make sure the elements are in the right order so 0-dimensional entities are processed before multidimensional ones (because the latter may need the former). This implies even assigning an id to the 0-dimensional ones *before* they are actually added to the db, and this is done with a context manager. --- spinedb_api/db_mapping_add_mixin.py | 58 +++++-- spinedb_api/db_mapping_check_mixin.py | 24 +-- spinedb_api/import_functions.py | 149 ++++++++-------- spinedb_api/import_mapping/generator.py | 18 +- spinedb_api/import_mapping/import_mapping.py | 12 +- tests/import_mapping/test_generator.py | 48 +++--- tests/import_mapping/test_import_mapping.py | 172 +++++++++---------- tests/test_DiffDatabaseMapping.py | 16 +- 8 files changed, 269 insertions(+), 228 deletions(-) diff --git a/spinedb_api/db_mapping_add_mixin.py b/spinedb_api/db_mapping_add_mixin.py index 18b647f8..52b6e17b 100644 --- a/spinedb_api/db_mapping_add_mixin.py +++ b/spinedb_api/db_mapping_add_mixin.py @@ -17,6 +17,7 @@ # TODO: improve docstrings from datetime import datetime +from contextlib import contextmanager from sqlalchemy import func, Table, Column, Integer, String, null, select from sqlalchemy.exc import DBAPIError from .exception import SpineDBAPIError @@ -25,6 +26,20 @@ class DatabaseMappingAddMixin: """Provides methods to perform ``INSERT`` operations over a Spine db.""" + class _IdGenerator: + def __init__(self, next_id): + self._next_id = next_id + + @property + def next_id(self): + return self._next_id + + def __call__(self): + try: + return self._next_id + finally: + self._next_id += 1 + def __init__(self, *args, **kwargs): """Initialize class.""" super().__init__(*args, **kwargs) @@ -59,22 +74,17 @@ def __init__(self, *args, **kwargs): # Some other concurrent process must have beaten us to create the table self._next_id = Table("next_id", self._metadata, autoload=True) - def _add_commit_id_and_ids(self, tablename, *items): - if not items: - return [], set() - ids = self._reserve_ids(tablename, len(items)) - commit_id = self._make_commit_id() - for id_, item in zip(ids, items): - item["commit_id"] = commit_id - item["id"] = id_ + @contextmanager + def generate_ids(self, tablename): + """Manages id generation for new items to be added to the db. - def _reserve_ids(self, tablename, count): - if self.committing: - return self._do_reserve_ids(self.connection, tablename, count) - with self.engine.begin() as connection: - return self._do_reserve_ids(connection, tablename, count) + Args: + tablename (str): the table to which items will be added - def _do_reserve_ids(self, connection, tablename, count): + Yields: + self._IdGenerator: an object that generates a new id every time it is called. + """ + connection = self.connection if self.committing else self.engine.connect() fieldname = { "entity_class": "entity_class_id", "object_class": "entity_class_id", @@ -113,9 +123,23 @@ def _do_reserve_ids(self, connection, tablename, count): select_max_id = select([func.max(getattr(table.c, id_col))]) max_id = connection.execute(select_max_id).scalar() next_id = max_id + 1 if max_id else 1 - new_next_id = next_id + count - connection.execute(stmt, {"user": self.username, "date": datetime.utcnow(), fieldname: new_next_id}) - return range(next_id, new_next_id) + gen = self._IdGenerator(next_id) + try: + yield gen + finally: + connection.execute(stmt, {"user": self.username, "date": datetime.utcnow(), fieldname: gen.next_id}) + if not self.committing: + connection.close() + + def _add_commit_id_and_ids(self, tablename, *items): + if not items: + return [], set() + commit_id = self._make_commit_id() + with self.generate_ids(tablename) as new_id: + for item in items: + item["commit_id"] = commit_id + if "id" not in item: + item["id"] = new_id() def _readd_items(self, tablename, *items): """Add known items to database.""" diff --git a/spinedb_api/db_mapping_check_mixin.py b/spinedb_api/db_mapping_check_mixin.py index a1dc8ef6..205f8067 100644 --- a/spinedb_api/db_mapping_check_mixin.py +++ b/spinedb_api/db_mapping_check_mixin.py @@ -635,21 +635,23 @@ def check_parameter_definitions(self, *items, for_update=False, strict=False, ca (x.entity_class_id, x.parameter_name): x.id for x in cache.get("parameter_definition", {}).values() } entity_class_ids = {x.id for x in cache.get("entity_class", {}).values()} + object_class_ids = {x.id for x in cache.get("entity_class", {}).values() if not x.dimension_id_list} + relationship_class_ids = {x.id for x in cache.get("entity_class", {}).values() if x.dimension_id_list} parameter_value_lists = {x.id: x.value_id_list for x in cache.get("parameter_value_list", {}).values()} list_values = {x.id: from_database(x.value, x.type) for x in cache.get("list_value", {}).values()} for item in items: - object_class_id = item.get("object_class_id") - relationship_class_id = item.get("relationship_class_id") - if object_class_id and relationship_class_id: - e = SpineIntegrityError("Can't associate a parameter to both an object and a relationship class.") - if strict: - raise e - intgr_error_log.append(e) - continue - entity_class_id = object_class_id or relationship_class_id - if "entity_class_id" not in item and entity_class_id is not None: - item["entity_class_id"] = entity_class_id try: + object_class_id = item.get("object_class_id") + relationship_class_id = item.get("relationship_class_id") + if object_class_id and relationship_class_id: + raise SpineIntegrityError("Can't associate a parameter to both an object and a relationship class.") + if object_class_id and object_class_id not in object_class_ids: + raise SpineIntegrityError("Invalid object class id.") + if relationship_class_id and relationship_class_id not in relationship_class_ids: + raise SpineIntegrityError("Invalid relationship class id.") + entity_class_id = object_class_id or relationship_class_id + if "entity_class_id" not in item and entity_class_id is not None: + item["entity_class_id"] = entity_class_id if ( for_update and item["id"] in parameter_definition_ids_with_values diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 3bed39e6..57a4a127 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -366,41 +366,43 @@ def import_entity_classes(db_map, data, make_cache=None): def _get_entity_classes_for_import(db_map, data, make_cache): - # FIXME: We need to find a way to set the ids for newly added single dimensional entities - # so that they can be used in this same function for adding multi dimensional ones cache = make_cache({"entity_class"}, include_ancestors=True) entity_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values()} checked = set() error_log = [] to_add = [] to_update = [] - for name, *optionals in data: - if name in checked: - continue - ec_id = entity_class_ids.pop(name, None) - item = ( - cache["entity_class"][ec_id]._asdict() - if ec_id is not None - else {"name": name, "description": None, "display_icon": None} - ) - item.update(dict(zip(("dimension_name_list", "description", "display_icon"), optionals))) - item["dimension_id_list"] = tuple(entity_class_ids.get(x, None) for x in item.get("dimension_name_list", ())) - try: - check_entity_class(item, entity_class_ids) - except SpineIntegrityError as e: - error_log.append( - ImportErrorLogItem(f"Could not import entity class '{name}': {e.msg}", db_type="entity_class") + with db_map.generate_ids("entity_class") as new_entity_class_id: + for name, *optionals in data: + if name in checked: + continue + ec_id = entity_class_ids.pop(name, None) + item = ( + cache["entity_class"][ec_id]._asdict() + if ec_id is not None + else {"name": name, "description": None, "display_icon": None} ) - continue - finally: + item.update(dict(zip(("dimension_name_list", "description", "display_icon"), optionals))) + item["dimension_id_list"] = tuple( + entity_class_ids.get(x, None) for x in item.get("dimension_name_list", ()) + ) + try: + check_entity_class(item, entity_class_ids) + except SpineIntegrityError as e: + error_log.append( + ImportErrorLogItem(f"Could not import entity class '{name}': {e.msg}", db_type="entity_class") + ) + continue + finally: + if ec_id is not None: + entity_class_ids[name] = ec_id + checked.add(name) if ec_id is not None: - entity_class_ids[name] = ec_id - checked.add(name) - if ec_id is not None: - item["id"] = ec_id - to_update.append(item) - else: - to_add.append(item) + item["id"] = ec_id + to_update.append(item) + else: + item["id"] = entity_class_ids[name] = new_entity_class_id() + to_add.append(item) return to_add, to_update, error_log @@ -447,57 +449,62 @@ def _get_entities_for_import(db_map, data, make_cache): entity_classes = { x.id: {"dimension_id_list": x.dimension_id_list, "name": x.name} for x in cache.get("entity_class", {}).values() } - entity_ids = {(x["name"], x["class_id"]): id_ for id_, x in entities.items()} entity_class_ids = {x["name"]: id_ for id_, x in entity_classes.items()} dimension_id_lists = {id_: x["dimension_id_list"] for id_, x in entity_classes.items()} error_log = [] to_add = [] to_update = [] checked = set() - for class_name, ent_name_or_el_names, *optionals in data: - ec_id = entity_class_ids.get(class_name, None) - dim_ids = dimension_id_lists.get(ec_id, ()) - if isinstance(ent_name_or_el_names, str): - el_ids = () - e_key = ent_name_or_el_names - else: - el_ids = tuple(entity_ids.get((name, dim_id), None) for name, dim_id in zip(ent_name_or_el_names, dim_ids)) - e_key = el_ids - if (ec_id, e_key) in checked: - continue - e_id = entity_ids_per_el_id_lst.pop((ec_id, el_ids), None) - if e_id is not None: - e_name = cache["entity"][e_id].name - entity_ids_per_name.pop((e_id, e_name)) - else: - e_name = _make_unique_entity_name(ec_id, class_name, ent_name_or_el_names, entity_ids_per_name) - item = ( - cache["entity"][e_id]._asdict() - if e_id is not None - else { - "name": e_name, - "class_id": ec_id, - "element_id_list": el_ids, - "dimension_id_list": dim_ids, - } - ) - item.update(dict(zip(("description",), optionals))) - try: - check_entity(item, entity_ids_per_name, entity_ids_per_el_id_lst, entity_classes, entities) - except SpineIntegrityError as e: - msg = f"Could not import entity {tuple(ent_name_or_el_names)} into '{class_name}': {e.msg}" - error_log.append(ImportErrorLogItem(msg=msg, db_type="relationship")) - continue - finally: + with db_map.generate_ids("entity") as new_entity_id: + for class_name, ent_name_or_el_names, *optionals in data: + ec_id = entity_class_ids.get(class_name, None) + dim_ids = dimension_id_lists.get(ec_id, ()) + if isinstance(ent_name_or_el_names, str): + el_ids = () + e_key = ent_name_or_el_names + else: + el_ids = tuple( + entity_ids_per_name.get((dim_id, name), None) for dim_id, name in zip(dim_ids, ent_name_or_el_names) + ) + e_key = el_ids + if (ec_id, e_key) in checked: + continue + e_id = entity_ids_per_el_id_lst.pop((ec_id, el_ids), None) if e_id is not None: - entity_ids_per_el_id_lst[ec_id, el_ids] = e_id - entity_ids_per_name[ec_id, e_name] = e_id - checked.add((ec_id, e_key)) - if e_id is not None: - item["id"] = e_id - to_update.append(item) - else: - to_add.append(item) + e_name = cache["entity"][e_id].name + entity_ids_per_name.pop((e_id, e_name)) + else: + e_name = _make_unique_entity_name(ec_id, class_name, ent_name_or_el_names, entity_ids_per_name) + item = ( + cache["entity"][e_id]._asdict() + if e_id is not None + else { + "name": e_name, + "class_id": ec_id, + "element_id_list": el_ids, + "dimension_id_list": dim_ids, + } + ) + item.update(dict(zip(("description",), optionals))) + print(item) + try: + check_entity(item, entity_ids_per_name, entity_ids_per_el_id_lst, entity_classes, entities) + except SpineIntegrityError as e: + msg = f"Could not import entity {tuple(ent_name_or_el_names)} into '{class_name}': {e.msg}" + error_log.append(ImportErrorLogItem(msg=msg, db_type="relationship")) + continue + finally: + if e_id is not None: + entity_ids_per_el_id_lst[ec_id, el_ids] = entity_ids_per_name[ec_id, e_name] = e_id + checked.add((ec_id, e_key)) + if e_id is not None: + item["id"] = e_id + to_update.append(item) + else: + item["id"] = entity_ids_per_el_id_lst[ec_id, el_ids] = entity_ids_per_name[ + ec_id, e_name + ] = new_entity_id() + to_add.append(item) return to_add, to_update, error_log diff --git a/spinedb_api/import_mapping/generator.py b/spinedb_api/import_mapping/generator.py index 502bbac9..e1e5219c 100644 --- a/spinedb_api/import_mapping/generator.py +++ b/spinedb_api/import_mapping/generator.py @@ -18,6 +18,7 @@ """ from copy import deepcopy +from operator import itemgetter from .import_mapping_compat import import_mapping_from_dict from .import_mapping import ImportMapping, check_validity from ..mapping import Position @@ -160,6 +161,7 @@ def get_mapped_data( full_row.append(row[column_pos]) mapping.import_row(full_row, read_state, mapped_data) _make_entity_classes(mapped_data) + _make_entities(mapped_data) _make_parameter_values(mapped_data, unparse_value) return mapped_data, errors @@ -268,11 +270,19 @@ def _make_entity_classes(mapped_data): rows = mapped_data.get("entity_classes") if rows is None: return - full_rows = set() - for class_name, dimension_names in rows.items(): + rows = [(class_name, tuple(dimension_names)) for class_name, dimension_names in rows.items()] + rows.sort(key=itemgetter(1)) + mapped_data["entity_classes"] = final_rows = [] + for class_name, dimension_names in rows: row = (class_name, tuple(dimension_names)) if dimension_names else (class_name,) - full_rows.add(row) - mapped_data["entity_classes"] = full_rows + final_rows.append(row) + + +def _make_entities(mapped_data): + rows = mapped_data.get("entities") + if rows is None: + return + mapped_data["entities"] = list(rows) def _make_parameter_values(mapped_data, unparse_value): diff --git a/spinedb_api/import_mapping/import_mapping.py b/spinedb_api/import_mapping/import_mapping.py index 176c49cf..b566ab07 100644 --- a/spinedb_api/import_mapping/import_mapping.py +++ b/spinedb_api/import_mapping/import_mapping.py @@ -391,7 +391,7 @@ def _import_row(self, source_data, state, mapped_data): entity_name = state[ImportKey.ENTITY_NAME] = str(source_data) if isinstance(self.child, EntityGroupMapping): raise KeyError(ImportKey.MEMBER_NAME) - mapped_data.setdefault("entities", set()).add((entity_class_name, entity_name)) + mapped_data.setdefault("entities", {})[entity_class_name, entity_name] = None class EntityMetadataMapping(ImportMapping): @@ -422,8 +422,9 @@ def _import_row(self, source_data, state, mapped_data): member_name = str(source_data) mapped_data.setdefault("entity_groups", set()).add((entity_class_name, group_name, member_name)) if self.import_entities: - entities = (entity_class_name, group_name), (entity_class_name, member_name) - mapped_data.setdefault("entities", set()).update(entities) + entities = mapped_data.setdefault("entities", {}) + entities[entity_class_name, group_name] = None + entities[entity_class_name, member_name] = None raise KeyFix(ImportKey.MEMBER_NAME) @@ -464,10 +465,9 @@ def _import_row(self, source_data, state, mapped_data): k = len(element_names) - 1 dimension_name = dimension_names[k] mapped_data.setdefault("entity_classes", {}).update({dimension_name: ()}) - mapped_data.setdefault("entities", set()).add((dimension_name, element_name)) + mapped_data.setdefault("entities", {})[dimension_name, element_name] = None if len(element_names) == state[ImportKey.DIMENSION_COUNT]: - entities = mapped_data.setdefault("entities", set()) - entities.add((entity_class_name, tuple(element_names))) + mapped_data.setdefault("entities", {})[entity_class_name, tuple(element_names)] = None raise KeyFix(ImportKey.ELEMENT_NAMES) raise KeyError(ImportKey.ELEMENT_NAMES) diff --git a/tests/import_mapping/test_generator.py b/tests/import_mapping/test_generator.py index 44d0b65f..299f1b49 100644 --- a/tests/import_mapping/test_generator.py +++ b/tests/import_mapping/test_generator.py @@ -72,10 +72,10 @@ def test_returns_appropriate_error_if_last_row_is_empty(self): mapped_data, { 'alternatives': {'Base'}, - 'entity_classes': {('Object',)}, + 'entity_classes': [('Object',)], 'parameter_values': [['Object', 'data', 'Parameter', Map(["T1", "T2"], [5.0, 99.0]), 'Base']], 'parameter_definitions': [('Object', 'Parameter')], - 'entities': {('Object', 'data')}, + 'entities': [('Object', 'data')], }, ) @@ -104,10 +104,10 @@ def test_convert_functions_get_expanded_over_last_defined_column_in_pivoted_data mapped_data, { 'alternatives': {'Base'}, - 'entity_classes': {('Object',)}, + 'entity_classes': [('Object',)], 'parameter_values': [['Object', 'data', 'Parameter', Map(["T1", "T2"], [5.0, 99.0]), 'Base']], 'parameter_definitions': [('Object', 'Parameter')], - 'entities': {('Object', 'data')}, + 'entities': [('Object', 'data')], }, ) @@ -135,10 +135,10 @@ def test_read_start_row_skips_rows_in_pivoted_data(self): self.assertEqual( mapped_data, { - 'entity_classes': {('klass',)}, + 'entity_classes': [('klass',)], 'parameter_values': [['klass', 'kloss', 'Parameter_2', Map(["T1", "T2"], [2.3, 23.0])]], 'parameter_definitions': [('klass', 'Parameter_2')], - 'entities': {('klass', 'kloss')}, + 'entities': [('klass', 'kloss')], }, ) @@ -189,10 +189,10 @@ def test_map_without_values_is_ignored_and_not_interpreted_as_null(self): mapped_data, { "alternatives": {"base"}, - "entity_classes": {("o",)}, + 'entity_classes': [("o",)], "parameter_definitions": [("o", "parameter_name")], "parameter_values": [], - "entities": {("o", "o1")}, + "entities": [("o", "o1")], }, ) @@ -225,16 +225,16 @@ def test_import_object_works_with_multiple_relationship_object_imports(self): mapped_data, { "alternatives": {"base"}, - "entity_classes": {("o",), ("q",), ("o_to_q", ("o", "q"))}, - "entities": { + 'entity_classes': [("o",), ("q",), ("o_to_q", ("o", "q"))], + "entities": [ ("o", "o1"), - ("o", "o2"), ("q", "q1"), - ("q", "q2"), ("o_to_q", ("o1", "q1")), - ("o_to_q", ("o1", "q2")), + ("o", "o2"), + ("q", "q2"), ("o_to_q", ("o2", "q2")), - }, + ("o_to_q", ("o1", "q2")), + ], "parameter_definitions": [("o_to_q", "param")], "parameter_values": [ ["o_to_q", ("o1", "q1"), "param", Map(["t1", "t2"], [11, 22], index_name="time"), "base"], @@ -269,10 +269,10 @@ def test_default_convert_function_in_column_convert_functions(self): self.assertEqual( mapped_data, { - "entity_classes": {("klass",)}, + 'entity_classes': [("klass",)], "parameter_values": [["klass", "kloss", "Parameter_2", Map(["T1", "T2"], [2.3, 23.0])]], "parameter_definitions": [("klass", "Parameter_2")], - "entities": {("klass", "kloss")}, + "entities": [("klass", "kloss")], }, ) @@ -297,10 +297,10 @@ def test_identity_function_is_used_as_convert_function_when_no_convert_functions self.assertEqual( mapped_data, { - "entity_classes": {("klass",)}, + 'entity_classes': [("klass",)], "parameter_values": [["klass", "kloss", "Parameter_2", Map(["T1", "T2"], ["2.3", "23.0"])]], "parameter_definitions": [("klass", "Parameter_2")], - "entities": {("klass", "kloss")}, + "entities": [("klass", "kloss")], }, ) @@ -327,10 +327,10 @@ def test_last_convert_function_gets_used_as_default_convert_function_when_no_def self.assertEqual( mapped_data, { - "entity_classes": {("klass",)}, + 'entity_classes': [("klass",)], "parameter_values": [["klass", "kloss", "Parameter_2", Map(["T1", "T2"], [2.3, 23.0])]], "parameter_definitions": [("klass", "Parameter_2")], - "entities": {("klass", "kloss")}, + "entities": [("klass", "kloss")], }, ) @@ -360,13 +360,13 @@ def test_array_parameters_get_imported_correctly_when_objects_are_in_header(self mapped_data, { "alternatives": {"Base"}, - "entity_classes": {("class",)}, + 'entity_classes': [("class",)], "parameter_values": [ ["class", "object_1", "param", Array([-1.1, 1.1]), "Base"], ["class", "object_2", "param", Array([2.3, -2.3]), "Base"], ], "parameter_definitions": [("class", "param")], - "entities": {("class", "object_1"), ("class", "object_2")}, + "entities": [("class", "object_1"), ("class", "object_2")], }, ) @@ -396,13 +396,13 @@ def test_arrays_get_imported_correctly_when_objects_are_in_header_and_alternativ mapped_data, { "alternatives": {"Base"}, - "entity_classes": {("Gadget",)}, + 'entity_classes': [("Gadget",)], "parameter_values": [ ["Gadget", "object_1", "data", Array([-1.1, 1.1]), "Base"], ["Gadget", "object_2", "data", Array([2.3, -2.3]), "Base"], ], "parameter_definitions": [("Gadget", "data")], - "entities": {("Gadget", "object_1"), ("Gadget", "object_2")}, + "entities": [("Gadget", "object_1"), ("Gadget", "object_2")], }, ) diff --git a/tests/import_mapping/test_import_mapping.py b/tests/import_mapping/test_import_mapping.py index c2830af9..79eb4c37 100644 --- a/tests/import_mapping/test_import_mapping.py +++ b/tests/import_mapping/test_import_mapping.py @@ -59,8 +59,8 @@ def test_convert_functions_float(self): param_def_mapping.flatten()[-1].position = 1 mapped_data, _ = get_mapped_data(data, [mapping], column_convert_fns=column_convert_fns) expected = { - 'entity_classes': {('a',)}, - 'entities': {('a', 'obj')}, + 'entity_classes': [('a',)], + 'entities': [('a', 'obj')], 'parameter_definitions': [('a', 'param', 1.2)], } self.assertEqual(mapped_data, expected) @@ -78,8 +78,8 @@ def test_convert_functions_str(self): param_def_mapping.flatten()[-1].position = 1 mapped_data, _ = get_mapped_data(data, [mapping], column_convert_fns=column_convert_fns) expected = { - 'entity_classes': {('a',)}, - 'entities': {('a', 'obj')}, + 'entity_classes': [('a',)], + 'entities': [('a', 'obj')], 'parameter_definitions': [('a', 'param', '1111.2222')], } self.assertEqual(mapped_data, expected) @@ -97,8 +97,8 @@ def test_convert_functions_bool(self): param_def_mapping.flatten()[-1].position = 1 mapped_data, _ = get_mapped_data(data, [mapping], column_convert_fns=column_convert_fns) expected = { - 'entity_classes': {('a',)}, - 'entities': {('a', 'obj')}, + 'entity_classes': [('a',)], + 'entities': [('a', 'obj')], 'parameter_definitions': [('a', 'param', False)], } self.assertEqual(mapped_data, expected) @@ -792,7 +792,7 @@ def test_read_iterator_with_row_with_all_Nones(self): [None, None, None, None], ["oc2", "obj2", "parameter_name2", 2], ] - expected = {"entity_classes": {("oc2",)}} + expected = {"entity_classes": [("oc2",)]} data = iter(input_data) data_header = next(data) @@ -805,7 +805,7 @@ def test_read_iterator_with_row_with_all_Nones(self): def test_read_iterator_with_None(self): input_data = [["object_class", "object", "parameter", "value"], None, ["oc2", "obj2", "parameter_name2", 2]] - expected = {"entity_classes": {("oc2",)}} + expected = {"entity_classes": [("oc2",)]} data = iter(input_data) data_header = next(data) @@ -823,8 +823,8 @@ def test_read_flat_file(self): ["oc2", "obj2", "parameter_name2", 2], ] expected = { - "entity_classes": {("oc1",), ("oc2",)}, - "entities": {("oc1", "obj1"), ("oc2", "obj2")}, + "entity_classes": [("oc1",), ("oc2",)], + "entities": [("oc1", "obj1"), ("oc2", "obj2")], "parameter_definitions": [("oc1", "parameter_name1"), ("oc2", "parameter_name2")], "parameter_values": [["oc1", "obj1", "parameter_name1", 1], ["oc2", "obj2", "parameter_name2", 2]], } @@ -850,8 +850,8 @@ def test_read_flat_file_array(self): ["oc1", "obj1", "parameter_name1", 2], ] expected = { - "entity_classes": {("oc1",)}, - "entities": {("oc1", "obj1")}, + "entity_classes": [("oc1",)], + "entities": [("oc1", "obj1")], "parameter_definitions": [("oc1", "parameter_name1")], "parameter_values": [["oc1", "obj1", "parameter_name1", Array([1, 2])]], } @@ -877,8 +877,8 @@ def test_read_flat_file_array_with_ed(self): ["oc1", "obj1", "parameter_name1", 2, 1], ] expected = { - "entity_classes": {("oc1",)}, - "entities": {("oc1", "obj1")}, + "entity_classes": [("oc1",)], + "entities": [("oc1", "obj1")], "parameter_definitions": [("oc1", "parameter_name1")], "parameter_values": [["oc1", "obj1", "parameter_name1", Array([1, 2])]], } @@ -905,7 +905,7 @@ def test_read_flat_file_array_with_ed(self): def test_read_flat_file_with_column_name_reference(self): input_data = [["object", "parameter", "value"], ["obj1", "parameter_name1", 1], ["obj2", "parameter_name2", 2]] - expected = {"entity_classes": {("object",)}, "entities": {("object", "obj1"), ("object", "obj2")}} + expected = {"entity_classes": [("object",)], "entities": [("object", "obj1"), ("object", "obj2")]} data = iter(input_data) data_header = next(data) @@ -919,8 +919,8 @@ def test_read_flat_file_with_column_name_reference(self): def test_read_object_class_from_header_using_string_as_integral_index(self): input_data = [["object_class"], ["obj1"], ["obj2"]] expected = { - "entity_classes": {("object_class",)}, - "entities": {("object_class", "obj1"), ("object_class", "obj2")}, + "entity_classes": [("object_class",)], + "entities": [("object_class", "obj1"), ("object_class", "obj2")], } data = iter(input_data) @@ -935,8 +935,8 @@ def test_read_object_class_from_header_using_string_as_integral_index(self): def test_read_object_class_from_header_using_string_as_column_header_name(self): input_data = [["object_class"], ["obj1"], ["obj2"]] expected = { - "entity_classes": {("object_class",)}, - "entities": {("object_class", "obj1"), ("object_class", "obj2")}, + "entity_classes": [("object_class",)], + "entities": [("object_class", "obj1"), ("object_class", "obj2")], } data = iter(input_data) @@ -954,7 +954,7 @@ def test_read_object_class_from_header_using_string_as_column_header_name(self): def test_read_with_list_of_mappings(self): input_data = [["object", "parameter", "value"], ["obj1", "parameter_name1", 1], ["obj2", "parameter_name2", 2]] - expected = {"entity_classes": {("object",)}, "entities": {("object", "obj1"), ("object", "obj2")}} + expected = {"entity_classes": [("object",)], "entities": [("object", "obj1"), ("object", "obj2")]} data = iter(input_data) data_header = next(data) @@ -968,8 +968,8 @@ def test_read_with_list_of_mappings(self): def test_read_pivoted_parameters_from_header(self): input_data = [["object", "parameter_name1", "parameter_name2"], ["obj1", 0, 1], ["obj2", 2, 3]] expected = { - "entity_classes": {("object",)}, - "entities": {("object", "obj1"), ("object", "obj2")}, + "entity_classes": [("object",)], + "entities": [("object", "obj1"), ("object", "obj2")], "parameter_definitions": [("object", "parameter_name1"), ("object", "parameter_name2")], "parameter_values": [ ["object", "obj1", "parameter_name1", 0], @@ -1014,8 +1014,8 @@ def test_read_empty_pivot(self): def test_read_pivoted_parameters_from_data(self): input_data = [["object", "parameter_name1", "parameter_name2"], ["obj1", 0, 1], ["obj2", 2, 3]] expected = { - "entity_classes": {("object",)}, - "entities": {("object", "obj1"), ("object", "obj2")}, + "entity_classes": [("object",)], + "entities": [("object", "obj1"), ("object", "obj2")], "parameter_definitions": [("object", "parameter_name1"), ("object", "parameter_name2")], "parameter_values": [ ["object", "obj1", "parameter_name1", 0], @@ -1048,8 +1048,8 @@ def test_pivoted_value_has_actual_position(self): ["obj2", "T2", 22.0], ] expected = { - "entity_classes": {("timeline",)}, - "entities": {("timeline", "obj1"), ("timeline", "obj2")}, + "entity_classes": [("timeline",)], + "entities": [("timeline", "obj1"), ("timeline", "obj2")], "parameter_definitions": [("timeline", "value")], "alternatives": {"Base"}, "parameter_values": [ @@ -1078,8 +1078,8 @@ def test_import_objects_from_pivoted_data_when_they_lack_parameter_values(self): """Pivoted mapping works even when last mapping has valid position in columns.""" input_data = [["object", "is_skilled", "has_powers"], ["obj1", "yes", "no"], ["obj2", None, None]] expected = { - "entity_classes": {("node",)}, - "entities": {("node", "obj1"), ("node", "obj2")}, + "entity_classes": [("node",)], + "entities": [("node", "obj1"), ("node", "obj2")], "parameter_definitions": [("node", "is_skilled"), ("node", "has_powers")], "alternatives": {"Base"}, "parameter_values": [ @@ -1109,8 +1109,8 @@ def test_import_objects_from_pivoted_data_when_they_lack_map_type_parameter_valu ["obj1", "today", None, "yes"], ] expected = { - "entity_classes": {("node",)}, - "entities": {("node", "obj1")}, + "entity_classes": [("node",)], + "entities": [("node", "obj1")], "parameter_definitions": [("node", "is_skilled"), ("node", "has_powers")], "alternatives": {"Base"}, "parameter_values": [ @@ -1138,8 +1138,8 @@ def test_read_flat_file_with_extra_value_dimensions(self): input_data = [["object", "time", "parameter_name1"], ["obj1", "2018-01-01", 1], ["obj1", "2018-01-02", 2]] expected = { - "entity_classes": {("object",)}, - "entities": {("object", "obj1")}, + "entity_classes": [("object",)], + "entities": [("object", "obj1")], "parameter_definitions": [("object", "parameter_name1")], "parameter_values": [ [ @@ -1175,8 +1175,8 @@ def test_read_flat_file_with_parameter_definition(self): input_data = [["object", "time", "parameter_name1"], ["obj1", "2018-01-01", 1], ["obj1", "2018-01-02", 2]] expected = { - "entity_classes": {("object",)}, - "entities": {("object", "obj1")}, + "entity_classes": [("object",)], + "entities": [("object", "obj1")], "parameter_definitions": [("object", "parameter_name1")], } @@ -1202,8 +1202,8 @@ def test_read_flat_file_with_parameter_definition(self): def test_read_1dim_relationships(self): input_data = [["unit", "node"], ["u1", "n1"], ["u1", "n2"]] expected = { - "entity_classes": {("node_group", ("node",))}, - "entities": {("node_group", ("n1",)), ("node_group", ("n2",))}, + "entity_classes": [("node_group", ("node",))], + "entities": [("node_group", ("n1",)), ("node_group", ("n2",))], } data = iter(input_data) @@ -1223,8 +1223,8 @@ def test_read_1dim_relationships(self): def test_read_relationships(self): input_data = [["unit", "node"], ["u1", "n1"], ["u1", "n2"]] expected = { - "entity_classes": {("unit__node", ("unit", "node"))}, - "entities": {("unit__node", ("u1", "n1")), ("unit__node", ("u1", "n2"))}, + "entity_classes": [("unit__node", ("unit", "node"))], + "entities": [("unit__node", ("u1", "n1")), ("unit__node", ("u1", "n2"))], } data = iter(input_data) @@ -1247,8 +1247,8 @@ def test_read_relationships(self): def test_read_relationships_with_parameters(self): input_data = [["unit", "node", "rel_parameter"], ["u1", "n1", 0], ["u1", "n2", 1]] expected = { - "entity_classes": {("unit__node", ("unit", "node"))}, - "entities": {("unit__node", ("u1", "n1")), ("unit__node", ("u1", "n2"))}, + "entity_classes": [("unit__node", ("unit", "node"))], + "entities": [("unit__node", ("u1", "n1")), ("unit__node", ("u1", "n2"))], "parameter_definitions": [("unit__node", "rel_parameter")], "parameter_values": [ ["unit__node", ("u1", "n1"), "rel_parameter", 0], @@ -1277,15 +1277,14 @@ def test_read_relationships_with_parameters(self): def test_read_relationships_with_parameters2(self): input_data = [["nuts2", "Capacity", "Fueltype"], ["BE23", 268.0, "Bioenergy"], ["DE11", 14.0, "Bioenergy"]] expected = { - "entity_classes": {("nuts2",), ("fueltype",), ("nuts2__fueltype", ("nuts2", "fueltype"))}, - "entities": { + "entity_classes": [("nuts2",), ("fueltype",), ("nuts2__fueltype", ("nuts2", "fueltype"))], + "entities": [ ("nuts2", "BE23"), ("fueltype", "Bioenergy"), - ("nuts2", "DE11"), - ("fueltype", "Bioenergy"), ("nuts2__fueltype", ("BE23", "Bioenergy")), + ("nuts2", "DE11"), ("nuts2__fueltype", ("DE11", "Bioenergy")), - }, + ], "parameter_definitions": [("nuts2__fueltype", "capacity")], "parameter_values": [ ["nuts2__fueltype", ("BE23", "Bioenergy"), "capacity", 268.0], @@ -1322,8 +1321,8 @@ def test_read_relationships_with_parameters2(self): def test_read_parameter_header_with_only_one_parameter(self): input_data = [["object", "parameter_name1"], ["obj1", 0], ["obj2", 2]] expected = { - "entity_classes": {("object",)}, - "entities": {("object", "obj1"), ("object", "obj2")}, + "entity_classes": [("object",)], + "entities": [("object", "obj1"), ("object", "obj2")], "parameter_definitions": [("object", "parameter_name1")], "parameter_values": [ ["object", "obj1", "parameter_name1", 0], @@ -1348,8 +1347,8 @@ def test_read_parameter_header_with_only_one_parameter(self): def test_read_pivoted_parameters_from_data_with_skipped_column(self): input_data = [["object", "parameter_name1", "parameter_name2"], ["obj1", 0, 1], ["obj2", 2, 3]] expected = { - "entity_classes": {("object",)}, - "entities": {("object", "obj1"), ("object", "obj2")}, + "entity_classes": [("object",)], + "entities": [("object", "obj1"), ("object", "obj2")], "parameter_definitions": [("object", "parameter_name1")], "parameter_values": [ ["object", "obj1", "parameter_name1", 0], @@ -1374,15 +1373,15 @@ def test_read_pivoted_parameters_from_data_with_skipped_column(self): def test_read_relationships_and_import_objects(self): input_data = [["unit", "node"], ["u1", "n1"], ["u2", "n2"]] expected = { - "entity_classes": {("unit",), ("node",), ("unit__node", ("unit", "node"))}, - "entities": { + "entity_classes": [("unit",), ("node",), ("unit__node", ("unit", "node"))], + "entities": [ ("unit", "u1"), ("node", "n1"), + ("unit__node", ("u1", "n1")), ("unit", "u2"), ("node", "n2"), - ("unit__node", ("u1", "n1")), ("unit__node", ("u2", "n2")), - }, + ], } data = iter(input_data) @@ -1407,9 +1406,9 @@ def test_read_relationships_parameter_values_with_extra_dimensions(self): input_data = [["", "a", "b"], ["", "c", "d"], ["", "e", "f"], ["a", 2, 3], ["b", 4, 5]] expected = { - "entity_classes": {("unit__node", ("unit", "node"))}, + "entity_classes": [("unit__node", ("unit", "node"))], "parameter_definitions": [("unit__node", "e"), ("unit__node", "f")], - "entities": {("unit__node", ("a", "c")), ("unit__node", ("b", "d"))}, + "entities": [("unit__node", ("a", "c")), ("unit__node", ("b", "d"))], "parameter_values": [ ["unit__node", ("a", "c"), "e", Map(["a", "b"], [2, 4])], ["unit__node", ("b", "d"), "f", Map(["a", "b"], [3, 5])], @@ -1444,8 +1443,8 @@ def test_read_data_with_read_start_row(self): ["oc2", "obj2", "parameter_name2", 2], ] expected = { - "entity_classes": {("oc1",), ("oc2",)}, - "entities": {("oc1", "obj1"), ("oc2", "obj2")}, + "entity_classes": [("oc1",), ("oc2",)], + "entities": [("oc1", "obj1"), ("oc2", "obj2")], "parameter_definitions": [("oc1", "parameter_name1"), ("oc2", "parameter_name2")], "parameter_values": [["oc1", "obj1", "parameter_name1", 1], ["oc2", "obj2", "parameter_name2", 2]], } @@ -1473,8 +1472,8 @@ def test_read_data_with_two_mappings_with_different_read_start_row(self): ["oc1_obj2", "oc2_obj2", 2, 4], ] expected = { - "entity_classes": {("oc1",), ("oc2",)}, - "entities": {("oc1", "oc1_obj1"), ("oc1", "oc1_obj2"), ("oc2", "oc2_obj2")}, + "entity_classes": [("oc1",), ("oc2",)], + "entities": [("oc1", "oc1_obj1"), ("oc1", "oc1_obj2"), ("oc2", "oc2_obj2")], "parameter_definitions": [("oc1", "parameter_class1"), ("oc2", "parameter_class2")], "parameter_values": [ ["oc1", "oc1_obj1", "parameter_class1", 1], @@ -1516,8 +1515,8 @@ def test_read_object_class_with_table_name_as_class_name(self): } out, errors = get_mapped_data(data, [mapping], data_header, "class name") expected = { - "entity_classes": {("class name",)}, - "entities": {("class name", "object 1"), ("class name", "object 2")}, + "entity_classes": [("class name",)], + "entities": [("class name", "object 1"), ("class name", "object 2")], } self.assertFalse(errors) self.assertEqual(out, expected) @@ -1541,8 +1540,8 @@ def test_read_flat_map_from_columns(self): out, errors = get_mapped_data(data, [mapping], data_header) expected_map = Map(["key1", "key2"], [-2, -1]) expected = { - "entity_classes": {("object_class",)}, - "entities": {("object_class", "object")}, + "entity_classes": [("object_class",)], + "entities": [("object_class", "object")], "parameter_values": [["object_class", "object", "parameter", expected_map]], "parameter_definitions": [("object_class", "parameter")], } @@ -1568,8 +1567,8 @@ def test_read_nested_map_from_columns(self): out, errors = get_mapped_data(data, [mapping], data_header) expected_map = Map(["key11", "key21"], [Map(["key12"], [-2]), Map(["key22"], [-1])]) expected = { - "entity_classes": {("object_class",)}, - "entities": {("object_class", "object")}, + "entity_classes": [("object_class",)], + "entities": [("object_class", "object")], "parameter_values": [["object_class", "object", "parameter", expected_map]], "parameter_definitions": [("object_class", "parameter")], } @@ -1612,8 +1611,8 @@ def test_read_uneven_nested_map_from_columns(self): ], ) expected = { - "entity_classes": {("object_class",)}, - "entities": {("object_class", "object")}, + "entity_classes": [("object_class",)], + "entities": [("object_class", "object")], "parameter_values": [["object_class", "object", "parameter", expected_map]], "parameter_definitions": [("object_class", "parameter")], } @@ -1654,8 +1653,8 @@ def test_read_nested_map_with_compression(self): ], ) expected = { - "entity_classes": {("object_class",)}, - "entities": {("object_class", "object")}, + "entity_classes": [("object_class",)], + "entities": [("object_class", "object")], "parameter_values": [["object_class", "object", "parameter", expected_map]], "parameter_definitions": [("object_class", "parameter")], } @@ -1817,7 +1816,7 @@ def test_read_object_group_without_parameters(self): mapping = {"map_type": "ObjectGroup", "name": 0, "groups": 1, "members": 2} out, errors = get_mapped_data(data, [mapping], data_header) expected = dict() - expected["entity_classes"] = {("class_A",)} + expected["entity_classes"] = [("class_A",)] expected["entity_groups"] = { ("class_A", "group1", "object1"), ("class_A", "group1", "object2"), @@ -1843,15 +1842,14 @@ def test_read_object_group_and_import_objects(self): ("class_A", "group1", "object2"), ("class_A", "group2", "object3"), } - expected["entity_classes"] = {("class_A",)} - expected["entities"] = { + expected["entity_classes"] = [("class_A",)] + expected["entities"] = [ ("class_A", "group1"), ("class_A", "object1"), - ("class_A", "group1"), ("class_A", "object2"), ("class_A", "group2"), ("class_A", "object3"), - } + ] self.assertFalse(errors) self.assertEqual(out, expected) @@ -1876,7 +1874,7 @@ def test_read_parameter_definition_with_default_values_and_value_lists(self): } out, errors = get_mapped_data(data, [mapping], data_header) expected = dict() - expected["entity_classes"] = {("class_A",), ("class_A",), ("class_B",)} + expected["entity_classes"] = [("class_A",), ("class_B",)] expected["parameter_definitions"] = [ ("class_A", "param1", 23.0, "listA"), ("class_A", "param2", 42.0, "listB"), @@ -1900,7 +1898,7 @@ def test_map_as_default_parameter_value(self): out, errors = get_mapped_data(data, [mapping]) expected_map = Map(["key1", "key2", "key3"], [-2.3, 5.5, 3.2]) expected = { - "entity_classes": {("object_class",)}, + "entity_classes": [("object_class",)], "parameter_definitions": [("object_class", "parameter", expected_map)], } self.assertFalse(errors) @@ -1922,7 +1920,7 @@ def test_read_parameter_definition_with_nested_map_as_default_value(self): out, errors = get_mapped_data(data, [mapping], data_header) expected_map = Map(["key11", "key21"], [Map(["key12"], [-2]), Map(["key22"], [-1])]) expected = { - "entity_classes": {("object_class",)}, + "entity_classes": [("object_class",)], "parameter_definitions": [("object_class", "parameter", expected_map)], } self.assertFalse(errors) @@ -1952,8 +1950,8 @@ def test_read_map_index_names_from_columns(self): index_name="Index 1", ) expected = { - "entity_classes": {("object_class",)}, - "entities": {("object_class", "object")}, + "entity_classes": [("object_class",)], + "entities": [("object_class", "object")], "parameter_values": [["object_class", "object", "parameter", expected_map]], "parameter_definitions": [("object_class", "parameter")], } @@ -1984,8 +1982,8 @@ def test_missing_map_index_name(self): index_name="", ) expected = { - "entity_classes": {("object_class",)}, - "entities": {("object_class", "object")}, + "entity_classes": [("object_class",)], + "entities": [("object_class", "object")], "parameter_values": [["object_class", "object", "parameter", expected_map]], "parameter_definitions": [("object_class", "parameter")], } @@ -2015,7 +2013,7 @@ def test_read_default_value_index_names_from_columns(self): index_name="Index 1", ) expected = { - "entity_classes": {("object_class",)}, + "entity_classes": [("object_class",)], "parameter_definitions": [("object_class", "parameter", expected_map)], } self.assertFalse(errors) @@ -2026,7 +2024,7 @@ def test_filter_regular_expression_in_root_mapping(self): data = iter(input_data) mapping_root = unflatten([EntityClassMapping(0, filter_re="B"), EntityMapping(1)]) out, errors = get_mapped_data(data, [mapping_root]) - expected = {"entity_classes": {("B",)}, "entities": {("B", "r")}} + expected = {"entity_classes": [("B",)], "entities": [("B", "r")]} self.assertFalse(errors) self.assertEqual(out, expected) @@ -2035,7 +2033,7 @@ def test_filter_regular_expression_in_child_mapping(self): data = iter(input_data) mapping_root = unflatten([EntityClassMapping(0), EntityMapping(1, filter_re="q|r")]) out, errors = get_mapped_data(data, [mapping_root]) - expected = {"entity_classes": {("A",), ("B",)}, "entities": {("A", "q"), ("B", "r")}} + expected = {"entity_classes": [("A",), ("B",)], "entities": [("A", "q"), ("B", "r")]} self.assertFalse(errors) self.assertEqual(out, expected) @@ -2044,7 +2042,7 @@ def test_filter_regular_expression_in_child_mapping_filters_parent_mappings_too( data = iter(input_data) mapping_root = unflatten([EntityClassMapping(0), EntityMapping(1, filter_re="q")]) out, errors = get_mapped_data(data, [mapping_root]) - expected = {"entity_classes": {("A",)}, "entities": {("A", "q")}} + expected = {"entity_classes": [("A",)], "entities": [("A", "q")]} self.assertFalse(errors) self.assertEqual(out, expected) @@ -2063,8 +2061,8 @@ def test_arrays_get_imported_to_correct_alternatives(self): ) out, errors = get_mapped_data(data, [mapping_root]) expected = { - "entity_classes": {("class",)}, - "entities": {("class", "y")}, + "entity_classes": [("class",)], + "entities": [("class", "y")], "parameter_definitions": [("class", "parameter")], "alternatives": {"Base", "alternative"}, "parameter_values": [ diff --git a/tests/test_DiffDatabaseMapping.py b/tests/test_DiffDatabaseMapping.py index dae7b1c3..6d5853c7 100644 --- a/tests/test_DiffDatabaseMapping.py +++ b/tests/test_DiffDatabaseMapping.py @@ -492,7 +492,7 @@ def test_add_object_with_invalid_class(self): def test_add_relationship_classes(self): """Test that adding relationship classes works.""" - self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) + self._db_map.add_object_classes({"name": "oc1"}, {"name": "oc2"}) self._db_map.add_wide_relationship_classes( {"name": "rc1", "object_class_id_list": [1, 2]}, {"name": "rc2", "object_class_id_list": [2, 1]} ) @@ -515,7 +515,7 @@ def test_add_relationship_classes_with_invalid_name(self): def test_add_relationship_classes_with_same_name(self): """Test that adding two relationship classes with the same name only adds one of them.""" - self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) + self._db_map.add_object_classes({"name": "oc1"}, {"name": "oc2"}) self._db_map.add_wide_relationship_classes( {"name": "rc1", "object_class_id_list": [1, 2]}, {"name": "rc1", "object_class_id_list": [1, 2]} ) @@ -564,9 +564,9 @@ def test_add_relationship_class_with_invalid_object_class(self): def test_add_relationships(self): """Test that adding relationships works.""" - self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) - self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "object_class_id_list": [1, 2]}) - self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 2}) + self._db_map.add_object_classes({"name": "oc1"}, {"name": "oc2"}) + self._db_map.add_wide_relationship_classes({"name": "rc1", "object_class_id_list": [1, 2]}) + self._db_map.add_objects({"name": "o1", "class_id": 1}, {"name": "o2", "class_id": 2}) self._db_map.add_wide_relationships({"name": "nemo__pluto", "class_id": 3, "object_id_list": [1, 2]}) ent_els = self._db_map.query(self._db_map.get_table("entity_element")).all() relationships = self._db_map.query(self._db_map.wide_relationship_sq).all() @@ -588,9 +588,9 @@ def test_add_relationship_with_invalid_name(self): def test_add_identical_relationships(self): """Test that adding two relationships with the same class and same objects only adds the first one.""" - self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) - self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "object_class_id_list": [1, 2]}) - self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 2}) + self._db_map.add_object_classes({"name": "oc1"}, {"name": "oc2"}) + self._db_map.add_wide_relationship_classes({"name": "rc1", "object_class_id_list": [1, 2]}) + self._db_map.add_objects({"name": "o1", "class_id": 1}, {"name": "o2", "class_id": 2}) self._db_map.add_wide_relationships( {"name": "nemo__pluto", "class_id": 3, "object_id_list": [1, 2]}, {"name": "nemo__pluto_duplicate", "class_id": 3, "object_id_list": [1, 2]}, From 703e8c198439593a65b2777bdf44364cda8ec5d9 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 5 Apr 2023 12:15:13 +0200 Subject: [PATCH 022/317] Fix tests --- tests/test_DiffDatabaseMapping.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_DiffDatabaseMapping.py b/tests/test_DiffDatabaseMapping.py index 6d5853c7..a73c4b0f 100644 --- a/tests/test_DiffDatabaseMapping.py +++ b/tests/test_DiffDatabaseMapping.py @@ -97,7 +97,9 @@ def test_cascade_remove_relationship(self): self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "object_class_id_list": [1, 2]}) self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 2}) - ids, _ = self._db_map.add_wide_relationships({"name": "remove_me", "class_id": 3, "object_id_list": [1, 2]}) + ids, _ = self._db_map.add_wide_relationships( + {"id": 3, "name": "remove_me", "class_id": 3, "object_id_list": [1, 2]} + ) self._db_map.cascade_remove_items(relationship=ids) self.assertEqual(len(self._db_map.query(self._db_map.wide_relationship_sq).all()), 0) self._db_map.commit_session("delete") @@ -108,7 +110,9 @@ def test_cascade_remove_relationship_from_committed_session(self): self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "object_class_id_list": [1, 2]}) self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 2}) - ids, _ = self._db_map.add_wide_relationships({"name": "remove_me", "class_id": 3, "object_id_list": [1, 2]}) + ids, _ = self._db_map.add_wide_relationships( + {"id": 3, "name": "remove_me", "class_id": 3, "object_id_list": [1, 2]} + ) self._db_map.commit_session("add") self.assertEqual(len(self._db_map.query(self._db_map.wide_relationship_sq).all()), 1) self._db_map.cascade_remove_items(relationship=ids) From 0c9365423a35ee9cdfac505f57bfbdec9d768463 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 5 Apr 2023 12:16:23 +0200 Subject: [PATCH 023/317] Remove print --- spinedb_api/import_functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 57a4a127..28d389f8 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -486,7 +486,6 @@ def _get_entities_for_import(db_map, data, make_cache): } ) item.update(dict(zip(("description",), optionals))) - print(item) try: check_entity(item, entity_ids_per_name, entity_ids_per_el_id_lst, entity_classes, entities) except SpineIntegrityError as e: From 19571f9ede2d95ed2f4eef102fba5c2198951a89 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 5 Apr 2023 12:36:04 +0200 Subject: [PATCH 024/317] Fix importing entities --- spinedb_api/import_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 28d389f8..7b379adf 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -503,6 +503,7 @@ def _get_entities_for_import(db_map, data, make_cache): item["id"] = entity_ids_per_el_id_lst[ec_id, el_ids] = entity_ids_per_name[ ec_id, e_name ] = new_entity_id() + entities[item["id"]] = item to_add.append(item) return to_add, to_update, error_log From ab63397bc1cbe6aa87e1a94d901b2d8d59aa5209 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 5 Apr 2023 12:47:48 +0200 Subject: [PATCH 025/317] Fix one more test --- spinedb_api/export_mapping/export_mapping.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/spinedb_api/export_mapping/export_mapping.py b/spinedb_api/export_mapping/export_mapping.py index 911aee9b..c128c50e 100644 --- a/spinedb_api/export_mapping/export_mapping.py +++ b/spinedb_api/export_mapping/export_mapping.py @@ -643,6 +643,11 @@ def add_query_columns(self, db_map, query): db_map.ext_entity_class_sq.c.dimension_name_list.label("dimension_name_list"), ) + def filter_query(self, db_map, query): + if isinstance(self.child, DimensionMapping): + return query.filter(db_map.ext_entity_class_sq.c.dimension_id_list != None) + return query.filter(db_map.ext_entity_class_sq.c.dimension_id_list == None) + @staticmethod def name_field(): return "entity_class_name" From d9073499a3a75385412e07a9b05d66f028de66fe Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 5 Apr 2023 12:55:41 +0200 Subject: [PATCH 026/317] Improve check for dimension count in export_mapping.EntityMapping --- spinedb_api/export_mapping/export_mapping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spinedb_api/export_mapping/export_mapping.py b/spinedb_api/export_mapping/export_mapping.py index c128c50e..06eb59ef 100644 --- a/spinedb_api/export_mapping/export_mapping.py +++ b/spinedb_api/export_mapping/export_mapping.py @@ -644,7 +644,7 @@ def add_query_columns(self, db_map, query): ) def filter_query(self, db_map, query): - if isinstance(self.child, DimensionMapping): + if any(isinstance(m, (DimensionMapping, ElementMapping)) for m in self.flatten()): return query.filter(db_map.ext_entity_class_sq.c.dimension_id_list != None) return query.filter(db_map.ext_entity_class_sq.c.dimension_id_list == None) From 9613863461247c8e2be030580cfcca64b1c6c2ab Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 5 Apr 2023 14:28:11 +0200 Subject: [PATCH 027/317] Fix migration script to ignore non existing constraints --- ...c61_drop_object_and_relationship_tables.py | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/spinedb_api/alembic/versions/6b7c994c1c61_drop_object_and_relationship_tables.py b/spinedb_api/alembic/versions/6b7c994c1c61_drop_object_and_relationship_tables.py index 703c0547..afb3adda 100644 --- a/spinedb_api/alembic/versions/6b7c994c1c61_drop_object_and_relationship_tables.py +++ b/spinedb_api/alembic/versions/6b7c994c1c61_drop_object_and_relationship_tables.py @@ -9,7 +9,6 @@ import sqlalchemy as sa from spinedb_api.helpers import naming_convention - # revision identifiers, used by Alembic. revision = '6b7c994c1c61' down_revision = '989fccf80441' @@ -75,14 +74,20 @@ def upgrade(): sa.PrimaryKeyConstraint('entity_id', 'position', name=op.f('pk_entity_element')), ) _persist_data() + # NOTE: some constraints are only created by the create_new_spine_database() function, + # not by the corresponding migration script. Thus we need to check before removing those constraints. + # We should avoid this in the future. + entity_class_constraints, entity_constraints = _get_constraints() with op.batch_alter_table("entity", naming_convention=naming_convention) as batch_op: - batch_op.drop_constraint('uq_entity_idclass_id', type_='unique') - batch_op.drop_constraint('uq_entity_idtype_idclass_id', type_='unique') + for cname in ('uq_entity_idclass_id', 'uq_entity_idtype_idclass_id'): + if cname in entity_constraints: + batch_op.drop_constraint(cname, type_='unique') batch_op.drop_constraint('fk_entity_type_id_entity_type', type_='foreignkey') batch_op.drop_column('type_id') with op.batch_alter_table("entity_class", naming_convention=naming_convention) as batch_op: - batch_op.drop_constraint('uq_entity_class_idtype_id', type_='unique') - batch_op.drop_constraint('uq_entity_class_type_idname', type_='unique') + for cname in ('uq_entity_class_idtype_id', 'uq_entity_class_type_idname'): + if cname in entity_class_constraints: + batch_op.drop_constraint(cname, type_='unique') batch_op.drop_constraint('fk_entity_class_type_id_entity_class_type', type_='foreignkey') batch_op.drop_constraint('fk_entity_class_commit_id_commit', type_='foreignkey') batch_op.drop_column('commit_id') @@ -98,6 +103,13 @@ def upgrade(): op.drop_table('relationship_entity') +def _get_constraints(): + conn = op.get_bind() + meta = sa.MetaData(conn) + meta.reflect() + return [[c.name for c in meta.tables[tname].constraints] for tname in ["entity_class", "entity"]] + + def _persist_data(): conn = op.get_bind() meta = sa.MetaData(conn) From b9202a8695a7c1e7d8d7824e01159b660e449941 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 5 Apr 2023 14:41:10 +0200 Subject: [PATCH 028/317] Fix import entities --- spinedb_api/import_functions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 7b379adf..151a683d 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -462,14 +462,15 @@ def _get_entities_for_import(db_map, data, make_cache): if isinstance(ent_name_or_el_names, str): el_ids = () e_key = ent_name_or_el_names + e_id = None else: el_ids = tuple( entity_ids_per_name.get((dim_id, name), None) for dim_id, name in zip(dim_ids, ent_name_or_el_names) ) e_key = el_ids + e_id = entity_ids_per_el_id_lst.pop((ec_id, el_ids), None) if (ec_id, e_key) in checked: continue - e_id = entity_ids_per_el_id_lst.pop((ec_id, el_ids), None) if e_id is not None: e_name = cache["entity"][e_id].name entity_ids_per_name.pop((e_id, e_name)) From 40edc285d5a85bd1f238d069b10894459badc8a8 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 5 Apr 2023 18:00:51 +0200 Subject: [PATCH 029/317] Introduce dry_run argument as replacement for committing --- spinedb_api/db_mapping_add_mixin.py | 78 ++++++++++++++++---------- spinedb_api/db_mapping_base.py | 12 +--- spinedb_api/db_mapping_commit_mixin.py | 15 +++-- spinedb_api/db_mapping_remove_mixin.py | 5 +- spinedb_api/db_mapping_update_mixin.py | 67 +++++++++++++--------- spinedb_api/export_functions.py | 2 +- spinedb_api/import_functions.py | 13 +++-- 7 files changed, 108 insertions(+), 84 deletions(-) diff --git a/spinedb_api/db_mapping_add_mixin.py b/spinedb_api/db_mapping_add_mixin.py index 52b6e17b..54983dc5 100644 --- a/spinedb_api/db_mapping_add_mixin.py +++ b/spinedb_api/db_mapping_add_mixin.py @@ -75,7 +75,7 @@ def __init__(self, *args, **kwargs): self._next_id = Table("next_id", self._metadata, autoload=True) @contextmanager - def generate_ids(self, tablename): + def generate_ids(self, tablename, dry_run=False): """Manages id generation for new items to be added to the db. Args: @@ -84,7 +84,7 @@ def generate_ids(self, tablename): Yields: self._IdGenerator: an object that generates a new id every time it is called. """ - connection = self.connection if self.committing else self.engine.connect() + connection = self.engine.connect() if dry_run else self.connection fieldname = { "entity_class": "entity_class_id", "object_class": "entity_class_id", @@ -128,25 +128,19 @@ def generate_ids(self, tablename): yield gen finally: connection.execute(stmt, {"user": self.username, "date": datetime.utcnow(), fieldname: gen.next_id}) - if not self.committing: + if dry_run: connection.close() - def _add_commit_id_and_ids(self, tablename, *items): + def _add_commit_id_and_ids(self, tablename, *items, dry_run=False): if not items: return [], set() - commit_id = self._make_commit_id() - with self.generate_ids(tablename) as new_id: + commit_id = self._make_commit_id(dry_run=dry_run) + with self.generate_ids(tablename, dry_run=dry_run) as new_id: for item in items: item["commit_id"] = commit_id if "id" not in item: item["id"] = new_id() - def _readd_items(self, tablename, *items): - """Add known items to database.""" - self._make_commit_id() - for _ in self._do_add_items(tablename, *items): - pass - def add_items( self, tablename, @@ -157,6 +151,7 @@ def add_items( return_items=False, cache=None, readd=False, + dry_run=False, ): """Add items to db. @@ -177,7 +172,9 @@ def add_items( list(SpineIntegrityError): found violations """ if readd: - self._readd_items(tablename, *items) + if not dry_run: + for _ in self._do_add_items(tablename, *items): + pass return items if return_items else {x["id"] for x in items}, [] if check: checked_items, intgr_error_log = self.check_items( @@ -185,14 +182,14 @@ def add_items( ) else: checked_items, intgr_error_log = list(items), [] - ids = self._add_items(tablename, *checked_items) + ids = self._add_items(tablename, *checked_items, dry_run=dry_run) if return_items: return checked_items, intgr_error_log if return_dups: ids.update(set(x.id for x in intgr_error_log if x.id)) return ids, intgr_error_log - def _add_items(self, tablename, *items): + def _add_items(self, tablename, *items, dry_run=False): """Add items to database without checking integrity. Args: @@ -204,9 +201,10 @@ def _add_items(self, tablename, *items): Returns: ids (set): added instances' ids """ - self._add_commit_id_and_ids(tablename, *items) - for _ in self._do_add_items(tablename, *items): - pass + self._add_commit_id_and_ids(tablename, *items, dry_run=dry_run) + if not dry_run: + for _ in self._do_add_items(tablename, *items): + pass return {item["id"] for item in items} def _get_table_for_insert(self, tablename): @@ -224,8 +222,6 @@ def _get_table_for_insert(self, tablename): return self._metadata.tables[tablename] def _do_add_items(self, tablename, *items_to_add): - if not self.committing: - return try: for tablename_, items_to_add_ in self._items_to_add_per_table(tablename, items_to_add): table = self._get_table_for_insert(tablename_) @@ -234,6 +230,8 @@ def _do_add_items(self, tablename, *items_to_add): except DBAPIError as e: msg = f"DBAPIError while inserting {tablename} items: {e.orig.args}" raise SpineDBAPIError(msg) from e + else: + self._has_pending_changes = True def _items_to_add_per_table(self, tablename, items_to_add): """ @@ -382,7 +380,7 @@ def add_entity_metadata(self, *items, **kwargs): def add_parameter_value_metadata(self, *items, **kwargs): return self.add_items("parameter_value_metadata", *items, **kwargs) - def _get_or_add_metadata_ids_for_items(self, *items, check, strict, cache): + def _get_or_add_metadata_ids_for_items(self, *items, check, strict, cache, dry_run): metadata_ids = {} for entry in cache.get("metadata", {}).values(): metadata_ids.setdefault(entry.name, {})[entry.value] = entry.id @@ -397,7 +395,13 @@ def _get_or_add_metadata_ids_for_items(self, *items, check, strict, cache): else: item["metadata_id"] = existing_id added_metadata, errors = self.add_items( - "metadata", *metadata_to_add, check=check, strict=strict, return_items=True, cache=cache + "metadata", + *metadata_to_add, + check=check, + strict=strict, + return_items=True, + cache=cache, + dry_run=dry_run, ) for x in added_metadata: cache.table_cache("metadata").add_item(x) @@ -411,36 +415,52 @@ def _get_or_add_metadata_ids_for_items(self, *items, check, strict, cache): item["metadata_id"] = new_metadata_ids[metadata_name][metadata_value] return added_metadata, errors - def _add_ext_item_metadata(self, table_name, *items, check=True, strict=False, return_items=False, cache=None): + def _add_ext_item_metadata( + self, table_name, *items, check=True, strict=False, return_items=False, cache=None, dry_run=False + ): # Note, that even though return_items can be False, it doesn't make much sense here because we'll be mixing # metadata and entity metadata ids. if cache is None: cache = self.make_cache({table_name}, include_ancestors=True) added_metadata, metadata_errors = self._get_or_add_metadata_ids_for_items( - *items, check=check, strict=strict, cache=cache + *items, check=check, strict=strict, cache=cache, dry_run=dry_run ) if metadata_errors: if not return_items: return added_metadata, metadata_errors return {i["id"] for i in added_metadata}, metadata_errors added_item_metadata, item_errors = self.add_items( - table_name, *items, check=check, strict=strict, return_items=True, cache=cache + table_name, *items, check=check, strict=strict, return_items=True, cache=cache, dry_run=dry_run ) errors = metadata_errors + item_errors if not return_items: return {i["id"] for i in added_metadata + added_item_metadata}, errors return added_metadata + added_item_metadata, errors - def add_ext_entity_metadata(self, *items, check=True, strict=False, return_items=False, cache=None, readd=False): + def add_ext_entity_metadata( + self, *items, check=True, strict=False, return_items=False, cache=None, readd=False, dry_run=False + ): return self._add_ext_item_metadata( - "entity_metadata", *items, check=check, strict=strict, return_items=return_items, cache=cache + "entity_metadata", + *items, + check=check, + strict=strict, + return_items=return_items, + cache=cache, + dry_run=dry_run, ) def add_ext_parameter_value_metadata( - self, *items, check=True, strict=False, return_items=False, cache=None, readd=False + self, *items, check=True, strict=False, return_items=False, cache=None, readd=False, dry_run=False ): return self._add_ext_item_metadata( - "parameter_value_metadata", *items, check=check, strict=strict, return_items=return_items, cache=cache + "parameter_value_metadata", + *items, + check=check, + strict=strict, + return_items=return_items, + cache=cache, + dry_run=dry_run, ) def _add_entity_classes(self, *items): diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index ec5f7aaf..080b5fad 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -117,7 +117,6 @@ def __init__( self.username = username if username else "anon" self.codename = self._make_codename(codename) self._memory = memory - self.committing = True self._memory_dirty = False self._original_engine = self.create_engine( self.sa_url, upgrade=upgrade, create=create, sqlite_timeout=sqlite_timeout @@ -263,15 +262,6 @@ def __enter__(self): def __exit__(self, _exc_type, _exc_val, _exc_tb): self.connection.close() - @contextmanager - def override_committing(self, new_committing): - committing = self.committing - self.committing = new_committing - try: - yield None - finally: - self.committing = committing - def _descendant_tablenames(self, tablename): child_tablenames = { "alternative": ("parameter_value", "scenario_alternative"), @@ -319,7 +309,7 @@ def get_table(self, tablename): def commit_id(self): return self._commit_id - def _make_commit_id(self): + def _make_commit_id(self, dry_run=False): return None def _check_commit(self, comment): diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py index 754c4e92..47355f3d 100644 --- a/spinedb_api/db_mapping_commit_mixin.py +++ b/spinedb_api/db_mapping_commit_mixin.py @@ -31,9 +31,10 @@ def __init__(self, *args, **kwargs): """Initialize class.""" super().__init__(*args, **kwargs) self._commit_id = None + self._has_pending_changes = False def has_pending_changes(self): - return self._commit_id is not None + return self._has_pending_changes def _get_sqlite_lock(self): """Commits the session's natural transaction and begins a new locking one.""" @@ -41,14 +42,14 @@ def _get_sqlite_lock(self): self.session.commit() self.session.execute("BEGIN IMMEDIATE") - def _make_commit_id(self): + def _make_commit_id(self, dry_run=False): if self._commit_id is None: - if self.committing: - self._get_sqlite_lock() - self._commit_id = self._do_make_commit_id(self.connection) - else: + if dry_run: with self.engine.begin() as connection: self._commit_id = self._do_make_commit_id(connection) + else: + self._get_sqlite_lock() + self._commit_id = self._do_make_commit_id(self.connection) return self._commit_id def _do_make_commit_id(self, connection): @@ -71,6 +72,7 @@ def commit_session(self, comment): self._checked_execute(upd, dict(user=user, date=date, comment=comment)) self.session.commit() self._commit_id = None + self._has_pending_changes = False if self._memory: self._memory_dirty = True @@ -83,3 +85,4 @@ def reset_session(self): self.session.rollback() self.cache.clear() self._commit_id = None + self._has_pending_changes = False diff --git a/spinedb_api/db_mapping_remove_mixin.py b/spinedb_api/db_mapping_remove_mixin.py index 6f051578..ecd766d1 100644 --- a/spinedb_api/db_mapping_remove_mixin.py +++ b/spinedb_api/db_mapping_remove_mixin.py @@ -40,9 +40,6 @@ def remove_items(self, **kwargs): Args: **kwargs: keyword is table name, argument is list of ids to remove """ - if not self.committing: - return - self._make_commit_id() for tablename, ids in kwargs.items(): if not ids: continue @@ -59,6 +56,8 @@ def remove_items(self, **kwargs): except DBAPIError as e: msg = f"DBAPIError while removing {tablename} items: {e.orig.args}" raise SpineDBAPIError(msg) from e + else: + self._has_pending_changes = True # pylint: disable=redefined-builtin def cascading_ids(self, cache=None, **kwargs): diff --git a/spinedb_api/db_mapping_update_mixin.py b/spinedb_api/db_mapping_update_mixin.py index 3d1d8a88..c6efbfef 100644 --- a/spinedb_api/db_mapping_update_mixin.py +++ b/spinedb_api/db_mapping_update_mixin.py @@ -27,7 +27,7 @@ def _add_commit_id(self, *items): for item in items: item["commit_id"] = self._make_commit_id() - def _update_items(self, tablename, *items): + def _update_items(self, tablename, *items, dry_run=False): if not items: return set() # Special cases @@ -36,26 +36,28 @@ def _update_items(self, tablename, *items): if tablename == "relationship": return self._update_wide_relationships(*items) real_tablename = self._real_tablename(tablename) - return self._do_update_items(real_tablename, *items) + if not dry_run: + self._do_update_items(real_tablename, *items) + return {x["id"] for x in items} def _do_update_items(self, tablename, *items): if not items: - return set() - if self.committing: - self._add_commit_id(*items) - table = self._metadata.tables[tablename] - upd = table.update() - for k in self._get_primary_key(tablename): - upd = upd.where(getattr(table.c, k) == bindparam(k)) - upd = upd.values({key: bindparam(key) for key in table.columns.keys() & items[0].keys()}) - try: - self._checked_execute(upd, [{**item} for item in items]) - except DBAPIError as e: - msg = f"DBAPIError while updating '{tablename}' items: {e.orig.args}" - raise SpineDBAPIError(msg) - return {x["id"] for x in items} + return + self._add_commit_id(*items) + table = self._metadata.tables[tablename] + upd = table.update() + for k in self._get_primary_key(tablename): + upd = upd.where(getattr(table.c, k) == bindparam(k)) + upd = upd.values({key: bindparam(key) for key in table.columns.keys() & items[0].keys()}) + try: + self._checked_execute(upd, [{**item} for item in items]) + except DBAPIError as e: + msg = f"DBAPIError while updating '{tablename}' items: {e.orig.args}" + raise SpineDBAPIError(msg) from e + else: + self._has_pending_changes = True - def update_items(self, tablename, *items, check=True, strict=False, return_items=False, cache=None): + def update_items(self, tablename, *items, check=True, strict=False, return_items=False, cache=None, dry_run=False): """Updates items. Args: @@ -78,7 +80,7 @@ def update_items(self, tablename, *items, check=True, strict=False, return_items ) else: checked_items, intgr_error_log = list(items), [] - updated_ids = self._update_items(tablename, *checked_items) + updated_ids = self._update_items(tablename, *checked_items, dry_run=dry_run) if return_items: return checked_items, intgr_error_log return updated_ids, intgr_error_log @@ -135,9 +137,9 @@ def _update_entities(self, *items): "element_id": element_id, } entity_element_items.append(rel_ent_item) - entity_ids = self._do_update_items("entity", *entity_items) + self._do_update_items("entity", *entity_items) self._do_update_items("entity_element", *entity_element_items) - return entity_ids + return {x["id"] for x in entity_items} def update_object_classes(self, *items, **kwargs): return self.update_items("object_class", *items, **kwargs) @@ -185,9 +187,9 @@ def _update_wide_relationships(self, *items): "element_id": element_id, } entity_element_items.append(rel_ent_item) - entity_ids = self._do_update_items("entity", *entity_items) + self._do_update_items("entity", *entity_items) self._do_update_items("entity_element", *entity_element_items) - return entity_ids + return {x["id"] for x in entity_items} def update_parameter_definitions(self, *items, **kwargs): return self.update_items("parameter_definition", *items, **kwargs) @@ -243,23 +245,27 @@ def update_metadata(self, *items, **kwargs): def _update_metadata(self, *items): return self._update_items("metadata", *items) - def update_ext_entity_metadata(self, *items, check=True, strict=False, return_items=False, cache=None): + def update_ext_entity_metadata( + self, *items, check=True, strict=False, return_items=False, cache=None, dry_run=False + ): updated_items, errors = self._update_ext_item_metadata( - "entity_metadata", *items, check=check, strict=strict, cache=cache + "entity_metadata", *items, check=check, strict=strict, cache=cache, dry_run=dry_run ) if return_items: return updated_items, errors return {i["id"] for i in updated_items}, errors - def update_ext_parameter_value_metadata(self, *items, check=True, strict=False, return_items=False, cache=None): + def update_ext_parameter_value_metadata( + self, *items, check=True, strict=False, return_items=False, cache=None, dry_run=False + ): updated_items, errors = self._update_ext_item_metadata( - "parameter_value_metadata", *items, check=check, strict=strict, cache=cache + "parameter_value_metadata", *items, check=check, strict=strict, cache=cache, dry_run=dry_run ) if return_items: return updated_items, errors return {i["id"] for i in updated_items}, errors - def _update_ext_item_metadata(self, metadata_table, *items, check=True, strict=False, cache=None): + def _update_ext_item_metadata(self, metadata_table, *items, check=True, strict=False, cache=None, dry_run=False): if cache is None: cache = self.make_cache({"entity_metadata", "parameter_value_metadata", "metadata"}) metadata_ids = {} @@ -320,7 +326,12 @@ def _update_ext_item_metadata(self, metadata_table, *items, check=True, strict=F errors = [] if updatable_metadata_items: updated_metadata, errors = self.update_metadata( - *updatable_metadata_items, check=False, strict=strict, return_items=True, cache=cache + *updatable_metadata_items, + check=False, + strict=strict, + return_items=True, + cache=cache, + dry_run=dry_run, ) all_items += updated_metadata if errors: diff --git a/spinedb_api/export_functions.py b/spinedb_api/export_functions.py index 262e793a..258f9b00 100644 --- a/spinedb_api/export_functions.py +++ b/spinedb_api/export_functions.py @@ -204,7 +204,7 @@ def export_parameter_values(db_map, ids=Asterisk, make_cache=None, parse_value=f ( ( x.entity_class_name, - x.element_name_list or x.name, + x.element_name_list or x.entity_name, x.parameter_name, parse_value(x.value, x.type), x.alternative_name, diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 151a683d..57cb095a 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -186,6 +186,7 @@ def get_data_for_import( make_cache=None, unparse_value=to_database, on_conflict="merge", + dry_run=False, entity_classes=(), entities=(), parameter_definitions=(), @@ -259,7 +260,7 @@ def get_data_for_import( yield ("alternative", _get_alternatives_for_import(alternatives, make_cache)) yield ("scenario_alternative", _get_scenario_alternatives_for_import(scenario_alternatives, make_cache)) if entity_classes: - yield ("entity_class", _get_entity_classes_for_import(db_map, entity_classes, make_cache)) + yield ("entity_class", _get_entity_classes_for_import(db_map, entity_classes, make_cache, dry_run)) if object_classes: yield ("object_class", _get_object_classes_for_import(db_map, object_classes, make_cache)) if relationship_classes: @@ -294,7 +295,7 @@ def get_data_for_import( _get_tool_feature_methods_for_import(db_map, tool_feature_methods, make_cache, unparse_value), ) if entities: - yield ("entity", _get_entities_for_import(db_map, entities, make_cache)) + yield ("entity", _get_entities_for_import(db_map, entities, make_cache, dry_run)) if objects: yield ("object", _get_objects_for_import(db_map, objects, make_cache)) if relationships: @@ -365,14 +366,14 @@ def import_entity_classes(db_map, data, make_cache=None): return import_data(db_map, entity_classes=data, make_cache=make_cache) -def _get_entity_classes_for_import(db_map, data, make_cache): +def _get_entity_classes_for_import(db_map, data, make_cache, dry_run): cache = make_cache({"entity_class"}, include_ancestors=True) entity_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values()} checked = set() error_log = [] to_add = [] to_update = [] - with db_map.generate_ids("entity_class") as new_entity_class_id: + with db_map.generate_ids("entity_class", dry_run=dry_run) as new_entity_class_id: for name, *optionals in data: if name in checked: continue @@ -439,7 +440,7 @@ def _make_unique_entity_name(class_id, class_name, ent_name_or_el_names, class_i return name -def _get_entities_for_import(db_map, data, make_cache): +def _get_entities_for_import(db_map, data, make_cache, dry_run): cache = make_cache({"entity"}, include_ancestors=True) entities = {x.id: x for x in cache.get("entity", {}).values()} entity_ids_per_name = {(x.class_id, x.name): x.id for x in cache.get("entity", {}).values()} @@ -455,7 +456,7 @@ def _get_entities_for_import(db_map, data, make_cache): to_add = [] to_update = [] checked = set() - with db_map.generate_ids("entity") as new_entity_id: + with db_map.generate_ids("entity", dry_run=dry_run) as new_entity_id: for class_name, ent_name_or_el_names, *optionals in data: ec_id = entity_class_ids.get(class_name, None) dim_ids = dimension_id_lists.get(ec_id, ()) From 9d1c2d5d4bb0d900b7652f8edc5fd37a646df289 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 12 Apr 2023 18:14:26 +0200 Subject: [PATCH 030/317] Fix typo --- spinedb_api/import_functions.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 57cb095a..cbbeb547 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -1282,9 +1282,7 @@ def import_object_classes(db_map, data, make_cache=None): def _get_object_classes_for_import(db_map, data, make_cache): cache = make_cache({"entity_class"}, include_ancestors=True) - object_class_ids = { - oc.name: oc.id for oc in cache.get("_get_object_classes_for_import", {}).values() if not oc.dimension_id_list - } + object_class_ids = {oc.name: oc.id for oc in cache.get("entity_class", {}).values() if not oc.dimension_id_list} checked = set() to_add = [] to_update = [] From 05caf49f86614adc5c82ac65095be8f621714e2a Mon Sep 17 00:00:00 2001 From: Manuel Date: Mon, 24 Apr 2023 10:54:16 +0200 Subject: [PATCH 031/317] Sort entities and classes by dimensionality in export data This is so import_data can consume that data in the right order, that is import zero-dim stuff before multi-dim one. --- spinedb_api/export_functions.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/spinedb_api/export_functions.py b/spinedb_api/export_functions.py index a8fdc05c..823c045e 100644 --- a/spinedb_api/export_functions.py +++ b/spinedb_api/export_functions.py @@ -166,15 +166,21 @@ def export_parameter_value_lists(db_map, ids=Asterisk, make_cache=None, parse_va def export_entity_classes(db_map, ids=Asterisk, make_cache=None): return sorted( - (x.name, x.dimension_name_list, x.description, x.display_icon) - for x in _get_items(db_map, "entity_class", ids, make_cache) + ( + (x.name, x.dimension_name_list, x.description, x.display_icon) + for x in _get_items(db_map, "entity_class", ids, make_cache) + ), + key=lambda x: (len(x[1]), x[0]), ) def export_entities(db_map, ids=Asterisk, make_cache=None): return sorted( - (x.class_name, x.element_name_list or x.name, x.description) - for x in _get_items(db_map, "entity", ids, make_cache) + ( + (x.class_name, x.element_name_list or x.name, x.description) + for x in _get_items(db_map, "entity", ids, make_cache) + ), + key=lambda x: (0 if isinstance(x[1], str) else len(x[1]), x[0]), ) From b73b4b61479be4cd7bc87048a383d24cd344e472 Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 2 May 2023 10:08:29 +0200 Subject: [PATCH 032/317] WIP: introduce entity_alternative --- spinedb_api/__init__.py | 9 - spinedb_api/check_functions.py | 91 ----- spinedb_api/db_cache.py | 91 +---- spinedb_api/db_mapping_add_mixin.py | 130 ++++---- spinedb_api/db_mapping_base.py | 186 ++--------- spinedb_api/db_mapping_check_mixin.py | 161 --------- spinedb_api/db_mapping_remove_mixin.py | 38 +-- spinedb_api/db_mapping_update_mixin.py | 167 +++++----- spinedb_api/diff_db_mapping.py | 4 +- spinedb_api/diff_db_mapping_base.py | 2 +- spinedb_api/diff_db_mapping_commit_mixin.py | 6 +- spinedb_api/export_functions.py | 39 --- spinedb_api/export_mapping/__init__.py | 4 - spinedb_api/export_mapping/export_mapping.py | 238 +------------- spinedb_api/export_mapping/settings.py | 95 ------ spinedb_api/filters/scenario_filter.py | 35 ++ spinedb_api/filters/tool_filter.py | 287 ---------------- spinedb_api/filters/tools.py | 17 +- spinedb_api/filters/value_transformer.py | 14 +- spinedb_api/helpers.py | 81 +---- spinedb_api/import_functions.py | 311 +----------------- spinedb_api/import_mapping/import_mapping.py | 143 +------- .../import_mapping/import_mapping_compat.py | 70 +--- spinedb_api/spine_db_server.py | 5 +- 24 files changed, 255 insertions(+), 1969 deletions(-) delete mode 100644 spinedb_api/filters/tool_filter.py diff --git a/spinedb_api/__init__.py b/spinedb_api/__init__.py index f54b09ea..9899c056 100644 --- a/spinedb_api/__init__.py +++ b/spinedb_api/__init__.py @@ -63,10 +63,6 @@ import_relationships, import_scenarios, import_scenario_alternatives, - import_tools, - import_features, - import_tool_features, - import_tool_feature_methods, import_metadata, import_object_metadata, import_relationship_metadata, @@ -88,10 +84,6 @@ export_relationships, export_scenario_alternatives, export_scenarios, - export_tools, - export_features, - export_tool_features, - export_tool_feature_methods, ) from .import_mapping.import_mapping_compat import import_mapping_from_dict from .import_mapping.generator import get_mapped_data @@ -118,7 +110,6 @@ ) from .filters.alternative_filter import apply_alternative_filter_to_parameter_value_sq from .filters.scenario_filter import apply_scenario_filter_to_subqueries -from .filters.tool_filter import apply_tool_filter_to_entity_sq from .filters.execution_filter import apply_execution_filter from .filters.renamer import apply_renaming_to_parameter_definition_sq, apply_renaming_to_entity_class_sq from .filters.tools import ( diff --git a/spinedb_api/check_functions.py b/spinedb_api/check_functions.py index 45e45113..a75e8b5e 100644 --- a/spinedb_api/check_functions.py +++ b/spinedb_api/check_functions.py @@ -560,97 +560,6 @@ def check_list_value(item, list_names_by_id, list_value_ids_by_index, list_value raise SpineIntegrityError(f"'{list_name}' already has the value '{from_database(value, type_)}'.", id=dup_id) -def check_tool(item, current_items): - try: - name = item["name"] - except KeyError: - raise SpineIntegrityError("Missing tool name.") - if name in current_items: - raise SpineIntegrityError(f"There can't be more than one tool called '{name}'.", id=current_items[name]) - - -def check_feature(item, current_items, parameter_definitions): - try: - parameter_definition_id = item["parameter_definition_id"] - except KeyError: - raise SpineIntegrityError("Missing parameter identifier.") - try: - parameter_value_list_id = item["parameter_value_list_id"] - except KeyError: - raise SpineIntegrityError("Missing parameter value list identifier.") - try: - parameter_definition = parameter_definitions[parameter_definition_id] - except KeyError: - raise SpineIntegrityError("Parameter not found.") - if parameter_value_list_id is None: - raise SpineIntegrityError(f"Parameter '{parameter_definition['name']}' doesn't have a value list.") - if parameter_value_list_id != parameter_definition["parameter_value_list_id"]: - raise SpineIntegrityError("Parameter definition and value list don't match.") - if parameter_definition_id in current_items: - raise SpineIntegrityError( - f"There's already a feature defined for parameter '{parameter_definition['name']}'.", - id=current_items[parameter_definition_id], - ) - - -def check_tool_feature(item, current_items, tools, features): - try: - tool_id = item["tool_id"] - except KeyError: - raise SpineIntegrityError("Missing tool identifier.") - try: - feature_id = item["feature_id"] - except KeyError: - raise SpineIntegrityError("Missing feature identifier.") - try: - parameter_value_list_id = item["parameter_value_list_id"] - except KeyError: - raise SpineIntegrityError("Missing parameter value list identifier.") - try: - tool = tools[tool_id] - except KeyError: - raise SpineIntegrityError("Tool not found.") - try: - feature = features[feature_id] - except KeyError: - raise SpineIntegrityError("Feature not found.") - dup_id = current_items.get((tool_id, feature_id)) - if dup_id is not None: - raise SpineIntegrityError(f"Tool '{tool['name']}' already has feature '{feature['name']}'.", id=dup_id) - if parameter_value_list_id != feature["parameter_value_list_id"]: - raise SpineIntegrityError("Feature and parameter value list don't match.") - - -def check_tool_feature_method(item, current_items, tool_features, parameter_value_lists): - try: - tool_feature_id = item["tool_feature_id"] - except KeyError: - raise SpineIntegrityError("Missing tool feature identifier.") - try: - parameter_value_list_id = item["parameter_value_list_id"] - except KeyError: - raise SpineIntegrityError("Missing parameter value list identifier.") - try: - method_index = item["method_index"] - except KeyError: - raise SpineIntegrityError("Missing method index.") - try: - tool_feature = tool_features[tool_feature_id] - except KeyError: - raise SpineIntegrityError("Tool feature not found.") - try: - parameter_value_list = parameter_value_lists[parameter_value_list_id] - except KeyError: - raise SpineIntegrityError("Parameter value list not found.") - dup_id = current_items.get((tool_feature_id, method_index)) - if dup_id is not None: - raise SpineIntegrityError("Tool feature already has the given method.", id=dup_id) - if parameter_value_list_id != tool_feature["parameter_value_list_id"]: - raise SpineIntegrityError("Feature and parameter value list don't match.") - if method_index not in parameter_value_list["value_index_list"]: - raise SpineIntegrityError("Invalid method for tool feature.") - - def check_metadata(item, metadata): """Check whether the entity metadata item violates an integrity constraint. diff --git a/spinedb_api/db_cache.py b/spinedb_api/db_cache.py index d9d7d445..a2287f34 100644 --- a/spinedb_api/db_cache.py +++ b/spinedb_api/db_cache.py @@ -70,9 +70,6 @@ def make_item(self, item_type, item): "entity_group": EntityGroupItem, "scenario": ScenarioItem, "scenario_alternative": ScenarioAlternativeItem, - "feature": FeatureItem, - "tool_feature": ToolFeatureItem, - "tool_feature_method": ToolFeatureMethodItem, "parameter_value_list": ParameterValueListItem, }.get(item_type, CacheItem) return factory(self, item_type, **item) @@ -106,6 +103,9 @@ def update_item(self, item): current_item.cascade_update() def remove_item(self, id_): + if self._item_type == "alternative" and id_ == 1: + # Do not remove the Base alternative + return CacheItem(self._db_cache, self._item_type) current_item = self.get(id_) if current_item: current_item.cascade_remove() @@ -325,6 +325,8 @@ def __getitem__(self, key): return tuple(self._get_ref("entity", id_, key).get("name") for id_ in self["element_id_list"]) if key == "byname": return self["element_name_list"] or (self["name"],) + if key == "alternative_name": + return self._get_ref("alternative", self["alternative_id"], key).get("name") return super().__getitem__(key) def _reference_keys(self): @@ -333,6 +335,7 @@ def _reference_keys(self): "dimension_id_list", "dimension_name_list", "element_name_list", + "alternative_name", ) @@ -465,88 +468,6 @@ def _reference_keys(self): return super()._reference_keys() + ("scenario_name", "alternative_name") -class FeatureItem(CacheItem): - def __getitem__(self, key): - if key == "parameter_definition_name": - return self._get_ref("parameter_definition", self["parameter_definition_id"], key).get("name") - if key in ("entity_class_id", "entity_class_name"): - return self._get_ref("parameter_definition", self["parameter_definition_id"], key).get(key) - if key == "parameter_value_list_name": - return self._get_ref("parameter_value_list", self["parameter_value_list_id"], key).get("name") - return super().__getitem__(key) - - def _reference_keys(self): - return super()._reference_keys() + ( - "entity_class_id", - "entity_class_name", - "parameter_definition_name", - "parameter_value_list_name", - ) - - -class ToolFeatureItem(CacheItem): - def __getitem__(self, key): - if key in ("entity_class_id", "entity_class_name", "parameter_definition_id", "parameter_definition_name"): - return self._get_ref("feature", self["feature_id"], key).get(key) - if key == "tool_name": - return self._get_ref("tool", self["tool_id"], key).get("name") - if key == "parameter_value_list_name": - return self._get_ref("parameter_value_list", self["parameter_value_list_id"], key).get("name") - if key == "required": - return dict.get(self, "required", False) - return super().__getitem__(key) - - def _reference_keys(self): - return super()._reference_keys() + ( - "tool_name", - "entity_class_id", - "entity_class_name", - "parameter_definition_id", - "parameter_definition_name", - "parameter_value_list_name", - ) - - -class ToolFeatureMethodItem(CacheItem): - def __getitem__(self, key): - if key in ( - "tool_id", - "tool_name", - "feature_id", - "entity_class_id", - "entity_class_name", - "parameter_definition_id", - "parameter_definition_name", - "parameter_value_list_id", - "parameter_value_list_name", - ): - return self._get_ref("tool_feature", self["tool_feature_id"], key).get(key) - if key == "method": - value_list = self._get_ref("parameter_value_list", self["parameter_value_list_id"], key) - if not value_list: - return None - try: - list_value_id = value_list["value_id_list"][self["method_index"]] - return self._get_ref("list_value", list_value_id, key).get("value") - except IndexError: - return None - return super().__getitem__(key) - - def _reference_keys(self): - return super()._reference_keys() + ( - "tool_id", - "tool_name", - "feature_id", - "entity_class_id", - "entity_class_name", - "parameter_definition_id", - "parameter_definition_name", - "parameter_value_list_id", - "parameter_value_list_name", - "method", - ) - - class ParameterValueListItem(CacheItem): def _sorted_list_values(self, key): return sorted( diff --git a/spinedb_api/db_mapping_add_mixin.py b/spinedb_api/db_mapping_add_mixin.py index d69a94c4..bf1c141c 100644 --- a/spinedb_api/db_mapping_add_mixin.py +++ b/spinedb_api/db_mapping_add_mixin.py @@ -58,10 +58,6 @@ def __init__(self, *args, **kwargs): Column("alternative_id", Integer, server_default=null()), Column("scenario_id", Integer, server_default=null()), Column("scenario_alternative_id", Integer, server_default=null()), - Column("tool_id", Integer, server_default=null()), - Column("feature_id", Integer, server_default=null()), - Column("tool_feature_id", Integer, server_default=null()), - Column("tool_feature_method_id", Integer, server_default=null()), Column("metadata_id", Integer, server_default=null()), Column("parameter_value_metadata_id", Integer, server_default=null()), Column("entity_metadata_id", Integer, server_default=null()), @@ -98,10 +94,6 @@ def generate_ids(self, tablename, dry_run=False): "alternative": "alternative_id", "scenario": "scenario_id", "scenario_alternative": "scenario_alternative_id", - "tool": "tool_id", - "feature": "feature_id", - "tool_feature": "tool_feature_id", - "tool_feature_method": "tool_feature_method_id", "metadata": "metadata_id", "parameter_value_metadata": "parameter_value_metadata_id", "entity_metadata": "entity_metadata_id", @@ -117,8 +109,8 @@ def generate_ids(self, tablename, dry_run=False): if next_id is None: real_tablename = self._real_tablename(tablename) table = self._metadata.tables[real_tablename] - id_col = self.table_ids.get(real_tablename, "id") - select_max_id = select([func.max(getattr(table.c, id_col))]) + id_field = self._id_fields.get(real_tablename, "id") + select_max_id = select([func.max(getattr(table.c, id_field))]) max_id = connection.execute(select_max_id).scalar() next_id = max_id + 1 if max_id else 1 gen = self._IdGenerator(next_id) @@ -231,7 +223,8 @@ def _do_add_items(self, tablename, *items_to_add): else: self._has_pending_changes = True - def _items_to_add_per_table(self, tablename, items_to_add): + @staticmethod + def _items_to_add_per_table(tablename, items_to_add): """ Yields tuples of string tablename, list of items to insert. Needed because some insert queries actually need to insert records to more than one table. @@ -244,61 +237,72 @@ def _items_to_add_per_table(self, tablename, items_to_add): tuple: database table name, items to add """ if tablename == "entity_class": - ecd_items_to_add = [] - for item in items_to_add: - ecd_items_to_add += [ - {"entity_class_id": item["id"], "position": position, "dimension_id": dimension_id} - for position, dimension_id in enumerate(item["dimension_id_list"]) - ] + ecd_items_to_add = [ + {"entity_class_id": item["id"], "position": position, "dimension_id": dimension_id} + for item in items_to_add + for position, dimension_id in enumerate(item["dimension_id_list"]) + ] yield ("entity_class", items_to_add) yield ("entity_class_dimension", ecd_items_to_add) elif tablename == "entity": - ee_items_to_add = [] - for item in items_to_add: - ee_items_to_add += [ - { - "entity_id": item["id"], - "entity_class_id": item["class_id"], - "position": position, - "element_id": element_id, - "dimension_id": dimension_id, - } - for position, (element_id, dimension_id) in enumerate( - zip(item["element_id_list"], item["dimension_id_list"]) - ) - ] + ee_items_to_add = [ + { + "entity_id": item["id"], + "entity_class_id": item["class_id"], + "position": position, + "element_id": element_id, + "dimension_id": dimension_id, + } + for item in items_to_add + for position, (element_id, dimension_id) in enumerate( + zip(item["element_id_list"], item["dimension_id_list"]) + ) + ] + ea_items_to_add = [ + {"entity_id": item["id"], "alternative_id": item["alternative_id"], "active": item["active"]} + for item in items_to_add + ] yield ("entity", items_to_add) yield ("entity_element", ee_items_to_add) + yield ("entity_alternative", ea_items_to_add) elif tablename == "object_class": yield ("entity_class", items_to_add) elif tablename == "object": + ea_items_to_add = [ + {"entity_id": item["id"], "alternative_id": item["alternative_id"], "active": item["active"]} + for item in items_to_add + ] yield ("entity", items_to_add) + yield ("entity_alternative", ea_items_to_add) elif tablename == "relationship_class": - ecd_items_to_add = [] - for item in items_to_add: - ecd_items_to_add += [ - {"entity_class_id": item["id"], "position": position, "dimension_id": dimension_id} - for position, dimension_id in enumerate(item["object_class_id_list"]) - ] + ecd_items_to_add = [ + {"entity_class_id": item["id"], "position": position, "dimension_id": dimension_id} + for item in items_to_add + for position, dimension_id in enumerate(item["object_class_id_list"]) + ] yield ("entity_class", items_to_add) yield ("entity_class_dimension", ecd_items_to_add) elif tablename == "relationship": - ee_items_to_add = [] - for item in items_to_add: - ee_items_to_add += [ - { - "entity_id": item["id"], - "entity_class_id": item["class_id"], - "position": position, - "element_id": element_id, - "dimension_id": dimension_id, - } - for position, (element_id, dimension_id) in enumerate( - zip(item["object_id_list"], item["object_class_id_list"]) - ) - ] + ee_items_to_add = [ + { + "entity_id": item["id"], + "entity_class_id": item["class_id"], + "position": position, + "element_id": element_id, + "dimension_id": dimension_id, + } + for item in items_to_add + for position, (element_id, dimension_id) in enumerate( + zip(item["object_id_list"], item["object_class_id_list"]) + ) + ] + ea_items_to_add = [ + {"entity_id": item["id"], "alternative_id": item["alternative_id"], "active": item["active"]} + for item in items_to_add + ] yield ("entity", items_to_add) yield ("entity_element", ee_items_to_add) + yield ("entity_alternative", ea_items_to_add) elif tablename == "parameter_definition": for item in items_to_add: item["entity_class_id"] = ( @@ -345,18 +349,6 @@ def add_parameter_value_lists(self, *items, **kwargs): def add_list_values(self, *items, **kwargs): return self.add_items("list_value", *items, **kwargs) - def add_features(self, *items, **kwargs): - return self.add_items("feature", *items, **kwargs) - - def add_tools(self, *items, **kwargs): - return self.add_items("tool", *items, **kwargs) - - def add_tool_features(self, *items, **kwargs): - return self.add_items("tool_feature", *items, **kwargs) - - def add_tool_feature_methods(self, *items, **kwargs): - return self.add_items("tool_feature_method", *items, **kwargs) - def add_alternatives(self, *items, **kwargs): return self.add_items("alternative", *items, **kwargs) @@ -491,18 +483,6 @@ def _add_parameter_value_lists(self, *items): def _add_list_values(self, *items): return self._add_items("list_value", *items) - def _add_features(self, *items): - return self._add_items("feature", *items) - - def _add_tools(self, *items): - return self._add_items("tool", *items) - - def _add_tool_features(self, *items): - return self._add_items("tool_feature", *items) - - def _add_tool_feature_methods(self, *items): - return self._add_items("tool_feature_method", *items) - def _add_alternatives(self, *items): return self._add_items("alternative", *items) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 74723971..cc55c4bc 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -20,7 +20,6 @@ import time from collections import Counter from types import MethodType -from contextlib import contextmanager from sqlalchemy import create_engine, case, MetaData, Table, Column, false, and_, func, inspect, cast, Integer, or_ from sqlalchemy.sql.expression import label, Alias from sqlalchemy.engine.url import make_url, URL @@ -68,10 +67,6 @@ class DatabaseMappingBase: "alternative", "scenario", "scenario_alternative", - "feature", - "tool", - "tool_feature", - "tool_feature_method", "metadata", "entity_metadata", "parameter_value_metadata", @@ -148,14 +143,9 @@ def __init__( self._parameter_value_sq = None self._parameter_value_list_sq = None self._list_value_sq = None - self._feature_sq = None - self._tool_sq = None - self._tool_feature_sq = None - self._tool_feature_method_sq = None self._metadata_sq = None self._parameter_value_metadata_sq = None self._entity_metadata_sq = None - self._clean_parameter_value_sq = None # Special convenience subqueries that join two or more tables self._ext_entity_class_sq = None self._ext_entity_sq = None @@ -180,9 +170,6 @@ def __init__( self._entity_parameter_value_sq = None self._object_parameter_value_sq = None self._relationship_parameter_value_sq = None - self._ext_feature_sq = None - self._ext_tool_feature_sq = None - self._ext_tool_feature_method_sq = None self._ext_parameter_value_metadata_sq = None self._ext_entity_metadata_sq = None # Import alternative suff @@ -190,26 +177,22 @@ def __init__( self._import_alternative_name = None self._table_to_sq_attr = {} # Table primary ids map: - self.table_ids = { + self._id_fields = { "object_class": "entity_class_id", "relationship_class": "entity_class_id", "entity_class_dimension": "entity_class_id", "object": "entity_id", "relationship": "entity_id", - "entity_element": "entity_id", } self.composite_pks = { "entity_element": ("entity_id", "position"), + "entity_alternative": ("entity_id", "alternative_id"), "entity_class_dimension": ("entity_class_id", "position"), } # Subqueries used to populate cache self.cache_sqs = { "entity_class": "ext_entity_class_sq", "entity": "ext_entity_sq", - "feature": "feature_sq", - "tool": "tool_sq", - "tool_feature": "tool_feature_sq", - "tool_feature_method": "tool_feature_method_sq", "parameter_value_list": "parameter_value_list_sq", "list_value": "list_value_sq", "alternative": "alternative_sq", @@ -217,16 +200,13 @@ def __init__( "scenario_alternative": "scenario_alternative_sq", "entity_group": "entity_group_sq", "parameter_definition": "parameter_definition_sq", - "parameter_value": "clean_parameter_value_sq", + "parameter_value": "parameter_value_sq", "metadata": "metadata_sq", "entity_metadata": "ext_entity_metadata_sq", "parameter_value_metadata": "ext_parameter_value_metadata_sq", "commit": "commit_sq", } self.ancestor_tablenames = { - "feature": ("parameter_definition",), - "tool_feature": ("tool", "feature"), - "tool_feature_method": ("tool_feature", "parameter_value_list", "list_value"), "scenario_alternative": ("scenario", "alternative"), "entity": ("entity_class",), "entity_group": ("entity_class", "entity"), @@ -266,12 +246,9 @@ def _descendant_tablenames(self, tablename): "scenario": ("scenario_alternative",), "entity_class": ("entity", "parameter_definition"), "entity": ("parameter_value", "entity_group", "entity_metadata"), - "parameter_definition": ("parameter_value", "feature"), - "parameter_value_list": ("feature",), + "parameter_definition": ("parameter_value",), + "parameter_value_list": (), "parameter_value": ("parameter_value_metadata", "entity_metadata"), - "feature": ("tool_feature",), - "tool": ("tool_feature",), - "tool_feature": ("tool_feature_method",), "entity_metadata": ("metadata",), "parameter_value_metadata": ("metadata",), } @@ -879,21 +856,6 @@ def parameter_value_sq(self): self._parameter_value_sq = self._make_parameter_value_sq() return self._parameter_value_sq - @property - def clean_parameter_value_sq(self): - """A subquery of the parameter_value table that excludes rows with filtered entities. - This yields the correct results whenever there are both a scenario filter that filters some parameter values, - and a tool filter that then filters some entities based on the value of some their parameters - after the scenario filtering. Mildly insane. - """ - if self._clean_parameter_value_sq is None: - self._clean_parameter_value_sq = ( - self.query(self.parameter_value_sq) - .join(self.entity_sq, self.entity_sq.c.id == self.parameter_value_sq.c.entity_id) - .subquery() - ) - return self._clean_parameter_value_sq - @property def parameter_value_list_sq(self): """A subquery of the form: @@ -915,30 +877,6 @@ def list_value_sq(self): self._list_value_sq = self._subquery("list_value") return self._list_value_sq - @property - def feature_sq(self): - if self._feature_sq is None: - self._feature_sq = self._subquery("feature") - return self._feature_sq - - @property - def tool_sq(self): - if self._tool_sq is None: - self._tool_sq = self._subquery("tool") - return self._tool_sq - - @property - def tool_feature_sq(self): - if self._tool_feature_sq is None: - self._tool_feature_sq = self._subquery("tool_feature") - return self._tool_feature_sq - - @property - def tool_feature_method_sq(self): - if self._tool_feature_method_sq is None: - self._tool_feature_method_sq = self._subquery("tool_feature_method") - return self._tool_feature_method_sq - @property def metadata_sq(self): if self._metadata_sq is None: @@ -1691,92 +1629,6 @@ def relationship_parameter_value_sq(self): ) return self._relationship_parameter_value_sq - @property - def ext_feature_sq(self): - """ - Returns: - sqlalchemy.sql.expression.Alias - """ - if self._ext_feature_sq is None: - self._ext_feature_sq = ( - self.query( - self.feature_sq.c.id.label("id"), - self.entity_class_sq.c.id.label("entity_class_id"), - self.entity_class_sq.c.name.label("entity_class_name"), - self.feature_sq.c.parameter_definition_id.label("parameter_definition_id"), - self.parameter_definition_sq.c.name.label("parameter_definition_name"), - self.parameter_value_list_sq.c.id.label("parameter_value_list_id"), - self.parameter_value_list_sq.c.name.label("parameter_value_list_name"), - self.feature_sq.c.description.label("description"), - self.feature_sq.c.commit_id.label("commit_id"), - ) - .filter(self.feature_sq.c.parameter_definition_id == self.parameter_definition_sq.c.id) - .filter(self.parameter_definition_sq.c.parameter_value_list_id == self.parameter_value_list_sq.c.id) - .filter(self.parameter_definition_sq.c.entity_class_id == self.entity_class_sq.c.id) - .subquery() - ) - return self._ext_feature_sq - - @property - def ext_tool_feature_sq(self): - """ - Returns: - sqlalchemy.sql.expression.Alias - """ - if self._ext_tool_feature_sq is None: - self._ext_tool_feature_sq = ( - self.query( - self.tool_feature_sq.c.id.label("id"), - self.tool_feature_sq.c.tool_id.label("tool_id"), - self.tool_sq.c.name.label("tool_name"), - self.tool_feature_sq.c.feature_id.label("feature_id"), - self.ext_feature_sq.c.entity_class_id.label("entity_class_id"), - self.ext_feature_sq.c.entity_class_name.label("entity_class_name"), - self.ext_feature_sq.c.parameter_definition_id.label("parameter_definition_id"), - self.ext_feature_sq.c.parameter_definition_name.label("parameter_definition_name"), - self.ext_feature_sq.c.parameter_value_list_id.label("parameter_value_list_id"), - self.ext_feature_sq.c.parameter_value_list_name.label("parameter_value_list_name"), - self.tool_feature_sq.c.required.label("required"), - self.tool_feature_sq.c.commit_id.label("commit_id"), - ) - .filter(self.tool_feature_sq.c.tool_id == self.tool_sq.c.id) - .filter(self.tool_feature_sq.c.feature_id == self.ext_feature_sq.c.id) - .subquery() - ) - return self._ext_tool_feature_sq - - @property - def ext_tool_feature_method_sq(self): - """ - Returns: - sqlalchemy.sql.expression.Alias - """ - if self._ext_tool_feature_method_sq is None: - self._ext_tool_feature_method_sq = ( - self.query( - self.tool_feature_method_sq.c.id, - self.ext_tool_feature_sq.c.id.label("tool_feature_id"), - self.ext_tool_feature_sq.c.tool_id, - self.ext_tool_feature_sq.c.tool_name, - self.ext_tool_feature_sq.c.feature_id, - self.ext_tool_feature_sq.c.entity_class_id, - self.ext_tool_feature_sq.c.entity_class_name, - self.ext_tool_feature_sq.c.parameter_definition_id, - self.ext_tool_feature_sq.c.parameter_definition_name, - self.ext_tool_feature_sq.c.parameter_value_list_id, - self.ext_tool_feature_sq.c.parameter_value_list_name, - self.tool_feature_method_sq.c.method_index, - self.list_value_sq.c.value.label("method"), - self.tool_feature_method_sq.c.commit_id, - ) - .filter(self.tool_feature_method_sq.c.tool_feature_id == self.ext_tool_feature_sq.c.id) - .filter(self.ext_tool_feature_sq.c.parameter_value_list_id == self.parameter_value_list_sq.c.id) - .filter(self.parameter_value_list_sq.c.id == self.list_value_sq.c.parameter_value_list_id) - .filter(self.tool_feature_method_sq.c.method_index == self.list_value_sq.c.index) - .subquery() - ) - return self._ext_tool_feature_method_sq - @property def ext_parameter_value_metadata_sq(self): """ @@ -1835,7 +1687,9 @@ def _make_entity_sq(self): Returns: Alias: an entity subquery """ - return self._subquery("entity") + e_sq = self._subquery("entity") + ea_sq = self._subquery("entity_alternative") + return self.query(e_sq, ea_sq).filter(e_sq.c.id == ea_sq.c.entity_id).subquery() def _make_entity_class_sq(self): """ @@ -2031,16 +1885,6 @@ def restore_parameter_value_sq_maker(self): self._make_parameter_value_sq = MethodType(DatabaseMappingBase._make_parameter_value_sq, self) self._clear_subqueries("parameter_value") - def override_create_import_alternative(self, method): - """ - Overrides the ``_create_import_alternative`` function. - - Args: - method (Callable) - """ - self._create_import_alternative = MethodType(method, self) - self._import_alternative_id = None - def override_alternative_sq_maker(self, method): """ Overrides the function that creates the ``alternative_sq`` property. @@ -2089,6 +1933,16 @@ def restore_scenario_alternative_sq_maker(self): self._make_scenario_alternative_sq = MethodType(DatabaseMappingBase._make_scenario_alternative_sq, self) self._clear_subqueries("scenario_alternative") + def override_create_import_alternative(self, method): + """ + Overrides the ``_create_import_alternative`` function. + + Args: + method (Callable) + """ + self._create_import_alternative = MethodType(method, self) + self._import_alternative_id = None + def _checked_execute(self, stmt, items): if not items: return @@ -2097,8 +1951,8 @@ def _checked_execute(self, stmt, items): def _get_primary_key(self, tablename): pk = self.composite_pks.get(tablename) if pk is None: - table_id = self.table_ids.get(tablename, "id") - pk = (table_id,) + id_field = self._id_fields.get(tablename, "id") + pk = (id_field,) return pk def _reset_mapping(self): diff --git a/spinedb_api/db_mapping_check_mixin.py b/spinedb_api/db_mapping_check_mixin.py index 03bec8d6..5efe5b6f 100644 --- a/spinedb_api/db_mapping_check_mixin.py +++ b/spinedb_api/db_mapping_check_mixin.py @@ -31,10 +31,6 @@ check_parameter_value, check_parameter_value_list, check_list_value, - check_feature, - check_tool, - check_tool_feature, - check_tool_feature_method, check_entity_metadata, check_metadata, check_parameter_value_metadata, @@ -64,168 +60,11 @@ def check_items(self, tablename, *items, for_update=False, strict=False, cache=N "parameter_value": self.check_parameter_values, "parameter_value_list": self.check_parameter_value_lists, "list_value": self.check_list_values, - "feature": self.check_features, - "tool": self.check_tools, - "tool_feature": self.check_tool_features, - "tool_feature_method": self.check_tool_feature_methods, "metadata": self.check_metadata, "entity_metadata": self.check_entity_metadata, "parameter_value_metadata": self.check_parameter_value_metadata, }[tablename](*items, for_update=for_update, strict=strict, cache=cache) - def check_features(self, *items, for_update=False, strict=False, cache=None): - """Check whether features passed as argument respect integrity constraints. - - Args: - items (Iterable): One or more Python :class:`dict` objects representing the items to be checked. - strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` - if one of the items violates an integrity constraint. - - Returns - list: items that passed the check. - list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. - """ - if cache is None: - cache = self.make_cache({"feature"}, include_ancestors=True) - intgr_error_log = [] - checked_items = list() - feature_ids = {x.parameter_definition_id: x.id for x in cache.get("feature", {}).values()} - parameter_definitions = { - x.id: { - "name": x.parameter_name, - "entity_class_id": x.entity_class_id, - "parameter_value_list_id": x.value_list_id, - } - for x in cache.get("parameter_definition", {}).values() - } - for item in items: - try: - with self._manage_stocks( - "feature", item, {("parameter_definition_id",): feature_ids}, for_update, cache, intgr_error_log - ) as item: - check_feature(item, feature_ids, parameter_definitions) - checked_items.append(item) - except SpineIntegrityError as e: - if strict: - raise e - intgr_error_log.append(e) - return checked_items, intgr_error_log - - def check_tools(self, *items, for_update=False, strict=False, cache=None): - """Check whether tools passed as argument respect integrity constraints. - - Args: - items (Iterable): One or more Python :class:`dict` objects representing the items to be checked. - strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` - if one of the items violates an integrity constraint. - - Returns - list: items that passed the check. - list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. - """ - if cache is None: - cache = self.make_cache({"tool"}, include_ancestors=True) - intgr_error_log = [] - checked_items = list() - tool_ids = {x.name: x.id for x in cache.get("tool", {}).values()} - for item in items: - try: - with self._manage_stocks( - "tool", item, {("name",): tool_ids}, for_update, cache, intgr_error_log - ) as item: - check_tool(item, tool_ids) - checked_items.append(item) - except SpineIntegrityError as e: - if strict: - raise e - intgr_error_log.append(e) - return checked_items, intgr_error_log - - def check_tool_features(self, *items, for_update=False, strict=False, cache=None): - """Check whether tool features passed as argument respect integrity constraints. - - Args: - items (Iterable): One or more Python :class:`dict` objects representing the items to be checked. - strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` - if one of the items violates an integrity constraint. - - Returns - list: items that passed the check. - list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. - """ - if cache is None: - cache = self.make_cache({"tool_feature"}, include_ancestors=True) - intgr_error_log = [] - checked_items = list() - tool_feature_ids = {(x.tool_id, x.feature_id): x.id for x in cache.get("tool_feature", {}).values()} - tools = {x.id: x._asdict() for x in cache.get("tool", {}).values()} - features = { - x.id: { - "name": x.entity_class_name + "/" + x.parameter_definition_name, - "parameter_value_list_id": x.parameter_value_list_id, - } - for x in cache.get("feature", {}).values() - } - for item in items: - try: - with self._manage_stocks( - "tool_feature", - item, - {("tool_id", "feature_id"): tool_feature_ids}, - for_update, - cache, - intgr_error_log, - ) as item: - check_tool_feature(item, tool_feature_ids, tools, features) - checked_items.append(item) - except SpineIntegrityError as e: - if strict: - raise e - intgr_error_log.append(e) - return checked_items, intgr_error_log - - def check_tool_feature_methods(self, *items, for_update=False, strict=False, cache=None): - """Check whether tool feature methods passed as argument respect integrity constraints. - - Args: - items (Iterable): One or more Python :class:`dict` objects representing the items to be checked. - strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` - if one of the items violates an integrity constraint. - - Returns - list: items that passed the check. - list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. - """ - if cache is None: - cache = self.make_cache({"tool_feature_method"}, include_ancestors=True) - intgr_error_log = [] - checked_items = list() - tool_feature_method_ids = { - (x.tool_feature_id, x.method_index): x.id for x in cache.get("tool_feature_method", {}).values() - } - tool_features = {x.id: x._asdict() for x in cache.get("tool_feature", {}).values()} - parameter_value_lists = { - x.id: {"name": x.name, "value_index_list": x.value_index_list} - for x in cache.get("parameter_value_list", {}).values() - } - for item in items: - try: - with self._manage_stocks( - "tool_feature_method", - item, - {("tool_feature_id", "method_index"): tool_feature_method_ids}, - for_update, - cache, - intgr_error_log, - ) as item: - check_tool_feature_method(item, tool_feature_method_ids, tool_features, parameter_value_lists) - checked_items.append(item) - except SpineIntegrityError as e: - if strict: - raise e - intgr_error_log.append(e) - return checked_items, intgr_error_log - def check_alternatives(self, *items, for_update=False, strict=False, cache=None): """Check whether alternatives passed as argument respect integrity constraints. diff --git a/spinedb_api/db_mapping_remove_mixin.py b/spinedb_api/db_mapping_remove_mixin.py index f3a1cd89..6ab2e806 100644 --- a/spinedb_api/db_mapping_remove_mixin.py +++ b/spinedb_api/db_mapping_remove_mixin.py @@ -39,12 +39,15 @@ def remove_items(self, **kwargs): **kwargs: keyword is table name, argument is list of ids to remove """ for tablename, ids in kwargs.items(): + if tablename == "alternative": + # Do not remove the Base alternative + ids -= {1} if not ids: continue real_tablename = self._real_tablename(tablename) - table_id = self.table_ids.get(real_tablename, "id") + id_field = self._id_fields.get(real_tablename, "id") table = self._metadata.tables[real_tablename] - delete = table.delete().where(self.in_(getattr(table.c, table_id), ids)) + delete = table.delete().where(self.in_(getattr(table.c, id_field), ids)) try: self.connection.execute(delete) table_cache = self.cache.get(tablename) @@ -96,10 +99,6 @@ def cascading_ids(self, cache=None, **kwargs): self._merge(ids, self._alternative_cascading_ids(kwargs.get("alternative", set()), cache)) self._merge(ids, self._scenario_cascading_ids(kwargs.get("scenario", set()), cache)) self._merge(ids, self._scenario_alternatives_cascading_ids(kwargs.get("scenario_alternative", set()), cache)) - self._merge(ids, self._feature_cascading_ids(kwargs.get("feature", set()), cache)) - self._merge(ids, self._tool_cascading_ids(kwargs.get("tool", set()), cache)) - self._merge(ids, self._tool_feature_cascading_ids(kwargs.get("tool_feature", set()), cache)) - self._merge(ids, self._tool_feature_method_cascading_ids(kwargs.get("tool_feature_method", set()), cache)) self._merge(ids, self._metadata_cascading_ids(kwargs.get("metadata", set()), cache)) self._merge(ids, self._entity_metadata_cascading_ids(kwargs.get("entity_metadata", set()), cache)) self._merge( @@ -186,9 +185,7 @@ def _parameter_definition_cascading_ids(self, ids, cache): """Returns parameter definition cascading ids.""" cascading_ids = {"parameter_definition": set(ids)} parameter_values = [x for x in dict.values(cache.get("parameter_value", {})) if x.parameter_id in ids] - features = [x for x in dict.values(cache.get("feature", {})) if x.parameter_definition_id in ids] self._merge(cascading_ids, self._parameter_value_cascading_ids({x.id for x in parameter_values}, cache)) - self._merge(cascading_ids, self._feature_cascading_ids({x.id for x in features}, cache)) return cascading_ids def _parameter_value_cascading_ids(self, ids, cache): # pylint: disable=no-self-use @@ -203,8 +200,6 @@ def _parameter_value_cascading_ids(self, ids, cache): # pylint: disable=no-self def _parameter_value_list_cascading_ids(self, ids, cache): # pylint: disable=no-self-use """Returns parameter value list cascading ids and adds them to the given dictionaries.""" cascading_ids = {"parameter_value_list": set(ids)} - features = [x for x in dict.values(cache.get("feature", {})) if x.parameter_value_list_id in ids] - self._merge(cascading_ids, self._feature_cascading_ids({x.id for x in features}, cache)) return cascading_ids def _list_value_cascading_ids(self, ids, cache): # pylint: disable=no-self-use @@ -214,29 +209,6 @@ def _list_value_cascading_ids(self, ids, cache): # pylint: disable=no-self-use def _scenario_alternatives_cascading_ids(self, ids, cache): return {"scenario_alternative": set(ids)} - def _feature_cascading_ids(self, ids, cache): - cascading_ids = {"feature": set(ids)} - tool_features = [x for x in dict.values(cache.get("tool_feature", {})) if x.feature_id in ids] - self._merge(cascading_ids, self._tool_feature_cascading_ids({x.id for x in tool_features}, cache)) - return cascading_ids - - def _tool_cascading_ids(self, ids, cache): - cascading_ids = {"tool": set(ids)} - tool_features = [x for x in dict.values(cache.get("tool_feature", {})) if x.tool_id in ids] - self._merge(cascading_ids, self._tool_feature_cascading_ids({x.id for x in tool_features}, cache)) - return cascading_ids - - def _tool_feature_cascading_ids(self, ids, cache): - cascading_ids = {"tool_feature": set(ids)} - tool_feature_methods = [ - x for x in dict.values(cache.get("tool_feature_method", {})) if x.tool_feature_id in ids - ] - self._merge(cascading_ids, self._tool_feature_method_cascading_ids({x.id for x in tool_feature_methods}, cache)) - return cascading_ids - - def _tool_feature_method_cascading_ids(self, ids, cache): - return {"tool_feature_method": set(ids)} - def _metadata_cascading_ids(self, ids, cache): cascading_ids = {"metadata": set(ids)} entity_metadata = { diff --git a/spinedb_api/db_mapping_update_mixin.py b/spinedb_api/db_mapping_update_mixin.py index 0c4f5a94..2223edc0 100644 --- a/spinedb_api/db_mapping_update_mixin.py +++ b/spinedb_api/db_mapping_update_mixin.py @@ -28,15 +28,92 @@ def _add_commit_id(self, *items): def _update_items(self, tablename, *items, dry_run=False): if not items: return set() + if dry_run: + return {x["id"] for x in items} # Special cases if tablename == "entity": - return self._update_entities(*items) + return self._do_update_entities(*items) + if tablename == "object": + return self._do_update_objects(*items) if tablename == "relationship": - return self._update_wide_relationships(*items) + return self._do_update_wide_relationships(*items) real_tablename = self._real_tablename(tablename) - if not dry_run: - self._do_update_items(real_tablename, *items) - return {x["id"] for x in items} + self._do_update_items(real_tablename, *items) + + def _do_update_entities(self, *items): + entity_items = [] + entity_element_items = [] + entity_alternative_items = [] + for item in items: + entity_id = item["id"] + class_id = item["class_id"] + dimension_id_list = item["dimension_id_list"] + element_id_list = item["element_id_list"] + entity_items.append( + {"id": entity_id, "class_id": class_id, "name": item["name"], "description": item.get("description")} + ) + entity_element_items.extend( + [ + { + "entity_class_id": class_id, + "entity_id": entity_id, + "position": position, + "dimension_id": dimension_id, + "element_id": element_id, + } + for position, (dimension_id, element_id) in enumerate(zip(dimension_id_list, element_id_list)) + ] + ) + entity_alternative_items.append( + {"entity_id": entity_id, "alternative_id": item["alternative_id"], "active": item["active"]} + ) + self._do_update_items("entity", *entity_items) + self._do_update_items("entity_element", *entity_element_items) + self._do_update_items("entity_alternative", *entity_alternative_items) + return {x["id"] for x in entity_items} + + def _do_update_objects(self, *items): + entity_items = [] + entity_alternative_items = [] + for item in items: + entity_id = item["id"] + class_id = item["class_id"] + entity_items.append( + {"id": entity_id, "class_id": class_id, "name": item["name"], "description": item.get("description")} + ) + entity_alternative_items.append( + {"entity_id": entity_id, "alternative_id": item["alternative_id"], "active": item["active"]} + ) + self._do_update_items("entity", *entity_items) + self._do_update_items("entity_alternative", *entity_alternative_items) + return {x["id"] for x in entity_items} + + def _do_update_wide_relationships(self, *items): + entity_items = [] + entity_element_items = [] + for item in items: + entity_id = item["id"] + class_id = item["class_id"] + object_class_id_list = item["object_class_id_list"] + object_id_list = item["object_id_list"] + entity_items.append( + {"id": entity_id, "class_id": class_id, "name": item["name"], "description": item.get("description")} + ) + entity_element_items.extend( + [ + { + "entity_class_id": class_id, + "entity_id": entity_id, + "position": position, + "dimension_id": dimension_id, + "element_id": element_id, + } + for position, (dimension_id, element_id) in enumerate(zip(object_class_id_list, object_id_list)) + ] + ) + self._do_update_items("entity", *entity_items) + self._do_update_items("entity_element", *entity_element_items) + return {x["id"] for x in entity_items} def _do_update_items(self, tablename, *items): if not items: @@ -111,33 +188,7 @@ def update_entities(self, *items, **kwargs): return self.update_items("entity", *items, **kwargs) def _update_entities(self, *items): - entity_items = [] - entity_element_items = [] - for item in items: - entity_id = item["id"] - class_id = item["class_id"] - ent_item = { - "id": entity_id, - "class_id": class_id, - "name": item["name"], - "description": item.get("description"), - } - entity_items.append(ent_item) - dimension_id_list = item["dimension_id_list"] - element_id_list = item["element_id_list"] - for position, (dimension_id, element_id) in enumerate(zip(dimension_id_list, element_id_list)): - rel_ent_item = { - "id": None, # Need to have an "id" field to make _update_items() happy. - "entity_class_id": class_id, - "entity_id": entity_id, - "position": position, - "dimension_id": dimension_id, - "element_id": element_id, - } - entity_element_items.append(rel_ent_item) - self._do_update_items("entity", *entity_items) - self._do_update_items("entity_element", *entity_element_items) - return {x["id"] for x in entity_items} + return self._update_items("entity", *items) def update_object_classes(self, *items, **kwargs): return self.update_items("object_class", *items, **kwargs) @@ -161,33 +212,7 @@ def update_wide_relationships(self, *items, **kwargs): return self.update_items("relationship", *items, **kwargs) def _update_wide_relationships(self, *items): - entity_items = [] - entity_element_items = [] - for item in items: - entity_id = item["id"] - class_id = item["class_id"] - ent_item = { - "id": entity_id, - "class_id": class_id, - "name": item["name"], - "description": item.get("description"), - } - entity_items.append(ent_item) - object_class_id_list = item["object_class_id_list"] - object_id_list = item["object_id_list"] - for position, (dimension_id, element_id) in enumerate(zip(object_class_id_list, object_id_list)): - rel_ent_item = { - "id": None, # Need to have an "id" field to make _update_items() happy. - "entity_class_id": class_id, - "entity_id": entity_id, - "position": position, - "dimension_id": dimension_id, - "element_id": element_id, - } - entity_element_items.append(rel_ent_item) - self._do_update_items("entity", *entity_items) - self._do_update_items("entity_element", *entity_element_items) - return {x["id"] for x in entity_items} + return self._update_items("relationship", *items) def update_parameter_definitions(self, *items, **kwargs): return self.update_items("parameter_definition", *items, **kwargs) @@ -201,30 +226,6 @@ def update_parameter_values(self, *items, **kwargs): def _update_parameter_values(self, *items): return self._update_items("parameter_value", *items) - def update_features(self, *items, **kwargs): - return self.update_items("feature", *items, **kwargs) - - def _update_features(self, *items): - return self._update_items("feature", *items) - - def update_tools(self, *items, **kwargs): - return self.update_items("tool", *items, **kwargs) - - def _update_tools(self, *items): - return self._update_items("tool", *items) - - def update_tool_features(self, *items, **kwargs): - return self.update_items("tool_feature", *items, **kwargs) - - def _update_tool_features(self, *items): - return self._update_items("tool_feature", *items) - - def update_tool_feature_methods(self, *items, **kwargs): - return self.update_items("tool_feature_method", *items, **kwargs) - - def _update_tool_feature_methods(self, *items): - return self._update_items("tool_feature_method", *items) - def update_parameter_value_lists(self, *items, **kwargs): return self.update_items("parameter_value_list", *items, **kwargs) diff --git a/spinedb_api/diff_db_mapping.py b/spinedb_api/diff_db_mapping.py index 526464c6..7f06f1b8 100644 --- a/spinedb_api/diff_db_mapping.py +++ b/spinedb_api/diff_db_mapping.py @@ -79,7 +79,7 @@ def _get_items_for_update_and_insert(self, tablename, checked_items): items_for_insert = list() dirty_ids = set() updated_ids = set() - id_field = self.table_ids.get(tablename, "id") + id_field = self._id_fields.get(tablename, "id") for item in checked_items: id_ = item[id_field] updated_ids.add(id_) @@ -162,7 +162,7 @@ def remove_items(self, **kwargs): """ if self.committing: for tablename, ids in kwargs.items(): - table_id = self.table_ids.get(tablename, "id") + table_id = self._id_fields.get(tablename, "id") diff_table = self._diff_table(tablename) delete = diff_table.delete().where(self.in_(getattr(diff_table.c, table_id), ids)) try: diff --git a/spinedb_api/diff_db_mapping_base.py b/spinedb_api/diff_db_mapping_base.py index ad663671..b866ea20 100644 --- a/spinedb_api/diff_db_mapping_base.py +++ b/spinedb_api/diff_db_mapping_base.py @@ -85,7 +85,7 @@ def _subquery(self, tablename): SELECT * FROM diff_table """ orig_table = self._metadata.tables[tablename] - table_id = self.table_ids.get(tablename, "id") + table_id = self._id_fields.get(tablename, "id") qry = self.query(*labelled_columns(orig_table)).filter( ~self.in_(getattr(orig_table.c, table_id), self.dirty_item_id[tablename]) ) diff --git a/spinedb_api/diff_db_mapping_commit_mixin.py b/spinedb_api/diff_db_mapping_commit_mixin.py index 9a03632e..35c570b5 100644 --- a/spinedb_api/diff_db_mapping_commit_mixin.py +++ b/spinedb_api/diff_db_mapping_commit_mixin.py @@ -42,13 +42,13 @@ def commit_session(self, comment): if not ids: continue table = self._metadata.tables[tablename] - id_col = self.table_ids.get(tablename, "id") + id_col = self._id_fields.get(tablename, "id") self.query(table).filter(self.in_(getattr(table.c, id_col), ids)).delete(synchronize_session=False) # Update for tablename, ids in self.updated_item_id.items(): if not ids: continue - id_col = self.table_ids.get(tablename, "id") + id_col = self._id_fields.get(tablename, "id") orig_table = self._metadata.tables[tablename] diff_table = self._diff_table(tablename) updated_items = [] @@ -65,7 +65,7 @@ def commit_session(self, comment): for tablename, ids in self.added_item_id.items(): if not ids: continue - id_col = self.table_ids.get(tablename, "id") + id_col = self._id_fields.get(tablename, "id") orig_table = self._metadata.tables[tablename] diff_table = self._diff_table(tablename) new_items = [] diff --git a/spinedb_api/export_functions.py b/spinedb_api/export_functions.py index 823c045e..db011fe4 100644 --- a/spinedb_api/export_functions.py +++ b/spinedb_api/export_functions.py @@ -40,10 +40,6 @@ def export_data( alternative_ids=Asterisk, scenario_ids=Asterisk, scenario_alternative_ids=Asterisk, - tool_ids=Asterisk, - feature_ids=Asterisk, - tool_feature_ids=Asterisk, - tool_feature_method_ids=Asterisk, make_cache=None, parse_value=from_database, ): @@ -64,10 +60,6 @@ def export_data( alternative_ids (Iterable, optional): A collection of ids to pick from the database table scenario_ids (Iterable, optional): A collection of ids to pick from the database table scenario_alternative_ids (Iterable, optional): A collection of ids to pick from the database table - tool_ids (Iterable, optional): A collection of ids to pick from the database table - feature_ids (Iterable, optional): A collection of ids to pick from the database table - tool_feature_ids (Iterable, optional): A collection of ids to pick from the database table - tool_feature_method_ids (Iterable, optional): A collection of ids to pick from the database table Returns: dict: exported data @@ -105,12 +97,6 @@ def export_data( "alternatives": export_alternatives(db_map, alternative_ids, make_cache=make_cache), "scenarios": export_scenarios(db_map, scenario_ids, make_cache=make_cache), "scenario_alternatives": export_scenario_alternatives(db_map, scenario_alternative_ids, make_cache=make_cache), - "tools": export_tools(db_map, tool_ids, make_cache=make_cache), - "features": export_features(db_map, feature_ids, make_cache=make_cache), - "tool_features": export_tool_features(db_map, tool_feature_ids, make_cache=make_cache), - "tool_feature_methods": export_tool_feature_methods( - db_map, tool_feature_method_ids, make_cache=make_cache, parse_value=parse_value - ), } return {key: value for key, value in data.items() if value} @@ -366,28 +352,3 @@ def export_scenario_alternatives(db_map, ids=Asterisk, make_cache=None): ), key=itemgetter(0), ) - - -def export_tools(db_map, ids=Asterisk, make_cache=None): - return sorted((x.name, x.description) for x in _get_items(db_map, "tool", ids, make_cache)) - - -def export_features(db_map, ids=Asterisk, make_cache=None): - return sorted( - (x.entity_class_name, x.parameter_definition_name, x.parameter_value_list_name, x.description) - for x in _get_items(db_map, "feature", ids, make_cache) - ) - - -def export_tool_features(db_map, ids=Asterisk, make_cache=None): - return sorted( - (x.tool_name, x.entity_class_name, x.parameter_definition_name, x.required) - for x in _get_items(db_map, "tool_feature", ids, make_cache) - ) - - -def export_tool_feature_methods(db_map, ids=Asterisk, make_cache=None, parse_value=from_database): - return sorted( - (x.tool_name, x.entity_class_name, x.parameter_definition_name, parse_value(x.method, None)) - for x in _get_items(db_map, "tool_feature_method", ids, make_cache) - ) diff --git a/spinedb_api/export_mapping/__init__.py b/spinedb_api/export_mapping/__init__.py index 0b8e847a..c75b202c 100644 --- a/spinedb_api/export_mapping/__init__.py +++ b/spinedb_api/export_mapping/__init__.py @@ -16,7 +16,6 @@ from .generator import rows, titles from .settings import ( alternative_export, - feature_export, entity_export, entity_group_export, entity_parameter_default_value_export, @@ -26,7 +25,4 @@ entity_dimension_parameter_value_export, scenario_alternative_export, scenario_export, - tool_export, - tool_feature_export, - tool_feature_method_export, ) diff --git a/spinedb_api/export_mapping/export_mapping.py b/spinedb_api/export_mapping/export_mapping.py index 67534837..8a4066e7 100644 --- a/spinedb_api/export_mapping/export_mapping.py +++ b/spinedb_api/export_mapping/export_mapping.py @@ -1439,227 +1439,6 @@ def is_buddy(parent): return isinstance(parent, ScenarioAlternativeMapping) -class FeatureEntityClassMapping(ExportMapping): - """Maps feature entity classes. - - Can be used as the topmost mapping. - """ - - MAP_TYPE = "FeatureEntityClass" - - def add_query_columns(self, db_map, query): - return query.add_columns(db_map.ext_feature_sq.c.entity_class_id, db_map.ext_feature_sq.c.entity_class_name) - - @staticmethod - def name_field(): - return "entity_class_name" - - @staticmethod - def id_field(): - return "entity_class_id" - - -class FeatureParameterDefinitionMapping(ExportMapping): - """Maps feature parameter definitions. - - Cannot be used as the topmost mapping; must have a :class:`FeatureEntityClassMapping` as parent. - """ - - MAP_TYPE = "FeatureParameterDefinition" - - def add_query_columns(self, db_map, query): - return query.add_columns( - db_map.ext_feature_sq.c.parameter_definition_id, db_map.ext_feature_sq.c.parameter_definition_name - ) - - @staticmethod - def name_field(): - return "parameter_definition_name" - - @staticmethod - def id_field(): - return "parameter_definition_id" - - @staticmethod - def is_buddy(parent): - return isinstance(parent, FeatureEntityClassMapping) - - -class ToolMapping(ExportMapping): - """Maps tools. - - Can be used as the topmost mapping. - """ - - MAP_TYPE = "Tool" - - def add_query_columns(self, db_map, query): - return query.add_columns(db_map.tool_sq.c.id.label("tool_id"), db_map.tool_sq.c.name.label("tool_name")) - - @staticmethod - def name_field(): - return "tool_name" - - @staticmethod - def id_field(): - return "tool_id" - - -class ToolFeatureEntityClassMapping(ExportMapping): - """Maps tool feature entity classes. - - Cannot be used as the topmost mapping; must have :class:`ToolMapping` as parent. - """ - - MAP_TYPE = "ToolFeatureEntityClass" - - def add_query_columns(self, db_map, query): - return query.add_columns( - db_map.ext_tool_feature_sq.c.entity_class_id, db_map.ext_tool_feature_sq.c.entity_class_name - ) - - def filter_query(self, db_map, query): - return query.outerjoin(db_map.ext_tool_feature_sq, db_map.ext_tool_feature_sq.c.tool_id == db_map.tool_sq.c.id) - - @staticmethod - def name_field(): - return "entity_class_name" - - @staticmethod - def id_field(): - return "entity_class_id" - - @staticmethod - def is_buddy(parent): - return isinstance(parent, ToolMapping) - - -class ToolFeatureParameterDefinitionMapping(ExportMapping): - """Maps tool feature parameter definitions. - - Cannot be used as the topmost mapping; must have :class:`ToolFeatureEntityClassMapping` as parent. - """ - - MAP_TYPE = "ToolFeatureParameterDefinition" - - def add_query_columns(self, db_map, query): - return query.add_columns( - db_map.ext_tool_feature_sq.c.parameter_definition_id, db_map.ext_tool_feature_sq.c.parameter_definition_name - ) - - @staticmethod - def name_field(): - return "parameter_definition_name" - - @staticmethod - def id_field(): - return "parameter_definition_id" - - @staticmethod - def is_buddy(parent): - return isinstance(parent, ToolFeatureEntityClassMapping) - - -class ToolFeatureRequiredFlagMapping(ExportMapping): - """Maps tool feature required flags. - - Cannot be used as the topmost mapping; must have :class:`ToolFeatureEntityClassMapping` as parent. - """ - - MAP_TYPE = "ToolFeatureRequiredFlag" - - def add_query_columns(self, db_map, query): - return query.add_columns(db_map.ext_tool_feature_sq.c.required) - - @staticmethod - def name_field(): - return "required" - - @staticmethod - def id_field(): - return "required" - - -class ToolFeatureMethodEntityClassMapping(ExportMapping): - """Maps tool feature method entity classes. - - Cannot be used as the topmost mapping; must have :class:`ToolMapping` as parent. - """ - - MAP_TYPE = "ToolFeatureMethodEntityClass" - - def add_query_columns(self, db_map, query): - return query.add_columns( - db_map.ext_tool_feature_sq.c.entity_class_id, db_map.ext_tool_feature_sq.c.entity_class_name - ) - - def filter_query(self, db_map, query): - return query.outerjoin(db_map.ext_tool_feature_sq, db_map.ext_tool_feature_sq.c.tool_id == db_map.tool_sq.c.id) - - @staticmethod - def name_field(): - return "entity_class_name" - - @staticmethod - def id_field(): - return "entity_class_id" - - -class ToolFeatureMethodParameterDefinitionMapping(ExportMapping): - """Maps tool feature method parameter definitions. - - Cannot be used as the topmost mapping; must have :class:`ToolFeatureMethodEntityClassMapping` as parent. - """ - - MAP_TYPE = "ToolFeatureMethodParameterDefinition" - - def add_query_columns(self, db_map, query): - return query.add_columns( - db_map.ext_tool_feature_sq.c.parameter_definition_id, db_map.ext_tool_feature_sq.c.parameter_definition_name - ) - - @staticmethod - def name_field(): - return "parameter_definition_name" - - @staticmethod - def id_field(): - return "parameter_definition_id" - - -class ToolFeatureMethodMethodMapping(ExportMapping): - """Maps tool feature method methods. - - Cannot be used as the topmost mapping; must have :class:`ToolFeatureMethodEntityClassMapping` as parent. - """ - - MAP_TYPE = "ToolFeatureMethodMethod" - - def add_query_columns(self, db_map, query): - return query.add_columns(db_map.ext_tool_feature_method_sq.c.method) - - def filter_query(self, db_map, query): - return query.outerjoin( - db_map.ext_tool_feature_method_sq, - and_( - db_map.ext_tool_feature_method_sq.c.tool_id == db_map.ext_tool_feature_sq.c.tool_id, - db_map.ext_tool_feature_method_sq.c.feature_id == db_map.ext_tool_feature_sq.c.feature_id, - ), - ) - - @staticmethod - def name_field(): - return "method" - - @staticmethod - def id_field(): - return "method" - - def _data(self, db_row): - data = super()._data(db_row) - return from_database_to_single_value(data, None) - - class _DescriptionMappingBase(ExportMapping): """Maps descriptions.""" @@ -1807,8 +1586,6 @@ def from_dict(serialized): ElementMapping, ExpandedParameterDefaultValueMapping, ExpandedParameterValueMapping, - FeatureEntityClassMapping, - FeatureParameterDefinitionMapping, FixedValueMapping, IndexNameMapping, EntityClassMapping, @@ -1829,12 +1606,15 @@ def from_dict(serialized): ScenarioBeforeAlternativeMapping, ScenarioDescriptionMapping, ScenarioMapping, - ToolMapping, - ToolFeatureEntityClassMapping, - ToolFeatureParameterDefinitionMapping, - ToolFeatureRequiredFlagMapping, - ToolFeatureMethodEntityClassMapping, - ToolFeatureMethodParameterDefinitionMapping, + # FIXME + # FeatureEntityClassMapping, + # FeatureParameterDefinitionMapping, + # ToolMapping, + # ToolFeatureEntityClassMapping, + # ToolFeatureParameterDefinitionMapping, + # ToolFeatureRequiredFlagMapping, + # ToolFeatureMethodEntityClassMapping, + # ToolFeatureMethodParameterDefinitionMapping, ) } legacy_mappings = { diff --git a/spinedb_api/export_mapping/settings.py b/spinedb_api/export_mapping/settings.py index f54f287c..47afaae3 100644 --- a/spinedb_api/export_mapping/settings.py +++ b/spinedb_api/export_mapping/settings.py @@ -23,8 +23,6 @@ ElementMapping, ExpandedParameterDefaultValueMapping, ExpandedParameterValueMapping, - FeatureEntityClassMapping, - FeatureParameterDefinitionMapping, EntityGroupMapping, EntityGroupEntityMapping, EntityMapping, @@ -43,13 +41,6 @@ ScenarioBeforeAlternativeMapping, ScenarioMapping, ScenarioDescriptionMapping, - ToolFeatureEntityClassMapping, - ToolFeatureMethodMethodMapping, - ToolFeatureMethodEntityClassMapping, - ToolFeatureMethodParameterDefinitionMapping, - ToolFeatureParameterDefinitionMapping, - ToolFeatureRequiredFlagMapping, - ToolMapping, IndexNameMapping, DefaultValueIndexNameMapping, ParameterDefaultValueTypeMapping, @@ -424,92 +415,6 @@ def set_parameter_default_value_dimensions(mapping, dimensions): ) -def feature_export(entity_class_position=Position.hidden, definition_position=Position.hidden): - """ - Sets up export mappings for exporting features. - - Args: - entity_class_position (int or Position): position of entity classes - definition_position (int or Position): position of parameter definitions - - Returns: - ExportMapping: root mapping - """ - class_ = FeatureEntityClassMapping(entity_class_position) - definition = FeatureParameterDefinitionMapping(definition_position) - class_.child = definition - return class_ - - -def tool_export(tool_position=Position.hidden): - """ - Sets up export mappings for exporting tools. - - Args: - tool_position (int or Position): position of tools - - Returns: - ExportMapping: root mapping - """ - return ToolMapping(tool_position) - - -def tool_feature_export( - tool_position=Position.hidden, - entity_class_position=Position.hidden, - definition_position=Position.hidden, - required_flag_position=Position.hidden, -): - """ - Sets up export mappings for exporting tool features. - - Args: - tool_position (int or Position): position of tools - entity_class_position (int or Position): position of entity classes - definition_position (int or Position): position of parameter definitions - required_flag_position (int or Position): position of required flags - - Returns: - ExportMapping: root mapping - """ - tool = ToolMapping(tool_position) - class_ = ToolFeatureEntityClassMapping(entity_class_position) - definition = ToolFeatureParameterDefinitionMapping(definition_position) - required_flag = ToolFeatureRequiredFlagMapping(required_flag_position) - definition.child = required_flag - class_.child = definition - tool.child = class_ - return tool - - -def tool_feature_method_export( - tool_position=Position.hidden, - entity_class_position=Position.hidden, - definition_position=Position.hidden, - method_position=Position.hidden, -): - """ - Sets up export mappings for exporting tool feature methods. - - Args: - tool_position (int or Position): position of tools - entity_class_position (int or Position): position of entity classes - definition_position (int or Position): position of parameter definitions - method_position (int or Position): position of methods - - Returns: - ExportMapping: root mapping - """ - tool = ToolMapping(tool_position) - class_ = ToolFeatureMethodEntityClassMapping(entity_class_position) - definition = ToolFeatureMethodParameterDefinitionMapping(definition_position) - method = ToolFeatureMethodMethodMapping(method_position) - definition.child = method - class_.child = definition - tool.child = class_ - return tool - - def _generate_dimensions(parent, cls, positions): """ Nests mappings of same type as children of given ``parent``. diff --git a/spinedb_api/filters/scenario_filter.py b/spinedb_api/filters/scenario_filter.py index 4028a6b9..bd4bba6a 100644 --- a/spinedb_api/filters/scenario_filter.py +++ b/spinedb_api/filters/scenario_filter.py @@ -31,6 +31,8 @@ def apply_scenario_filter_to_subqueries(db_map, scenario): scenario (str or int): scenario name or id """ state = _ScenarioFilterState(db_map, scenario) + make_entity_sq = partial(_make_scenario_filtered_entity_sq, state=state) + db_map.override_entity_sq_maker(make_entity_sq) make_parameter_value_sq = partial(_make_scenario_filtered_parameter_value_sq, state=state) db_map.override_parameter_value_sq_maker(make_parameter_value_sq) make_alternative_sq = partial(_make_scenario_filtered_alternative_sq, state=state) @@ -112,6 +114,7 @@ class _ScenarioFilterState: Internal state for :func:`_make_scenario_filtered_parameter_value_sq`. Attributes: + original_entity_sq (Alias): previous ``entity_sq`` original_alternative_sq (Alias): previous ``alternative_sq`` original_parameter_value_sq (Alias): previous ``parameter_value_sq`` original_scenario_alternative_sq (Alias): previous ``scenario_alternative_sq`` @@ -126,6 +129,7 @@ def __init__(self, db_map, scenario): db_map (DatabaseMappingBase): database the state applies to scenario (str or int): scenario name or ids """ + self.original_entity_sq = db_map.entity_sq self.original_parameter_value_sq = db_map.parameter_value_sq self.original_scenario_sq = db_map.scenario_sq self.original_scenario_alternative_sq = db_map.scenario_alternative_sq @@ -179,6 +183,37 @@ def _scenario_alternative_ids(self, db_map): return scenario_alternative_ids, alternative_ids +def _make_scenario_filtered_entity_sq(db_map, state): + """Returns a scenario filtering subquery similar to :func:`DatabaseMappingBase.entity_sq`. + + This function can be used as replacement for entity subquery maker in :class:`DatabaseMappingBase`. + + Args: + db_map (DatabaseMappingBase): a database map + state (_ScenarioFilterState): a state bound to ``db_map`` + + Returns: + Alias: a subquery for entity filtered by selected scenario + """ + ext_entity_sq = ( + db_map.query( + state.original_entity_sq, + func.row_number() + .over( + partition_by=[state.original_entity_sq.c.id], + order_by=desc(db_map.scenario_alternative_sq.c.rank), + ) + .label("max_rank_row_number"), + db_map.entity_alternative_sq.active.label("active"), + ) + .filter(state.original_entity_sq.c.id == db_map.entity_alternative_sq.c.entity_id) + .filter(db_map.entity_alternative_sq.c.alternative_id == db_map.scenario_alternative_sq.c.alternative_id) + .filter(db_map.scenario_alternative_sq.c.scenario_id == state.scenario_id) + ).subquery() + # TODO: Maybe we want to filter multi-dimensional entities involving filtered entities right here too? + return db_map.query(ext_entity_sq).filter_by(max_rank_row_number=1, is_active=True).subquery() + + def _make_scenario_filtered_parameter_value_sq(db_map, state): """ Returns a scenario filtering subquery similar to :func:`DatabaseMappingBase.parameter_value_sq`. diff --git a/spinedb_api/filters/tool_filter.py b/spinedb_api/filters/tool_filter.py deleted file mode 100644 index e0149b50..00000000 --- a/spinedb_api/filters/tool_filter.py +++ /dev/null @@ -1,287 +0,0 @@ -###################################################################################################################### -# Copyright (C) 2017-2022 Spine project consortium -# This file is part of Spine Database API. -# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser -# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your -# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; -# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General -# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with -# this program. If not, see . -###################################################################################################################### - -""" -Provides functions to apply filtering based on tools to entity subqueries. - -""" -from functools import partial -from uuid import uuid4 -from sqlalchemy import and_, or_, case, func, Column, ForeignKey -from ..exception import SpineDBAPIError - - -TOOL_FILTER_TYPE = "tool_filter" -TOOL_SHORTHAND_TAG = "tool" - - -def apply_tool_filter_to_entity_sq(db_map, tool): - """ - Replaces entity subquery properties in ``db_map`` such that they return only values of given tool. - - Args: - db_map (DatabaseMappingBase): a database map to alter - tool (str or int): tool name or id - """ - state = _ToolFilterState(db_map, tool) - filtering = partial(_make_tool_filtered_entity_sq, state=state) - db_map.override_entity_sq_maker(filtering) - - -def tool_filter_config(tool): - """ - Creates a config dict for tool filter. - - Args: - tool (str): tool name - - Returns: - dict: filter configuration - """ - return {"type": TOOL_FILTER_TYPE, "tool": tool} - - -def tool_filter_from_dict(db_map, config): - """ - Applies tool filter to given database map. - - Args: - db_map (DatabaseMappingBase): target database map - config (dict): tool filter configuration - """ - apply_tool_filter_to_entity_sq(db_map, config["tool"]) - - -def tool_name_from_dict(config): - """ - Returns tool's name from filter config. - - Args: - config (dict): tool filter configuration - - Returns: - str: tool name or None if ``config`` is not a valid tool filter configuration - """ - if config["type"] != TOOL_FILTER_TYPE: - return None - return config["tool"] - - -def tool_filter_config_to_shorthand(config): - """ - Makes a shorthand string from tool filter configuration. - - Args: - config (dict): tool filter configuration - - Returns: - str: a shorthand string - """ - return TOOL_SHORTHAND_TAG + ":" + config["tool"] - - -def tool_filter_shorthand_to_config(shorthand): - """ - Makes configuration dictionary out of a shorthand string. - - Args: - shorthand (str): a shorthand string - - Returns: - dict: tool filter configuration - """ - _, _, tool = shorthand.partition(":") - return tool_filter_config(tool) - - -class _ToolFilterState: - """ - Internal state for :func:`_make_tool_filtered_entity_sq` - - Attributes: - original_entity_sq (Alias): previous ``entity_sq`` - table (Table): temporary table containing cached entity ids that passed the filter - """ - - def __init__(self, db_map, tool): - """ - Args: - db_map (DatabaseMappingBase): database the state applies to - tool (str or int): tool name or id - """ - self.original_entity_sq = db_map.entity_sq - tool_id = self._tool_id(db_map, tool) - table_name = "tool_filter_cache_" + uuid4().hex - column = Column("entity_id", ForeignKey("entity.id")) - self.table = db_map.make_temporary_table(table_name, column) - statement = self.table.insert().from_select(["entity_id"], self.active_entity_id_sq(db_map, tool_id)) - db_map.connection.execute(statement) - - @staticmethod - def _tool_id(db_map, tool): - """ - Finds id for given tool. - - Args: - db_map (DatabaseMappingBase): a database map - tool (str or int): tool name or id - - Returns: - int or NoneType: tool id - """ - if isinstance(tool, str): - tool_name = tool - tool_id = db_map.query(db_map.tool_sq.c.id).filter(db_map.tool_sq.c.name == tool_name).scalar() - if tool_id is None: - raise SpineDBAPIError(f"Tool '{tool_name}' not found.") - return tool_id - tool_id = tool - tool = db_map.query(db_map.tool_sq).filter(db_map.tool_sq.c.id == tool_id).one_or_none() - if tool is None: - raise SpineDBAPIError(f"Tool id {tool_id} not found.") - return tool_id - - @staticmethod - def active_entity_id_sq(db_map, tool_id): - """ - Creates a subquery that returns entity ids that pass the tool filter. - - Args: - db_map (DatabaseMappingBase): database mapping - tool_id (int): tool identifier - - Returns: - Alias: subquery - """ - tool_feature_method_sq = _make_ext_tool_feature_method_sq(db_map, tool_id) - method_filter = _make_method_filter( - tool_feature_method_sq, db_map.parameter_value_sq, db_map.parameter_definition_sq - ) - required_filter = _make_required_filter(tool_feature_method_sq, db_map.parameter_value_sq) - return ( - db_map.query(db_map.entity_sq.c.id) - .outerjoin( - db_map.parameter_definition_sq, - db_map.parameter_definition_sq.c.entity_class_id == db_map.entity_sq.c.class_id, - ) - .outerjoin( - db_map.parameter_value_sq, - and_( - db_map.parameter_value_sq.c.parameter_definition_id == db_map.parameter_definition_sq.c.id, - db_map.parameter_value_sq.c.entity_id == db_map.entity_sq.c.id, - ), - ) - .outerjoin( - tool_feature_method_sq, - tool_feature_method_sq.c.parameter_definition_id == db_map.parameter_definition_sq.c.id, - ) - .group_by(db_map.entity_sq.c.id) - .having(and_(func.min(method_filter).is_(True), func.min(required_filter).is_(True))) - ).subquery() - - -def _make_ext_tool_feature_method_sq(db_map, tool_id): - """ - Returns an extended tool_feature_method subquery that has ``None`` whenever no method is specified. - Used by ``_make_tool_filtered_entity_sq`` - - Args: - db_map (DatabaseMappingBase): a database map - tool_id (int): tool id - - Returns: - Alias: a subquery for tool_feature_method - """ - return ( - db_map.query( - db_map.ext_tool_feature_sq.c.tool_id, - db_map.ext_tool_feature_sq.c.parameter_definition_id, - db_map.ext_tool_feature_sq.c.required, - db_map.list_value_sq.c.id.label("method_list_value_id"), - ) - .outerjoin( - db_map.tool_feature_method_sq, - db_map.tool_feature_method_sq.c.tool_feature_id == db_map.ext_tool_feature_sq.c.id, - ) - .outerjoin( - db_map.list_value_sq, - and_( - db_map.ext_tool_feature_sq.c.parameter_value_list_id == db_map.list_value_sq.c.parameter_value_list_id, - db_map.tool_feature_method_sq.c.method_index == db_map.list_value_sq.c.index, - ), - ) - .filter(db_map.ext_tool_feature_sq.c.tool_id == tool_id) - ).subquery() - - -def _make_method_filter(tool_feature_method_sq, parameter_value_sq, parameter_definition_sq): - # Filter passes if either: - # 1) parameter definition is not a feature for the tool - # 2) method is not specified - # 3) value is equal to method - # 4) value is not specified, but default value is equal to method - return case( - [ - ( - or_( - tool_feature_method_sq.c.parameter_definition_id.is_(None), - tool_feature_method_sq.c.method_list_value_id.is_(None), - parameter_value_sq.c.list_value_id == tool_feature_method_sq.c.method_list_value_id, - and_( - parameter_value_sq.c.value.is_(None), - parameter_definition_sq.c.list_value_id == tool_feature_method_sq.c.method_list_value_id, - ), - ), - True, - ) - ], - else_=False, - ) - - -def _make_required_filter(tool_feature_method_sq, parameter_value_sq): - # Filter passes if either: - # 1) parameter definition is not a feature for the tool - # 2) value is specified - # 3) method is not required - return case( - [ - ( - or_( - tool_feature_method_sq.c.parameter_definition_id.is_(None), - parameter_value_sq.c.value.isnot(None), - tool_feature_method_sq.c.required.is_(False), - ), - True, - ) - ], - else_=False, - ) - - -def _make_tool_filtered_entity_sq(db_map, state): - """ - Returns a tool filtering subquery similar to :func:`DatabaseMappingBase.entity_sq`. - - This function can be used as replacement for entity subquery maker in :class:`DatabaseMappingBase`. - - Args: - db_map (DatabaseMappingBase): a database map - state (_ScenarioFilterState): a state bound to ``db_map`` - - Returns: - Alias: a subquery for entity filtered by selected tool - """ - return ( - db_map.query(state.original_entity_sq) - .join(state.table, state.original_entity_sq.c.id == state.table.c.entity_id) - .subquery() - ) diff --git a/spinedb_api/filters/tools.py b/spinedb_api/filters/tools.py index e7408115..d8af50fe 100644 --- a/spinedb_api/filters/tools.py +++ b/spinedb_api/filters/tools.py @@ -45,15 +45,6 @@ scenario_filter_shorthand_to_config, scenario_name_from_dict, ) -from .tool_filter import ( - TOOL_SHORTHAND_TAG, - TOOL_FILTER_TYPE, - tool_filter_config, - tool_filter_config_to_shorthand, - tool_filter_from_dict, - tool_filter_shorthand_to_config, - tool_name_from_dict, -) from .value_transformer import ( VALUE_TRANSFORMER_SHORTHAND_TAG, VALUE_TRANSFORMER_TYPE, @@ -88,7 +79,6 @@ def apply_filter_stack(db_map, stack): EXECUTION_FILTER_TYPE: execution_filter_from_dict, PARAMETER_RENAMER_TYPE: parameter_renamer_from_dict, SCENARIO_FILTER_TYPE: scenario_filter_from_dict, - TOOL_FILTER_TYPE: tool_filter_from_dict, VALUE_TRANSFORMER_TYPE: value_transformer_from_dict, } for filter_ in stack: @@ -139,7 +129,6 @@ def filter_config(filter_type, value): """ return { SCENARIO_FILTER_TYPE: scenario_filter_config, - TOOL_FILTER_TYPE: tool_filter_config, ALTERNATIVE_FILTER_TYPE: alternative_filter_config, EXECUTION_FILTER_TYPE: execution_filter_config, }[filter_type](value) @@ -288,7 +277,6 @@ def config_to_shorthand(config): ENTITY_CLASS_RENAMER_TYPE: entity_class_renamer_config_to_shorthand, PARAMETER_RENAMER_TYPE: parameter_renamer_config_to_shorthand, SCENARIO_FILTER_TYPE: scenario_filter_config_to_shorthand, - TOOL_FILTER_TYPE: tool_filter_config_to_shorthand, EXECUTION_FILTER_TYPE: execution_filter_config_to_shorthand, VALUE_TRANSFORMER_TYPE: value_transformer_config_to_shorthand, } @@ -310,7 +298,6 @@ def _parse_shorthand(shorthand): ENTITY_CLASS_RENAMER_SHORTHAND_TAG: entity_class_renamer_shorthand_to_config, PARAMETER_RENAMER_SHORTHAND_TAG: parameter_renamer_shorthand_to_config, SCENARIO_SHORTHAND_TAG: scenario_filter_shorthand_to_config, - TOOL_SHORTHAND_TAG: tool_filter_shorthand_to_config, EXECUTION_SHORTHAND_TAG: execution_filter_shorthand_to_config, VALUE_TRANSFORMER_SHORTHAND_TAG: value_transformer_shorthand_to_config, } @@ -320,7 +307,7 @@ def _parse_shorthand(shorthand): def name_from_dict(config): """ - Returns scenario or tool name from filter config. + Returns scenario name from filter config. Args: config (dict): filter configuration @@ -328,7 +315,7 @@ def name_from_dict(config): Returns: str: name or None if ``config`` is not a valid 'name' filter configuration """ - func = {SCENARIO_FILTER_TYPE: scenario_name_from_dict, TOOL_FILTER_TYPE: tool_name_from_dict}.get(config["type"]) + func = {SCENARIO_FILTER_TYPE: scenario_name_from_dict}.get(config["type"]) if func is None: return None return func(config) diff --git a/spinedb_api/filters/value_transformer.py b/spinedb_api/filters/value_transformer.py index a364feb8..0d35d256 100644 --- a/spinedb_api/filters/value_transformer.py +++ b/spinedb_api/filters/value_transformer.py @@ -16,7 +16,7 @@ from functools import partial from numbers import Number from sqlalchemy import case, literal, Integer, LargeBinary, String -from sqlalchemy.sql.expression import label, select, cast, union_all +from sqlalchemy.sql.expression import select, cast, union_all from ..exception import SpineDBAPIError from ..helpers import LONGTEXT_LENGTH @@ -188,24 +188,12 @@ def _make_parameter_value_transforming_sq(db_map, state): temp_sq = union_all(*statements).alias("transformed_values") new_value = case([(temp_sq.c.transformed_value != None, temp_sq.c.transformed_value)], else_=subquery.c.value) new_type = case([(temp_sq.c.transformed_type != None, temp_sq.c.transformed_type)], else_=subquery.c.type) - object_class_case = case( - [(db_map.ext_entity_class_sq.c.dimension_id_list == None, subquery.c.entity_class_id)], else_=None - ) - rel_class_case = case( - [(db_map.ext_entity_class_sq.c.dimension_id_list != None, subquery.c.entity_class_id)], else_=None - ) - object_entity_case = case([(db_map.ext_entity_sq.c.element_id_list == None, subquery.c.entity_id)], else_=None) - rel_entity_case = case([(db_map.ext_entity_sq.c.element_id_list != None, subquery.c.entity_id)], else_=None) parameter_value_sq = ( db_map.query( subquery.c.id.label("id"), subquery.c.parameter_definition_id, subquery.c.entity_class_id, subquery.c.entity_id, - label("object_class_id", object_class_case), - label("relationship_class_id", rel_class_case), - label("object_id", object_entity_case), - label("relationship_id", rel_entity_case), new_value.label("value"), new_type.label("type"), subquery.c.list_value_id, diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index 577bc3d5..9d5f6b53 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -41,6 +41,7 @@ func, inspect, null, + true, select, ) from sqlalchemy.ext.automap import generate_relationship @@ -423,6 +424,21 @@ def create_spine_metadata(): ("member_id", "entity_class_id"), ("entity.id", "entity.class_id"), onupdate="CASCADE", ondelete="CASCADE" ), ) + Table( + "entity_alternative", + meta, + Column("id", Integer, primary_key=True), + Column("entity_id", Integer, ForeignKey("entity.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=False), + Column( + "alternative_id", + Integer, + ForeignKey("alternative.id", onupdate="CASCADE", ondelete="CASCADE"), + nullable=False, + ), + Column("active", Boolean(name="active"), server_default=true(), nullable=False), + Column("commit_id", Integer, ForeignKey("commit.id")), + UniqueConstraint("entity_id", "alternative_id"), + ) Table( "parameter_definition", meta, @@ -505,71 +521,6 @@ def create_spine_metadata(): Column("commit_id", Integer, ForeignKey("commit.id")), UniqueConstraint("parameter_value_list_id", "index"), ) - Table( - "tool", - meta, - Column("id", Integer, primary_key=True), - Column("name", String(155), nullable=False), - Column("description", Text(), server_default=null()), - Column("commit_id", Integer, ForeignKey("commit.id")), - ) - Table( - "feature", - meta, - Column("id", Integer, primary_key=True), - Column("parameter_definition_id", Integer, nullable=False), - Column("parameter_value_list_id", Integer, nullable=False), - Column("description", Text(), server_default=null()), - Column("commit_id", Integer, ForeignKey("commit.id")), - UniqueConstraint("parameter_definition_id", "parameter_value_list_id"), - UniqueConstraint("id", "parameter_value_list_id"), - ForeignKeyConstraint( - ("parameter_definition_id", "parameter_value_list_id"), - ("parameter_definition.id", "parameter_definition.parameter_value_list_id"), - onupdate="CASCADE", - ondelete="CASCADE", - ), - ) - Table( - "tool_feature", - meta, - Column("id", Integer, primary_key=True), - Column("tool_id", Integer, ForeignKey("tool.id")), - Column("feature_id", Integer, nullable=False), - Column("parameter_value_list_id", Integer, nullable=False), - Column("required", Boolean(name="required"), server_default=false(), nullable=False), - Column("commit_id", Integer, ForeignKey("commit.id")), - UniqueConstraint("tool_id", "feature_id"), - UniqueConstraint("id", "parameter_value_list_id"), - ForeignKeyConstraint( - ("feature_id", "parameter_value_list_id"), - ("feature.id", "feature.parameter_value_list_id"), - onupdate="CASCADE", - ondelete="CASCADE", - ), - ) - Table( - "tool_feature_method", - meta, - Column("id", Integer, primary_key=True), - Column("tool_feature_id", Integer, nullable=False), - Column("parameter_value_list_id", Integer, nullable=False), - Column("method_index", Integer), - Column("commit_id", Integer, ForeignKey("commit.id")), - UniqueConstraint("tool_feature_id", "method_index"), - ForeignKeyConstraint( - ("tool_feature_id", "parameter_value_list_id"), - ("tool_feature.id", "tool_feature.parameter_value_list_id"), - onupdate="CASCADE", - ondelete="CASCADE", - ), - ForeignKeyConstraint( - ("parameter_value_list_id", "method_index"), - ("list_value.parameter_value_list_id", "list_value.index"), - onupdate="CASCADE", - ondelete="CASCADE", - ), - ) Table( "metadata", meta, diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index ec3591b9..7ba8a770 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -19,10 +19,6 @@ from .check_functions import ( check_entity_class, check_entity, - check_tool, - check_feature, - check_tool_feature, - check_tool_feature_method, check_alternative, check_object_class, check_object, @@ -72,7 +68,6 @@ def import_data(db_map, make_cache=None, unparse_value=to_database, on_conflict= alternatives = [['example_alternative', 'An example']] scenarios = [['example_scenario', 'An example']] scenario_alternatives = [('scenario', 'alternative1'), ('scenario', 'alternative0', 'alternative1')] - tools = [('tool1', 'Tool one description'), ('tool2', 'Tool two description']] import_data(db_map, object_classes=object_c, @@ -86,8 +81,7 @@ def import_data(db_map, make_cache=None, unparse_value=to_database, on_conflict= relationship_parameter_values=rel_p_values, alternatives=alternatives, scenarios=scenarios, - scenario_alternatives=scenario_alternatives - tools=tools) + scenario_alternatives=scenario_alternatives) Args: db_map (spinedb_api.DiffDatabaseMapping): database mapping @@ -127,10 +121,6 @@ def import_data(db_map, make_cache=None, unparse_value=to_database, on_conflict= "parameter_value_list": db_map._add_parameter_value_lists, "list_value": db_map._add_list_values, "parameter_definition": db_map._add_parameter_definitions, - "feature": db_map._add_features, - "tool": db_map._add_tools, - "tool_feature": db_map._add_tool_features, - "tool_feature_method": db_map._add_tool_feature_methods, "entity": db_map._add_entities, "object": db_map._add_objects, "relationship": db_map._add_wide_relationships, @@ -150,9 +140,6 @@ def import_data(db_map, make_cache=None, unparse_value=to_database, on_conflict= "parameter_value_list": db_map._update_parameter_value_lists, "list_value": db_map._update_list_values, "parameter_definition": db_map._update_parameter_definitions, - "feature": db_map._update_features, - "tool": db_map._update_tools, - "tool_feature": db_map._update_tool_features, "entity": db_map._update_entities, "object": db_map._update_objects, "parameter_value": db_map._update_parameter_values, @@ -203,10 +190,6 @@ def get_data_for_import( alternatives=(), scenarios=(), scenario_alternatives=(), - features=(), - tools=(), - tool_features=(), - tool_feature_methods=(), metadata=(), object_metadata=(), relationship_metadata=(), @@ -281,17 +264,6 @@ def get_data_for_import( "parameter_definition", _get_relationship_parameters_for_import(db_map, relationship_parameters, make_cache, unparse_value), ) - if features: - yield ("feature", _get_features_for_import(db_map, features, make_cache)) - if tools: - yield ("tool", _get_tools_for_import(db_map, tools, make_cache)) - if tool_features: - yield ("tool_feature", _get_tool_features_for_import(db_map, tool_features, make_cache)) - if tool_feature_methods: - yield ( - "tool_feature_method", - _get_tool_feature_methods_for_import(db_map, tool_feature_methods, make_cache, unparse_value), - ) if entities: yield ("entity", _get_entities_for_import(db_map, entities, make_cache, dry_run)) if objects: @@ -779,287 +751,6 @@ def _get_parameter_values_for_import(db_map, data, make_cache, unparse_value, on return to_add, to_update, error_log -def import_features(db_map, data, make_cache=None): - """ - Imports features. - - Example: - - data = [('class', 'parameter'), ('another_class', 'another_parameter', 'description')] - import_features(db_map, data) - - Args: - db_map (DiffDatabaseMapping): mapping for database to insert into - data (Iterable): an iterable of lists/tuples with class name, parameter name, and optionally description - - Returns: - tuple of int and list: Number of successfully inserted features, list of errors - """ - return import_data(db_map, features=data, make_cache=make_cache) - - -def _get_features_for_import(db_map, data, make_cache): - cache = make_cache({"feature"}, include_ancestors=True) - feature_ids = {x.parameter_definition_id: x.id for x in cache.get("feature", {}).values()} - parameter_ids = { - (x.entity_class_name, x.parameter_name): (x.id, x.value_list_id) - for x in cache.get("parameter_definition", {}).values() - } - parameter_definitions = { - x.id: { - "name": x.parameter_name, - "entity_class_id": x.entity_class_id, - "parameter_value_list_id": x.value_list_id, - } - for x in cache.get("parameter_definition", {}).values() - } - checked = set() - to_add = [] - to_update = [] - error_log = [] - for class_name, parameter_name, *optionals in data: - parameter_definition_id, parameter_value_list_id = parameter_ids.get((class_name, parameter_name), (None, None)) - if parameter_definition_id in checked: - continue - feature_id = feature_ids.pop(parameter_definition_id, None) - item = ( - cache["feature"][feature_id]._asdict() - if feature_id is not None - else { - "parameter_definition_id": parameter_definition_id, - "parameter_value_list_id": parameter_value_list_id, - "description": None, - } - ) - item.update(dict(zip(("description",), optionals))) - try: - check_feature(item, feature_ids, parameter_definitions) - except SpineIntegrityError as e: - error_log.append( - ImportErrorLogItem( - msg=f"Could not import feature '{class_name, parameter_name}': {e.msg}", db_type="feature" - ) - ) - continue - finally: - if feature_id is not None: - feature_ids[parameter_definition_id] = feature_id - checked.add(parameter_definition_id) - if feature_id is not None: - item["id"] = feature_id - to_update.append(item) - else: - to_add.append(item) - return to_add, to_update, error_log - - -def import_tools(db_map, data, make_cache=None): - """ - Imports tools. - - Example: - - data = ['tool', ('another_tool', 'description')] - import_tools(db_map, data) - - Args: - db_map (DiffDatabaseMapping): mapping for database to insert into - data (Iterable): an iterable of tool names, - or of lists/tuples with tool names and optional descriptions - - Returns: - tuple of int and list: Number of successfully inserted tools, list of errors - """ - return import_data(db_map, tools=data, make_cache=make_cache) - - -def _get_tools_for_import(db_map, data, make_cache): - cache = make_cache({"tool"}, include_ancestors=True) - tool_ids = {tool.name: tool.id for tool in cache.get("tool", {}).values()} - checked = set() - to_add = [] - to_update = [] - error_log = [] - for tool in data: - if isinstance(tool, str): - tool = (tool,) - name, *optionals = tool - if name in checked: - continue - tool_id = tool_ids.pop(name, None) - item = cache["tool"][tool_id]._asdict() if tool_id is not None else {"name": name, "description": None} - item.update(dict(zip(("description",), optionals))) - try: - check_tool(item, tool_ids) - except SpineIntegrityError as e: - error_log.append(ImportErrorLogItem(msg=f"Could not import tool '{name}': {e.msg}", db_type="tool")) - continue - finally: - if tool_id is not None: - tool_ids[name] = tool_id - checked.add(name) - if tool_id is not None: - item["id"] = tool_id - to_update.append(item) - else: - to_add.append(item) - return to_add, to_update, error_log - - -def import_tool_features(db_map, data, make_cache=None): - """ - Imports tool features. - - Example: - - data = [('tool', 'class', 'parameter'), ('another_tool', 'another_class', 'another_parameter', 'required')] - import_tool_features(db_map, data) - - Args: - db_map (DiffDatabaseMapping): mapping for database to insert into - data (Iterable): an iterable of lists/tuples with tool name, class name, parameter name, - and optionally description - - Returns: - tuple of int and list: Number of successfully inserted tool features, list of errors - """ - return import_data(db_map, tool_features=data, make_cache=make_cache) - - -def _get_tool_features_for_import(db_map, data, make_cache): - cache = make_cache({"tool_feature"}, include_ancestors=True) - tool_feature_ids = {(x.tool_id, x.feature_id): x.id for x in cache.get("tool_feature", {}).values()} - tool_ids = {x.name: x.id for x in cache.get("tool", {}).values()} - feature_ids = { - (x.entity_class_name, x.parameter_definition_name): (x.id, x.parameter_value_list_id) - for x in cache.get("feature", {}).values() - } - tools = {x.id: x._asdict() for x in cache.get("tool", {}).values()} - features = { - x.id: { - "name": x.entity_class_name + "/" + x.parameter_definition_name, - "parameter_value_list_id": x.parameter_value_list_id, - } - for x in cache.get("feature", {}).values() - } - checked = set() - to_add = [] - to_update = [] - error_log = [] - for tool_name, class_name, parameter_name, *optionals in data: - tool_id = tool_ids.get(tool_name) - feature_id, parameter_value_list_id = feature_ids.get((class_name, parameter_name), (None, None)) - if (tool_id, feature_id) in checked: - continue - tool_feature_id = tool_feature_ids.pop((tool_id, feature_id), None) - item = ( - cache["tool_feature"][tool_feature_id]._asdict() - if tool_feature_id is not None - else { - "tool_id": tool_id, - "feature_id": feature_id, - "parameter_value_list_id": parameter_value_list_id, - "required": False, - } - ) - item.update(dict(zip(("required",), optionals))) - try: - check_tool_feature(item, tool_feature_ids, tools, features) - except SpineIntegrityError as e: - error_log.append( - ImportErrorLogItem( - msg=f"Could not import tool feature '{tool_name, class_name, parameter_name}': {e.msg}", - db_type="tool_feature", - ) - ) - continue - finally: - if tool_feature_id is not None: - tool_feature_ids[tool_id, feature_id] = tool_feature_id - checked.add((tool_id, feature_id)) - if tool_feature_id is not None: - item["id"] = tool_feature_id - to_update.append(item) - else: - to_add.append(item) - return to_add, to_update, error_log - - -def import_tool_feature_methods(db_map, data, make_cache=None, unparse_value=to_database): - """ - Imports tool feature methods. - - Example: - - data = [('tool', 'class', 'parameter', 'method'), ('another_tool', 'another_class', 'another_parameter', 'another_method')] - import_tool_features(db_map, data) - - Args: - db_map (DiffDatabaseMapping): mapping for database to insert into - data (Iterable): an iterable of lists/tuples with tool name, class name, parameter name, and method - - Returns: - tuple of int and list: Number of successfully inserted tool features, list of errors - """ - return import_data(db_map, tool_feature_methods=data, make_cache=make_cache, unparse_value=unparse_value) - - -def _get_tool_feature_methods_for_import(db_map, data, make_cache, unparse_value): - cache = make_cache({"tool_feature_method"}, include_ancestors=True) - tool_feature_method_ids = { - (x.tool_feature_id, x.method_index): x.id for x in cache.get("tool_feature_method", {}).values() - } - tool_feature_ids = { - (x.tool_name, x.entity_class_name, x.parameter_definition_name): (x.id, x.parameter_value_list_id) - for x in cache.get("tool_feature", {}).values() - } - tool_features = {x.id: x._asdict() for x in cache.get("tool_feature", {}).values()} - parameter_value_lists = { - x.id: {"name": x.name, "value_index_list": x.value_index_list} - for x in cache.get("parameter_value_list", {}).values() - } - list_values = { - (x.parameter_value_list_id, x.index): from_database(x.value, x.type) - for x in cache.get("list_value", {}).values() - } - seen = set() - to_add = [] - error_log = [] - for tool_name, class_name, parameter_name, method in data: - tool_feature_id, parameter_value_list_id = tool_feature_ids.get( - (tool_name, class_name, parameter_name), (None, None) - ) - parameter_value_list = parameter_value_lists.get(parameter_value_list_id, {}) - value_index_list = parameter_value_list.get("value_index_list", []) - method = from_database(*unparse_value(method)) - method_index = next( - iter(index for index in value_index_list if list_values.get((parameter_value_list_id, index)) == method), - None, - ) - if (tool_feature_id, method_index) in seen | tool_feature_method_ids.keys(): - continue - item = { - "tool_feature_id": tool_feature_id, - "parameter_value_list_id": parameter_value_list_id, - "method_index": method_index, - } - try: - check_tool_feature_method(item, tool_feature_method_ids, tool_features, parameter_value_lists) - to_add.append(item) - seen.add((tool_feature_id, method_index)) - except SpineIntegrityError as e: - error_log.append( - ImportErrorLogItem( - msg=( - f"Could not import tool feature method '{tool_name, class_name, parameter_name, method}':" - f" {e.msg}" - ), - db_type="tool_feature_method", - ) - ) - return to_add, [], error_log - - def import_alternatives(db_map, data, make_cache=None): """ Imports alternatives. diff --git a/spinedb_api/import_mapping/import_mapping.py b/spinedb_api/import_mapping/import_mapping.py index 6a1fddad..6c88459b 100644 --- a/spinedb_api/import_mapping/import_mapping.py +++ b/spinedb_api/import_mapping/import_mapping.py @@ -38,10 +38,6 @@ class ImportKey(Enum): ALTERNATIVE_NAME = auto() SCENARIO_NAME = auto() SCENARIO_ALTERNATIVE = auto() - FEATURE = auto() - TOOL_NAME = auto() - TOOL_FEATURE = auto() - TOOL_FEATURE_METHOD = auto() PARAMETER_VALUE_LIST_NAME = auto() def __str__(self): @@ -59,10 +55,6 @@ def __str__(self): self.PARAMETER_VALUE_LIST_NAME.value: "Parameter value lists", self.SCENARIO_NAME.value: "Scenario names", self.SCENARIO_ALTERNATIVE.value: "Alternative names", - self.TOOL_NAME.value: "Tool names", - self.FEATURE.value: "Entity class names", - self.TOOL_FEATURE.value: "Entity class names", - self.TOOL_FEATURE_METHOD.value: "Entity class names", }.get(self.value) if name is not None: return name @@ -839,124 +831,6 @@ def _import_row(self, source_data, state, mapped_data): mapped_data.setdefault("tools", set()).add(tool) -class FeatureEntityClassMapping(ImportMapping): - """Maps feature entity classes. - - Can be used as the topmost mapping. - """ - - MAP_TYPE = "FeatureEntityClass" - - def _import_row(self, source_data, state, mapped_data): - entity_class = str(source_data) - state[ImportKey.FEATURE] = [entity_class] - - -class FeatureParameterDefinitionMapping(ImportMapping): - """Maps feature parameter definitions. - - Cannot be used as the topmost mapping; must have a :class:`FeatureEntityClassMapping` as parent. - """ - - MAP_TYPE = "FeatureParameterDefinition" - - def _import_row(self, source_data, state, mapped_data): - feature = state[ImportKey.FEATURE] - parameter = str(source_data) - feature.append(parameter) - mapped_data.setdefault("features", set()).add(tuple(feature)) - - -class ToolFeatureEntityClassMapping(ImportMapping): - """Maps tool feature entity classes. - - Cannot be used as the topmost mapping; must have :class:`ToolMapping` as parent. - """ - - MAP_TYPE = "ToolFeatureEntityClass" - - def _import_row(self, source_data, state, mapped_data): - tool = state[ImportKey.TOOL_NAME] - entity_class = str(source_data) - tool_feature = [tool, entity_class] - state[ImportKey.TOOL_FEATURE] = tool_feature - mapped_data.setdefault("tool_features", []).append(tool_feature) - - -class ToolFeatureParameterDefinitionMapping(ImportMapping): - """Maps tool feature parameter definitions. - - Cannot be used as the topmost mapping; must have :class:`ToolFeatureEntityClassMapping` as parent. - """ - - MAP_TYPE = "ToolFeatureParameterDefinition" - - def _import_row(self, source_data, state, mapped_data): - tool_feature = state[ImportKey.TOOL_FEATURE] - parameter = str(source_data) - tool_feature.append(parameter) - - -class ToolFeatureRequiredFlagMapping(ImportMapping): - """Maps tool feature required flags. - - Cannot be used as the topmost mapping; must have :class:`ToolFeatureEntityClassMapping` as parent. - """ - - MAP_TYPE = "ToolFeatureRequiredFlag" - - def _import_row(self, source_data, state, mapped_data): - required = bool(strtobool(str(source_data))) - tool_feature = state[ImportKey.TOOL_FEATURE] - tool_feature.append(required) - - -class ToolFeatureMethodEntityClassMapping(ImportMapping): - """Maps tool feature method entity classes. - - Cannot be used as the topmost mapping; must have :class:`ToolMapping` as parent. - """ - - MAP_TYPE = "ToolFeatureMethodEntityClass" - - def _import_row(self, source_data, state, mapped_data): - tool_name = state[ImportKey.TOOL_NAME] - entity_class = str(source_data) - tool_feature_method = [tool_name, entity_class] - state[ImportKey.TOOL_FEATURE_METHOD] = tool_feature_method - mapped_data.setdefault("tool_feature_methods", []).append(tool_feature_method) - - -class ToolFeatureMethodParameterDefinitionMapping(ImportMapping): - """Maps tool feature method parameter definitions. - - Cannot be used as the topmost mapping; must have :class:`ToolFeatureMethodEntityClassMapping` as parent. - """ - - MAP_TYPE = "ToolFeatureMethodParameterDefinition" - - def _import_row(self, source_data, state, mapped_data): - tool_feature_method = state[ImportKey.TOOL_FEATURE_METHOD] - parameter = str(source_data) - tool_feature_method.append(parameter) - - -class ToolFeatureMethodMethodMapping(ImportMapping): - """Maps tool feature method methods. - - Cannot be used as the topmost mapping; must have :class:`ToolFeatureMethodEntityClassMapping` as parent. - """ - - MAP_TYPE = "ToolFeatureMethodMethod" - - def _import_row(self, source_data, state, mapped_data): - method = source_data - if method == "": - return - tool_feature_method = state[ImportKey.TOOL_FEATURE_METHOD] - tool_feature_method.append(method) - - def from_dict(serialized): """ Deserializes mappings. @@ -995,14 +869,15 @@ def from_dict(serialized): ScenarioAlternativeMapping, ScenarioBeforeAlternativeMapping, ToolMapping, - FeatureEntityClassMapping, - FeatureParameterDefinitionMapping, - ToolFeatureEntityClassMapping, - ToolFeatureParameterDefinitionMapping, - ToolFeatureRequiredFlagMapping, - ToolFeatureMethodEntityClassMapping, - ToolFeatureMethodParameterDefinitionMapping, - ToolFeatureMethodMethodMapping, + # FIXME + # FeatureEntityClassMapping, + # FeatureParameterDefinitionMapping, + # ToolFeatureEntityClassMapping, + # ToolFeatureParameterDefinitionMapping, + # ToolFeatureRequiredFlagMapping, + # ToolFeatureMethodEntityClassMapping, + # ToolFeatureMethodParameterDefinitionMapping, + # ToolFeatureMethodMethodMapping, ) } legacy_mappings = { diff --git a/spinedb_api/import_mapping/import_mapping_compat.py b/spinedb_api/import_mapping/import_mapping_compat.py index aed96a0b..a400bffd 100644 --- a/spinedb_api/import_mapping/import_mapping_compat.py +++ b/spinedb_api/import_mapping/import_mapping_compat.py @@ -35,15 +35,6 @@ ScenarioActiveFlagMapping, ScenarioAlternativeMapping, ScenarioBeforeAlternativeMapping, - ToolMapping, - FeatureEntityClassMapping, - FeatureParameterDefinitionMapping, - ToolFeatureEntityClassMapping, - ToolFeatureParameterDefinitionMapping, - ToolFeatureRequiredFlagMapping, - ToolFeatureMethodEntityClassMapping, - ToolFeatureMethodParameterDefinitionMapping, - ToolFeatureMethodMethodMapping, EntityGroupMapping, ParameterValueListMapping, ParameterValueListValueMapping, @@ -84,10 +75,11 @@ def import_mapping_from_dict(map_dict): "Alternative": _alternative_mapping_from_dict, "Scenario": _scenario_mapping_from_dict, "ScenarioAlternative": _scenario_alternative_mapping_from_dict, - "Tool": _tool_mapping_from_dict, - "Feature": _feature_mapping_from_dict, - "ToolFeature": _tool_feature_mapping_from_dict, - "ToolFeatureMethod": _tool_feature_method_mapping_from_dict, + # FIXME + # "Tool": _tool_mapping_from_dict, + # "Feature": _feature_mapping_from_dict, + # "ToolFeature": _tool_feature_mapping_from_dict, + # "ToolFeatureMethod": _tool_feature_method_mapping_from_dict, "ObjectGroup": _object_group_mapping_from_dict, "ParameterValueList": _parameter_value_list_mapping_from_dict, } @@ -141,58 +133,6 @@ def _scenario_alternative_mapping_from_dict(map_dict): return root_mapping -def _tool_mapping_from_dict(map_dict): - name = map_dict.get("name") - skip_columns = map_dict.get("skip_columns", []) - read_start_row = map_dict.get("read_start_row", 0) - root_mapping = ToolMapping(*_pos_and_val(name), skip_columns=skip_columns, read_start_row=read_start_row) - return root_mapping - - -def _feature_mapping_from_dict(map_dict): - entity_class_name = map_dict.get("entity_class_name") - parameter_definition_name = map_dict.get("parameter_definition_name") - skip_columns = map_dict.get("skip_columns", []) - read_start_row = map_dict.get("read_start_row", 0) - root_mapping = FeatureEntityClassMapping( - *_pos_and_val(entity_class_name), skip_columns=skip_columns, read_start_row=read_start_row - ) - root_mapping.child = FeatureParameterDefinitionMapping(*_pos_and_val(parameter_definition_name)) - return root_mapping - - -def _tool_feature_mapping_from_dict(map_dict): - name = map_dict.get("name") - entity_class_name = map_dict.get("entity_class_name") - parameter_definition_name = map_dict.get("parameter_definition_name") - required = map_dict.get("required", "false") - skip_columns = map_dict.get("skip_columns", []) - read_start_row = map_dict.get("read_start_row", 0) - root_mapping = ToolMapping(*_pos_and_val(name), skip_columns=skip_columns, read_start_row=read_start_row) - root_mapping.child = ent_class_mapping = ToolFeatureEntityClassMapping(*_pos_and_val(entity_class_name)) - ent_class_mapping.child = param_def_mapping = ToolFeatureParameterDefinitionMapping( - *_pos_and_val(parameter_definition_name) - ) - param_def_mapping.child = ToolFeatureRequiredFlagMapping(*_pos_and_val(required)) - return root_mapping - - -def _tool_feature_method_mapping_from_dict(map_dict): - name = map_dict.get("name") - entity_class_name = map_dict.get("entity_class_name") - parameter_definition_name = map_dict.get("parameter_definition_name") - method = map_dict.get("method") - skip_columns = map_dict.get("skip_columns", []) - read_start_row = map_dict.get("read_start_row", 0) - root_mapping = ToolMapping(*_pos_and_val(name), skip_columns=skip_columns, read_start_row=read_start_row) - root_mapping.child = ent_class_mapping = ToolFeatureMethodEntityClassMapping(*_pos_and_val(entity_class_name)) - ent_class_mapping.child = param_def_mapping = ToolFeatureMethodParameterDefinitionMapping( - *_pos_and_val(parameter_definition_name) - ) - param_def_mapping.child = ToolFeatureMethodMethodMapping(*_pos_and_val(method)) - return root_mapping - - def _object_class_mapping_from_dict(map_dict): name = map_dict.get("name") objects = map_dict.get("objects", map_dict.get("object")) diff --git a/spinedb_api/spine_db_server.py b/spinedb_api/spine_db_server.py index 0a80a51d..708ddb92 100644 --- a/spinedb_api/spine_db_server.py +++ b/spinedb_api/spine_db_server.py @@ -33,7 +33,6 @@ from .parameter_value import dump_db_value from .server_client_helpers import ReceiveAllMixing, encode, decode from .filters.scenario_filter import scenario_filter_config -from .filters.tool_filter import tool_filter_config from .filters.alternative_filter import alternative_filter_config from .filters.tools import apply_filter_stack from .spine_db_client import SpineDBClient @@ -462,9 +461,7 @@ def call_method(self, method_name, *args, **kwargs): def apply_filters(self, filters): configs = [ - {"scenario": scenario_filter_config, "tool": tool_filter_config, "alternatives": alternative_filter_config}[ - key - ](value) + {"scenario": scenario_filter_config, "alternatives": alternative_filter_config}[key](value) for key, value in filters.items() ] return _db_manager.apply_filters(self.server_address, configs) From 091b39415dbdb87b27409483b55bc7da94d6c8f7 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 3 May 2023 16:31:47 +0200 Subject: [PATCH 033/317] Rationalize the use of the cache --- spinedb_api/db_cache.py | 88 +++---- spinedb_api/db_mapping_add_mixin.py | 73 +++--- spinedb_api/db_mapping_base.py | 105 ++++----- spinedb_api/db_mapping_check_mixin.py | 140 +++++------ spinedb_api/db_mapping_remove_mixin.py | 137 +++++------ spinedb_api/db_mapping_update_mixin.py | 143 ++++++----- spinedb_api/export_functions.py | 142 +++++------ spinedb_api/filters/execution_filter.py | 2 +- spinedb_api/import_functions.py | 301 +++++++++++++----------- 9 files changed, 542 insertions(+), 589 deletions(-) diff --git a/spinedb_api/db_cache.py b/spinedb_api/db_cache.py index a2287f34..3109bfdf 100644 --- a/spinedb_api/db_cache.py +++ b/spinedb_api/db_cache.py @@ -13,7 +13,6 @@ """ from contextlib import suppress -from operator import itemgetter # TODO: Implement CacheItem.pop() to do lookup? @@ -25,8 +24,8 @@ def __init__(self, advance_query, *args, **kwargs): A dictionary that maps table names to ids to items. Used to store and retrieve database contents. Args: - advance_query (function): A function to call when references aren't found. - It receives a table name (a.k.a item type) and should bring more items of that type into this cache. + advance_query (function): A function that receives a table name (a.k.a item type) as input and returns + more items of that type to be added to this cache. """ super().__init__(*args, **kwargs) self._advance_query = advance_query @@ -41,12 +40,25 @@ def get_item(self, item_type, id_): return {} return item + def fetch_more(self, item_type): + items = self._advance_query(item_type) + if not items: + return False + table_cache = self.table_cache(item_type) + for item in items: + table_cache.add_item(item._asdict()) + return True + + def fetch_all(self, item_type): + while self.fetch_more(item_type): + pass + def fetch_ref(self, item_type, id_): - while self._advance_query(item_type): + while self.fetch_more(item_type): with suppress(KeyError): return self[item_type][id_] # It is possible that fetching was completed between deciding to call this function - # and starting the while loop above resulting in self._advance_query() to return False immediately. + # and starting the while loop above resulting in self.fetch_more() to return False immediately. # Therefore, we should try one last time if the ref is available. with suppress(KeyError): return self[item_type][id_] @@ -65,12 +77,10 @@ def make_item(self, item_type, item): factory = { "entity_class": EntityClassItem, "entity": EntityItem, + "entity_group": EntityGroupItem, "parameter_definition": ParameterDefinitionItem, "parameter_value": ParameterValueItem, - "entity_group": EntityGroupItem, - "scenario": ScenarioItem, "scenario_alternative": ScenarioAlternativeItem, - "parameter_value_list": ParameterValueListItem, }.get(item_type, CacheItem) return factory(self, item_type, **item) @@ -325,8 +335,6 @@ def __getitem__(self, key): return tuple(self._get_ref("entity", id_, key).get("name") for id_ in self["element_id_list"]) if key == "byname": return self["element_name_list"] or (self["name"],) - if key == "alternative_name": - return self._get_ref("alternative", self["alternative_id"], key).get("name") return super().__getitem__(key) def _reference_keys(self): @@ -335,7 +343,6 @@ def _reference_keys(self): "dimension_id_list", "dimension_name_list", "element_name_list", - "alternative_name", ) @@ -428,60 +435,23 @@ def _reference_keys(self): return super()._reference_keys() + ("class_name", "group_name", "member_name", "dimension_id_list") -class ScenarioItem(CacheItem): - @property - def _sorted_scen_alts(self): - return sorted( - (x for x in self._db_cache.get("scenario_alternative", {}).values() if x["scenario_id"] == self["id"]), - key=itemgetter("rank"), - ) - - def __getitem__(self, key): - if key == "active": - return dict.get(self, "active", False) - if key == "alternative_id_list": - return tuple(x.get("alternative_id") for x in self._sorted_scen_alts) - if key == "alternative_name_list": - return tuple(x.get("alternative_name") for x in self._sorted_scen_alts) - return super().__getitem__(key) - - class ScenarioAlternativeItem(CacheItem): def __getitem__(self, key): if key == "scenario_name": return self._get_ref("scenario", self["scenario_id"], key).get("name") if key == "alternative_name": return self._get_ref("alternative", self["alternative_id"], key).get("name") - scen_key = { - "before_alternative_id": "alternative_id_list", - "before_alternative_name": "alternative_name_list", - }.get(key) - if scen_key is not None: - scenario = self._get_ref("scenario", self["scenario_id"], key) - try: - return scenario[scen_key][self["rank"]] - except IndexError: - return None - return super().__getitem__(key) + if key == "before_alternative_name": + return self._get_ref("alternative", self["before_alternative_id"], key).get("name") + if key == "before_alternative_id": + return next( + ( + x + for x in self._db_cache.get("scenario_alternative", {}).values() + if x["scenario_id"] == self["scenario_id"] and x["rank"] == self["rank"] - 1 + ), + {}, + ).get("alternative_id") def _reference_keys(self): return super()._reference_keys() + ("scenario_name", "alternative_name") - - -class ParameterValueListItem(CacheItem): - def _sorted_list_values(self, key): - return sorted( - ( - self._get_ref("list_value", x["id"], key) - for x in self._db_cache.get("list_value", {}).values() - if x["parameter_value_list_id"] == self["id"] - ), - key=itemgetter("index"), - ) - - def __getitem__(self, key): - if key == "value_index_list": - return tuple(x.get("index") for x in self._sorted_list_values(key)) - if key == "value_id_list": - return tuple(x.get("id") for x in self._sorted_list_values(key)) - return super().__getitem__(key) diff --git a/spinedb_api/db_mapping_add_mixin.py b/spinedb_api/db_mapping_add_mixin.py index bf1c141c..71a2d775 100644 --- a/spinedb_api/db_mapping_add_mixin.py +++ b/spinedb_api/db_mapping_add_mixin.py @@ -139,7 +139,6 @@ def add_items( strict=False, return_dups=False, return_items=False, - cache=None, readd=False, dry_run=False, ): @@ -153,8 +152,6 @@ def add_items( if the insertion of one of the items violates an integrity constraint. return_dups (bool): Whether or not already existing and duplicated entries should also be returned. return_items (bool): Return full items rather than just ids - cache (dict, optional): A dict mapping table names to a list of dictionary items, to use as db replacement - for queries readd (bool): Readds items directly Returns: @@ -167,9 +164,7 @@ def add_items( pass return items if return_items else {x["id"] for x in items}, [] if check: - checked_items, intgr_error_log = self.check_items( - tablename, *items, for_update=False, strict=strict, cache=cache - ) + checked_items, intgr_error_log = self.check_items(tablename, *items, for_update=False, strict=strict) else: checked_items, intgr_error_log = list(items), [] ids = self._add_items(tablename, *checked_items, dry_run=dry_run) @@ -259,8 +254,13 @@ def _items_to_add_per_table(tablename, items_to_add): ) ] ea_items_to_add = [ - {"entity_id": item["id"], "alternative_id": item["alternative_id"], "active": item["active"]} + {"entity_id": item["id"], "alternative_id": alternative_id, "active": True} + for item in items_to_add + for alternative_id in item["active_alternative_id_list"] + ] + [ + {"entity_id": item["id"], "alternative_id": alternative_id, "active": False} for item in items_to_add + for alternative_id in item["inactive_alternative_id_list"] ] yield ("entity", items_to_add) yield ("entity_element", ee_items_to_add) @@ -269,8 +269,13 @@ def _items_to_add_per_table(tablename, items_to_add): yield ("entity_class", items_to_add) elif tablename == "object": ea_items_to_add = [ - {"entity_id": item["id"], "alternative_id": item["alternative_id"], "active": item["active"]} + {"entity_id": item["id"], "alternative_id": alternative_id, "active": True} + for item in items_to_add + for alternative_id in item["active_alternative_id_list"] + ] + [ + {"entity_id": item["id"], "alternative_id": alternative_id, "active": False} for item in items_to_add + for alternative_id in item["inactive_alternative_id_list"] ] yield ("entity", items_to_add) yield ("entity_alternative", ea_items_to_add) @@ -297,8 +302,13 @@ def _items_to_add_per_table(tablename, items_to_add): ) ] ea_items_to_add = [ - {"entity_id": item["id"], "alternative_id": item["alternative_id"], "active": item["active"]} + {"entity_id": item["id"], "alternative_id": alternative_id, "active": True} for item in items_to_add + for alternative_id in item["active_alternative_id_list"] + ] + [ + {"entity_id": item["id"], "alternative_id": alternative_id, "active": False} + for item in items_to_add + for alternative_id in item["inactive_alternative_id_list"] ] yield ("entity", items_to_add) yield ("entity_element", ee_items_to_add) @@ -370,7 +380,8 @@ def add_entity_metadata(self, *items, **kwargs): def add_parameter_value_metadata(self, *items, **kwargs): return self.add_items("parameter_value_metadata", *items, **kwargs) - def _get_or_add_metadata_ids_for_items(self, *items, check, strict, cache, dry_run): + def _get_or_add_metadata_ids_for_items(self, *items, check, strict, dry_run): + cache = self.cache metadata_ids = {} for entry in cache.get("metadata", {}).values(): metadata_ids.setdefault(entry.name, {})[entry.value] = entry.id @@ -385,13 +396,7 @@ def _get_or_add_metadata_ids_for_items(self, *items, check, strict, cache, dry_r else: item["metadata_id"] = existing_id added_metadata, errors = self.add_items( - "metadata", - *metadata_to_add, - check=check, - strict=strict, - return_items=True, - cache=cache, - dry_run=dry_run, + "metadata", *metadata_to_add, check=check, strict=strict, return_items=True, dry_run=dry_run ) for x in added_metadata: cache.table_cache("metadata").add_item(x) @@ -405,52 +410,36 @@ def _get_or_add_metadata_ids_for_items(self, *items, check, strict, cache, dry_r item["metadata_id"] = new_metadata_ids[metadata_name][metadata_value] return added_metadata, errors - def _add_ext_item_metadata( - self, table_name, *items, check=True, strict=False, return_items=False, cache=None, dry_run=False - ): + def _add_ext_item_metadata(self, table_name, *items, check=True, strict=False, return_items=False, dry_run=False): # Note, that even though return_items can be False, it doesn't make much sense here because we'll be mixing # metadata and entity metadata ids. - if cache is None: - cache = self.make_cache({table_name}, include_ancestors=True) + self.fetch_all({table_name}, include_ancestors=True) + cache = self.cache added_metadata, metadata_errors = self._get_or_add_metadata_ids_for_items( - *items, check=check, strict=strict, cache=cache, dry_run=dry_run + *items, check=check, strict=strict, dry_run=dry_run ) if metadata_errors: if not return_items: return added_metadata, metadata_errors return {i["id"] for i in added_metadata}, metadata_errors added_item_metadata, item_errors = self.add_items( - table_name, *items, check=check, strict=strict, return_items=True, cache=cache, dry_run=dry_run + table_name, *items, check=check, strict=strict, return_items=True, dry_run=dry_run ) errors = metadata_errors + item_errors if not return_items: return {i["id"] for i in added_metadata + added_item_metadata}, errors return added_metadata + added_item_metadata, errors - def add_ext_entity_metadata( - self, *items, check=True, strict=False, return_items=False, cache=None, readd=False, dry_run=False - ): + def add_ext_entity_metadata(self, *items, check=True, strict=False, return_items=False, readd=False, dry_run=False): return self._add_ext_item_metadata( - "entity_metadata", - *items, - check=check, - strict=strict, - return_items=return_items, - cache=cache, - dry_run=dry_run, + "entity_metadata", *items, check=check, strict=strict, return_items=return_items, dry_run=dry_run ) def add_ext_parameter_value_metadata( - self, *items, check=True, strict=False, return_items=False, cache=None, readd=False, dry_run=False + self, *items, check=True, strict=False, return_items=False, readd=False, dry_run=False ): return self._add_ext_item_metadata( - "parameter_value_metadata", - *items, - check=check, - strict=strict, - return_items=return_items, - cache=cache, - dry_run=dry_run, + "parameter_value_metadata", *items, check=check, strict=strict, return_items=return_items, dry_run=dry_run ) def _add_entity_classes(self, *items): diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index cc55c4bc..c0579a1d 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -134,6 +134,7 @@ def __init__( self._entity_sq = None self._entity_class_dimension_sq = None self._entity_element_sq = None + self._entity_alternative_sq = None self._object_class_sq = None self._object_sq = None self._relationship_class_sq = None @@ -441,8 +442,8 @@ def _clear_subqueries(self, *tablenames): """ tablenames = list(tablenames) for tablename in tablenames: - if self.cache.pop(tablename, None): - self._do_advance_cache_query(tablename) + if self.cache.pop(tablename, False): + self.cache.fetch_all(tablename) attr_names = set(attr for tablename in tablenames for attr in self._get_table_to_sq_attr().get(tablename, [])) for attr_name in attr_names: setattr(self, attr_name, None) @@ -535,6 +536,12 @@ def entity_element_sq(self): self._entity_element_sq = self._subquery("entity_element") return self._entity_element_sq + @property + def entity_alternative_sq(self): + if self._entity_alternative_sq is None: + self._entity_alternative_sq = self._subquery("entity_alternative") + return self._entity_alternative_sq + @property def entity_sq(self): """A subquery of the form: @@ -643,26 +650,12 @@ def ext_entity_sq(self): """ if self._ext_entity_sq is None: entity_element_sq = ( - self.query( - self.entity_element_sq.c.entity_id, - self.entity_element_sq.c.element_id, - self.entity_element_sq.c.position, - self.entity_sq.c.name.label("element_name"), - ) + self.query(self.entity_element_sq, self.entity_sq.c.name.label("element_name")) .filter(self.entity_element_sq.c.element_id == self.entity_sq.c.id) .subquery() ) - ee_sq = ( - self.query( - self.entity_sq.c.id, - self.entity_sq.c.class_id, - self.entity_sq.c.name, - self.entity_sq.c.description, - self.entity_sq.c.commit_id, - entity_element_sq.c.element_id, - entity_element_sq.c.element_name, - entity_element_sq.c.position, - ) + entity_sq = ( + self.query(self.entity_sq, entity_element_sq) .outerjoin( entity_element_sq, self.entity_sq.c.id == entity_element_sq.c.entity_id, @@ -672,20 +665,20 @@ def ext_entity_sq(self): ) self._ext_entity_sq = ( self.query( - ee_sq.c.id, - ee_sq.c.class_id, - ee_sq.c.name, - ee_sq.c.description, - ee_sq.c.commit_id, - group_concat(ee_sq.c.element_id, ee_sq.c.position).label("element_id_list"), - group_concat(ee_sq.c.element_name, ee_sq.c.position).label("element_name_list"), + entity_sq.c.id, + entity_sq.c.class_id, + entity_sq.c.name, + entity_sq.c.description, + entity_sq.c.commit_id, + group_concat(entity_sq.c.element_id, entity_sq.c.position).label("element_id_list"), + group_concat(entity_sq.c.element_name, entity_sq.c.position).label("element_name_list"), ) .group_by( - ee_sq.c.id, - ee_sq.c.class_id, - ee_sq.c.name, - ee_sq.c.description, - ee_sq.c.commit_id, + entity_sq.c.id, + entity_sq.c.class_id, + entity_sq.c.name, + entity_sq.c.description, + entity_sq.c.commit_id, ) .subquery() ) @@ -1687,9 +1680,7 @@ def _make_entity_sq(self): Returns: Alias: an entity subquery """ - e_sq = self._subquery("entity") - ea_sq = self._subquery("entity_alternative") - return self.query(e_sq, ea_sq).filter(e_sq.c.id == ea_sq.c.entity_id).subquery() + return self._subquery("entity") def _make_entity_class_sq(self): """ @@ -1798,23 +1789,26 @@ def _make_scenario_alternative_sq(self): """ return self._subquery("scenario_alternative") - def get_import_alternative(self, cache=None): + def get_import_alternative(self): """Returns the id of the alternative to use as default for all import operations. Returns: int, str """ if self._import_alternative_id is None: - self._create_import_alternative(cache=cache) + self._create_import_alternative() return self._import_alternative_id, self._import_alternative_name - def _create_import_alternative(self, cache=None): + def _create_import_alternative(self): """Creates the alternative to be used as default for all import operations.""" - if "alternative" not in cache: - cache = self.make_cache({"alternative"}) + self.fetch_all({"alternative"}) self._import_alternative_name = "Base" self._import_alternative_id = next( - (id_ for id_, alt in cache.get("alternative", {}).items() if alt.name == self._import_alternative_name), + ( + id_ + for id_, alt in self.cache.get("alternative", {}).items() + if alt.name == self._import_alternative_name + ), None, ) if not self._import_alternative_id: @@ -1964,7 +1958,7 @@ def _reset_mapping(self): self.connection.execute(table.delete()) self.connection.execute("INSERT INTO alternative VALUES (1, 'Base', 'Base alternative', null)") - def make_cache(self, tablenames, include_descendants=False, include_ancestors=False, force_tablenames=None): + def fetch_all(self, tablenames, include_descendants=False, include_ancestors=False, force_tablenames=None): if include_descendants: tablenames |= { descendant for tablename in tablenames for descendant in self.descendant_tablenames.get(tablename, ()) @@ -1976,22 +1970,12 @@ def make_cache(self, tablenames, include_descendants=False, include_ancestors=Fa if force_tablenames: tablenames |= force_tablenames for tablename in tablenames & self.cache_sqs.keys(): - self._do_advance_cache_query(tablename) - return self.cache - - def _advance_cache_query(self, tablename, callback=None): - advanced = False - if tablename not in self.cache: - advanced = True - self._do_advance_cache_query(tablename) - if callback is not None: - callback() - return advanced - - def _do_advance_cache_query(self, tablename): - table_cache = self.cache.table_cache(tablename) - for x in self.query(getattr(self, self.cache_sqs[tablename])).yield_per(1000).enable_eagerloads(False): - table_cache.add_item(x._asdict()) + self.cache.fetch_all(tablename) + + def _advance_cache_query(self, tablename): + if tablename in self.cache: + return [] + return self.query(getattr(self, self.cache_sqs[tablename])).yield_per(1000).enable_eagerloads(False).all() def _object_class_id(self): return case([(self.ext_entity_class_sq.c.dimension_id_list == None, self.ext_entity_class_sq.c.id)], else_=None) @@ -2050,16 +2034,13 @@ def _object_name_list(self): [(self.ext_entity_sq.c.element_id_list != None, self.wide_relationship_sq.c.object_name_list)], else_=None ) - @staticmethod - def _metadata_usage_counts(cache): + def _metadata_usage_counts(self): """Counts references to metadata name, value pairs in entity_metadata and parameter_value_metadata tables. - Args: - cache (dict): database cache - Returns: Counter: usage counts keyed by metadata id """ + cache = self.cache usage_counts = Counter() for entry in cache.get("entity_metadata", {}).values(): usage_counts[entry.metadata_id] += 1 diff --git a/spinedb_api/db_mapping_check_mixin.py b/spinedb_api/db_mapping_check_mixin.py index 5efe5b6f..174eaf64 100644 --- a/spinedb_api/db_mapping_check_mixin.py +++ b/spinedb_api/db_mapping_check_mixin.py @@ -44,7 +44,7 @@ class DatabaseMappingCheckMixin: """Provides methods to check whether insert and update operations violate Spine db integrity constraints.""" - def check_items(self, tablename, *items, for_update=False, strict=False, cache=None): + def check_items(self, tablename, *items, for_update=False, strict=False): return { "alternative": self.check_alternatives, "scenario": self.check_scenarios, @@ -63,9 +63,9 @@ def check_items(self, tablename, *items, for_update=False, strict=False, cache=N "metadata": self.check_metadata, "entity_metadata": self.check_entity_metadata, "parameter_value_metadata": self.check_parameter_value_metadata, - }[tablename](*items, for_update=for_update, strict=strict, cache=cache) + }[tablename](*items, for_update=for_update, strict=strict) - def check_alternatives(self, *items, for_update=False, strict=False, cache=None): + def check_alternatives(self, *items, for_update=False, strict=False): """Check whether alternatives passed as argument respect integrity constraints. Args: @@ -77,15 +77,15 @@ def check_alternatives(self, *items, for_update=False, strict=False, cache=None) list: items that passed the check. list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. """ - if cache is None: - cache = self.make_cache({"alternative"}, include_ancestors=True) + self.fetch_all({"alternative"}, include_ancestors=True) + cache = self.cache intgr_error_log = [] checked_items = list() alternative_ids = {x.name: x.id for x in cache.get("alternative", {}).values()} for item in items: try: with self._manage_stocks( - "alternative", item, {("name",): alternative_ids}, for_update, cache, intgr_error_log + "alternative", item, {("name",): alternative_ids}, for_update, intgr_error_log ) as item: check_alternative(item, alternative_ids) checked_items.append(item) @@ -95,7 +95,7 @@ def check_alternatives(self, *items, for_update=False, strict=False, cache=None) intgr_error_log.append(e) return checked_items, intgr_error_log - def check_scenarios(self, *items, for_update=False, strict=False, cache=None): + def check_scenarios(self, *items, for_update=False, strict=False): """Check whether scenarios passed as argument respect integrity constraints. Args: @@ -107,15 +107,15 @@ def check_scenarios(self, *items, for_update=False, strict=False, cache=None): list: items that passed the check. list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. """ - if cache is None: - cache = self.make_cache({"scenario"}, include_ancestors=True) + self.fetch_all({"scenario"}, include_ancestors=True) + cache = self.cache intgr_error_log = [] checked_items = list() scenario_ids = {x.name: x.id for x in cache.get("scenario", {}).values()} for item in items: try: with self._manage_stocks( - "scenario", item, {("name",): scenario_ids}, for_update, cache, intgr_error_log + "scenario", item, {("name",): scenario_ids}, for_update, intgr_error_log ) as item: check_scenario(item, scenario_ids) checked_items.append(item) @@ -125,7 +125,7 @@ def check_scenarios(self, *items, for_update=False, strict=False, cache=None): intgr_error_log.append(e) return checked_items, intgr_error_log - def check_scenario_alternatives(self, *items, for_update=False, strict=False, cache=None): + def check_scenario_alternatives(self, *items, for_update=False, strict=False): """Check whether scenario alternatives passed as argument respect integrity constraints. Args: @@ -137,8 +137,8 @@ def check_scenario_alternatives(self, *items, for_update=False, strict=False, ca list: items that passed the check. list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. """ - if cache is None: - cache = self.make_cache({"scenario_alternative"}, include_ancestors=True) + self.fetch_all({"scenario_alternative"}, include_ancestors=True) + cache = self.cache intgr_error_log = [] checked_items = list() ids_by_alt_id = {} @@ -155,7 +155,6 @@ def check_scenario_alternatives(self, *items, for_update=False, strict=False, ca item, {("scenario_id", "alternative_id"): ids_by_alt_id, ("scenario_id", "rank"): ids_by_rank}, for_update, - cache, intgr_error_log, ) as item: check_scenario_alternative(item, ids_by_alt_id, ids_by_rank, scenario_names, alternative_names) @@ -166,7 +165,7 @@ def check_scenario_alternatives(self, *items, for_update=False, strict=False, ca intgr_error_log.append(e) return checked_items, intgr_error_log - def check_entity_classes(self, *items, for_update=False, strict=False, cache=None): + def check_entity_classes(self, *items, for_update=False, strict=False): """Check whether entity classes passed as argument respect integrity constraints. Args: @@ -178,15 +177,15 @@ def check_entity_classes(self, *items, for_update=False, strict=False, cache=Non list: items that passed the check. list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. """ - if cache is None: - cache = self.make_cache({"entity_class"}, include_ancestors=True) + self.fetch_all({"entity_class"}, include_ancestors=True) + cache = self.cache intgr_error_log = [] checked_items = list() entity_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values()} for item in items: try: with self._manage_stocks( - "entity_class", item, {("name",): entity_class_ids}, for_update, cache, intgr_error_log + "entity_class", item, {("name",): entity_class_ids}, for_update, intgr_error_log ) as item: check_entity_class(item, entity_class_ids) checked_items.append(item) @@ -196,7 +195,7 @@ def check_entity_classes(self, *items, for_update=False, strict=False, cache=Non intgr_error_log.append(e) return checked_items, intgr_error_log - def check_entities(self, *items, for_update=False, strict=False, cache=None): + def check_entities(self, *items, for_update=False, strict=False): """Check whether entities passed as argument respect integrity constraints. Args: @@ -208,8 +207,8 @@ def check_entities(self, *items, for_update=False, strict=False, cache=None): list: items that passed the check. list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. """ - if cache is None: - cache = self.make_cache({"entity"}, include_ancestors=True) + self.fetch_all({"entity"}, include_ancestors=True) + cache = self.cache intgr_error_log = [] checked_items = list() entity_ids_by_name = {(x.class_id, x.name): x.id for x in cache.get("entity", {}).values()} @@ -229,7 +228,6 @@ def check_entities(self, *items, for_update=False, strict=False, cache=None): ("class_id", "element_id_list"): entity_ids_by_el_id_lst, }, for_update, - cache, intgr_error_log, ) as item: check_entity( @@ -246,7 +244,7 @@ def check_entities(self, *items, for_update=False, strict=False, cache=None): intgr_error_log.append(e) return checked_items, intgr_error_log - def check_object_classes(self, *items, for_update=False, strict=False, cache=None): + def check_object_classes(self, *items, for_update=False, strict=False): """Check whether object classes passed as argument respect integrity constraints. Args: @@ -258,15 +256,15 @@ def check_object_classes(self, *items, for_update=False, strict=False, cache=Non list: items that passed the check. list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. """ - if cache is None: - cache = self.make_cache({"entity_class"}, include_ancestors=True) + self.fetch_all({"entity_class"}, include_ancestors=True) + cache = self.cache intgr_error_log = [] checked_items = list() object_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values() if not x.dimension_id_list} for item in items: try: with self._manage_stocks( - "entity_class", item, {("name",): object_class_ids}, for_update, cache, intgr_error_log + "entity_class", item, {("name",): object_class_ids}, for_update, intgr_error_log ) as item: check_object_class(item, object_class_ids) checked_items.append(item) @@ -276,7 +274,7 @@ def check_object_classes(self, *items, for_update=False, strict=False, cache=Non intgr_error_log.append(e) return checked_items, intgr_error_log - def check_objects(self, *items, for_update=False, strict=False, cache=None): + def check_objects(self, *items, for_update=False, strict=False): """Check whether objects passed as argument respect integrity constraints. Args: items (Iterable): One or more Python :class:`dict` objects representing the items to be checked. @@ -287,8 +285,8 @@ def check_objects(self, *items, for_update=False, strict=False, cache=None): list: items that passed the check. list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. """ - if cache is None: - cache = self.make_cache({"entity"}, include_ancestors=True) + self.fetch_all({"entity"}, include_ancestors=True) + cache = self.cache intgr_error_log = [] checked_items = list() object_ids = {(x.class_id, x.name): x.id for x in cache.get("entity", {}).values() if not x.element_id_list} @@ -296,7 +294,7 @@ def check_objects(self, *items, for_update=False, strict=False, cache=None): for item in items: try: with self._manage_stocks( - "entity", item, {("class_id", "name"): object_ids}, for_update, cache, intgr_error_log + "entity", item, {("class_id", "name"): object_ids}, for_update, intgr_error_log ) as item: check_object(item, object_ids, object_class_ids) checked_items.append(item) @@ -306,7 +304,7 @@ def check_objects(self, *items, for_update=False, strict=False, cache=None): intgr_error_log.append(e) return checked_items, intgr_error_log - def check_wide_relationship_classes(self, *wide_items, for_update=False, strict=False, cache=None): + def check_wide_relationship_classes(self, *wide_items, for_update=False, strict=False): """Check whether relationship classes passed as argument respect integrity constraints. Args: @@ -318,8 +316,8 @@ def check_wide_relationship_classes(self, *wide_items, for_update=False, strict= list: items that passed the check. list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. """ - if cache is None: - cache = self.make_cache({"entity_class"}, include_ancestors=True) + self.fetch_all({"entity_class"}, include_ancestors=True) + cache = self.cache intgr_error_log = [] checked_wide_items = list() relationship_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values() if x.dimension_id_list} @@ -334,7 +332,6 @@ def check_wide_relationship_classes(self, *wide_items, for_update=False, strict= wide_item, {("name",): relationship_class_ids}, for_update, - cache, intgr_error_log, ) as wide_item: if "object_class_id_list" not in wide_item: @@ -348,7 +345,7 @@ def check_wide_relationship_classes(self, *wide_items, for_update=False, strict= intgr_error_log.append(e) return checked_wide_items, intgr_error_log - def check_wide_relationships(self, *wide_items, for_update=False, strict=False, cache=None): + def check_wide_relationships(self, *wide_items, for_update=False, strict=False): """Check whether relationships passed as argument respect integrity constraints. Args: @@ -360,8 +357,8 @@ def check_wide_relationships(self, *wide_items, for_update=False, strict=False, list: items that passed the check. list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. """ - if cache is None: - cache = self.make_cache({"entity"}, include_ancestors=True) + self.fetch_all({"entity"}, include_ancestors=True) + cache = self.cache intgr_error_log = [] checked_wide_items = list() relationship_ids_by_name = { @@ -393,7 +390,6 @@ def check_wide_relationships(self, *wide_items, for_update=False, strict=False, ("class_id", "element_id_list"): relationship_ids_by_obj_lst, }, for_update, - cache, intgr_error_log, ) as wide_item: if "object_class_id_list" not in wide_item: @@ -416,7 +412,7 @@ def check_wide_relationships(self, *wide_items, for_update=False, strict=False, intgr_error_log.append(e) return checked_wide_items, intgr_error_log - def check_entity_groups(self, *items, for_update=False, strict=False, cache=None): + def check_entity_groups(self, *items, for_update=False, strict=False): """Check whether entity groups passed as argument respect integrity constraints. Args: @@ -428,8 +424,8 @@ def check_entity_groups(self, *items, for_update=False, strict=False, cache=None list: items that passed the check. list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. """ - if cache is None: - cache = self.make_cache({"entity_group"}, include_ancestors=True) + self.fetch_all({"entity_group"}, include_ancestors=True) + cache = self.cache intgr_error_log = list() checked_items = list() current_ids = {(x.group_id, x.member_id): x.id for x in cache.get("entity_group", {}).values()} @@ -439,7 +435,7 @@ def check_entity_groups(self, *items, for_update=False, strict=False, cache=None for item in items: try: with self._manage_stocks( - "entity_group", item, {("entity_id", "member_id"): current_ids}, for_update, cache, intgr_error_log + "entity_group", item, {("entity_id", "member_id"): current_ids}, for_update, intgr_error_log ) as item: check_entity_group(item, current_ids, entities) checked_items.append(item) @@ -449,7 +445,7 @@ def check_entity_groups(self, *items, for_update=False, strict=False, cache=None intgr_error_log.append(e) return checked_items, intgr_error_log - def check_parameter_definitions(self, *items, for_update=False, strict=False, cache=None): + def check_parameter_definitions(self, *items, for_update=False, strict=False): """Check whether parameter definitions passed as argument respect integrity constraints. Args: @@ -461,8 +457,8 @@ def check_parameter_definitions(self, *items, for_update=False, strict=False, ca list: items that passed the check. list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. """ - if cache is None: - cache = self.make_cache({"parameter_definition", "parameter_value"}, include_ancestors=True) + self.fetch_all({"parameter_definition", "parameter_value"}, include_ancestors=True) + cache = self.cache parameter_definition_ids_with_values = { value.parameter_id for value in cache.get("parameter_value", {}).values() } @@ -502,7 +498,6 @@ def check_parameter_definitions(self, *items, for_update=False, strict=False, ca item, {("entity_class_id", "name"): parameter_definition_ids}, for_update, - cache, intgr_error_log, ) as full_item: check_parameter_definition( @@ -515,7 +510,7 @@ def check_parameter_definitions(self, *items, for_update=False, strict=False, ca intgr_error_log.append(e) return checked_items, intgr_error_log - def check_parameter_values(self, *items, for_update=False, strict=False, cache=None): + def check_parameter_values(self, *items, for_update=False, strict=False): """Check whether parameter values passed as argument respect integrity constraints. Args: @@ -527,8 +522,8 @@ def check_parameter_values(self, *items, for_update=False, strict=False, cache=N list: items that passed the check. list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. """ - if cache is None: - cache = self.make_cache({"parameter_value"}, include_ancestors=True) + self.fetch_all({"parameter_value"}, include_ancestors=True) + cache = self.cache intgr_error_log = [] checked_items = list() parameter_value_ids = { @@ -556,7 +551,6 @@ def check_parameter_values(self, *items, for_update=False, strict=False, cache=N item, {("entity_id", "parameter_definition_id", "alternative_id"): parameter_value_ids}, for_update, - cache, intgr_error_log, ) as item: check_parameter_value( @@ -575,7 +569,7 @@ def check_parameter_values(self, *items, for_update=False, strict=False, cache=N intgr_error_log.append(e) return checked_items, intgr_error_log - def check_parameter_value_lists(self, *items, for_update=False, strict=False, cache=None): + def check_parameter_value_lists(self, *items, for_update=False, strict=False): """Check whether parameter value-lists passed as argument respect integrity constraints. Args: @@ -587,8 +581,8 @@ def check_parameter_value_lists(self, *items, for_update=False, strict=False, ca list: items that passed the check. list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. """ - if cache is None: - cache = self.make_cache({"parameter_value_list"}, include_ancestors=True) + self.fetch_all({"parameter_value_list"}, include_ancestors=True) + cache = self.cache intgr_error_log = [] checked_items = list() parameter_value_list_ids = {x.name: x.id for x in cache.get("parameter_value_list", {}).values()} @@ -599,7 +593,6 @@ def check_parameter_value_lists(self, *items, for_update=False, strict=False, ca item, {("name",): parameter_value_list_ids}, for_update, - cache, intgr_error_log, ) as item: check_parameter_value_list(item, parameter_value_list_ids) @@ -610,7 +603,7 @@ def check_parameter_value_lists(self, *items, for_update=False, strict=False, ca intgr_error_log.append(e) return checked_items, intgr_error_log - def check_list_values(self, *items, for_update=False, strict=False, cache=None): + def check_list_values(self, *items, for_update=False, strict=False): """Check whether list values passed as argument respect integrity constraints. Args: @@ -622,8 +615,8 @@ def check_list_values(self, *items, for_update=False, strict=False, cache=None): list: items that passed the check. list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. """ - if cache is None: - cache = self.make_cache({"list_value"}, include_ancestors=True) + self.fetch_all({"list_value"}, include_ancestors=True) + cache = self.cache intgr_error_log = [] checked_items = list() list_value_ids_by_index = { @@ -643,7 +636,6 @@ def check_list_values(self, *items, for_update=False, strict=False, cache=None): ("parameter_value_list_id", "type", "value"): list_value_ids_by_value, }, for_update, - cache, intgr_error_log, ) as item: check_list_value(item, list_names_by_id, list_value_ids_by_index, list_value_ids_by_value) @@ -654,28 +646,27 @@ def check_list_values(self, *items, for_update=False, strict=False, cache=None): intgr_error_log.append(e) return checked_items, intgr_error_log - def check_metadata(self, *items, for_update=False, strict=False, cache=None): + def check_metadata(self, *items, for_update=False, strict=False): """Checks whether metadata respects integrity constraints. Args: *items: One or more Python :class:`dict` objects representing the items to be checked. strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` if one of the items violates an integrity constraint. - cache (dict, optional): Database cache Returns list: items that passed the check. list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. """ - if cache is None: - cache = self.make_cache({"metadata"}) + self.fetch_all({"metadata"}) + cache = self.cache intgr_error_log = [] checked_items = list() metadata = {(x.name, x.value): x.id for x in cache.get("metadata", {}).values()} for item in items: try: with self._manage_stocks( - "metadata", item, {("name", "value"): metadata}, for_update, cache, intgr_error_log + "metadata", item, {("name", "value"): metadata}, for_update, intgr_error_log ) as item: check_metadata(item, metadata) if (item["name"], item["value"]) not in metadata: @@ -686,28 +677,27 @@ def check_metadata(self, *items, for_update=False, strict=False, cache=None): intgr_error_log.append(e) return checked_items, intgr_error_log - def check_entity_metadata(self, *items, for_update=False, strict=False, cache=None): + def check_entity_metadata(self, *items, for_update=False, strict=False): """Checks whether entity metadata respects integrity constraints. Args: *items: One or more Python :class:`dict` objects representing the items to be checked. strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` if one of the items violates an integrity constraint. - cache (dict, optional): Database cache Returns list: items that passed the check. list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. """ - if cache is None: - cache = self.make_cache({"entity_metadata"}, include_ancestors=True) + self.fetch_all({"entity_metadata"}, include_ancestors=True) + cache = self.cache intgr_error_log = [] checked_items = list() entities = {x.id for x in cache.get("entity", {}).values()} metadata = {x.id for x in cache.get("metadata", {}).values()} for item in items: try: - with self._manage_stocks("entity_metadata", item, {}, for_update, cache, intgr_error_log) as item: + with self._manage_stocks("entity_metadata", item, {}, for_update, intgr_error_log) as item: check_entity_metadata(item, entities, metadata) checked_items.append(item) except SpineIntegrityError as e: @@ -716,30 +706,27 @@ def check_entity_metadata(self, *items, for_update=False, strict=False, cache=No intgr_error_log.append(e) return checked_items, intgr_error_log - def check_parameter_value_metadata(self, *items, for_update=False, strict=False, cache=None): + def check_parameter_value_metadata(self, *items, for_update=False, strict=False): """Checks whether parameter value metadata respects integrity constraints. Args: *items: One or more Python :class:`dict` objects representing the items to be checked. strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` if one of the items violates an integrity constraint. - cache (dict, optional): Database cache Returns list: items that passed the check. list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. """ - if cache is None: - cache = self.make_cache({"parameter_value_metadata"}, include_ancestors=True) + self.fetch_all({"parameter_value_metadata"}, include_ancestors=True) + cache = self.cache intgr_error_log = [] checked_items = list() values = {x.id for x in cache.get("parameter_value", {}).values()} metadata = {x.id for x in cache.get("metadata", {}).values()} for item in items: try: - with self._manage_stocks( - "parameter_value_metadata", item, {}, for_update, cache, intgr_error_log - ) as item: + with self._manage_stocks("parameter_value_metadata", item, {}, for_update, intgr_error_log) as item: check_parameter_value_metadata(item, values, metadata) checked_items.append(item) except SpineIntegrityError as e: @@ -749,7 +736,8 @@ def check_parameter_value_metadata(self, *items, for_update=False, strict=False, return checked_items, intgr_error_log @contextmanager - def _manage_stocks(self, item_type, item, existing_ids_by_pk, for_update, cache, intgr_error_log): + def _manage_stocks(self, item_type, item, existing_ids_by_pk, for_update, intgr_error_log): + cache = self.cache if for_update: try: id_ = item["id"] diff --git a/spinedb_api/db_mapping_remove_mixin.py b/spinedb_api/db_mapping_remove_mixin.py index 6ab2e806..b9d4f564 100644 --- a/spinedb_api/db_mapping_remove_mixin.py +++ b/spinedb_api/db_mapping_remove_mixin.py @@ -23,13 +23,13 @@ class DatabaseMappingRemoveMixin: """Provides the :meth:`remove_items` method to stage ``REMOVE`` operations over a Spine db.""" # pylint: disable=redefined-builtin - def cascade_remove_items(self, cache=None, **kwargs): + def cascade_remove_items(self, **kwargs): """Removes items by id in cascade. Args: **kwargs: keyword is table name, argument is list of ids to remove """ - cascading_ids = self.cascading_ids(cache=cache, **kwargs) + cascading_ids = self.cascading_ids(**kwargs) self.remove_items(**cascading_ids) def remove_items(self, **kwargs): @@ -61,11 +61,10 @@ def remove_items(self, **kwargs): self._has_pending_changes = True # pylint: disable=redefined-builtin - def cascading_ids(self, cache=None, **kwargs): + def cascading_ids(self, **kwargs): """Returns cascading ids. Keyword args: - cache (dict, optional) **kwargs: set of ids keyed by table name to be removed Returns: @@ -80,30 +79,27 @@ def cascading_ids(self, cache=None, **kwargs): if ids is not None: # FIXME: Add deprecation warning kwargs.setdefault(new_tablename, set()).update(ids) - if cache is None: - cache = self.make_cache( - set(kwargs), - include_descendants=True, - force_tablenames={"entity_metadata", "parameter_value_metadata"} - if any(x in kwargs for x in ("entity_metadata", "parameter_value_metadata", "metadata")) - else None, - ) - ids = {} - self._merge(ids, self._entity_class_cascading_ids(kwargs.get("entity_class", set()), cache)) - self._merge(ids, self._entity_cascading_ids(kwargs.get("entity", set()), cache)) - self._merge(ids, self._entity_group_cascading_ids(kwargs.get("entity_group", set()), cache)) - self._merge(ids, self._parameter_definition_cascading_ids(kwargs.get("parameter_definition", set()), cache)) - self._merge(ids, self._parameter_value_cascading_ids(kwargs.get("parameter_value", set()), cache)) - self._merge(ids, self._parameter_value_list_cascading_ids(kwargs.get("parameter_value_list", set()), cache)) - self._merge(ids, self._list_value_cascading_ids(kwargs.get("list_value", set()), cache)) - self._merge(ids, self._alternative_cascading_ids(kwargs.get("alternative", set()), cache)) - self._merge(ids, self._scenario_cascading_ids(kwargs.get("scenario", set()), cache)) - self._merge(ids, self._scenario_alternatives_cascading_ids(kwargs.get("scenario_alternative", set()), cache)) - self._merge(ids, self._metadata_cascading_ids(kwargs.get("metadata", set()), cache)) - self._merge(ids, self._entity_metadata_cascading_ids(kwargs.get("entity_metadata", set()), cache)) - self._merge( - ids, self._parameter_value_metadata_cascading_ids(kwargs.get("parameter_value_metadata", set()), cache) + self.fetch_all( + set(kwargs), + include_descendants=True, + force_tablenames={"entity_metadata", "parameter_value_metadata"} + if any(x in kwargs for x in ("entity_metadata", "parameter_value_metadata", "metadata")) + else None, ) + ids = {} + self._merge(ids, self._entity_class_cascading_ids(kwargs.get("entity_class", set()))) + self._merge(ids, self._entity_cascading_ids(kwargs.get("entity", set()))) + self._merge(ids, self._entity_group_cascading_ids(kwargs.get("entity_group", set()))) + self._merge(ids, self._parameter_definition_cascading_ids(kwargs.get("parameter_definition", set()))) + self._merge(ids, self._parameter_value_cascading_ids(kwargs.get("parameter_value", set()))) + self._merge(ids, self._parameter_value_list_cascading_ids(kwargs.get("parameter_value_list", set()))) + self._merge(ids, self._list_value_cascading_ids(kwargs.get("list_value", set()))) + self._merge(ids, self._alternative_cascading_ids(kwargs.get("alternative", set()))) + self._merge(ids, self._scenario_cascading_ids(kwargs.get("scenario", set()))) + self._merge(ids, self._scenario_alternatives_cascading_ids(kwargs.get("scenario_alternative", set()))) + self._merge(ids, self._metadata_cascading_ids(kwargs.get("metadata", set()))) + self._merge(ids, self._entity_metadata_cascading_ids(kwargs.get("entity_metadata", set()))) + self._merge(ids, self._parameter_value_metadata_cascading_ids(kwargs.get("parameter_value_metadata", set()))) sorted_ids = {} while ids: tablename = next(iter(ids)) @@ -122,31 +118,32 @@ def _merge(left, right): for tablename, ids in right.items(): left.setdefault(tablename, set()).update(ids) - def _alternative_cascading_ids(self, ids, cache): + def _alternative_cascading_ids(self, ids): """Returns alternative cascading ids.""" + cache = self.cache cascading_ids = {"alternative": set(ids)} - parameter_values = [x for x in dict.values(cache.get("parameter_value", {})) if x.alternative_id in ids] - scenario_alternatives = [ + entity_alternatives = (x for x in dict.values(cache.get("entity_alternative", {})) if x.alternative_id in ids) + parameter_values = (x for x in dict.values(cache.get("parameter_value", {})) if x.alternative_id in ids) + scenario_alternatives = ( x for x in dict.values(cache.get("scenario_alternative", {})) if x.alternative_id in ids - ] - self._merge(cascading_ids, self._parameter_value_cascading_ids({x.id for x in parameter_values}, cache)) - self._merge( - cascading_ids, self._scenario_alternatives_cascading_ids({x.id for x in scenario_alternatives}, cache) ) + self._merge(cascading_ids, self._entity_alternative_cascading_ids({x.id for x in entity_alternatives})) + self._merge(cascading_ids, self._parameter_value_cascading_ids({x.id for x in parameter_values})) + self._merge(cascading_ids, self._scenario_alternatives_cascading_ids({x.id for x in scenario_alternatives})) return cascading_ids - def _scenario_cascading_ids(self, ids, cache): + def _scenario_cascading_ids(self, ids): + cache = self.cache cascading_ids = {"scenario": set(ids)} scenario_alternatives = [x for x in dict.values(cache.get("scenario_alternative", {})) if x.scenario_id in ids] - self._merge( - cascading_ids, self._scenario_alternatives_cascading_ids({x.id for x in scenario_alternatives}, cache) - ) + self._merge(cascading_ids, self._scenario_alternatives_cascading_ids({x.id for x in scenario_alternatives})) return cascading_ids - def _entity_class_cascading_ids(self, ids, cache): + def _entity_class_cascading_ids(self, ids): """Returns entity class cascading ids.""" if not ids: return {} + cache = self.cache cascading_ids = {"entity_class": set(ids), "entity_class_dimension": set(ids)} entities = [x for x in dict.values(cache.get("entity", {})) if x.class_id in ids] entity_classes = ( @@ -155,61 +152,68 @@ def _entity_class_cascading_ids(self, ids, cache): paramerer_definitions = [ x for x in dict.values(cache.get("parameter_definition", {})) if x.entity_class_id in ids ] - self._merge(cascading_ids, self._entity_cascading_ids({x.id for x in entities}, cache)) - self._merge(cascading_ids, self._entity_class_cascading_ids({x.id for x in entity_classes}, cache)) - self._merge( - cascading_ids, self._parameter_definition_cascading_ids({x.id for x in paramerer_definitions}, cache) - ) + self._merge(cascading_ids, self._entity_cascading_ids({x.id for x in entities})) + self._merge(cascading_ids, self._entity_class_cascading_ids({x.id for x in entity_classes})) + self._merge(cascading_ids, self._parameter_definition_cascading_ids({x.id for x in paramerer_definitions})) return cascading_ids - def _entity_cascading_ids(self, ids, cache): + def _entity_cascading_ids(self, ids): """Returns entity cascading ids.""" if not ids: return {} + cache = self.cache cascading_ids = {"entity": set(ids), "entity_element": set(ids)} entities = (x for x in dict.values(cache.get("entity", {})) if set(x.element_id_list).intersection(ids)) - parameter_values = [x for x in dict.values(cache.get("parameter_value", {})) if x.entity_id in ids] - groups = [x for x in dict.values(cache.get("entity_group", {})) if {x.group_id, x.member_id}.intersection(ids)] + entity_alternatives = (x for x in dict.values(cache.get("entity_alternative", {})) if x.entity_id in ids) + parameter_values = (x for x in dict.values(cache.get("parameter_value", {})) if x.entity_id in ids) + groups = (x for x in dict.values(cache.get("entity_group", {})) if {x.group_id, x.member_id}.intersection(ids)) entity_metadata_ids = {x.id for x in dict.values(cache.get("entity_metadata", {})) if x.entity_id in ids} - self._merge(cascading_ids, self._entity_cascading_ids({x.id for x in entities}, cache)) - self._merge(cascading_ids, self._parameter_value_cascading_ids({x.id for x in parameter_values}, cache)) - self._merge(cascading_ids, self._entity_group_cascading_ids({x.id for x in groups}, cache)) - self._merge(cascading_ids, self._entity_metadata_cascading_ids(entity_metadata_ids, cache)) + self._merge(cascading_ids, self._entity_cascading_ids({x.id for x in entities})) + self._merge(cascading_ids, self._entity_alternative_cascading_ids({x.id for x in entity_alternatives})) + self._merge(cascading_ids, self._parameter_value_cascading_ids({x.id for x in parameter_values})) + self._merge(cascading_ids, self._entity_group_cascading_ids({x.id for x in groups})) + self._merge(cascading_ids, self._entity_metadata_cascading_ids(entity_metadata_ids)) return cascading_ids - def _entity_group_cascading_ids(self, ids, cache): # pylint: disable=no-self-use + def _entity_alternative_cascading_ids(self, ids): + return {"entity_alternative": set(ids)} + + def _entity_group_cascading_ids(self, ids): # pylint: disable=no-self-use """Returns entity group cascading ids.""" return {"entity_group": set(ids)} - def _parameter_definition_cascading_ids(self, ids, cache): + def _parameter_definition_cascading_ids(self, ids): """Returns parameter definition cascading ids.""" + cache = self.cache cascading_ids = {"parameter_definition": set(ids)} parameter_values = [x for x in dict.values(cache.get("parameter_value", {})) if x.parameter_id in ids] - self._merge(cascading_ids, self._parameter_value_cascading_ids({x.id for x in parameter_values}, cache)) + self._merge(cascading_ids, self._parameter_value_cascading_ids({x.id for x in parameter_values})) return cascading_ids - def _parameter_value_cascading_ids(self, ids, cache): # pylint: disable=no-self-use + def _parameter_value_cascading_ids(self, ids): # pylint: disable=no-self-use """Returns parameter value cascading ids.""" + cache = self.cache cascading_ids = {"parameter_value": set(ids)} value_metadata_ids = { x.id for x in dict.values(cache.get("parameter_value_metadata", {})) if x.parameter_value_id in ids } - self._merge(cascading_ids, self._parameter_value_metadata_cascading_ids(value_metadata_ids, cache)) + self._merge(cascading_ids, self._parameter_value_metadata_cascading_ids(value_metadata_ids)) return cascading_ids - def _parameter_value_list_cascading_ids(self, ids, cache): # pylint: disable=no-self-use + def _parameter_value_list_cascading_ids(self, ids): # pylint: disable=no-self-use """Returns parameter value list cascading ids and adds them to the given dictionaries.""" cascading_ids = {"parameter_value_list": set(ids)} return cascading_ids - def _list_value_cascading_ids(self, ids, cache): # pylint: disable=no-self-use + def _list_value_cascading_ids(self, ids): # pylint: disable=no-self-use """Returns parameter value list value cascading ids.""" return {"list_value": set(ids)} - def _scenario_alternatives_cascading_ids(self, ids, cache): + def _scenario_alternatives_cascading_ids(self, ids): return {"scenario_alternative": set(ids)} - def _metadata_cascading_ids(self, ids, cache): + def _metadata_cascading_ids(self, ids): + cache = self.cache cascading_ids = {"metadata": set(ids)} entity_metadata = { "entity_metadata": {x.id for x in dict.values(cache.get("entity_metadata", {})) if x.metadata_id in ids} @@ -223,8 +227,9 @@ def _metadata_cascading_ids(self, ids, cache): self._merge(cascading_ids, value_metadata) return cascading_ids - def _non_referenced_metadata_ids(self, ids, metadata_table_name, cache): - metadata_id_counts = self._metadata_usage_counts(cache) + def _non_referenced_metadata_ids(self, ids, metadata_table_name): + cache = self.cache + metadata_id_counts = self._metadata_usage_counts() cascading_ids = {} metadata = cache.get(metadata_table_name, {}) for id_ in ids: @@ -233,12 +238,12 @@ def _non_referenced_metadata_ids(self, ids, metadata_table_name, cache): self._merge(cascading_ids, {"metadata": zero_count_metadata_ids}) return cascading_ids - def _entity_metadata_cascading_ids(self, ids, cache): + def _entity_metadata_cascading_ids(self, ids): cascading_ids = {"entity_metadata": set(ids)} - cascading_ids.update(self._non_referenced_metadata_ids(ids, "entity_metadata", cache)) + cascading_ids.update(self._non_referenced_metadata_ids(ids, "entity_metadata")) return cascading_ids - def _parameter_value_metadata_cascading_ids(self, ids, cache): + def _parameter_value_metadata_cascading_ids(self, ids): cascading_ids = {"parameter_value_metadata": set(ids)} - cascading_ids.update(self._non_referenced_metadata_ids(ids, "parameter_value_metadata", cache)) + cascading_ids.update(self._non_referenced_metadata_ids(ids, "parameter_value_metadata")) return cascading_ids diff --git a/spinedb_api/db_mapping_update_mixin.py b/spinedb_api/db_mapping_update_mixin.py index 2223edc0..bb612e0a 100644 --- a/spinedb_api/db_mapping_update_mixin.py +++ b/spinedb_api/db_mapping_update_mixin.py @@ -33,6 +33,8 @@ def _update_items(self, tablename, *items, dry_run=False): # Special cases if tablename == "entity": return self._do_update_entities(*items) + if tablename == "scenario": + return self._do_update_scenarios(*items) if tablename == "object": return self._do_update_objects(*items) if tablename == "relationship": @@ -64,14 +66,56 @@ def _do_update_entities(self, *items): for position, (dimension_id, element_id) in enumerate(zip(dimension_id_list, element_id_list)) ] ) - entity_alternative_items.append( - {"entity_id": entity_id, "alternative_id": item["alternative_id"], "active": item["active"]} + entity_alternative_items.extend( + [ + {"entity_id": entity_id, "alternative_id": alt_id, "active": True} + for alt_id in item["active_alternative_id_list"] + ] + + [ + {"entity_id": entity_id, "alternative_id": alt_id, "active": False} + for alt_id in item["inactive_alternative_id_list"] + ] ) self._do_update_items("entity", *entity_items) self._do_update_items("entity_element", *entity_element_items) self._do_update_items("entity_alternative", *entity_alternative_items) return {x["id"] for x in entity_items} + def _do_update_scenarios(self, *items): + """Returns data to add and remove, in order to set wide scenario alternatives. + + Args: + *items: One or more wide scenario :class:`dict` objects to set. + Each item must include the following keys: + + - "id": integer scenario id + - "alternative_id_list": list of alternative ids for that scenario + + Returns + list: narrow scenario_alternative :class:`dict` objects to add. + set: integer scenario_alternative ids to remove + """ + self.fetch_all({"scenario_alternative", "scenario"}) + cache = self.cache + current_alternative_id_lists = {x.id: x.alternative_id_list for x in cache.get("scenario", {}).values()} + scenario_alternative_ids = { + (x.scenario_id, x.alternative_id): x.id for x in cache.get("scenario_alternative", {}).values() + } + scen_alts_to_add = [] + scen_alt_ids_to_remove = set() + for item in items: + scenario_id = item["id"] + alternative_id_list = item["alternative_id_list"] + current_alternative_id_list = current_alternative_id_lists[scenario_id] + for k, alternative_id in enumerate(alternative_id_list): + item_to_add = {"scenario_id": scenario_id, "alternative_id": alternative_id, "rank": k + 1} + scen_alts_to_add.append(item_to_add) + for alternative_id in current_alternative_id_list: + scen_alt_ids_to_remove.add(scenario_alternative_ids[scenario_id, alternative_id]) + self.remove_items(scenario_alternative=scen_alt_ids_to_remove) + self.add_items("scenario_alternative", *scen_alts_to_add) + return self._do_update_items("scenario", *items) + def _do_update_objects(self, *items): entity_items = [] entity_alternative_items = [] @@ -81,8 +125,15 @@ def _do_update_objects(self, *items): entity_items.append( {"id": entity_id, "class_id": class_id, "name": item["name"], "description": item.get("description")} ) - entity_alternative_items.append( - {"entity_id": entity_id, "alternative_id": item["alternative_id"], "active": item["active"]} + entity_alternative_items.extend( + [ + {"entity_id": entity_id, "alternative_id": alt_id, "active": True} + for alt_id in item["active_alternative_id_list"] + ] + + [ + {"entity_id": entity_id, "alternative_id": alt_id, "active": False} + for alt_id in item["inactive_alternative_id_list"] + ] ) self._do_update_items("entity", *entity_items) self._do_update_items("entity_alternative", *entity_alternative_items) @@ -91,6 +142,7 @@ def _do_update_objects(self, *items): def _do_update_wide_relationships(self, *items): entity_items = [] entity_element_items = [] + entity_alternative_items = [] for item in items: entity_id = item["id"] class_id = item["class_id"] @@ -111,8 +163,19 @@ def _do_update_wide_relationships(self, *items): for position, (dimension_id, element_id) in enumerate(zip(object_class_id_list, object_id_list)) ] ) + entity_alternative_items.extend( + [ + {"entity_id": entity_id, "alternative_id": alt_id, "active": True} + for alt_id in item["active_alternative_id_list"] + ] + + [ + {"entity_id": entity_id, "alternative_id": alt_id, "active": False} + for alt_id in item["inactive_alternative_id_list"] + ] + ) self._do_update_items("entity", *entity_items) self._do_update_items("entity_element", *entity_element_items) + self._do_update_items("entity_alternative", *entity_alternative_items) return {x["id"] for x in entity_items} def _do_update_items(self, tablename, *items): @@ -132,7 +195,7 @@ def _do_update_items(self, tablename, *items): else: self._has_pending_changes = True - def update_items(self, tablename, *items, check=True, strict=False, return_items=False, cache=None, dry_run=False): + def update_items(self, tablename, *items, check=True, strict=False, return_items=False, dry_run=False): """Updates items. Args: @@ -142,17 +205,13 @@ def update_items(self, tablename, *items, check=True, strict=False, return_items strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` if the insertion of one of the items violates an integrity constraint. return_items (bool): Return full items rather than just ids - cache (dict): A dict mapping table names to a list of dictionary items, to use as db replacement - for queries Returns: set: ids or items successfully updated list(SpineIntegrityError): found violations """ if check: - checked_items, intgr_error_log = self.check_items( - tablename, *items, for_update=True, strict=strict, cache=cache - ) + checked_items, intgr_error_log = self.check_items(tablename, *items, for_update=True, strict=strict) else: checked_items, intgr_error_log = list(items), [] updated_ids = self._update_items(tablename, *checked_items, dry_run=dry_run) @@ -244,34 +303,30 @@ def update_metadata(self, *items, **kwargs): def _update_metadata(self, *items): return self._update_items("metadata", *items) - def update_ext_entity_metadata( - self, *items, check=True, strict=False, return_items=False, cache=None, dry_run=False - ): + def update_ext_entity_metadata(self, *items, check=True, strict=False, return_items=False, dry_run=False): updated_items, errors = self._update_ext_item_metadata( - "entity_metadata", *items, check=check, strict=strict, cache=cache, dry_run=dry_run + "entity_metadata", *items, check=check, strict=strict, dry_run=dry_run ) if return_items: return updated_items, errors return {i["id"] for i in updated_items}, errors - def update_ext_parameter_value_metadata( - self, *items, check=True, strict=False, return_items=False, cache=None, dry_run=False - ): + def update_ext_parameter_value_metadata(self, *items, check=True, strict=False, return_items=False, dry_run=False): updated_items, errors = self._update_ext_item_metadata( - "parameter_value_metadata", *items, check=check, strict=strict, cache=cache, dry_run=dry_run + "parameter_value_metadata", *items, check=check, strict=strict, dry_run=dry_run ) if return_items: return updated_items, errors return {i["id"] for i in updated_items}, errors - def _update_ext_item_metadata(self, metadata_table, *items, check=True, strict=False, cache=None, dry_run=False): - if cache is None: - cache = self.make_cache({"entity_metadata", "parameter_value_metadata", "metadata"}) + def _update_ext_item_metadata(self, metadata_table, *items, check=True, strict=False, dry_run=False): + self.fetch_all({"entity_metadata", "parameter_value_metadata", "metadata"}) + cache = self.cache metadata_ids = {} for entry in cache.get("metadata", {}).values(): metadata_ids.setdefault(entry.name, {})[entry.value] = entry.id item_metadata_cache = cache[metadata_table] - metadata_usage_counts = self._metadata_usage_counts(cache) + metadata_usage_counts = self._metadata_usage_counts() updatable_items = [] homeless_items = [] for item in items: @@ -325,12 +380,7 @@ def _update_ext_item_metadata(self, metadata_table, *items, check=True, strict=F errors = [] if updatable_metadata_items: updated_metadata, errors = self.update_metadata( - *updatable_metadata_items, - check=False, - strict=strict, - return_items=True, - cache=cache, - dry_run=dry_run, + *updatable_metadata_items, check=False, strict=strict, return_items=True, dry_run=dry_run ) all_items += updated_metadata if errors: @@ -341,7 +391,7 @@ def _update_ext_item_metadata(self, metadata_table, *items, check=True, strict=F added_metadata = [] if addable_metadata: added_metadata, metadata_add_errors = self.add_metadata( - *addable_metadata, check=False, strict=strict, return_items=True, cache=cache + *addable_metadata, check=False, strict=strict, return_items=True ) all_items += added_metadata errors += metadata_add_errors @@ -354,43 +404,10 @@ def _update_ext_item_metadata(self, metadata_table, *items, check=True, strict=F item["metadata_id"] = added_metadata_ids[item["metadata_name"]][item["metadata_value"]] updatable_items.append(item) if updatable_items: - # Force-clear cache before updating item metadata to ensure that added/updated metadata is found. + # FIXME: Force-clear cache before updating item metadata to ensure that added/updated metadata is found. updated_item_metadata, item_metadata_errors = self.update_items( metadata_table, *updatable_items, check=check, strict=strict, return_items=True ) all_items += updated_item_metadata errors += item_metadata_errors return all_items, errors - - def get_data_to_set_scenario_alternatives(self, *items, cache=None): - """Returns data to add and remove, in order to set wide scenario alternatives. - - Args: - *items: One or more wide scenario :class:`dict` objects to set. - Each item must include the following keys: - - - "id": integer scenario id - - "alternative_id_list": list of alternative ids for that scenario - - Returns - list: narrow scenario_alternative :class:`dict` objects to add. - set: integer scenario_alternative ids to remove - """ - if cache is None: - cache = self.make_cache("scenario_alternative", "scenario") - current_alternative_id_lists = {x.id: x.alternative_id_list for x in cache.get("scenario", {}).values()} - scenario_alternative_ids = { - (x.scenario_id, x.alternative_id): x.id for x in cache.get("scenario_alternative", {}).values() - } - items_to_add = list() - ids_to_remove = set() - for item in items: - scenario_id = item["id"] - alternative_id_list = item["alternative_id_list"] - current_alternative_id_list = current_alternative_id_lists[scenario_id] - for k, alternative_id in enumerate(alternative_id_list): - item_to_add = {"scenario_id": scenario_id, "alternative_id": alternative_id, "rank": k + 1} - items_to_add.append(item_to_add) - for alternative_id in current_alternative_id_list: - ids_to_remove.add(scenario_alternative_ids[scenario_id, alternative_id]) - return items_to_add, ids_to_remove diff --git a/spinedb_api/export_functions.py b/spinedb_api/export_functions.py index db011fe4..673d937c 100644 --- a/spinedb_api/export_functions.py +++ b/spinedb_api/export_functions.py @@ -40,7 +40,6 @@ def export_data( alternative_ids=Asterisk, scenario_ids=Asterisk, scenario_alternative_ids=Asterisk, - make_cache=None, parse_value=from_database, ): """ @@ -65,50 +64,44 @@ def export_data( dict: exported data """ data = { - "entity_classes": export_entity_classes(db_map, entity_class_ids, make_cache=make_cache), - "entities": export_entities(db_map, entity_ids, make_cache=make_cache), - "entity_groups": export_entity_groups(db_map, entity_group_ids, make_cache=make_cache), + "entity_classes": export_entity_classes(db_map, entity_class_ids), + "entities": export_entities(db_map, entity_ids), + "entity_groups": export_entity_groups(db_map, entity_group_ids), "parameter_definitions": export_parameter_definitions( - db_map, parameter_definition_ids, make_cache=make_cache, parse_value=parse_value + db_map, parameter_definition_ids, parse_value=parse_value ), - "parameter_values": export_parameter_values( - db_map, parameter_value_ids, make_cache=make_cache, parse_value=parse_value - ), - "object_classes": export_object_classes(db_map, object_class_ids, make_cache=make_cache), - "relationship_classes": export_relationship_classes(db_map, relationship_class_ids, make_cache=make_cache), + "parameter_values": export_parameter_values(db_map, parameter_value_ids, parse_value=parse_value), + "object_classes": export_object_classes(db_map, object_class_ids), + "relationship_classes": export_relationship_classes(db_map, relationship_class_ids), "parameter_value_lists": export_parameter_value_lists( - db_map, parameter_value_list_ids, make_cache=make_cache, parse_value=parse_value - ), - "object_parameters": export_object_parameters( - db_map, object_parameter_ids, make_cache=make_cache, parse_value=parse_value + db_map, parameter_value_list_ids, parse_value=parse_value ), + "object_parameters": export_object_parameters(db_map, object_parameter_ids, parse_value=parse_value), "relationship_parameters": export_relationship_parameters( - db_map, relationship_parameter_ids, make_cache=make_cache, parse_value=parse_value + db_map, relationship_parameter_ids, parse_value=parse_value ), - "objects": export_objects(db_map, object_ids, make_cache=make_cache), - "relationships": export_relationships(db_map, relationship_ids, make_cache=make_cache), - "object_groups": export_object_groups(db_map, object_group_ids, make_cache=make_cache), + "objects": export_objects(db_map, object_ids), + "relationships": export_relationships(db_map, relationship_ids), + "object_groups": export_object_groups(db_map, object_group_ids), "object_parameter_values": export_object_parameter_values( - db_map, object_parameter_value_ids, make_cache=make_cache, parse_value=parse_value + db_map, object_parameter_value_ids, parse_value=parse_value ), "relationship_parameter_values": export_relationship_parameter_values( - db_map, relationship_parameter_value_ids, make_cache=make_cache, parse_value=parse_value + db_map, relationship_parameter_value_ids, parse_value=parse_value ), - "alternatives": export_alternatives(db_map, alternative_ids, make_cache=make_cache), - "scenarios": export_scenarios(db_map, scenario_ids, make_cache=make_cache), - "scenario_alternatives": export_scenario_alternatives(db_map, scenario_alternative_ids, make_cache=make_cache), + "alternatives": export_alternatives(db_map, alternative_ids), + "scenarios": export_scenarios(db_map, scenario_ids), + "scenario_alternatives": export_scenario_alternatives(db_map, scenario_alternative_ids), } return {key: value for key, value in data.items() if value} -def _get_items(db_map, tablename, ids, make_cache): +def _get_items(db_map, tablename, ids): if not ids: return () - if make_cache is None: - make_cache = db_map.make_cache - cache = make_cache({tablename}, include_ancestors=True) - _process_item = _make_item_processor(tablename, make_cache) - for item in _get_items_from_cache(cache, tablename, ids): + db_map.fetch_all({tablename}, include_ancestors=True) + _process_item = _make_item_processor(db_map.cache, tablename) + for item in _get_items_from_cache(db_map.cache, tablename, ids): yield from _process_item(item) @@ -123,15 +116,15 @@ def _get_items_from_cache(cache, tablename, ids): yield item -def _make_item_processor(tablename, make_cache): +def _make_item_processor(cache, tablename): if tablename == "parameter_value_list": - return _ParameterValueListProcessor(make_cache) + return _ParameterValueListProcessor(cache) return lambda item: (item,) class _ParameterValueListProcessor: - def __init__(self, make_cache): - self._cache = make_cache({"list_value"}) + def __init__(self, cache): + self._list_value_by_id = cache.get("list_value", {}) def __call__(self, item): fields = ["name", "value", "type"] @@ -139,44 +132,39 @@ def __call__(self, item): yield KeyedTuple([item.name, None, None], fields) return for value_id in item.value_id_list: - val = self._cache["list_value"][value_id] + val = self._list_value_by_id[value_id] yield KeyedTuple([item.name, val.value, val.type], fields) -def export_parameter_value_lists(db_map, ids=Asterisk, make_cache=None, parse_value=from_database): +def export_parameter_value_lists(db_map, ids=Asterisk, parse_value=from_database): return sorted( - ((x.name, parse_value(x.value, x.type)) for x in _get_items(db_map, "parameter_value_list", ids, make_cache)), + ((x.name, parse_value(x.value, x.type)) for x in _get_items(db_map, "parameter_value_list", ids)), key=itemgetter(0), ) -def export_entity_classes(db_map, ids=Asterisk, make_cache=None): +def export_entity_classes(db_map, ids=Asterisk): return sorted( ( (x.name, x.dimension_name_list, x.description, x.display_icon) - for x in _get_items(db_map, "entity_class", ids, make_cache) + for x in _get_items(db_map, "entity_class", ids) ), key=lambda x: (len(x[1]), x[0]), ) -def export_entities(db_map, ids=Asterisk, make_cache=None): +def export_entities(db_map, ids=Asterisk): return sorted( - ( - (x.class_name, x.element_name_list or x.name, x.description) - for x in _get_items(db_map, "entity", ids, make_cache) - ), + ((x.class_name, x.element_name_list or x.name, x.description) for x in _get_items(db_map, "entity", ids)), key=lambda x: (0 if isinstance(x[1], str) else len(x[1]), x[0]), ) -def export_entity_groups(db_map, ids=Asterisk, make_cache=None): - return sorted( - (x.class_name, x.group_name, x.member_name) for x in _get_items(db_map, "entity_group", ids, make_cache) - ) +def export_entity_groups(db_map, ids=Asterisk): + return sorted((x.class_name, x.group_name, x.member_name) for x in _get_items(db_map, "entity_group", ids)) -def export_parameter_definitions(db_map, ids=Asterisk, make_cache=None, parse_value=from_database): +def export_parameter_definitions(db_map, ids=Asterisk, parse_value=from_database): return sorted( ( x.entity_class_name, @@ -185,11 +173,11 @@ def export_parameter_definitions(db_map, ids=Asterisk, make_cache=None, parse_va x.value_list_name, x.description, ) - for x in _get_items(db_map, "parameter_definition", ids, make_cache) + for x in _get_items(db_map, "parameter_definition", ids) ) -def export_parameter_values(db_map, ids=Asterisk, make_cache=None, parse_value=from_database): +def export_parameter_values(db_map, ids=Asterisk, parse_value=from_database): return sorted( ( ( @@ -199,51 +187,47 @@ def export_parameter_values(db_map, ids=Asterisk, make_cache=None, parse_value=f parse_value(x.value, x.type), x.alternative_name, ) - for x in _get_items(db_map, "parameter_value", ids, make_cache) + for x in _get_items(db_map, "parameter_value", ids) ), key=lambda x: x[:3] + (x[-1],), ) -def export_object_classes(db_map, ids=Asterisk, make_cache=None): +def export_object_classes(db_map, ids=Asterisk): return sorted( (x.name, x.description, x.display_icon) - for x in _get_items(db_map, "entity_class", ids, make_cache) + for x in _get_items(db_map, "entity_class", ids) if not x.dimension_id_list ) -def export_relationship_classes(db_map, ids=Asterisk, make_cache=None): +def export_relationship_classes(db_map, ids=Asterisk): return sorted( (x.name, x.dimension_name_list, x.description, x.display_icon) - for x in _get_items(db_map, "entity_class", ids, make_cache) + for x in _get_items(db_map, "entity_class", ids) if x.dimension_id_list ) -def export_objects(db_map, ids=Asterisk, make_cache=None): +def export_objects(db_map, ids=Asterisk): return sorted( - (x.class_name, x.name, x.description) - for x in _get_items(db_map, "entity", ids, make_cache) - if not x.element_id_list + (x.class_name, x.name, x.description) for x in _get_items(db_map, "entity", ids) if not x.element_id_list ) -def export_relationships(db_map, ids=Asterisk, make_cache=None): - return sorted( - (x.class_name, x.element_name_list) for x in _get_items(db_map, "entity", ids, make_cache) if x.element_id_list - ) +def export_relationships(db_map, ids=Asterisk): + return sorted((x.class_name, x.element_name_list) for x in _get_items(db_map, "entity", ids) if x.element_id_list) -def export_object_groups(db_map, ids=Asterisk, make_cache=None): +def export_object_groups(db_map, ids=Asterisk): return sorted( (x.class_name, x.group_name, x.member_name) - for x in _get_items(db_map, "entity_group", ids, make_cache) + for x in _get_items(db_map, "entity_group", ids) if not x.dimension_id_list ) -def export_object_parameters(db_map, ids=Asterisk, make_cache=None, parse_value=from_database): +def export_object_parameters(db_map, ids=Asterisk, parse_value=from_database): return sorted( ( x.entity_class_name, @@ -252,12 +236,12 @@ def export_object_parameters(db_map, ids=Asterisk, make_cache=None, parse_value= x.value_list_name, x.description, ) - for x in _get_items(db_map, "parameter_definition", ids, make_cache) + for x in _get_items(db_map, "parameter_definition", ids) if not x.dimension_id_list ) -def export_relationship_parameters(db_map, ids=Asterisk, make_cache=None, parse_value=from_database): +def export_relationship_parameters(db_map, ids=Asterisk, parse_value=from_database): return sorted( ( x.entity_class_name, @@ -266,23 +250,23 @@ def export_relationship_parameters(db_map, ids=Asterisk, make_cache=None, parse_ x.value_list_name, x.description, ) - for x in _get_items(db_map, "parameter_definition", ids, make_cache) + for x in _get_items(db_map, "parameter_definition", ids) if x.dimension_id_list ) -def export_object_parameter_values(db_map, ids=Asterisk, make_cache=None, parse_value=from_database): +def export_object_parameter_values(db_map, ids=Asterisk, parse_value=from_database): return sorted( ( (x.entity_class_name, x.entity_name, x.parameter_name, parse_value(x.value, x.type), x.alternative_name) - for x in _get_items(db_map, "parameter_value", ids, make_cache) + for x in _get_items(db_map, "parameter_value", ids) if not x.element_id_list ), key=lambda x: x[:3] + (x[-1],), ) -def export_relationship_parameter_values(db_map, ids=Asterisk, make_cache=None, parse_value=from_database): +def export_relationship_parameter_values(db_map, ids=Asterisk, parse_value=from_database): return sorted( ( ( @@ -292,14 +276,14 @@ def export_relationship_parameter_values(db_map, ids=Asterisk, make_cache=None, parse_value(x.value, x.type), x.alternative_name, ) - for x in _get_items(db_map, "parameter_value", ids, make_cache) + for x in _get_items(db_map, "parameter_value", ids) if x.element_id_list ), key=lambda x: x[:3] + (x[-1],), ) -def export_alternatives(db_map, ids=Asterisk, make_cache=None): +def export_alternatives(db_map, ids=Asterisk): """ Exports alternatives from database. @@ -312,10 +296,10 @@ def export_alternatives(db_map, ids=Asterisk, make_cache=None): Returns: Iterable: tuples of two elements: name of alternative and description """ - return sorted((x.name, x.description) for x in _get_items(db_map, "alternative", ids, make_cache)) + return sorted((x.name, x.description) for x in _get_items(db_map, "alternative", ids)) -def export_scenarios(db_map, ids=Asterisk, make_cache=None): +def export_scenarios(db_map, ids=Asterisk): """ Exports scenarios from database. @@ -328,10 +312,10 @@ def export_scenarios(db_map, ids=Asterisk, make_cache=None): Returns: Iterable: tuples of two elements: name of scenario and description """ - return sorted((x.name, x.active, x.description) for x in _get_items(db_map, "scenario", ids, make_cache)) + return sorted((x.name, x.active, x.description) for x in _get_items(db_map, "scenario", ids)) -def export_scenario_alternatives(db_map, ids=Asterisk, make_cache=None): +def export_scenario_alternatives(db_map, ids=Asterisk): """ Exports scenario alternatives from database. @@ -348,7 +332,7 @@ def export_scenario_alternatives(db_map, ids=Asterisk, make_cache=None): return sorted( ( (x.scenario_name, x.alternative_name, x.before_alternative_name) - for x in _get_items(db_map, "scenario_alternative", ids, make_cache) + for x in _get_items(db_map, "scenario_alternative", ids) ), key=itemgetter(0), ) diff --git a/spinedb_api/filters/execution_filter.py b/spinedb_api/filters/execution_filter.py index a58ed37a..b4741f62 100644 --- a/spinedb_api/filters/execution_filter.py +++ b/spinedb_api/filters/execution_filter.py @@ -143,7 +143,7 @@ def _parse_execution_descriptor(self, execution): return execution_item, scenarios, timestamp -def _create_import_alternative(db_map, state, cache=None): +def _create_import_alternative(db_map, state): """ Creates an alternative to use as default for all import operations on the given db_map. diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 7ba8a770..f032e385 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -50,7 +50,7 @@ def __repr__(self): return self.msg -def import_data(db_map, make_cache=None, unparse_value=to_database, on_conflict="merge", **kwargs): +def import_data(db_map, unparse_value=to_database, on_conflict="merge", **kwargs): """Imports data into a Spine database using name references (rather than id references). Example:: @@ -147,7 +147,7 @@ def import_data(db_map, make_cache=None, unparse_value=to_database, on_conflict= error_log = [] num_imports = 0 for tablename, (to_add, to_update, errors) in get_data_for_import( - db_map, make_cache=make_cache, unparse_value=unparse_value, on_conflict=on_conflict, **kwargs + db_map, unparse_value=unparse_value, on_conflict=on_conflict, **kwargs ): update_items = update_items_by_tablename.get(tablename, lambda *args, **kwargs: ()) try: @@ -168,7 +168,6 @@ def import_data(db_map, make_cache=None, unparse_value=to_database, on_conflict= def get_data_for_import( db_map, - make_cache=None, unparse_value=to_database, on_conflict="merge", dry_run=False, @@ -225,95 +224,89 @@ def get_data_for_import( Returns: dict(str, list) """ - if make_cache is None: - make_cache = db_map.make_cache # NOTE: The order is important, because of references. E.g., we want to import alternatives before parameter_values if alternatives: - yield ("alternative", _get_alternatives_for_import(alternatives, make_cache)) + yield ("alternative", _get_alternatives_for_import(alternatives)) if scenarios: - yield ("scenario", _get_scenarios_for_import(scenarios, make_cache)) + yield ("scenario", _get_scenarios_for_import(scenarios)) if scenario_alternatives: if not scenarios: scenarios = (item[0] for item in scenario_alternatives) - yield ("scenario", _get_scenarios_for_import(scenarios, make_cache)) + yield ("scenario", _get_scenarios_for_import(scenarios)) if not alternatives: alternatives = (item[1] for item in scenario_alternatives) - yield ("alternative", _get_alternatives_for_import(alternatives, make_cache)) - yield ("scenario_alternative", _get_scenario_alternatives_for_import(scenario_alternatives, make_cache)) + yield ("alternative", _get_alternatives_for_import(alternatives)) + yield ("scenario_alternative", _get_scenario_alternatives_for_import(scenario_alternatives)) if entity_classes: - yield ("entity_class", _get_entity_classes_for_import(db_map, entity_classes, make_cache, dry_run)) + yield ("entity_class", _get_entity_classes_for_import(db_map, entity_classes, dry_run)) if object_classes: - yield ("object_class", _get_object_classes_for_import(db_map, object_classes, make_cache)) + yield ("object_class", _get_object_classes_for_import(db_map, object_classes)) if relationship_classes: - yield ("relationship_class", _get_relationship_classes_for_import(db_map, relationship_classes, make_cache)) + yield ("relationship_class", _get_relationship_classes_for_import(db_map, relationship_classes)) if parameter_value_lists: - yield ("parameter_value_list", _get_parameter_value_lists_for_import(db_map, parameter_value_lists, make_cache)) - yield ("list_value", _get_list_values_for_import(db_map, parameter_value_lists, make_cache, unparse_value)) + yield ("parameter_value_list", _get_parameter_value_lists_for_import(db_map, parameter_value_lists)) + yield ("list_value", _get_list_values_for_import(db_map, parameter_value_lists, unparse_value)) if parameter_definitions: yield ( "parameter_definition", - _get_parameter_definitions_for_import(db_map, parameter_definitions, make_cache, unparse_value), + _get_parameter_definitions_for_import(db_map, parameter_definitions, unparse_value), ) if object_parameters: yield ( "parameter_definition", - _get_object_parameters_for_import(db_map, object_parameters, make_cache, unparse_value), + _get_object_parameters_for_import(db_map, object_parameters, unparse_value), ) if relationship_parameters: yield ( "parameter_definition", - _get_relationship_parameters_for_import(db_map, relationship_parameters, make_cache, unparse_value), + _get_relationship_parameters_for_import(db_map, relationship_parameters, unparse_value), ) if entities: - yield ("entity", _get_entities_for_import(db_map, entities, make_cache, dry_run)) + yield ("entity", _get_entities_for_import(db_map, entities, dry_run)) if objects: - yield ("object", _get_objects_for_import(db_map, objects, make_cache)) + yield ("object", _get_objects_for_import(db_map, objects)) if relationships: - yield ("relationship", _get_relationships_for_import(db_map, relationships, make_cache)) + yield ("relationship", _get_relationships_for_import(db_map, relationships)) if entity_groups: - yield ("entity_group", _get_entity_groups_for_import(db_map, entity_groups, make_cache)) + yield ("entity_group", _get_entity_groups_for_import(db_map, entity_groups)) if object_groups: - yield ("entity_group", _get_object_groups_for_import(db_map, object_groups, make_cache)) + yield ("entity_group", _get_object_groups_for_import(db_map, object_groups)) if parameter_values: yield ( "parameter_value", - _get_parameter_values_for_import(db_map, parameter_values, make_cache, unparse_value, on_conflict), + _get_parameter_values_for_import(db_map, parameter_values, unparse_value, on_conflict), ) if object_parameter_values: yield ( "parameter_value", - _get_object_parameter_values_for_import( - db_map, object_parameter_values, make_cache, unparse_value, on_conflict - ), + _get_object_parameter_values_for_import(db_map, object_parameter_values, unparse_value, on_conflict), ) if relationship_parameter_values: yield ( "parameter_value", _get_relationship_parameter_values_for_import( - db_map, relationship_parameter_values, make_cache, unparse_value, on_conflict + db_map, relationship_parameter_values, unparse_value, on_conflict ), ) if metadata: - yield ("metadata", _get_metadata_for_import(db_map, metadata, make_cache)) + yield ("metadata", _get_metadata_for_import(db_map, metadata)) if object_metadata: - yield ("entity_metadata", _get_object_metadata_for_import(db_map, object_metadata, make_cache)) + yield ("entity_metadata", _get_object_metadata_for_import(db_map, object_metadata)) if relationship_metadata: - yield ("entity_metadata", _get_relationship_metadata_for_import(db_map, relationship_metadata, make_cache)) + yield ("entity_metadata", _get_relationship_metadata_for_import(db_map, relationship_metadata)) if object_parameter_value_metadata: yield ( "parameter_value_metadata", - _get_object_parameter_value_metadata_for_import(db_map, object_parameter_value_metadata, make_cache), + _get_object_parameter_value_metadata_for_import(db_map, object_parameter_value_metadata), ) if relationship_parameter_value_metadata: yield ( "parameter_value_metadata", - _get_relationship_parameter_value_metadata_for_import( - db_map, relationship_parameter_value_metadata, make_cache - ), + _get_relationship_parameter_value_metadata_for_import(db_map, relationship_parameter_value_metadata), ) -def import_entity_classes(db_map, data, make_cache=None): +def import_entity_classes(db_map, data): """Imports entity classes. Example:: @@ -333,11 +326,12 @@ def import_entity_classes(db_map, data, make_cache=None): Returns: tuple of int and list: Number of successfully inserted object classes, list of errors """ - return import_data(db_map, entity_classes=data, make_cache=make_cache) + return import_data(db_map, entity_classes=data) -def _get_entity_classes_for_import(db_map, data, make_cache, dry_run): - cache = make_cache({"entity_class"}, include_ancestors=True) +def _get_entity_classes_for_import(db_map, data, dry_run): + db_map.fetch_all({"entity_class"}, include_ancestors=True) + cache = db_map.cache entity_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values()} checked = set() error_log = [] @@ -377,7 +371,7 @@ def _get_entity_classes_for_import(db_map, data, make_cache, dry_run): return to_add, to_update, error_log -def import_entities(db_map, data, make_cache=None): +def import_entities(db_map, data): """Imports entities. Example:: @@ -397,7 +391,7 @@ def import_entities(db_map, data, make_cache=None): Returns: (Int, List) Number of successful inserted entities, list of errors """ - return import_data(db_map, entities=data, make_cache=make_cache) + return import_data(db_map, entities=data) def _make_unique_entity_name(class_id, class_name, ent_name_or_el_names, class_id_name_tuples): @@ -410,8 +404,9 @@ def _make_unique_entity_name(class_id, class_name, ent_name_or_el_names, class_i return name -def _get_entities_for_import(db_map, data, make_cache, dry_run): - cache = make_cache({"entity"}, include_ancestors=True) +def _get_entities_for_import(db_map, data, dry_run): + db_map.fetch_all({"entity"}, include_ancestors=True) + cache = db_map.cache entities = {x.id: x for x in cache.get("entity", {}).values()} entity_ids_per_name = {(x.class_id, x.name): x.id for x in cache.get("entity", {}).values()} entity_ids_per_el_id_lst = { @@ -480,7 +475,7 @@ def _get_entities_for_import(db_map, data, make_cache, dry_run): return to_add, to_update, error_log -def import_entity_groups(db_map, data, make_cache=None): +def import_entity_groups(db_map, data): """Imports list of entity groups by name with associated class name into given database mapping: Ignores duplicate and existing (group, member) tuples. @@ -500,11 +495,12 @@ def import_entity_groups(db_map, data, make_cache=None): Returns: (Int, List) Number of successful inserted entity groups, list of errors """ - return import_data(db_map, entity_groups=data, make_cache=make_cache) + return import_data(db_map, entity_groups=data) -def _get_entity_groups_for_import(db_map, data, make_cache): - cache = make_cache({"entity_group"}, include_ancestors=True) +def _get_entity_groups_for_import(db_map, data): + db_map.fetch_all({"entity_group"}, include_ancestors=True) + cache = db_map.cache entity_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values()} entity_ids = {(x.class_id, x.name): x.id for x in cache.get("entity", {}).values()} entities = {} @@ -535,7 +531,7 @@ def _get_entity_groups_for_import(db_map, data, make_cache): return to_add, [], error_log -def import_parameter_definitions(db_map, data, make_cache=None, unparse_value=to_database): +def import_parameter_definitions(db_map, data, unparse_value=to_database): """Imports list of parameter definitions: Example:: @@ -554,11 +550,12 @@ def import_parameter_definitions(db_map, data, make_cache=None, unparse_value=to Returns: (Int, List) Number of successful inserted parameter definitions, list of errors """ - return import_data(db_map, parameter_definitions=data, make_cache=make_cache, unparse_value=unparse_value) + return import_data(db_map, parameter_definitions=data, unparse_value=unparse_value) -def _get_parameter_definitions_for_import(db_map, data, make_cache, unparse_value): - cache = make_cache({"parameter_definition"}, include_ancestors=True) +def _get_parameter_definitions_for_import(db_map, data, unparse_value): + db_map.fetch_all({"parameter_definition"}, include_ancestors=True) + cache = db_map.cache parameter_definition_ids = { (x.entity_class_id, x.parameter_name): x.id for x in cache.get("parameter_definition", {}).values() } @@ -620,7 +617,7 @@ def _get_parameter_definitions_for_import(db_map, data, make_cache, unparse_valu return to_add, to_update, error_log -def import_parameter_values(db_map, data, make_cache=None, unparse_value=to_database, on_conflict="merge"): +def import_parameter_values(db_map, data, unparse_value=to_database, on_conflict="merge"): """Imports parameter values: Example:: @@ -640,13 +637,12 @@ def import_parameter_values(db_map, data, make_cache=None, unparse_value=to_data Returns: (Int, List) Number of successful inserted parameter values, list of errors """ - return import_data( - db_map, parameter_values=data, make_cache=make_cache, unparse_value=unparse_value, on_conflict=on_conflict - ) + return import_data(db_map, parameter_values=data, unparse_value=unparse_value, on_conflict=on_conflict) -def _get_parameter_values_for_import(db_map, data, make_cache, unparse_value, on_conflict): - cache = make_cache({"parameter_value"}, include_ancestors=True) +def _get_parameter_values_for_import(db_map, data, unparse_value, on_conflict): + db_map.fetch_all({"parameter_value"}, include_ancestors=True) + cache = db_map.cache dimension_id_lists = {x.id: x.dimension_id_list for x in cache.get("entity_class", {}).values()} parameter_value_ids = { (x.entity_id, x.parameter_id, x.alternative_id): x.id for x in cache.get("parameter_value", {}).values() @@ -696,7 +692,7 @@ def _get_parameter_values_for_import(db_map, data, make_cache, unparse_value, on ) continue else: - alt_id, alternative_name = db_map.get_import_alternative(cache=cache) + alt_id, alternative_name = db_map.get_import_alternative() alternative_ids.add(alt_id) checked_key = (e_id, p_id, alt_id) if checked_key in checked: @@ -751,7 +747,7 @@ def _get_parameter_values_for_import(db_map, data, make_cache, unparse_value, on return to_add, to_update, error_log -def import_alternatives(db_map, data, make_cache=None): +def import_alternatives(db_map, data): """ Imports alternatives. @@ -768,11 +764,12 @@ def import_alternatives(db_map, data, make_cache=None): Returns: tuple of int and list: Number of successfully inserted alternatives, list of errors """ - return import_data(db_map, alternatives=data, make_cache=make_cache) + return import_data(db_map, alternatives=data) -def _get_alternatives_for_import(data, make_cache): - cache = make_cache({"alternative"}, include_ancestors=True) +def _get_alternatives_for_import(data): + db_map.fetch_all({"alternative"}, include_ancestors=True) + cache = db_map.cache alternative_ids = {alternative.name: alternative.id for alternative in cache.get("alternative", {}).values()} checked = set() to_add = [] @@ -810,7 +807,7 @@ def _get_alternatives_for_import(data, make_cache): return to_add, to_update, error_log -def import_scenarios(db_map, data, make_cache=None): +def import_scenarios(db_map, data): """ Imports scenarios. @@ -829,11 +826,12 @@ def import_scenarios(db_map, data, make_cache=None): Returns: tuple of int and list: Number of successfully inserted scenarios, list of errors """ - return import_data(db_map, scenarios=data, make_cache=make_cache) + return import_data(db_map, scenarios=data) -def _get_scenarios_for_import(data, make_cache): - cache = make_cache({"scenario"}, include_ancestors=True) +def _get_scenarios_for_import(data): + db_map.fetch_all({"scenario"}, include_ancestors=True) + cache = db_map.cache scenario_ids = {scenario.name: scenario.id for scenario in cache.get("scenario", {}).values()} checked = set() to_add = [] @@ -869,7 +867,7 @@ def _get_scenarios_for_import(data, make_cache): return to_add, to_update, error_log -def import_scenario_alternatives(db_map, data, make_cache=None): +def import_scenario_alternatives(db_map, data): """ Imports scenario alternatives. @@ -888,11 +886,12 @@ def import_scenario_alternatives(db_map, data, make_cache=None): Returns: tuple of int and list: Number of successfully inserted scenario alternatives, list of errors """ - return import_data(db_map, scenario_alternatives=data, make_cache=make_cache) + return import_data(db_map, scenario_alternatives=data) -def _get_scenario_alternatives_for_import(data, make_cache): - cache = make_cache({"scenario_alternative"}, include_ancestors=True) +def _get_scenario_alternatives_for_import(data): + db_map.fetch_all({"scenario_alternative"}, include_ancestors=True) + cache = db_map.cache scenario_alternative_id_lists = {x.id: x.alternative_id_list for x in cache.get("scenario", {}).values()} scenario_alternative_ids = { (x.scenario_id, x.alternative_id): x.id for x in cache.get("scenario_alternative", {}).values() @@ -950,7 +949,7 @@ def _get_scenario_alternatives_for_import(data, make_cache): return to_add, to_update, error_log -def import_object_classes(db_map, data, make_cache=None): +def import_object_classes(db_map, data): """Imports object classes. Example:: @@ -966,11 +965,12 @@ def import_object_classes(db_map, data, make_cache=None): Returns: tuple of int and list: Number of successfully inserted object classes, list of errors """ - return import_data(db_map, object_classes=data, make_cache=make_cache) + return import_data(db_map, object_classes=data) -def _get_object_classes_for_import(db_map, data, make_cache): - cache = make_cache({"entity_class"}, include_ancestors=True) +def _get_object_classes_for_import(db_map, data): + db_map.fetch_all({"entity_class"}, include_ancestors=True) + cache = db_map.cache object_class_ids = {oc.name: oc.id for oc in cache.get("entity_class", {}).values() if not oc.dimension_id_list} checked = set() to_add = [] @@ -1008,7 +1008,7 @@ def _get_object_classes_for_import(db_map, data, make_cache): return to_add, to_update, error_log -def import_relationship_classes(db_map, data, make_cache=None): +def import_relationship_classes(db_map, data): """Imports relationship classes. Example:: @@ -1027,11 +1027,12 @@ def import_relationship_classes(db_map, data, make_cache=None): Returns: (Int, List) Number of successful inserted objects, list of errors """ - return import_data(db_map, relationship_classes=data, make_cache=make_cache) + return import_data(db_map, relationship_classes=data) -def _get_relationship_classes_for_import(db_map, data, make_cache): - cache = make_cache({"entity_class"}, include_ancestors=True) +def _get_relationship_classes_for_import(db_map, data): + db_map.fetch_all({"entity_class"}, include_ancestors=True) + cache = db_map.cache object_class_ids = {oc.name: oc.id for oc in cache.get("entity_class", {}).values() if not oc.dimension_id_list} relationship_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values() if x.dimension_id_list} checked = set() @@ -1076,7 +1077,7 @@ def _get_relationship_classes_for_import(db_map, data, make_cache): return to_add, to_update, error_log -def import_objects(db_map, data, make_cache=None): +def import_objects(db_map, data): """Imports list of object by name with associated object class name into given database mapping: Ignores duplicate names and existing names. @@ -1095,11 +1096,12 @@ def import_objects(db_map, data, make_cache=None): Returns: (Int, List) Number of successful inserted objects, list of errors """ - return import_data(db_map, objects=data, make_cache=make_cache) + return import_data(db_map, objects=data) -def _get_objects_for_import(db_map, data, make_cache): - cache = make_cache({"entity"}, include_ancestors=True) +def _get_objects_for_import(db_map, data): + db_map.fetch_all({"entity"}, include_ancestors=True) + cache = db_map.cache object_class_ids = {oc.name: oc.id for oc in cache.get("entity_class", {}).values() if not oc.dimension_id_list} object_ids = {(o.class_id, o.name): o.id for o in cache.get("entity", {}).values() if not o.element_id_list} checked = set() @@ -1138,7 +1140,7 @@ def _get_objects_for_import(db_map, data, make_cache): return to_add, to_update, error_log -def import_object_groups(db_map, data, make_cache=None): +def import_object_groups(db_map, data): """Imports list of object groups by name with associated object class name into given database mapping: Ignores duplicate and existing (group, member) tuples. @@ -1158,11 +1160,12 @@ def import_object_groups(db_map, data, make_cache=None): Returns: (Int, List) Number of successful inserted objects, list of errors """ - return import_data(db_map, object_groups=data, make_cache=make_cache) + return import_data(db_map, object_groups=data) -def _get_object_groups_for_import(db_map, data, make_cache): - cache = make_cache({"entity_group"}, include_ancestors=True) +def _get_object_groups_for_import(db_map, data): + db_map.fetch_all({"entity_group"}, include_ancestors=True) + cache = db_map.cache object_class_ids = {oc.name: oc.id for oc in cache.get("entity_class", {}).values() if not oc.dimension_id_list} object_ids = {(o.class_id, o.name): o.id for o in cache.get("entity", {}).values() if not o.element_id_list} objects = {} @@ -1194,7 +1197,7 @@ def _get_object_groups_for_import(db_map, data, make_cache): return to_add, [], error_log -def import_relationships(db_map, data, make_cache=None): +def import_relationships(db_map, data): """Imports relationships. Example:: @@ -1210,7 +1213,7 @@ def import_relationships(db_map, data, make_cache=None): Returns: (Int, List) Number of successful inserted objects, list of errors """ - return import_data(db_map, relationships=data, make_cache=make_cache) + return import_data(db_map, relationships=data) def _make_unique_relationship_name(class_id, class_name, object_names, class_id_name_tuples): @@ -1221,8 +1224,9 @@ def _make_unique_relationship_name(class_id, class_name, object_names, class_id_ return name -def _get_relationships_for_import(db_map, data, make_cache): - cache = make_cache({"entity"}, include_ancestors=True) +def _get_relationships_for_import(db_map, data): + db_map.fetch_all({"entity"}, include_ancestors=True) + cache = db_map.cache relationships = {x.name: x for x in cache.get("entity", {}).values() if x.element_id_list} relationship_ids_per_name = {(x.class_id, x.name): x.id for x in relationships.values()} relationship_ids_per_obj_lst = {(x.class_id, x.element_id_list): x.id for x in relationships.values()} @@ -1291,7 +1295,7 @@ def _get_relationships_for_import(db_map, data, make_cache): return to_add, to_update, error_log -def import_object_parameters(db_map, data, make_cache=None, unparse_value=to_database): +def import_object_parameters(db_map, data, unparse_value=to_database): """Imports list of object class parameters: Example:: @@ -1310,11 +1314,12 @@ def import_object_parameters(db_map, data, make_cache=None, unparse_value=to_dat Returns: (Int, List) Number of successful inserted objects, list of errors """ - return import_data(db_map, object_parameters=data, make_cache=make_cache, unparse_value=unparse_value) + return import_data(db_map, object_parameters=data, unparse_value=unparse_value) -def _get_object_parameters_for_import(db_map, data, make_cache, unparse_value): - cache = make_cache({"parameter_definition"}, include_ancestors=True) +def _get_object_parameters_for_import(db_map, data, unparse_value): + db_map.fetch_all({"parameter_definition"}, include_ancestors=True) + cache = db_map.cache parameter_ids = { (x.entity_class_id, x.parameter_name): x.id for x in cache.get("parameter_definition", {}).values() } @@ -1376,7 +1381,7 @@ def _get_object_parameters_for_import(db_map, data, make_cache, unparse_value): return to_add, to_update, error_log -def import_relationship_parameters(db_map, data, make_cache=None, unparse_value=to_database): +def import_relationship_parameters(db_map, data, unparse_value=to_database): """Imports list of relationship class parameters: Example:: @@ -1395,11 +1400,12 @@ def import_relationship_parameters(db_map, data, make_cache=None, unparse_value= Returns: (Int, List) Number of successful inserted objects, list of errors """ - return import_data(db_map, relationship_parameters=data, make_cache=make_cache, unparse_value=unparse_value) + return import_data(db_map, relationship_parameters=data, unparse_value=unparse_value) -def _get_relationship_parameters_for_import(db_map, data, make_cache, unparse_value): - cache = make_cache({"parameter_definition"}, include_ancestors=True) +def _get_relationship_parameters_for_import(db_map, data, unparse_value): + db_map.fetch_all({"parameter_definition"}, include_ancestors=True) + cache = db_map.cache parameter_ids = { (x.entity_class_id, x.parameter_name): x.id for x in cache.get("parameter_definition", {}).values() } @@ -1462,7 +1468,7 @@ def _get_relationship_parameters_for_import(db_map, data, make_cache, unparse_va return to_add, to_update, error_log -def import_object_parameter_values(db_map, data, make_cache=None, unparse_value=to_database, on_conflict="merge"): +def import_object_parameter_values(db_map, data, unparse_value=to_database, on_conflict="merge"): """Imports object parameter values: Example:: @@ -1484,14 +1490,14 @@ def import_object_parameter_values(db_map, data, make_cache=None, unparse_value= return import_data( db_map, object_parameter_values=data, - make_cache=make_cache, unparse_value=unparse_value, on_conflict=on_conflict, ) -def _get_object_parameter_values_for_import(db_map, data, make_cache, unparse_value, on_conflict): - cache = make_cache({"parameter_value"}, include_ancestors=True) +def _get_object_parameter_values_for_import(db_map, data, unparse_value, on_conflict): + db_map.fetch_all({"parameter_value"}, include_ancestors=True) + cache = db_map.cache object_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values() if not x.dimension_id_list} parameter_value_ids = { (x.entity_id, x.parameter_id, x.alternative_id): x.id for x in cache.get("parameter_value", {}).values() @@ -1539,7 +1545,7 @@ def _get_object_parameter_values_for_import(db_map, data, make_cache, unparse_va ) continue else: - alt_id, alternative_name = db_map.get_import_alternative(cache=cache) + alt_id, alternative_name = db_map.get_import_alternative() alternative_ids.add(alt_id) checked_key = (o_id, p_id, alt_id) if checked_key in checked: @@ -1591,7 +1597,7 @@ def _get_object_parameter_values_for_import(db_map, data, make_cache, unparse_va return to_add, to_update, error_log -def import_relationship_parameter_values(db_map, data, make_cache=None, unparse_value=to_database, on_conflict="merge"): +def import_relationship_parameter_values(db_map, data, unparse_value=to_database, on_conflict="merge"): """Imports relationship parameter values: Example:: @@ -1613,14 +1619,14 @@ def import_relationship_parameter_values(db_map, data, make_cache=None, unparse_ return import_data( db_map, relationship_parameter_values=data, - make_cache=make_cache, unparse_value=unparse_value, on_conflict=on_conflict, ) -def _get_relationship_parameter_values_for_import(db_map, data, make_cache, unparse_value, on_conflict): - cache = make_cache({"parameter_value"}, include_ancestors=True) +def _get_relationship_parameter_values_for_import(db_map, data, unparse_value, on_conflict): + db_map.fetch_all({"parameter_value"}, include_ancestors=True) + cache = db_map.cache object_class_id_lists = { x.id: x.dimension_id_list for x in cache.get("entity_class", {}).values() if x.dimension_id_list } @@ -1677,7 +1683,7 @@ def _get_relationship_parameter_values_for_import(db_map, data, make_cache, unpa ) continue else: - alt_id, alternative_name = db_map.get_import_alternative(cache=cache) + alt_id, alternative_name = db_map.get_import_alternative() alternative_ids.add(alt_id) checked_key = (r_id, p_id, alt_id) if checked_key in checked: @@ -1735,7 +1741,7 @@ def _get_relationship_parameter_values_for_import(db_map, data, make_cache, unpa return to_add, to_update, error_log -def import_parameter_value_lists(db_map, data, make_cache=None, unparse_value=to_database): +def import_parameter_value_lists(db_map, data, unparse_value=to_database): """Imports list of parameter value lists: Example:: @@ -1754,11 +1760,12 @@ def import_parameter_value_lists(db_map, data, make_cache=None, unparse_value=to Returns: (Int, List) Number of successful inserted objects, list of errors """ - return import_data(db_map, parameter_value_lists=data, make_cache=make_cache, unparse_value=unparse_value) + return import_data(db_map, parameter_value_lists=data, unparse_value=unparse_value) -def _get_parameter_value_lists_for_import(db_map, data, make_cache): - cache = make_cache({"parameter_value_list"}, include_ancestors=True) +def _get_parameter_value_lists_for_import(db_map, data): + db_map.fetch_all({"parameter_value_list"}, include_ancestors=True) + cache = db_map.cache parameter_value_list_ids = {x.name: x.id for x in cache.get("parameter_value_list", {}).values()} error_log = [] to_add = [] @@ -1772,9 +1779,18 @@ def _get_parameter_value_lists_for_import(db_map, data, make_cache): return to_add, [], error_log -def _get_list_values_for_import(db_map, data, make_cache, unparse_value): - cache = make_cache({"list_value"}, include_ancestors=True) - value_lists_by_name = {x.name: (x.id, x.value_index_list) for x in cache.get("parameter_value_list", {}).values()} +def _get_list_values_for_import(db_map, data, unparse_value): + db_map.fetch_all({"list_value"}, include_ancestors=True) + cache = db_map.cache + value_lists_by_name = { + x.name: ( + x.id, + max( + (y.index for y in cache.get("list_value", {}).values() if y.parameter_value_list_id == x.id), default=-1 + ), + ) + for x in cache.get("parameter_value_list", {}).values() + } list_value_ids_by_index = {(x.parameter_value_list_id, x.index): x.id for x in cache.get("list_value", {}).values()} list_value_ids_by_value = { (x.parameter_value_list_id, x.type, x.value): x.id for x in cache.get("list_value", {}).values() @@ -1787,7 +1803,7 @@ def _get_list_values_for_import(db_map, data, make_cache, unparse_value): max_indexes = dict() for list_name, value in data: try: - list_id, value_index_list = value_lists_by_name.get(list_name) + list_id, current_max_index = value_lists_by_name.get(list_name) except TypeError: # cannot unpack non-iterable NoneType object error_log.append( @@ -1809,10 +1825,8 @@ def _get_list_values_for_import(db_map, data, make_cache, unparse_value): max_index = max_indexes.get(list_id) if max_index is not None: index = max_index + 1 - elif not value_index_list: - index = 0 else: - index = max(value_index_list) + 1 + index = max(current_max_index) + 1 item = {"parameter_value_list_id": list_id, "value": val, "type": type_, "index": index} try: check_list_value(item, list_names_by_id, list_value_ids_by_index, list_value_ids_by_value) @@ -1830,7 +1844,7 @@ def _get_list_values_for_import(db_map, data, make_cache, unparse_value): return to_add, to_update, error_log -def import_metadata(db_map, data, make_cache=None): +def import_metadata(db_map, data=None): """Imports metadata. Ignores duplicates. Example:: @@ -1845,11 +1859,12 @@ def import_metadata(db_map, data, make_cache=None): Returns: (Int, List) Number of successful inserted objects, list of errors """ - return import_data(db_map, metadata=data, make_cache=make_cache) + return import_data(db_map, metadata=data) -def _get_metadata_for_import(db_map, data, make_cache): - cache = make_cache({"metadata"}, include_ancestors=True) +def _get_metadata_for_import(db_map, data): + db_map.fetch_all({"metadata"}, include_ancestors=True) + cache = db_map.cache seen = {(x.name, x.value) for x in cache.get("metadata", {}).values()} to_add = [] for metadata in data: @@ -1865,7 +1880,7 @@ def _get_metadata_for_import(db_map, data, make_cache): # TODO: import_entity_metadata, import_parameter_value_metadata -def import_object_metadata(db_map, data, make_cache=None): +def import_object_metadata(db_map, data): """Imports object metadata. Ignores duplicates. Example:: @@ -1881,11 +1896,12 @@ def import_object_metadata(db_map, data, make_cache=None): Returns: (Int, List) Number of successful inserted items, list of errors """ - return import_data(db_map, object_metadata=data, make_cache=make_cache) + return import_data(db_map, object_metadata=data) -def _get_object_metadata_for_import(db_map, data, make_cache): - cache = make_cache({"object", "entity_metadata"}, include_ancestors=True) +def _get_object_metadata_for_import(db_map, data): + db_map.fetch_all({"object", "entity_metadata"}, include_ancestors=True) + cache = db_map.cache object_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values() if not x.dimension_id_list} metadata_ids = {(x.name, x.value): x.id for x in cache.get("metadata", {}).values()} object_ids = {(x.name, x.class_id): x.id for x in cache.get("entity", {}).values() if not x.element_id_list} @@ -1922,7 +1938,7 @@ def _get_object_metadata_for_import(db_map, data, make_cache): return to_add, [], error_log -def import_relationship_metadata(db_map, data, make_cache=None): +def import_relationship_metadata(db_map, data): """Imports relationship metadata. Ignores duplicates. Example:: @@ -1941,11 +1957,12 @@ def import_relationship_metadata(db_map, data, make_cache=None): Returns: (Int, List) Number of successful inserted items, list of errors """ - return import_data(db_map, relationship_metadata=data, make_cache=make_cache) + return import_data(db_map, relationship_metadata=data) -def _get_relationship_metadata_for_import(db_map, data, make_cache): - cache = make_cache({"relationship", "entity_metadata"}, include_ancestors=True) +def _get_relationship_metadata_for_import(db_map, data): + db_map.fetch_all({"relationship", "entity_metadata"}, include_ancestors=True) + cache = db_map.cache relationship_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values() if x.dimension_id_list} object_class_id_lists = { x.id: x.dimension_id_list for x in cache.get("entity_class", {}).values() if x.dimension_id_list @@ -1992,7 +2009,7 @@ def _get_relationship_metadata_for_import(db_map, data, make_cache): return to_add, [], error_log -def import_object_parameter_value_metadata(db_map, data, make_cache=None): +def import_object_parameter_value_metadata(db_map, data): """Imports object parameter value metadata. Ignores duplicates. Example:: @@ -2011,11 +2028,12 @@ def import_object_parameter_value_metadata(db_map, data, make_cache=None): Returns: (Int, List) Number of successful inserted items, list of errors """ - return import_data(db_map, object_parameter_value_metadata=data, make_cache=make_cache) + return import_data(db_map, object_parameter_value_metadata=data) -def _get_object_parameter_value_metadata_for_import(db_map, data, make_cache): - cache = make_cache({"parameter_value", "parameter_value_metadata"}, include_ancestors=True) +def _get_object_parameter_value_metadata_for_import(db_map, data): + db_map.fetch_all({"parameter_value", "parameter_value_metadata"}, include_ancestors=True) + cache = db_map.cache object_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values() if not x.dimension_id_list} object_ids = {(x.name, x.class_id): x.id for x in cache.get("entity", {}).values() if not x.element_id_list} parameter_ids = { @@ -2037,7 +2055,7 @@ def _get_object_parameter_value_metadata_for_import(db_map, data, make_cache): alternative_name = optionals[0] alt_id = alternative_ids.get(alternative_name, None) else: - alt_id, alternative_name = db_map.get_import_alternative(cache=cache) + alt_id, alternative_name = db_map.get_import_alternative() pv_id = parameter_value_ids.get((o_id, p_id, alt_id), None) if pv_id is None: msg = ( @@ -2067,7 +2085,7 @@ def _get_object_parameter_value_metadata_for_import(db_map, data, make_cache): return to_add, [], error_log -def import_relationship_parameter_value_metadata(db_map, data, make_cache=None): +def import_relationship_parameter_value_metadata(db_map, data): """Imports relationship parameter value metadata. Ignores duplicates. Example:: @@ -2086,11 +2104,12 @@ def import_relationship_parameter_value_metadata(db_map, data, make_cache=None): Returns: (Int, List) Number of successful inserted items, list of errors """ - return import_data(db_map, relationship_parameter_value_metadata=data, make_cache=make_cache) + return import_data(db_map, relationship_parameter_value_metadata=data) -def _get_relationship_parameter_value_metadata_for_import(db_map, data, make_cache): - cache = make_cache({"parameter_value", "parameter_value_metadata"}, include_ancestors=True) +def _get_relationship_parameter_value_metadata_for_import(db_map, data): + db_map.fetch_all({"parameter_value", "parameter_value_metadata"}, include_ancestors=True) + cache = db_map.cache relationship_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values() if x.dimension_id_list} object_class_id_lists = { x.id: x.dimension_id_list for x in cache.get("entity_class", {}).values() if x.dimension_id_list @@ -2120,7 +2139,7 @@ def _get_relationship_parameter_value_metadata_for_import(db_map, data, make_cac alternative_name = optionals[0] alt_id = alternative_ids.get(alternative_name, None) else: - alt_id, alternative_name = db_map.get_import_alternative(cache=cache) + alt_id, alternative_name = db_map.get_import_alternative() pv_id = parameter_value_ids.get((r_id, p_id, alt_id), None) if pv_id is None: msg = ( From c91c7b1a1ce1e156c5cc02d2e1a090a314d51980 Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 5 May 2023 17:13:01 +0200 Subject: [PATCH 034/317] Move fetching to DBCache, get rid of Diff, commit from cache --- spinedb_api/__init__.py | 1 - spinedb_api/db_cache.py | 173 +++++++--- spinedb_api/db_mapping.py | 2 - spinedb_api/db_mapping_add_mixin.py | 206 ++++-------- spinedb_api/db_mapping_base.py | 106 ++++--- spinedb_api/db_mapping_commit_mixin.py | 37 +-- spinedb_api/db_mapping_query_mixin.py | 289 ----------------- spinedb_api/db_mapping_remove_mixin.py | 51 +-- spinedb_api/db_mapping_update_mixin.py | 335 +++++++++----------- spinedb_api/diff_db_mapping.py | 177 ----------- spinedb_api/diff_db_mapping_base.py | 143 --------- spinedb_api/diff_db_mapping_commit_mixin.py | 104 ------ spinedb_api/export_functions.py | 22 +- spinedb_api/import_functions.py | 36 ++- spinedb_api/purge.py | 6 +- 15 files changed, 471 insertions(+), 1217 deletions(-) delete mode 100644 spinedb_api/db_mapping_query_mixin.py delete mode 100644 spinedb_api/diff_db_mapping.py delete mode 100644 spinedb_api/diff_db_mapping_base.py delete mode 100644 spinedb_api/diff_db_mapping_commit_mixin.py diff --git a/spinedb_api/__init__.py b/spinedb_api/__init__.py index 9899c056..d6d7069a 100644 --- a/spinedb_api/__init__.py +++ b/spinedb_api/__init__.py @@ -10,7 +10,6 @@ ###################################################################################################################### from .db_mapping import DatabaseMapping -from .diff_db_mapping import DiffDatabaseMapping from .exception import ( SpineDBAPIError, SpineIntegrityError, diff --git a/spinedb_api/db_cache.py b/spinedb_api/db_cache.py index 3109bfdf..88bd5859 100644 --- a/spinedb_api/db_cache.py +++ b/spinedb_api/db_cache.py @@ -13,22 +13,93 @@ """ from contextlib import suppress - +from operator import itemgetter # TODO: Implement CacheItem.pop() to do lookup? class DBCache(dict): - def __init__(self, advance_query, *args, **kwargs): + """A dictionary that maps table names to ids to items. Used to store and retrieve database contents.""" + + def __init__(self, db_map, chunk_size=None): """ - A dictionary that maps table names to ids to items. Used to store and retrieve database contents. + Args: + db_map (DatabaseMapping) + """ + super().__init__() + self._db_map = db_map + self._offsets = {} + self._fetched_item_types = set() + self._chunk_size = chunk_size + + def to_change(self): + to_add = {} + to_update = {} + to_remove = {} + for item_type, table_cache in self.items(): + new = [x for x in table_cache.values() if x.new] + dirty = [x for x in table_cache.values() if x.dirty and not x.new] + removed = {x.id for x in dict.values(table_cache) if x.removed} + if new: + to_add[item_type] = new + if dirty: + to_update[item_type] = dirty + if removed: + to_remove[item_type] = removed + return to_add, to_update, to_remove + + @property + def fetched_item_types(self): + return self._fetched_item_types + + def reset_queries(self): + """Resets queries and clears caches.""" + self._offsets.clear() + self._fetched_item_types.clear() + + def advance_query(self, item_type): + """Schedules an advance of the DB query that fetches items of given type. Args: - advance_query (function): A function that receives a table name (a.k.a item type) as input and returns - more items of that type to be added to this cache. + item_type (str) + + Returns: + Future """ - super().__init__(*args, **kwargs) - self._advance_query = advance_query + return self._db_map.executor.submit(self.do_advance_query, item_type) + + def _get_next_chunk(self, item_type): + try: + sq_name = self._db_map.cache_sqs[item_type] + qry = self._db_map.query(getattr(self._db_map, sq_name)) + except KeyError: + return [] + if not self._chunk_size: + self._fetched_item_types.add(item_type) + return [x._asdict() for x in qry.yield_per(1000).enable_eagerloads(False)] + offset = self._offsets.setdefault(item_type, 0) + chunk = [x._asdict() for x in qry.limit(self._chunk_size).offset(offset)] + self._offsets[item_type] += len(chunk) + return chunk + + def do_advance_query(self, item_type): + """Advances the DB query that fetches items of given type and caches the results. + + Args: + item_type (str) + + Returns: + list: items fetched from the DB + """ + chunk = self._get_next_chunk(item_type) + if not chunk: + self._fetched_item_types.add(item_type) + return [] + table_cache = self.table_cache(item_type) + for item in chunk: + # FIXME: This will overwrite working changes after a refresh + table_cache.add_item(item) + return chunk def table_cache(self, item_type): return self.setdefault(item_type, TableCache(self, item_type)) @@ -41,13 +112,9 @@ def get_item(self, item_type, id_): return item def fetch_more(self, item_type): - items = self._advance_query(item_type) - if not items: + if item_type in self._fetched_item_types: return False - table_cache = self.table_cache(item_type) - for item in items: - table_cache.add_item(item._asdict()) - return True + return bool(self.advance_query(item_type).result()) def fetch_all(self, item_type): while self.fetch_more(item_type): @@ -80,6 +147,7 @@ def make_item(self, item_type, item): "entity_group": EntityGroupItem, "parameter_definition": ParameterDefinitionItem, "parameter_value": ParameterValueItem, + "scenario": ScenarioItem, "scenario_alternative": ScenarioAlternativeItem, }.get(item_type, CacheItem) return factory(self, item_type, **item) @@ -99,35 +167,32 @@ def __init__(self, db_cache, item_type, *args, **kwargs): def values(self): return (x for x in super().values() if x.is_valid()) - def add_item(self, item, keep_existing=False): - if keep_existing: - existing_item = self.get(item["id"]) - if existing_item is not None: - return existing_item + def add_item(self, item, new=False): self[item["id"]] = new_item = self._db_cache.make_item(self._item_type, item) + new_item.new = new return new_item def update_item(self, item): current_item = self[item["id"]] + current_item.dirty = True current_item.update(item) current_item.cascade_update() def remove_item(self, id_): - if self._item_type == "alternative" and id_ == 1: - # Do not remove the Base alternative - return CacheItem(self._db_cache, self._item_type) current_item = self.get(id_) - if current_item: + if current_item is not None: current_item.cascade_remove() return current_item + def restore_item(self, id_): + current_item = self.get(id_) + if current_item is not None: + current_item.cascade_restore() + return current_item -class CacheItem(dict): - """A dictionary that behaves kinda like a row from a query result. - It is used to store items in a cache, so we can access them as if they were rows from a query result. - This is mainly because we want to use the cache as a replacement for db queries in some methods. - """ +class CacheItem(dict): + """A dictionary that represents an db item.""" def __init__(self, db_cache, item_type, *args, **kwargs): """ @@ -139,13 +204,19 @@ def __init__(self, db_cache, item_type, *args, **kwargs): self._item_type = item_type self._referrers = {} self._weak_referrers = {} - self.readd_callbacks = set() + self.restore_callbacks = set() self.update_callbacks = set() self.remove_callbacks = set() self._to_remove = False self._removed = False self._corrupted = False self._valid = None + self.new = False + self.dirty = False + + @property + def removed(self): + return self._removed @property def item_type(self): @@ -187,11 +258,11 @@ def _get_ref(self, ref_type, ref_id, source_key): def _handle_ref(self, ref, source_key): if source_key in self._reference_keys(): ref.add_referrer(self) - if ref.is_removed(): + if ref.removed: self._to_remove = True else: ref.add_weak_referrer(self) - if ref.is_removed(): + if ref.removed: return {} return ref @@ -221,9 +292,6 @@ def is_valid(self): self._valid = not self._removed and not self._corrupted return self._valid - def is_removed(self): - return self._removed - def add_referrer(self, referrer): if referrer.key is None: return @@ -235,19 +303,19 @@ def add_weak_referrer(self, referrer): if referrer.key not in self._referrers: self._weak_referrers[referrer.key] = referrer - def cascade_readd(self): + def cascade_restore(self): if not self._removed: return self._removed = False for referrer in self._referrers.values(): - referrer.cascade_readd() + referrer.cascade_restore() for weak_referrer in self._weak_referrers.values(): weak_referrer.call_update_callbacks() obsolete = set() - for callback in self.readd_callbacks: + for callback in self.restore_callbacks: if not callback(self): obsolete.add(callback) - self.readd_callbacks -= obsolete + self.restore_callbacks -= obsolete def cascade_remove(self): if self._removed: @@ -435,6 +503,23 @@ def _reference_keys(self): return super()._reference_keys() + ("class_name", "group_name", "member_name", "dimension_id_list") +class ScenarioItem(CacheItem): + @property + def sorted_scenario_alternatives(self): + self._db_cache.fetch_all("scenario_alternative") + return sorted( + (x for x in self._db_cache.get("scenario_alternative", {}).values() if x["scenario_id"] == self["id"]), + key=itemgetter("rank"), + ) + + def __getitem__(self, key): + if key == "alternative_id_list": + return [x["alternative_id"] for x in self.sorted_scenario_alternatives] + if key == "alternative_name_list": + return [x["alternative_name"] for x in self.sorted_scenario_alternatives] + return super().__getitem__(key) + + class ScenarioAlternativeItem(CacheItem): def __getitem__(self, key): if key == "scenario_name": @@ -444,14 +529,12 @@ def __getitem__(self, key): if key == "before_alternative_name": return self._get_ref("alternative", self["before_alternative_id"], key).get("name") if key == "before_alternative_id": - return next( - ( - x - for x in self._db_cache.get("scenario_alternative", {}).values() - if x["scenario_id"] == self["scenario_id"] and x["rank"] == self["rank"] - 1 - ), - {}, - ).get("alternative_id") + scenario = self._get_ref("scenario", self["scenario_id"], None) + try: + return scenario["alternative_id_list"][self["rank"]] + except IndexError: + return None + return super().__getitem__(key) def _reference_keys(self): return super()._reference_keys() + ("scenario_name", "alternative_name") diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index fe87bc86..171eefd2 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -14,7 +14,6 @@ """ -from .db_mapping_query_mixin import DatabaseMappingQueryMixin from .db_mapping_base import DatabaseMappingBase from .db_mapping_add_mixin import DatabaseMappingAddMixin from .db_mapping_check_mixin import DatabaseMappingCheckMixin @@ -25,7 +24,6 @@ class DatabaseMapping( - DatabaseMappingQueryMixin, DatabaseMappingCheckMixin, DatabaseMappingAddMixin, DatabaseMappingUpdateMixin, diff --git a/spinedb_api/db_mapping_add_mixin.py b/spinedb_api/db_mapping_add_mixin.py index 71a2d775..2b38c87c 100644 --- a/spinedb_api/db_mapping_add_mixin.py +++ b/spinedb_api/db_mapping_add_mixin.py @@ -69,7 +69,7 @@ def __init__(self, *args, **kwargs): self._next_id = Table("next_id", self._metadata, autoload=True) @contextmanager - def generate_ids(self, tablename, dry_run=False): + def generate_ids(self, tablename): """Manages id generation for new items to be added to the db. Args: @@ -78,7 +78,6 @@ def generate_ids(self, tablename, dry_run=False): Yields: self._IdGenerator: an object that generates a new id every time it is called. """ - connection = self.engine.connect() if dry_run else self.connection fieldname = { "entity_class": "entity_class_id", "object_class": "entity_class_id", @@ -98,51 +97,40 @@ def generate_ids(self, tablename, dry_run=False): "parameter_value_metadata": "parameter_value_metadata_id", "entity_metadata": "entity_metadata_id", }[tablename] - select_next_id = select([self._next_id]) - next_id_row = connection.execute(select_next_id).first() - if next_id_row is None: - next_id = None - stmt = self._next_id.insert() - else: - next_id = getattr(next_id_row, fieldname) - stmt = self._next_id.update() - if next_id is None: - real_tablename = self._real_tablename(tablename) - table = self._metadata.tables[real_tablename] - id_field = self._id_fields.get(real_tablename, "id") - select_max_id = select([func.max(getattr(table.c, id_field))]) - max_id = connection.execute(select_max_id).scalar() - next_id = max_id + 1 if max_id else 1 - gen = self._IdGenerator(next_id) - try: - yield gen - finally: - connection.execute(stmt, {"user": self.username, "date": datetime.utcnow(), fieldname: gen.next_id}) - if dry_run: - connection.close() + with self.engine.begin() as connection: + select_next_id = select([self._next_id]) + next_id_row = connection.execute(select_next_id).first() + if next_id_row is None: + next_id = None + stmt = self._next_id.insert() + else: + next_id = getattr(next_id_row, fieldname) + stmt = self._next_id.update() + if next_id is None: + real_tablename = self._real_tablename(tablename) + table = self._metadata.tables[real_tablename] + id_field = self._id_fields.get(real_tablename, "id") + select_max_id = select([func.max(getattr(table.c, id_field))]) + max_id = connection.execute(select_max_id).scalar() + next_id = max_id + 1 if max_id else 1 + gen = self._IdGenerator(next_id) + try: + yield gen + finally: + connection.execute(stmt, {"user": self.username, "date": datetime.utcnow(), fieldname: gen.next_id}) - def _add_commit_id_and_ids(self, tablename, *items, dry_run=False): + def _add_commit_id_and_ids(self, tablename, *items): if not items: return [], set() - commit_id = self._make_commit_id(dry_run=dry_run) - with self.generate_ids(tablename, dry_run=dry_run) as new_id: + commit_id = self._make_commit_id() + with self.generate_ids(tablename) as new_id: for item in items: item["commit_id"] = commit_id if "id" not in item: item["id"] = new_id() - def add_items( - self, - tablename, - *items, - check=True, - strict=False, - return_dups=False, - return_items=False, - readd=False, - dry_run=False, - ): - """Add items to db. + def add_items(self, tablename, *items, check=True, strict=False): + """Add items to cache. Args: tablename (str) @@ -150,32 +138,20 @@ def add_items( check (bool): Whether or not to check integrity strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` if the insertion of one of the items violates an integrity constraint. - return_dups (bool): Whether or not already existing and duplicated entries should also be returned. - return_items (bool): Return full items rather than just ids - readd (bool): Readds items directly Returns: set: ids or items successfully added list(SpineIntegrityError): found violations """ - if readd: - if not dry_run: - for _ in self._do_add_items(tablename, *items): - pass - return items if return_items else {x["id"] for x in items}, [] if check: checked_items, intgr_error_log = self.check_items(tablename, *items, for_update=False, strict=strict) else: checked_items, intgr_error_log = list(items), [] - ids = self._add_items(tablename, *checked_items, dry_run=dry_run) - if return_items: - return checked_items, intgr_error_log - if return_dups: - ids.update(set(x.id for x in intgr_error_log if x.id)) - return ids, intgr_error_log + _ = self._add_items(tablename, *checked_items) + return checked_items, intgr_error_log - def _add_items(self, tablename, *items, dry_run=False): - """Add items to database without checking integrity. + def _add_items(self, tablename, *items): + """Add items to cache without checking integrity. Args: tablename (str) @@ -186,37 +162,22 @@ def _add_items(self, tablename, *items, dry_run=False): Returns: ids (set): added instances' ids """ - self._add_commit_id_and_ids(tablename, *items, dry_run=dry_run) - if not dry_run: - for _ in self._do_add_items(tablename, *items): - pass + self._add_commit_id_and_ids(tablename, *items) + tablename = self._real_tablename(tablename) + table_cache = self.cache.table_cache(tablename) + for item in items: + table_cache.add_item(item, new=True) return {item["id"] for item in items} - def _get_table_for_insert(self, tablename): - """ - Returns the table name where to perform insertion. - - Subclasses can override this method to insert to another table instead (e.g., diff...) - - Args: - tablename (str): target database table name - - Yields: - str: database table name - """ - return self._metadata.tables[tablename] - def _do_add_items(self, tablename, *items_to_add): + """Add items to DB without checking integrity.""" try: for tablename_, items_to_add_ in self._items_to_add_per_table(tablename, items_to_add): - table = self._get_table_for_insert(tablename_) + table = self._metadata.tables[self._real_tablename(tablename_)] self._checked_execute(table.insert(), [{**item} for item in items_to_add_]) - yield tablename_ except DBAPIError as e: msg = f"DBAPIError while inserting {tablename} items: {e.orig.args}" raise SpineDBAPIError(msg) from e - else: - self._has_pending_changes = True @staticmethod def _items_to_add_per_table(tablename, items_to_add): @@ -265,8 +226,14 @@ def _items_to_add_per_table(tablename, items_to_add): yield ("entity", items_to_add) yield ("entity_element", ee_items_to_add) yield ("entity_alternative", ea_items_to_add) - elif tablename == "object_class": + elif tablename == "relationship_class": + ecd_items_to_add = [ + {"entity_class_id": item["id"], "position": position, "dimension_id": dimension_id} + for item in items_to_add + for position, dimension_id in enumerate(item["object_class_id_list"]) + ] yield ("entity_class", items_to_add) + yield ("entity_class_dimension", ecd_items_to_add) elif tablename == "object": ea_items_to_add = [ {"entity_id": item["id"], "alternative_id": alternative_id, "active": True} @@ -279,14 +246,6 @@ def _items_to_add_per_table(tablename, items_to_add): ] yield ("entity", items_to_add) yield ("entity_alternative", ea_items_to_add) - elif tablename == "relationship_class": - ecd_items_to_add = [ - {"entity_class_id": item["id"], "position": position, "dimension_id": dimension_id} - for item in items_to_add - for position, dimension_id in enumerate(item["object_class_id_list"]) - ] - yield ("entity_class", items_to_add) - yield ("entity_class_dimension", ecd_items_to_add) elif tablename == "relationship": ee_items_to_add = [ { @@ -380,7 +339,7 @@ def add_entity_metadata(self, *items, **kwargs): def add_parameter_value_metadata(self, *items, **kwargs): return self.add_items("parameter_value_metadata", *items, **kwargs) - def _get_or_add_metadata_ids_for_items(self, *items, check, strict, dry_run): + def _get_or_add_metadata_ids_for_items(self, *items, check, strict): cache = self.cache metadata_ids = {} for entry in cache.get("metadata", {}).values(): @@ -395,9 +354,7 @@ def _get_or_add_metadata_ids_for_items(self, *items, check, strict, dry_run): items_missing_metadata_ids.setdefault(item["metadata_name"], {})[item["metadata_value"]] = item else: item["metadata_id"] = existing_id - added_metadata, errors = self.add_items( - "metadata", *metadata_to_add, check=check, strict=strict, return_items=True, dry_run=dry_run - ) + added_metadata, errors = self.add_items("metadata", *metadata_to_add, check=check, strict=strict) for x in added_metadata: cache.table_cache("metadata").add_item(x) if errors: @@ -410,37 +367,20 @@ def _get_or_add_metadata_ids_for_items(self, *items, check, strict, dry_run): item["metadata_id"] = new_metadata_ids[metadata_name][metadata_value] return added_metadata, errors - def _add_ext_item_metadata(self, table_name, *items, check=True, strict=False, return_items=False, dry_run=False): - # Note, that even though return_items can be False, it doesn't make much sense here because we'll be mixing - # metadata and entity metadata ids. + def _add_ext_item_metadata(self, table_name, *items, check=True, strict=False): self.fetch_all({table_name}, include_ancestors=True) - cache = self.cache - added_metadata, metadata_errors = self._get_or_add_metadata_ids_for_items( - *items, check=check, strict=strict, dry_run=dry_run - ) + added_metadata, metadata_errors = self._get_or_add_metadata_ids_for_items(*items, check=check, strict=strict) if metadata_errors: - if not return_items: - return added_metadata, metadata_errors - return {i["id"] for i in added_metadata}, metadata_errors - added_item_metadata, item_errors = self.add_items( - table_name, *items, check=check, strict=strict, return_items=True, dry_run=dry_run - ) + return added_metadata, metadata_errors + added_item_metadata, item_errors = self.add_items(table_name, *items, check=check, strict=strict) errors = metadata_errors + item_errors - if not return_items: - return {i["id"] for i in added_metadata + added_item_metadata}, errors return added_metadata + added_item_metadata, errors - def add_ext_entity_metadata(self, *items, check=True, strict=False, return_items=False, readd=False, dry_run=False): - return self._add_ext_item_metadata( - "entity_metadata", *items, check=check, strict=strict, return_items=return_items, dry_run=dry_run - ) + def add_ext_entity_metadata(self, *items, check=True, strict=False): + return self._add_ext_item_metadata("entity_metadata", *items, check=check, strict=strict) - def add_ext_parameter_value_metadata( - self, *items, check=True, strict=False, return_items=False, readd=False, dry_run=False - ): - return self._add_ext_item_metadata( - "parameter_value_metadata", *items, check=check, strict=strict, return_items=return_items, dry_run=dry_run - ) + def add_ext_parameter_value_metadata(self, *items, check=True, strict=False): + return self._add_ext_item_metadata("parameter_value_metadata", *items, check=check, strict=strict) def _add_entity_classes(self, *items): return self._add_items("entity_class", *items) @@ -576,39 +516,3 @@ def add_parameter_value(self, **kwargs): sq = self.parameter_value_sq ids, _ = self.add_parameter_values(kwargs, strict=True) return self.query(sq).filter(sq.c.id.in_(ids)).one_or_none() - - def get_or_add_object_class(self, **kwargs): - """Stage an object class item for insertion if it doesn't already exists in the db. - - :returns: - - **item** -- The item successfully staged for insertion or already existing. - - :rtype: :class:`~sqlalchemy.util.KeyedTuple` - """ - sq = self.object_class_sq - ids, _ = self.add_object_classes(kwargs, return_dups=True) - return self.query(sq).filter(sq.c.id.in_(ids)).one_or_none() - - def get_or_add_object(self, **kwargs): - """Stage an object item for insertion if it doesn't already exists in the db. - - :returns: - - **item** -- The item successfully staged for insertion or already existing. - - :rtype: :class:`~sqlalchemy.util.KeyedTuple` - """ - sq = self.object_sq - ids, _ = self.add_objects(kwargs, return_dups=True) - return self.query(sq).filter(sq.c.id.in_(ids)).one_or_none() - - def get_or_add_parameter_definition(self, **kwargs): - """Stage a parameter definition item for insertion if it doesn't already exists in the db. - - :returns: - - **item** -- The item successfully staged for insertion or already existing. - - :rtype: :class:`~sqlalchemy.util.KeyedTuple` - """ - sq = self.parameter_definition_sq - ids, _ = self.add_parameter_definitions(kwargs, return_dups=True) - return self.query(sq).filter(sq.c.id.in_(ids)).one_or_none() diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index c0579a1d..be9c3af4 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -20,6 +20,7 @@ import time from collections import Counter from types import MethodType +from concurrent.futures import ThreadPoolExecutor from sqlalchemy import create_engine, case, MetaData, Table, Column, false, and_, func, inspect, cast, Integer, or_ from sqlalchemy.sql.expression import label, Alias from sqlalchemy.engine.url import make_url, URL @@ -46,6 +47,7 @@ from .spine_db_client import get_db_url_from_server from .db_cache import DBCache + logging.getLogger("alembic").setLevel(logging.CRITICAL) @@ -82,7 +84,8 @@ def __init__( apply_filters=True, memory=False, sqlite_timeout=1800, - advance_cache_query=None, + asynchronous=False, + chunk_size=None, ): """ Args: @@ -93,9 +96,11 @@ def __init__( create (bool): Whether or not to create a Spine db at the given URL if it's not already. apply_filters (bool): Whether or not filters in the URL's query part are applied to the database map. memory (bool): Whether or not to use a sqlite memory db as replacement for this DB map. + sqlite_timeout (int): How many seconds to wait before raising connection errors. + asynchronous (bool): Whether or not communication with the db should be done asynchronously. + chunk_size (int, optional): How many rows to fetch from the DB at a time when populating the cache. + If not specified, then all rows are fetched at once. """ - if advance_cache_query is None: - advance_cache_query = self._advance_cache_query # FIXME: We should also check the server memory property and use it here db_url = get_db_url_from_server(db_url) self.db_url = str(db_url) @@ -111,20 +116,22 @@ def __init__( self.codename = self._make_codename(codename) self._memory = memory self._memory_dirty = False + self._asynchronous = asynchronous self._original_engine = self.create_engine( self.sa_url, upgrade=upgrade, create=create, sqlite_timeout=sqlite_timeout ) # NOTE: The NullPool is needed to receive the close event (or any events), for some reason self.engine = create_engine("sqlite://", poolclass=NullPool) if self._memory else self._original_engine - self.connection = self.engine.connect() - if self._memory: - copy_database_bind(self.connection, self._original_engine) listen(self.engine, 'close', self._receive_engine_close) + self.executor = self._make_executor() + self.connection = self.executor.submit(self.engine.connect).result() + self.session = Session(self.connection, **self._session_kwargs) + if self._memory: + self.executor.submit(copy_database_bind, self.connection, self._original_engine) self._metadata = MetaData(self.connection) - self._metadata.reflect() + _ = self.executor.submit(self._metadata.reflect).result() self._tablenames = [t.name for t in self._metadata.sorted_tables] - self.session = Session(self.connection, **self._session_kwargs) - self.cache = DBCache(advance_cache_query) + self.cache = DBCache(self, chunk_size=chunk_size) # Subqueries that select everything from each table self._commit_sq = None self._alternative_sq = None @@ -239,7 +246,19 @@ def __enter__(self): return self def __exit__(self, _exc_type, _exc_val, _exc_tb): - self.connection.close() + self.close() + + def _make_executor(self): + return ThreadPoolExecutor(max_workers=1) if self._asynchronous else _Executor() + + def close(self): + if not self.connection.closed: + self.executor.submit(self.connection.close) + self.executor.shutdown() + + def reconnect(self): + self.executor = self._make_executor() + self.connection = self.executor.submit(self.engine.connect).result() def _descendant_tablenames(self, tablename): child_tablenames = { @@ -285,7 +304,7 @@ def get_table(self, tablename): def commit_id(self): return self._commit_id - def _make_commit_id(self, dry_run=False): + def _make_commit_id(self): return None def _check_commit(self, comment): @@ -387,10 +406,7 @@ def upgrade_to_head(rev, context): def _receive_engine_close(self, dbapi_con, _connection_record): if dbapi_con == self.connection.connection.connection and self._memory_dirty: - copy_database_bind(self._original_engine, self.connection) - - def reconnect(self): - self.connection = self.engine.connect() + copy_database_bind(self._original_engine, self.engine) def in_(self, column, values): """Returns an expression equivalent to column.in_(values), that circumvents the @@ -405,9 +421,8 @@ def in_(self, column, values): Column("value", column.type, primary_key=True), prefixes=['TEMPORARY'], ) - in_value.create(self.connection, checkfirst=True) - python_type = column.type.python_type - self._checked_execute(in_value.insert(), [{"value": python_type(val)} for val in set(values)]) + self.executor.submit(in_value.create, self.connection, checkfirst=True).result() + self._checked_execute(in_value.insert(), [{"value": column.type.python_type(val)} for val in set(values)]) return column.in_(self.query(in_value.c.value)) def _get_table_to_sq_attr(self): @@ -1940,7 +1955,7 @@ def override_create_import_alternative(self, method): def _checked_execute(self, stmt, items): if not items: return - return self.connection.execute(stmt, items) + return self.executor.submit(self.connection.execute, stmt, items).result() def _get_primary_key(self, tablename): pk = self.composite_pks.get(tablename) @@ -1972,11 +1987,6 @@ def fetch_all(self, tablenames, include_descendants=False, include_ancestors=Fal for tablename in tablenames & self.cache_sqs.keys(): self.cache.fetch_all(tablename) - def _advance_cache_query(self, tablename): - if tablename in self.cache: - return [] - return self.query(getattr(self, self.cache_sqs[tablename])).yield_per(1000).enable_eagerloads(False).all() - def _object_class_id(self): return case([(self.ext_entity_class_sq.c.dimension_id_list == None, self.ext_entity_class_sq.c.id)], else_=None) @@ -2050,24 +2060,42 @@ def _metadata_usage_counts(self): def __del__(self): try: - self.connection.close() + self.close() except AttributeError: pass - def make_temporary_table(self, table_name, *columns): - """Creates a temporary table. + def get_filter_configs(self): + return self._filter_configs + - Args: - table_name (str): table name - *columns: table's columns +class _Future: + def __init__(self): + self._result = None + self._exception = None - Returns: - Table: created table - """ - table = Table(table_name, self._metadata, *columns, prefixes=["TEMPORARY"]) - table.drop(self.connection, checkfirst=True) - table.create(self.connection) - return table + def set_result(self, result): + self._result = result - def get_filter_configs(self): - return self._filter_configs + def set_exception(self, exception): + self._exception = exception + + def add_done_callback(self, callback): + callback(self) + + def result(self): + if self._exception is not None: + raise self._exception + return self._result + + +class _Executor: + def submit(self, fn, *args, **kwargs): + future = _Future() + try: + future.set_result(fn(*args, **kwargs)) + except Exception as exc: + future.set_exception(exc) + return future + + def shutdown(self): + pass diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py index 6c709da4..0eaf5610 100644 --- a/spinedb_api/db_mapping_commit_mixin.py +++ b/spinedb_api/db_mapping_commit_mixin.py @@ -19,20 +19,16 @@ class DatabaseMappingCommitMixin: - """Provides methods to commit or rollback pending changes onto a Spine database. - Unlike Diff..., there's no "staging area", i.e., all changes are applied directly on the 'original' tables. - So no regrets. But it's much faster than maintaining the staging area and diff tables, - so ideal for, e.g., Spine Toolbox's Importer that operates 'in one go'. - """ + """Provides methods to commit or rollback pending changes onto a Spine database.""" def __init__(self, *args, **kwargs): """Initialize class.""" super().__init__(*args, **kwargs) self._commit_id = None - self._has_pending_changes = False def has_pending_changes(self): - return self._has_pending_changes + # FIXME + return True def _get_sqlite_lock(self): """Commits the session's natural transaction and begins a new locking one.""" @@ -40,14 +36,10 @@ def _get_sqlite_lock(self): self.session.commit() self.session.execute("BEGIN IMMEDIATE") - def _make_commit_id(self, dry_run=False): + def _make_commit_id(self): if self._commit_id is None: - if dry_run: - with self.engine.begin() as connection: - self._commit_id = self._do_make_commit_id(connection) - else: - self._get_sqlite_lock() - self._commit_id = self._do_make_commit_id(self.connection) + with self.engine.begin() as connection: + self._commit_id = self._do_make_commit_id(connection) return self._commit_id def _do_make_commit_id(self, connection): @@ -68,19 +60,20 @@ def commit_session(self, comment): date = datetime.now(timezone.utc) upd = commit.update().where(commit.c.id == self._make_commit_id()) self._checked_execute(upd, dict(user=user, date=date, comment=comment)) - self.session.commit() + to_add, to_update, to_remove = self.cache.to_change() + for tablename, items in to_add.items(): + self._do_add_items(tablename, *items) + for tablename, items in to_update.items(): + self._do_update_items(tablename, *items) + self._do_remove_items(**to_remove) + self.executor.submit(self.session.commit) self._commit_id = None - self._has_pending_changes = False if self._memory: self._memory_dirty = True def rollback_session(self): if not self.has_pending_changes(): raise SpineDBAPIError("Nothing to rollback.") - self.reset_session() - - def reset_session(self): - self.session.rollback() - self.cache.clear() + self.executor.submit(self.session.rollback) + self.cache.reset_queries() self._commit_id = None - self._has_pending_changes = False diff --git a/spinedb_api/db_mapping_query_mixin.py b/spinedb_api/db_mapping_query_mixin.py deleted file mode 100644 index 8ced3f20..00000000 --- a/spinedb_api/db_mapping_query_mixin.py +++ /dev/null @@ -1,289 +0,0 @@ -###################################################################################################################### -# Copyright (C) 2017-2022 Spine project consortium -# This file is part of Spine Database API. -# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser -# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your -# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; -# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General -# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with -# this program. If not, see . -###################################################################################################################### - -"""Provides :class:`.DatabaseMappingQueryMixin`. - -""" -# TODO: Deprecate and drop this module - -from sqlalchemy import func, or_ - - -class DatabaseMappingQueryMixin: - """Provides methods to perform standard queries (``SELECT`` statements) on a Spine db.""" - - def object_class_list(self, id_list=None, ordered=True): - """Return all records from the :meth:`object_class_sq <.DatabaseMappingBase.object_class_sq>` subquery. - - :param id_list: If present, only return records where ``id`` is in this list. - :param bool ordered: if True, order the result by the ``display_order`` field. - - :rtype: :class:`~sqlalchemy.orm.query.Query` - """ - qry = self.query(self.object_class_sq) - if id_list is not None: - qry = qry.filter(self.object_class_sq.c.id.in_(id_list)) - if ordered: - qry = qry.order_by(self.object_class_sq.c.display_order) - return qry - - def object_list(self, id_list=None, class_id=None): - """Return all records from the :meth:`object_sq <.DatabaseMappingBase.object_sq>` subquery. - - :param id_list: If present, only return records where ``id`` is in this list. - :param int class_id: If present, only return records where ``class_id`` is equal to this. - - :rtype: :class:`~sqlalchemy.orm.query.Query` - """ - qry = self.query(self.object_sq) - if id_list is not None: - qry = qry.filter(self.object_sq.c.id.in_(id_list)) - if class_id is not None: - qry = qry.filter(self.object_sq.c.class_id == class_id) - return qry - - def wide_relationship_class_list(self, id_list=None, object_class_id=None): - """Return all records from the - :meth:`wide_relationship_class_sq <.DatabaseMappingBase.wide_relationship_class_sq>` subquery. - - :param id_list: If present, only return records where ``id`` is in this list. - :param int object_class_id: If present, only return records where ``object_class_id`` is equal to this. - - :rtype: :class:`~sqlalchemy.orm.query.Query` - """ - qry = self.query(self.wide_relationship_class_sq) - if id_list is not None: - qry = qry.filter(self.wide_relationship_class_sq.c.id.in_(id_list)) - if object_class_id is not None: - qry = qry.filter( - or_( - self.wide_relationship_class_sq.c.object_class_id_list.like(f"%,{object_class_id},%"), - self.wide_relationship_class_sq.c.object_class_id_list.like(f"{object_class_id},%"), - self.wide_relationship_class_sq.c.object_class_id_list.like(f"%,{object_class_id}"), - self.wide_relationship_class_sq.c.object_class_id_list == str(object_class_id), - ) - ) - return qry - - def wide_relationship_list(self, id_list=None, class_id=None, object_id=None): - """Return all records from the - :meth:`wide_relationship_sq <.DatabaseMappingBase.wide_relationship_sq>` subquery. - - :param id_list: If present, only return records where ``id`` is in this list. - :param int class_id: If present, only return records where ``class_id`` is equal to this. - :param int object_id: If present, only return records where ``object_id`` is equal to this. - - :rtype: :class:`~sqlalchemy.orm.query.Query` - """ - qry = self.query(self.wide_relationship_sq) - if id_list is not None: - qry = qry.filter(self.wide_relationship_sq.c.id.in_(id_list)) - if class_id is not None: - qry = qry.filter(self.wide_relationship_sq.c.class_id == class_id) - if object_id is not None: - qry = qry.filter( - or_( - self.wide_relationship_sq.c.object_id_list.like(f"%,{object_id},%"), - self.wide_relationship_sq.c.object_id_list.like(f"{object_id},%"), - self.wide_relationship_sq.c.object_id_list.like(f"%,{object_id}"), - self.wide_relationship_sq.c.object_id_list == object_id, - ) - ) - return qry - - def parameter_definition_list(self, id_list=None, object_class_id=None, relationship_class_id=None): - """Return all records from the - :meth:`parameter_definition_sq <.DatabaseMappingBase.parameter_definition_sq>` subquery. - - :param id_list: If present, only return records where ``id`` is in this list. - :param int object_class_id: If present, only return records where ``object_class_id`` is equal to this. - :param int relationship_class_id: If present, only return records where ``relationship_class_id`` - is equal to this. - - :rtype: :class:`~sqlalchemy.orm.query.Query` - """ - qry = self.query(self.parameter_definition_sq) - if id_list is not None: - qry = qry.filter(self.parameter_definition_sq.c.id.in_(id_list)) - if object_class_id is not None: - # to do make sure type is object - qry = qry.filter(self.parameter_definition_sq.c.object_class_id == object_class_id) - if relationship_class_id is not None: - # to do make sure type is relationship - qry = qry.filter(self.parameter_definition_sq.c.relationship_class_id == relationship_class_id) - return qry - - def object_parameter_definition_list(self, object_class_id=None, parameter_definition_id=None): - """Return all records from the - :meth:`object_parameter_definition_sq <.DatabaseMappingBase.object_parameter_definition_sq>` subquery. - - :param int object_class_id: If present, only return records where ``object_class_id`` is equal to this. - :param int parameter_definition_id: If present, only return records where ``id`` is in this list. - - :rtype: :class:`~sqlalchemy.orm.query.Query` - """ - qry = self.query(self.object_parameter_definition_sq) - if object_class_id: - qry = qry.filter(self.object_parameter_definition_sq.c.object_class_id == object_class_id) - if parameter_definition_id: - qry = qry.filter(self.object_parameter_definition_sq.c.id == parameter_definition_id) - return qry - - def relationship_parameter_definition_list(self, relationship_class_id=None, parameter_definition_id=None): - """Return all records from the - :meth:`relationship_parameter_definition_sq <.DatabaseMappingBase.relationship_parameter_definition_sq>` - subquery. - - :param int relationship_class_id: If present, only return records where ``relationship_class_id`` - is equal to this. - :param int parameter_definition_id: If present, only return records where ``id`` is in this list. - - :rtype: :class:`~sqlalchemy.orm.query.Query` - """ - qry = self.query(self.relationship_parameter_definition_sq) - if relationship_class_id: - qry = qry.filter(self.relationship_parameter_definition_sq.c.relationship_class_id == relationship_class_id) - if parameter_definition_id: - qry = qry.filter(self.relationship_parameter_definition_sq.c.id == parameter_definition_id) - return qry - - def wide_object_parameter_definition_list(self, object_class_id_list=None, parameter_definition_id_list=None): - """Return object classes and their parameter definitions in wide format.""" - qry = self.query( - self.object_class_sq.c.id.label("object_class_id"), - self.object_class_sq.c.name.label("object_class_name"), - self.parameter_definition_sq.c.id.label("parameter_definition_id"), - self.parameter_definition_sq.c.name.label("parameter_name"), - ).filter(self.object_class_sq.c.id == self.parameter_definition_sq.c.object_class_id) - if object_class_id_list is not None: - qry = qry.filter(self.object_class_sq.c.id.in_(object_class_id_list)) - if parameter_definition_id_list is not None: - qry = qry.filter(self.parameter_definition_sq.c.id.in_(parameter_definition_id_list)) - subqry = qry.subquery() - return self.query( - subqry.c.object_class_id, - subqry.c.object_class_name, - func.group_concat(subqry.c.parameter_definition_id).label("parameter_definition_id_list"), - func.group_concat(subqry.c.parameter_name).label("parameter_name_list"), - ).group_by(subqry.c.object_class_id, subqry.c.object_class_name) - - def wide_relationship_parameter_definition_list( - self, relationship_class_id_list=None, parameter_definition_id_list=None - ): - """Return relationship classes and their parameter definitions in wide format.""" - qry = self.query( - self.relationship_class_sq.c.id.label("relationship_class_id"), - self.relationship_class_sq.c.name.label("relationship_class_name"), - self.parameter_definition_sq.c.id.label("parameter_definition_id"), - self.parameter_definition_sq.c.name.label("parameter_name"), - ).filter(self.relationship_class_sq.c.id == self.parameter_definition_sq.c.relationship_class_id) - if relationship_class_id_list is not None: - qry = qry.filter(self.relationship_class_sq.c.id.in_(relationship_class_id_list)) - if parameter_definition_id_list is not None: - qry = qry.filter(self.parameter_definition_sq.c.id.in_(parameter_definition_id_list)) - subqry = qry.subquery() - return self.query( - subqry.c.relationship_class_id, - subqry.c.relationship_class_name, - func.group_concat(subqry.c.parameter_definition_id).label("parameter_definition_id_list"), - func.group_concat(subqry.c.parameter_name).label("parameter_name_list"), - ).group_by(subqry.c.relationship_class_id, subqry.c.relationship_class_name) - - def parameter_value_list(self, id_list=None, object_id=None, relationship_id=None): - """Return all records from the - :meth:`parameter_value_sq <.DatabaseMappingBase.parameter_value_sq>` subquery. - - :param id_list: If present, only return records where ``id`` is in this list. - :param int object_id: If present, only return records where ``object_id`` is equal to this. - :param int relationship_id: If present, only return records where ``relationship_id`` is equal to this. - - :rtype: :class:`~sqlalchemy.orm.query.Query` - """ - qry = self.query(self.parameter_value_sq) - if id_list is not None: - qry = qry.filter(self.parameter_value_sq.c.id.in_(id_list)) - if object_id: - qry = qry.filter(self.parameter_value_sq.c.object_id == object_id) - if relationship_id: - qry = qry.filter(self.parameter_value_sq.c.relationship_id == relationship_id) - return qry - - def object_parameter_value_list(self, parameter_name=None): - """Return all records from the - :meth:`object_parameter_value_sq <.DatabaseMappingBase.object_parameter_value_sq>` subquery. - - :param str parameter_name: If present, only return records where ``parameter_name`` is equal to this. - - :rtype: :class:`~sqlalchemy.orm.query.Query` - """ - qry = self.query(self.object_parameter_value_sq) - if parameter_name: - qry = qry.filter(self.object_parameter_value_sq.c.parameter_name == parameter_name) - return qry - - def relationship_parameter_value_list(self, parameter_name=None): - """Return all records from the - :meth:`relationship_parameter_value_sq <.DatabaseMappingBase.relationship_parameter_value_sq>` subquery. - - :param str parameter_name: If present, only return records where ``parameter_name`` is equal to this. - - :rtype: :class:`~sqlalchemy.orm.query.Query` - """ - qry = self.query(self.relationship_parameter_value_sq) - if parameter_name: - qry = qry.filter(self.relationship_parameter_value_sq.c.parameter_name == parameter_name) - return qry - - def parameter_value_list_list(self, id_list=None): - """Return all records from the - :meth:`parameter_value_list_sq <.DatabaseMappingBase.parameter_value_list_sq>` subquery. - - :param id_list: If present, only return records where ``id`` is in this list. - - :rtype: :class:`~sqlalchemy.orm.query.Query` - """ - qry = self.query(self.parameter_value_list_sq) - if id_list is not None: - qry = qry.filter(self.parameter_value_list_sq.c.id.in_(id_list)) - return qry - - def wide_parameter_value_list_list(self, id_list=None): - """Return all records from the - :meth:`wide_parameter_value_list_sq <.DatabaseMappingBase.wide_parameter_value_list_sq>` subquery. - - :param id_list: If present, only return records where ``id`` is in this list. - - :rtype: :class:`~sqlalchemy.orm.query.Query` - """ - qry = self.query(self.wide_parameter_value_list_sq) - if id_list is not None: - qry = qry.filter(self.wide_parameter_value_list_sq.c.id.in_(id_list)) - return qry - - def object_parameter_definition_fields(self): - """Return names of columns that would be returned by :meth:`object_parameter_definition_list`.""" - return [x["name"] for x in self.object_parameter_definition_list().column_descriptions] - - def relationship_parameter_definition_fields(self): - """Return names of columns that would be returned by :meth:`relationship_parameter_definition_list`.""" - return [x["name"] for x in self.relationship_parameter_definition_list().column_descriptions] - - def object_parameter_value_fields(self): - """Return names of columns that would be returned by :meth:`object_parameter_value_list`.""" - return [x["name"] for x in self.object_parameter_value_list().column_descriptions] - - def relationship_parameter_value_fields(self): - """Return names of columns that would be returned by :meth:`relationship_parameter_value_list`.""" - return [x["name"] for x in self.relationship_parameter_value_list().column_descriptions] - - def alternative_list(self): - """Return names of columns that would be returned by :meth:`relationship_parameter_value_list`.""" - return self.query(self.alternative_sq) diff --git a/spinedb_api/db_mapping_remove_mixin.py b/spinedb_api/db_mapping_remove_mixin.py index b9d4f564..cc5e54ef 100644 --- a/spinedb_api/db_mapping_remove_mixin.py +++ b/spinedb_api/db_mapping_remove_mixin.py @@ -20,45 +20,52 @@ class DatabaseMappingRemoveMixin: - """Provides the :meth:`remove_items` method to stage ``REMOVE`` operations over a Spine db.""" + """Provides methods to perform ``REMOVE`` operations over a Spine db.""" - # pylint: disable=redefined-builtin - def cascade_remove_items(self, **kwargs): - """Removes items by id in cascade. + def restore_items(self, tablename, *ids): + if not ids: + return [] + tablename = self._real_tablename(tablename) + table_cache = self.cache.get(tablename) + if not table_cache: + return [] + return [table_cache.restore_item(id_) for id_ in ids] - Args: - **kwargs: keyword is table name, argument is list of ids to remove - """ - cascading_ids = self.cascading_ids(**kwargs) - self.remove_items(**cascading_ids) + def remove_items(self, tablename, *ids): + if not ids: + return [] + tablename = self._real_tablename(tablename) + table_cache = self.cache.get(tablename) + if not table_cache: + return [] + ids = set(ids) + if tablename == "alternative": + # Do not remove the Base alternative + ids -= {1} + return [table_cache.remove_item(id_) for id_ in ids] - def remove_items(self, **kwargs): - """Removes items by id, *not in cascade*. + def _do_remove_items(self, **kwargs): + """Removes items from the db. Args: **kwargs: keyword is table name, argument is list of ids to remove """ - for tablename, ids in kwargs.items(): + cascading_ids = self.cascading_ids(**kwargs) + for tablename, ids in cascading_ids.items(): + tablename = self._real_tablename(tablename) if tablename == "alternative": # Do not remove the Base alternative ids -= {1} if not ids: continue - real_tablename = self._real_tablename(tablename) - id_field = self._id_fields.get(real_tablename, "id") - table = self._metadata.tables[real_tablename] + id_field = self._id_fields.get(tablename, "id") + table = self._metadata.tables[tablename] delete = table.delete().where(self.in_(getattr(table.c, id_field), ids)) try: - self.connection.execute(delete) - table_cache = self.cache.get(tablename) - if table_cache: - for id_ in ids: - table_cache.remove_item(id_) + self.executor.submit(self.connection.execute, delete).result() except DBAPIError as e: msg = f"DBAPIError while removing {tablename} items: {e.orig.args}" raise SpineDBAPIError(msg) from e - else: - self._has_pending_changes = True # pylint: disable=redefined-builtin def cascading_ids(self, **kwargs): diff --git a/spinedb_api/db_mapping_update_mixin.py b/spinedb_api/db_mapping_update_mixin.py index bb612e0a..5a243a58 100644 --- a/spinedb_api/db_mapping_update_mixin.py +++ b/spinedb_api/db_mapping_update_mixin.py @@ -21,182 +21,100 @@ class DatabaseMappingUpdateMixin: """Provides methods to perform ``UPDATE`` operations over a Spine db.""" - def _add_commit_id(self, *items): - for item in items: - item["commit_id"] = self._make_commit_id() - - def _update_items(self, tablename, *items, dry_run=False): - if not items: - return set() - if dry_run: - return {x["id"] for x in items} - # Special cases - if tablename == "entity": - return self._do_update_entities(*items) - if tablename == "scenario": - return self._do_update_scenarios(*items) - if tablename == "object": - return self._do_update_objects(*items) - if tablename == "relationship": - return self._do_update_wide_relationships(*items) - real_tablename = self._real_tablename(tablename) - self._do_update_items(real_tablename, *items) - - def _do_update_entities(self, *items): - entity_items = [] - entity_element_items = [] - entity_alternative_items = [] - for item in items: - entity_id = item["id"] - class_id = item["class_id"] - dimension_id_list = item["dimension_id_list"] - element_id_list = item["element_id_list"] - entity_items.append( - {"id": entity_id, "class_id": class_id, "name": item["name"], "description": item.get("description")} - ) - entity_element_items.extend( - [ - { - "entity_class_id": class_id, - "entity_id": entity_id, - "position": position, - "dimension_id": dimension_id, - "element_id": element_id, - } - for position, (dimension_id, element_id) in enumerate(zip(dimension_id_list, element_id_list)) - ] - ) - entity_alternative_items.extend( - [ - {"entity_id": entity_id, "alternative_id": alt_id, "active": True} - for alt_id in item["active_alternative_id_list"] - ] - + [ - {"entity_id": entity_id, "alternative_id": alt_id, "active": False} - for alt_id in item["inactive_alternative_id_list"] - ] - ) - self._do_update_items("entity", *entity_items) - self._do_update_items("entity_element", *entity_element_items) - self._do_update_items("entity_alternative", *entity_alternative_items) - return {x["id"] for x in entity_items} + def _do_update_items(self, tablename, *items_to_update): + """Update items in DB without checking integrity.""" + try: + for tablename_, items_to_update_ in self._items_to_update_per_table(tablename, items_to_update): + if not items_to_update_: + continue + table = self._metadata.tables[self._real_tablename(tablename_)] + upd = table.update() + for k in self._get_primary_key(tablename_): + upd = upd.where(getattr(table.c, k) == bindparam(k)) + upd = upd.values({key: bindparam(key) for key in table.columns.keys() & items_to_update_[0].keys()}) + self._checked_execute(upd, [{**item} for item in items_to_update_]) + except DBAPIError as e: + msg = f"DBAPIError while updating '{tablename}' items: {e.orig.args}" + raise SpineDBAPIError(msg) from e - def _do_update_scenarios(self, *items): - """Returns data to add and remove, in order to set wide scenario alternatives. + @staticmethod + def _items_to_update_per_table(tablename, items_to_update): + """ + Yields tuples of string tablename, list of items to update. Needed because some update queries + actually need to update records in more than one table. Args: - *items: One or more wide scenario :class:`dict` objects to set. - Each item must include the following keys: - - - "id": integer scenario id - - "alternative_id_list": list of alternative ids for that scenario + tablename (str): target database table name + items_to_update (list): items to update - Returns - list: narrow scenario_alternative :class:`dict` objects to add. - set: integer scenario_alternative ids to remove + Yields: + tuple: database table name, items to update """ - self.fetch_all({"scenario_alternative", "scenario"}) - cache = self.cache - current_alternative_id_lists = {x.id: x.alternative_id_list for x in cache.get("scenario", {}).values()} - scenario_alternative_ids = { - (x.scenario_id, x.alternative_id): x.id for x in cache.get("scenario_alternative", {}).values() - } - scen_alts_to_add = [] - scen_alt_ids_to_remove = set() - for item in items: - scenario_id = item["id"] - alternative_id_list = item["alternative_id_list"] - current_alternative_id_list = current_alternative_id_lists[scenario_id] - for k, alternative_id in enumerate(alternative_id_list): - item_to_add = {"scenario_id": scenario_id, "alternative_id": alternative_id, "rank": k + 1} - scen_alts_to_add.append(item_to_add) - for alternative_id in current_alternative_id_list: - scen_alt_ids_to_remove.add(scenario_alternative_ids[scenario_id, alternative_id]) - self.remove_items(scenario_alternative=scen_alt_ids_to_remove) - self.add_items("scenario_alternative", *scen_alts_to_add) - return self._do_update_items("scenario", *items) - - def _do_update_objects(self, *items): - entity_items = [] - entity_alternative_items = [] - for item in items: - entity_id = item["id"] - class_id = item["class_id"] - entity_items.append( - {"id": entity_id, "class_id": class_id, "name": item["name"], "description": item.get("description")} - ) - entity_alternative_items.extend( - [ - {"entity_id": entity_id, "alternative_id": alt_id, "active": True} - for alt_id in item["active_alternative_id_list"] - ] - + [ - {"entity_id": entity_id, "alternative_id": alt_id, "active": False} - for alt_id in item["inactive_alternative_id_list"] - ] - ) - self._do_update_items("entity", *entity_items) - self._do_update_items("entity_alternative", *entity_alternative_items) - return {x["id"] for x in entity_items} - - def _do_update_wide_relationships(self, *items): - entity_items = [] - entity_element_items = [] - entity_alternative_items = [] - for item in items: - entity_id = item["id"] - class_id = item["class_id"] - object_class_id_list = item["object_class_id_list"] - object_id_list = item["object_id_list"] - entity_items.append( - {"id": entity_id, "class_id": class_id, "name": item["name"], "description": item.get("description")} - ) - entity_element_items.extend( - [ + if tablename == "entity": + entity_items = [] + entity_element_items = [] + for item in items_to_update: + entity_id = item["id"] + class_id = item["class_id"] + dimension_id_list = item["dimension_id_list"] + element_id_list = item["element_id_list"] + entity_items.append( { - "entity_class_id": class_id, - "entity_id": entity_id, - "position": position, - "dimension_id": dimension_id, - "element_id": element_id, + "id": entity_id, + "class_id": class_id, + "name": item["name"], + "description": item.get("description"), } - for position, (dimension_id, element_id) in enumerate(zip(object_class_id_list, object_id_list)) - ] - ) - entity_alternative_items.extend( - [ - {"entity_id": entity_id, "alternative_id": alt_id, "active": True} - for alt_id in item["active_alternative_id_list"] - ] - + [ - {"entity_id": entity_id, "alternative_id": alt_id, "active": False} - for alt_id in item["inactive_alternative_id_list"] - ] - ) - self._do_update_items("entity", *entity_items) - self._do_update_items("entity_element", *entity_element_items) - self._do_update_items("entity_alternative", *entity_alternative_items) - return {x["id"] for x in entity_items} - - def _do_update_items(self, tablename, *items): - if not items: - return - self._add_commit_id(*items) - table = self._metadata.tables[tablename] - upd = table.update() - for k in self._get_primary_key(tablename): - upd = upd.where(getattr(table.c, k) == bindparam(k)) - upd = upd.values({key: bindparam(key) for key in table.columns.keys() & items[0].keys()}) - try: - self._checked_execute(upd, [{**item} for item in items]) - except DBAPIError as e: - msg = f"DBAPIError while updating '{tablename}' items: {e.orig.args}" - raise SpineDBAPIError(msg) from e + ) + entity_element_items.extend( + [ + { + "entity_class_id": class_id, + "entity_id": entity_id, + "position": position, + "dimension_id": dimension_id, + "element_id": element_id, + } + for position, (dimension_id, element_id) in enumerate(zip(dimension_id_list, element_id_list)) + ] + ) + yield ("entity", entity_items) + yield ("entity_element", entity_element_items) + elif tablename == "relationship": + entity_items = [] + entity_element_items = [] + for item in items_to_update: + entity_id = item["id"] + class_id = item["class_id"] + object_class_id_list = item["object_class_id_list"] + object_id_list = item["object_id_list"] + entity_items.append( + { + "id": entity_id, + "class_id": class_id, + "name": item["name"], + "description": item.get("description"), + } + ) + entity_element_items.extend( + [ + { + "entity_class_id": class_id, + "entity_id": entity_id, + "position": position, + "dimension_id": dimension_id, + "element_id": element_id, + } + for position, (dimension_id, element_id) in enumerate(zip(object_class_id_list, object_id_list)) + ] + ) + yield ("entity", entity_items) + yield ("entity_element", entity_element_items) else: - self._has_pending_changes = True + yield (tablename, items_to_update) - def update_items(self, tablename, *items, check=True, strict=False, return_items=False, dry_run=False): - """Updates items. + def update_items(self, tablename, *items, check=True, strict=False): + """Updates items in cache. Args: tablename (str): Target database table name @@ -204,7 +122,6 @@ def update_items(self, tablename, *items, check=True, strict=False, return_items check (bool): Whether or not to check integrity strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` if the insertion of one of the items violates an integrity constraint. - return_items (bool): Return full items rather than just ids Returns: set: ids or items successfully updated @@ -214,10 +131,21 @@ def update_items(self, tablename, *items, check=True, strict=False, return_items checked_items, intgr_error_log = self.check_items(tablename, *items, for_update=True, strict=strict) else: checked_items, intgr_error_log = list(items), [] - updated_ids = self._update_items(tablename, *checked_items, dry_run=dry_run) - if return_items: - return checked_items, intgr_error_log - return updated_ids, intgr_error_log + _ = self._update_items(tablename, *checked_items) + return checked_items, intgr_error_log + + def _update_items(self, tablename, *items): + """Updates items in cache without checking integrity.""" + if not items: + return set() + tablename = self._real_tablename(tablename) + table_cache = self.cache.get(tablename) + if table_cache is not None: + commit_id = self._make_commit_id() + for item in items: + item["commit_id"] = commit_id + table_cache.update_item(item) + return {x["id"] for x in items} def update_alternatives(self, *items, **kwargs): return self.update_items("alternative", *items, **kwargs) @@ -303,23 +231,17 @@ def update_metadata(self, *items, **kwargs): def _update_metadata(self, *items): return self._update_items("metadata", *items) - def update_ext_entity_metadata(self, *items, check=True, strict=False, return_items=False, dry_run=False): - updated_items, errors = self._update_ext_item_metadata( - "entity_metadata", *items, check=check, strict=strict, dry_run=dry_run - ) - if return_items: - return updated_items, errors - return {i["id"] for i in updated_items}, errors + def update_ext_entity_metadata(self, *items, check=True, strict=False): + updated_items, errors = self._update_ext_item_metadata("entity_metadata", *items, check=check, strict=strict) + return updated_items, errors - def update_ext_parameter_value_metadata(self, *items, check=True, strict=False, return_items=False, dry_run=False): + def update_ext_parameter_value_metadata(self, *items, check=True, strict=False): updated_items, errors = self._update_ext_item_metadata( - "parameter_value_metadata", *items, check=check, strict=strict, dry_run=dry_run + "parameter_value_metadata", *items, check=check, strict=strict ) - if return_items: - return updated_items, errors - return {i["id"] for i in updated_items}, errors + return updated_items, errors - def _update_ext_item_metadata(self, metadata_table, *items, check=True, strict=False, dry_run=False): + def _update_ext_item_metadata(self, metadata_table, *items, check=True, strict=False): self.fetch_all({"entity_metadata", "parameter_value_metadata", "metadata"}) cache = self.cache metadata_ids = {} @@ -379,9 +301,7 @@ def _update_ext_item_metadata(self, metadata_table, *items, check=True, strict=F all_items = [] errors = [] if updatable_metadata_items: - updated_metadata, errors = self.update_metadata( - *updatable_metadata_items, check=False, strict=strict, return_items=True, dry_run=dry_run - ) + updated_metadata, errors = self.update_metadata(*updatable_metadata_items, check=False, strict=strict) all_items += updated_metadata if errors: return all_items, errors @@ -390,9 +310,7 @@ def _update_ext_item_metadata(self, metadata_table, *items, check=True, strict=F ] added_metadata = [] if addable_metadata: - added_metadata, metadata_add_errors = self.add_metadata( - *addable_metadata, check=False, strict=strict, return_items=True - ) + added_metadata, metadata_add_errors = self.add_metadata(*addable_metadata, check=False, strict=strict) all_items += added_metadata errors += metadata_add_errors if errors: @@ -406,8 +324,41 @@ def _update_ext_item_metadata(self, metadata_table, *items, check=True, strict=F if updatable_items: # FIXME: Force-clear cache before updating item metadata to ensure that added/updated metadata is found. updated_item_metadata, item_metadata_errors = self.update_items( - metadata_table, *updatable_items, check=check, strict=strict, return_items=True + metadata_table, *updatable_items, check=check, strict=strict ) all_items += updated_item_metadata errors += item_metadata_errors return all_items, errors + + def get_data_to_set_scenario_alternatives(self, *items): + """Returns data to add and remove, in order to set wide scenario alternatives. + + Args: + *items: One or more wide scenario :class:`dict` objects to set. + Each item must include the following keys: + + - "id": integer scenario id + - "alternative_id_list": list of alternative ids for that scenario + + Returns + list: scenario_alternative :class:`dict` objects to add. + set: integer scenario_alternative ids to remove + """ + self.fetch_all({"scenario_alternative", "scenario"}) + cache = self.cache + current_alternative_id_lists = {x.id: x.alternative_id_list for x in cache.get("scenario", {}).values()} + scenario_alternative_ids = { + (x.scenario_id, x.alternative_id): x.id for x in cache.get("scenario_alternative", {}).values() + } + scen_alts_to_add = [] + scen_alt_ids_to_remove = set() + for item in items: + scenario_id = item["id"] + alternative_id_list = item["alternative_id_list"] + current_alternative_id_list = current_alternative_id_lists[scenario_id] + for k, alternative_id in enumerate(alternative_id_list): + item_to_add = {"scenario_id": scenario_id, "alternative_id": alternative_id, "rank": k + 1} + scen_alts_to_add.append(item_to_add) + for alternative_id in current_alternative_id_list: + scen_alt_ids_to_remove.add(scenario_alternative_ids[scenario_id, alternative_id]) + return scen_alts_to_add, scen_alt_ids_to_remove diff --git a/spinedb_api/diff_db_mapping.py b/spinedb_api/diff_db_mapping.py deleted file mode 100644 index 7f06f1b8..00000000 --- a/spinedb_api/diff_db_mapping.py +++ /dev/null @@ -1,177 +0,0 @@ -###################################################################################################################### -# Copyright (C) 2017-2022 Spine project consortium -# This file is part of Spine Database API. -# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser -# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your -# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; -# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General -# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with -# this program. If not, see . -###################################################################################################################### - -""" -Provides :class:`DiffDatabaseMapping`. - -""" - -from sqlalchemy.sql.expression import bindparam -from sqlalchemy.exc import DBAPIError -from .db_mapping_query_mixin import DatabaseMappingQueryMixin -from .db_mapping_check_mixin import DatabaseMappingCheckMixin -from .db_mapping_add_mixin import DatabaseMappingAddMixin -from .db_mapping_update_mixin import DatabaseMappingUpdateMixin -from .db_mapping_remove_mixin import DatabaseMappingRemoveMixin -from .diff_db_mapping_commit_mixin import DiffDatabaseMappingCommitMixin -from .diff_db_mapping_base import DiffDatabaseMappingBase -from .filters.tools import apply_filter_stack, load_filters -from .exception import SpineDBAPIError - - -class DiffDatabaseMapping( - DatabaseMappingQueryMixin, - DatabaseMappingCheckMixin, - DatabaseMappingAddMixin, - DatabaseMappingUpdateMixin, - DatabaseMappingRemoveMixin, - DiffDatabaseMappingCommitMixin, - DiffDatabaseMappingBase, -): - """A read-write database mapping. - - Provides methods to *stage* any number of changes (namely, ``INSERT``, ``UPDATE`` and ``REMOVE`` operations) - over a Spine database, as well as to commit or rollback the batch of changes. - - For convenience, querying this mapping return results *as if* all the staged changes were already committed. - - :param str db_url: A database URL in RFC-1738 format pointing to the database to be mapped. - :param str username: A user name. If ``None``, it gets replaced by the string ``"anon"``. - :param bool upgrade: Whether or not the db at the given URL should be upgraded to the most recent version. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - if self._filter_configs is not None: - stack = load_filters(self._filter_configs) - apply_filter_stack(self, stack) - - def _add_items(self, tablename, *items): - self._add_commit_id_and_ids(tablename, *items) - ids = {x["id"] for x in items} - for tablename_ in self._do_add_items(tablename, *items): - self.added_item_id[tablename_].update(ids) - self._clear_subqueries(tablename_) - return ids - - def _readd_items(self, tablename, *items): - ids = set(x["id"] for x in items) - for tablename_ in self._do_add_items(tablename, *items): - self.added_item_id[tablename_].update(ids) - self._clear_subqueries(tablename_) - - def _get_table_for_insert(self, tablename): - return self._diff_table(tablename) - - def _get_items_for_update_and_insert(self, tablename, checked_items): - """Return lists of items for update and insert. - Items in the diff table should be updated, whereas items in the original table - should be marked as dirty and inserted into the corresponding diff table.""" - items_for_update = list() - items_for_insert = list() - dirty_ids = set() - updated_ids = set() - id_field = self._id_fields.get(tablename, "id") - for item in checked_items: - id_ = item[id_field] - updated_ids.add(id_) - if id_ in self.added_item_id[tablename] | self.updated_item_id[tablename]: - items_for_update.append(item) - else: - items_for_insert.append(item) - dirty_ids.add(id_) - return items_for_update, items_for_insert, dirty_ids, updated_ids - - def _do_update_items(self, tablename, *items): - items_for_update, items_for_insert, dirty_ids, updated_ids = self._get_items_for_update_and_insert( - tablename, items - ) - if self.committing: - try: - self._update_and_insert_items(tablename, items_for_update, items_for_insert) - self._mark_as_dirty(tablename, dirty_ids) - self.updated_item_id[tablename].update(dirty_ids) - except DBAPIError as e: - msg = f"DBAPIError while updating {tablename} items: {e.orig.args}" - raise SpineDBAPIError(msg) - return updated_ids - - def _update_and_insert_items(self, tablename, items_for_update, items_for_insert): - diff_table = self._diff_table(tablename) - if items_for_update: - upd = diff_table.update() - for k in self._get_primary_key(tablename): - upd = upd.where(getattr(diff_table.c, k) == bindparam(k)) - upd = upd.values({key: bindparam(key) for key in diff_table.columns.keys() & items_for_update[0].keys()}) - self._checked_execute(upd, [{**item} for item in items_for_update]) - ins = diff_table.insert() - self._checked_execute(ins, [{**item} for item in items_for_insert]) - - def _update_wide_relationships(self, *items): - """Update relationships without checking integrity.""" - items = self._items_with_type_id("relationship", *items) - ent_items = [] - rel_ent_items = [] - for item in items: - ent_item = item.copy() - object_class_id_list = ent_item.pop("object_class_id_list", []) - object_id_list = ent_item.pop("object_id_list", []) - ent_items.append(ent_item) - for dimension, (member_class_id, member_id) in enumerate(zip(object_class_id_list, object_id_list)): - rel_ent_item = ent_item.copy() - rel_ent_item["entity_class_id"] = rel_ent_item.pop("class_id", None) - rel_ent_item["entity_id"] = rel_ent_item.pop("id", None) - rel_ent_item["dimension"] = dimension - rel_ent_item["member_class_id"] = member_class_id - rel_ent_item["member_id"] = member_id - rel_ent_items.append(rel_ent_item) - try: - ents_for_update, ents_for_insert, dirty_ent_ids, updated_ent_ids = self._get_items_for_update_and_insert( - "entity", ent_items - ) - ( - rel_ents_for_update, - rel_ents_for_insert, - dirty_rel_ent_ids, - updated_rel_ent_ids, - ) = self._get_items_for_update_and_insert("relationship_entity", rel_ent_items) - self._update_and_insert_items("entity", ents_for_update, ents_for_insert) - self._mark_as_dirty("entity", dirty_ent_ids) - self.updated_item_id["entity"].update(dirty_ent_ids) - self._update_and_insert_items("relationship_entity", rel_ents_for_update, rel_ents_for_insert) - self._mark_as_dirty("relationship_entity", dirty_rel_ent_ids) - self.updated_item_id["relationship_entity"].update(dirty_rel_ent_ids) - return updated_ent_ids.union(updated_rel_ent_ids) - except DBAPIError as e: - msg = "DBAPIError while updating relationships: {}".format(e.orig.args) - raise SpineDBAPIError(msg) - - def remove_items(self, **kwargs): - """Removes items by id, *not in cascade*. - - Args: - **kwargs: keyword is table name, argument is list of ids to remove - """ - if self.committing: - for tablename, ids in kwargs.items(): - table_id = self._id_fields.get(tablename, "id") - diff_table = self._diff_table(tablename) - delete = diff_table.delete().where(self.in_(getattr(diff_table.c, table_id), ids)) - try: - self.connection.execute(delete) - except DBAPIError as e: - msg = f"DBAPIError while removing {tablename} items: {e.orig.args}" - raise SpineDBAPIError(msg) - for tablename, ids in kwargs.items(): - self.added_item_id[tablename].difference_update(ids) - self.updated_item_id[tablename].difference_update(ids) - self.removed_item_id[tablename].update(ids) - self._mark_as_dirty(tablename, ids) diff --git a/spinedb_api/diff_db_mapping_base.py b/spinedb_api/diff_db_mapping_base.py deleted file mode 100644 index b866ea20..00000000 --- a/spinedb_api/diff_db_mapping_base.py +++ /dev/null @@ -1,143 +0,0 @@ -###################################################################################################################### -# Copyright (C) 2017-2022 Spine project consortium -# This file is part of Spine Database API. -# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser -# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your -# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; -# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General -# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with -# this program. If not, see . -###################################################################################################################### - -""" -Provides :class:`.DiffDatabaseMappingBase`. - -""" - -from datetime import datetime, timezone -from sqlalchemy import select -from sqlalchemy.sql.expression import literal, union_all -from .db_mapping_base import DatabaseMappingBase -from .helpers import labelled_columns - -# TODO: improve docstrings - - -class DiffDatabaseMappingBase(DatabaseMappingBase): - """Base class for the read-write database mapping. - - :param str db_url: A URL in RFC-1738 format pointing to the database to be mapped. - :param str username: A user name. If ``None``, it gets replaced by the string ``"anon"``. - :param bool upgrade: Whether or not the db at the given URL should be upgraded to the most recent version. - """ - - # NOTE: It works by creating and mapping a set of - # temporary 'diff' tables, where temporary changes are staged until the moment of commit. - - _session_kwargs = dict(autocommit=True) - - def __init__(self, *args, **kwargs): - """Initialize class.""" - super().__init__(*args, **kwargs) - self.diff_prefix = None - # Diff dictionaries - self.added_item_id = {} - self.updated_item_id = {} - self.removed_item_id = {} - self.dirty_item_id = {} - # Initialize stuff - self._init_diff_dicts() - self._create_diff_tables() - - def _init_diff_dicts(self): - """Initialize dictionaries that help keeping track of the differences.""" - self.added_item_id = {x: set() for x in self._tablenames} - self.updated_item_id = {x: set() for x in self._tablenames} - self.removed_item_id = {x: set() for x in self._tablenames} - self.dirty_item_id = {x: set() for x in self._tablenames} - - def _reset_diff_dicts(self): - self._init_diff_dicts() - self._clear_subqueries(*self._tablenames) - - def _create_diff_tables(self): - """Create diff tables.""" - diff_name_prefix = "diff_" + self.username - self.diff_prefix = diff_name_prefix + datetime.now(timezone.utc).strftime("%Y%m%d%H%M%S") + "_" - for tablename in self._tablenames: - table = self._metadata.tables[tablename] - diff_columns = [c.copy() for c in table.columns] - self.make_temporary_table(self.diff_prefix + tablename, *diff_columns) - - def _mark_as_dirty(self, tablename, ids): - """Mark items as dirty, which means the corresponding records from the original tables - are no longer valid, and they should be queried from the diff tables instead.""" - self.dirty_item_id[tablename].update(ids) - self._clear_subqueries(tablename) - - def _subquery(self, tablename): - """Overriden method to - (i) filter dirty items from original tables, and - (ii) also bring data from diff tables: - Roughly equivalent to: - SELECT * FROM orig_table WHERE id NOT IN dirty_ids - UNION ALL - SELECT * FROM diff_table - """ - orig_table = self._metadata.tables[tablename] - table_id = self._id_fields.get(tablename, "id") - qry = self.query(*labelled_columns(orig_table)).filter( - ~self.in_(getattr(orig_table.c, table_id), self.dirty_item_id[tablename]) - ) - if self.added_item_id[tablename] or self.updated_item_id[tablename]: - diff_table = self._diff_table(tablename) - if self.sa_url.drivername.startswith("mysql"): - # Work around the "can't reopen " error in MySQL. - # (This happens whenever a temporary table is used more than once in a query.) - # Basically what we do here, is dump the contents of the diff table into a - # `SELECT first row UNION ALL SELECT second row ... UNION ALL SELECT last row` statement, - # and use it as a replacement. - diff_row_selects = [ - select([literal(v).label(k) for k, v in row._asdict().items()]) for row in self.query(diff_table) - ] - diff_table = union_all(*diff_row_selects).alias() - qry = qry.union_all(self.query(*labelled_columns(diff_table))) - return qry.subquery() - - def _orig_subquery(self, tablename): - """A subquery of the form: - - .. code-block:: sql - - SELECT * FROM tablename - - :param str tablename: A string indicating the table to be queried. - :type: :class:`~sqlalchemy.sql.expression.Alias` - """ - table = self._metadata.tables[tablename] - return self.query(table).subquery() - - def _diff_subquery(self, tablename): - """A subquery of the form: - - .. code-block:: sql - - SELECT * FROM tablename - - :param str tablename: A string indicating the table to be queried. - :type: :class:`~sqlalchemy.sql.expression.Alias` - """ - return self.query(self._diff_table(tablename)).subquery() - - def diff_ids(self): - return {x: self.added_item_id[x] | self.updated_item_id[x] for x in self._tablenames} - - def _diff_table(self, tablename): - return self._metadata.tables.get(self.diff_prefix + tablename) - - def _reset_diff_mapping(self): - """Delete all records from diff tables (but don't drop the tables).""" - for tablename in self._tablenames: - table = self._diff_table(tablename) - if table is not None: - self.connection.execute(table.delete()) diff --git a/spinedb_api/diff_db_mapping_commit_mixin.py b/spinedb_api/diff_db_mapping_commit_mixin.py deleted file mode 100644 index 35c570b5..00000000 --- a/spinedb_api/diff_db_mapping_commit_mixin.py +++ /dev/null @@ -1,104 +0,0 @@ -###################################################################################################################### -# Copyright (C) 2017-2022 Spine project consortium -# This file is part of Spine Database API. -# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser -# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your -# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; -# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General -# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with -# this program. If not, see . -###################################################################################################################### - -""" -Provides :class:`DiffDatabaseMappingCommitMixin`. - -""" - -from datetime import datetime, timezone -from sqlalchemy.exc import DBAPIError -from sqlalchemy.sql.expression import bindparam -from .exception import SpineDBAPIError - - -class DiffDatabaseMappingCommitMixin: - """Provides methods to commit or rollback staged changes onto a Spine database.""" - - def commit_session(self, comment): - """Commit staged changes to the database. - - Args: - comment (str): An informative comment explaining the nature of the commit. - """ - self._check_commit(comment) - transaction = self.connection.begin() - try: - user = self.username - date = datetime.now(timezone.utc) - ins = self._metadata.tables["commit"].insert().values(user=user, date=date, comment=comment) - commit_id = self.connection.execute(ins).inserted_primary_key[0] - # NOTE: Remove first, so `scenario_alternative.rank`s become 'free'. - # Remove - for tablename, ids in self.removed_item_id.items(): - if not ids: - continue - table = self._metadata.tables[tablename] - id_col = self._id_fields.get(tablename, "id") - self.query(table).filter(self.in_(getattr(table.c, id_col), ids)).delete(synchronize_session=False) - # Update - for tablename, ids in self.updated_item_id.items(): - if not ids: - continue - id_col = self._id_fields.get(tablename, "id") - orig_table = self._metadata.tables[tablename] - diff_table = self._diff_table(tablename) - updated_items = [] - for item in self.query(diff_table).filter(self.in_(getattr(diff_table.c, id_col), ids)): - kwargs = item._asdict() - kwargs["commit_id"] = commit_id - updated_items.append(kwargs) - upd = orig_table.update() - for k in self._get_primary_key(tablename): - upd = upd.where(getattr(orig_table.c, k) == bindparam(k)) - upd = upd.values({key: bindparam(key) for key in orig_table.columns.keys()}) - self._checked_execute(upd, updated_items) - # Add - for tablename, ids in self.added_item_id.items(): - if not ids: - continue - id_col = self._id_fields.get(tablename, "id") - orig_table = self._metadata.tables[tablename] - diff_table = self._diff_table(tablename) - new_items = [] - for item in self.query(diff_table).filter(self.in_(getattr(diff_table.c, id_col), ids)): - kwargs = item._asdict() - kwargs["commit_id"] = commit_id - new_items.append(kwargs) - self._checked_execute(orig_table.insert(), new_items) - self._reset_diff_mapping() - transaction.commit() - self._reset_diff_dicts() - if self._memory: - self._memory_dirty = True - except DBAPIError as e: - msg = "DBAPIError while committing changes: {}".format(e.orig.args) - raise SpineDBAPIError(msg) from None - - def rollback_session(self): - """Discard all staged changes.""" - if not self.has_pending_changes(): - raise SpineDBAPIError("Nothing to rollback.") - self.reset_session() - - def reset_session(self): - transaction = self.connection.begin() - try: - self._reset_diff_mapping() - transaction.commit() - self._reset_diff_dicts() - except DBAPIError as e: - msg = "DBAPIError while rolling back changes: {}".format(e.orig.args) - raise SpineDBAPIError(msg) from None - - def has_pending_changes(self): - """True if this mapping has any staged changes.""" - return any(self.added_item_id.values()) or any(self.dirty_item_id.values()) diff --git a/spinedb_api/export_functions.py b/spinedb_api/export_functions.py index 673d937c..8d79e6fb 100644 --- a/spinedb_api/export_functions.py +++ b/spinedb_api/export_functions.py @@ -100,7 +100,7 @@ def _get_items(db_map, tablename, ids): if not ids: return () db_map.fetch_all({tablename}, include_ancestors=True) - _process_item = _make_item_processor(db_map.cache, tablename) + _process_item = _make_item_processor(db_map, tablename) for item in _get_items_from_cache(db_map.cache, tablename, ids): yield from _process_item(item) @@ -116,24 +116,22 @@ def _get_items_from_cache(cache, tablename, ids): yield item -def _make_item_processor(cache, tablename): +def _make_item_processor(db_map, tablename): if tablename == "parameter_value_list": - return _ParameterValueListProcessor(cache) + db_map.fetch_all({"list_value"}, include_ancestors=True) + return _ParameterValueListProcessor(db_map.cache.get("list_value", {}).values()) return lambda item: (item,) class _ParameterValueListProcessor: - def __init__(self, cache): - self._list_value_by_id = cache.get("list_value", {}) + def __init__(self, value_items): + self._value_items_by_list_id = {} + for x in value_items: + self._value_items_by_list_id.setdefault(x.parameter_value_list_id, []).append(x) def __call__(self, item): - fields = ["name", "value", "type"] - if item.value_id_list is None: - yield KeyedTuple([item.name, None, None], fields) - return - for value_id in item.value_id_list: - val = self._list_value_by_id[value_id] - yield KeyedTuple([item.name, val.value, val.type], fields) + for list_value_item in sorted(self._value_items_by_list_id.get(item.id, ()), key=lambda x: x.index): + yield KeyedTuple([item.name, list_value_item.value, list_value_item.type], ["name", "value", "type"]) def export_parameter_value_lists(db_map, ids=Asterisk, parse_value=from_database): diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index f032e385..828851c5 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -170,7 +170,6 @@ def get_data_for_import( db_map, unparse_value=to_database, on_conflict="merge", - dry_run=False, entity_classes=(), entities=(), parameter_definitions=(), @@ -194,6 +193,11 @@ def get_data_for_import( relationship_metadata=(), object_parameter_value_metadata=(), relationship_parameter_value_metadata=(), + # FIXME: compat + tools=(), + features=(), + tool_features=(), + tool_feature_methods=(), ): """Returns an iterator of data for import, that the user can call instead of `import_data` if they want to add and update the data by themselves. @@ -226,19 +230,19 @@ def get_data_for_import( """ # NOTE: The order is important, because of references. E.g., we want to import alternatives before parameter_values if alternatives: - yield ("alternative", _get_alternatives_for_import(alternatives)) + yield ("alternative", _get_alternatives_for_import(db_map, alternatives)) if scenarios: - yield ("scenario", _get_scenarios_for_import(scenarios)) + yield ("scenario", _get_scenarios_for_import(db_map, scenarios)) if scenario_alternatives: if not scenarios: scenarios = (item[0] for item in scenario_alternatives) - yield ("scenario", _get_scenarios_for_import(scenarios)) + yield ("scenario", _get_scenarios_for_import(db_map, scenarios)) if not alternatives: alternatives = (item[1] for item in scenario_alternatives) - yield ("alternative", _get_alternatives_for_import(alternatives)) - yield ("scenario_alternative", _get_scenario_alternatives_for_import(scenario_alternatives)) + yield ("alternative", _get_alternatives_for_import(db_map, alternatives)) + yield ("scenario_alternative", _get_scenario_alternatives_for_import(db_map, scenario_alternatives)) if entity_classes: - yield ("entity_class", _get_entity_classes_for_import(db_map, entity_classes, dry_run)) + yield ("entity_class", _get_entity_classes_for_import(db_map, entity_classes)) if object_classes: yield ("object_class", _get_object_classes_for_import(db_map, object_classes)) if relationship_classes: @@ -262,7 +266,7 @@ def get_data_for_import( _get_relationship_parameters_for_import(db_map, relationship_parameters, unparse_value), ) if entities: - yield ("entity", _get_entities_for_import(db_map, entities, dry_run)) + yield ("entity", _get_entities_for_import(db_map, entities)) if objects: yield ("object", _get_objects_for_import(db_map, objects)) if relationships: @@ -329,7 +333,7 @@ def import_entity_classes(db_map, data): return import_data(db_map, entity_classes=data) -def _get_entity_classes_for_import(db_map, data, dry_run): +def _get_entity_classes_for_import(db_map, data): db_map.fetch_all({"entity_class"}, include_ancestors=True) cache = db_map.cache entity_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values()} @@ -337,7 +341,7 @@ def _get_entity_classes_for_import(db_map, data, dry_run): error_log = [] to_add = [] to_update = [] - with db_map.generate_ids("entity_class", dry_run=dry_run) as new_entity_class_id: + with db_map.generate_ids("entity_class") as new_entity_class_id: for name, *optionals in data: if name in checked: continue @@ -404,7 +408,7 @@ def _make_unique_entity_name(class_id, class_name, ent_name_or_el_names, class_i return name -def _get_entities_for_import(db_map, data, dry_run): +def _get_entities_for_import(db_map, data): db_map.fetch_all({"entity"}, include_ancestors=True) cache = db_map.cache entities = {x.id: x for x in cache.get("entity", {}).values()} @@ -421,7 +425,7 @@ def _get_entities_for_import(db_map, data, dry_run): to_add = [] to_update = [] checked = set() - with db_map.generate_ids("entity", dry_run=dry_run) as new_entity_id: + with db_map.generate_ids("entity") as new_entity_id: for class_name, ent_name_or_el_names, *optionals in data: ec_id = entity_class_ids.get(class_name, None) dim_ids = dimension_id_lists.get(ec_id, ()) @@ -767,7 +771,7 @@ def import_alternatives(db_map, data): return import_data(db_map, alternatives=data) -def _get_alternatives_for_import(data): +def _get_alternatives_for_import(db_map, data): db_map.fetch_all({"alternative"}, include_ancestors=True) cache = db_map.cache alternative_ids = {alternative.name: alternative.id for alternative in cache.get("alternative", {}).values()} @@ -829,7 +833,7 @@ def import_scenarios(db_map, data): return import_data(db_map, scenarios=data) -def _get_scenarios_for_import(data): +def _get_scenarios_for_import(db_map, data): db_map.fetch_all({"scenario"}, include_ancestors=True) cache = db_map.cache scenario_ids = {scenario.name: scenario.id for scenario in cache.get("scenario", {}).values()} @@ -889,7 +893,7 @@ def import_scenario_alternatives(db_map, data): return import_data(db_map, scenario_alternatives=data) -def _get_scenario_alternatives_for_import(data): +def _get_scenario_alternatives_for_import(db_map, data): db_map.fetch_all({"scenario_alternative"}, include_ancestors=True) cache = db_map.cache scenario_alternative_id_lists = {x.id: x.alternative_id_list for x in cache.get("scenario", {}).values()} @@ -1826,7 +1830,7 @@ def _get_list_values_for_import(db_map, data, unparse_value): if max_index is not None: index = max_index + 1 else: - index = max(current_max_index) + 1 + index = current_max_index + 1 item = {"parameter_value_list_id": list_id, "value": val, "type": type_, "index": index} try: check_list_value(item, list_names_by_id, list_value_ids_by_index, list_value_ids_by_value) diff --git a/spinedb_api/purge.py b/spinedb_api/purge.py index e535bc3a..e4eac936 100644 --- a/spinedb_api/purge.py +++ b/spinedb_api/purge.py @@ -43,7 +43,7 @@ def purge_url(url, purge_settings, logger=None): logger.msg_warning.emit(f"Failed to purge url {sanitized_url}: {err}") return False success = purge(db_map, purge_settings, logger=logger) - db_map.connection.close() + db_map.close() return success @@ -69,7 +69,9 @@ def purge(db_map, purge_settings, logger=None): try: if logger: logger.msg.emit("Purging database...") - db_map.cascade_remove_items(**removable_db_map_data) + for item_type, ids in removable_db_map_data.items(): + db_map.remove_items(item_type, **ids) + # FIXME: What do do here? How does one affect the DB directly, bypassing cache? db_map.commit_session("Purge database") if logger: logger.msg.emit("Database purged") From 2a9fc0cca3c3128748ec8877fda591a5c6da190f Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 9 May 2023 16:05:42 +0200 Subject: [PATCH 035/317] Rationalize integrity checks --- spinedb_api/__init__.py | 12 - ..._replace_values_with_reference_to_list_.py | 54 +- spinedb_api/check_functions.py | 618 ------ spinedb_api/db_cache.py | 413 ++-- spinedb_api/db_mapping.py | 2 - spinedb_api/db_mapping_add_mixin.py | 145 +- spinedb_api/db_mapping_base.py | 16 +- spinedb_api/db_mapping_check_mixin.py | 817 -------- spinedb_api/db_mapping_update_mixin.py | 50 +- spinedb_api/import_functions.py | 1706 +++-------------- 10 files changed, 611 insertions(+), 3222 deletions(-) delete mode 100644 spinedb_api/check_functions.py delete mode 100644 spinedb_api/db_mapping_check_mixin.py diff --git a/spinedb_api/__init__.py b/spinedb_api/__init__.py index d6d7069a..dae763ba 100644 --- a/spinedb_api/__init__.py +++ b/spinedb_api/__init__.py @@ -32,18 +32,6 @@ forward_sweep, Asterisk, ) -from .check_functions import ( - check_alternative, - check_scenario, - check_scenario_alternative, - check_object_class, - check_object, - check_wide_relationship_class, - check_wide_relationship, - check_parameter_definition, - check_parameter_value, - check_parameter_value_list, -) from .import_functions import ( import_alternatives, import_data, diff --git a/spinedb_api/alembic/versions/989fccf80441_replace_values_with_reference_to_list_.py b/spinedb_api/alembic/versions/989fccf80441_replace_values_with_reference_to_list_.py index fde8a1bd..2183c42f 100644 --- a/spinedb_api/alembic/versions/989fccf80441_replace_values_with_reference_to_list_.py +++ b/spinedb_api/alembic/versions/989fccf80441_replace_values_with_reference_to_list_.py @@ -9,11 +9,7 @@ from sqlalchemy import MetaData from sqlalchemy.sql.expression import bindparam from sqlalchemy.orm import sessionmaker -from spinedb_api.check_functions import ( - replace_default_values_with_list_references, - replace_parameter_values_with_list_references, -) -from spinedb_api.parameter_value import from_database +from spinedb_api.parameter_value import dump_db_value, from_database, ParameterValueFormatError from spinedb_api.helpers import group_concat from spinedb_api.exception import SpineIntegrityError @@ -71,3 +67,51 @@ def upgrade(): def downgrade(): pass + + +def replace_default_values_with_list_references(item, parameter_value_lists, list_values): + parameter_value_list_id = item.get("parameter_value_list_id") + return _replace_values_with_list_references( + "parameter_definition", item, parameter_value_list_id, parameter_value_lists, list_values + ) + + +def replace_parameter_values_with_list_references(item, parameter_definitions, parameter_value_lists, list_values): + parameter_definition_id = item["parameter_definition_id"] + parameter_definition = parameter_definitions[parameter_definition_id] + parameter_value_list_id = parameter_definition["parameter_value_list_id"] + return _replace_values_with_list_references( + "parameter_value", item, parameter_value_list_id, parameter_value_lists, list_values + ) + + +def _replace_values_with_list_references(item_type, item, parameter_value_list_id, parameter_value_lists, list_values): + if parameter_value_list_id is None: + return False + if parameter_value_list_id not in parameter_value_lists: + raise SpineIntegrityError("Parameter value list not found.") + value_id_list = parameter_value_lists[parameter_value_list_id] + if value_id_list is None: + raise SpineIntegrityError("Parameter value list is empty!") + value_key, type_key = { + "parameter_value": ("value", "type"), + "parameter_definition": ("default_value", "default_type"), + }[item_type] + value = dict.get(item, value_key) + value_type = dict.get(item, type_key) + try: + parsed_value = from_database(value, value_type) + except ParameterValueFormatError as err: + raise SpineIntegrityError(f"Invalid {value_key} '{value}': {err}") from None + if parsed_value is None: + return False + list_value_id = next((id_ for id_ in value_id_list if list_values.get(id_) == parsed_value), None) + if list_value_id is None: + valid_values = ", ".join(f"{dump_db_value(list_values.get(id_))[0].decode('utf8')!r}" for id_ in value_id_list) + raise SpineIntegrityError( + f"Invalid {value_key} '{parsed_value}' - it should be one from the parameter value list: {valid_values}." + ) + item[value_key] = str(list_value_id).encode("UTF8") + item[type_key] = "list_value_ref" + item["list_value_id"] = list_value_id + return True diff --git a/spinedb_api/check_functions.py b/spinedb_api/check_functions.py deleted file mode 100644 index a75e8b5e..00000000 --- a/spinedb_api/check_functions.py +++ /dev/null @@ -1,618 +0,0 @@ -###################################################################################################################### -# Copyright (C) 2017-2022 Spine project consortium -# This file is part of Spine Database API. -# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser -# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your -# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; -# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General -# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with -# this program. If not, see . -###################################################################################################################### - -"""Functions for checking whether inserting data into a Spine database leads -to the violation of integrity constraints. - -""" - -from .parameter_value import dump_db_value, from_database, ParameterValueFormatError -from .exception import SpineIntegrityError - -# NOTE: We parse each parameter value or default value before accepting it. Is it too much? - - -def check_alternative(item, current_items): - try: - name = item["name"] - except KeyError: - raise SpineIntegrityError("Missing alternative name.") - if name in current_items: - raise SpineIntegrityError(f"There can't be more than one alternative called '{name}'.", id=current_items[name]) - - -def check_scenario(item, current_items): - try: - name = item["name"] - except KeyError: - raise SpineIntegrityError("Missing scenario name.") - if name in current_items: - raise SpineIntegrityError(f"There can't be more than one scenario called '{name}'.", id=current_items[name]) - - -def check_scenario_alternative(item, ids_by_alt_id, ids_by_rank, scenario_names, alternative_names): - """ - Checks if given scenario alternative violates a database's integrity. - - Args: - item (dict: a scenario alternative item for checking; must contain the following fields: - - "scenario_id": scenario's id - - "alternative_id": alternative's id - - "rank": alternative's rank within the scenario - ids_by_alt_id (dict): a mapping from (scenario id, alternative id) tuples to scenario_alternative ids - already in the database - ids_by_rank (dict): a mapping from (scenario id, rank) tuples to scenario_alternative ranks already in the database - scenario_names (Iterable): the names of existing scenarios in the database keyed by id - alternative_names (Iterable): the names of existing alternatives in the database keyed by id - - Raises: - SpineIntegrityError: if the insertion of the item violates an integrity constraint. - """ - try: - scen_id = item["scenario_id"] - except KeyError: - raise SpineIntegrityError("Missing scenario identifier.") - try: - alt_id = item["alternative_id"] - except KeyError: - raise SpineIntegrityError("Missing alternative identifier.") - try: - rank = item["rank"] - except KeyError: - raise SpineIntegrityError("Missing scenario alternative rank.") - scen_name = scenario_names.get(scen_id) - if scen_name is None: - raise SpineIntegrityError(f"Scenario with id {scen_id} does not have a name.") - alt_name = alternative_names.get(alt_id) - if alt_name is None: - raise SpineIntegrityError(f"Alternative with id {alt_id} does not have a name.") - dup_id = ids_by_alt_id.get((scen_id, alt_id)) - if dup_id is not None: - raise SpineIntegrityError(f"Alternative {alt_name} already exists in scenario {scen_name}.", id=dup_id) - dup_id = ids_by_rank.get((scen_id, rank)) - if dup_id is not None: - raise SpineIntegrityError( - f"Rank {rank} already exists in scenario {scen_name}. Cannot give the same rank for " - f"alternative {alt_name}.", - id=dup_id, - ) - - -def check_entity_class(item, current_items): - """Check whether the insertion of an entity class item results in the violation of an integrity constraint. - - Args: - wide_item (dict): An entity class item to be checked. - current_items (dict): A dictionary mapping names to ids of entity classes already in the database. - - Raises: - SpineIntegrityError: if the insertion of the item violates an integrity constraint. - """ - try: - name = item["name"] - except KeyError: - raise SpineIntegrityError("The name for the entity class is missing.") - if not name: - raise SpineIntegrityError("Entity class name is an empty string, which is not valid") - try: - dimension_id_list = item["dimension_id_list"] - except KeyError: - item["dimension_id_list"] = dimension_id_list = () - if not all(id_ in current_items.values() for id_ in dimension_id_list): - raise SpineIntegrityError(f"One or more dimension ids for the entity class '{name}' are not in the database.") - if name in current_items: - raise SpineIntegrityError( - f"There can't be more than one entity class with the name '{name}'.", id=current_items[name] - ) - - -def check_entity(item, current_items_by_name, current_items_by_el_id_lst, entity_classes, entities): - """Check whether the insertion of an entity item results in the violation of an integrity constraint. - - Args: - wide_item (dict): An entity item to be checked. - current_items_by_name (dict): A dictionary mapping tuples (class_id, name) to ids of - entities already in the database. - current_items_by_el_id_lst (dict): A dictionary mapping tuples (class_id, element_name_list) to ids - of entities already in the database. - entity_classes (dict): A dictionary of entity class items in the database keyed by id. - entities (dict): A dictionary of entity items in the database keyed by id. - - Raises: - SpineIntegrityError: if the insertion of the item violates an integrity constraint. - """ - try: - name = item["name"] - except KeyError: - raise SpineIntegrityError("The name for the entity is missing.") - if not name: - raise SpineIntegrityError("Entity name is an empty string, which is not valid") - try: - class_id = item["class_id"] - except KeyError: - raise SpineIntegrityError(f"The entity class id for entity '{name}' is missing.") - if (class_id, name) in current_items_by_name: - raise SpineIntegrityError( - f"There's already an entity called '{name}' in the same class.", - id=current_items_by_name[class_id, name], - ) - dimension_id_list = entity_classes[class_id]["dimension_id_list"] - if not dimension_id_list: - return - try: - element_id_list = tuple(item["element_id_list"]) - except KeyError: - item["element_id_list"] = element_id_list = () - try: - given_dimension_id_list = tuple(entities[id_]["class_id"] for id_ in element_id_list) - except KeyError: - raise SpineIntegrityError(f"Some of the elements in entity '{name}' are not in the database.") - if given_dimension_id_list != dimension_id_list: - element_name_list = [entities[id_]["name"] for id_ in element_id_list] - entity_class_name = entity_classes[class_id]["name"] - raise SpineIntegrityError(f"Incorrect elements '{element_name_list}' for entity class '{entity_class_name}'.") - if (class_id, element_id_list) in current_items_by_el_id_lst: - element_name_list = [entities[id]["name"] for id in element_id_list] - entity_class_name = entity_classes[class_id]["name"] - raise SpineIntegrityError( - f"There's already an entity with elements {element_name_list} in class {entity_class_name}.", - id=current_items_by_el_id_lst[class_id, element_id_list], - ) - - -def check_object_class(item, current_items): - """Check whether the insertion of an object class item - results in the violation of an integrity constraint. - - Args: - item (dict): An object class item to be checked. - current_items (dict): A dictionary mapping names to ids of object classes already in the database. - - Raises: - SpineIntegrityError: if the insertion of the item violates an integrity constraint. - """ - try: - name = item["name"] - except KeyError: - raise SpineIntegrityError( - "Python KeyError: There is no dictionary key for the object class name. Probably a bug, please report." - ) - if not name: - raise SpineIntegrityError("Object class name is an empty string and therefore not valid") - if name in current_items: - raise SpineIntegrityError(f"There can't be more than one object class called '{name}'.", id=current_items[name]) - - -def check_object(item, current_items, object_class_ids): - """Check whether the insertion of an object item - results in the violation of an integrity constraint. - - Args: - item (dict): An object item to be checked. - current_items (dict): A dictionary mapping tuples (class_id, name) to ids of objects already in the database. - object_class_ids (list): A list of object class ids in the database. - - Raises: - SpineIntegrityError: if the insertion of the item violates an integrity constraint. - """ - try: - name = item["name"] - except KeyError: - raise SpineIntegrityError( - "Python KeyError: There is no dictionary key for the object name. Probably a bug, please report." - ) - if not name: - raise SpineIntegrityError("Object name is an empty string and therefore not valid") - try: - class_id = item["class_id"] - except KeyError: - raise SpineIntegrityError(f"Object '{name}' does not have an object class id.") - if class_id not in object_class_ids: - raise SpineIntegrityError(f"Object class id for object '{name}' not found.") - if (class_id, name) in current_items: - raise SpineIntegrityError( - f"There's already an object called '{name}' in the same object class.", id=current_items[class_id, name] - ) - - -def check_wide_relationship_class(wide_item, current_items, object_class_ids): - """Check whether the insertion of a relationship class item - results in the violation of an integrity constraint. - - Args: - wide_item (dict): A wide relationship class item to be checked. - current_items (dict): A dictionary mapping names to ids of relationship classes already in the database. - object_class_ids (list): A list of object class ids in the database. - - Raises: - SpineIntegrityError: if the insertion of the item violates an integrity constraint. - """ - try: - name = wide_item["name"] - except KeyError: - raise SpineIntegrityError( - "Python KeyError: There is no dictionary key for the relationship class name. " - "Probably a bug, please report." - ) - if not name: - raise SpineIntegrityError(f"Name '{name}' is not valid") - try: - given_object_class_id_list = wide_item["object_class_id_list"] - except KeyError: - raise SpineIntegrityError( - f"Python KeyError: There is no dictionary keys for the object class ids of relationship class '{name}'. " - "Probably a bug, please report." - ) - if not given_object_class_id_list: - raise SpineIntegrityError(f"At least one object class is needed for the relationship class '{name}'.") - if not all(id_ in object_class_ids for id_ in given_object_class_id_list): - raise SpineIntegrityError( - f"At least one of the object class ids of the relationship class '{name}' is not in the database." - ) - if name in current_items: - raise SpineIntegrityError( - f"There can't be more than one relationship class with the name '{name}'.", id=current_items[name] - ) - - -def check_wide_relationship(wide_item, current_items_by_name, current_items_by_obj_lst, relationship_classes, objects): - """Check whether the insertion of a relationship item - results in the violation of an integrity constraint. - - Args: - wide_item (dict): A wide relationship item to be checked. - current_items_by_name (dict): A dictionary mapping tuples (class_id, name) to ids of - relationships already in the database. - current_items_by_obj_lst (dict): A dictionary mapping tuples (class_id, object_name_list) to ids - of relationships already in the database. - relationship_classes (dict): A dictionary of wide relationship class items in the database keyed by id. - objects (dict): A dictionary of object items in the database keyed by id. - - Raises: - SpineIntegrityError: if the insertion of the item violates an integrity constraint. - """ - try: - name = wide_item["name"] - except KeyError: - raise SpineIntegrityError( - "Python KeyError: There is no dictionary key for the relationship name. Probably a bug, please report." - ) - if not name: - raise SpineIntegrityError("Relationship name is an empty string, which is not valid") - try: - class_id = wide_item["class_id"] - except KeyError: - raise SpineIntegrityError( - f"Python KeyError: There is no dictionary key for the relationship class id of relationship '{name}'. " - "Probably a bug, please report" - ) - if (class_id, name) in current_items_by_name: - raise SpineIntegrityError( - f"There's already a relationship called '{name}' in the same class.", - id=current_items_by_name[class_id, name], - ) - try: - object_class_id_list = relationship_classes[class_id]["object_class_id_list"] - except KeyError: - raise SpineIntegrityError(f"There is no object class id list for relationship '{name}'") - try: - object_id_list = tuple(wide_item["object_id_list"]) - except KeyError: - raise SpineIntegrityError(f"There is no object id list for relationship '{name}'") - try: - given_object_class_id_list = tuple(objects[id]["class_id"] for id in object_id_list) - except KeyError: - raise SpineIntegrityError(f"Some of the objects in relationship '{name}' are invalid.") - if given_object_class_id_list != object_class_id_list: - object_name_list = [objects[id]["name"] for id in object_id_list] - relationship_class_name = relationship_classes[class_id]["name"] - raise SpineIntegrityError( - f"Incorrect objects '{object_name_list}' for relationship class '{relationship_class_name}'." - ) - if (class_id, object_id_list) in current_items_by_obj_lst: - object_name_list = [objects[id]["name"] for id in object_id_list] - relationship_class_name = relationship_classes[class_id]["name"] - raise SpineIntegrityError( - "There's already a relationship between objects {} in class {}.".format( - object_name_list, relationship_class_name - ), - id=current_items_by_obj_lst[class_id, object_id_list], - ) - - -def check_entity_group(item, current_items, entities): - """Check whether the insertion of an entity group item - results in the violation of an integrity constraint. - - Args: - item (dict): An entity group item to be checked. - current_items (dict): A dictionary mapping tuples (entity_id, member_id) to ids of entity groups - already in the database. - entities (dict): A dictionary mapping entity class ids, to entity ids, to entity items already in the db - - Raises: - SpineIntegrityError: if the insertion of the item violates an integrity constraint. - """ - try: - entity_id = item["entity_id"] - except KeyError: - raise SpineIntegrityError( - "Python KeyError: There is no dictionary key for the entity id of entity group. " - "Probably a bug, please report." - ) - try: - member_id = item["member_id"] - except KeyError: - raise SpineIntegrityError( - "Python KeyError: There is no dictionary key for the member id of an entity group. " - "Probably a bug, please report." - ) - try: - entity_class_id = item["entity_class_id"] - except KeyError: - raise SpineIntegrityError( - "Python KeyError: There is no dictionary key for the entity class id of entity group. " - "Probably a bug, please report." - ) - ents = entities.get(entity_class_id) - if ents is None: - raise SpineIntegrityError("Entity class not found for entity group.") - entity = ents.get(entity_id) - if not entity: - raise SpineIntegrityError("No entity id for the entity group.") - member = ents.get(member_id) - if not member: - raise SpineIntegrityError("Entity group has no members.") - if (entity_id, member_id) in current_items: - raise SpineIntegrityError( - "{0} is already a member in {1}.".format(member["name"], entity["name"]), - id=current_items[entity_id, member_id], - ) - - -def check_parameter_definition(item, current_items, entity_class_ids, parameter_value_lists, list_values): - """Check whether the insertion of a parameter definition item - results in the violation of an integrity constraint. - - Args: - item (dict): A parameter definition item to be checked. - current_items (dict): A dictionary mapping tuples (entity_class_id, name) to ids of parameter definitions - already in the database. - entity_class_ids (Iterable): A set of entity class ids in the database. - parameter_value_lists (dict): A dictionary of value-lists in the database keyed by id. - - Raises: - SpineIntegrityError: if the insertion of the item violates an integrity constraint. - """ - name = item.get("name") - if not name: - raise SpineIntegrityError("No name provided for a parameter definition.") - entity_class_id = item.get("entity_class_id") - if not entity_class_id: - raise SpineIntegrityError(f"Missing entity class id for parameter definition '{name}'.") - if entity_class_id not in entity_class_ids: - raise SpineIntegrityError( - f"Entity class id for parameter definition '{name}' not found in the entity class " - "ids of the current database." - ) - if (entity_class_id, name) in current_items: - raise SpineIntegrityError( - "There's already a parameter called {0} in entity class with id {1}.".format(name, entity_class_id), - id=current_items[entity_class_id, name], - ) - replace_default_values_with_list_references(item, parameter_value_lists, list_values) - - -def check_parameter_value( - item, current_items, parameter_definitions, entities, parameter_value_lists, list_values, alternatives -): - """Check whether the insertion of a parameter value item results in the violation of an integrity constraint. - - Args: - item (dict): A parameter value item to be checked. - current_items (dict): A dictionary mapping tuples (entity_id, parameter_definition_id) - to ids of parameter values already in the database. - parameter_definitions (dict): A dictionary of parameter definition items in the database keyed by id. - entities (dict): A dictionary of entity items already in the database keyed by id. - parameter_value_lists (dict): A dictionary of value-lists in the database keyed by id. - list_values (dict): A dictionary of list-values in the database keyed by id. - alternatives (set) - - Raises: - SpineIntegrityError: if the insertion of the item violates an integrity constraint. - """ - try: - parameter_definition_id = item["parameter_definition_id"] - except KeyError: - raise SpineIntegrityError("Missing parameter identifier.") - try: - parameter_definition = parameter_definitions[parameter_definition_id] - except KeyError: - raise SpineIntegrityError("Parameter not found.") - alt_id = item.get("alternative_id") - if alt_id not in alternatives: - raise SpineIntegrityError("Alternative not found.") - entity_id = item.get("entity_id") - if not entity_id: - raise SpineIntegrityError("Missing object or relationship identifier.") - try: - entity_class_id = entities[entity_id]["class_id"] - except KeyError: - raise SpineIntegrityError("Entity not found") - if entity_class_id != parameter_definition["entity_class_id"]: - entity_name = entities[entity_id]["name"] - parameter_name = parameter_definition["name"] - raise SpineIntegrityError("Incorrect entity '{}' for parameter '{}'.".format(entity_name, parameter_name)) - if (entity_id, parameter_definition_id, alt_id) in current_items: - entity_name = entities[entity_id]["name"] - parameter_name = parameter_definition["name"] - raise SpineIntegrityError( - "The value of parameter '{}' for entity '{}' is already specified.".format(parameter_name, entity_name), - id=current_items[entity_id, parameter_definition_id, alt_id], - ) - replace_parameter_values_with_list_references(item, parameter_definitions, parameter_value_lists, list_values) - - -def replace_default_values_with_list_references(item, parameter_value_lists, list_values): - parameter_value_list_id = item.get("parameter_value_list_id") - return _replace_values_with_list_references( - "parameter_definition", item, parameter_value_list_id, parameter_value_lists, list_values - ) - - -def replace_parameter_values_with_list_references(item, parameter_definitions, parameter_value_lists, list_values): - parameter_definition_id = item["parameter_definition_id"] - parameter_definition = parameter_definitions[parameter_definition_id] - parameter_value_list_id = parameter_definition["parameter_value_list_id"] - return _replace_values_with_list_references( - "parameter_value", item, parameter_value_list_id, parameter_value_lists, list_values - ) - - -def _replace_values_with_list_references(item_type, item, parameter_value_list_id, parameter_value_lists, list_values): - if parameter_value_list_id is None: - return False - if parameter_value_list_id not in parameter_value_lists: - raise SpineIntegrityError("Parameter value list not found.") - value_id_list = parameter_value_lists[parameter_value_list_id] - if value_id_list is None: - raise SpineIntegrityError("Parameter value list is empty!") - value_key, type_key = { - "parameter_value": ("value", "type"), - "parameter_definition": ("default_value", "default_type"), - }[item_type] - value = dict.get(item, value_key) - value_type = dict.get(item, type_key) - try: - parsed_value = from_database(value, value_type) - except ParameterValueFormatError as err: - raise SpineIntegrityError(f"Invalid {value_key} '{value}': {err}") from None - if parsed_value is None: - return False - list_value_id = next((id_ for id_ in value_id_list if list_values.get(id_) == parsed_value), None) - if list_value_id is None: - valid_values = ", ".join(f"{dump_db_value(list_values.get(id_))[0].decode('utf8')!r}" for id_ in value_id_list) - raise SpineIntegrityError( - f"Invalid {value_key} '{parsed_value}' - it should be one from the parameter value list: {valid_values}." - ) - item[value_key] = str(list_value_id).encode("UTF8") - item[type_key] = "list_value_ref" - item["list_value_id"] = list_value_id - return True - - -def check_parameter_value_list(item, current_items): - """Check whether the insertion of a parameter value-list item results in the violation of an integrity constraint. - - Args: - item (dict): A parameter value-list item to be checked. - current_items (dict): A dictionary mapping names to ids of parameter value-lists already in the database. - - Raises: - SpineIntegrityError: if the insertion of the item violates an integrity constraint. - """ - try: - name = item["name"] - except KeyError: - raise SpineIntegrityError("Missing parameter value list name.") - if name in current_items: - raise SpineIntegrityError( - "There can't be more than one parameter value_list called '{}'.".format(name), id=current_items[name] - ) - - -def check_list_value(item, list_names_by_id, list_value_ids_by_index, list_value_ids_by_value): - """Check whether the insertion of a list value item results in the violation of an integrity constraint. - - Args: - item (dict): A list value item to be checked. - list_names_by_id (dict): Mapping parameter value list ids to names. - list_value_ids_by_index (dict): Mapping tuples (list id, index) to ids of existing list values. - list_value_ids_by_value (dict): Mapping tuples (list id, type, value) to ids of existing list values. - - Raises: - SpineIntegrityError: if the insertion of the item violates an integrity constraint. - """ - keys = {"parameter_value_list_id", "index", "value", "type"} - missing_keys = keys - item.keys() - if missing_keys: - raise SpineIntegrityError(f"Missing keys: {', '.join(missing_keys)}.") - list_id = item["parameter_value_list_id"] - list_name = list_names_by_id.get(list_id) - if list_name is None: - raise SpineIntegrityError("Unknown parameter value list identifier.") - index = item["index"] - type_ = item["type"] - value = item["value"] - dup_id = list_value_ids_by_index.get((list_id, index)) - if dup_id is not None: - raise SpineIntegrityError(f"'{list_name}' already has the index '{index}'.", id=dup_id) - dup_id = list_value_ids_by_value.get((list_id, type_, value)) - if dup_id is not None: - raise SpineIntegrityError(f"'{list_name}' already has the value '{from_database(value, type_)}'.", id=dup_id) - - -def check_metadata(item, metadata): - """Check whether the entity metadata item violates an integrity constraint. - - Args: - item (dict): An entity metadata item to be checked. - metadata (dict): Mapping from metadata name and value to metadata id. - - Raises: - SpineIntegrityError: if the item violates an integrity constraint. - """ - keys = {"name", "value"} - missing_keys = keys - item.keys() - if missing_keys: - raise SpineIntegrityError(f"Missing keys: {', '.join(missing_keys)}.") - - -def check_entity_metadata(item, entities, metadata): - """Check whether the entity metadata item violates an integrity constraint. - - Args: - item (dict): An entity metadata item to be checked. - entities (set of int): Available entity ids. - metadata (set of int): Available metadata ids. - - Raises: - SpineIntegrityError: if the item violates an integrity constraint. - """ - keys = {"entity_id", "metadata_id"} - missing_keys = keys - item.keys() - if missing_keys: - raise SpineIntegrityError(f"Missing keys: {', '.join(missing_keys)}.") - if item["entity_id"] not in entities: - raise SpineIntegrityError("Unknown entity identifier.") - if item["metadata_id"] not in metadata: - raise SpineIntegrityError("Unknown metadata identifier.") - - -def check_parameter_value_metadata(item, values, metadata): - """Check whether the parameter value metadata item violates an integrity constraint. - - Args: - item (dict): An entity metadata item to be checked. - values (set of int): Available parameter value ids. - metadata (set of int): Available metadata ids. - - Raises: - SpineIntegrityError: if the item violates an integrity constraint. - """ - keys = {"parameter_value_id", "metadata_id"} - missing_keys = keys - item.keys() - if missing_keys: - raise SpineIntegrityError(f"Missing keys: {', '.join(missing_keys)}.") - if item["parameter_value_id"] not in values: - raise SpineIntegrityError("Unknown parameter value identifier.") - if item["metadata_id"] not in metadata: - raise SpineIntegrityError("Unknown metadata identifier.") diff --git a/spinedb_api/db_cache.py b/spinedb_api/db_cache.py index 88bd5859..96419ebb 100644 --- a/spinedb_api/db_cache.py +++ b/spinedb_api/db_cache.py @@ -12,8 +12,10 @@ DB cache utility. """ +import uuid from contextlib import suppress from operator import itemgetter +from sqlalchemy.exc import ProgrammingError # TODO: Implement CacheItem.pop() to do lookup? @@ -114,7 +116,11 @@ def get_item(self, item_type, id_): def fetch_more(self, item_type): if item_type in self._fetched_item_types: return False - return bool(self.advance_query(item_type).result()) + # We try to advance the query directly. If we are in the wrong thread this will raise ProgrammingError. + try: + return bool(self.do_advance_query(item_type)) + except ProgrammingError: + return bool(self.advance_query(item_type).result()) def fetch_all(self, item_type): while self.fetch_more(item_type): @@ -131,62 +137,143 @@ def fetch_ref(self, item_type, id_): return self[item_type][id_] return None - def make_item(self, item_type, item): - """Returns a cache item. +class TableCache(dict): + def __init__(self, db_cache, item_type, *args, **kwargs): + """ Args: + db_cache (DBCache): the DB cache where this table cache belongs. item_type (str): the item type, equal to a table name - item (dict): the 'db item' to use as base + """ + super().__init__(*args, **kwargs) + self._db_cache = db_cache + self._item_type = item_type + self._existing = {} + + def existing(self, key, value): + """Returns the CacheItem that has the given value for the given unique constraint key, or None. + + Args: + key (tuple) + value (tuple) Returns: CacheItem """ - factory = { + self._db_cache.fetch_all(self._item_type) + return self._existing.get(key, {}).get(value) + + def values(self): + return (x for x in super().values() if x.is_valid()) + + @property + def _item_factory(self): + return { "entity_class": EntityClassItem, "entity": EntityItem, "entity_group": EntityGroupItem, "parameter_definition": ParameterDefinitionItem, "parameter_value": ParameterValueItem, + "list_value": ListValueItem, "scenario": ScenarioItem, "scenario_alternative": ScenarioAlternativeItem, - }.get(item_type, CacheItem) - return factory(self, item_type, **item) + "metadata": MetadataItem, + "entity_metadata": EntityMetadataItem, + "parameter_value_metadata": ParameterValueMetadataItem, + }.get(self._item_type, CacheItem) + def _make_item(self, item): + """Returns a cache item. -class TableCache(dict): - def __init__(self, db_cache, item_type, *args, **kwargs): - """ Args: - db_cache (DBCache): the DB cache where this table cache belongs. - item_type (str): the item type, equal to a table name - """ - super().__init__(*args, **kwargs) - self._db_cache = db_cache - self._item_type = item_type + item (dict): the 'db item' to use as base - def values(self): - return (x for x in super().values() if x.is_valid()) + Returns: + CacheItem + """ + return self._item_factory(self._db_cache, self._item_type, **item) + + def _current_item(self, item): + id_ = item.get("id") + if isinstance(id_, int): + # id is an int, easy + return self.get(id_) + if isinstance(id_, dict): + # id is a dict specifying the values for one of the unique constraints + return self._current_item_from_dict_id(id_) + if id_ is None: + # No id. Try to build the dict id from the item itself. Used by import_data. + for key in self._item_factory.unique_constraint: + dict_id = {k: item.get(k) for k in key} + current_item = self._current_item_from_dict_id(dict_id) + if current_item: + return current_item + + def _current_item_from_dict_id(self, dict_id): + key, value = zip(*dict_id.items()) + return self.existing(key, value) + + def check_item(self, item, for_update=False): + if for_update: + current_item = self._current_item(item) + if current_item is None: + return None, f"no {self._item_type} matching {item} to update" + item = {**current_item, **item} + item["id"] = current_item["id"] + else: + current_item = None + candidate_item = self._make_item(item) + candidate_item.resolve_inverse_references() + missing_ref = candidate_item.missing_ref() + if missing_ref: + return None, f"missing {missing_ref} for {self._item_type}" + try: + for key, value in candidate_item.unique_values(): + existing_item = self.existing(key, value) + if existing_item not in (None, current_item) and existing_item.is_valid(): + kv_parts = [f"{k} '{', '.join(v) if isinstance(v, tuple) else v}'" for k, v in zip(key, value)] + head, tail = kv_parts[:-1], kv_parts[-1] + head_str = ", ".join(head) + main_parts = [head_str, tail] if head_str else [tail] + key_val = " and ".join(main_parts) + return None, f"there's already a {self._item_type} with {key_val}" + except KeyError as e: + return None, f"missing {e} for {self._item_type}" + return candidate_item._asdict(), None + + def _add_to_existing(self, item): + for key, value in item.unique_values(): + self._existing.setdefault(key, {})[value] = item + + def _remove_from_existing(self, item): + for key, value in item.unique_values(): + self._existing.get(key, {}).pop(value, None) def add_item(self, item, new=False): - self[item["id"]] = new_item = self._db_cache.make_item(self._item_type, item) + self[item["id"]] = new_item = self._make_item(item) + self._add_to_existing(new_item) new_item.new = new return new_item def update_item(self, item): current_item = self[item["id"]] + self._remove_from_existing(current_item) current_item.dirty = True current_item.update(item) + self._add_to_existing(current_item) current_item.cascade_update() def remove_item(self, id_): current_item = self.get(id_) if current_item is not None: + self._remove_from_existing(current_item) current_item.cascade_remove() return current_item def restore_item(self, id_): current_item = self.get(id_) if current_item is not None: + self._add_to_existing(current_item) current_item.cascade_restore() return current_item @@ -194,6 +281,10 @@ def restore_item(self, id_): class CacheItem(dict): """A dictionary that represents an db item.""" + unique_constraint = (("name",),) + _references = {} + _inverse_references = {} + def __init__(self, db_cache, item_type, *args, **kwargs): """ Args: @@ -214,6 +305,23 @@ def __init__(self, db_cache, item_type, *args, **kwargs): self.new = False self.dirty = False + def missing_ref(self): + for key, (ref_type, _ref_key) in self._references.values(): + try: + ref_id = self[key] + except KeyError: + return key + if isinstance(ref_id, tuple): + for x in ref_id: + if not self._get_ref(ref_type, x): + return key + elif not self._get_ref(ref_type, ref_id): + return key + + def unique_values(self): + for key in self.unique_constraint: + yield key, tuple(self[k] for k in key) + @property def removed(self): return self._removed @@ -236,27 +344,24 @@ def __repr__(self): return f"{self._item_type}{self._extended()}" def _extended(self): - return {**self, **{key: self[key] for key in self._reference_keys()}} + return {**self, **{key: self[key] for key in self._references}} def _asdict(self): return dict(**self) - def _reference_keys(self): - return () - - def _get_ref(self, ref_type, ref_id, source_key): + def _get_ref(self, ref_type, ref_id, strong=True): ref = self._db_cache.get_item(ref_type, ref_id) if not ref: - if source_key not in self._reference_keys(): + if not strong: return {} ref = self._db_cache.fetch_ref(ref_type, ref_id) if not ref: self._corrupted = True return {} - return self._handle_ref(ref, source_key) + return self._handle_ref(ref, strong) - def _handle_ref(self, ref, source_key): - if source_key in self._reference_keys(): + def _handle_ref(self, ref, strong): + if strong: ref.add_referrer(self) if ref.removed: self._to_remove = True @@ -272,12 +377,6 @@ def get(self, key, default=None): except KeyError: return default - def copy(self): - return type(self)(self._db_cache, self._item_type, **self) - - def updated(self, other): - return type(self)(self._db_cache, self._item_type, **{**self, **other}) - def is_valid(self): if self._valid is not None: return self._valid @@ -285,7 +384,7 @@ def is_valid(self): return False self._to_remove = False self._corrupted = False - for key in self._reference_keys(): + for key in self._references: _ = self[key] if self._to_remove: self.cascade_remove() @@ -348,6 +447,30 @@ def call_update_callbacks(self): obsolete.add(callback) self.update_callbacks -= obsolete + def __getitem__(self, key): + ref = self._references.get(key) + if ref: + key, (ref_type, ref_key) = ref + ref_id = self[key] + if isinstance(ref_id, tuple): + return tuple(self._get_ref(ref_type, x).get(ref_key) for x in ref_id) + return self._get_ref(ref_type, ref_id).get(ref_key) + return super().__getitem__(key) + + def resolve_inverse_references(self): + for src_key, (id_key, (ref_type, ref_key)) in self._inverse_references.items(): + id_value = tuple(dict.get(self, k) or self.get(k) for k in id_key) + if None in id_value: + continue + table_cache = self._db_cache.table_cache(ref_type) + with suppress(AttributeError): # NoneType has no attribute id, happens when existing() returns None + self[src_key] = ( + tuple(table_cache.existing(ref_key, v).id for v in zip(*id_value)) + if all(isinstance(v, tuple) for v in id_value) + else table_cache.existing(ref_key, id_value).id + ) + # FIXME: Do we need to catch the AttributeError and give it to the user instead?? + class DisplayIconMixin: def __getitem__(self, key): @@ -364,6 +487,9 @@ def __getitem__(self, key): class EntityClassItem(DisplayIconMixin, DescriptionMixin, CacheItem): + _references = {"dimension_name_list": ("dimension_id_list", ("entity_class", "name"))} + _inverse_references = {"dimension_id_list": (("dimension_name_list",), ("entity_class", ("name",)))} + def __init__(self, *args, **kwargs): dimension_id_list = kwargs.get("dimension_id_list") if dimension_id_list is None: @@ -373,16 +499,20 @@ def __init__(self, *args, **kwargs): kwargs["dimension_id_list"] = tuple(dimension_id_list) super().__init__(*args, **kwargs) - def __getitem__(self, key): - if key == "dimension_name_list": - return tuple(self._get_ref("entity_class", id_, key).get("name") for id_ in self["dimension_id_list"]) - return super().__getitem__(key) - - def _reference_keys(self): - return super()._reference_keys() + ("dimension_name_list",) - class EntityItem(DescriptionMixin, CacheItem): + unique_constraint = (("class_name", "name"), ("class_name", "byname")) + _references = { + "class_name": ("class_id", ("entity_class", "name")), + "dimension_id_list": ("class_id", ("entity_class", "dimension_id_list")), + "dimension_name_list": ("class_id", ("entity_class", "dimension_name_list")), + "element_name_list": ("element_id_list", ("entity", "name")), + } + _inverse_references = { + "class_id": (("class_name",), ("entity_class", ("name",))), + "element_id_list": (("dimension_name_list", "element_name_list"), ("entity", ("class_name", "name"))), + } + def __init__(self, *args, **kwargs): element_id_list = kwargs.get("element_id_list") if element_id_list is None: @@ -393,42 +523,59 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def __getitem__(self, key): - if key == "class_name": - return self._get_ref("entity_class", self["class_id"], key).get("name") - if key == "dimension_id_list": - return self._get_ref("entity_class", self["class_id"], key).get("dimension_id_list") - if key == "dimension_name_list": - return self._get_ref("entity_class", self["class_id"], key).get("dimension_name_list") - if key == "element_name_list": - return tuple(self._get_ref("entity", id_, key).get("name") for id_ in self["element_id_list"]) if key == "byname": return self["element_name_list"] or (self["name"],) return super().__getitem__(key) - def _reference_keys(self): - return super()._reference_keys() + ( - "class_name", - "dimension_id_list", - "dimension_name_list", - "element_name_list", - ) + def resolve_inverse_references(self): + super().resolve_inverse_references() + self._fill_name() + + def _fill_name(self): + if "name" in self: + return + base_name = self["class_name"] + "_" + "__".join(self["element_name_list"]) + name = base_name + table_cache = self._db_cache.table_cache(self._item_type) + while table_cache.existing(("class_name", "name"), (self["class_name"], name)) is not None: + name = base_name + uuid.uuid4().hex + self["name"] = name + +class EntityGroupItem(CacheItem): + unique_constraint = (("group_name", "member_name"),) + _references = { + "class_name": ("entity_class_id", ("entity_class", "name")), + "group_name": ("entity_id", ("entity", "name")), + "member_name": ("member_id", ("entity", "name")), + "dimension_id_list": ("entity_class_id", ("entity_class", "dimension_id_list")), + } + _inverse_references = { + "entity_class_id": (("class_name",), ("entity_class", ("name",))), + "entity_id": (("class_name", "group_name"), ("entity", ("class_name", "name"))), + "member_id": (("class_name", "member_name"), ("entity", ("class_name", "name"))), + } -class ParameterMixin: def __getitem__(self, key): - if key in ("dimension_id_list", "dimension_name_list"): - return self._get_ref("entity_class", self["entity_class_id"], key)[key] - if key == "entity_class_name": - return self._get_ref("entity_class", self["entity_class_id"], key)["name"] - if key == "parameter_value_list_id": - return dict.get(self, key) + if key == "class_id": + return self["entity_class_id"] + if key == "group_id": + return self["entity_id"] return super().__getitem__(key) - def _reference_keys(self): - return super()._reference_keys() + ("entity_class_name", "dimension_id_list", "dimension_name_list") +class ParameterDefinitionItem(DescriptionMixin, CacheItem): + unique_constraint = (("entity_class_name", "name"),) + _references = { + "entity_class_name": ("entity_class_id", ("entity_class", "name")), + "dimension_id_list": ("entity_class_id", ("entity_class", "dimension_id_list")), + "dimension_name_list": ("entity_class_id", ("entity_class", "dimension_name_list")), + } + _inverse_references = { + "entity_class_id": (("entity_class_name",), ("entity_class", ("name",))), + "parameter_value_list_id": (("parameter_value_list_name",), ("parameter_value_list", ("name",))), + } -class ParameterDefinitionItem(DescriptionMixin, ParameterMixin, CacheItem): def __init__(self, *args, **kwargs): if kwargs.get("list_value_id") is None: kwargs["list_value_id"] = ( @@ -441,16 +588,40 @@ def __getitem__(self, key): return super().__getitem__("name") if key == "value_list_id": return super().__getitem__("parameter_value_list_id") + if key == "parameter_value_list_id": + return dict.get(self, key) if key == "value_list_name": - return self._get_ref("parameter_value_list", self["value_list_id"], key).get("name") + return self._get_ref("parameter_value_list", self["value_list_id"], strong=False).get("name") if key in ("default_value", "default_type"): if self["list_value_id"] is not None: - return self._get_ref("list_value", self["list_value_id"], key).get(key.split("_")[1]) + return self._get_ref("list_value", self["list_value_id"], strong=False).get(key.split("_")[1]) return dict.get(self, key) return super().__getitem__(key) -class ParameterValueItem(ParameterMixin, CacheItem): +class ParameterValueItem(CacheItem): + unique_constraint = (("parameter_definition_name", "entity_byname", "alternative_name"),) + _references = { + "entity_class_name": ("entity_class_id", ("entity_class", "name")), + "dimension_id_list": ("entity_class_id", ("entity_class", "dimension_id_list")), + "dimension_name_list": ("entity_class_id", ("entity_class", "dimension_name_list")), + "parameter_definition_name": ("parameter_definition_id", ("parameter_definition", "name")), + "parameter_value_list_id": ("parameter_definition_id", ("parameter_definition", "parameter_value_list_id")), + "entity_name": ("entity_id", ("entity", "name")), + "entity_byname": ("entity_id", ("entity", "byname")), + "element_id_list": ("entity_id", ("entity", "element_id_list")), + "element_name_list": ("entity_id", ("entity", "element_name_list")), + "alternative_name": ("alternative_id", ("alternative", "name")), + } + _inverse_references = { + "parameter_definition_id": ( + ("entity_class_name", "parameter_definition_name"), + ("parameter_definition", ("entity_class_name", "name")), + ), + "entity_id": (("entity_class_name", "entity_byname"), ("entity", ("class_name", "byname"))), + "alternative_id": (("alternative_name",), ("alternative", ("name",))), + } + def __init__(self, *args, **kwargs): if kwargs.get("list_value_id") is None: kwargs["list_value_id"] = int(kwargs["value"]) if kwargs.get("type") == "list_value_ref" else None @@ -459,48 +630,17 @@ def __init__(self, *args, **kwargs): def __getitem__(self, key): if key == "parameter_id": return super().__getitem__("parameter_definition_id") - if key == "parameter_name": - return self._get_ref("parameter_definition", self["parameter_definition_id"], key).get("name") - if key == "entity_name": - return self._get_ref("entity", self["entity_id"], key)["name"] - if key == "entity_byname": - return self._get_ref("entity", self["entity_id"], key)["byname"] - if key in ("element_id_list", "element_name_list"): - return self._get_ref("entity", self["entity_id"], key)[key] - if key == "alternative_name": - return self._get_ref("alternative", self["alternative_id"], key).get("name") if key in ("value", "type") and self["list_value_id"] is not None: - return self._get_ref("list_value", self["list_value_id"], key).get(key) + return self._get_ref("list_value", self["list_value_id"], strong=False).get(key) return super().__getitem__(key) - def _reference_keys(self): - return super()._reference_keys() + ( - "parameter_name", - "alternative_name", - "entity_name", - "element_id_list", - "element_name_list", - ) - -class EntityGroupItem(CacheItem): - def __getitem__(self, key): - if key == "class_id": - return self["entity_class_id"] - if key == "group_id": - return self["entity_id"] - if key == "class_name": - return self._get_ref("entity_class", self["entity_class_id"], key)["name"] - if key == "group_name": - return self._get_ref("entity", self["entity_id"], key)["name"] - if key == "member_name": - return self._get_ref("entity", self["member_id"], key)["name"] - if key == "dimension_id_list": - return self._get_ref("entity_class", self["entity_class_id"], key)["dimension_id_list"] - return super().__getitem__(key) - - def _reference_keys(self): - return super()._reference_keys() + ("class_name", "group_name", "member_name", "dimension_id_list") +class ListValueItem(CacheItem): + unique_constraint = (("parameter_value_list_name", "value"), ("parameter_value_list_name", "index")) + _references = {"parameter_value_list_name": ("parameter_value_list_id", ("parameter_value_list", "name"))} + _inverse_references = { + "parameter_value_list_id": (("parameter_value_list_name",), ("parameter_value_list", ("name",))), + } class ScenarioItem(CacheItem): @@ -521,20 +661,59 @@ def __getitem__(self, key): class ScenarioAlternativeItem(CacheItem): + unique_constraint = (("scenario_name", "alternative_name"), ("scenario_name", "rank")) + _references = { + "scenario_name": ("scenario_id", ("scenario", "name")), + "alternative_name": ("alternative_id", ("alternative", "name")), + } + _inverse_references = { + "scenario_id": (("scenario_name",), ("scenario", ("name",))), + "alternative_id": (("alternative_name",), ("alternative", ("name",))), + "before_alternative_id": (("before_alternative_name",), ("alternative", ("name",))), + } + def __getitem__(self, key): - if key == "scenario_name": - return self._get_ref("scenario", self["scenario_id"], key).get("name") - if key == "alternative_name": - return self._get_ref("alternative", self["alternative_id"], key).get("name") if key == "before_alternative_name": - return self._get_ref("alternative", self["before_alternative_id"], key).get("name") + return self._get_ref("alternative", self["before_alternative_id"], strong=False).get("name") if key == "before_alternative_id": - scenario = self._get_ref("scenario", self["scenario_id"], None) + scenario = self._get_ref("scenario", self["scenario_id"], strong=False) try: return scenario["alternative_id_list"][self["rank"]] except IndexError: return None return super().__getitem__(key) - def _reference_keys(self): - return super()._reference_keys() + ("scenario_name", "alternative_name") + +class MetadataItem(CacheItem): + unique_constraint = (("name", "value"),) + + +class EntityMetadataItem(CacheItem): + unique_constraint = (("entity_name", "metadata_name"),) + _references = { + "entity_name": ("entity_id", ("entity", "name")), + "metadata_name": ("metadata_id", ("metadata", "name")), + "metadata_value": ("metadata_id", ("metadata", "value")), + } + _inverse_references = { + "entity_id": (("entity_class_name", "entity_byname"), ("entity", ("class_name", "byname"))), + "metadata_id": (("metadata_name", "metadata_value"), ("metadata", ("name", "value"))), + } + + +class ParameterValueMetadataItem(CacheItem): + unique_constraint = (("parameter_definition_name", "entity_byname", "alternative_name", "metadata_name"),) + _references = { + "parameter_definition_name": ("parameter_value_id", ("parameter_value", "parameter_definition_name")), + "entity_byname": ("parameter_value_id", ("parameter_value", "entity_byname")), + "alternative_name": ("parameter_value_id", ("parameter_value", "alternative_name")), + "metadata_name": ("metadata_id", ("metadata", "name")), + "metadata_value": ("metadata_id", ("metadata", "value")), + } + _inverse_references = { + "parameter_value_id": ( + ("parameter_definition_name", "entity_byname", "alternative_name"), + ("parameter_value", ("parameter_definition_name", "entity_byname", "alternative_name")), + ), + "metadata_id": (("metadata_name", "metadata_value"), ("metadata", ("name", "value"))), + } diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 171eefd2..bb45dcff 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -16,7 +16,6 @@ from .db_mapping_base import DatabaseMappingBase from .db_mapping_add_mixin import DatabaseMappingAddMixin -from .db_mapping_check_mixin import DatabaseMappingCheckMixin from .db_mapping_update_mixin import DatabaseMappingUpdateMixin from .db_mapping_remove_mixin import DatabaseMappingRemoveMixin from .db_mapping_commit_mixin import DatabaseMappingCommitMixin @@ -24,7 +23,6 @@ class DatabaseMapping( - DatabaseMappingCheckMixin, DatabaseMappingAddMixin, DatabaseMappingUpdateMixin, DatabaseMappingRemoveMixin, diff --git a/spinedb_api/db_mapping_add_mixin.py b/spinedb_api/db_mapping_add_mixin.py index 2b38c87c..c489d3c5 100644 --- a/spinedb_api/db_mapping_add_mixin.py +++ b/spinedb_api/db_mapping_add_mixin.py @@ -141,14 +141,16 @@ def add_items(self, tablename, *items, check=True, strict=False): Returns: set: ids or items successfully added - list(SpineIntegrityError): found violations + list(str): found violations """ if check: - checked_items, intgr_error_log = self.check_items(tablename, *items, for_update=False, strict=strict) + checked_items, errors = self.check_items(tablename, *items) else: - checked_items, intgr_error_log = list(items), [] + checked_items, errors = list(items), [] + if errors and strict: + raise SpineDBAPIError(", ".join(errors)) _ = self._add_items(tablename, *checked_items) - return checked_items, intgr_error_log + return checked_items, errors def _add_items(self, tablename, *items): """Add items to cache without checking integrity. @@ -381,138 +383,3 @@ def add_ext_entity_metadata(self, *items, check=True, strict=False): def add_ext_parameter_value_metadata(self, *items, check=True, strict=False): return self._add_ext_item_metadata("parameter_value_metadata", *items, check=check, strict=strict) - - def _add_entity_classes(self, *items): - return self._add_items("entity_class", *items) - - def _add_entities(self, *items): - return self._add_items("entity", *items) - - def _add_object_classes(self, *items): - return self._add_items("object_class", *items) - - def _add_objects(self, *items): - return self._add_items("object", *items) - - def _add_wide_relationship_classes(self, *items): - return self._add_items("relationship_class", *items) - - def _add_wide_relationships(self, *items): - return self._add_items("relationship", *items) - - def _add_parameter_definitions(self, *items): - return self._add_items("parameter_definition", *items) - - def _add_parameter_values(self, *items): - return self._add_items("parameter_value", *items) - - def _add_parameter_value_lists(self, *items): - return self._add_items("parameter_value_list", *items) - - def _add_list_values(self, *items): - return self._add_items("list_value", *items) - - def _add_alternatives(self, *items): - return self._add_items("alternative", *items) - - def _add_scenarios(self, *items): - return self._add_items("scenario", *items) - - def _add_scenario_alternatives(self, *items): - return self._add_items("scenario_alternative", *items) - - def _add_entity_groups(self, *items): - return self._add_items("entity_group", *items) - - def _add_metadata(self, *items): - return self._add_items("metadata", *items) - - def _add_parameter_value_metadata(self, *items): - return self._add_items("parameter_value_metadata", *items) - - def _add_entity_metadata(self, *items): - return self._add_items("entity_metadata", *items) - - def add_object_class(self, **kwargs): - """Stage an object class item for insertion. - - :raises SpineIntegrityError: if the insertion of the item violates an integrity constraint. - - :returns: - - **new_item** -- The item successfully staged for insertion. - - :rtype: :class:`~sqlalchemy.util.KeyedTuple` - """ - sq = self.object_class_sq - ids, _ = self.add_object_classes(kwargs, strict=True) - return self.query(sq).filter(sq.c.id.in_(ids)).one_or_none() - - def add_object(self, **kwargs): - """Stage an object item for insertion. - - :raises SpineIntegrityError: if the insertion of the item violates an integrity constraint. - - :returns: - - **new_item** -- The item successfully staged for insertion. - - :rtype: :class:`~sqlalchemy.util.KeyedTuple` - """ - sq = self.object_sq - ids, _ = self.add_objects(kwargs, strict=True) - return self.query(sq).filter(sq.c.id.in_(ids)).one_or_none() - - def add_wide_relationship_class(self, **kwargs): - """Stage a relationship class item for insertion. - - :raises SpineIntegrityError: if the insertion of the item violates an integrity constraint. - - :returns: - - **new_item** -- The item successfully staged for insertion. - - :rtype: :class:`~sqlalchemy.util.KeyedTuple` - """ - sq = self.wide_relationship_class_sq - ids, _ = self.add_wide_relationship_classes(kwargs, strict=True) - return self.query(sq).filter(sq.c.id.in_(ids)).one_or_none() - - def add_wide_relationship(self, **kwargs): - """Stage a relationship item for insertion. - - :raises SpineIntegrityError: if the insertion of the item violates an integrity constraint. - - :returns: - - **new_item** -- The item successfully staged for insertion. - - :rtype: :class:`~sqlalchemy.util.KeyedTuple` - """ - sq = self.wide_relationship_sq - ids, _ = self.add_wide_relationships(kwargs, strict=True) - return self.query(sq).filter(sq.c.id.in_(ids)).one_or_none() - - def add_parameter_definition(self, **kwargs): - """Stage a parameter definition item for insertion. - - :raises SpineIntegrityError: if the insertion of the item violates an integrity constraint. - - :returns: - - **new_item** -- The item successfully staged for insertion. - - :rtype: :class:`~sqlalchemy.util.KeyedTuple` - """ - sq = self.parameter_definition_sq - ids, _ = self.add_parameter_definitions(kwargs, strict=True) - return self.query(sq).filter(sq.c.id.in_(ids)).one_or_none() - - def add_parameter_value(self, **kwargs): - """Stage a parameter value item for insertion. - - :raises SpineIntegrityError: if the insertion of the item violates an integrity constraint. - - :returns: - - **new_item** -- The item successfully staged for insertion. - - :rtype: :class:`~sqlalchemy.util.KeyedTuple` - """ - sq = self.parameter_value_sq - ids, _ = self.add_parameter_values(kwargs, strict=True) - return self.query(sq).filter(sq.c.id.in_(ids)).one_or_none() diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index be9c3af4..581ca722 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -210,8 +210,8 @@ def __init__( "parameter_definition": "parameter_definition_sq", "parameter_value": "parameter_value_sq", "metadata": "metadata_sq", - "entity_metadata": "ext_entity_metadata_sq", - "parameter_value_metadata": "ext_parameter_value_metadata_sq", + "entity_metadata": "entity_metadata_sq", + "parameter_value_metadata": "parameter_value_metadata_sq", "commit": "commit_sq", } self.ancestor_tablenames = { @@ -2058,6 +2058,18 @@ def _metadata_usage_counts(self): usage_counts[entry.metadata_id] += 1 return usage_counts + def check_items(self, tablename, *items, for_update=False): + tablename = self._real_tablename(tablename) + table_cache = self.cache.table_cache(tablename) + checked_items, errors = [], [] + for item in items: + checked_item, error = table_cache.check_item(item, for_update=for_update) + if error: + errors.append(error) + else: + checked_items.append(checked_item) + return checked_items, errors + def __del__(self): try: self.close() diff --git a/spinedb_api/db_mapping_check_mixin.py b/spinedb_api/db_mapping_check_mixin.py deleted file mode 100644 index 174eaf64..00000000 --- a/spinedb_api/db_mapping_check_mixin.py +++ /dev/null @@ -1,817 +0,0 @@ -###################################################################################################################### -# Copyright (C) 2017-2022 Spine project consortium -# This file is part of Spine Database API. -# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser -# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your -# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; -# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General -# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with -# this program. If not, see . -###################################################################################################################### - -"""Provides :class:`.DatabaseMappingCheckMixin`. - -""" -# TODO: Review docstrings, they are almost good - -from contextlib import contextmanager -from .exception import SpineIntegrityError -from .check_functions import ( - check_alternative, - check_scenario, - check_scenario_alternative, - check_entity_class, - check_entity, - check_object_class, - check_object, - check_wide_relationship_class, - check_wide_relationship, - check_entity_group, - check_parameter_definition, - check_parameter_value, - check_parameter_value_list, - check_list_value, - check_entity_metadata, - check_metadata, - check_parameter_value_metadata, -) -from .parameter_value import from_database - - -# NOTE: To check for an update we remove the current instance from our lookup dictionary, -# check for an insert of the updated instance, -# and finally reinsert the instance to the dictionary -class DatabaseMappingCheckMixin: - """Provides methods to check whether insert and update operations violate Spine db integrity constraints.""" - - def check_items(self, tablename, *items, for_update=False, strict=False): - return { - "alternative": self.check_alternatives, - "scenario": self.check_scenarios, - "scenario_alternative": self.check_scenario_alternatives, - "entity": self.check_entities, - "entity_class": self.check_entity_classes, - "object": self.check_objects, - "object_class": self.check_object_classes, - "relationship_class": self.check_wide_relationship_classes, - "relationship": self.check_wide_relationships, - "entity_group": self.check_entity_groups, - "parameter_definition": self.check_parameter_definitions, - "parameter_value": self.check_parameter_values, - "parameter_value_list": self.check_parameter_value_lists, - "list_value": self.check_list_values, - "metadata": self.check_metadata, - "entity_metadata": self.check_entity_metadata, - "parameter_value_metadata": self.check_parameter_value_metadata, - }[tablename](*items, for_update=for_update, strict=strict) - - def check_alternatives(self, *items, for_update=False, strict=False): - """Check whether alternatives passed as argument respect integrity constraints. - - Args: - items (Iterable): One or more Python :class:`dict` objects representing the items to be checked. - strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` - if one of the items violates an integrity constraint. - - Returns - list: items that passed the check. - list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. - """ - self.fetch_all({"alternative"}, include_ancestors=True) - cache = self.cache - intgr_error_log = [] - checked_items = list() - alternative_ids = {x.name: x.id for x in cache.get("alternative", {}).values()} - for item in items: - try: - with self._manage_stocks( - "alternative", item, {("name",): alternative_ids}, for_update, intgr_error_log - ) as item: - check_alternative(item, alternative_ids) - checked_items.append(item) - except SpineIntegrityError as e: - if strict: - raise e - intgr_error_log.append(e) - return checked_items, intgr_error_log - - def check_scenarios(self, *items, for_update=False, strict=False): - """Check whether scenarios passed as argument respect integrity constraints. - - Args: - items (Iterable): One or more Python :class:`dict` objects representing the items to be checked. - strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` - if one of the items violates an integrity constraint. - - Returns - list: items that passed the check. - list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. - """ - self.fetch_all({"scenario"}, include_ancestors=True) - cache = self.cache - intgr_error_log = [] - checked_items = list() - scenario_ids = {x.name: x.id for x in cache.get("scenario", {}).values()} - for item in items: - try: - with self._manage_stocks( - "scenario", item, {("name",): scenario_ids}, for_update, intgr_error_log - ) as item: - check_scenario(item, scenario_ids) - checked_items.append(item) - except SpineIntegrityError as e: - if strict: - raise e - intgr_error_log.append(e) - return checked_items, intgr_error_log - - def check_scenario_alternatives(self, *items, for_update=False, strict=False): - """Check whether scenario alternatives passed as argument respect integrity constraints. - - Args: - items (Iterable): One or more Python :class:`dict` objects representing the items to be checked. - strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` - if one of the items violates an integrity constraint. - - Returns - list: items that passed the check. - list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. - """ - self.fetch_all({"scenario_alternative"}, include_ancestors=True) - cache = self.cache - intgr_error_log = [] - checked_items = list() - ids_by_alt_id = {} - ids_by_rank = {} - for item in cache.get("scenario_alternative", {}).values(): - ids_by_alt_id[item.scenario_id, item.alternative_id] = item.id - ids_by_rank[item.scenario_id, item.rank] = item.id - scenario_names = {s.id: s.name for s in cache.get("scenario", {}).values()} - alternative_names = {s.id: s.name for s in cache.get("alternative", {}).values()} - for item in items: - try: - with self._manage_stocks( - "scenario_alternative", - item, - {("scenario_id", "alternative_id"): ids_by_alt_id, ("scenario_id", "rank"): ids_by_rank}, - for_update, - intgr_error_log, - ) as item: - check_scenario_alternative(item, ids_by_alt_id, ids_by_rank, scenario_names, alternative_names) - checked_items.append(item) - except SpineIntegrityError as e: - if strict: - raise e - intgr_error_log.append(e) - return checked_items, intgr_error_log - - def check_entity_classes(self, *items, for_update=False, strict=False): - """Check whether entity classes passed as argument respect integrity constraints. - - Args: - items (Iterable): One or more Python :class:`dict` objects representing the items to be checked. - strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` - if one of the items violates an integrity constraint. - - Returns - list: items that passed the check. - list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. - """ - self.fetch_all({"entity_class"}, include_ancestors=True) - cache = self.cache - intgr_error_log = [] - checked_items = list() - entity_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values()} - for item in items: - try: - with self._manage_stocks( - "entity_class", item, {("name",): entity_class_ids}, for_update, intgr_error_log - ) as item: - check_entity_class(item, entity_class_ids) - checked_items.append(item) - except SpineIntegrityError as e: - if strict: - raise e - intgr_error_log.append(e) - return checked_items, intgr_error_log - - def check_entities(self, *items, for_update=False, strict=False): - """Check whether entities passed as argument respect integrity constraints. - - Args: - items (Iterable): One or more Python :class:`dict` objects representing the items to be checked. - strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` - if one of the items violates an integrity constraint. - - Returns - list: items that passed the check. - list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. - """ - self.fetch_all({"entity"}, include_ancestors=True) - cache = self.cache - intgr_error_log = [] - checked_items = list() - entity_ids_by_name = {(x.class_id, x.name): x.id for x in cache.get("entity", {}).values()} - entity_ids_by_el_id_lst = {(x.class_id, x.element_id_list): x.id for x in cache.get("entity", {}).values()} - entity_classes = { - x.id: {"dimension_id_list": x.dimension_id_list, "name": x.name} - for x in cache.get("entity_class", {}).values() - } - entities = {x.id: {"class_id": x.class_id, "name": x.name} for x in cache.get("entity", {}).values()} - for item in items: - try: - with self._manage_stocks( - "entity", - item, - { - ("class_id", "name"): entity_ids_by_name, - ("class_id", "element_id_list"): entity_ids_by_el_id_lst, - }, - for_update, - intgr_error_log, - ) as item: - check_entity( - item, - entity_ids_by_name, - entity_ids_by_el_id_lst, - entity_classes, - entities, - ) - checked_items.append(item) - except SpineIntegrityError as e: - if strict: - raise e - intgr_error_log.append(e) - return checked_items, intgr_error_log - - def check_object_classes(self, *items, for_update=False, strict=False): - """Check whether object classes passed as argument respect integrity constraints. - - Args: - items (Iterable): One or more Python :class:`dict` objects representing the items to be checked. - strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` - if one of the items violates an integrity constraint. - - Returns - list: items that passed the check. - list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. - """ - self.fetch_all({"entity_class"}, include_ancestors=True) - cache = self.cache - intgr_error_log = [] - checked_items = list() - object_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values() if not x.dimension_id_list} - for item in items: - try: - with self._manage_stocks( - "entity_class", item, {("name",): object_class_ids}, for_update, intgr_error_log - ) as item: - check_object_class(item, object_class_ids) - checked_items.append(item) - except SpineIntegrityError as e: - if strict: - raise e - intgr_error_log.append(e) - return checked_items, intgr_error_log - - def check_objects(self, *items, for_update=False, strict=False): - """Check whether objects passed as argument respect integrity constraints. - Args: - items (Iterable): One or more Python :class:`dict` objects representing the items to be checked. - strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` - if one of the items violates an integrity constraint. - - Returns - list: items that passed the check. - list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. - """ - self.fetch_all({"entity"}, include_ancestors=True) - cache = self.cache - intgr_error_log = [] - checked_items = list() - object_ids = {(x.class_id, x.name): x.id for x in cache.get("entity", {}).values() if not x.element_id_list} - object_class_ids = [x.id for x in cache.get("entity_class", {}).values() if not x.dimension_id_list] - for item in items: - try: - with self._manage_stocks( - "entity", item, {("class_id", "name"): object_ids}, for_update, intgr_error_log - ) as item: - check_object(item, object_ids, object_class_ids) - checked_items.append(item) - except SpineIntegrityError as e: - if strict: - raise e - intgr_error_log.append(e) - return checked_items, intgr_error_log - - def check_wide_relationship_classes(self, *wide_items, for_update=False, strict=False): - """Check whether relationship classes passed as argument respect integrity constraints. - - Args: - items (Iterable): One or more Python :class:`dict` objects representing the items to be checked. - strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` - if one of the items violates an integrity constraint. - - Returns - list: items that passed the check. - list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. - """ - self.fetch_all({"entity_class"}, include_ancestors=True) - cache = self.cache - intgr_error_log = [] - checked_wide_items = list() - relationship_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values() if x.dimension_id_list} - object_class_ids = [x.id for x in cache.get("entity_class", {}).values() if not x.dimension_id_list] - for wide_item in wide_items: - object_class_id_list = wide_item.get("object_class_id_list") - if "dimension_id_list" not in wide_item and object_class_id_list is not None: - wide_item["dimension_id_list"] = object_class_id_list - try: - with self._manage_stocks( - "entity_class", - wide_item, - {("name",): relationship_class_ids}, - for_update, - intgr_error_log, - ) as wide_item: - if "object_class_id_list" not in wide_item: - # Use CacheItem.get rather than pop since the former implements the lookup - wide_item["object_class_id_list"] = wide_item.get("dimension_id_list", ()) - check_wide_relationship_class(wide_item, relationship_class_ids, object_class_ids) - checked_wide_items.append(wide_item) - except SpineIntegrityError as e: - if strict: - raise e - intgr_error_log.append(e) - return checked_wide_items, intgr_error_log - - def check_wide_relationships(self, *wide_items, for_update=False, strict=False): - """Check whether relationships passed as argument respect integrity constraints. - - Args: - items (Iterable): One or more Python :class:`dict` objects representing the items to be checked. - strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` - if one of the items violates an integrity constraint. - - Returns - list: items that passed the check. - list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. - """ - self.fetch_all({"entity"}, include_ancestors=True) - cache = self.cache - intgr_error_log = [] - checked_wide_items = list() - relationship_ids_by_name = { - (x.class_id, x.name): x.id for x in cache.get("entity", {}).values() if x.element_id_list - } - relationship_ids_by_obj_lst = { - (x.class_id, x.element_id_list): x.id for x in cache.get("entity", {}).values() if x.element_id_list - } - relationship_classes = { - x.id: {"object_class_id_list": x.dimension_id_list, "name": x.name} - for x in cache.get("entity_class", {}).values() - if x.dimension_id_list - } - objects = { - x.id: {"class_id": x.class_id, "name": x.name} - for x in cache.get("entity", {}).values() - if not x.element_id_list - } - for wide_item in wide_items: - object_id_list = wide_item.get("object_id_list") - if "element_id_list" not in wide_item and object_id_list is not None: - wide_item["element_id_list"] = object_id_list - try: - with self._manage_stocks( - "entity", - wide_item, - { - ("class_id", "name"): relationship_ids_by_name, - ("class_id", "element_id_list"): relationship_ids_by_obj_lst, - }, - for_update, - intgr_error_log, - ) as wide_item: - if "object_class_id_list" not in wide_item: - # NOTE: Use CacheItem.get rather than pop since the former implements the lookup - wide_item["object_class_id_list"] = wide_item.get("dimension_id_list", ()) - if "object_id_list" not in wide_item: - # NOTE: Use CacheItem.get rather than pop since the former implements the lookup - wide_item["object_id_list"] = wide_item.get("element_id_list", ()) - check_wide_relationship( - wide_item, - relationship_ids_by_name, - relationship_ids_by_obj_lst, - relationship_classes, - objects, - ) - checked_wide_items.append(wide_item) - except SpineIntegrityError as e: - if strict: - raise e - intgr_error_log.append(e) - return checked_wide_items, intgr_error_log - - def check_entity_groups(self, *items, for_update=False, strict=False): - """Check whether entity groups passed as argument respect integrity constraints. - - Args: - items (Iterable): One or more Python :class:`dict` objects representing the items to be checked. - strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` - if one of the items violates an integrity constraint. - - Returns - list: items that passed the check. - list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. - """ - self.fetch_all({"entity_group"}, include_ancestors=True) - cache = self.cache - intgr_error_log = list() - checked_items = list() - current_ids = {(x.group_id, x.member_id): x.id for x in cache.get("entity_group", {}).values()} - entities = {} - for entity in cache.get("entity", {}).values(): - entities.setdefault(entity.class_id, dict())[entity.id] = entity._asdict() - for item in items: - try: - with self._manage_stocks( - "entity_group", item, {("entity_id", "member_id"): current_ids}, for_update, intgr_error_log - ) as item: - check_entity_group(item, current_ids, entities) - checked_items.append(item) - except SpineIntegrityError as e: - if strict: - raise e - intgr_error_log.append(e) - return checked_items, intgr_error_log - - def check_parameter_definitions(self, *items, for_update=False, strict=False): - """Check whether parameter definitions passed as argument respect integrity constraints. - - Args: - items (Iterable): One or more Python :class:`dict` objects representing the items to be checked. - strict (bool): Whether the method should raise :exc:`~.exception.SpineIntegrityError` - if one of the items violates an integrity constraint. - - Returns: - list: items that passed the check. - list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. - """ - self.fetch_all({"parameter_definition", "parameter_value"}, include_ancestors=True) - cache = self.cache - parameter_definition_ids_with_values = { - value.parameter_id for value in cache.get("parameter_value", {}).values() - } - intgr_error_log = [] - checked_items = list() - parameter_definition_ids = { - (x.entity_class_id, x.parameter_name): x.id for x in cache.get("parameter_definition", {}).values() - } - entity_class_ids = {x.id for x in cache.get("entity_class", {}).values()} - object_class_ids = {x.id for x in cache.get("entity_class", {}).values() if not x.dimension_id_list} - relationship_class_ids = {x.id for x in cache.get("entity_class", {}).values() if x.dimension_id_list} - parameter_value_lists = {x.id: x.value_id_list for x in cache.get("parameter_value_list", {}).values()} - list_values = {x.id: from_database(x.value, x.type) for x in cache.get("list_value", {}).values()} - for item in items: - try: - object_class_id = item.get("object_class_id") - relationship_class_id = item.get("relationship_class_id") - if object_class_id and relationship_class_id: - raise SpineIntegrityError("Can't associate a parameter to both an object and a relationship class.") - if object_class_id and object_class_id not in object_class_ids: - raise SpineIntegrityError("Invalid object class id.") - if relationship_class_id and relationship_class_id not in relationship_class_ids: - raise SpineIntegrityError("Invalid relationship class id.") - entity_class_id = object_class_id or relationship_class_id - if "entity_class_id" not in item and entity_class_id is not None: - item["entity_class_id"] = entity_class_id - if ( - for_update - and item["id"] in parameter_definition_ids_with_values - and item["parameter_value_list_id"] != cache["parameter_definition"][item["id"]].value_list_id - ): - raise SpineIntegrityError( - f"Can't change value list on parameter {item['name']} because it has parameter values." - ) - with self._manage_stocks( - "parameter_definition", - item, - {("entity_class_id", "name"): parameter_definition_ids}, - for_update, - intgr_error_log, - ) as full_item: - check_parameter_definition( - full_item, parameter_definition_ids, entity_class_ids, parameter_value_lists, list_values - ) - checked_items.append(full_item) - except SpineIntegrityError as e: - if strict: - raise e - intgr_error_log.append(e) - return checked_items, intgr_error_log - - def check_parameter_values(self, *items, for_update=False, strict=False): - """Check whether parameter values passed as argument respect integrity constraints. - - Args: - items (Iterable): One or more Python :class:`dict` objects representing the items to be checked. - strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` - if one of the items violates an integrity constraint. - - Returns - list: items that passed the check. - list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. - """ - self.fetch_all({"parameter_value"}, include_ancestors=True) - cache = self.cache - intgr_error_log = [] - checked_items = list() - parameter_value_ids = { - (x.entity_id, x.parameter_id, x.alternative_id): x.id for x in cache.get("parameter_value", {}).values() - } - parameter_definitions = { - x.id: { - "name": x.parameter_name, - "entity_class_id": x.entity_class_id, - "parameter_value_list_id": x.value_list_id, - } - for x in cache.get("parameter_definition", {}).values() - } - entities = {x.id: {"class_id": x.class_id, "name": x.name} for x in cache.get("entity", {}).values()} - parameter_value_lists = {x.id: x.value_id_list for x in cache.get("parameter_value_list", {}).values()} - list_values = {x.id: from_database(x.value, x.type) for x in cache.get("list_value", {}).values()} - alternatives = set(a.id for a in cache.get("alternative", {}).values()) - for item in items: - entity_id = item.get("object_id") or item.get("relationship_id") - if "entity_id" not in item and entity_id is not None: - item["entity_id"] = entity_id - try: - with self._manage_stocks( - "parameter_value", - item, - {("entity_id", "parameter_definition_id", "alternative_id"): parameter_value_ids}, - for_update, - intgr_error_log, - ) as item: - check_parameter_value( - item, - parameter_value_ids, - parameter_definitions, - entities, - parameter_value_lists, - list_values, - alternatives, - ) - checked_items.append(item) - except SpineIntegrityError as e: - if strict: - raise e - intgr_error_log.append(e) - return checked_items, intgr_error_log - - def check_parameter_value_lists(self, *items, for_update=False, strict=False): - """Check whether parameter value-lists passed as argument respect integrity constraints. - - Args: - items (Iterable): One or more Python :class:`dict` objects representing the items to be checked. - strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` - if one of the items violates an integrity constraint. - - Returns - list: items that passed the check. - list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. - """ - self.fetch_all({"parameter_value_list"}, include_ancestors=True) - cache = self.cache - intgr_error_log = [] - checked_items = list() - parameter_value_list_ids = {x.name: x.id for x in cache.get("parameter_value_list", {}).values()} - for item in items: - try: - with self._manage_stocks( - "parameter_value_list", - item, - {("name",): parameter_value_list_ids}, - for_update, - intgr_error_log, - ) as item: - check_parameter_value_list(item, parameter_value_list_ids) - checked_items.append(item) - except SpineIntegrityError as e: - if strict: - raise e - intgr_error_log.append(e) - return checked_items, intgr_error_log - - def check_list_values(self, *items, for_update=False, strict=False): - """Check whether list values passed as argument respect integrity constraints. - - Args: - items (Iterable): One or more Python :class:`dict` objects representing the items to be checked. - strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` - if one of the items violates an integrity constraint. - - Returns - list: items that passed the check. - list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. - """ - self.fetch_all({"list_value"}, include_ancestors=True) - cache = self.cache - intgr_error_log = [] - checked_items = list() - list_value_ids_by_index = { - (x.parameter_value_list_id, x.index): x.id for x in cache.get("list_value", {}).values() - } - list_value_ids_by_value = { - (x.parameter_value_list_id, x.type, x.value): x.id for x in cache.get("list_value", {}).values() - } - list_names_by_id = {x.id: x.name for x in cache.get("parameter_value_list", {}).values()} - for item in items: - try: - with self._manage_stocks( - "list_value", - item, - { - ("parameter_value_list_id", "index"): list_value_ids_by_index, - ("parameter_value_list_id", "type", "value"): list_value_ids_by_value, - }, - for_update, - intgr_error_log, - ) as item: - check_list_value(item, list_names_by_id, list_value_ids_by_index, list_value_ids_by_value) - checked_items.append(item) - except SpineIntegrityError as e: - if strict: - raise e - intgr_error_log.append(e) - return checked_items, intgr_error_log - - def check_metadata(self, *items, for_update=False, strict=False): - """Checks whether metadata respects integrity constraints. - - Args: - *items: One or more Python :class:`dict` objects representing the items to be checked. - strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` - if one of the items violates an integrity constraint. - - Returns - list: items that passed the check. - list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. - """ - self.fetch_all({"metadata"}) - cache = self.cache - intgr_error_log = [] - checked_items = list() - metadata = {(x.name, x.value): x.id for x in cache.get("metadata", {}).values()} - for item in items: - try: - with self._manage_stocks( - "metadata", item, {("name", "value"): metadata}, for_update, intgr_error_log - ) as item: - check_metadata(item, metadata) - if (item["name"], item["value"]) not in metadata: - checked_items.append(item) - except SpineIntegrityError as e: - if strict: - raise e - intgr_error_log.append(e) - return checked_items, intgr_error_log - - def check_entity_metadata(self, *items, for_update=False, strict=False): - """Checks whether entity metadata respects integrity constraints. - - Args: - *items: One or more Python :class:`dict` objects representing the items to be checked. - strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` - if one of the items violates an integrity constraint. - - Returns - list: items that passed the check. - list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. - """ - self.fetch_all({"entity_metadata"}, include_ancestors=True) - cache = self.cache - intgr_error_log = [] - checked_items = list() - entities = {x.id for x in cache.get("entity", {}).values()} - metadata = {x.id for x in cache.get("metadata", {}).values()} - for item in items: - try: - with self._manage_stocks("entity_metadata", item, {}, for_update, intgr_error_log) as item: - check_entity_metadata(item, entities, metadata) - checked_items.append(item) - except SpineIntegrityError as e: - if strict: - raise e - intgr_error_log.append(e) - return checked_items, intgr_error_log - - def check_parameter_value_metadata(self, *items, for_update=False, strict=False): - """Checks whether parameter value metadata respects integrity constraints. - - Args: - *items: One or more Python :class:`dict` objects representing the items to be checked. - strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` - if one of the items violates an integrity constraint. - - Returns - list: items that passed the check. - list: :exc:`~.exception.SpineIntegrityError` instances corresponding to found violations. - """ - self.fetch_all({"parameter_value_metadata"}, include_ancestors=True) - cache = self.cache - intgr_error_log = [] - checked_items = list() - values = {x.id for x in cache.get("parameter_value", {}).values()} - metadata = {x.id for x in cache.get("metadata", {}).values()} - for item in items: - try: - with self._manage_stocks("parameter_value_metadata", item, {}, for_update, intgr_error_log) as item: - check_parameter_value_metadata(item, values, metadata) - checked_items.append(item) - except SpineIntegrityError as e: - if strict: - raise e - intgr_error_log.append(e) - return checked_items, intgr_error_log - - @contextmanager - def _manage_stocks(self, item_type, item, existing_ids_by_pk, for_update, intgr_error_log): - cache = self.cache - if for_update: - try: - id_ = item["id"] - except KeyError: - raise SpineIntegrityError(f"Missing {item_type} identifier.") from None - try: - full_item = cache.get(item_type, {})[id_] - except KeyError: - raise SpineIntegrityError(f"{item_type} not found.") from None - else: - id_ = None - full_item = cache.make_item(item_type, item) - try: - existing_ids_by_key = { - _get_key(full_item, pk): existing_ids for pk, existing_ids in existing_ids_by_pk.items() - } - except KeyError as e: - raise SpineIntegrityError(f"Missing key field {e} for {item_type}.") from None - if for_update: - try: - # Remove from existing - for key, existing_ids in existing_ids_by_key.items(): - del existing_ids[key] - except KeyError: - raise SpineIntegrityError(f"{item_type} not found.") from None - intgr_error_log += _fix_immutable_fields(item_type, full_item, item) - full_item.update(item) - try: - yield full_item - # Check is performed at this point - except SpineIntegrityError: # pylint: disable=try-except-raise - # Check didn't pass, so reraise - raise - else: - # Check passed, so add to existing - for key, existing_ids in existing_ids_by_key.items(): - existing_ids[key] = id_ - if for_update: - cache.get(item_type, {})[id_] = full_item - - -def _get_key_values(item, pk): - for field in pk: - value = item[field] - if isinstance(value, list): - value = tuple(value) - yield value - - -def _get_key(item, pk): - key = tuple(_get_key_values(item, pk)) - if len(key) > 1: - return key - return key[0] - - -def _fix_immutable_fields(item_type, current_item, item): - immutable_fields = { - "entity_class": ("dimension_id_list",), - "relationship_class": ("object_class_id_list",), - "object": ("class_id",), - "relationship": ("class_id",), - "entity": ("class_id",), - "parameter_definition": ("entity_class_id", "object_class_id", "relationship_class_id"), - "parameter_value": ("entity_class_id", "object_class_id", "relationship_class_id"), - }.get(item_type, ()) - fixed = [] - for field in immutable_fields: - if current_item.get(field) is None: - continue - if field in item and item[field] != current_item[field]: - fixed.append(field) - item[field] = current_item[field] - if fixed: - fixed = ', '.join([f"'{field}'" for field in fixed]) - return [SpineIntegrityError(f"Can't update fixed fields {fixed}")] - return [] diff --git a/spinedb_api/db_mapping_update_mixin.py b/spinedb_api/db_mapping_update_mixin.py index 5a243a58..838d62aa 100644 --- a/spinedb_api/db_mapping_update_mixin.py +++ b/spinedb_api/db_mapping_update_mixin.py @@ -128,11 +128,13 @@ def update_items(self, tablename, *items, check=True, strict=False): list(SpineIntegrityError): found violations """ if check: - checked_items, intgr_error_log = self.check_items(tablename, *items, for_update=True, strict=strict) + checked_items, errors = self.check_items(tablename, *items, for_update=True) else: - checked_items, intgr_error_log = list(items), [] + checked_items, errors = list(items), [] + if errors and strict: + raise SpineDBAPIError(", ".join(errors)) _ = self._update_items(tablename, *checked_items) - return checked_items, intgr_error_log + return checked_items, errors def _update_items(self, tablename, *items): """Updates items in cache without checking integrity.""" @@ -150,87 +152,45 @@ def _update_items(self, tablename, *items): def update_alternatives(self, *items, **kwargs): return self.update_items("alternative", *items, **kwargs) - def _update_alternatives(self, *items): - return self._update_items("alternative", *items) - def update_scenarios(self, *items, **kwargs): return self.update_items("scenario", *items, **kwargs) - def _update_scenarios(self, *items): - return self._update_items("scenario", *items) - def update_scenario_alternatives(self, *items, **kwargs): return self.update_items("scenario_alternative", *items, **kwargs) - def _update_scenario_alternatives(self, *items): - return self._update_items("scenario_alternative", *items) - def update_entity_classes(self, *items, **kwargs): return self.update_items("entity_class", *items, **kwargs) - def _update_entity_classes(self, *items): - return self._update_items("entity_class", *items) - def update_entities(self, *items, **kwargs): return self.update_items("entity", *items, **kwargs) - def _update_entities(self, *items): - return self._update_items("entity", *items) - def update_object_classes(self, *items, **kwargs): return self.update_items("object_class", *items, **kwargs) - def _update_object_classes(self, *items): - return self._update_items("object_class", *items) - def update_objects(self, *items, **kwargs): return self.update_items("object", *items, **kwargs) - def _update_objects(self, *items): - return self._update_items("object", *items) - def update_wide_relationship_classes(self, *items, **kwargs): return self.update_items("relationship_class", *items, **kwargs) - def _update_wide_relationship_classes(self, *items): - return self._update_items("relationship_class", *items) - def update_wide_relationships(self, *items, **kwargs): return self.update_items("relationship", *items, **kwargs) - def _update_wide_relationships(self, *items): - return self._update_items("relationship", *items) - def update_parameter_definitions(self, *items, **kwargs): return self.update_items("parameter_definition", *items, **kwargs) - def _update_parameter_definitions(self, *items): - return self._update_items("parameter_definition", *items) - def update_parameter_values(self, *items, **kwargs): return self.update_items("parameter_value", *items, **kwargs) - def _update_parameter_values(self, *items): - return self._update_items("parameter_value", *items) - def update_parameter_value_lists(self, *items, **kwargs): return self.update_items("parameter_value_list", *items, **kwargs) - def _update_parameter_value_lists(self, *items): - return self._update_items("parameter_value_list", *items) - def update_list_values(self, *items, **kwargs): return self.update_items("list_value", *items, **kwargs) - def _update_list_values(self, *items): - return self._update_items("list_value", *items) - def update_metadata(self, *items, **kwargs): return self.update_items("metadata", *items, **kwargs) - def _update_metadata(self, *items): - return self._update_items("metadata", *items) - def update_ext_entity_metadata(self, *items, check=True, strict=False): updated_items, errors = self._update_ext_item_metadata("entity_metadata", *items, check=check, strict=strict) return updated_items, errors diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 828851c5..51e27667 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -14,27 +14,11 @@ """ -import uuid -from .exception import SpineIntegrityError, SpineDBAPIError -from .check_functions import ( - check_entity_class, - check_entity, - check_alternative, - check_object_class, - check_object, - check_wide_relationship_class, - check_wide_relationship, - check_entity_group, - check_parameter_definition, - check_parameter_value, - check_scenario, - check_parameter_value_list, - check_list_value, -) -from .parameter_value import to_database, from_database, fix_conflict +from .parameter_value import to_database, fix_conflict from .helpers import _parse_metadata # TODO: update docstrings +# FIXME: alt_id, alternative_name = db_map.get_import_alternative() class ImportErrorLogItem: @@ -111,56 +95,13 @@ def import_data(db_map, unparse_value=to_database, on_conflict="merge", **kwargs tuple: number of inserted/changed entities and list of ImportErrorLogItem with any import errors """ - add_items_by_tablename = { - "alternative": db_map._add_alternatives, - "scenario": db_map._add_scenarios, - "scenario_alternative": db_map._add_scenario_alternatives, - "entity_class": db_map._add_entity_classes, - "object_class": db_map._add_object_classes, - "relationship_class": db_map._add_wide_relationship_classes, - "parameter_value_list": db_map._add_parameter_value_lists, - "list_value": db_map._add_list_values, - "parameter_definition": db_map._add_parameter_definitions, - "entity": db_map._add_entities, - "object": db_map._add_objects, - "relationship": db_map._add_wide_relationships, - "entity_group": db_map._add_entity_groups, - "parameter_value": db_map._add_parameter_values, - "metadata": db_map._add_metadata, - "entity_metadata": db_map._add_entity_metadata, - "parameter_value_metadata": db_map._add_parameter_value_metadata, - } - update_items_by_tablename = { - "alternative": db_map._update_alternatives, - "scenario": db_map._update_scenarios, - "scenario_alternative": db_map._update_scenario_alternatives, - "entity_class": db_map._update_entity_classes, - "object_class": db_map._update_object_classes, - "relationship_class": db_map._update_wide_relationship_classes, - "parameter_value_list": db_map._update_parameter_value_lists, - "list_value": db_map._update_list_values, - "parameter_definition": db_map._update_parameter_definitions, - "entity": db_map._update_entities, - "object": db_map._update_objects, - "parameter_value": db_map._update_parameter_values, - } error_log = [] num_imports = 0 for tablename, (to_add, to_update, errors) in get_data_for_import( db_map, unparse_value=unparse_value, on_conflict=on_conflict, **kwargs ): - update_items = update_items_by_tablename.get(tablename, lambda *args, **kwargs: ()) - try: - updated = update_items(*to_update) - except SpineDBAPIError as error: - updated = [] - error_log.append(ImportErrorLogItem(msg=str(error), db_type=tablename)) - add_items = add_items_by_tablename[tablename] - try: - added = add_items(*to_add) - except SpineDBAPIError as error: - added = [] - error_log.append(ImportErrorLogItem(msg=str(error), db_type=tablename)) + updated, _ = db_map.update_items(tablename, *to_update, check=False) + added, _ = db_map.add_items(tablename, *to_add, check=False) num_imports += len(added) + len(updated) error_log.extend(errors) return num_imports, error_log @@ -256,10 +197,7 @@ def get_data_for_import( _get_parameter_definitions_for_import(db_map, parameter_definitions, unparse_value), ) if object_parameters: - yield ( - "parameter_definition", - _get_object_parameters_for_import(db_map, object_parameters, unparse_value), - ) + yield ("parameter_definition", _get_object_parameters_for_import(db_map, object_parameters, unparse_value)) if relationship_parameters: yield ( "parameter_definition", @@ -333,48 +271,6 @@ def import_entity_classes(db_map, data): return import_data(db_map, entity_classes=data) -def _get_entity_classes_for_import(db_map, data): - db_map.fetch_all({"entity_class"}, include_ancestors=True) - cache = db_map.cache - entity_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values()} - checked = set() - error_log = [] - to_add = [] - to_update = [] - with db_map.generate_ids("entity_class") as new_entity_class_id: - for name, *optionals in data: - if name in checked: - continue - ec_id = entity_class_ids.pop(name, None) - item = ( - cache["entity_class"][ec_id]._asdict() - if ec_id is not None - else {"name": name, "description": None, "display_icon": None} - ) - item.update(dict(zip(("dimension_name_list", "description", "display_icon"), optionals))) - item["dimension_id_list"] = tuple( - entity_class_ids.get(x, None) for x in item.get("dimension_name_list", ()) - ) - try: - check_entity_class(item, entity_class_ids) - except SpineIntegrityError as e: - error_log.append( - ImportErrorLogItem(f"Could not import entity class '{name}': {e.msg}", db_type="entity_class") - ) - continue - finally: - if ec_id is not None: - entity_class_ids[name] = ec_id - checked.add(name) - if ec_id is not None: - item["id"] = ec_id - to_update.append(item) - else: - item["id"] = entity_class_ids[name] = new_entity_class_id() - to_add.append(item) - return to_add, to_update, error_log - - def import_entities(db_map, data): """Imports entities. @@ -398,87 +294,6 @@ def import_entities(db_map, data): return import_data(db_map, entities=data) -def _make_unique_entity_name(class_id, class_name, ent_name_or_el_names, class_id_name_tuples): - if isinstance(ent_name_or_el_names, str): - return ent_name_or_el_names - base_name = class_name + "_" + "__".join([en if en is not None else "None" for en in ent_name_or_el_names]) - name = base_name - while (class_id, name) in class_id_name_tuples: - name = base_name + uuid.uuid4().hex - return name - - -def _get_entities_for_import(db_map, data): - db_map.fetch_all({"entity"}, include_ancestors=True) - cache = db_map.cache - entities = {x.id: x for x in cache.get("entity", {}).values()} - entity_ids_per_name = {(x.class_id, x.name): x.id for x in cache.get("entity", {}).values()} - entity_ids_per_el_id_lst = { - (x.class_id, x.element_id_list): x.id for x in cache.get("entity", {}).values() if x.element_id_list - } - entity_classes = { - x.id: {"dimension_id_list": x.dimension_id_list, "name": x.name} for x in cache.get("entity_class", {}).values() - } - entity_class_ids = {x["name"]: id_ for id_, x in entity_classes.items()} - dimension_id_lists = {id_: x["dimension_id_list"] for id_, x in entity_classes.items()} - error_log = [] - to_add = [] - to_update = [] - checked = set() - with db_map.generate_ids("entity") as new_entity_id: - for class_name, ent_name_or_el_names, *optionals in data: - ec_id = entity_class_ids.get(class_name, None) - dim_ids = dimension_id_lists.get(ec_id, ()) - if isinstance(ent_name_or_el_names, str): - el_ids = () - e_key = ent_name_or_el_names - e_id = None - else: - el_ids = tuple( - entity_ids_per_name.get((dim_id, name), None) for dim_id, name in zip(dim_ids, ent_name_or_el_names) - ) - e_key = el_ids - e_id = entity_ids_per_el_id_lst.pop((ec_id, el_ids), None) - if (ec_id, e_key) in checked: - continue - if e_id is not None: - e_name = cache["entity"][e_id].name - entity_ids_per_name.pop((e_id, e_name)) - else: - e_name = _make_unique_entity_name(ec_id, class_name, ent_name_or_el_names, entity_ids_per_name) - item = ( - cache["entity"][e_id]._asdict() - if e_id is not None - else { - "name": e_name, - "class_id": ec_id, - "element_id_list": el_ids, - "dimension_id_list": dim_ids, - } - ) - item.update(dict(zip(("description",), optionals))) - try: - check_entity(item, entity_ids_per_name, entity_ids_per_el_id_lst, entity_classes, entities) - except SpineIntegrityError as e: - msg = f"Could not import entity {tuple(ent_name_or_el_names)} into '{class_name}': {e.msg}" - error_log.append(ImportErrorLogItem(msg=msg, db_type="relationship")) - continue - finally: - if e_id is not None: - entity_ids_per_el_id_lst[ec_id, el_ids] = entity_ids_per_name[ec_id, e_name] = e_id - checked.add((ec_id, e_key)) - if e_id is not None: - item["id"] = e_id - to_update.append(item) - else: - item["id"] = entity_ids_per_el_id_lst[ec_id, el_ids] = entity_ids_per_name[ - ec_id, e_name - ] = new_entity_id() - entities[item["id"]] = item - to_add.append(item) - return to_add, to_update, error_log - - def import_entity_groups(db_map, data): """Imports list of entity groups by name with associated class name into given database mapping: Ignores duplicate and existing (group, member) tuples. @@ -502,39 +317,6 @@ def import_entity_groups(db_map, data): return import_data(db_map, entity_groups=data) -def _get_entity_groups_for_import(db_map, data): - db_map.fetch_all({"entity_group"}, include_ancestors=True) - cache = db_map.cache - entity_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values()} - entity_ids = {(x.class_id, x.name): x.id for x in cache.get("entity", {}).values()} - entities = {} - for ent in cache.get("entity", {}).values(): - entities.setdefault(ent.class_id, {})[ent.id] = ent._asdict() - entity_group_ids = {(x.group_id, x.member_id): x.id for x in cache.get("entity_group", {}).values()} - error_log = [] - to_add = [] - seen = set() - for class_name, group_name, member_name in data: - ec_id = entity_class_ids.get(class_name) - g_id = entity_ids.get((ec_id, group_name)) - m_id = entity_ids.get((ec_id, member_name)) - if (g_id, m_id) in seen | entity_group_ids.keys(): - continue - item = {"entity_class_id": ec_id, "entity_id": g_id, "member_id": m_id} - try: - check_entity_group(item, entity_group_ids, entities) - to_add.append(item) - seen.add((g_id, m_id)) - except SpineIntegrityError as e: - error_log.append( - ImportErrorLogItem( - msg=f"Could not import entity '{member_name}' into group '{group_name}': {e.msg}", - db_type="entity group", - ) - ) - return to_add, [], error_log - - def import_parameter_definitions(db_map, data, unparse_value=to_database): """Imports list of parameter definitions: @@ -557,70 +339,6 @@ def import_parameter_definitions(db_map, data, unparse_value=to_database): return import_data(db_map, parameter_definitions=data, unparse_value=unparse_value) -def _get_parameter_definitions_for_import(db_map, data, unparse_value): - db_map.fetch_all({"parameter_definition"}, include_ancestors=True) - cache = db_map.cache - parameter_definition_ids = { - (x.entity_class_id, x.parameter_name): x.id for x in cache.get("parameter_definition", {}).values() - } - entity_class_names = {x.id: x.name for x in cache.get("entity_class", {}).values()} - entity_class_ids = {ec_name: id_ for id_, ec_name in entity_class_names.items()} - parameter_value_lists = {} - parameter_value_list_ids = {} - for x in cache.get("parameter_value_list", {}).values(): - parameter_value_lists[x.id] = x.value_id_list - parameter_value_list_ids[x.name] = x.id - list_values = {x.id: from_database(x.value, x.type) for x in cache.get("list_value", {}).values()} - error_log = [] - to_add = [] - to_update = [] - checked = set() - functions = [unparse_value, lambda x: (parameter_value_list_ids.get(x),), lambda x: (x,)] - for class_name, parameter_name, *optionals in data: - ec_id = entity_class_ids.get(class_name, None) - checked_key = (ec_id, parameter_name) - if checked_key in checked: - continue - p_id = parameter_definition_ids.pop((ec_id, parameter_name), None) - item = ( - cache["parameter_definition"][p_id]._asdict() - if p_id is not None - else { - "name": parameter_name, - "entity_class_id": ec_id, - "default_value": None, - "default_type": None, - "parameter_value_list_id": None, - "description": None, - } - ) - optionals = [y for f, x in zip(functions, optionals) for y in f(x)] - item.update(dict(zip(("default_value", "default_type", "parameter_value_list_id", "description"), optionals))) - try: - check_parameter_definition( - item, parameter_definition_ids, entity_class_names.keys(), parameter_value_lists, list_values - ) - except SpineIntegrityError as e: - # Relationship class doesn't exists - error_log.append( - ImportErrorLogItem( - msg=f"Could not import parameter definition '{parameter_name}' with class '{class_name}': {e.msg}", - db_type="parameter definition", - ) - ) - continue - finally: - if p_id is not None: - parameter_definition_ids[ec_id, parameter_name] = p_id - checked.add(checked_key) - if p_id is not None: - item["id"] = p_id - to_update.append(item) - else: - to_add.append(item) - return to_add, to_update, error_log - - def import_parameter_values(db_map, data, unparse_value=to_database, on_conflict="merge"): """Imports parameter values: @@ -644,113 +362,6 @@ def import_parameter_values(db_map, data, unparse_value=to_database, on_conflict return import_data(db_map, parameter_values=data, unparse_value=unparse_value, on_conflict=on_conflict) -def _get_parameter_values_for_import(db_map, data, unparse_value, on_conflict): - db_map.fetch_all({"parameter_value"}, include_ancestors=True) - cache = db_map.cache - dimension_id_lists = {x.id: x.dimension_id_list for x in cache.get("entity_class", {}).values()} - parameter_value_ids = { - (x.entity_id, x.parameter_id, x.alternative_id): x.id for x in cache.get("parameter_value", {}).values() - } - parameters = { - x.id: { - "name": x.parameter_name, - "entity_class_id": x.entity_class_id, - "parameter_value_list_id": x.value_list_id, - } - for x in cache.get("parameter_definition", {}).values() - } - entities = { - x.id: {"class_id": x.class_id, "name": x.name, "element_id_list": x.element_id_list} - for x in cache.get("entity", {}).values() - } - parameter_value_lists = {x.id: x.value_id_list for x in cache.get("parameter_value_list", {}).values()} - list_values = {x.id: from_database(x.value, x.type) for x in cache.get("list_value", {}).values()} - parameter_ids = {(p["entity_class_id"], p["name"]): p_id for p_id, p in parameters.items()} - entity_ids = {(x["class_id"], x["element_id_list"] or x["name"]): e_id for e_id, x in entities.items()} - entity_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values()} - alternatives = {a.name: a.id for a in cache.get("alternative", {}).values()} - alternative_ids = set(alternatives.values()) - error_log = [] - to_add = [] - to_update = [] - checked = set() - for class_name, ent_name_or_el_names, parameter_name, value, *optionals in data: - ec_id = entity_class_ids.get(class_name, None) - dim_ids = dimension_id_lists.get(ec_id, ()) - el_ids = tuple(entity_ids.get((dim_id, name)) for dim_id, name in zip(dim_ids, ent_name_or_el_names)) - ent_key = el_ids or ent_name_or_el_names - e_id = entity_ids.get((ec_id, ent_key), None) - p_id = parameter_ids.get((ec_id, parameter_name), None) - if optionals: - alternative_name = optionals[0] - alt_id = alternatives.get(alternative_name) - if not alt_id: - error_log.append( - ImportErrorLogItem( - msg=( - f"Could not import parameter value for '{ent_name_or_el_names}', class '{class_name}', " - f"parameter '{parameter_name}': alternative {alternative_name} does not exist." - ), - db_type="parameter value", - ) - ) - continue - else: - alt_id, alternative_name = db_map.get_import_alternative() - alternative_ids.add(alt_id) - checked_key = (e_id, p_id, alt_id) - if checked_key in checked: - msg = ( - f"Could not import parameter value for '{ent_name_or_el_names}', class '{class_name}', " - f"parameter '{parameter_name}', alternative {alternative_name}: " - "Duplicate parameter value, only first value will be considered." - ) - error_log.append(ImportErrorLogItem(msg=msg, db_type="parameter_value")) - continue - pv_id = parameter_value_ids.pop((e_id, p_id, alt_id), None) - value, type_ = unparse_value(value) - if pv_id is not None: - current_pv = cache["parameter_value"][pv_id] - value, type_ = fix_conflict((value, type_), (current_pv.value, current_pv.type), on_conflict) - item = { - "parameter_definition_id": p_id, - "entity_class_id": ec_id, - "entity_id": e_id, - "value": value, - "type": type_, - "alternative_id": alt_id, - } - try: - check_parameter_value( - item, - parameter_value_ids, - parameters, - entities, - parameter_value_lists, - list_values, - alternative_ids, - ) - except SpineIntegrityError as e: - error_log.append( - ImportErrorLogItem( - msg=f"Could not import parameter value for '{ent_name_or_el_names}', class '{class_name}', " - f"parameter '{parameter_name}', alternative {alternative_name}: {e.msg}", - db_type="parameter_value", - ) - ) - continue - finally: - if pv_id is not None: - parameter_value_ids[e_id, p_id, alt_id] = pv_id - checked.add(checked_key) - if pv_id is not None: - item["id"] = pv_id - to_update.append(item) - else: - to_add.append(item) - return to_add, to_update, error_log - - def import_alternatives(db_map, data): """ Imports alternatives. @@ -771,46 +382,6 @@ def import_alternatives(db_map, data): return import_data(db_map, alternatives=data) -def _get_alternatives_for_import(db_map, data): - db_map.fetch_all({"alternative"}, include_ancestors=True) - cache = db_map.cache - alternative_ids = {alternative.name: alternative.id for alternative in cache.get("alternative", {}).values()} - checked = set() - to_add = [] - to_update = [] - error_log = [] - for alternative in data: - if isinstance(alternative, str): - alternative = (alternative,) - name, *optionals = alternative - if name in checked: - continue - alternative_id = alternative_ids.pop(name, None) - item = ( - cache["alternative"][alternative_id]._asdict() - if alternative_id is not None - else {"name": name, "description": None} - ) - item.update(dict(zip(("description",), optionals))) - try: - check_alternative(item, alternative_ids) - except SpineIntegrityError as e: - error_log.append( - ImportErrorLogItem(msg=f"Could not import alternative '{name}': {e.msg}", db_type="alternative") - ) - continue - finally: - if alternative_id is not None: - alternative_ids[name] = alternative_id - checked.add(name) - if alternative_id is not None: - item["id"] = alternative_id - to_update.append(item) - else: - to_add.append(item) - return to_add, to_update, error_log - - def import_scenarios(db_map, data): """ Imports scenarios. @@ -833,44 +404,6 @@ def import_scenarios(db_map, data): return import_data(db_map, scenarios=data) -def _get_scenarios_for_import(db_map, data): - db_map.fetch_all({"scenario"}, include_ancestors=True) - cache = db_map.cache - scenario_ids = {scenario.name: scenario.id for scenario in cache.get("scenario", {}).values()} - checked = set() - to_add = [] - to_update = [] - error_log = [] - for scenario in data: - if isinstance(scenario, str): - scenario = (scenario,) - name, *optionals = scenario - if name in checked: - continue - scenario_id = scenario_ids.pop(name, None) - item = ( - cache["scenario"][scenario_id]._asdict() - if scenario_id is not None - else {"name": name, "active": False, "description": None} - ) - item.update(dict(zip(("active", "description"), optionals))) - try: - check_scenario(item, scenario_ids) - except SpineIntegrityError as e: - error_log.append(ImportErrorLogItem(msg=f"Could not import scenario '{name}': {e.msg}", db_type="scenario")) - continue - finally: - if scenario_id is not None: - scenario_ids[name] = scenario_id - checked.add(name) - if scenario_id is not None: - item["id"] = scenario_id - to_update.append(item) - else: - to_add.append(item) - return to_add, to_update, error_log - - def import_scenario_alternatives(db_map, data): """ Imports scenario alternatives. @@ -893,64 +426,44 @@ def import_scenario_alternatives(db_map, data): return import_data(db_map, scenario_alternatives=data) -def _get_scenario_alternatives_for_import(db_map, data): - db_map.fetch_all({"scenario_alternative"}, include_ancestors=True) - cache = db_map.cache - scenario_alternative_id_lists = {x.id: x.alternative_id_list for x in cache.get("scenario", {}).values()} - scenario_alternative_ids = { - (x.scenario_id, x.alternative_id): x.id for x in cache.get("scenario_alternative", {}).values() - } - scenario_ids = {scenario.name: scenario.id for scenario in cache.get("scenario", {}).values()} - alternative_ids = {alternative.name: alternative.id for alternative in cache.get("alternative", {}).values()} - checked = set() - to_add = [] - to_update = [] - error_log = [] - for scenario_name, alternative_name, *optionals in data: - scenario_id = scenario_ids.get(scenario_name) - if not scenario_id: - error_log.append( - ImportErrorLogItem(msg=f"Scenario '{scenario_name}' not found.", db_type="scenario alternative") - ) - continue - alternative_id = alternative_ids.get(alternative_name) - if not alternative_id: - error_log.append( - ImportErrorLogItem(msg=f"Alternative '{alternative_name}' not found.", db_type="scenario alternative") - ) - continue - if (scenario_name, alternative_name) in checked: - continue - checked.add((scenario_name, alternative_name)) - if optionals and optionals[0]: - before_alt_name = optionals[0] - try: - before_alt_id = alternative_ids[before_alt_name] - except KeyError: - error_log.append( - ImportErrorLogItem(msg=f"Before alternative '{before_alt_name}' not found for '{alternative_name}'") - ) - continue - else: - before_alt_id = None - orig_alt_id_list = scenario_alternative_id_lists.get(scenario_id, []) - new_alt_id_list = [id_ for id_ in orig_alt_id_list if id_ != alternative_id] - try: - pos = new_alt_id_list.index(before_alt_id) - except ValueError: - pos = len(new_alt_id_list) - new_alt_id_list.insert(pos, alternative_id) - scenario_alternative_id_lists[scenario_id] = new_alt_id_list - for scenario_id, new_alt_id_list in scenario_alternative_id_lists.items(): - for k, alt_id in enumerate(new_alt_id_list): - id_ = scenario_alternative_ids.get((scenario_id, alt_id)) - item = {"scenario_id": scenario_id, "alternative_id": alt_id, "rank": k + 1} - if id_ is not None: - item["id"] = id_ - to_update.append(item) - else: - to_add.append(item) - return to_add, to_update, error_log +def import_parameter_value_lists(db_map, data, unparse_value=to_database): + """Imports list of parameter value lists: + + Example:: + + data = [ + ['value_list_name', value1], ['value_list_name', value2], + ['another_value_list_name', value3], + ] + import_parameter_value_lists(db_map, data) + + Args: + db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into + data (List[List/Tuple]): list/set/iterable of lists/tuples with + value list name, list of values + + Returns: + (Int, List) Number of successful inserted objects, list of errors + """ + return import_data(db_map, parameter_value_lists=data, unparse_value=unparse_value) + + +def import_metadata(db_map, data=None): + """Imports metadata. Ignores duplicates. + + Example:: + + data = ['{"name1": "value1"}', '{"name2": "value2"}'] + import_metadata(db_map, data) + + Args: + db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into + data (List[List/Tuple]): list/set/iterable of string metadata entries in JSON format + + Returns: + (Int, List) Number of successful inserted objects, list of errors + """ + return import_data(db_map, metadata=data) def import_object_classes(db_map, data): @@ -972,46 +485,6 @@ def import_object_classes(db_map, data): return import_data(db_map, object_classes=data) -def _get_object_classes_for_import(db_map, data): - db_map.fetch_all({"entity_class"}, include_ancestors=True) - cache = db_map.cache - object_class_ids = {oc.name: oc.id for oc in cache.get("entity_class", {}).values() if not oc.dimension_id_list} - checked = set() - to_add = [] - to_update = [] - error_log = [] - for object_class in data: - if isinstance(object_class, str): - object_class = (object_class,) - name, *optionals = object_class - if name in checked: - continue - oc_id = object_class_ids.pop(name, None) - item = ( - cache["entity_class"][oc_id]._asdict() - if oc_id is not None - else {"name": name, "description": None, "display_icon": None} - ) - item.update(dict(zip(("description", "display_icon"), optionals))) - try: - check_object_class(item, object_class_ids) - except SpineIntegrityError as e: - error_log.append( - ImportErrorLogItem(msg=f"Could not import object class '{name}': {e.msg}", db_type="object class") - ) - continue - finally: - if oc_id is not None: - object_class_ids[name] = oc_id - checked.add(name) - if oc_id is not None: - item["id"] = oc_id - to_update.append(item) - else: - to_add.append(item) - return to_add, to_update, error_log - - def import_relationship_classes(db_map, data): """Imports relationship classes. @@ -1034,53 +507,6 @@ def import_relationship_classes(db_map, data): return import_data(db_map, relationship_classes=data) -def _get_relationship_classes_for_import(db_map, data): - db_map.fetch_all({"entity_class"}, include_ancestors=True) - cache = db_map.cache - object_class_ids = {oc.name: oc.id for oc in cache.get("entity_class", {}).values() if not oc.dimension_id_list} - relationship_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values() if x.dimension_id_list} - checked = set() - error_log = [] - to_add = [] - to_update = [] - for name, oc_names, *optionals in data: - if name in checked: - continue - rc_id = relationship_class_ids.pop(name, None) - item = ( - cache["entity_class"][rc_id]._asdict() - if rc_id is not None - else { - "name": name, - "dimension_id_list": [object_class_ids.get(oc, None) for oc in oc_names], - "description": None, - "display_icon": None, - } - ) - item["object_class_id_list"] = item.pop("dimension_id_list") - item.update(dict(zip(("description", "display_icon"), optionals))) - try: - check_wide_relationship_class(item, relationship_class_ids, set(object_class_ids.values())) - except SpineIntegrityError as e: - error_log.append( - ImportErrorLogItem( - f"Could not import relationship class '{name}' with object classes {tuple(oc_names)}: {e.msg}", - db_type="relationship class", - ) - ) - continue - finally: - if rc_id is not None: - relationship_class_ids[name] = rc_id - checked.add(name) - if rc_id is not None: - item["id"] = rc_id - to_update.append(item) - else: - to_add.append(item) - return to_add, to_update, error_log - - def import_objects(db_map, data): """Imports list of object by name with associated object class name into given database mapping: Ignores duplicate names and existing names. @@ -1103,47 +529,6 @@ def import_objects(db_map, data): return import_data(db_map, objects=data) -def _get_objects_for_import(db_map, data): - db_map.fetch_all({"entity"}, include_ancestors=True) - cache = db_map.cache - object_class_ids = {oc.name: oc.id for oc in cache.get("entity_class", {}).values() if not oc.dimension_id_list} - object_ids = {(o.class_id, o.name): o.id for o in cache.get("entity", {}).values() if not o.element_id_list} - checked = set() - error_log = [] - to_add = [] - to_update = [] - for oc_name, name, *optionals in data: - oc_id = object_class_ids.get(oc_name, None) - if (oc_id, name) in checked: - continue - o_id = object_ids.pop((oc_id, name), None) - item = ( - cache["entity"][o_id]._asdict() - if o_id is not None - else {"name": name, "class_id": oc_id, "description": None} - ) - item.update(dict(zip(("description",), optionals))) - try: - check_object(item, object_ids, set(object_class_ids.values())) - except SpineIntegrityError as e: - error_log.append( - ImportErrorLogItem( - msg=f"Could not import object '{name}' with class '{oc_name}': {e.msg}", db_type="object" - ) - ) - continue - finally: - if o_id is not None: - object_ids[oc_id, name] = o_id - checked.add((oc_id, name)) - if o_id is not None: - item["id"] = o_id - to_update.append(item) - else: - to_add.append(item) - return to_add, to_update, error_log - - def import_object_groups(db_map, data): """Imports list of object groups by name with associated object class name into given database mapping: Ignores duplicate and existing (group, member) tuples. @@ -1167,40 +552,6 @@ def import_object_groups(db_map, data): return import_data(db_map, object_groups=data) -def _get_object_groups_for_import(db_map, data): - db_map.fetch_all({"entity_group"}, include_ancestors=True) - cache = db_map.cache - object_class_ids = {oc.name: oc.id for oc in cache.get("entity_class", {}).values() if not oc.dimension_id_list} - object_ids = {(o.class_id, o.name): o.id for o in cache.get("entity", {}).values() if not o.element_id_list} - objects = {} - for obj in cache.get("entity", {}).values(): - if not obj.element_id_list: - objects.setdefault(obj.class_id, dict())[obj.id] = obj._asdict() - entity_group_ids = {(x.group_id, x.member_id): x.id for x in cache.get("entity_group", {}).values()} - error_log = [] - to_add = [] - seen = set() - for class_name, group_name, member_name in data: - oc_id = object_class_ids.get(class_name) - g_id = object_ids.get((oc_id, group_name)) - m_id = object_ids.get((oc_id, member_name)) - if (g_id, m_id) in seen | entity_group_ids.keys(): - continue - item = {"entity_class_id": oc_id, "entity_id": g_id, "member_id": m_id} - try: - check_entity_group(item, entity_group_ids, objects) - to_add.append(item) - seen.add((g_id, m_id)) - except SpineIntegrityError as e: - error_log.append( - ImportErrorLogItem( - msg=f"Could not import object '{member_name}' into group '{group_name}': {e.msg}", - db_type="entity group", - ) - ) - return to_add, [], error_log - - def import_relationships(db_map, data): """Imports relationships. @@ -1220,85 +571,6 @@ def import_relationships(db_map, data): return import_data(db_map, relationships=data) -def _make_unique_relationship_name(class_id, class_name, object_names, class_id_name_tuples): - base_name = class_name + "_" + "__".join([obj if obj is not None else "None" for obj in object_names]) - name = base_name - while (class_id, name) in class_id_name_tuples: - name = base_name + uuid.uuid4().hex - return name - - -def _get_relationships_for_import(db_map, data): - db_map.fetch_all({"entity"}, include_ancestors=True) - cache = db_map.cache - relationships = {x.name: x for x in cache.get("entity", {}).values() if x.element_id_list} - relationship_ids_per_name = {(x.class_id, x.name): x.id for x in relationships.values()} - relationship_ids_per_obj_lst = {(x.class_id, x.element_id_list): x.id for x in relationships.values()} - relationship_classes = { - x.id: {"object_class_id_list": x.dimension_id_list, "name": x.name} - for x in cache.get("entity_class", {}).values() - if x.dimension_id_list - } - objects = { - x.id: {"class_id": x.class_id, "name": x.name} - for x in cache.get("entity", {}).values() - if not x.element_id_list - } - object_ids = {(o["name"], o["class_id"]): o_id for o_id, o in objects.items()} - relationship_class_ids = {rc["name"]: rc_id for rc_id, rc in relationship_classes.items()} - object_class_id_lists = {rc_id: rc["object_class_id_list"] for rc_id, rc in relationship_classes.items()} - error_log = [] - to_add = [] - to_update = [] - checked = set() - for class_name, object_names, *optionals in data: - rc_id = relationship_class_ids.get(class_name, None) - oc_ids = object_class_id_lists.get(rc_id, []) - o_ids = tuple(object_ids.get((name, oc_id), None) for name, oc_id in zip(object_names, oc_ids)) - if (rc_id, o_ids) in checked: - continue - r_id = relationship_ids_per_obj_lst.pop((rc_id, o_ids), None) - if r_id is not None: - r_name = cache["entity"][r_id].name - relationship_ids_per_name.pop((rc_id, r_name)) - item = ( - cache["entity"][r_id]._asdict() - if r_id is not None - else { - "name": _make_unique_relationship_name(rc_id, class_name, object_names, relationship_ids_per_name), - "class_id": rc_id, - "element_id_list": list(o_ids), - "dimension_id_list": oc_ids, - } - ) - item["object_id_list"] = item.pop("element_id_list") - item["object_class_id_list"] = item.pop("dimension_id_list", ()) - item.update(dict(zip(("description",), optionals))) - try: - check_wide_relationship( - item, - relationship_ids_per_name, - relationship_ids_per_obj_lst, - relationship_classes, - objects, - ) - except SpineIntegrityError as e: - msg = f"Could not import relationship with objects {tuple(object_names)} into '{class_name}': {e.msg}" - error_log.append(ImportErrorLogItem(msg=msg, db_type="relationship")) - continue - finally: - if r_id is not None: - relationship_ids_per_obj_lst[rc_id, o_ids] = r_id - relationship_ids_per_name[rc_id, r_name] = r_id - checked.add((rc_id, o_ids)) - if r_id is not None: - item["id"] = r_id - to_update.append(item) - else: - to_add.append(item) - return to_add, to_update, error_log - - def import_object_parameters(db_map, data, unparse_value=to_database): """Imports list of object class parameters: @@ -1321,70 +593,6 @@ def import_object_parameters(db_map, data, unparse_value=to_database): return import_data(db_map, object_parameters=data, unparse_value=unparse_value) -def _get_object_parameters_for_import(db_map, data, unparse_value): - db_map.fetch_all({"parameter_definition"}, include_ancestors=True) - cache = db_map.cache - parameter_ids = { - (x.entity_class_id, x.parameter_name): x.id for x in cache.get("parameter_definition", {}).values() - } - object_class_names = {x.id: x.name for x in cache.get("entity_class", {}).values() if not x.dimension_id_list} - object_class_ids = {oc_name: oc_id for oc_id, oc_name in object_class_names.items()} - parameter_value_lists = {} - parameter_value_list_ids = {} - for x in cache.get("parameter_value_list", {}).values(): - parameter_value_lists[x.id] = x.value_id_list - parameter_value_list_ids[x.name] = x.id - list_values = {x.id: from_database(x.value, x.type) for x in cache.get("list_value", {}).values()} - checked = set() - error_log = [] - to_add = [] - to_update = [] - functions = [unparse_value, lambda x: (parameter_value_list_ids.get(x),), lambda x: (x,)] - for class_name, parameter_name, *optionals in data: - oc_id = object_class_ids.get(class_name, None) - checked_key = (oc_id, parameter_name) - if checked_key in checked: - continue - p_id = parameter_ids.pop((oc_id, parameter_name), None) - item = ( - cache["parameter_definition"][p_id]._asdict() - if p_id is not None - else { - "name": parameter_name, - "entity_class_id": oc_id, - "object_class_id": oc_id, - "default_value": None, - "default_type": None, - "parameter_value_list_id": None, - "description": None, - } - ) - optionals = [y for f, x in zip(functions, optionals) for y in f(x)] - item.update(dict(zip(("default_value", "default_type", "parameter_value_list_id", "description"), optionals))) - try: - check_parameter_definition( - item, parameter_ids, object_class_names.keys(), parameter_value_lists, list_values - ) - except SpineIntegrityError as e: - error_log.append( - ImportErrorLogItem( - f"Could not import parameter '{parameter_name}' with class '{class_name}': {e.msg}", - db_type="parameter definition", - ) - ) - continue - finally: - if p_id is not None: - parameter_ids[oc_id, parameter_name] = p_id - checked.add(checked_key) - if p_id is not None: - item["id"] = p_id - to_update.append(item) - else: - to_add.append(item) - return to_add, to_update, error_log - - def import_relationship_parameters(db_map, data, unparse_value=to_database): """Imports list of relationship class parameters: @@ -1407,71 +615,6 @@ def import_relationship_parameters(db_map, data, unparse_value=to_database): return import_data(db_map, relationship_parameters=data, unparse_value=unparse_value) -def _get_relationship_parameters_for_import(db_map, data, unparse_value): - db_map.fetch_all({"parameter_definition"}, include_ancestors=True) - cache = db_map.cache - parameter_ids = { - (x.entity_class_id, x.parameter_name): x.id for x in cache.get("parameter_definition", {}).values() - } - relationship_class_names = {x.id: x.name for x in cache.get("entity_class", {}).values() if x.dimension_id_list} - relationship_class_ids = {rc_name: rc_id for rc_id, rc_name in relationship_class_names.items()} - parameter_value_lists = {} - parameter_value_list_ids = {} - for x in cache.get("parameter_value_list", {}).values(): - parameter_value_lists[x.id] = x.value_id_list - parameter_value_list_ids[x.name] = x.id - list_values = {x.id: from_database(x.value, x.type) for x in cache.get("list_value", {}).values()} - error_log = [] - to_add = [] - to_update = [] - checked = set() - functions = [unparse_value, lambda x: (parameter_value_list_ids.get(x),), lambda x: (x,)] - for class_name, parameter_name, *optionals in data: - rc_id = relationship_class_ids.get(class_name, None) - checked_key = (rc_id, parameter_name) - if checked_key in checked: - continue - p_id = parameter_ids.pop((rc_id, parameter_name), None) - item = ( - cache["parameter_definition"][p_id]._asdict() - if p_id is not None - else { - "name": parameter_name, - "entity_class_id": rc_id, - "relationship_class_id": rc_id, - "default_value": None, - "default_type": None, - "parameter_value_list_id": None, - "description": None, - } - ) - optionals = [y for f, x in zip(functions, optionals) for y in f(x)] - item.update(dict(zip(("default_value", "default_type", "parameter_value_list_id", "description"), optionals))) - try: - check_parameter_definition( - item, parameter_ids, relationship_class_names.keys(), parameter_value_lists, list_values - ) - except SpineIntegrityError as e: - # Relationship class doesn't exists - error_log.append( - ImportErrorLogItem( - msg=f"Could not import parameter '{parameter_name}' with class '{class_name}': {e.msg}", - db_type="parameter definition", - ) - ) - continue - finally: - if p_id is not None: - parameter_ids[rc_id, parameter_name] = p_id - checked.add(checked_key) - if p_id is not None: - item["id"] = p_id - to_update.append(item) - else: - to_add.append(item) - return to_add, to_update, error_log - - def import_object_parameter_values(db_map, data, unparse_value=to_database, on_conflict="merge"): """Imports object parameter values: @@ -1491,114 +634,7 @@ def import_object_parameter_values(db_map, data, unparse_value=to_database, on_c Returns: (Int, List) Number of successful inserted objects, list of errors """ - return import_data( - db_map, - object_parameter_values=data, - unparse_value=unparse_value, - on_conflict=on_conflict, - ) - - -def _get_object_parameter_values_for_import(db_map, data, unparse_value, on_conflict): - db_map.fetch_all({"parameter_value"}, include_ancestors=True) - cache = db_map.cache - object_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values() if not x.dimension_id_list} - parameter_value_ids = { - (x.entity_id, x.parameter_id, x.alternative_id): x.id for x in cache.get("parameter_value", {}).values() - } - parameters = { - x.id: { - "name": x.parameter_name, - "entity_class_id": x.entity_class_id, - "parameter_value_list_id": x.value_list_id, - } - for x in cache.get("parameter_definition", {}).values() - } - objects = { - x.id: {"class_id": x.class_id, "name": x.name} - for x in cache.get("entity", {}).values() - if not x.element_id_list - } - parameter_value_lists = {x.id: x.value_id_list for x in cache.get("parameter_value_list", {}).values()} - list_values = {x.id: from_database(x.value, x.type) for x in cache.get("list_value", {}).values()} - object_ids = {(o["name"], o["class_id"]): o_id for o_id, o in objects.items()} - parameter_ids = {(p["name"], p["entity_class_id"]): p_id for p_id, p in parameters.items()} - alternatives = {a.name: a.id for a in cache.get("alternative", {}).values()} - alternative_ids = set(alternatives.values()) - error_log = [] - to_add = [] - to_update = [] - checked = set() - for class_name, object_name, parameter_name, value, *optionals in data: - oc_id = object_class_ids.get(class_name, None) - o_id = object_ids.get((object_name, oc_id), None) - p_id = parameter_ids.get((parameter_name, oc_id), None) - if optionals: - alternative_name = optionals[0] - alt_id = alternatives.get(alternative_name) - if not alt_id: - error_log.append( - ImportErrorLogItem( - msg=( - "Could not import parameter value for " - f"'{object_name}', class '{class_name}', parameter '{parameter_name}': " - f"alternative '{alternative_name}' does not exist." - ), - db_type="parameter value", - ) - ) - continue - else: - alt_id, alternative_name = db_map.get_import_alternative() - alternative_ids.add(alt_id) - checked_key = (o_id, p_id, alt_id) - if checked_key in checked: - msg = ( - f"Could not import parameter value for '{object_name}', class '{class_name}', " - f"parameter '{parameter_name}', alternative {alternative_name}: " - "Duplicate parameter value, only first value will be considered." - ) - error_log.append(ImportErrorLogItem(msg=msg, db_type="parameter value")) - continue - pv_id = parameter_value_ids.pop((o_id, p_id, alt_id), None) - value, type_ = unparse_value(value) - if pv_id is not None: - current_pv = cache["parameter_value"][pv_id] - value, type_ = fix_conflict((value, type_), (current_pv.value, current_pv.type), on_conflict) - item = { - "parameter_definition_id": p_id, - "entity_class_id": oc_id, - "entity_id": o_id, - "object_class_id": oc_id, - "object_id": o_id, - "value": value, - "type": type_, - "alternative_id": alt_id, - } - try: - check_parameter_value( - item, parameter_value_ids, parameters, objects, parameter_value_lists, list_values, alternative_ids - ) - except SpineIntegrityError as e: - error_log.append( - ImportErrorLogItem( - msg="Could not import parameter value for '{0}', class '{1}', parameter '{2}': {3}".format( - object_name, class_name, parameter_name, e.msg - ), - db_type="parameter value", - ) - ) - continue - finally: - if pv_id is not None: - parameter_value_ids[o_id, p_id, alt_id] = pv_id - checked.add(checked_key) - if pv_id is not None: - item["id"] = pv_id - to_update.append(item) - else: - to_add.append(item) - return to_add, to_update, error_log + return import_data(db_map, object_parameter_values=data, unparse_value=unparse_value, on_conflict=on_conflict) def import_relationship_parameter_values(db_map, data, unparse_value=to_database, on_conflict="merge"): @@ -1620,268 +656,7 @@ def import_relationship_parameter_values(db_map, data, unparse_value=to_database Returns: (Int, List) Number of successful inserted objects, list of errors """ - return import_data( - db_map, - relationship_parameter_values=data, - unparse_value=unparse_value, - on_conflict=on_conflict, - ) - - -def _get_relationship_parameter_values_for_import(db_map, data, unparse_value, on_conflict): - db_map.fetch_all({"parameter_value"}, include_ancestors=True) - cache = db_map.cache - object_class_id_lists = { - x.id: x.dimension_id_list for x in cache.get("entity_class", {}).values() if x.dimension_id_list - } - parameter_value_ids = { - (x.entity_id, x.parameter_id, x.alternative_id): x.id for x in cache.get("parameter_value", {}).values() - } - parameters = { - x.id: { - "name": x.parameter_name, - "entity_class_id": x.entity_class_id, - "parameter_value_list_id": x.value_list_id, - } - for x in cache.get("parameter_definition", {}).values() - } - relationships = { - x.id: {"class_id": x.class_id, "name": x.name, "object_id_list": x.element_id_list} - for x in cache.get("entity", {}).values() - if x.element_id_list - } - parameter_value_lists = {x.id: x.value_id_list for x in cache.get("parameter_value_list", {}).values()} - list_values = {x.id: from_database(x.value, x.type) for x in cache.get("list_value", {}).values()} - parameter_ids = {(p["entity_class_id"], p["name"]): p_id for p_id, p in parameters.items()} - relationship_ids = {(r["class_id"], tuple(r["object_id_list"])): r_id for r_id, r in relationships.items()} - object_ids = {(o.name, o.class_id): o.id for o in cache.get("entity", {}).values() if not o.element_id_list} - relationship_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values() if x.dimension_id_list} - alternatives = {a.name: a.id for a in cache.get("alternative", {}).values()} - alternative_ids = set(alternatives.values()) - error_log = [] - to_add = [] - to_update = [] - checked = set() - for class_name, object_names, parameter_name, value, *optionals in data: - rc_id = relationship_class_ids.get(class_name, None) - oc_ids = object_class_id_lists.get(rc_id, []) - if len(object_names) == len(oc_ids): - o_ids = tuple(object_ids.get((name, oc_id), None) for name, oc_id in zip(object_names, oc_ids)) - else: - o_ids = tuple(None for _ in object_names) - r_id = relationship_ids.get((rc_id, o_ids), None) - p_id = parameter_ids.get((rc_id, parameter_name), None) - if optionals: - alternative_name = optionals[0] - alt_id = alternatives.get(alternative_name) - if not alt_id: - error_log.append( - ImportErrorLogItem( - msg=( - "Could not import parameter value for " - f"'{object_names}', class '{class_name}', parameter '{parameter_name}': " - f"alternative {alternative_name} does not exist." - ), - db_type="parameter value", - ) - ) - continue - else: - alt_id, alternative_name = db_map.get_import_alternative() - alternative_ids.add(alt_id) - checked_key = (r_id, p_id, alt_id) - if checked_key in checked: - msg = ( - f"Could not import parameter value for '{object_names}', class '{class_name}', " - f"parameter '{parameter_name}', alternative {alternative_name}: " - "Duplicate parameter value, only first value will be considered." - ) - error_log.append(ImportErrorLogItem(msg=msg, db_type="parameter value")) - continue - pv_id = parameter_value_ids.pop((r_id, p_id, alt_id), None) - value, type_ = unparse_value(value) - if pv_id is not None: - current_pv = cache["parameter_value"][pv_id] - value, type_ = fix_conflict((value, type_), (current_pv.value, current_pv.type), on_conflict) - item = { - "parameter_definition_id": p_id, - "entity_class_id": rc_id, - "entity_id": r_id, - "relationship_class_id": rc_id, - "relationship_id": r_id, - "value": value, - "type": type_, - "alternative_id": alt_id, - } - try: - check_parameter_value( - item, - parameter_value_ids, - parameters, - relationships, - parameter_value_lists, - list_values, - alternative_ids, - ) - except SpineIntegrityError as e: - error_log.append( - ImportErrorLogItem( - msg="Could not import parameter value for '{0}', class '{1}', parameter '{2}': {3}".format( - object_names, class_name, parameter_name, e.msg - ), - db_type="parameter value", - ) - ) - continue - finally: - if pv_id is not None: - parameter_value_ids[r_id, p_id, alt_id] = pv_id - checked.add(checked_key) - if pv_id is not None: - item["id"] = pv_id - to_update.append(item) - else: - to_add.append(item) - return to_add, to_update, error_log - - -def import_parameter_value_lists(db_map, data, unparse_value=to_database): - """Imports list of parameter value lists: - - Example:: - - data = [ - ['value_list_name', value1], ['value_list_name', value2], - ['another_value_list_name', value3], - ] - import_parameter_value_lists(db_map, data) - - Args: - db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into - data (List[List/Tuple]): list/set/iterable of lists/tuples with - value list name, list of values - - Returns: - (Int, List) Number of successful inserted objects, list of errors - """ - return import_data(db_map, parameter_value_lists=data, unparse_value=unparse_value) - - -def _get_parameter_value_lists_for_import(db_map, data): - db_map.fetch_all({"parameter_value_list"}, include_ancestors=True) - cache = db_map.cache - parameter_value_list_ids = {x.name: x.id for x in cache.get("parameter_value_list", {}).values()} - error_log = [] - to_add = [] - for name in list({x[0]: None for x in data}): - item = {"name": name} - try: - check_parameter_value_list(item, parameter_value_list_ids) - except SpineIntegrityError: - continue - to_add.append(item) - return to_add, [], error_log - - -def _get_list_values_for_import(db_map, data, unparse_value): - db_map.fetch_all({"list_value"}, include_ancestors=True) - cache = db_map.cache - value_lists_by_name = { - x.name: ( - x.id, - max( - (y.index for y in cache.get("list_value", {}).values() if y.parameter_value_list_id == x.id), default=-1 - ), - ) - for x in cache.get("parameter_value_list", {}).values() - } - list_value_ids_by_index = {(x.parameter_value_list_id, x.index): x.id for x in cache.get("list_value", {}).values()} - list_value_ids_by_value = { - (x.parameter_value_list_id, x.type, x.value): x.id for x in cache.get("list_value", {}).values() - } - list_names_by_id = {x.id: x.name for x in cache.get("parameter_value_list", {}).values()} - error_log = [] - to_add = [] - to_update = [] - seen_values = set() - max_indexes = dict() - for list_name, value in data: - try: - list_id, current_max_index = value_lists_by_name.get(list_name) - except TypeError: - # cannot unpack non-iterable NoneType object - error_log.append( - ImportErrorLogItem( - msg=f"Could not import value for list '{list_name}': list not found", db_type="list value" - ) - ) - continue - val, type_ = unparse_value(value) - if (list_id, type_, val) in seen_values: - error_log.append( - ImportErrorLogItem( - msg=f"Could not import value for list '{list_name}': " - "Duplicate value, only first will be considered", - db_type="list value", - ) - ) - continue - max_index = max_indexes.get(list_id) - if max_index is not None: - index = max_index + 1 - else: - index = current_max_index + 1 - item = {"parameter_value_list_id": list_id, "value": val, "type": type_, "index": index} - try: - check_list_value(item, list_names_by_id, list_value_ids_by_index, list_value_ids_by_value) - except SpineIntegrityError as e: - if e.id is None: - error_log.append( - ImportErrorLogItem( - msg=f"Could not import value '{value}' for list '{list_name}': {e.msg}", db_type="list value" - ) - ) - continue - max_indexes[list_id] = index - seen_values.add((list_id, type_, val)) - to_add.append(item) - return to_add, to_update, error_log - - -def import_metadata(db_map, data=None): - """Imports metadata. Ignores duplicates. - - Example:: - - data = ['{"name1": "value1"}', '{"name2": "value2"}'] - import_metadata(db_map, data) - - Args: - db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into - data (List[List/Tuple]): list/set/iterable of string metadata entries in JSON format - - Returns: - (Int, List) Number of successful inserted objects, list of errors - """ - return import_data(db_map, metadata=data) - - -def _get_metadata_for_import(db_map, data): - db_map.fetch_all({"metadata"}, include_ancestors=True) - cache = db_map.cache - seen = {(x.name, x.value) for x in cache.get("metadata", {}).values()} - to_add = [] - for metadata in data: - for name, value in _parse_metadata(metadata): - if (name, value) in seen: - continue - item = {"name": name, "value": value} - seen.add((name, value)) - to_add.append(item) - return to_add, [], [] - - -# TODO: import_entity_metadata, import_parameter_value_metadata + return import_data(db_map, relationship_parameter_values=data, unparse_value=unparse_value, on_conflict=on_conflict) def import_object_metadata(db_map, data): @@ -1903,45 +678,6 @@ def import_object_metadata(db_map, data): return import_data(db_map, object_metadata=data) -def _get_object_metadata_for_import(db_map, data): - db_map.fetch_all({"object", "entity_metadata"}, include_ancestors=True) - cache = db_map.cache - object_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values() if not x.dimension_id_list} - metadata_ids = {(x.name, x.value): x.id for x in cache.get("metadata", {}).values()} - object_ids = {(x.name, x.class_id): x.id for x in cache.get("entity", {}).values() if not x.element_id_list} - seen = {(x.entity_id, x.metadata_id) for x in cache.get("entity_metadata", {}).values()} - error_log = [] - to_add = [] - for class_name, object_name, metadata in data: - oc_id = object_class_ids.get(class_name, None) - o_id = object_ids.get((object_name, oc_id), None) - if o_id is None: - error_log.append( - ImportErrorLogItem( - msg=f"Could not import object metadata: unknown object '{object_name}' of class '{class_name}'", - db_type="object metadata", - ) - ) - continue - for name, value in _parse_metadata(metadata): - m_id = metadata_ids.get((name, value), None) - if m_id is None: - error_log.append( - ImportErrorLogItem( - msg=f"Could not import object metadata: unknown metadata '{name}': '{value}'", - db_type="object metadata", - ) - ) - continue - unique_key = (o_id, m_id) - if unique_key in seen: - continue - item = {"entity_id": o_id, "metadata_id": m_id} - seen.add(unique_key) - to_add.append(item) - return to_add, [], error_log - - def import_relationship_metadata(db_map, data): """Imports relationship metadata. Ignores duplicates. @@ -1964,55 +700,6 @@ def import_relationship_metadata(db_map, data): return import_data(db_map, relationship_metadata=data) -def _get_relationship_metadata_for_import(db_map, data): - db_map.fetch_all({"relationship", "entity_metadata"}, include_ancestors=True) - cache = db_map.cache - relationship_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values() if x.dimension_id_list} - object_class_id_lists = { - x.id: x.dimension_id_list for x in cache.get("entity_class", {}).values() if x.dimension_id_list - } - metadata_ids = {(x.name, x.value): x.id for x in cache.get("metadata", {}).values()} - object_ids = {(x.name, x.class_id): x.id for x in cache.get("entity", {}).values() if not x.element_id_list} - relationship_ids = { - (x.class_id, x.element_id_list): x.id for x in cache.get("entity", {}).values() if x.element_id_list - } - seen = {(x.entity_id, x.metadata_id) for x in cache.get("entity_metadata", {}).values()} - error_log = [] - to_add = [] - for class_name, object_names, metadata in data: - rc_id = relationship_class_ids.get(class_name, None) - oc_ids = object_class_id_lists.get(rc_id, []) - o_ids = tuple(object_ids.get((name, oc_id), None) for name, oc_id in zip(object_names, oc_ids)) - r_id = relationship_ids.get((rc_id, o_ids), None) - if r_id is None: - error_log.append( - ImportErrorLogItem( - msg="Could not import relationship metadata: unknown relationship '{0}' of class '{1}'".format( - object_names, class_name - ), - db_type="relationship metadata", - ) - ) - continue - for name, value in _parse_metadata(metadata): - m_id = metadata_ids.get((name, value), None) - if m_id is None: - error_log.append( - ImportErrorLogItem( - msg=f"Could not import relationship metadata: unknown metadata '{name}': '{value}'", - db_type="relationship metadata", - ) - ) - continue - unique_key = (r_id, m_id) - if unique_key in seen: - continue - item = {"entity_id": r_id, "metadata_id": m_id} - seen.add(unique_key) - to_add.append(item) - return to_add, [], error_log - - def import_object_parameter_value_metadata(db_map, data): """Imports object parameter value metadata. Ignores duplicates. @@ -2035,60 +722,6 @@ def import_object_parameter_value_metadata(db_map, data): return import_data(db_map, object_parameter_value_metadata=data) -def _get_object_parameter_value_metadata_for_import(db_map, data): - db_map.fetch_all({"parameter_value", "parameter_value_metadata"}, include_ancestors=True) - cache = db_map.cache - object_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values() if not x.dimension_id_list} - object_ids = {(x.name, x.class_id): x.id for x in cache.get("entity", {}).values() if not x.element_id_list} - parameter_ids = { - (x.parameter_name, x.entity_class_id): x.id for x in cache.get("parameter_definition", {}).values() - } - alternative_ids = {a.name: a.id for a in cache.get("alternative", {}).values()} - parameter_value_ids = { - (x.entity_id, x.parameter_id, x.alternative_id): x.id for x in cache.get("parameter_value", {}).values() - } - metadata_ids = {(x.name, x.value): x.id for x in cache.get("metadata", {}).values()} - seen = {(x.parameter_value_id, x.metadata_id) for x in cache.get("parameter_value_metadata", {}).values()} - error_log = [] - to_add = [] - for class_name, object_name, parameter_name, metadata, *optionals in data: - oc_id = object_class_ids.get(class_name, None) - o_id = object_ids.get((object_name, oc_id), None) - p_id = parameter_ids.get((parameter_name, oc_id), None) - if optionals: - alternative_name = optionals[0] - alt_id = alternative_ids.get(alternative_name, None) - else: - alt_id, alternative_name = db_map.get_import_alternative() - pv_id = parameter_value_ids.get((o_id, p_id, alt_id), None) - if pv_id is None: - msg = ( - "Could not import object parameter value metadata: " - "parameter {0} doesn't have a value for object {1}, alternative {2}".format( - parameter_name, object_name, alternative_name - ) - ) - error_log.append(ImportErrorLogItem(msg=msg, db_type="object parameter value metadata")) - continue - for name, value in _parse_metadata(metadata): - m_id = metadata_ids.get((name, value), None) - if m_id is None: - error_log.append( - ImportErrorLogItem( - msg=f"Could not import object parameter value metadata: unknown metadata '{name}': '{value}'", - db_type="object parameter value metadata", - ) - ) - continue - unique_key = (pv_id, m_id) - if unique_key in seen: - continue - item = {"parameter_value_id": pv_id, "metadata_id": m_id} - seen.add(unique_key) - to_add.append(item) - return to_add, [], error_log - - def import_relationship_parameter_value_metadata(db_map, data): """Imports relationship parameter value metadata. Ignores duplicates. @@ -2111,59 +744,202 @@ def import_relationship_parameter_value_metadata(db_map, data): return import_data(db_map, relationship_parameter_value_metadata=data) -def _get_relationship_parameter_value_metadata_for_import(db_map, data): - db_map.fetch_all({"parameter_value", "parameter_value_metadata"}, include_ancestors=True) - cache = db_map.cache - relationship_class_ids = {x.name: x.id for x in cache.get("entity_class", {}).values() if x.dimension_id_list} - object_class_id_lists = { - x.id: x.dimension_id_list for x in cache.get("entity_class", {}).values() if x.dimension_id_list - } - object_ids = {(x.name, x.class_id): x.id for x in cache.get("entity", {}).values() if not x.element_id_list} - relationship_ids = { - (x.element_id_list, x.class_id): x.id for x in cache.get("entity", {}).values() if x.element_id_list - } - parameter_ids = { - (x.parameter_name, x.entity_class_id): x.id for x in cache.get("parameter_definition", {}).values() - } - alternative_ids = {a.name: a.id for a in cache.get("alternative", {}).values()} - parameter_value_ids = { - (x.entity_id, x.parameter_id, x.alternative_id): x.id for x in cache.get("parameter_value", {}).values() - } - metadata_ids = {(x.name, x.value): x.id for x in cache.get("metadata", {}).values()} - seen = {(x.parameter_value_id, x.metadata_id) for x in cache.get("parameter_value_metadata", {}).values()} - error_log = [] +def _get_items_for_import(db_map, item_type, data): + table_cache = db_map.cache.table_cache(item_type) + errors = [] to_add = [] - for class_name, object_names, parameter_name, metadata, *optionals in data: - rc_id = relationship_class_ids.get(class_name, None) - oc_ids = object_class_id_lists.get(rc_id, []) - o_ids = tuple(object_ids.get((name, oc_id), None) for name, oc_id in zip(object_names, oc_ids)) - r_id = relationship_ids.get((o_ids, rc_id), None) - p_id = parameter_ids.get((parameter_name, rc_id), None) - if optionals: - alternative_name = optionals[0] - alt_id = alternative_ids.get(alternative_name, None) - else: - alt_id, alternative_name = db_map.get_import_alternative() - pv_id = parameter_value_ids.get((r_id, p_id, alt_id), None) - if pv_id is None: - msg = ( - "Could not import relationship parameter value metadata: " - "parameter '{0}' doesn't have a value for relationship '{1}', alternative '{2}'".format( - parameter_name, object_names, alternative_name - ) - ) - error_log.append(ImportErrorLogItem(msg=msg, db_type="relationship parameter value metadata")) - continue - for name, value in _parse_metadata(metadata): - m_id = metadata_ids.get((name, value), None) - if m_id is None: - msg = f"Could not import relationship parameter value metadata: unknown metadata '{name}': '{value}'" - error_log.append(ImportErrorLogItem(msg=msg, db_type="relationship parameter value metadata")) + to_update = [] + with db_map.generate_ids(item_type) as new_id: + for item in data: + checked_item, error = table_cache.check_item(item) + if not error: + item["id"] = new_id() + to_add.append(checked_item) continue - unique_key = (pv_id, m_id) - if unique_key in seen: + checked_item, error = table_cache.check_item(item, for_update=True) + if not error: + to_update.append(checked_item) continue - item = {"parameter_value_id": pv_id, "metadata_id": m_id} - seen.add(unique_key) - to_add.append(item) - return to_add, [], error_log + errors.append(error) + return to_add, to_update, errors + + +def _get_entity_classes_for_import(db_map, data): + key = ("name", "dimension_name_list", "description", "display_icon") + return _get_items_for_import( + db_map, "entity_class", ({"name": x} if isinstance(x, str) else dict(zip(key, x)) for x in data) + ) + + +def _get_entities_for_import(db_map, data): + def _data_iterator(): + for class_name, name_or_element_name_list, *optionals in data: + byname_key = "name" if isinstance(name_or_element_name_list, str) else "element_name_list" + key = ("class_name", byname_key, "description") + yield dict(zip(key, (class_name, name_or_element_name_list, *optionals))) + + return _get_items_for_import(db_map, "entity", _data_iterator()) + + +def _get_entity_groups_for_import(db_map, data): + key = ("class_name", "group_name", "member_name") + return _get_items_for_import(db_map, "entity_group", (dict(zip(key, x)) for x in data)) + + +def _get_parameter_definitions_for_import(db_map, data, unparse_value): + def _data_iterator(): + for class_name, parameter_name, *optionals in data: + if not optionals: + yield class_name, parameter_name + continue + value = optionals.pop(0) + value, type_ = unparse_value(value) + yield class_name, parameter_name, value, type_, *optionals + + key = ("entity_class_name", "name", "default_value", "default_type", "parameter_value_list_name", "description") + return _get_items_for_import(db_map, "parameter_definition", (dict(zip(key, x)) for x in _data_iterator())) + + +def _get_parameter_values_for_import(db_map, data, unparse_value, on_conflict): + def _data_iterator(): + for class_name, entity_byname, parameter_name, value, *optionals in data: + if isinstance(entity_byname, str): + entity_byname = (entity_byname,) + value, type_ = unparse_value(value) + yield class_name, entity_byname, parameter_name, value, type_, *optionals + + key = ("entity_class_name", "entity_byname", "parameter_definition_name", "value", "type", "alternative_name") + return _get_items_for_import(db_map, "parameter_value", (dict(zip(key, x)) for x in _data_iterator())) + # FIXME: value, type_ = fix_conflict((value, type_), (current_pv.value, current_pv.type), on_conflict) + + +def _get_alternatives_for_import(db_map, data): + key = ("name", "description") + return _get_items_for_import( + db_map, "alternative", ({"name": x} if isinstance(x, str) else dict(zip(key, x)) for x in data) + ) + + +def _get_scenarios_for_import(db_map, data): + key = ("name", "active", "description") + return _get_items_for_import( + db_map, "scenario", ({"name": x} if isinstance(x, str) else dict(zip(key, x)) for x in data) + ) + + +def _get_scenario_alternatives_for_import(db_map, data): + key = ("scenario_name", "alternative_name", "before_alternative_name") + return _get_items_for_import(db_map, "scenario_alternative", (dict(zip(key, x)) for x in data)) + + +def _get_parameter_value_lists_for_import(db_map, data): + return _get_items_for_import(db_map, "parameter_value_list", ({"name": x} for x in {x[0]: None for x in data})) + + +def _get_list_values_for_import(db_map, data, unparse_value): + def _data_iterator(): + for list_name, value in data: + value, type_ = unparse_value(value) + yield {"parameter_value_list_name": list_name, "value": value, "type": type_} + + return _get_items_for_import(db_map, "list_value", _data_iterator()) + + +def _get_metadata_for_import(db_map, data): + def _data_iterator(): + for metadata in data: + for name, value in _parse_metadata(metadata): + yield {"name": name, "value": value} + + return _get_items_for_import(db_map, "metadata", _data_iterator()) + + +def _get_entity_metadata_for_import(db_map, data): + def _data_iterator(): + for class_name, entity_byname, metadata in data: + if isinstance(entity_byname, str): + entity_byname = (entity_byname,) + for name, value in _parse_metadata(metadata): + yield (class_name, entity_byname, name, value) + + key = ("entity_class_name", "entity_byname", "metadata_name", "metadata_value") + return _get_items_for_import(db_map, "entity_metadata", (dict(zip(key, x)) for x in _data_iterator())) + + +def _get_parameter_value_metadata_for_import(db_map, data): + def _data_iterator(): + for class_name, entity_byname, parameter_name, metadata, *optionals in data: + if isinstance(entity_byname, str): + entity_byname = (entity_byname,) + for name, value in _parse_metadata(metadata): + yield (class_name, entity_byname, parameter_name, name, value, *optionals) + + key = ( + "entity_class_name", + "entity_byname", + "parameter_definition_name", + "metadata_name", + "metadata_value", + "alternative_name", + ) + return _get_items_for_import(db_map, "parameter_value_metadata", (dict(zip(key, x)) for x in _data_iterator())) + + +# Legacy +def _get_object_classes_for_import(db_map, data): + def _data_iterator(): + for x in data: + if isinstance(x, str): + yield x + name, *optionals = x + yield name, (), *optionals + + return _get_entity_classes_for_import(db_map, _data_iterator()) + + +def _get_relationship_classes_for_import(db_map, data): + return _get_entity_classes_for_import(db_map, data) + + +def _get_objects_for_import(db_map, data): + return _get_entities_for_import(db_map, data) + + +def _get_relationships_for_import(db_map, data): + return _get_entities_for_import(db_map, data) + + +def _get_object_groups_for_import(db_map, data): + return _get_entity_groups_for_import(db_map, data) + + +def _get_object_parameters_for_import(db_map, data, unparse_value): + return _get_parameter_definitions_for_import(db_map, data, unparse_value) + + +def _get_relationship_parameters_for_import(db_map, data, unparse_value): + return _get_parameter_definitions_for_import(db_map, data, unparse_value) + + +def _get_object_parameter_values_for_import(db_map, data, unparse_value, on_conflict): + return _get_parameter_values_for_import(db_map, data, unparse_value, on_conflict) + + +def _get_relationship_parameter_values_for_import(db_map, data, unparse_value, on_conflict): + return _get_parameter_values_for_import(db_map, data, unparse_value, on_conflict) + + +def _get_object_metadata_for_import(db_map, data): + return _get_entity_metadata_for_import(db_map, data) + + +def _get_relationship_metadata_for_import(db_map, data): + return _get_entity_metadata_for_import(db_map, data) + + +def _get_object_parameter_value_metadata_for_import(db_map, data): + return _get_parameter_value_metadata_for_import(db_map, data) + + +def _get_relationship_parameter_value_metadata_for_import(db_map, data): + return _get_parameter_value_metadata_for_import(db_map, data) From 6cc604f4c1d79585e0ea9efc23ea3162aad65203 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 10 May 2023 17:19:41 +0200 Subject: [PATCH 036/317] Get rid of the session attribute, implement our own Query --- spinedb_api/db_mapping_base.py | 358 +++++++++---------- spinedb_api/db_mapping_commit_mixin.py | 8 - spinedb_api/db_mapping_remove_mixin.py | 2 +- spinedb_api/db_mapping_update_mixin.py | 2 +- spinedb_api/export_mapping/export_mapping.py | 38 +- spinedb_api/filters/execution_filter.py | 1 + spinedb_api/filters/scenario_filter.py | 4 +- spinedb_api/query.py | 73 ++++ 8 files changed, 264 insertions(+), 222 deletions(-) create mode 100644 spinedb_api/query.py diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 581ca722..26307b4e 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -9,9 +9,7 @@ # this program. If not, see . ###################################################################################################################### -"""Provides :class:`.DatabaseMappingBase`. - -""" +"""Provides :class:`.DatabaseMappingBase`.""" # TODO: Finish docstrings import uuid import hashlib @@ -21,11 +19,11 @@ from collections import Counter from types import MethodType from concurrent.futures import ThreadPoolExecutor -from sqlalchemy import create_engine, case, MetaData, Table, Column, false, and_, func, inspect, cast, Integer, or_ -from sqlalchemy.sql.expression import label, Alias +from sqlalchemy import create_engine, MetaData, Table, Column, Integer, inspect, case, func, cast, false, and_, or_ +from sqlalchemy.sql.expression import label, Alias, select from sqlalchemy.engine.url import make_url, URL -from sqlalchemy.orm import Session, aliased -from sqlalchemy.exc import DatabaseError +from sqlalchemy.orm import aliased +from sqlalchemy.exc import DatabaseError, ProgrammingError from sqlalchemy.event import listen from sqlalchemy.pool import NullPool from alembic.migration import MigrationContext @@ -46,6 +44,7 @@ from .filters.tools import pop_filter_configs from .spine_db_client import get_db_url_from_server from .db_cache import DBCache +from .query import Query logging.getLogger("alembic").setLevel(logging.CRITICAL) @@ -125,7 +124,6 @@ def __init__( listen(self.engine, 'close', self._receive_engine_close) self.executor = self._make_executor() self.connection = self.executor.submit(self.engine.connect).result() - self.session = Session(self.connection, **self._session_kwargs) if self._memory: self.executor.submit(copy_database_bind, self.connection, self._original_engine) self._metadata = MetaData(self.connection) @@ -155,8 +153,8 @@ def __init__( self._parameter_value_metadata_sq = None self._entity_metadata_sq = None # Special convenience subqueries that join two or more tables - self._ext_entity_class_sq = None - self._ext_entity_sq = None + self._wide_entity_class_sq = None + self._wide_entity_sq = None self._ext_parameter_value_list_sq = None self._wide_parameter_value_list_sq = None self._ord_list_value_sq = None @@ -197,23 +195,6 @@ def __init__( "entity_alternative": ("entity_id", "alternative_id"), "entity_class_dimension": ("entity_class_id", "position"), } - # Subqueries used to populate cache - self.cache_sqs = { - "entity_class": "ext_entity_class_sq", - "entity": "ext_entity_sq", - "parameter_value_list": "parameter_value_list_sq", - "list_value": "list_value_sq", - "alternative": "alternative_sq", - "scenario": "scenario_sq", - "scenario_alternative": "scenario_alternative_sq", - "entity_group": "entity_group_sq", - "parameter_definition": "parameter_definition_sq", - "parameter_value": "parameter_value_sq", - "metadata": "metadata_sq", - "entity_metadata": "entity_metadata_sq", - "parameter_value_metadata": "parameter_value_metadata_sq", - "commit": "commit_sq", - } self.ancestor_tablenames = { "scenario_alternative": ("scenario", "alternative"), "entity": ("entity_class",), @@ -239,7 +220,7 @@ def __init__( "list_value": ("parameter_value_list",), } self.descendant_tablenames = { - tablename: set(self._descendant_tablenames(tablename)) for tablename in self.cache_sqs + tablename: set(self._descendant_tablenames(tablename)) for tablename in self.ITEM_TYPES } def __enter__(self): @@ -422,7 +403,7 @@ def in_(self, column, values): prefixes=['TEMPORARY'], ) self.executor.submit(in_value.create, self.connection, checkfirst=True).result() - self._checked_execute(in_value.insert(), [{"value": column.type.python_type(val)} for val in set(values)]) + self.safe_execute(in_value.insert(), [{"value": column.type.python_type(val)} for val in set(values)]) return column.in_(self.query(in_value.c.value)) def _get_table_to_sq_attr(self): @@ -488,7 +469,7 @@ def query(self, *args, **kwargs): db_map.object_sq.c.class_id == db_map.object_class_sq.c.id ).group_by(db_map.object_class_sq.c.name).all() """ - return self.session.query(*args, **kwargs) + return Query(self, select(args)) def _subquery(self, tablename): """A subquery of the form: @@ -504,7 +485,7 @@ def _subquery(self, tablename): sqlalchemy.sql.expression.Alias """ table = self._metadata.tables[tablename] - return self.query(table).subquery() + return self.query(table).subquery(tablename + "_sq") @property def alternative_sq(self): @@ -545,18 +526,6 @@ def entity_class_dimension_sq(self): self._entity_class_dimension_sq = self._subquery("entity_class_dimension") return self._entity_class_dimension_sq - @property - def entity_element_sq(self): - if self._entity_element_sq is None: - self._entity_element_sq = self._subquery("entity_element") - return self._entity_element_sq - - @property - def entity_alternative_sq(self): - if self._entity_alternative_sq is None: - self._entity_alternative_sq = self._subquery("entity_alternative") - return self._entity_alternative_sq - @property def entity_sq(self): """A subquery of the form: @@ -573,7 +542,19 @@ def entity_sq(self): return self._entity_sq @property - def ext_entity_class_sq(self): + def entity_element_sq(self): + if self._entity_element_sq is None: + self._entity_element_sq = self._subquery("entity_element") + return self._entity_element_sq + + @property + def entity_alternative_sq(self): + if self._entity_alternative_sq is None: + self._entity_alternative_sq = self._subquery("entity_alternative") + return self._entity_alternative_sq + + @property + def wide_entity_class_sq(self): """A subquery of the form: .. code-block:: sql @@ -591,7 +572,7 @@ def ext_entity_class_sq(self): Returns: sqlalchemy.sql.expression.Alias """ - if self._ext_entity_class_sq is None: + if self._wide_entity_class_sq is None: entity_class_dimension_sq = ( self.query( self.entity_class_dimension_sq.c.entity_class_id, @@ -600,7 +581,7 @@ def ext_entity_class_sq(self): self.entity_class_sq.c.name.label("dimension_name"), ) .filter(self.entity_class_dimension_sq.c.dimension_id == self.entity_class_sq.c.id) - .subquery() + .subquery("entity_class_dimension_sq") ) ecd_sq = ( self.query( @@ -619,9 +600,9 @@ def ext_entity_class_sq(self): self.entity_class_sq.c.id == entity_class_dimension_sq.c.entity_class_id, ) .order_by(self.entity_class_sq.c.id, entity_class_dimension_sq.c.position) - .subquery() + .subquery("ext_entity_class_sq") ) - self._ext_entity_class_sq = ( + self._wide_entity_class_sq = ( self.query( ecd_sq.c.id, ecd_sq.c.name, @@ -640,12 +621,12 @@ def ext_entity_class_sq(self): ecd_sq.c.display_icon, ecd_sq.c.hidden, ) - .subquery() + .subquery("wide_entity_class_sq") ) - return self._ext_entity_class_sq + return self._wide_entity_class_sq @property - def ext_entity_sq(self): + def wide_entity_sq(self): """A subquery of the form: .. code-block:: sql @@ -663,41 +644,41 @@ def ext_entity_sq(self): Returns: sqlalchemy.sql.expression.Alias """ - if self._ext_entity_sq is None: + if self._wide_entity_sq is None: entity_element_sq = ( self.query(self.entity_element_sq, self.entity_sq.c.name.label("element_name")) .filter(self.entity_element_sq.c.element_id == self.entity_sq.c.id) - .subquery() + .subquery("entity_element_sq") ) - entity_sq = ( + ext_entity_sq = ( self.query(self.entity_sq, entity_element_sq) .outerjoin( entity_element_sq, self.entity_sq.c.id == entity_element_sq.c.entity_id, ) .order_by(self.entity_sq.c.id, entity_element_sq.c.position) - .subquery() + .subquery("ext_entity_sq") ) - self._ext_entity_sq = ( + self._wide_entity_sq = ( self.query( - entity_sq.c.id, - entity_sq.c.class_id, - entity_sq.c.name, - entity_sq.c.description, - entity_sq.c.commit_id, - group_concat(entity_sq.c.element_id, entity_sq.c.position).label("element_id_list"), - group_concat(entity_sq.c.element_name, entity_sq.c.position).label("element_name_list"), + ext_entity_sq.c.id, + ext_entity_sq.c.class_id, + ext_entity_sq.c.name, + ext_entity_sq.c.description, + ext_entity_sq.c.commit_id, + group_concat(ext_entity_sq.c.element_id, ext_entity_sq.c.position).label("element_id_list"), + group_concat(ext_entity_sq.c.element_name, ext_entity_sq.c.position).label("element_name_list"), ) .group_by( - entity_sq.c.id, - entity_sq.c.class_id, - entity_sq.c.name, - entity_sq.c.description, - entity_sq.c.commit_id, + ext_entity_sq.c.id, + ext_entity_sq.c.class_id, + ext_entity_sq.c.name, + ext_entity_sq.c.description, + ext_entity_sq.c.commit_id, ) - .subquery() + .subquery("wide_entity_sq") ) - return self._ext_entity_sq + return self._wide_entity_sq @property def object_class_sq(self): @@ -713,14 +694,14 @@ def object_class_sq(self): if self._object_class_sq is None: self._object_class_sq = ( self.query( - self.ext_entity_class_sq.c.id.label("id"), - self.ext_entity_class_sq.c.name.label("name"), - self.ext_entity_class_sq.c.description.label("description"), - self.ext_entity_class_sq.c.display_order.label("display_order"), - self.ext_entity_class_sq.c.display_icon.label("display_icon"), - self.ext_entity_class_sq.c.hidden.label("hidden"), + self.wide_entity_class_sq.c.id.label("id"), + self.wide_entity_class_sq.c.name.label("name"), + self.wide_entity_class_sq.c.description.label("description"), + self.wide_entity_class_sq.c.display_order.label("display_order"), + self.wide_entity_class_sq.c.display_icon.label("display_icon"), + self.wide_entity_class_sq.c.hidden.label("hidden"), ) - .filter(self.ext_entity_class_sq.c.dimension_id_list == None) + .filter(self.wide_entity_class_sq.c.dimension_id_list == None) .subquery() ) return self._object_class_sq @@ -739,13 +720,13 @@ def object_sq(self): if self._object_sq is None: self._object_sq = ( self.query( - self.ext_entity_sq.c.id.label("id"), - self.ext_entity_sq.c.class_id.label("class_id"), - self.ext_entity_sq.c.name.label("name"), - self.ext_entity_sq.c.description.label("description"), - self.ext_entity_sq.c.commit_id.label("commit_id"), + self.wide_entity_sq.c.id.label("id"), + self.wide_entity_sq.c.class_id.label("class_id"), + self.wide_entity_sq.c.name.label("name"), + self.wide_entity_sq.c.description.label("description"), + self.wide_entity_sq.c.commit_id.label("commit_id"), ) - .filter(self.ext_entity_sq.c.element_id_list == None) + .filter(self.wide_entity_sq.c.element_id_list == None) .subquery() ) return self._object_sq @@ -768,12 +749,12 @@ def relationship_class_sq(self): ent_cls_dim_sq.c.entity_class_id.label("id"), ent_cls_dim_sq.c.position.label("dimension"), # NOTE: nothing to do with the `dimension` concept ent_cls_dim_sq.c.dimension_id.label("object_class_id"), - self.ext_entity_class_sq.c.name.label("name"), - self.ext_entity_class_sq.c.description.label("description"), - self.ext_entity_class_sq.c.display_icon.label("display_icon"), - self.ext_entity_class_sq.c.hidden.label("hidden"), + self.wide_entity_class_sq.c.name.label("name"), + self.wide_entity_class_sq.c.description.label("description"), + self.wide_entity_class_sq.c.display_icon.label("display_icon"), + self.wide_entity_class_sq.c.hidden.label("hidden"), ) - .filter(self.ext_entity_class_sq.c.id == ent_cls_dim_sq.c.entity_class_id) + .filter(self.wide_entity_class_sq.c.id == ent_cls_dim_sq.c.entity_class_id) .subquery() ) return self._relationship_class_sq @@ -797,10 +778,10 @@ def relationship_sq(self): ent_el_sq.c.position.label("dimension"), # NOTE: nothing to do with the `dimension` concept ent_el_sq.c.element_id.label("object_id"), ent_el_sq.c.entity_class_id.label("class_id"), - self.ext_entity_sq.c.name.label("name"), - self.ext_entity_sq.c.commit_id.label("commit_id"), + self.wide_entity_sq.c.name.label("name"), + self.wide_entity_sq.c.commit_id.label("commit_id"), ) - .filter(self.ext_entity_sq.c.id == ent_el_sq.c.entity_id) + .filter(self.wide_entity_sq.c.id == ent_el_sq.c.entity_id) .subquery() ) return self._relationship_sq @@ -817,20 +798,7 @@ def entity_group_sq(self): sqlalchemy.sql.expression.Alias """ if self._entity_group_sq is None: - group_entity = aliased(self.entity_sq) - member_entity = aliased(self.entity_sq) - entity_group_sq = self._subquery("entity_group") - self._entity_group_sq = ( - self.query( - entity_group_sq.c.id, - entity_group_sq.c.entity_class_id, - group_entity.c.id.label("entity_id"), - member_entity.c.id.label("member_id"), - ) - .join(group_entity, group_entity.c.id == entity_group_sq.c.entity_id) - .join(member_entity, member_entity.c.id == entity_group_sq.c.member_id) - .subquery() - ) + self._entity_group_sq = self._subquery("entity_group") return self._entity_group_sq @property @@ -1315,13 +1283,13 @@ def ext_entity_group_sq(self): self.entity_group_sq.c.entity_class_id.label("class_id"), self.entity_group_sq.c.entity_id.label("group_id"), self.entity_group_sq.c.member_id.label("member_id"), - self.ext_entity_class_sq.c.name.label("class_name"), + self.wide_entity_class_sq.c.name.label("class_name"), group_entity.c.name.label("group_name"), member_entity.c.name.label("member_name"), label("object_class_id", self._object_class_id()), label("relationship_class_id", self._relationship_class_id()), ) - .filter(self.entity_group_sq.c.entity_class_id == self.ext_entity_class_sq.c.id) + .filter(self.entity_group_sq.c.entity_class_id == self.wide_entity_class_sq.c.id) .join(group_entity, self.entity_group_sq.c.entity_id == group_entity.c.id) .join(member_entity, self.entity_group_sq.c.member_id == member_entity.c.id) .subquery() @@ -1341,7 +1309,7 @@ def entity_parameter_definition_sq(self): self.parameter_definition_sq.c.entity_class_id, self.parameter_definition_sq.c.object_class_id, self.parameter_definition_sq.c.relationship_class_id, - self.ext_entity_class_sq.c.name.label("entity_class_name"), + self.wide_entity_class_sq.c.name.label("entity_class_name"), label("object_class_name", self._object_class_name()), label("relationship_class_name", self._relationship_class_name()), label("object_class_id_list", self._object_class_id_list()), @@ -1356,8 +1324,8 @@ def entity_parameter_definition_sq(self): self.parameter_definition_sq.c.commit_id, ) .join( - self.ext_entity_class_sq, - self.ext_entity_class_sq.c.id == self.parameter_definition_sq.c.entity_class_id, + self.wide_entity_class_sq, + self.wide_entity_class_sq.c.id == self.parameter_definition_sq.c.entity_class_id, ) .outerjoin( self.parameter_value_list_sq, @@ -1365,7 +1333,7 @@ def entity_parameter_definition_sq(self): ) .outerjoin( self.wide_relationship_class_sq, - self.wide_relationship_class_sq.c.id == self.ext_entity_class_sq.c.id, + self.wide_relationship_class_sq.c.id == self.wide_entity_class_sq.c.id, ) .subquery() ) @@ -1523,13 +1491,13 @@ def entity_parameter_value_sq(self): self.parameter_definition_sq.c.entity_class_id, self.parameter_definition_sq.c.object_class_id, self.parameter_definition_sq.c.relationship_class_id, - self.ext_entity_class_sq.c.name.label("entity_class_name"), + self.wide_entity_class_sq.c.name.label("entity_class_name"), label("object_class_name", self._object_class_name()), label("relationship_class_name", self._relationship_class_name()), label("object_class_id_list", self._object_class_id_list()), label("object_class_name_list", self._object_class_name_list()), self.parameter_value_sq.c.entity_id, - self.ext_entity_sq.c.name.label("entity_name"), + self.wide_entity_sq.c.name.label("entity_name"), self.parameter_value_sq.c.object_id, self.parameter_value_sq.c.relationship_id, label("object_name", self._object_name()), @@ -1548,17 +1516,17 @@ def entity_parameter_value_sq(self): self.parameter_definition_sq, self.parameter_definition_sq.c.id == self.parameter_value_sq.c.parameter_definition_id, ) - .join(self.ext_entity_sq, self.parameter_value_sq.c.entity_id == self.ext_entity_sq.c.id) + .join(self.wide_entity_sq, self.parameter_value_sq.c.entity_id == self.wide_entity_sq.c.id) .join( - self.ext_entity_class_sq, - self.parameter_definition_sq.c.entity_class_id == self.ext_entity_class_sq.c.id, + self.wide_entity_class_sq, + self.parameter_definition_sq.c.entity_class_id == self.wide_entity_class_sq.c.id, ) .join(self.alternative_sq, self.parameter_value_sq.c.alternative_id == self.alternative_sq.c.id) .outerjoin( self.wide_relationship_class_sq, - self.wide_relationship_class_sq.c.id == self.ext_entity_class_sq.c.id, + self.wide_relationship_class_sq.c.id == self.wide_entity_class_sq.c.id, ) - .outerjoin(self.wide_relationship_sq, self.wide_relationship_sq.c.id == self.ext_entity_sq.c.id) + .outerjoin(self.wide_relationship_sq, self.wide_relationship_sq.c.id == self.wide_entity_sq.c.id) # object_id_list might be None when objects have been filtered out .filter( or_( @@ -1688,23 +1656,23 @@ def ext_entity_metadata_sq(self): ) return self._ext_entity_metadata_sq - def _make_entity_sq(self): + def _make_entity_class_sq(self): """ - Creates a subquery for entities. + Creates a subquery for entity classes. Returns: - Alias: an entity subquery + Alias: an entity class subquery """ - return self._subquery("entity") + return self._subquery("entity_class") - def _make_entity_class_sq(self): + def _make_entity_sq(self): """ - Creates a subquery for entity classes. + Creates a subquery for entities. Returns: - Alias: an entity class subquery + Alias: an entity subquery """ - return self._subquery("entity_class") + return self._subquery("entity") def _make_parameter_definition_sq(self): """ @@ -1739,9 +1707,9 @@ def _make_parameter_definition_sq(self): par_def_sq.c.commit_id.label("commit_id"), par_def_sq.c.parameter_value_list_id.label("parameter_value_list_id"), ) - .join(self.ext_entity_class_sq, self.ext_entity_class_sq.c.id == par_def_sq.c.entity_class_id) + .join(self.wide_entity_class_sq, self.wide_entity_class_sq.c.id == par_def_sq.c.entity_class_id) .outerjoin(self.list_value_sq, self.list_value_sq.c.id == list_value_id) - .subquery() + .subquery("clean_parameter_definition_sq") ) def _make_parameter_value_sq(self): @@ -1771,10 +1739,10 @@ def _make_parameter_value_sq(self): par_val_sq.c.commit_id.label("commit_id"), par_val_sq.c.alternative_id, ) - .join(self.ext_entity_sq, self.ext_entity_sq.c.id == par_val_sq.c.entity_id) - .join(self.ext_entity_class_sq, self.ext_entity_class_sq.c.id == par_val_sq.c.entity_class_id) + .join(self.wide_entity_sq, self.wide_entity_sq.c.id == par_val_sq.c.entity_id) + .join(self.wide_entity_class_sq, self.wide_entity_class_sq.c.id == par_val_sq.c.entity_class_id) .outerjoin(self.list_value_sq, self.list_value_sq.c.id == list_value_id) - .subquery() + .subquery("clean_parameter_value_sq") ) def _make_alternative_sq(self): @@ -1830,21 +1798,15 @@ def _create_import_alternative(self): ids = self._add_alternatives({"name": self._import_alternative_name}) self._import_alternative_id = next(iter(ids)) - def override_entity_sq_maker(self, method): + def override_create_import_alternative(self, method): """ - Overrides the function that creates the ``entity_sq`` property. + Overrides the ``_create_import_alternative`` function. Args: - method (Callable): a function that accepts a :class:`DatabaseMappingBase` as its argument and - returns entity subquery as an :class:`Alias` object + method (Callable) """ - self._make_entity_sq = MethodType(method, self) - self._clear_subqueries("entity") - - def restore_entity_sq_maker(self): - """Restores the original function that creates the ``entity_sq`` property.""" - self._make_entity_sq = MethodType(DatabaseMappingBase._make_entity_sq, self) - self._clear_subqueries("entity") + self._create_import_alternative = MethodType(method, self) + self._import_alternative_id = None def override_entity_class_sq_maker(self, method): """ @@ -1857,10 +1819,16 @@ def override_entity_class_sq_maker(self, method): self._make_entity_class_sq = MethodType(method, self) self._clear_subqueries("entity_class") - def restore_entity_class_sq_maker(self): - """Restores the original function that creates the ``entity_class_sq`` property.""" - self._make_entity_class_sq = MethodType(DatabaseMappingBase._make_entity_class_sq, self) - self._clear_subqueries("entity_class") + def override_entity_sq_maker(self, method): + """ + Overrides the function that creates the ``entity_sq`` property. + + Args: + method (Callable): a function that accepts a :class:`DatabaseMappingBase` as its argument and + returns entity subquery as an :class:`Alias` object + """ + self._make_entity_sq = MethodType(method, self) + self._clear_subqueries("entity") def override_parameter_definition_sq_maker(self, method): """ @@ -1873,11 +1841,6 @@ def override_parameter_definition_sq_maker(self, method): self._make_parameter_definition_sq = MethodType(method, self) self._clear_subqueries("parameter_definition") - def restore_parameter_definition_sq_maker(self): - """Restores the original function that creates the ``parameter_definition_sq`` property.""" - self._make_parameter_definition_sq = MethodType(DatabaseMappingBase._make_parameter_definition_sq, self) - self._clear_subqueries("parameter_definition") - def override_parameter_value_sq_maker(self, method): """ Overrides the function that creates the ``parameter_value_sq`` property. @@ -1889,11 +1852,6 @@ def override_parameter_value_sq_maker(self, method): self._make_parameter_value_sq = MethodType(method, self) self._clear_subqueries("parameter_value") - def restore_parameter_value_sq_maker(self): - """Restores the original function that creates the ``parameter_value_sq`` property.""" - self._make_parameter_value_sq = MethodType(DatabaseMappingBase._make_parameter_value_sq, self) - self._clear_subqueries("parameter_value") - def override_alternative_sq_maker(self, method): """ Overrides the function that creates the ``alternative_sq`` property. @@ -1905,11 +1863,6 @@ def override_alternative_sq_maker(self, method): self._make_alternative_sq = MethodType(method, self) self._clear_subqueries("alternative") - def restore_alternative_sq_maker(self): - """Restores the original function that creates the ``alternative_sq`` property.""" - self._make_alternative_sq = MethodType(DatabaseMappingBase._make_alternative_sq, self) - self._clear_subqueries("alternative") - def override_scenario_sq_maker(self, method): """ Overrides the function that creates the ``scenario_sq`` property. @@ -1921,11 +1874,6 @@ def override_scenario_sq_maker(self, method): self._make_scenario_sq = MethodType(method, self) self._clear_subqueries("scenario") - def restore_scenario_sq_maker(self): - """Restores the original function that creates the ``scenario_sq`` property.""" - self._make_scenario_sq = MethodType(DatabaseMappingBase._make_scenario_sq, self) - self._clear_subqueries("scenario") - def override_scenario_alternative_sq_maker(self, method): """ Overrides the function that creates the ``scenario_alternative_sq`` property. @@ -1937,25 +1885,47 @@ def override_scenario_alternative_sq_maker(self, method): self._make_scenario_alternative_sq = MethodType(method, self) self._clear_subqueries("scenario_alternative") + def restore_entity_class_sq_maker(self): + """Restores the original function that creates the ``entity_class_sq`` property.""" + self._make_entity_class_sq = MethodType(DatabaseMappingBase._make_entity_class_sq, self) + self._clear_subqueries("entity_class") + + def restore_entity_sq_maker(self): + """Restores the original function that creates the ``entity_sq`` property.""" + self._make_entity_sq = MethodType(DatabaseMappingBase._make_entity_sq, self) + self._clear_subqueries("entity") + + def restore_parameter_definition_sq_maker(self): + """Restores the original function that creates the ``parameter_definition_sq`` property.""" + self._make_parameter_definition_sq = MethodType(DatabaseMappingBase._make_parameter_definition_sq, self) + self._clear_subqueries("parameter_definition") + + def restore_parameter_value_sq_maker(self): + """Restores the original function that creates the ``parameter_value_sq`` property.""" + self._make_parameter_value_sq = MethodType(DatabaseMappingBase._make_parameter_value_sq, self) + self._clear_subqueries("parameter_value") + + def restore_alternative_sq_maker(self): + """Restores the original function that creates the ``alternative_sq`` property.""" + self._make_alternative_sq = MethodType(DatabaseMappingBase._make_alternative_sq, self) + self._clear_subqueries("alternative") + + def restore_scenario_sq_maker(self): + """Restores the original function that creates the ``scenario_sq`` property.""" + self._make_scenario_sq = MethodType(DatabaseMappingBase._make_scenario_sq, self) + self._clear_subqueries("scenario") + def restore_scenario_alternative_sq_maker(self): """Restores the original function that creates the ``scenario_alternative_sq`` property.""" self._make_scenario_alternative_sq = MethodType(DatabaseMappingBase._make_scenario_alternative_sq, self) self._clear_subqueries("scenario_alternative") - def override_create_import_alternative(self, method): - """ - Overrides the ``_create_import_alternative`` function. - - Args: - method (Callable) - """ - self._create_import_alternative = MethodType(method, self) - self._import_alternative_id = None - - def _checked_execute(self, stmt, items): - if not items: - return - return self.executor.submit(self.connection.execute, stmt, items).result() + def safe_execute(self, *args): + # We try to execute directly. If we are in the wrong thread this will raise ProgrammingError. + try: + return self.connection.execute(*args) + except ProgrammingError: + return self.executor.submit(self.connection.execute, *args).result() def _get_primary_key(self, tablename): pk = self.composite_pks.get(tablename) @@ -1984,36 +1954,40 @@ def fetch_all(self, tablenames, include_descendants=False, include_ancestors=Fal } if force_tablenames: tablenames |= force_tablenames - for tablename in tablenames & self.cache_sqs.keys(): + for tablename in tablenames & set(self.ITEM_TYPES): self.cache.fetch_all(tablename) def _object_class_id(self): - return case([(self.ext_entity_class_sq.c.dimension_id_list == None, self.ext_entity_class_sq.c.id)], else_=None) + return case( + [(self.wide_entity_class_sq.c.dimension_id_list == None, self.wide_entity_class_sq.c.id)], else_=None + ) def _relationship_class_id(self): - return case([(self.ext_entity_class_sq.c.dimension_id_list != None, self.ext_entity_class_sq.c.id)], else_=None) + return case( + [(self.wide_entity_class_sq.c.dimension_id_list != None, self.wide_entity_class_sq.c.id)], else_=None + ) def _object_id(self): - return case([(self.ext_entity_sq.c.element_id_list == None, self.ext_entity_sq.c.id)], else_=None) + return case([(self.wide_entity_sq.c.element_id_list == None, self.wide_entity_sq.c.id)], else_=None) def _relationship_id(self): - return case([(self.ext_entity_sq.c.element_id_list != None, self.ext_entity_sq.c.id)], else_=None) + return case([(self.wide_entity_sq.c.element_id_list != None, self.wide_entity_sq.c.id)], else_=None) def _object_class_name(self): return case( - [(self.ext_entity_class_sq.c.dimension_id_list == None, self.ext_entity_class_sq.c.name)], else_=None + [(self.wide_entity_class_sq.c.dimension_id_list == None, self.wide_entity_class_sq.c.name)], else_=None ) def _relationship_class_name(self): return case( - [(self.ext_entity_class_sq.c.dimension_id_list != None, self.ext_entity_class_sq.c.name)], else_=None + [(self.wide_entity_class_sq.c.dimension_id_list != None, self.wide_entity_class_sq.c.name)], else_=None ) def _object_class_id_list(self): return case( [ ( - self.ext_entity_class_sq.c.dimension_id_list != None, + self.wide_entity_class_sq.c.dimension_id_list != None, self.wide_relationship_class_sq.c.object_class_id_list, ) ], @@ -2024,7 +1998,7 @@ def _object_class_name_list(self): return case( [ ( - self.ext_entity_class_sq.c.dimension_id_list != None, + self.wide_entity_class_sq.c.dimension_id_list != None, self.wide_relationship_class_sq.c.object_class_name_list, ) ], @@ -2032,16 +2006,16 @@ def _object_class_name_list(self): ) def _object_name(self): - return case([(self.ext_entity_sq.c.element_id_list == None, self.ext_entity_sq.c.name)], else_=None) + return case([(self.wide_entity_sq.c.element_id_list == None, self.wide_entity_sq.c.name)], else_=None) def _object_id_list(self): return case( - [(self.ext_entity_sq.c.element_id_list != None, self.wide_relationship_sq.c.object_id_list)], else_=None + [(self.wide_entity_sq.c.element_id_list != None, self.wide_relationship_sq.c.object_id_list)], else_=None ) def _object_name_list(self): return case( - [(self.ext_entity_sq.c.element_id_list != None, self.wide_relationship_sq.c.object_name_list)], else_=None + [(self.wide_entity_sq.c.element_id_list != None, self.wide_relationship_sq.c.object_name_list)], else_=None ) def _metadata_usage_counts(self): diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py index 0eaf5610..2efc916f 100644 --- a/spinedb_api/db_mapping_commit_mixin.py +++ b/spinedb_api/db_mapping_commit_mixin.py @@ -30,12 +30,6 @@ def has_pending_changes(self): # FIXME return True - def _get_sqlite_lock(self): - """Commits the session's natural transaction and begins a new locking one.""" - if self.sa_url.drivername == "sqlite": - self.session.commit() - self.session.execute("BEGIN IMMEDIATE") - def _make_commit_id(self): if self._commit_id is None: with self.engine.begin() as connection: @@ -66,7 +60,6 @@ def commit_session(self, comment): for tablename, items in to_update.items(): self._do_update_items(tablename, *items) self._do_remove_items(**to_remove) - self.executor.submit(self.session.commit) self._commit_id = None if self._memory: self._memory_dirty = True @@ -74,6 +67,5 @@ def commit_session(self, comment): def rollback_session(self): if not self.has_pending_changes(): raise SpineDBAPIError("Nothing to rollback.") - self.executor.submit(self.session.rollback) self.cache.reset_queries() self._commit_id = None diff --git a/spinedb_api/db_mapping_remove_mixin.py b/spinedb_api/db_mapping_remove_mixin.py index cc5e54ef..0fbd11a2 100644 --- a/spinedb_api/db_mapping_remove_mixin.py +++ b/spinedb_api/db_mapping_remove_mixin.py @@ -62,7 +62,7 @@ def _do_remove_items(self, **kwargs): table = self._metadata.tables[tablename] delete = table.delete().where(self.in_(getattr(table.c, id_field), ids)) try: - self.executor.submit(self.connection.execute, delete).result() + self.safe_execute(delete) except DBAPIError as e: msg = f"DBAPIError while removing {tablename} items: {e.orig.args}" raise SpineDBAPIError(msg) from e diff --git a/spinedb_api/db_mapping_update_mixin.py b/spinedb_api/db_mapping_update_mixin.py index 838d62aa..467ee0f8 100644 --- a/spinedb_api/db_mapping_update_mixin.py +++ b/spinedb_api/db_mapping_update_mixin.py @@ -32,7 +32,7 @@ def _do_update_items(self, tablename, *items_to_update): for k in self._get_primary_key(tablename_): upd = upd.where(getattr(table.c, k) == bindparam(k)) upd = upd.values({key: bindparam(key) for key in table.columns.keys() & items_to_update_[0].keys()}) - self._checked_execute(upd, [{**item} for item in items_to_update_]) + self.safe_execute(upd, [{**item} for item in items_to_update_]) except DBAPIError as e: msg = f"DBAPIError while updating '{tablename}' items: {e.orig.args}" raise SpineDBAPIError(msg) from e diff --git a/spinedb_api/export_mapping/export_mapping.py b/spinedb_api/export_mapping/export_mapping.py index 8a4066e7..8de7db99 100644 --- a/spinedb_api/export_mapping/export_mapping.py +++ b/spinedb_api/export_mapping/export_mapping.py @@ -635,16 +635,16 @@ class EntityClassMapping(ExportMapping): def add_query_columns(self, db_map, query): return query.add_columns( - db_map.ext_entity_class_sq.c.id.label("entity_class_id"), - db_map.ext_entity_class_sq.c.name.label("entity_class_name"), - db_map.ext_entity_class_sq.c.dimension_id_list.label("dimension_id_list"), - db_map.ext_entity_class_sq.c.dimension_name_list.label("dimension_name_list"), + db_map.wide_entity_class_sq.c.id.label("entity_class_id"), + db_map.wide_entity_class_sq.c.name.label("entity_class_name"), + db_map.wide_entity_class_sq.c.dimension_id_list.label("dimension_id_list"), + db_map.wide_entity_class_sq.c.dimension_name_list.label("dimension_name_list"), ) def filter_query(self, db_map, query): if any(isinstance(m, (DimensionMapping, ElementMapping)) for m in self.flatten()): - return query.filter(db_map.ext_entity_class_sq.c.dimension_id_list != None) - return query.filter(db_map.ext_entity_class_sq.c.dimension_id_list == None) + return query.filter(db_map.wide_entity_class_sq.c.dimension_id_list != None) + return query.filter(db_map.wide_entity_class_sq.c.dimension_id_list == None) @staticmethod def name_field(): @@ -676,14 +676,16 @@ class EntityMapping(ExportMapping): def add_query_columns(self, db_map, query): return query.add_columns( - db_map.ext_entity_sq.c.id.label("entity_id"), - db_map.ext_entity_sq.c.name.label("entity_name"), - db_map.ext_entity_sq.c.element_id_list, - db_map.ext_entity_sq.c.element_name_list, + db_map.wide_entity_sq.c.id.label("entity_id"), + db_map.wide_entity_sq.c.name.label("entity_name"), + db_map.wide_entity_sq.c.element_id_list, + db_map.wide_entity_sq.c.element_name_list, ) def filter_query(self, db_map, query): - return query.outerjoin(db_map.ext_entity_sq, db_map.ext_entity_sq.c.class_id == db_map.ext_entity_class_sq.c.id) + return query.outerjoin( + db_map.wide_entity_sq, db_map.wide_entity_sq.c.class_id == db_map.wide_entity_class_sq.c.id + ) @staticmethod def name_field(): @@ -721,7 +723,7 @@ def add_query_columns(self, db_map, query): def filter_query(self, db_map, query): return query.outerjoin( - db_map.ext_entity_group_sq, db_map.ext_entity_group_sq.c.class_id == db_map.ext_entity_class_sq.c.id + db_map.ext_entity_group_sq, db_map.ext_entity_group_sq.c.class_id == db_map.wide_entity_class_sq.c.id ).distinct() @staticmethod @@ -747,11 +749,11 @@ class EntityGroupEntityMapping(ExportMapping): def add_query_columns(self, db_map, query): return query.add_columns( - db_map.ext_entity_sq.c.id.label("entity_id"), db_map.ext_entity_sq.c.name.label("entity_name") + db_map.wide_entity_sq.c.id.label("entity_id"), db_map.wide_entity_sq.c.name.label("entity_name") ) def filter_query(self, db_map, query): - return query.filter(db_map.ext_entity_group_sq.c.member_id == db_map.ext_entity_sq.c.id) + return query.filter(db_map.ext_entity_group_sq.c.member_id == db_map.wide_entity_sq.c.id) @staticmethod def name_field(): @@ -797,7 +799,7 @@ def filter_query(self, db_map, query): position=self._highlight_dimension ) conditions = ( - and_(db_map.ext_entity_class_sq.c.id == x.entity_class_id, db_map.entity_class_sq.c.id == x.dimension_id) + and_(db_map.wide_entity_class_sq.c.id == x.entity_class_id, db_map.entity_class_sq.c.id == x.dimension_id) for x in highlighted_dimension_qry ) return query.filter(or_(*conditions)) @@ -883,7 +885,7 @@ def filter_query(self, db_map, query): position=self.query_parents("highlight_dimension") ) conditions = ( - and_(db_map.ext_entity_sq.c.id == x.entity_id, db_map.entity_sq.c.id == x.element_id) + and_(db_map.wide_entity_sq.c.id == x.entity_id, db_map.entity_sq.c.id == x.element_id) for x in highlighted_element_qry ) return query.filter(or_(*conditions)) @@ -950,7 +952,7 @@ def add_query_columns(self, db_map, query): def filter_query(self, db_map, query): column_names = {c["name"] for c in query.column_descriptions} # "dimension_id" in column_names means a DimensionHighlightingMapping is acting - entity_class_sq = db_map.entity_class_sq if "dimension_id" in column_names else db_map.ext_entity_class_sq + entity_class_sq = db_map.entity_class_sq if "dimension_id" in column_names else db_map.wide_entity_class_sq return query.outerjoin( db_map.parameter_definition_sq, db_map.parameter_definition_sq.c.entity_class_id == entity_class_sq.c.id, @@ -1127,7 +1129,7 @@ def filter_query(self, db_map, query): if not self._selects_value: return query column_names = {c["name"] for c in query.column_descriptions} - entity_sq = db_map.entity_sq if "element_id" in column_names else db_map.ext_entity_sq + entity_sq = db_map.entity_sq if "element_id" in column_names else db_map.wide_entity_sq return query.filter( and_( db_map.parameter_value_sq.c.entity_id == entity_sq.c.id, diff --git a/spinedb_api/filters/execution_filter.py b/spinedb_api/filters/execution_filter.py index b4741f62..078e53bb 100644 --- a/spinedb_api/filters/execution_filter.py +++ b/spinedb_api/filters/execution_filter.py @@ -151,6 +151,7 @@ def _create_import_alternative(db_map, state): db_map (DatabaseMappingBase): database the state applies to state (_ExecutionFilterState): a state bound to ``db_map`` """ + # FIXME execution_item = state.execution_item scenarios = state.scenarios timestamp = state.timestamp diff --git a/spinedb_api/filters/scenario_filter.py b/spinedb_api/filters/scenario_filter.py index bd4bba6a..7c8d767d 100644 --- a/spinedb_api/filters/scenario_filter.py +++ b/spinedb_api/filters/scenario_filter.py @@ -195,7 +195,7 @@ def _make_scenario_filtered_entity_sq(db_map, state): Returns: Alias: a subquery for entity filtered by selected scenario """ - ext_entity_sq = ( + wide_entity_sq = ( db_map.query( state.original_entity_sq, func.row_number() @@ -211,7 +211,7 @@ def _make_scenario_filtered_entity_sq(db_map, state): .filter(db_map.scenario_alternative_sq.c.scenario_id == state.scenario_id) ).subquery() # TODO: Maybe we want to filter multi-dimensional entities involving filtered entities right here too? - return db_map.query(ext_entity_sq).filter_by(max_rank_row_number=1, is_active=True).subquery() + return db_map.query(wide_entity_sq).filter_by(max_rank_row_number=1, is_active=True).subquery() def _make_scenario_filtered_parameter_value_sq(db_map, state): diff --git a/spinedb_api/query.py b/spinedb_api/query.py new file mode 100644 index 00000000..b21ef857 --- /dev/null +++ b/spinedb_api/query.py @@ -0,0 +1,73 @@ +###################################################################################################################### +# Copyright (C) 2017-2022 Spine project consortium +# This file is part of Spine Database API. +# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser +# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General +# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with +# this program. If not, see . +###################################################################################################################### + +"""Provides :class:`.Query`.""" + +from .exception import SpineDBAPIError + + +class Query: + def __init__(self, db_map, select_): + self._db_map = db_map + self._select = select_ + self._from = None + + def subquery(self, name=None): + return self._select.alias(name) + + def filter(self, *args): + self._select = self._select.where(*args) + return self + + def _get_from(self, right, on): + from_candidates = (set(_get_descendant_tables(on)) - {right}) & set(self._select.get_children()) + if len(from_candidates) != 1: + raise SpineDBAPIError(f"can't find a unique 'from-clause' to join into, candidates are {from_candidates}") + return next(iter(from_candidates)) + + def join(self, right, on, isouter=False): + from_ = self._get_from(right, on) if self._from is None else self._from + self._from = from_.join(right, on, isouter=isouter) + self._select = self._select.select_from(self._from) + return self + + def outerjoin(self, right, on): + return self.join(right, on, isouter=True) + + def order_by(self, *args): + self._select = self._select.order_by(*args) + return self + + def group_by(self, *args): + self._select = self._select.group_by(*args) + return self + + def limit(self, *args): + self._select = self._select.limit(*args) + return self + + def offset(self, *args): + self._select = self._select.offset(*args) + return self + + def all(self): + return list(self) + + def __iter__(self): + return self._db_map.connection.execute(self._select) + + +def _get_descendant_tables(on): + for x in on.get_children(): + try: + yield x.table + except AttributeError: + yield from _get_descendant_tables(x) From 6310e41f1343fd08df8120988aaedbc48559e84f Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 10 May 2023 17:20:37 +0200 Subject: [PATCH 037/317] Complete previous commit --- spinedb_api/db_cache.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/spinedb_api/db_cache.py b/spinedb_api/db_cache.py index 96419ebb..90599b0d 100644 --- a/spinedb_api/db_cache.py +++ b/spinedb_api/db_cache.py @@ -72,15 +72,30 @@ def advance_query(self, item_type): def _get_next_chunk(self, item_type): try: - sq_name = self._db_map.cache_sqs[item_type] + sq_name = { + "entity_class": "wide_entity_class_sq", + "entity": "wide_entity_sq", + "parameter_value_list": "parameter_value_list_sq", + "list_value": "list_value_sq", + "alternative": "alternative_sq", + "scenario": "scenario_sq", + "scenario_alternative": "scenario_alternative_sq", + "entity_group": "entity_group_sq", + "parameter_definition": "parameter_definition_sq", + "parameter_value": "parameter_value_sq", + "metadata": "metadata_sq", + "entity_metadata": "entity_metadata_sq", + "parameter_value_metadata": "parameter_value_metadata_sq", + "commit": "commit_sq", + }[item_type] qry = self._db_map.query(getattr(self._db_map, sq_name)) except KeyError: return [] if not self._chunk_size: self._fetched_item_types.add(item_type) - return [x._asdict() for x in qry.yield_per(1000).enable_eagerloads(False)] + return qry.all() offset = self._offsets.setdefault(item_type, 0) - chunk = [x._asdict() for x in qry.limit(self._chunk_size).offset(offset)] + chunk = qry.limit(self._chunk_size).offset(offset).all() self._offsets[item_type] += len(chunk) return chunk @@ -116,11 +131,7 @@ def get_item(self, item_type, id_): def fetch_more(self, item_type): if item_type in self._fetched_item_types: return False - # We try to advance the query directly. If we are in the wrong thread this will raise ProgrammingError. - try: - return bool(self.do_advance_query(item_type)) - except ProgrammingError: - return bool(self.advance_query(item_type).result()) + return bool(self.do_advance_query(item_type)) def fetch_all(self, item_type): while self.fetch_more(item_type): From d8b5550524cf4abbf97edc7fed0dcc5b52c77a0f Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 12 May 2023 15:16:50 +0200 Subject: [PATCH 038/317] Introduce Status to cache items, complete Query, fix a lot of tests --- spinedb_api/db_cache.py | 382 ++++++++++++------ spinedb_api/db_mapping_add_mixin.py | 145 ++----- spinedb_api/db_mapping_base.py | 76 ++-- spinedb_api/db_mapping_commit_mixin.py | 38 +- spinedb_api/db_mapping_remove_mixin.py | 2 +- spinedb_api/db_mapping_update_mixin.py | 155 +++---- spinedb_api/export_mapping/export_mapping.py | 9 +- spinedb_api/filters/alternative_filter.py | 2 +- spinedb_api/filters/scenario_filter.py | 11 +- spinedb_api/helpers.py | 25 ++ spinedb_api/import_functions.py | 93 ++++- .../import_mapping/import_mapping_compat.py | 11 +- spinedb_api/query.py | 71 +++- tests/export_mapping/test_export_mapping.py | 125 ------ tests/filters/test_alternative_filter.py | 25 +- tests/filters/test_scenario_filter.py | 371 ++++++++--------- tests/filters/test_tool_filter.py | 12 +- tests/filters/test_tools.py | 4 +- tests/import_mapping/test_import_mapping.py | 97 ++--- tests/test_DatabaseMapping.py | 161 ++++---- tests/test_DiffDatabaseMapping.py | 290 +++++++------ tests/test_check_functions.py | 3 +- tests/test_export_functions.py | 42 -- tests/test_import_functions.py | 370 ++++------------- tests/test_migration.py | 22 +- 25 files changed, 1110 insertions(+), 1432 deletions(-) diff --git a/spinedb_api/db_cache.py b/spinedb_api/db_cache.py index 90599b0d..61d903f9 100644 --- a/spinedb_api/db_cache.py +++ b/spinedb_api/db_cache.py @@ -15,11 +15,22 @@ import uuid from contextlib import suppress from operator import itemgetter -from sqlalchemy.exc import ProgrammingError +from enum import Enum, unique, auto +from .parameter_value import from_database # TODO: Implement CacheItem.pop() to do lookup? +@unique +class Status(Enum): + """Cache item status.""" + + committed = auto() + to_add = auto() + to_update = auto() + to_remove = auto() + + class DBCache(dict): """A dictionary that maps table names to ids to items. Used to store and retrieve database contents.""" @@ -34,20 +45,38 @@ def __init__(self, db_map, chunk_size=None): self._fetched_item_types = set() self._chunk_size = chunk_size - def to_change(self): + def commit(self): to_add = {} to_update = {} to_remove = {} for item_type, table_cache in self.items(): - new = [x for x in table_cache.values() if x.new] - dirty = [x for x in table_cache.values() if x.dirty and not x.new] - removed = {x.id for x in dict.values(table_cache) if x.removed} - if new: - to_add[item_type] = new - if dirty: - to_update[item_type] = dirty - if removed: - to_remove[item_type] = removed + for item in dict.values(table_cache): + if item.status == Status.to_add: + to_add.setdefault(item_type, []).append(item) + elif item.status == Status.to_update: + to_update.setdefault(item_type, []).append(item) + elif item.status == Status.to_remove: + to_remove.setdefault(item_type, set()).add(item["id"]) + item.status = Status.committed + # FIXME: When computing to_remove, we could at the same time fetch all tables where items should be removed + # in cascade. This could be nice. So we would visit the tables in order, collect removed items, + # and if we find some then we would fetch all the descendant tables and validate items in them. + # This would set the removed flag, and then we would be able to collect those items + # in subsequent iterations. + # This might solve the issue when the user removes, commits, and then undoes the removal. + # My impression is since committing the removal action would fetch all the referrers, then it would + # be possible to properly undo it. Maybe that is the case already because `cascading_ids()` + # also fetches all the descendant tablenams into cache. + # Actually, it looks like all we're missing is setting the new attribute for restored items too??!! + # Ok so when you restore and item whose removal was committed, you need to set new to True + + # Another option would be to build a list of fetched ids in a fully independent dictionary. + # Then we could compare contents of the cache with this list and easily find out which items need + # to be added, updated and removed. + # To add: Those that are valid in the cache but not in fetched id + # To update: Those that are both valid in the cache and in fetched id + # To remove: Those that are in fetched id but not valid in the cache. + # But this would require fetching the entire DB before committing or something like that... To think about it. return to_add, to_update, to_remove @property @@ -93,9 +122,9 @@ def _get_next_chunk(self, item_type): return [] if not self._chunk_size: self._fetched_item_types.add(item_type) - return qry.all() + return [dict(x) for x in qry] offset = self._offsets.setdefault(item_type, 0) - chunk = qry.limit(self._chunk_size).offset(offset).all() + chunk = [dict(x) for x in qry.limit(self._chunk_size).offset(offset)] self._offsets[item_type] += len(chunk) return chunk @@ -159,20 +188,27 @@ def __init__(self, db_cache, item_type, *args, **kwargs): super().__init__(*args, **kwargs) self._db_cache = db_cache self._item_type = item_type - self._existing = {} + self._id_by_unique_key_value = {} - def existing(self, key, value): - """Returns the CacheItem that has the given value for the given unique constraint key, or None. + def unique_key_value_to_id(self, key, value, strict=False): + """Returns the id that has the given value for the given unique key, or None. Args: key (tuple) value (tuple) Returns: - CacheItem + int """ + value = tuple(tuple(x) if isinstance(x, list) else x for x in value) self._db_cache.fetch_all(self._item_type) - return self._existing.get(key, {}).get(value) + id_by_unique_value = self._id_by_unique_key_value.get(key, {}) + if strict: + return id_by_unique_value[value] + return id_by_unique_value.get(value) + + def _unique_key_value_to_item(self, key, value, strict=False): + return self.get(self.unique_key_value_to_id(key, value)) def values(self): return (x for x in super().values() if x.is_valid()) @@ -186,6 +222,7 @@ def _item_factory(self): "parameter_definition": ParameterDefinitionItem, "parameter_value": ParameterValueItem, "list_value": ListValueItem, + "alternative": AlternativeItem, "scenario": ScenarioItem, "scenario_alternative": ScenarioAlternativeItem, "metadata": MetadataItem, @@ -204,29 +241,32 @@ def _make_item(self, item): """ return self._item_factory(self._db_cache, self._item_type, **item) - def _current_item(self, item): + def current_item(self, item, skip_keys=()): id_ = item.get("id") if isinstance(id_, int): # id is an int, easy return self.get(id_) if isinstance(id_, dict): # id is a dict specifying the values for one of the unique constraints - return self._current_item_from_dict_id(id_) + key, value = zip(*id_.items()) + return self._unique_key_value_to_item(key, value) if id_ is None: - # No id. Try to build the dict id from the item itself. Used by import_data. - for key in self._item_factory.unique_constraint: - dict_id = {k: item.get(k) for k in key} - current_item = self._current_item_from_dict_id(dict_id) + # No id. Try to locate the item by the value of one of the unique keys. Used by import_data. + item = self._make_item(item) + error = item.resolve_inverse_references() + if error: + return None + error = item.polish() + if error: + return None + for key, value in item.unique_values(skip_keys=skip_keys): + current_item = self._unique_key_value_to_item(key, value) if current_item: return current_item - def _current_item_from_dict_id(self, dict_id): - key, value = zip(*dict_id.items()) - return self.existing(key, value) - - def check_item(self, item, for_update=False): + def check_item(self, item, for_update=False, skip_keys=()): if for_update: - current_item = self._current_item(item) + current_item = self.current_item(item, skip_keys=skip_keys) if current_item is None: return None, f"no {self._item_type} matching {item} to update" item = {**current_item, **item} @@ -234,57 +274,63 @@ def check_item(self, item, for_update=False): else: current_item = None candidate_item = self._make_item(item) - candidate_item.resolve_inverse_references() - missing_ref = candidate_item.missing_ref() - if missing_ref: - return None, f"missing {missing_ref} for {self._item_type}" + error = candidate_item.resolve_inverse_references() + if error: + return None, error + error = candidate_item.polish() + if error: + return None, error + invalid_ref = candidate_item.invalid_ref() + if invalid_ref: + return None, f"invalid {invalid_ref} for {self._item_type}" try: - for key, value in candidate_item.unique_values(): - existing_item = self.existing(key, value) - if existing_item not in (None, current_item) and existing_item.is_valid(): - kv_parts = [f"{k} '{', '.join(v) if isinstance(v, tuple) else v}'" for k, v in zip(key, value)] - head, tail = kv_parts[:-1], kv_parts[-1] - head_str = ", ".join(head) - main_parts = [head_str, tail] if head_str else [tail] - key_val = " and ".join(main_parts) - return None, f"there's already a {self._item_type} with {key_val}" + for key, value in candidate_item.unique_values(skip_keys=skip_keys): + empty = {k for k, v in zip(key, value) if v == ""} + if empty: + return None, f"invalid empty keys {empty} for {self._item_type}" + unique_item = self._unique_key_value_to_item(key, value) + if unique_item not in (None, current_item) and unique_item.is_valid(): + return None, f"there's already a {self._item_type} with {dict(zip(key, value))}" except KeyError as e: return None, f"missing {e} for {self._item_type}" - return candidate_item._asdict(), None + return candidate_item, None - def _add_to_existing(self, item): + def _add_unique(self, item): for key, value in item.unique_values(): - self._existing.setdefault(key, {})[value] = item + self._id_by_unique_key_value.setdefault(key, {})[value] = item["id"] - def _remove_from_existing(self, item): + def _remove_unique(self, item): for key, value in item.unique_values(): - self._existing.get(key, {}).pop(value, None) + self._id_by_unique_key_value.get(key, {}).pop(value, None) def add_item(self, item, new=False): self[item["id"]] = new_item = self._make_item(item) - self._add_to_existing(new_item) - new_item.new = new + self._add_unique(new_item) + if new: + new_item.status = Status.to_add return new_item def update_item(self, item): current_item = self[item["id"]] - self._remove_from_existing(current_item) - current_item.dirty = True + self._remove_unique(current_item) current_item.update(item) - self._add_to_existing(current_item) + self._add_unique(current_item) current_item.cascade_update() + if current_item.status != Status.to_add: + current_item.status = Status.to_update + return current_item def remove_item(self, id_): current_item = self.get(id_) if current_item is not None: - self._remove_from_existing(current_item) + self._remove_unique(current_item) current_item.cascade_remove() return current_item def restore_item(self, id_): current_item = self.get(id_) if current_item is not None: - self._add_to_existing(current_item) + self._add_unique(current_item) current_item.cascade_restore() return current_item @@ -292,7 +338,8 @@ def restore_item(self, id_): class CacheItem(dict): """A dictionary that represents an db item.""" - unique_constraint = (("name",),) + _defaults = {} + _unique_keys = (("name",),) _references = {} _inverse_references = {} @@ -313,10 +360,44 @@ def __init__(self, db_cache, item_type, *args, **kwargs): self._removed = False self._corrupted = False self._valid = None - self.new = False - self.dirty = False + self.status = Status.committed + + def is_committed(self): + return self.status == Status.committed + + def polish(self): + """Polishes this item once all it's references are resolved. Returns any errors. + + Returns: + str or None + """ + for key, default_value in self._defaults.items(): + self.setdefault(key, default_value) + return "" + + def resolve_inverse_references(self): + for src_key, (id_key, (ref_type, ref_key)) in self._inverse_references.items(): + if dict.get(self, src_key): + # When updating items, the user might update the id keys while leaving the name keys intact. + # In this case we shouldn't overwrite the updated id keys from the obsolete name keys. + # FIXME: It feels that this is our fault, though, like it is us who keep the obsolete name keys around. + continue + id_value = tuple(dict.get(self, k) or self.get(k) for k in id_key) + if None in id_value: + continue + table_cache = self._db_cache.table_cache(ref_type) + try: + src_value = ( + tuple(table_cache.unique_key_value_to_id(ref_key, v, strict=True) for v in zip(*id_value)) + if all(isinstance(v, (tuple, list)) for v in id_value) + else table_cache.unique_key_value_to_id(ref_key, id_value, strict=True) + ) + self[src_key] = src_value + except KeyError as err: + # Happens at unique_key_value_to_id(..., strict=True) + return f"can't find {ref_type} with {dict(zip(ref_key, err.args[0]))}" - def missing_ref(self): + def invalid_ref(self): for key, (ref_type, _ref_key) in self._references.values(): try: ref_id = self[key] @@ -329,9 +410,10 @@ def missing_ref(self): elif not self._get_ref(ref_type, ref_id): return key - def unique_values(self): - for key in self.unique_constraint: - yield key, tuple(self[k] for k in key) + def unique_values(self, skip_keys=()): + for key in self._unique_keys: + if key not in skip_keys: + yield key, tuple(self.get(k) for k in key) @property def removed(self): @@ -414,6 +496,8 @@ def add_weak_referrer(self, referrer): self._weak_referrers[referrer.key] = referrer def cascade_restore(self): + if self.status == Status.committed: + self.status = Status.to_add if not self._removed: return self._removed = False @@ -428,6 +512,7 @@ def cascade_restore(self): self.restore_callbacks -= obsolete def cascade_remove(self): + self.status = Status.to_remove if self._removed: return self._removed = True @@ -468,36 +553,9 @@ def __getitem__(self, key): return self._get_ref(ref_type, ref_id).get(ref_key) return super().__getitem__(key) - def resolve_inverse_references(self): - for src_key, (id_key, (ref_type, ref_key)) in self._inverse_references.items(): - id_value = tuple(dict.get(self, k) or self.get(k) for k in id_key) - if None in id_value: - continue - table_cache = self._db_cache.table_cache(ref_type) - with suppress(AttributeError): # NoneType has no attribute id, happens when existing() returns None - self[src_key] = ( - tuple(table_cache.existing(ref_key, v).id for v in zip(*id_value)) - if all(isinstance(v, tuple) for v in id_value) - else table_cache.existing(ref_key, id_value).id - ) - # FIXME: Do we need to catch the AttributeError and give it to the user instead?? - -class DisplayIconMixin: - def __getitem__(self, key): - if key == "display_icon": - return dict.get(self, "display_icon") - return super().__getitem__(key) - - -class DescriptionMixin: - def __getitem__(self, key): - if key == "description": - return dict.get(self, "description") - return super().__getitem__(key) - - -class EntityClassItem(DisplayIconMixin, DescriptionMixin, CacheItem): +class EntityClassItem(CacheItem): + _defaults = {"description": None, "display_icon": None} _references = {"dimension_name_list": ("dimension_id_list", ("entity_class", "name"))} _inverse_references = {"dimension_id_list": (("dimension_name_list",), ("entity_class", ("name",)))} @@ -511,8 +569,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) -class EntityItem(DescriptionMixin, CacheItem): - unique_constraint = (("class_name", "name"), ("class_name", "byname")) +class EntityItem(CacheItem): + _defaults = {"description": None} + _unique_keys = (("class_name", "name"), ("class_name", "byname")) _references = { "class_name": ("class_id", ("entity_class", "name")), "dimension_id_list": ("class_id", ("entity_class", "dimension_id_list")), @@ -538,23 +597,22 @@ def __getitem__(self, key): return self["element_name_list"] or (self["name"],) return super().__getitem__(key) - def resolve_inverse_references(self): - super().resolve_inverse_references() - self._fill_name() - - def _fill_name(self): + def polish(self): + error = super().polish() + if error: + return error if "name" in self: return base_name = self["class_name"] + "_" + "__".join(self["element_name_list"]) name = base_name table_cache = self._db_cache.table_cache(self._item_type) - while table_cache.existing(("class_name", "name"), (self["class_name"], name)) is not None: - name = base_name + uuid.uuid4().hex + while table_cache.unique_key_value_to_id(("class_name", "name"), (self["class_name"], name)) is not None: + name = base_name + "_" + uuid.uuid4().hex self["name"] = name class EntityGroupItem(CacheItem): - unique_constraint = (("group_name", "member_name"),) + _unique_keys = (("group_name", "member_name"),) _references = { "class_name": ("entity_class_id", ("entity_class", "name")), "group_name": ("entity_id", ("entity", "name")), @@ -575,8 +633,9 @@ def __getitem__(self, key): return super().__getitem__(key) -class ParameterDefinitionItem(DescriptionMixin, CacheItem): - unique_constraint = (("entity_class_name", "name"),) +class ParameterDefinitionItem(CacheItem): + _defaults = {"description": None, "default_value": None, "default_type": None, "parameter_value_list_id": None} + _unique_keys = (("entity_class_name", "name"),) _references = { "entity_class_name": ("entity_class_id", ("entity_class", "name")), "dimension_id_list": ("entity_class_id", ("entity_class", "dimension_id_list")), @@ -587,12 +646,11 @@ class ParameterDefinitionItem(DescriptionMixin, CacheItem): "parameter_value_list_id": (("parameter_value_list_name",), ("parameter_value_list", ("name",))), } - def __init__(self, *args, **kwargs): - if kwargs.get("list_value_id") is None: - kwargs["list_value_id"] = ( - int(kwargs["default_value"]) if kwargs.get("default_type") == "list_value_ref" else None - ) - super().__init__(*args, **kwargs) + @property + def list_value_id(self): + if dict.get(self, "default_type") == "list_value_ref": + return int(dict.__getitem__(self, "default_value")) + return None def __getitem__(self, key): if key == "parameter_name": @@ -601,23 +659,50 @@ def __getitem__(self, key): return super().__getitem__("parameter_value_list_id") if key == "parameter_value_list_id": return dict.get(self, key) - if key == "value_list_name": - return self._get_ref("parameter_value_list", self["value_list_id"], strong=False).get("name") + if key == "parameter_value_list_name": + return self._get_ref("parameter_value_list", self["parameter_value_list_id"], strong=False).get("name") if key in ("default_value", "default_type"): - if self["list_value_id"] is not None: - return self._get_ref("list_value", self["list_value_id"], strong=False).get(key.split("_")[1]) + list_value_id = self.list_value_id + if list_value_id is not None: + list_value_key = {"default_value": "value", "default_type": "type"}[key] + return self._get_ref("list_value", list_value_id, strong=False).get(list_value_key) return dict.get(self, key) + if key == "list_value_id": + return self.list_value_id return super().__getitem__(key) + def polish(self): + error = super().polish() + if error: + return error + default_type = self["default_type"] + default_value = self["default_value"] + list_name = self["parameter_value_list_name"] + if list_name is None: + return + if default_type == "list_value_ref": + return + parsed_value = from_database(default_value, default_type) + if parsed_value is None: + return + list_value_id = self._db_cache.table_cache("list_value").unique_key_value_to_id( + ("parameter_value_list_name", "value", "type"), (list_name, default_value, default_type) + ) + if list_value_id is None: + return f"default value {parsed_value} of {self['name']} is not in {list_name}" + self["default_value"] = str(list_value_id).encode() + self["default_type"] = "list_value_ref" + class ParameterValueItem(CacheItem): - unique_constraint = (("parameter_definition_name", "entity_byname", "alternative_name"),) + _unique_keys = (("parameter_definition_name", "entity_byname", "alternative_name"),) _references = { "entity_class_name": ("entity_class_id", ("entity_class", "name")), "dimension_id_list": ("entity_class_id", ("entity_class", "dimension_id_list")), "dimension_name_list": ("entity_class_id", ("entity_class", "dimension_name_list")), "parameter_definition_name": ("parameter_definition_id", ("parameter_definition", "name")), "parameter_value_list_id": ("parameter_definition_id", ("parameter_definition", "parameter_value_list_id")), + "parameter_value_list_name": ("parameter_definition_id", ("parameter_definition", "parameter_value_list_name")), "entity_name": ("entity_id", ("entity", "name")), "entity_byname": ("entity_id", ("entity", "byname")), "element_id_list": ("entity_id", ("entity", "element_id_list")), @@ -625,6 +710,7 @@ class ParameterValueItem(CacheItem): "alternative_name": ("alternative_id", ("alternative", "name")), } _inverse_references = { + "entity_class_id": (("entity_class_name",), ("entity_class", ("name",))), "parameter_definition_id": ( ("entity_class_name", "parameter_definition_name"), ("parameter_definition", ("entity_class_name", "name")), @@ -633,30 +719,65 @@ class ParameterValueItem(CacheItem): "alternative_id": (("alternative_name",), ("alternative", ("name",))), } - def __init__(self, *args, **kwargs): - if kwargs.get("list_value_id") is None: - kwargs["list_value_id"] = int(kwargs["value"]) if kwargs.get("type") == "list_value_ref" else None - super().__init__(*args, **kwargs) + @property + def list_value_id(self): + if dict.__getitem__(self, "type") == "list_value_ref": + return int(dict.__getitem__(self, "value")) + return None def __getitem__(self, key): if key == "parameter_id": return super().__getitem__("parameter_definition_id") - if key in ("value", "type") and self["list_value_id"] is not None: - return self._get_ref("list_value", self["list_value_id"], strong=False).get(key) + if key == "parameter_name": + return super().__getitem__("parameter_definition_name") + if key in ("value", "type"): + list_value_id = self.list_value_id + if list_value_id: + return self._get_ref("list_value", list_value_id, strong=False).get(key) + if key == "list_value_id": + return self.list_value_id return super().__getitem__(key) + def polish(self): + list_name = self["parameter_value_list_name"] + if list_name is None: + return + type_ = self["type"] + if type_ == "list_value_ref": + return + value = self["value"] + parsed_value = from_database(value, type_) + if parsed_value is None: + return + list_value_id = self._db_cache.table_cache("list_value").unique_key_value_to_id( + ("parameter_value_list_name", "value", "type"), (list_name, value, type_) + ) + if list_value_id is None: + return ( + f"value {parsed_value} of {self['parameter_definition_name']} for {self['entity_byname']} " + "is not in {list_name}" + ) + self["value"] = str(list_value_id).encode() + self["type"] = "list_value_ref" + class ListValueItem(CacheItem): - unique_constraint = (("parameter_value_list_name", "value"), ("parameter_value_list_name", "index")) + _unique_keys = (("parameter_value_list_name", "value", "type"), ("parameter_value_list_name", "index")) _references = {"parameter_value_list_name": ("parameter_value_list_id", ("parameter_value_list", "name"))} _inverse_references = { "parameter_value_list_id": (("parameter_value_list_name",), ("parameter_value_list", ("name",))), } +class AlternativeItem(CacheItem): + _defaults = {"description": None} + + class ScenarioItem(CacheItem): + _defaults = {"active": False, "description": None} + @property - def sorted_scenario_alternatives(self): + def sorted_alternatives(self): self._db_cache.fetch_all("scenario_alternative") return sorted( (x for x in self._db_cache.get("scenario_alternative", {}).values() if x["scenario_id"] == self["id"]), @@ -665,14 +786,14 @@ def sorted_scenario_alternatives(self): def __getitem__(self, key): if key == "alternative_id_list": - return [x["alternative_id"] for x in self.sorted_scenario_alternatives] + return [x["alternative_id"] for x in self.sorted_alternatives] if key == "alternative_name_list": - return [x["alternative_name"] for x in self.sorted_scenario_alternatives] + return [x["alternative_name"] for x in self.sorted_alternatives] return super().__getitem__(key) class ScenarioAlternativeItem(CacheItem): - unique_constraint = (("scenario_name", "alternative_name"), ("scenario_name", "rank")) + _unique_keys = (("scenario_name", "alternative_name"), ("scenario_name", "rank")) _references = { "scenario_name": ("scenario_id", ("scenario", "name")), "alternative_name": ("alternative_id", ("alternative", "name")), @@ -680,10 +801,13 @@ class ScenarioAlternativeItem(CacheItem): _inverse_references = { "scenario_id": (("scenario_name",), ("scenario", ("name",))), "alternative_id": (("alternative_name",), ("alternative", ("name",))), - "before_alternative_id": (("before_alternative_name",), ("alternative", ("name",))), } def __getitem__(self, key): + # The 'before' is to be interpreted as, this scenario alternative goes *before* the before_alternative. + # Since ranks go from 1 to the alternative count, the first alternative will have the second as the 'before', + # the second will have the third, etc, and the last will have None. + # Note that alternatives with higher ranks overwrite the values of those with lower ranks. if key == "before_alternative_name": return self._get_ref("alternative", self["before_alternative_id"], strong=False).get("name") if key == "before_alternative_id": @@ -696,11 +820,11 @@ def __getitem__(self, key): class MetadataItem(CacheItem): - unique_constraint = (("name", "value"),) + _unique_keys = (("name", "value"),) class EntityMetadataItem(CacheItem): - unique_constraint = (("entity_name", "metadata_name"),) + _unique_keys = (("entity_name", "metadata_name"),) _references = { "entity_name": ("entity_id", ("entity", "name")), "metadata_name": ("metadata_id", ("metadata", "name")), @@ -713,7 +837,7 @@ class EntityMetadataItem(CacheItem): class ParameterValueMetadataItem(CacheItem): - unique_constraint = (("parameter_definition_name", "entity_byname", "alternative_name", "metadata_name"),) + _unique_keys = (("parameter_definition_name", "entity_byname", "alternative_name", "metadata_name"),) _references = { "parameter_definition_name": ("parameter_value_id", ("parameter_value", "parameter_definition_name")), "entity_byname": ("parameter_value_id", ("parameter_value", "entity_byname")), diff --git a/spinedb_api/db_mapping_add_mixin.py b/spinedb_api/db_mapping_add_mixin.py index c489d3c5..2f3c53a8 100644 --- a/spinedb_api/db_mapping_add_mixin.py +++ b/spinedb_api/db_mapping_add_mixin.py @@ -18,7 +18,8 @@ from contextlib import contextmanager from sqlalchemy import func, Table, Column, Integer, String, null, select from sqlalchemy.exc import DBAPIError -from .exception import SpineDBAPIError +from .exception import SpineIntegrityError +from .helpers import convert_legacy class DatabaseMappingAddMixin: @@ -119,16 +120,6 @@ def generate_ids(self, tablename): finally: connection.execute(stmt, {"user": self.username, "date": datetime.utcnow(), fieldname: gen.next_id}) - def _add_commit_id_and_ids(self, tablename, *items): - if not items: - return [], set() - commit_id = self._make_commit_id() - with self.generate_ids(tablename) as new_id: - for item in items: - item["commit_id"] = commit_id - if "id" not in item: - item["id"] = new_id() - def add_items(self, tablename, *items, check=True, strict=False): """Add items to cache. @@ -143,43 +134,42 @@ def add_items(self, tablename, *items, check=True, strict=False): set: ids or items successfully added list(str): found violations """ - if check: - checked_items, errors = self.check_items(tablename, *items) - else: - checked_items, errors = list(items), [] - if errors and strict: - raise SpineDBAPIError(", ".join(errors)) - _ = self._add_items(tablename, *checked_items) - return checked_items, errors - - def _add_items(self, tablename, *items): - """Add items to cache without checking integrity. - - Args: - tablename (str) - items (Iterable): list of dictionaries which correspond to the instances to add - strict (bool): if True SpineIntegrityError are raised. Otherwise - they are caught and returned as a log - - Returns: - ids (set): added instances' ids - """ - self._add_commit_id_and_ids(tablename, *items) + added, errors = [], [] tablename = self._real_tablename(tablename) table_cache = self.cache.table_cache(tablename) - for item in items: - table_cache.add_item(item, new=True) - return {item["id"] for item in items} + with self.generate_ids(tablename) as new_id: + if not check: + for item in items: + convert_legacy(tablename, item) + if "id" not in item: + item["id"] = new_id() + added.append(table_cache.add_item(item, new=True)._asdict()) + else: + for item in items: + convert_legacy(tablename, item) + checked_item, error = table_cache.check_item(item) + if error: + if strict: + raise SpineIntegrityError(error) + errors.append(error) + continue + item = checked_item._asdict() + if "id" not in item: + item["id"] = new_id() + added.append(table_cache.add_item(item, new=True)._asdict()) + return added, errors def _do_add_items(self, tablename, *items_to_add): """Add items to DB without checking integrity.""" try: for tablename_, items_to_add_ in self._items_to_add_per_table(tablename, items_to_add): + if not items_to_add_: + continue table = self._metadata.tables[self._real_tablename(tablename_)] - self._checked_execute(table.insert(), [{**item} for item in items_to_add_]) + self.connection_execute(table.insert(), [dict(item) for item in items_to_add_]) except DBAPIError as e: msg = f"DBAPIError while inserting {tablename} items: {e.orig.args}" - raise SpineDBAPIError(msg) from e + raise SpineIntegrityError(msg) from e @staticmethod def _items_to_add_per_table(tablename, items_to_add): @@ -194,13 +184,13 @@ def _items_to_add_per_table(tablename, items_to_add): Yields: tuple: database table name, items to add """ + yield (tablename, items_to_add) if tablename == "entity_class": ecd_items_to_add = [ {"entity_class_id": item["id"], "position": position, "dimension_id": dimension_id} for item in items_to_add for position, dimension_id in enumerate(item["dimension_id_list"]) ] - yield ("entity_class", items_to_add) yield ("entity_class_dimension", ecd_items_to_add) elif tablename == "entity": ee_items_to_add = [ @@ -216,79 +206,7 @@ def _items_to_add_per_table(tablename, items_to_add): zip(item["element_id_list"], item["dimension_id_list"]) ) ] - ea_items_to_add = [ - {"entity_id": item["id"], "alternative_id": alternative_id, "active": True} - for item in items_to_add - for alternative_id in item["active_alternative_id_list"] - ] + [ - {"entity_id": item["id"], "alternative_id": alternative_id, "active": False} - for item in items_to_add - for alternative_id in item["inactive_alternative_id_list"] - ] - yield ("entity", items_to_add) - yield ("entity_element", ee_items_to_add) - yield ("entity_alternative", ea_items_to_add) - elif tablename == "relationship_class": - ecd_items_to_add = [ - {"entity_class_id": item["id"], "position": position, "dimension_id": dimension_id} - for item in items_to_add - for position, dimension_id in enumerate(item["object_class_id_list"]) - ] - yield ("entity_class", items_to_add) - yield ("entity_class_dimension", ecd_items_to_add) - elif tablename == "object": - ea_items_to_add = [ - {"entity_id": item["id"], "alternative_id": alternative_id, "active": True} - for item in items_to_add - for alternative_id in item["active_alternative_id_list"] - ] + [ - {"entity_id": item["id"], "alternative_id": alternative_id, "active": False} - for item in items_to_add - for alternative_id in item["inactive_alternative_id_list"] - ] - yield ("entity", items_to_add) - yield ("entity_alternative", ea_items_to_add) - elif tablename == "relationship": - ee_items_to_add = [ - { - "entity_id": item["id"], - "entity_class_id": item["class_id"], - "position": position, - "element_id": element_id, - "dimension_id": dimension_id, - } - for item in items_to_add - for position, (element_id, dimension_id) in enumerate( - zip(item["object_id_list"], item["object_class_id_list"]) - ) - ] - ea_items_to_add = [ - {"entity_id": item["id"], "alternative_id": alternative_id, "active": True} - for item in items_to_add - for alternative_id in item["active_alternative_id_list"] - ] + [ - {"entity_id": item["id"], "alternative_id": alternative_id, "active": False} - for item in items_to_add - for alternative_id in item["inactive_alternative_id_list"] - ] - yield ("entity", items_to_add) yield ("entity_element", ee_items_to_add) - yield ("entity_alternative", ea_items_to_add) - elif tablename == "parameter_definition": - for item in items_to_add: - item["entity_class_id"] = ( - item.get("object_class_id") or item.get("relationship_class_id") or item.get("entity_class_id") - ) - yield ("parameter_definition", items_to_add) - elif tablename == "parameter_value": - for item in items_to_add: - item["entity_id"] = item.get("object_id") or item.get("relationship_id") or item.get("entity_id") - item["entity_class_id"] = ( - item.get("object_class_id") or item.get("relationship_class_id") or item.get("entity_class_id") - ) - yield ("parameter_value", items_to_add) - else: - yield (tablename, items_to_add) def add_object_classes(self, *items, **kwargs): return self.add_items("object_class", *items, **kwargs) @@ -342,9 +260,8 @@ def add_parameter_value_metadata(self, *items, **kwargs): return self.add_items("parameter_value_metadata", *items, **kwargs) def _get_or_add_metadata_ids_for_items(self, *items, check, strict): - cache = self.cache metadata_ids = {} - for entry in cache.get("metadata", {}).values(): + for entry in self.cache.get("metadata", {}).values(): metadata_ids.setdefault(entry.name, {})[entry.value] = entry.id metadata_to_add = [] items_missing_metadata_ids = {} @@ -357,8 +274,6 @@ def _get_or_add_metadata_ids_for_items(self, *items, check, strict): else: item["metadata_id"] = existing_id added_metadata, errors = self.add_items("metadata", *metadata_to_add, check=check, strict=strict) - for x in added_metadata: - cache.table_cache("metadata").add_item(x) if errors: return added_metadata, errors new_metadata_ids = {} diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 26307b4e..65d871e9 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -20,7 +20,7 @@ from types import MethodType from concurrent.futures import ThreadPoolExecutor from sqlalchemy import create_engine, MetaData, Table, Column, Integer, inspect, case, func, cast, false, and_, or_ -from sqlalchemy.sql.expression import label, Alias, select +from sqlalchemy.sql.expression import label, Alias from sqlalchemy.engine.url import make_url, URL from sqlalchemy.orm import aliased from sqlalchemy.exc import DatabaseError, ProgrammingError @@ -184,9 +184,10 @@ def __init__( self._table_to_sq_attr = {} # Table primary ids map: self._id_fields = { + "entity_class_dimension": "entity_class_id", + "entity_element": "entity_id", "object_class": "entity_class_id", "relationship_class": "entity_class_id", - "entity_class_dimension": "entity_class_id", "object": "entity_id", "relationship": "entity_id", } @@ -232,6 +233,14 @@ def __exit__(self, _exc_type, _exc_val, _exc_tb): def _make_executor(self): return ThreadPoolExecutor(max_workers=1) if self._asynchronous else _Executor() + def call_in_right_thread(self, fn, *args, **kwargs): + # We try to call directly. If we are in the wrong thread this will raise ProgrammingError. + # Then we can execute in the executor thread. + try: + return fn(*args, **kwargs) + except ProgrammingError: + return self.executor.submit(fn, *args, **kwargs).result() + def close(self): if not self.connection.closed: self.executor.submit(self.connection.close) @@ -282,23 +291,6 @@ def _real_tablename(self, tablename): def get_table(self, tablename): return self._metadata.tables[tablename] - def commit_id(self): - return self._commit_id - - def _make_commit_id(self): - return None - - def _check_commit(self, comment): - """Raises if commit not possible. - - Args: - comment (str): commit message - """ - if not self.has_pending_changes(): - raise SpineDBAPIError("Nothing to commit.") - if not comment: - raise SpineDBAPIError("Commit message cannot be empty.") - def _make_codename(self, codename): if codename: return str(codename) @@ -402,8 +394,8 @@ def in_(self, column, values): Column("value", column.type, primary_key=True), prefixes=['TEMPORARY'], ) - self.executor.submit(in_value.create, self.connection, checkfirst=True).result() - self.safe_execute(in_value.insert(), [{"value": column.type.python_type(val)} for val in set(values)]) + self.call_in_right_thread(in_value.create, self.connection, checkfirst=True) + self.connection_execute(in_value.insert(), [{"value": column.type.python_type(val)} for val in set(values)]) return column.in_(self.query(in_value.c.value)) def _get_table_to_sq_attr(self): @@ -469,7 +461,7 @@ def query(self, *args, **kwargs): db_map.object_sq.c.class_id == db_map.object_class_sq.c.id ).group_by(db_map.object_class_sq.c.name).all() """ - return Query(self, select(args)) + return Query(self, *args) def _subquery(self, tablename): """A subquery of the form: @@ -702,7 +694,7 @@ def object_class_sq(self): self.wide_entity_class_sq.c.hidden.label("hidden"), ) .filter(self.wide_entity_class_sq.c.dimension_id_list == None) - .subquery() + .subquery("object_class_sq") ) return self._object_class_sq @@ -727,7 +719,7 @@ def object_sq(self): self.wide_entity_sq.c.commit_id.label("commit_id"), ) .filter(self.wide_entity_sq.c.element_id_list == None) - .subquery() + .subquery("object_sq") ) return self._object_sq @@ -755,7 +747,7 @@ def relationship_class_sq(self): self.wide_entity_class_sq.c.hidden.label("hidden"), ) .filter(self.wide_entity_class_sq.c.id == ent_cls_dim_sq.c.entity_class_id) - .subquery() + .subquery("relationship_class_sq") ) return self._relationship_class_sq @@ -782,7 +774,7 @@ def relationship_sq(self): self.wide_entity_sq.c.commit_id.label("commit_id"), ) .filter(self.wide_entity_sq.c.id == ent_el_sq.c.entity_id) - .subquery() + .subquery("relationship_sq") ) return self._relationship_sq @@ -1920,12 +1912,8 @@ def restore_scenario_alternative_sq_maker(self): self._make_scenario_alternative_sq = MethodType(DatabaseMappingBase._make_scenario_alternative_sq, self) self._clear_subqueries("scenario_alternative") - def safe_execute(self, *args): - # We try to execute directly. If we are in the wrong thread this will raise ProgrammingError. - try: - return self.connection.execute(*args) - except ProgrammingError: - return self.executor.submit(self.connection.execute, *args).result() + def connection_execute(self, *args): + return self.call_in_right_thread(self.connection.execute, *args) def _get_primary_key(self, tablename): pk = self.composite_pks.get(tablename) @@ -1940,8 +1928,8 @@ def _reset_mapping(self): """ for tablename in self._tablenames: table = self._metadata.tables[tablename] - self.connection.execute(table.delete()) - self.connection.execute("INSERT INTO alternative VALUES (1, 'Base', 'Base alternative', null)") + self.connection_execute(table.delete()) + self.connection_execute("INSERT INTO alternative VALUES (1, 'Base', 'Base alternative', null)") def fetch_all(self, tablenames, include_descendants=False, include_ancestors=False, force_tablenames=None): if include_descendants: @@ -2026,23 +2014,14 @@ def _metadata_usage_counts(self): """ cache = self.cache usage_counts = Counter() - for entry in cache.get("entity_metadata", {}).values(): + for entry in dict.values(cache.get("entity_metadata", {})): usage_counts[entry.metadata_id] += 1 - for entry in cache.get("parameter_value_metadata", {}).values(): + for entry in dict.values(cache.get("parameter_value_metadata", {})): usage_counts[entry.metadata_id] += 1 return usage_counts - def check_items(self, tablename, *items, for_update=False): - tablename = self._real_tablename(tablename) - table_cache = self.cache.table_cache(tablename) - checked_items, errors = [], [] - for item in items: - checked_item, error = table_cache.check_item(item, for_update=for_update) - if error: - errors.append(error) - else: - checked_items.append(checked_item) - return checked_items, errors + def get_filter_configs(self): + return self._filter_configs def __del__(self): try: @@ -2050,9 +2029,6 @@ def __del__(self): except AttributeError: pass - def get_filter_configs(self): - return self._filter_configs - class _Future: def __init__(self): diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py index 2efc916f..ac109d5e 100644 --- a/spinedb_api/db_mapping_commit_mixin.py +++ b/spinedb_api/db_mapping_commit_mixin.py @@ -21,51 +21,31 @@ class DatabaseMappingCommitMixin: """Provides methods to commit or rollback pending changes onto a Spine database.""" - def __init__(self, *args, **kwargs): - """Initialize class.""" - super().__init__(*args, **kwargs) - self._commit_id = None - - def has_pending_changes(self): - # FIXME - return True - - def _make_commit_id(self): - if self._commit_id is None: - with self.engine.begin() as connection: - self._commit_id = self._do_make_commit_id(connection) - return self._commit_id - - def _do_make_commit_id(self, connection): - user = self.username - date = datetime.now(timezone.utc) - ins = self._metadata.tables["commit"].insert() - return connection.execute(ins, {"user": user, "date": date, "comment": "uncomplete"}).inserted_primary_key[0] - def commit_session(self, comment): """Commits current session to the database. Args: comment (str): commit message """ - self._check_commit(comment) - commit = self._metadata.tables["commit"] + if not comment: + raise SpineDBAPIError("Commit message cannot be empty.") user = self.username date = datetime.now(timezone.utc) - upd = commit.update().where(commit.c.id == self._make_commit_id()) - self._checked_execute(upd, dict(user=user, date=date, comment=comment)) - to_add, to_update, to_remove = self.cache.to_change() + ins = self._metadata.tables["commit"].insert() + commit_id = self.connection_execute(ins, dict(user=user, date=date, comment=comment)).inserted_primary_key[0] + to_add, to_update, to_remove = self.cache.commit() + if not to_add and not to_update and not to_remove: + raise SpineDBAPIError("Nothing to commit.") for tablename, items in to_add.items(): self._do_add_items(tablename, *items) for tablename, items in to_update.items(): self._do_update_items(tablename, *items) self._do_remove_items(**to_remove) - self._commit_id = None if self._memory: self._memory_dirty = True def rollback_session(self): - if not self.has_pending_changes(): + to_add, to_update, to_remove = self.cache.commit() + if not to_add and not to_update and not to_remove: raise SpineDBAPIError("Nothing to rollback.") self.cache.reset_queries() - self._commit_id = None diff --git a/spinedb_api/db_mapping_remove_mixin.py b/spinedb_api/db_mapping_remove_mixin.py index 0fbd11a2..2f6ec533 100644 --- a/spinedb_api/db_mapping_remove_mixin.py +++ b/spinedb_api/db_mapping_remove_mixin.py @@ -62,7 +62,7 @@ def _do_remove_items(self, **kwargs): table = self._metadata.tables[tablename] delete = table.delete().where(self.in_(getattr(table.c, id_field), ids)) try: - self.safe_execute(delete) + self.connection_execute(delete) except DBAPIError as e: msg = f"DBAPIError while removing {tablename} items: {e.orig.args}" raise SpineDBAPIError(msg) from e diff --git a/spinedb_api/db_mapping_update_mixin.py b/spinedb_api/db_mapping_update_mixin.py index 467ee0f8..c4b86fbb 100644 --- a/spinedb_api/db_mapping_update_mixin.py +++ b/spinedb_api/db_mapping_update_mixin.py @@ -15,7 +15,8 @@ from collections import Counter from sqlalchemy.exc import DBAPIError from sqlalchemy.sql.expression import bindparam -from .exception import SpineDBAPIError +from .exception import SpineIntegrityError +from .helpers import convert_legacy class DatabaseMappingUpdateMixin: @@ -32,10 +33,10 @@ def _do_update_items(self, tablename, *items_to_update): for k in self._get_primary_key(tablename_): upd = upd.where(getattr(table.c, k) == bindparam(k)) upd = upd.values({key: bindparam(key) for key in table.columns.keys() & items_to_update_[0].keys()}) - self.safe_execute(upd, [{**item} for item in items_to_update_]) + self.connection_execute(upd, [dict(item) for item in items_to_update_]) except DBAPIError as e: msg = f"DBAPIError while updating '{tablename}' items: {e.orig.args}" - raise SpineDBAPIError(msg) from e + raise SpineIntegrityError(msg) from e @staticmethod def _items_to_update_per_table(tablename, items_to_update): @@ -50,68 +51,22 @@ def _items_to_update_per_table(tablename, items_to_update): Yields: tuple: database table name, items to update """ + yield (tablename, items_to_update) if tablename == "entity": - entity_items = [] - entity_element_items = [] - for item in items_to_update: - entity_id = item["id"] - class_id = item["class_id"] - dimension_id_list = item["dimension_id_list"] - element_id_list = item["element_id_list"] - entity_items.append( - { - "id": entity_id, - "class_id": class_id, - "name": item["name"], - "description": item.get("description"), - } + ee_items_to_update = [ + { + "entity_id": item["id"], + "entity_class_id": item["class_id"], + "position": position, + "element_id": element_id, + "dimension_id": dimension_id, + } + for item in items_to_update + for position, (element_id, dimension_id) in enumerate( + zip(item["element_id_list"], item["dimension_id_list"]) ) - entity_element_items.extend( - [ - { - "entity_class_id": class_id, - "entity_id": entity_id, - "position": position, - "dimension_id": dimension_id, - "element_id": element_id, - } - for position, (dimension_id, element_id) in enumerate(zip(dimension_id_list, element_id_list)) - ] - ) - yield ("entity", entity_items) - yield ("entity_element", entity_element_items) - elif tablename == "relationship": - entity_items = [] - entity_element_items = [] - for item in items_to_update: - entity_id = item["id"] - class_id = item["class_id"] - object_class_id_list = item["object_class_id_list"] - object_id_list = item["object_id_list"] - entity_items.append( - { - "id": entity_id, - "class_id": class_id, - "name": item["name"], - "description": item.get("description"), - } - ) - entity_element_items.extend( - [ - { - "entity_class_id": class_id, - "entity_id": entity_id, - "position": position, - "dimension_id": dimension_id, - "element_id": element_id, - } - for position, (dimension_id, element_id) in enumerate(zip(object_class_id_list, object_id_list)) - ] - ) - yield ("entity", entity_items) - yield ("entity_element", entity_element_items) - else: - yield (tablename, items_to_update) + ] + yield ("entity_element", ee_items_to_update) def update_items(self, tablename, *items, check=True, strict=False): """Updates items in cache. @@ -127,27 +82,25 @@ def update_items(self, tablename, *items, check=True, strict=False): set: ids or items successfully updated list(SpineIntegrityError): found violations """ - if check: - checked_items, errors = self.check_items(tablename, *items, for_update=True) - else: - checked_items, errors = list(items), [] - if errors and strict: - raise SpineDBAPIError(", ".join(errors)) - _ = self._update_items(tablename, *checked_items) - return checked_items, errors - - def _update_items(self, tablename, *items): - """Updates items in cache without checking integrity.""" - if not items: - return set() + updated, errors = [], [] tablename = self._real_tablename(tablename) - table_cache = self.cache.get(tablename) - if table_cache is not None: - commit_id = self._make_commit_id() + table_cache = self.cache.table_cache(tablename) + if not check: + for item in items: + convert_legacy(tablename, item) + updated.append(table_cache.update_item(item)._asdict()) + else: for item in items: - item["commit_id"] = commit_id - table_cache.update_item(item) - return {x["id"] for x in items} + convert_legacy(tablename, item) + checked_item, error = table_cache.check_item(item, for_update=True) + if error: + if strict: + raise SpineIntegrityError(error) + errors.append(error) + continue + item = checked_item._asdict() + updated.append(table_cache.update_item(item)._asdict()) + return updated, errors def update_alternatives(self, *items, **kwargs): return self.update_items("alternative", *items, **kwargs) @@ -290,11 +243,11 @@ def _update_ext_item_metadata(self, metadata_table, *items, check=True, strict=F errors += item_metadata_errors return all_items, errors - def get_data_to_set_scenario_alternatives(self, *items): + def get_data_to_set_scenario_alternatives(self, *scenarios, strict=True): """Returns data to add and remove, in order to set wide scenario alternatives. Args: - *items: One or more wide scenario :class:`dict` objects to set. + *scenarios: One or more wide scenario :class:`dict` objects to set. Each item must include the following keys: - "id": integer scenario id @@ -304,21 +257,25 @@ def get_data_to_set_scenario_alternatives(self, *items): list: scenario_alternative :class:`dict` objects to add. set: integer scenario_alternative ids to remove """ - self.fetch_all({"scenario_alternative", "scenario"}) - cache = self.cache - current_alternative_id_lists = {x.id: x.alternative_id_list for x in cache.get("scenario", {}).values()} - scenario_alternative_ids = { - (x.scenario_id, x.alternative_id): x.id for x in cache.get("scenario_alternative", {}).values() - } scen_alts_to_add = [] scen_alt_ids_to_remove = set() - for item in items: - scenario_id = item["id"] - alternative_id_list = item["alternative_id_list"] - current_alternative_id_list = current_alternative_id_lists[scenario_id] - for k, alternative_id in enumerate(alternative_id_list): - item_to_add = {"scenario_id": scenario_id, "alternative_id": alternative_id, "rank": k + 1} + errors = [] + for scen in scenarios: + current_scen = self.cache.table_cache("scenario").current_item(scen) + if current_scen is None: + error = f"no scenario matching {scen} to set alternatives for" + if strict: + raise SpineIntegrityError(error) + errors.append(error) + continue + for k, alternative_id in enumerate(scen.get("alternative_id_list", ())): + item_to_add = {"scenario_id": current_scen["id"], "alternative_id": alternative_id, "rank": k + 1} + scen_alts_to_add.append(item_to_add) + for k, alternative_name in enumerate(scen.get("alternative_name_list", ())): + item_to_add = {"scenario_id": current_scen["id"], "alternative_name": alternative_name, "rank": k + 1} scen_alts_to_add.append(item_to_add) - for alternative_id in current_alternative_id_list: - scen_alt_ids_to_remove.add(scenario_alternative_ids[scenario_id, alternative_id]) - return scen_alts_to_add, scen_alt_ids_to_remove + for alternative_id in current_scen["alternative_id_list"]: + scen_alt = {"scenario_id": current_scen["id"], "alternative_id": alternative_id} + current_scen_alt = self.cache.table_cache("scenario_alternative").current_item(scen_alt) + scen_alt_ids_to_remove.add(current_scen_alt["id"]) + return scen_alts_to_add, scen_alt_ids_to_remove, errors diff --git a/spinedb_api/export_mapping/export_mapping.py b/spinedb_api/export_mapping/export_mapping.py index 8de7db99..ff7d696c 100644 --- a/spinedb_api/export_mapping/export_mapping.py +++ b/spinedb_api/export_mapping/export_mapping.py @@ -421,7 +421,7 @@ def rows(self, db_map, title_state): generator(dict) """ qry = self._build_query(db_map, title_state) - for db_row in qry.yield_per(1000): + for db_row in qry: yield from self.get_rows_recursive(db_row) def has_titles(self): @@ -506,7 +506,7 @@ def _non_unique_titles(self, db_map, limit=None): tuple(str,dict): title, and associated title state dictionary """ qry = self._build_title_query(db_map) - for db_row in qry.yield_per(1000): + for db_row in qry: yield from self.get_titles_recursive(db_row, limit=limit) def titles(self, db_map, limit=None): @@ -575,7 +575,7 @@ def make_header(self, db_map, title_state, buddies): Returns dict: a mapping from column index to string header """ - query = _Rewindable(self._build_header_query(db_map, title_state, buddies).yield_per(1000)) + query = _Rewindable(self._build_header_query(db_map, title_state, buddies)) return self.make_header_recursive(query, buddies) @@ -1488,9 +1488,6 @@ def __init__(self, query, condition): self._query = query self._condition = condition - def yield_per(self, count): - return _FilteredQuery(self._query.yield_per(count), self._condition) - def filter(self, *args, **kwargs): return _FilteredQuery(self._query.filter(*args, **kwargs), self._condition) diff --git a/spinedb_api/filters/alternative_filter.py b/spinedb_api/filters/alternative_filter.py index 2236dff9..67a01907 100644 --- a/spinedb_api/filters/alternative_filter.py +++ b/spinedb_api/filters/alternative_filter.py @@ -99,7 +99,7 @@ def alternative_filter_shorthand_to_config(shorthand): Returns: dict: alternative filter configuration """ - filter_type, separator, tokens = shorthand.partition(":'") + _filter_type, _separator, tokens = shorthand.partition(":'") alternatives = tokens.split("':'") alternatives[-1] = alternatives[-1][:-1] return alternative_filter_config(alternatives) diff --git a/spinedb_api/filters/scenario_filter.py b/spinedb_api/filters/scenario_filter.py index 7c8d767d..f3cd90a8 100644 --- a/spinedb_api/filters/scenario_filter.py +++ b/spinedb_api/filters/scenario_filter.py @@ -31,8 +31,9 @@ def apply_scenario_filter_to_subqueries(db_map, scenario): scenario (str or int): scenario name or id """ state = _ScenarioFilterState(db_map, scenario) - make_entity_sq = partial(_make_scenario_filtered_entity_sq, state=state) - db_map.override_entity_sq_maker(make_entity_sq) + # FIXME + # make_entity_sq = partial(_make_scenario_filtered_entity_sq, state=state) + # db_map.override_entity_sq_maker(make_entity_sq) make_parameter_value_sq = partial(_make_scenario_filtered_parameter_value_sq, state=state) db_map.override_parameter_value_sq_maker(make_parameter_value_sq) make_alternative_sq = partial(_make_scenario_filtered_alternative_sq, state=state) @@ -204,14 +205,14 @@ def _make_scenario_filtered_entity_sq(db_map, state): order_by=desc(db_map.scenario_alternative_sq.c.rank), ) .label("max_rank_row_number"), - db_map.entity_alternative_sq.active.label("active"), + db_map.entity_alternative_sq.c.active.label("active"), ) .filter(state.original_entity_sq.c.id == db_map.entity_alternative_sq.c.entity_id) .filter(db_map.entity_alternative_sq.c.alternative_id == db_map.scenario_alternative_sq.c.alternative_id) .filter(db_map.scenario_alternative_sq.c.scenario_id == state.scenario_id) ).subquery() # TODO: Maybe we want to filter multi-dimensional entities involving filtered entities right here too? - return db_map.query(wide_entity_sq).filter_by(max_rank_row_number=1, is_active=True).subquery() + return db_map.query(wide_entity_sq).filter_by(max_rank_row_number=1, active=True).subquery() def _make_scenario_filtered_parameter_value_sq(db_map, state): @@ -237,7 +238,7 @@ def _make_scenario_filtered_parameter_value_sq(db_map, state): state.original_parameter_value_sq.c.entity_id, ], order_by=desc(db_map.scenario_alternative_sq.c.rank), - ) + ) # the one with the highest rank will have row_number equal to 1, so it will 'win' in the filter below .label("max_rank_row_number"), ) .filter(state.original_parameter_value_sq.c.alternative_id == db_map.scenario_alternative_sq.c.alternative_id) diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index 9d5f6b53..4c1056d5 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -824,3 +824,28 @@ def remove_credentials_from_url(url): if parsed.username is None: return url return urlunparse(parsed._replace(netloc=parsed.netloc.partition("@")[-1])) + + +def convert_legacy(tablename, item): + if tablename in ("entity_class", "entity"): + object_class_id_list = tuple(item.pop("object_class_id_list", ())) + if object_class_id_list: + item["dimension_id_list"] = object_class_id_list + object_class_name_list = tuple(item.pop("object_class_name_list", ())) + if object_class_name_list: + item["dimension_name_list"] = object_class_name_list + if tablename == "entity": + object_id_list = tuple(item.pop("object_id_list", ())) + if object_id_list: + item["element_id_list"] = object_id_list + object_name_list = tuple(item.pop("object_name_list", ())) + if object_name_list: + item["element_name_list"] = object_name_list + if tablename in ("parameter_definition", "parameter_value"): + entity_class_id = item.pop("object_class_id", None) or item.pop("relationship_class_id", None) + if entity_class_id: + item["entity_class_id"] = entity_class_id + if tablename == "parameter_value": + entity_id = item.pop("object_id", None) or item.pop("relationship_id", None) + if entity_id: + item["entity_id"] = entity_id diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 51e27667..7b20de90 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -176,10 +176,10 @@ def get_data_for_import( yield ("scenario", _get_scenarios_for_import(db_map, scenarios)) if scenario_alternatives: if not scenarios: - scenarios = (item[0] for item in scenario_alternatives) + scenarios = list({item[0]: None for item in scenario_alternatives}) yield ("scenario", _get_scenarios_for_import(db_map, scenarios)) if not alternatives: - alternatives = (item[1] for item in scenario_alternatives) + alternatives = list({item[1]: None for item in scenario_alternatives}) yield ("alternative", _get_alternatives_for_import(db_map, alternatives)) yield ("scenario_alternative", _get_scenario_alternatives_for_import(db_map, scenario_alternatives)) if entity_classes: @@ -744,26 +744,48 @@ def import_relationship_parameter_value_metadata(db_map, data): return import_data(db_map, relationship_parameter_value_metadata=data) -def _get_items_for_import(db_map, item_type, data): +def _get_items_for_import(db_map, item_type, data, skip_keys=()): table_cache = db_map.cache.table_cache(item_type) errors = [] to_add = [] to_update = [] + seen = {} with db_map.generate_ids(item_type) as new_id: for item in data: - checked_item, error = table_cache.check_item(item) - if not error: - item["id"] = new_id() + checked_item, add_error = table_cache.check_item(item, skip_keys=skip_keys) + if not add_error: + if not _check_seen(item_type, checked_item, seen, errors): + continue + checked_item["id"] = new_id() to_add.append(checked_item) continue - checked_item, error = table_cache.check_item(item, for_update=True) - if not error: + checked_item, update_error = table_cache.check_item(item, for_update=True, skip_keys=skip_keys) + if not update_error: + if not _check_seen(item_type, checked_item, seen, errors): + continue + # FIXME: Maybe check that item and checked_item are different before updating??? to_update.append(checked_item) continue - errors.append(error) + errors.append(add_error) return to_add, to_update, errors +def _check_seen(item_type, checked_item, seen, errors): + dupe_key = _add_to_seen(checked_item, seen) + if not dupe_key: + return True + if item_type in ("parameter_value",): + errors.append(f"attempting to import more than one {item_type} with {dupe_key} - only first will be considered") + return False + + +def _add_to_seen(checked_item, seen): + for key, value in checked_item.unique_values(): + if value in seen.get(key, set()): + return dict(zip(key, value)) + seen.setdefault(key, set()).add(value) + + def _get_entity_classes_for_import(db_map, data): key = ("name", "dimension_name_list", "description", "display_icon") return _get_items_for_import( @@ -806,7 +828,8 @@ def _data_iterator(): if isinstance(entity_byname, str): entity_byname = (entity_byname,) value, type_ = unparse_value(value) - yield class_name, entity_byname, parameter_name, value, type_, *optionals + alternative_name = optionals[0] if optionals else "Base" + yield class_name, entity_byname, parameter_name, value, type_, alternative_name key = ("entity_class_name", "entity_byname", "parameter_definition_name", "value", "type", "alternative_name") return _get_items_for_import(db_map, "parameter_value", (dict(zip(key, x)) for x in _data_iterator())) @@ -828,8 +851,34 @@ def _get_scenarios_for_import(db_map, data): def _get_scenario_alternatives_for_import(db_map, data): - key = ("scenario_name", "alternative_name", "before_alternative_name") - return _get_items_for_import(db_map, "scenario_alternative", (dict(zip(key, x)) for x in data)) + alt_name_list_by_scen_name, errors = {}, [] + for scen_name, alt_name, *optionals in data: + scen = db_map.cache.table_cache("scenario").current_item({"name": scen_name}) + if scen is None: + errors.append(f"no scenario with name {scen_name} to set alternatives for") + continue + alternative_name_list = alt_name_list_by_scen_name.setdefault(scen_name, scen["alternative_name_list"]) + if alt_name in alternative_name_list: + alternative_name_list.remove(alt_name) + before_alt_name = optionals[0] if optionals else None + if before_alt_name is None: + alternative_name_list.append(alt_name) + continue + if before_alt_name in alternative_name_list: + pos = alternative_name_list.index(before_alt_name) + alternative_name_list.insert(pos, alt_name) + else: + errors.append(f"{before_alt_name} is not in {scen_name}") + + def _data_iterator(): + for scen_name, alternative_name_list in alt_name_list_by_scen_name.items(): + for k, alt_name in enumerate(alternative_name_list): + yield {"scenario_name": scen_name, "alternative_name": alt_name, "rank": k + 1} + + to_add, to_update, more_errors = _get_items_for_import( + db_map, "scenario_alternative", _data_iterator(), skip_keys=(("scenario_name", "rank"),) + ) + return to_add, to_update, errors + more_errors def _get_parameter_value_lists_for_import(db_map, data): @@ -838,9 +887,23 @@ def _get_parameter_value_lists_for_import(db_map, data): def _get_list_values_for_import(db_map, data, unparse_value): def _data_iterator(): + index_by_list_name = {} for list_name, value in data: value, type_ = unparse_value(value) - yield {"parameter_value_list_name": list_name, "value": value, "type": type_} + index = index_by_list_name.get(list_name) + if index is None: + current_list = db_map.cache.table_cache("parameter_value_list").current_item({"name": list_name}) + index = max( + ( + x["index"] + for x in db_map.cache.get("list_value", {}).values() + if x["parameter_value_list_id"] == current_list["id"] + ), + default=-1, + ) + index += 1 + index_by_list_name[list_name] = index + yield {"parameter_value_list_name": list_name, "value": value, "type": type_, "index": index} return _get_items_for_import(db_map, "list_value", _data_iterator()) @@ -871,8 +934,9 @@ def _data_iterator(): for class_name, entity_byname, parameter_name, metadata, *optionals in data: if isinstance(entity_byname, str): entity_byname = (entity_byname,) + alternative_name = optionals[0] if optionals else "Base" for name, value in _parse_metadata(metadata): - yield (class_name, entity_byname, parameter_name, name, value, *optionals) + yield (class_name, entity_byname, parameter_name, name, value, alternative_name) key = ( "entity_class_name", @@ -891,6 +955,7 @@ def _data_iterator(): for x in data: if isinstance(x, str): yield x + continue name, *optionals = x yield name, (), *optionals diff --git a/spinedb_api/import_mapping/import_mapping_compat.py b/spinedb_api/import_mapping/import_mapping_compat.py index a400bffd..c8170b51 100644 --- a/spinedb_api/import_mapping/import_mapping_compat.py +++ b/spinedb_api/import_mapping/import_mapping_compat.py @@ -75,18 +75,17 @@ def import_mapping_from_dict(map_dict): "Alternative": _alternative_mapping_from_dict, "Scenario": _scenario_mapping_from_dict, "ScenarioAlternative": _scenario_alternative_mapping_from_dict, - # FIXME - # "Tool": _tool_mapping_from_dict, - # "Feature": _feature_mapping_from_dict, - # "ToolFeature": _tool_feature_mapping_from_dict, - # "ToolFeatureMethod": _tool_feature_method_mapping_from_dict, "ObjectGroup": _object_group_mapping_from_dict, "ParameterValueList": _parameter_value_list_mapping_from_dict, } from_dict = legacy_mapping_from_dict.get(map_type) if from_dict is not None: return from_dict(map_dict) - raise ValueError(f'invalid "map_type" value, expected any of {", ".join(legacy_mapping_from_dict)}, got {map_type}') + obsolete_types = ("Tool", "Feature", "ToolFeature", "ToolFeatureMethod") + invalid = "obsolete" if map_type in obsolete_types else "unknown" + raise ValueError( + f'{invalid} "map_type" value, expected any of {", ".join(legacy_mapping_from_dict)}, got {map_type}' + ) def _parameter_value_list_mapping_from_dict(map_dict): diff --git a/spinedb_api/query.py b/spinedb_api/query.py index b21ef857..f80894b4 100644 --- a/spinedb_api/query.py +++ b/spinedb_api/query.py @@ -11,31 +11,49 @@ """Provides :class:`.Query`.""" +from sqlalchemy import select, and_ +from sqlalchemy.sql.functions import count from .exception import SpineDBAPIError class Query: - def __init__(self, db_map, select_): + def __init__(self, db_map, *entities): self._db_map = db_map - self._select = select_ + self._entities = entities + self._select = select(entities) self._from = None + @property + def column_descriptions(self): + return [{"name": c.name} for c in self._select.columns] + def subquery(self, name=None): return self._select.alias(name) + def add_columns(self, *columns): + self._entities += columns + self._select = select(self._entities) + return self + def filter(self, *args): self._select = self._select.where(*args) return self + def filter_by(self, **kwargs): + if len(self._entities) != 1: + raise SpineDBAPIError(f"can't find a unique 'from-clause' to filter, candidates are {self._entities}") + return self.filter(and_(getattr(self._entities[0].c, k) == v for k, v in kwargs.items())) + def _get_from(self, right, on): + if self._from is not None: + return self._from from_candidates = (set(_get_descendant_tables(on)) - {right}) & set(self._select.get_children()) if len(from_candidates) != 1: raise SpineDBAPIError(f"can't find a unique 'from-clause' to join into, candidates are {from_candidates}") return next(iter(from_candidates)) def join(self, right, on, isouter=False): - from_ = self._get_from(right, on) if self._from is None else self._from - self._from = from_.join(right, on, isouter=isouter) + self._from = self._get_from(right, on).join(right, on, isouter=isouter) self._select = self._select.select_from(self._from) return self @@ -58,11 +76,52 @@ def offset(self, *args): self._select = self._select.offset(*args) return self + def distinct(self, *args): + self._select = self._select.distinct(*args) + return self + + def having(self, *args): + self._select = self._select.having(*args) + return self + + def _result(self): + return self._db_map.connection_execute(self._select) + def all(self): - return list(self) + return self._result().fetchall() + + def first(self): + return self._result().first() + + def one_or_none(self): + result = self._result() + first = result.fetchone() + if first is None: + return None + second = result.fetchone() + if second is not None: + raise SpineDBAPIError("multiple results found for one_or_none()") + return first + + def scalar(self): + return self._result().scalar() + + def count(self): + return self._db_map.connection_execute(select([count()]).select_from(self._select)).scalar() def __iter__(self): - return self._db_map.connection.execute(self._select) + return self._result() + + +def _get_leaves(parent): + children = parent.get_children() + if not children: + try: + yield parent.table + except AttributeError: + pass + for child in children: + yield from _get_leaves(child) def _get_descendant_tables(on): diff --git a/tests/export_mapping/test_export_mapping.py b/tests/export_mapping/test_export_mapping.py index f0eb2a52..d7219d2c 100644 --- a/tests/export_mapping/test_export_mapping.py +++ b/tests/export_mapping/test_export_mapping.py @@ -17,7 +17,6 @@ from spinedb_api import ( DatabaseMapping, import_alternatives, - import_features, import_object_classes, import_object_parameter_values, import_object_parameters, @@ -27,9 +26,6 @@ import_relationships, import_scenario_alternatives, import_scenarios, - import_tool_features, - import_tool_feature_methods, - import_tools, Map, ) from spinedb_api.import_functions import import_object_groups @@ -47,8 +43,6 @@ FixedValueMapping, ExpandedParameterValueMapping, ExpandedParameterDefaultValueMapping, - FeatureEntityClassMapping, - FeatureParameterDefinitionMapping, from_dict, EntityGroupMapping, EntityGroupEntityMapping, @@ -69,13 +63,6 @@ ScenarioActiveFlagMapping, ScenarioAlternativeMapping, ScenarioMapping, - ToolMapping, - ToolFeatureEntityClassMapping, - ToolFeatureParameterDefinitionMapping, - ToolFeatureRequiredFlagMapping, - ToolFeatureMethodEntityClassMapping, - ToolFeatureMethodMethodMapping, - ToolFeatureMethodParameterDefinitionMapping, ) from spinedb_api.mapping import unflatten @@ -1005,118 +992,6 @@ def test_scenario_alternative_mapping(self): self.assertEqual(tables, {None: [["s1", "a1"], ["s1", "a2"], ["s2", "a2"], ["s2", "a3"]]}) db_map.connection.close() - def test_tool_mapping(self): - db_map = DatabaseMapping("sqlite://", create=True) - import_tools(db_map, ("tool1", "tool2")) - db_map.commit_session("Add test data.") - tool_mapping = ToolMapping(0) - tables = dict() - for title, title_key in titles(tool_mapping, db_map): - tables[title] = list(rows(tool_mapping, db_map, title_key)) - self.assertEqual(tables, {None: [["tool1"], ["tool2"]]}) - db_map.connection.close() - - def test_feature_mapping(self): - db_map = DatabaseMapping("sqlite://", create=True) - import_object_classes(db_map, ("oc1", "oc2")) - import_parameter_value_lists(db_map, (("features", "feat1"), ("features", "feat2"))) - import_object_parameters( - db_map, - ( - ("oc1", "p1", "feat1", "features"), - ("oc1", "p2", "feat1", "features"), - ("oc2", "p3", "feat2", "features"), - ), - ) - import_features(db_map, (("oc1", "p2"), ("oc2", "p3"))) - db_map.commit_session("Add test data.") - class_mapping = FeatureEntityClassMapping(0) - parameter_mapping = FeatureParameterDefinitionMapping(1) - class_mapping.child = parameter_mapping - tables = dict() - for title, title_key in titles(class_mapping, db_map): - tables[title] = list(rows(class_mapping, db_map, title_key)) - self.assertEqual(tables, {None: [["oc1", "p2"], ["oc2", "p3"]]}) - db_map.connection.close() - - def test_tool_feature_mapping(self): - db_map = DatabaseMapping("sqlite://", create=True) - import_object_classes(db_map, ("oc1", "oc2")) - import_parameter_value_lists(db_map, (("features", "feat1"), ("features", "feat2"))) - import_object_parameters( - db_map, - ( - ("oc1", "p1", "feat1", "features"), - ("oc1", "p2", "feat1", "features"), - ("oc2", "p3", "feat2", "features"), - ), - ) - import_features(db_map, (("oc1", "p1"), ("oc1", "p2"), ("oc2", "p3"))) - import_tools(db_map, ("tool1", "tool2")) - import_tool_features( - db_map, (("tool1", "oc1", "p1", True), ("tool1", "oc2", "p3", False), ("tool2", "oc1", "p1", True)) - ) - db_map.commit_session("Add test data.") - mapping = unflatten( - [ - ToolMapping(Position.table_name), - ToolFeatureEntityClassMapping(0), - ToolFeatureParameterDefinitionMapping(1), - ToolFeatureRequiredFlagMapping(2), - ] - ) - tables = dict() - for title, title_key in titles(mapping, db_map): - tables[title] = list(rows(mapping, db_map, title_key)) - expected = {"tool1": [["oc1", "p1", True], ["oc2", "p3", False]], "tool2": [["oc1", "p1", True]]} - self.assertEqual(tables, expected) - db_map.connection.close() - - def test_tool_feature_method_mapping(self): - db_map = DatabaseMapping("sqlite://", create=True) - import_object_classes(db_map, ("oc1", "oc2")) - import_parameter_value_lists(db_map, (("features", "feat1"), ("features", "feat2"))) - import_object_parameters( - db_map, - ( - ("oc1", "p1", "feat1", "features"), - ("oc1", "p2", "feat1", "features"), - ("oc2", "p3", "feat2", "features"), - ), - ) - import_features(db_map, (("oc1", "p1"), ("oc1", "p2"), ("oc2", "p3"))) - import_tools(db_map, ("tool1", "tool2")) - import_tool_features( - db_map, (("tool1", "oc1", "p1", True), ("tool1", "oc2", "p3", False), ("tool2", "oc1", "p1", True)) - ) - import_tool_feature_methods( - db_map, - ( - ("tool1", "oc1", "p1", "feat1"), - ("tool1", "oc1", "p1", "feat2"), - ("tool2", "oc1", "p1", "feat1"), - ("tool2", "oc1", "p1", "feat2"), - ), - ) - db_map.commit_session("Add test data.") - mapping = unflatten( - [ - ToolMapping(Position.table_name), - ToolFeatureMethodEntityClassMapping(0), - ToolFeatureMethodParameterDefinitionMapping(1), - ToolFeatureMethodMethodMapping(2), - ] - ) - tables = dict() - for title, title_key in titles(mapping, db_map): - tables[title] = list(rows(mapping, db_map, title_key)) - expected = { - "tool1": [["oc1", "p1", "feat1"], ["oc1", "p1", "feat2"]], - "tool2": [["oc1", "p1", "feat1"], ["oc1", "p1", "feat2"]], - } - self.assertEqual(tables, expected) - db_map.connection.close() - def test_header(self): db_map = DatabaseMapping("sqlite://", create=True) import_object_classes(db_map, ("oc",)) diff --git a/tests/filters/test_alternative_filter.py b/tests/filters/test_alternative_filter.py index 5b69f54a..c24da2c0 100644 --- a/tests/filters/test_alternative_filter.py +++ b/tests/filters/test_alternative_filter.py @@ -26,6 +26,7 @@ import_object_parameter_values, import_object_parameters, import_objects, + SpineDBAPIError, ) from spinedb_api.filters.alternative_filter import ( alternative_filter_config, @@ -49,20 +50,17 @@ def setUp(self): create_new_spine_database(self._db_url) self._out_map = DatabaseMapping(self._db_url) self._db_map = DatabaseMapping(self._db_url) - self._diff_db_map = DatabaseMapping(self._db_url) def tearDown(self): self._out_map.connection.close() self._db_map.connection.close() - self._diff_db_map.connection.close() def test_alternative_filter_without_scenarios_or_alternatives(self): self._build_data_without_alternatives() self._out_map.commit_session("Add test data") - for db_map in [self._db_map, self._diff_db_map]: - apply_alternative_filter_to_parameter_value_sq(db_map, []) - parameters = db_map.query(db_map.parameter_value_sq).all() - self.assertEqual(parameters, []) + apply_alternative_filter_to_parameter_value_sq(self._db_map, []) + parameters = self._db_map.query(self._db_map.parameter_value_sq).all() + self.assertEqual(parameters, []) def test_alternative_filter_without_scenarios_or_alternatives_uncommitted_data(self): self._build_data_without_alternatives() @@ -74,18 +72,17 @@ def test_alternative_filter_without_scenarios_or_alternatives_uncommitted_data(s def test_alternative_filter(self): self._build_data_with_single_alternative() self._out_map.commit_session("Add test data") - for db_map in [self._db_map, self._diff_db_map]: - apply_alternative_filter_to_parameter_value_sq(db_map, ["alternative"]) - parameters = db_map.query(db_map.parameter_value_sq).all() - self.assertEqual(len(parameters), 1) - self.assertEqual(parameters[0].value, b"23.0") + apply_alternative_filter_to_parameter_value_sq(self._db_map, ["alternative"]) + parameters = self._db_map.query(self._db_map.parameter_value_sq).all() + self.assertEqual(len(parameters), 1) + self.assertEqual(parameters[0].value, b"23.0") def test_alternative_filter_uncommitted_data(self): self._build_data_with_single_alternative() - apply_alternative_filter_to_parameter_value_sq(self._out_map, ["alternative"]) + with self.assertRaises(SpineDBAPIError): + apply_alternative_filter_to_parameter_value_sq(self._out_map, ["alternative"]) parameters = self._out_map.query(self._out_map.parameter_value_sq).all() - self.assertEqual(len(parameters), 1) - self.assertEqual(parameters[0].value, b"23.0") + self.assertEqual(len(parameters), 0) self._out_map.rollback_session() def test_alternative_filter_from_dict(self): diff --git a/tests/filters/test_scenario_filter.py b/tests/filters/test_scenario_filter.py index 0addaa69..5adf8d2f 100644 --- a/tests/filters/test_scenario_filter.py +++ b/tests/filters/test_scenario_filter.py @@ -32,6 +32,7 @@ import_relationships, import_scenario_alternatives, import_scenarios, + SpineDBAPIError, ) from spinedb_api.filters.scenario_filter import ( scenario_filter_config, @@ -55,12 +56,10 @@ def setUp(self): create_new_spine_database(self._db_url) self._out_map = DatabaseMapping(self._db_url) self._db_map = DatabaseMapping(self._db_url) - self._diff_db_map = DatabaseMapping(self._db_url) def tearDown(self): self._out_map.connection.close() self._db_map.connection.close() - self._diff_db_map.connection.close() def _build_data_with_single_scenario(self): import_alternatives(self._out_map, ["alternative"]) @@ -74,38 +73,13 @@ def _build_data_with_single_scenario(self): def test_scenario_filter(self): _build_data_with_single_scenario(self._out_map) - for db_map in [self._db_map, self._diff_db_map]: - apply_scenario_filter_to_subqueries(db_map, "scenario") - parameters = db_map.query(db_map.parameter_value_sq).all() - self.assertEqual(len(parameters), 1) - self.assertEqual(parameters[0].value, b"23.0") - alternatives = [a._asdict() for a in db_map.query(db_map.alternative_sq)] - self.assertEqual(alternatives, [{"name": "alternative", "description": None, "id": 2, "commit_id": 2}]) - scenarios = [s._asdict() for s in db_map.query(db_map.wide_scenario_sq).all()] - self.assertEqual( - scenarios, - [ - { - "name": "scenario", - "description": None, - "active": True, - "alternative_name_list": "alternative", - "alternative_id_list": "2", - "id": 1, - "commit_id": 2, - } - ], - ) - - def test_scenario_filter_uncommitted_data(self): - _build_data_with_single_scenario(self._out_map, commit=False) - apply_scenario_filter_to_subqueries(self._out_map, "scenario") - parameters = self._out_map.query(self._out_map.parameter_value_sq).all() + apply_scenario_filter_to_subqueries(self._db_map, "scenario") + parameters = self._db_map.query(self._db_map.parameter_value_sq).all() self.assertEqual(len(parameters), 1) self.assertEqual(parameters[0].value, b"23.0") - alternatives = [a._asdict() for a in self._out_map.query(self._out_map.alternative_sq)] - self.assertEqual(alternatives, [{"name": "alternative", "description": None, "id": 2, "commit_id": 2}]) - scenarios = [s._asdict() for s in self._out_map.query(self._out_map.wide_scenario_sq).all()] + alternatives = [dict(a) for a in self._db_map.query(self._db_map.alternative_sq)] + self.assertEqual(alternatives, [{"name": "alternative", "description": None, "id": 2, "commit_id": None}]) + scenarios = [dict(s) for s in self._db_map.query(self._db_map.wide_scenario_sq).all()] self.assertEqual( scenarios, [ @@ -116,39 +90,49 @@ def test_scenario_filter_uncommitted_data(self): "alternative_name_list": "alternative", "alternative_id_list": "2", "id": 1, - "commit_id": 2, + "commit_id": None, } ], ) + + def test_scenario_filter_uncommitted_data(self): + _build_data_with_single_scenario(self._out_map, commit=False) + with self.assertRaises(SpineDBAPIError): + apply_scenario_filter_to_subqueries(self._out_map, "scenario") + parameters = self._out_map.query(self._out_map.parameter_value_sq).all() + self.assertEqual(len(parameters), 0) + alternatives = [dict(a) for a in self._out_map.query(self._out_map.alternative_sq)] + self.assertEqual(alternatives, [{"name": "Base", "description": "Base alternative", "id": 1, "commit_id": 1}]) + scenarios = self._out_map.query(self._out_map.wide_scenario_sq).all() + self.assertEqual(len(scenarios), 0) self._out_map.rollback_session() def test_scenario_filter_works_for_object_parameter_value_sq(self): _build_data_with_single_scenario(self._out_map) - for db_map in [self._db_map, self._diff_db_map]: - apply_scenario_filter_to_subqueries(db_map, "scenario") - parameters = db_map.query(db_map.object_parameter_value_sq).all() - self.assertEqual(len(parameters), 1) - self.assertEqual(parameters[0].value, b"23.0") - alternatives = [a._asdict() for a in db_map.query(db_map.alternative_sq)] - self.assertEqual(alternatives, [{"name": "alternative", "description": None, "id": 2, "commit_id": 2}]) - scenarios = [s._asdict() for s in db_map.query(db_map.wide_scenario_sq).all()] - self.assertEqual( - scenarios, - [ - { - "name": "scenario", - "description": None, - "active": True, - "alternative_name_list": "alternative", - "alternative_id_list": "2", - "id": 1, - "commit_id": 2, - } - ], - ) + apply_scenario_filter_to_subqueries(self._db_map, "scenario") + parameters = self._db_map.query(self._db_map.object_parameter_value_sq).all() + self.assertEqual(len(parameters), 1) + self.assertEqual(parameters[0].value, b"23.0") + alternatives = [dict(a) for a in self._db_map.query(self._db_map.alternative_sq)] + self.assertEqual(alternatives, [{"name": "alternative", "description": None, "id": 2, "commit_id": None}]) + scenarios = [dict(s) for s in self._db_map.query(self._db_map.wide_scenario_sq).all()] + self.assertEqual( + scenarios, + [ + { + "name": "scenario", + "description": None, + "active": True, + "alternative_name_list": "alternative", + "alternative_id_list": "2", + "id": 1, + "commit_id": None, + } + ], + ) def test_scenario_filter_works_for_relationship_parameter_value_sq(self): - _build_data_with_single_scenario(self._out_map) + _build_data_with_single_scenario(self._out_map, commit=False) import_relationship_classes(self._out_map, [("relationship_class", ["object_class"])]) import_relationship_parameters(self._out_map, [("relationship_class", "relationship_parameter")]) import_relationships(self._out_map, [("relationship_class", ["object"])]) @@ -159,28 +143,27 @@ def test_scenario_filter_works_for_relationship_parameter_value_sq(self): self._out_map, [("relationship_class", ["object"], "relationship_parameter", 23.0, "alternative")] ) self._out_map.commit_session("Add test data") - for db_map in [self._db_map, self._diff_db_map]: - apply_scenario_filter_to_subqueries(db_map, "scenario") - parameters = db_map.query(db_map.relationship_parameter_value_sq).all() - self.assertEqual(len(parameters), 1) - self.assertEqual(parameters[0].value, b"23.0") - alternatives = [a._asdict() for a in db_map.query(db_map.alternative_sq)] - self.assertEqual(alternatives, [{"name": "alternative", "description": None, "id": 2, "commit_id": 2}]) - scenarios = [s._asdict() for s in db_map.query(db_map.wide_scenario_sq).all()] - self.assertEqual( - scenarios, - [ - { - "name": "scenario", - "description": None, - "active": True, - "alternative_name_list": "alternative", - "alternative_id_list": "2", - "id": 1, - "commit_id": 2, - } - ], - ) + apply_scenario_filter_to_subqueries(self._db_map, "scenario") + parameters = self._db_map.query(self._db_map.relationship_parameter_value_sq).all() + self.assertEqual(len(parameters), 1) + self.assertEqual(parameters[0].value, b"23.0") + alternatives = [dict(a) for a in self._db_map.query(self._db_map.alternative_sq)] + self.assertEqual(alternatives, [{"name": "alternative", "description": None, "id": 2, "commit_id": None}]) + scenarios = [dict(s) for s in self._db_map.query(self._db_map.wide_scenario_sq).all()] + self.assertEqual( + scenarios, + [ + { + "name": "scenario", + "description": None, + "active": True, + "alternative_name_list": "alternative", + "alternative_id_list": "2", + "id": 1, + "commit_id": None, + } + ], + ) def test_scenario_filter_selects_highest_ranked_alternative(self): import_alternatives(self._out_map, ["alternative3"]) @@ -203,35 +186,34 @@ def test_scenario_filter_selects_highest_ranked_alternative(self): ], ) self._out_map.commit_session("Add test data") - for db_map in [self._db_map, self._diff_db_map]: - apply_scenario_filter_to_subqueries(db_map, "scenario") - parameters = db_map.query(db_map.parameter_value_sq).all() - self.assertEqual(len(parameters), 1) - self.assertEqual(parameters[0].value, b"2000.0") - alternatives = [a._asdict() for a in db_map.query(db_map.alternative_sq)] - self.assertEqual( - alternatives, - [ - {"name": "alternative3", "description": None, "id": 2, "commit_id": 2}, - {"name": "alternative1", "description": None, "id": 3, "commit_id": 2}, - {"name": "alternative2", "description": None, "id": 4, "commit_id": 2}, - ], - ) - scenarios = [s._asdict() for s in db_map.query(db_map.wide_scenario_sq).all()] - self.assertEqual( - scenarios, - [ - { - "name": "scenario", - "description": None, - "active": True, - "alternative_name_list": "alternative1,alternative3,alternative2", - "alternative_id_list": "3,2,4", - "id": 1, - "commit_id": 2, - } - ], - ) + apply_scenario_filter_to_subqueries(self._db_map, "scenario") + parameters = self._db_map.query(self._db_map.parameter_value_sq).all() + self.assertEqual(len(parameters), 1) + self.assertEqual(parameters[0].value, b"2000.0") + alternatives = [dict(a) for a in self._db_map.query(self._db_map.alternative_sq)] + self.assertEqual( + alternatives, + [ + {"name": "alternative3", "description": None, "id": 2, "commit_id": None}, + {"name": "alternative1", "description": None, "id": 3, "commit_id": None}, + {"name": "alternative2", "description": None, "id": 4, "commit_id": None}, + ], + ) + scenarios = [dict(s) for s in self._db_map.query(self._db_map.wide_scenario_sq).all()] + self.assertEqual( + scenarios, + [ + { + "name": "scenario", + "description": None, + "active": True, + "alternative_name_list": "alternative1,alternative3,alternative2", + "alternative_id_list": "3,2,4", + "id": 1, + "commit_id": None, + } + ], + ) def test_scenario_filter_selects_highest_ranked_alternative_of_active_scenario(self): import_alternatives(self._out_map, ["alternative3"]) @@ -265,35 +247,34 @@ def test_scenario_filter_selects_highest_ranked_alternative_of_active_scenario(s ], ) self._out_map.commit_session("Add test data") - for db_map in [self._db_map, self._diff_db_map]: - apply_scenario_filter_to_subqueries(db_map, "scenario") - parameters = db_map.query(db_map.parameter_value_sq).all() - self.assertEqual(len(parameters), 1) - self.assertEqual(parameters[0].value, b"2000.0") - alternatives = [a._asdict() for a in db_map.query(db_map.alternative_sq)] - self.assertEqual( - alternatives, - [ - {"name": "alternative3", "description": None, "id": 2, "commit_id": 2}, - {"name": "alternative1", "description": None, "id": 3, "commit_id": 2}, - {"name": "alternative2", "description": None, "id": 4, "commit_id": 2}, - ], - ) - scenarios = [s._asdict() for s in db_map.query(db_map.wide_scenario_sq).all()] - self.assertEqual( - scenarios, - [ - { - "name": "scenario", - "description": None, - "active": True, - "alternative_name_list": "alternative1,alternative3,alternative2", - "alternative_id_list": "3,2,4", - "id": 1, - "commit_id": 2, - } - ], - ) + apply_scenario_filter_to_subqueries(self._db_map, "scenario") + parameters = self._db_map.query(self._db_map.parameter_value_sq).all() + self.assertEqual(len(parameters), 1) + self.assertEqual(parameters[0].value, b"2000.0") + alternatives = [dict(a) for a in self._db_map.query(self._db_map.alternative_sq)] + self.assertEqual( + alternatives, + [ + {"name": "alternative3", "description": None, "id": 2, "commit_id": None}, + {"name": "alternative1", "description": None, "id": 3, "commit_id": None}, + {"name": "alternative2", "description": None, "id": 4, "commit_id": None}, + ], + ) + scenarios = [dict(s) for s in self._db_map.query(self._db_map.wide_scenario_sq).all()] + self.assertEqual( + scenarios, + [ + { + "name": "scenario", + "description": None, + "active": True, + "alternative_name_list": "alternative1,alternative3,alternative2", + "alternative_id_list": "3,2,4", + "id": 1, + "commit_id": None, + } + ], + ) def test_scenario_filter_for_multiple_objects_and_parameters(self): import_alternatives(self._out_map, ["alternative"]) @@ -313,42 +294,41 @@ def test_scenario_filter_for_multiple_objects_and_parameters(self): import_scenarios(self._out_map, [("scenario", True)]) import_scenario_alternatives(self._out_map, [("scenario", "alternative")]) self._out_map.commit_session("Add test data") - for db_map in [self._db_map, self._diff_db_map]: - apply_scenario_filter_to_subqueries(db_map, "scenario") - parameters = db_map.query(db_map.parameter_value_sq).all() - self.assertEqual(len(parameters), 4) - object_names = {o.id: o.name for o in db_map.query(db_map.object_sq).all()} - alternative_names = {a.id: a.name for a in db_map.query(db_map.alternative_sq).all()} - parameter_names = {d.id: d.name for d in db_map.query(db_map.parameter_definition_sq).all()} - datamined_values = dict() - for parameter in parameters: - self.assertEqual(alternative_names[parameter.alternative_id], "alternative") - parameter_values = datamined_values.setdefault(object_names[parameter.object_id], dict()) - parameter_values[parameter_names[parameter.parameter_definition_id]] = parameter.value - self.assertEqual( - datamined_values, + apply_scenario_filter_to_subqueries(self._db_map, "scenario") + parameters = self._db_map.query(self._db_map.parameter_value_sq).all() + self.assertEqual(len(parameters), 4) + object_names = {o.id: o.name for o in self._db_map.query(self._db_map.object_sq).all()} + alternative_names = {a.id: a.name for a in self._db_map.query(self._db_map.alternative_sq).all()} + parameter_names = {d.id: d.name for d in self._db_map.query(self._db_map.parameter_definition_sq).all()} + datamined_values = dict() + for parameter in parameters: + self.assertEqual(alternative_names[parameter.alternative_id], "alternative") + parameter_values = datamined_values.setdefault(object_names[parameter.object_id], dict()) + parameter_values[parameter_names[parameter.parameter_definition_id]] = parameter.value + self.assertEqual( + datamined_values, + { + "object1": {"parameter1": b"10.0", "parameter2": b"11.0"}, + "object2": {"parameter1": b"20.0", "parameter2": b"22.0"}, + }, + ) + alternatives = [dict(a) for a in self._db_map.query(self._db_map.alternative_sq)] + self.assertEqual(alternatives, [{"name": "alternative", "description": None, "id": 2, "commit_id": None}]) + scenarios = [dict(s) for s in self._db_map.query(self._db_map.wide_scenario_sq).all()] + self.assertEqual( + scenarios, + [ { - "object1": {"parameter1": b"10.0", "parameter2": b"11.0"}, - "object2": {"parameter1": b"20.0", "parameter2": b"22.0"}, - }, - ) - alternatives = [a._asdict() for a in db_map.query(db_map.alternative_sq)] - self.assertEqual(alternatives, [{"name": "alternative", "description": None, "id": 2, "commit_id": 2}]) - scenarios = [s._asdict() for s in db_map.query(db_map.wide_scenario_sq).all()] - self.assertEqual( - scenarios, - [ - { - "name": "scenario", - "description": None, - "active": True, - "alternative_name_list": "alternative", - "alternative_id_list": "2", - "id": 1, - "commit_id": 2, - } - ], - ) + "name": "scenario", + "description": None, + "active": True, + "alternative_name_list": "alternative", + "alternative_id_list": "2", + "id": 1, + "commit_id": None, + } + ], + ) def test_filters_scenarios_and_alternatives(self): import_scenarios(self._out_map, ("scenario1", "scenario2")) @@ -363,31 +343,30 @@ def test_filters_scenarios_and_alternatives(self): ), ) self._out_map.commit_session("Add test data.") - for db_map in (self._db_map, self._diff_db_map): - apply_scenario_filter_to_subqueries(db_map, "scenario2") - alternatives = [a._asdict() for a in db_map.query(db_map.alternative_sq)] - self.assertEqual( - alternatives, - [ - {"name": "alternative2", "description": None, "id": 3, "commit_id": 2}, - {"name": "alternative3", "description": None, "id": 4, "commit_id": 2}, - ], - ) - scenarios = [s._asdict() for s in db_map.query(db_map.wide_scenario_sq).all()] - self.assertEqual( - scenarios, - [ - { - "name": "scenario2", - "description": None, - "active": False, - "alternative_name_list": "alternative2,alternative3", - "alternative_id_list": "3,4", - "id": 2, - "commit_id": 2, - } - ], - ) + apply_scenario_filter_to_subqueries(self._db_map, "scenario2") + alternatives = [dict(a) for a in self._db_map.query(self._db_map.alternative_sq)] + self.assertEqual( + alternatives, + [ + {"name": "alternative2", "description": None, "id": 3, "commit_id": None}, + {"name": "alternative3", "description": None, "id": 4, "commit_id": None}, + ], + ) + scenarios = [dict(s) for s in self._db_map.query(self._db_map.wide_scenario_sq).all()] + self.assertEqual( + scenarios, + [ + { + "name": "scenario2", + "description": None, + "active": False, + "alternative_name_list": "alternative2,alternative3", + "alternative_id_list": "3,4", + "id": 2, + "commit_id": None, + } + ], + ) class TestScenarioFilterUtils(unittest.TestCase): diff --git a/tests/filters/test_tool_filter.py b/tests/filters/test_tool_filter.py index ccb37f36..0c0ca922 100644 --- a/tests/filters/test_tool_filter.py +++ b/tests/filters/test_tool_filter.py @@ -18,7 +18,6 @@ import unittest from sqlalchemy.engine.url import URL from spinedb_api import ( - apply_tool_filter_to_entity_sq, create_new_spine_database, DatabaseMapping, import_object_classes, @@ -30,20 +29,11 @@ import_relationship_parameter_values, import_relationship_parameters, import_parameter_value_lists, - import_tools, - import_features, - import_tool_features, - import_tool_feature_methods, SpineDBAPIError, ) -from spinedb_api.filters.tool_filter import ( - tool_filter_config, - tool_filter_config_to_shorthand, - tool_filter_from_dict, - tool_filter_shorthand_to_config, -) +@unittest.skip("obsolete, but need to adapt into the scenario filter") class TestToolEntityFilter(unittest.TestCase): _db_url = None _temp_dir = None diff --git a/tests/filters/test_tools.py b/tests/filters/test_tools.py index e783e77a..0c63544a 100644 --- a/tests/filters/test_tools.py +++ b/tests/filters/test_tools.py @@ -306,8 +306,8 @@ def test_get_scenario_name(self): self.assertEqual(name_from_dict(config), "scenario_name") def test_get_tool_name(self): - config = filter_config("tool_filter", "tool_name") - self.assertEqual(name_from_dict(config), "tool_name") + with self.assertRaises(KeyError): + _ = filter_config("tool_filter", "tool_name") def test_returns_none_if_name_not_found(self): config = entity_class_renamer_config(name="rename") diff --git a/tests/import_mapping/test_import_mapping.py b/tests/import_mapping/test_import_mapping.py index c9e66bd4..5f6e0b6c 100644 --- a/tests/import_mapping/test_import_mapping.py +++ b/tests/import_mapping/test_import_mapping.py @@ -224,30 +224,16 @@ def test_scenario_alternative_mapping(self): self.assertEqual(types, expected) def test_tool_mapping(self): - mapping = import_mapping_from_dict({"map_type": "Tool"}) - d = mapping_to_dict(mapping) - types = [m["map_type"] for m in d] - expected = ['Tool'] - self.assertEqual(types, expected) + with self.assertRaises(ValueError): + import_mapping_from_dict({"map_type": "Tool"}) def test_tool_feature_mapping(self): - mapping = import_mapping_from_dict({"map_type": "ToolFeature"}) - d = mapping_to_dict(mapping) - types = [m["map_type"] for m in d] - expected = ['Tool', 'ToolFeatureEntityClass', 'ToolFeatureParameterDefinition', 'ToolFeatureRequiredFlag'] - self.assertEqual(types, expected) + with self.assertRaises(ValueError): + import_mapping_from_dict({"map_type": "ToolFeature"}) def test_tool_feature_method_mapping(self): - mapping = import_mapping_from_dict({"map_type": "ToolFeatureMethod"}) - d = mapping_to_dict(mapping) - types = [m["map_type"] for m in d] - expected = [ - 'Tool', - 'ToolFeatureMethodEntityClass', - 'ToolFeatureMethodParameterDefinition', - 'ToolFeatureMethodMethod', - ] - self.assertEqual(types, expected) + with self.assertRaises(ValueError): + import_mapping_from_dict({"map_type": "ToolFeatureMethod"}) def test_parameter_value_list_mapping(self): mapping = import_mapping_from_dict({"map_type": "ParameterValueList"}) @@ -433,32 +419,18 @@ def test_ScenarioAlternative_to_dict_from_dict(self): def test_Tool_to_dict_from_dict(self): mapping = {"map_type": "Tool", "name": 0} - mapping = import_mapping_from_dict(mapping) - out = mapping_to_dict(mapping) - expected = [{'map_type': 'Tool', 'position': 0}] - self.assertEqual(out, expected) + with self.assertRaises(ValueError): + import_mapping_from_dict(mapping) def test_Feature_to_dict_from_dict(self): mapping = {"map_type": "Feature", "entity_class_name": 0, "parameter_definition_name": 1} - mapping = import_mapping_from_dict(mapping) - out = mapping_to_dict(mapping) - expected = [ - {'map_type': 'FeatureEntityClass', 'position': 0}, - {'map_type': 'FeatureParameterDefinition', 'position': 1}, - ] - self.assertEqual(out, expected) + with self.assertRaises(ValueError): + import_mapping_from_dict(mapping) def test_ToolFeature_to_dict_from_dict(self): mapping = {"map_type": "ToolFeature", "name": 0, "entity_class_name": 1, "parameter_definition_name": 2} - mapping = import_mapping_from_dict(mapping) - out = mapping_to_dict(mapping) - expected = [ - {'map_type': 'Tool', 'position': 0}, - {'map_type': 'ToolFeatureEntityClass', 'position': 1}, - {'map_type': 'ToolFeatureParameterDefinition', 'position': 2}, - {'map_type': 'ToolFeatureRequiredFlag', 'position': 'hidden', 'value': 'false'}, - ] - self.assertEqual(out, expected) + with self.assertRaises(ValueError): + import_mapping_from_dict(mapping) def test_ToolFeatureMethod_to_dict_from_dict(self): mapping = { @@ -468,15 +440,8 @@ def test_ToolFeatureMethod_to_dict_from_dict(self): "parameter_definition_name": 2, "method": 3, } - mapping = import_mapping_from_dict(mapping) - out = mapping_to_dict(mapping) - expected = [ - {'map_type': 'Tool', 'position': 0}, - {'map_type': 'ToolFeatureMethodEntityClass', 'position': 1}, - {'map_type': 'ToolFeatureMethodParameterDefinition', 'position': 2}, - {'map_type': 'ToolFeatureMethodMethod', 'position': 3}, - ] - self.assertEqual(out, expected) + with self.assertRaises(ValueError): + import_mapping_from_dict(mapping) def test_MapValueMapping_from_dict_to_dict(self): mapping_dict = { @@ -1733,30 +1698,24 @@ def test_read_tool(self): data = iter(input_data) data_header = next(data) mapping = {"map_type": "Tool", "name": 0} - out, errors = get_mapped_data(data, [mapping], data_header) - expected = {"tools": {"tool1", "second_tool", "last_one"}} - self.assertFalse(errors) - self.assertEqual(out, expected) + with self.assertRaises(ValueError): + get_mapped_data(data, [mapping], data_header) def test_read_feature(self): input_data = [["Class", "Parameter"], ["class1", "param1"], ["class2", "param2"]] data = iter(input_data) data_header = next(data) mapping = {"map_type": "Feature", "entity_class_name": 0, "parameter_definition_name": 1} - out, errors = get_mapped_data(data, [mapping], data_header) - expected = {"features": {("class1", "param1"), ("class2", "param2")}} - self.assertFalse(errors) - self.assertEqual(out, expected) + with self.assertRaises(ValueError): + get_mapped_data(data, [mapping], data_header) def test_read_tool_feature(self): input_data = [["Tool", "Class", "Parameter"], ["tool1", "class1", "param1"], ["tool2", "class2", "param2"]] data = iter(input_data) data_header = next(data) mapping = {"map_type": "ToolFeature", "name": 0, "entity_class_name": 1, "parameter_definition_name": 2} - out, errors = get_mapped_data(data, [mapping], data_header) - expected = {"tool_features": [["tool1", "class1", "param1", False], ["tool2", "class2", "param2", False]]} - self.assertFalse(errors) - self.assertEqual(out, expected) + with self.assertRaises(ValueError): + get_mapped_data(data, [mapping], data_header) def test_read_tool_feature_with_required_flag(self): input_data = [ @@ -1773,10 +1732,8 @@ def test_read_tool_feature_with_required_flag(self): "parameter_definition_name": 2, "required": 3, } - out, errors = get_mapped_data(data, [mapping], data_header) - expected = {"tool_features": [["tool1", "class1", "param1", False], ["tool2", "class2", "param2", True]]} - self.assertFalse(errors) - self.assertEqual(out, expected) + with self.assertRaises(ValueError): + get_mapped_data(data, [mapping], data_header) def test_read_tool_feature_method(self): input_data = [ @@ -1793,14 +1750,8 @@ def test_read_tool_feature_method(self): "parameter_definition_name": 2, "method": 3, } - out, errors = get_mapped_data(data, [mapping], data_header) - expected = dict() - expected["tool_feature_methods"] = [ - ["tool1", "class1", "param1", "meth1"], - ["tool2", "class2", "param2", "meth2"], - ] - self.assertFalse(errors) - self.assertEqual(out, expected) + with self.assertRaises(ValueError): + get_mapped_data(data, [mapping], data_header) def test_read_object_group_without_parameters(self): input_data = [ diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 1127c4b8..a91de2a7 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -315,6 +315,7 @@ def test_commit_sq_hides_pending_commit(self): def test_alternative_sq(self): import_functions.import_alternatives(self._db_map, (("alt1", "test alternative"),)) + self._db_map.commit_session("test") alternative_rows = self._db_map.query(self._db_map.alternative_sq).all() expected_names_and_descriptions = {"Base": "Base alternative", "alt1": "test alternative"} self.assertEqual(len(alternative_rows), len(expected_names_and_descriptions)) @@ -326,6 +327,7 @@ def test_alternative_sq(self): def test_scenario_sq(self): import_functions.import_scenarios(self._db_map, (("scen1", True, "test scenario"),)) + self._db_map.commit_session("test") scenario_rows = self._db_map.query(self._db_map.scenario_sq).all() self.assertEqual(len(scenario_rows), 1) self.assertEqual(scenario_rows[0].name, "scen1") @@ -338,6 +340,7 @@ def test_ext_linked_scenario_alternative_sq(self): import_functions.import_scenario_alternatives(self._db_map, (("scen1", "alt2"),)) import_functions.import_scenario_alternatives(self._db_map, (("scen1", "alt3"),)) import_functions.import_scenario_alternatives(self._db_map, (("scen1", "alt1"),)) + self._db_map.commit_session("test") scenario_alternative_rows = self._db_map.query(self._db_map.ext_linked_scenario_alternative_sq).all() self.assertEqual(len(scenario_alternative_rows), 3) expected_befores = {"alt2": "alt3", "alt3": "alt1", "alt1": None} @@ -359,6 +362,7 @@ def test_ext_linked_scenario_alternative_sq(self): def test_entity_class_sq(self): obj_classes = self.create_object_classes() relationship_classes = self.create_relationship_classes() + self._db_map.commit_session("test") results = self._db_map.query(self._db_map.entity_class_sq).all() # Check that number of results matches total entities self.assertEqual(len(results), len(obj_classes) + len(relationship_classes)) @@ -371,6 +375,7 @@ def test_entity_sq(self): objects = self.create_objects() self.create_relationship_classes() relationships = self.create_relationships() + self._db_map.commit_session("test") entity_rows = self._db_map.query(self._db_map.entity_sq).all() self.assertEqual(len(entity_rows), len(objects) + len(relationships)) object_names = [o[1] for o in objects] @@ -381,6 +386,7 @@ def test_entity_sq(self): def test_object_class_sq_picks_object_classes_only(self): obj_classes = self.create_object_classes() self.create_relationship_classes() + self._db_map.commit_session("test") class_rows = self._db_map.query(self._db_map.object_class_sq).all() self.assertEqual(len(class_rows), len(obj_classes)) for row, expected_name in zip(class_rows, obj_classes): @@ -391,6 +397,7 @@ def test_object_sq_picks_objects_only(self): objects = self.create_objects() self.create_relationship_classes() self.create_relationships() + self._db_map.commit_session("test") object_rows = self._db_map.query(self._db_map.object_sq).all() self.assertEqual(len(object_rows), len(objects)) for row, expected_object in zip(object_rows, objects): @@ -399,6 +406,7 @@ def test_object_sq_picks_objects_only(self): def test_wide_relationship_class_sq(self): self.create_object_classes() relationship_classes = self.create_relationship_classes() + self._db_map.commit_session("test") class_rows = self._db_map.query(self._db_map.wide_relationship_class_sq).all() self.assertEqual(len(class_rows), 2) for row, relationship_class in zip(class_rows, relationship_classes): @@ -411,6 +419,7 @@ def test_wide_relationship_sq(self): relationship_classes = self.create_relationship_classes() object_classes = {rel_class[0]: rel_class[1] for rel_class in relationship_classes} relationships = self.create_relationships() + self._db_map.commit_session("test") relationship_rows = self._db_map.query(self._db_map.wide_relationship_sq).all() self.assertEqual(len(relationship_rows), 2) for row, relationship in zip(relationship_rows, relationships): @@ -422,6 +431,7 @@ def test_wide_relationship_sq(self): def test_parameter_definition_sq_for_object_class(self): self.create_object_classes() import_functions.import_object_parameters(self._db_map, (("class1", "par1"),)) + self._db_map.commit_session("test") definition_rows = self._db_map.query(self._db_map.parameter_definition_sq).all() self.assertEqual(len(definition_rows), 1) self.assertEqual(definition_rows[0].name, "par1") @@ -432,6 +442,7 @@ def test_parameter_definition_sq_for_relationship_class(self): self.create_object_classes() self.create_relationship_classes() import_functions.import_relationship_parameters(self._db_map, (("rel1", "par1"),)) + self._db_map.commit_session("test") definition_rows = self._db_map.query(self._db_map.parameter_definition_sq).all() self.assertEqual(len(definition_rows), 1) self.assertEqual(definition_rows[0].name, "par1") @@ -442,6 +453,7 @@ def test_entity_parameter_definition_sq_for_object_class(self): self.create_object_classes() self.create_relationship_classes() import_functions.import_object_parameters(self._db_map, (("class1", "par1"),)) + self._db_map.commit_session("test") definition_rows = self._db_map.query(self._db_map.entity_parameter_definition_sq).all() self.assertEqual(len(definition_rows), 1) self.assertEqual(definition_rows[0].parameter_name, "par1") @@ -456,6 +468,7 @@ def test_entity_parameter_definition_sq_for_relationship_class(self): object_classes = self.create_object_classes() self.create_relationship_classes() import_functions.import_relationship_parameters(self._db_map, (("rel2", "par1"),)) + self._db_map.commit_session("test") definition_rows = self._db_map.query(self._db_map.entity_parameter_definition_sq).all() self.assertEqual(len(definition_rows), 1) self.assertEqual(definition_rows[0].parameter_name, "par1") @@ -473,6 +486,7 @@ def test_entity_parameter_definition_sq_with_multiple_relationship_classes_but_s rel_parameter_definitions = [('rel1', 'rpar1a')] import_functions.import_object_parameters(self._db_map, obj_parameter_definitions) import_functions.import_relationship_parameters(self._db_map, rel_parameter_definitions) + self._db_map.commit_session("test") results = self._db_map.query(self._db_map.entity_parameter_definition_sq).all() # Check that number of results matches total entities self.assertEqual(len(results), len(obj_parameter_definitions) + len(rel_parameter_definitions)) @@ -499,6 +513,7 @@ def test_entity_parameter_values(self): relationship_parameter_values = [('rel1', ['obj11'], 'rpar1a', 1.1), ('rel2', ['obj11', 'obj21'], 'rpar2a', 42)] _, errors = import_functions.import_relationship_parameter_values(self._db_map, relationship_parameter_values) self.assertFalse(errors) + self._db_map.commit_session("test") results = self._db_map.query(self._db_map.entity_parameter_value_sq).all() # Check that number of results matches total entities self.assertEqual(len(results), len(object_parameter_values) + len(relationship_parameter_values)) @@ -517,6 +532,7 @@ def test_wide_parameter_value_list_sq(self): self._db_map, (("list1", "value1"), ("list1", "value2"), ("list2", "valueA")) ) self.assertEqual(errors, []) + self._db_map.commit_session("test") value_lists = self._db_map.query(self._db_map.wide_parameter_value_list_sq).all() self.assertEqual(len(value_lists), 2) self.assertEqual(value_lists[0].name, "list1") @@ -530,37 +546,12 @@ def setUp(self): def tearDown(self): self._db_map.connection.close() - def test_update_method_of_tool_feature_method(self): - import_functions.import_object_classes(self._db_map, ("object_class1", "object_class2")) - import_functions.import_parameter_value_lists( - self._db_map, (("value_list", "value1"), ("value_list", "value2")) - ) - import_functions.import_object_parameters( - self._db_map, (("object_class1", "parameter1", "value1", "value_list"), ("object_class1", "parameter2")) - ) - import_functions.import_features(self._db_map, (("object_class1", "parameter1"),)) - import_functions.import_tools(self._db_map, ("tool1",)) - import_functions.import_tool_features(self._db_map, (("tool1", "object_class1", "parameter1"),)) - import_functions.import_tool_feature_methods( - self._db_map, (("tool1", "object_class1", "parameter1", "value2"),) - ) - self._db_map.commit_session("Populate with initial data.") - updated_ids, errors = self._db_map.update_tool_feature_methods( - {"id": 1, "method_index": 0, "method": to_database("value1")[0]} - ) - self.assertEqual(errors, []) - self.assertEqual(updated_ids, {1}) - self._db_map.commit_session("Update data.") - tool_feature_methods = self._db_map.query(self._db_map.ext_tool_feature_method_sq).all() - self.assertEqual(len(tool_feature_methods), 1) - tool_feature_method = tool_feature_methods[0] - self.assertEqual(tool_feature_method.method, to_database("value1")[0]) - def test_update_wide_relationship_class(self): _ = import_functions.import_object_classes(self._db_map, ("object_class_1",)) _ = import_functions.import_relationship_classes(self._db_map, (("my_class", ("object_class_1",)),)) self._db_map.commit_session("Add test data") - updated_ids, errors = self._db_map.update_wide_relationship_classes({"id": 2, "name": "renamed"}) + items, errors = self._db_map.update_wide_relationship_classes({"id": 2, "name": "renamed"}) + updated_ids = {x["id"] for x in items} self.assertEqual(errors, []) self.assertEqual(updated_ids, {2}) self._db_map.commit_session("Update data.") @@ -572,9 +563,10 @@ def test_update_wide_relationship_class_does_not_update_member_class_id(self): import_functions.import_object_classes(self._db_map, ("object_class_1", "object_class_2")) import_functions.import_relationship_classes(self._db_map, (("my_class", ("object_class_1",)),)) self._db_map.commit_session("Add test data") - updated_ids, errors = self._db_map.update_wide_relationship_classes( + items, errors = self._db_map.update_wide_relationship_classes( {"id": 3, "name": "renamed", "object_class_id_list": [2]} ) + updated_ids = {x["id"] for x in items} self.assertEqual([str(err) for err in errors], ["Can't update fixed fields 'dimension_id_list'"]) self.assertEqual(updated_ids, {3}) self._db_map.commit_session("Update data.") @@ -594,9 +586,8 @@ def test_update_wide_relationship(self): ) import_functions.import_relationships(self._db_map, (("my_class", ("object_11", "object_21")),)) self._db_map.commit_session("Add test data") - updated_ids, errors = self._db_map.update_wide_relationships( - {"id": 4, "name": "renamed", "object_id_list": [2, 3]} - ) + items, errors = self._db_map.update_wide_relationships({"id": 4, "name": "renamed", "object_id_list": [2, 3]}) + updated_ids = {x["id"] for x in items} self.assertEqual(errors, []) self.assertEqual(updated_ids, {4}) self._db_map.commit_session("Update data.") @@ -613,7 +604,8 @@ def test_update_parameter_value_by_id_only(self): self._db_map, (("object_class1", "object1", "parameter1", "something"),) ) self._db_map.commit_session("Populate with initial data.") - updated_ids, errors = self._db_map.update_parameter_values({"id": 1, "value": b"something else"}) + items, errors = self._db_map.update_parameter_values({"id": 1, "value": b"something else"}) + updated_ids = {x["id"] for x in items} self.assertEqual(errors, []) self.assertEqual(updated_ids, {1}) self._db_map.commit_session("Update data.") @@ -626,7 +618,8 @@ def test_update_parameter_definition_by_id_only(self): import_functions.import_object_classes(self._db_map, ("object_class1",)) import_functions.import_object_parameters(self._db_map, (("object_class1", "parameter1"),)) self._db_map.commit_session("Populate with initial data.") - updated_ids, errors = self._db_map.update_parameter_definitions({"id": 1, "name": "parameter2"}) + items, errors = self._db_map.update_parameter_definitions({"id": 1, "name": "parameter2"}) + updated_ids = {x["id"] for x in items} self.assertEqual(errors, []) self.assertEqual(updated_ids, {1}) self._db_map.commit_session("Update data.") @@ -639,18 +632,19 @@ def test_update_parameter_definition_value_list(self): import_functions.import_object_classes(self._db_map, ("object_class",)) import_functions.import_object_parameters(self._db_map, (("object_class", "my_parameter"),)) self._db_map.commit_session("Populate with initial data.") - updated_ids, errors = self._db_map.update_parameter_definitions( + items, errors = self._db_map.update_parameter_definitions( {"id": 1, "name": "my_parameter", "parameter_value_list_id": 1} ) + updated_ids = {x["id"] for x in items} self.assertEqual(errors, []) self.assertEqual(updated_ids, {1}) self._db_map.commit_session("Update data.") pdefs = self._db_map.query(self._db_map.parameter_definition_sq).all() self.assertEqual(len(pdefs), 1) self.assertEqual( - pdefs[0]._asdict(), + dict(pdefs[0]), { - "commit_id": 3, + "commit_id": None, "default_type": None, "default_value": None, "description": None, @@ -673,9 +667,10 @@ def test_update_parameter_definition_value_list_when_values_exist_gives_error(se self._db_map, (("object_class", "my_object", "my_parameter", 23.0),) ) self._db_map.commit_session("Populate with initial data.") - updated_ids, errors = self._db_map.update_parameter_definitions( + items, errors = self._db_map.update_parameter_definitions( {"id": 1, "name": "my_parameter", "parameter_value_list_id": 1} ) + updated_ids = {x["id"] for x in items} self.assertEqual( list(map(str, errors)), ["Can't change value list on parameter my_parameter because it has parameter values."], @@ -688,13 +683,11 @@ def test_update_parameter_definitions_default_value_that_is_not_on_value_list_gi import_functions.import_objects(self._db_map, (("object_class", "my_object"),)) import_functions.import_object_parameters(self._db_map, (("object_class", "my_parameter", None, "my_list"),)) self._db_map.commit_session("Populate with initial data.") - updated_ids, errors = self._db_map.update_parameter_definitions( + items, errors = self._db_map.update_parameter_definitions( {"id": 1, "name": "my_parameter", "default_value": to_database(23.0)[0]} ) - self.assertEqual( - list(map(str, errors)), - ["Invalid default_value '23.0' - it should be one from the parameter value list: '99.0'."], - ) + updated_ids = {x["id"] for x in items} + self.assertEqual(list(map(str, errors)), ["default value 23.0 of my_parameter is not in my_list"]) self.assertEqual(updated_ids, set()) def test_update_parameter_definition_value_list_when_default_value_not_on_the_list_exists_gives_error(self): @@ -703,13 +696,11 @@ def test_update_parameter_definition_value_list_when_default_value_not_on_the_li import_functions.import_objects(self._db_map, (("object_class", "my_object"),)) import_functions.import_object_parameters(self._db_map, (("object_class", "my_parameter", 23.0),)) self._db_map.commit_session("Populate with initial data.") - updated_ids, errors = self._db_map.update_parameter_definitions( + items, errors = self._db_map.update_parameter_definitions( {"id": 1, "name": "my_parameter", "parameter_value_list_id": 1} ) - self.assertEqual( - list(map(str, errors)), - ["Invalid default_value '23.0' - it should be one from the parameter value list: '99.0'."], - ) + updated_ids = {x["id"] for x in items} + self.assertEqual(list(map(str, errors)), ["default value 23.0 of my_parameter is not in my_list"]) self.assertEqual(updated_ids, set()) def test_update_object_metadata(self): @@ -718,20 +709,20 @@ def test_update_object_metadata(self): import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) import_functions.import_object_metadata(self._db_map, (("my_class", "my_object", '{"title": "My metadata."}'),)) self._db_map.commit_session("Add test data") - ids, errors = self._db_map.update_ext_entity_metadata( + items, errors = self._db_map.update_ext_entity_metadata( *[{"id": 1, "metadata_name": "key_2", "metadata_value": "new value"}] ) + ids = {x["id"] for x in items} self.assertEqual(errors, []) self.assertEqual(ids, {1}) + self._db_map.commit_session("Update data") metadata_entries = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata_entries), 1) - self.assertEqual( - metadata_entries[0]._asdict(), {"id": 1, "name": "key_2", "value": "new value", "commit_id": 3} - ) + self.assertEqual(dict(metadata_entries[0]), {"id": 1, "name": "key_2", "value": "new value", "commit_id": None}) entity_metadata_entries = self._db_map.query(self._db_map.entity_metadata_sq).all() self.assertEqual(len(entity_metadata_entries), 1) self.assertEqual( - entity_metadata_entries[0]._asdict(), {"id": 1, "entity_id": 1, "metadata_id": 1, "commit_id": 2} + dict(entity_metadata_entries[0]), {"id": 1, "entity_id": 1, "metadata_id": 1, "commit_id": None} ) def test_update_object_metadata_reuses_existing_metadata(self): @@ -746,26 +737,28 @@ def test_update_object_metadata_reuses_existing_metadata(self): ), ) self._db_map.commit_session("Add test data") - ids, errors = self._db_map.update_ext_entity_metadata( + items, errors = self._db_map.update_ext_entity_metadata( *[{"id": 1, "metadata_name": "key 2", "metadata_value": "metadata value 2"}] ) + ids = {x["id"] for x in items} self.assertEqual(errors, []) self.assertEqual(ids, {1}) + self._db_map.commit_session("Update data") metadata_entries = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata_entries), 2) self.assertEqual( - metadata_entries[0]._asdict(), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": 2} + dict(metadata_entries[0]), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": None} ) self.assertEqual( - metadata_entries[1]._asdict(), {"id": 2, "name": "key 2", "value": "metadata value 2", "commit_id": 2} + dict(metadata_entries[1]), {"id": 2, "name": "key 2", "value": "metadata value 2", "commit_id": None} ) entity_metadata_entries = self._db_map.query(self._db_map.entity_metadata_sq).all() self.assertEqual(len(entity_metadata_entries), 2) self.assertEqual( - entity_metadata_entries[0]._asdict(), {"id": 1, "entity_id": 1, "metadata_id": 2, "commit_id": 3} + dict(entity_metadata_entries[0]), {"id": 1, "entity_id": 1, "metadata_id": 2, "commit_id": None} ) self.assertEqual( - entity_metadata_entries[1]._asdict(), {"id": 2, "entity_id": 2, "metadata_id": 2, "commit_id": 2} + dict(entity_metadata_entries[1]), {"id": 2, "entity_id": 2, "metadata_id": 2, "commit_id": None} ) def test_update_object_metadata_keeps_metadata_still_in_use(self): @@ -780,26 +773,28 @@ def test_update_object_metadata_keeps_metadata_still_in_use(self): ), ) self._db_map.commit_session("Add test data") - ids, errors = self._db_map.update_ext_entity_metadata( + items, errors = self._db_map.update_ext_entity_metadata( *[{"id": 1, "metadata_name": "new key", "metadata_value": "new value"}] ) + ids = {x["id"] for x in items} self.assertEqual(errors, []) self.assertEqual(ids, {1, 2}) + self._db_map.commit_session("Update data") metadata_entries = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata_entries), 2) self.assertEqual( - metadata_entries[0]._asdict(), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": 2} + dict(metadata_entries[0]), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": None} ) self.assertEqual( - metadata_entries[1]._asdict(), {"id": 2, "name": "new key", "value": "new value", "commit_id": 3} + dict(metadata_entries[1]), {"id": 2, "name": "new key", "value": "new value", "commit_id": None} ) entity_metadata_entries = self._db_map.query(self._db_map.entity_metadata_sq).all() self.assertEqual(len(entity_metadata_entries), 2) self.assertEqual( - entity_metadata_entries[0]._asdict(), {"id": 1, "entity_id": 1, "metadata_id": 2, "commit_id": 3} + dict(entity_metadata_entries[0]), {"id": 1, "entity_id": 1, "metadata_id": 2, "commit_id": None} ) self.assertEqual( - entity_metadata_entries[1]._asdict(), {"id": 2, "entity_id": 2, "metadata_id": 1, "commit_id": 2} + dict(entity_metadata_entries[1]), {"id": 2, "entity_id": 2, "metadata_id": 1, "commit_id": None} ) def test_update_parameter_value_metadata(self): @@ -814,20 +809,20 @@ def test_update_parameter_value_metadata(self): self._db_map, (("my_class", "my_object", "my_parameter", '{"title": "My metadata."}'),) ) self._db_map.commit_session("Add test data") - ids, errors = self._db_map.update_ext_parameter_value_metadata( + items, errors = self._db_map.update_ext_parameter_value_metadata( *[{"id": 1, "metadata_name": "key_2", "metadata_value": "new value"}] ) + ids = {x["id"] for x in items} self.assertEqual(errors, []) self.assertEqual(ids, {1}) + self._db_map.commit_session("Update data") metadata_entries = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata_entries), 1) - self.assertEqual( - metadata_entries[0]._asdict(), {"id": 1, "name": "key_2", "value": "new value", "commit_id": 3} - ) + self.assertEqual(dict(metadata_entries[0]), {"id": 1, "name": "key_2", "value": "new value", "commit_id": None}) value_metadata_entries = self._db_map.query(self._db_map.parameter_value_metadata_sq).all() self.assertEqual(len(value_metadata_entries), 1) self.assertEqual( - value_metadata_entries[0]._asdict(), {"id": 1, "parameter_value_id": 1, "metadata_id": 1, "commit_id": 2} + dict(value_metadata_entries[0]), {"id": 1, "parameter_value_id": 1, "metadata_id": 1, "commit_id": None} ) def test_update_parameter_value_metadata_will_not_delete_shared_entity_metadata(self): @@ -843,40 +838,42 @@ def test_update_parameter_value_metadata_will_not_delete_shared_entity_metadata( self._db_map, (("my_class", "my_object", "my_parameter", '{"title": "My metadata."}'),) ) self._db_map.commit_session("Add test data") - ids, errors = self._db_map.update_ext_parameter_value_metadata( + items, errors = self._db_map.update_ext_parameter_value_metadata( *[{"id": 1, "metadata_name": "key_2", "metadata_value": "new value"}] ) + ids = {x["id"] for x in items} self.assertEqual(errors, []) self.assertEqual(ids, {1, 2}) + self._db_map.commit_session("Update data") metadata_entries = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata_entries), 2) self.assertEqual( - metadata_entries[0]._asdict(), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": 2} - ) - self.assertEqual( - metadata_entries[1]._asdict(), {"id": 2, "name": "key_2", "value": "new value", "commit_id": 3} + dict(metadata_entries[0]), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": None} ) + self.assertEqual(dict(metadata_entries[1]), {"id": 2, "name": "key_2", "value": "new value", "commit_id": None}) value_metadata_entries = self._db_map.query(self._db_map.parameter_value_metadata_sq).all() self.assertEqual(len(value_metadata_entries), 1) self.assertEqual( - value_metadata_entries[0]._asdict(), {"id": 1, "parameter_value_id": 1, "metadata_id": 2, "commit_id": 3} + dict(value_metadata_entries[0]), {"id": 1, "parameter_value_id": 1, "metadata_id": 2, "commit_id": None} ) entity_metadata_entries = self._db_map.query(self._db_map.entity_metadata_sq).all() self.assertEqual(len(entity_metadata_entries), 1) self.assertEqual( - entity_metadata_entries[0]._asdict(), {"id": 1, "entity_id": 1, "metadata_id": 1, "commit_id": 2} + dict(entity_metadata_entries[0]), {"id": 1, "entity_id": 1, "metadata_id": 1, "commit_id": None} ) def test_update_metadata(self): import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) self._db_map.commit_session("Add test data.") - ids, errors = self._db_map.update_metadata(*({"id": 1, "name": "author", "value": "Prof. T. Est"},)) + items, errors = self._db_map.update_metadata(*({"id": 1, "name": "author", "value": "Prof. T. Est"},)) + ids = {x["id"] for x in items} self.assertEqual(errors, []) self.assertEqual(ids, {1}) + self._db_map.commit_session("Update data") metadata_records = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata_records), 1) self.assertEqual( - metadata_records[0]._asdict(), {"id": 1, "name": "author", "value": "Prof. T. Est", "commit_id": 3} + dict(metadata_records[0]), {"id": 1, "name": "author", "value": "Prof. T. Est", "commit_id": None} ) @@ -893,7 +890,7 @@ def test_remove_works_when_entity_groups_are_present(self): import_functions.import_objects(self._db_map, (("my_class", "my_group"),)) import_functions.import_object_groups(self._db_map, (("my_class", "my_group", "my_object"),)) self._db_map.commit_session("Add test data.") - self._db_map.cascade_remove_items(object={1}) # This shouldn't raise an exception + self._db_map.remove_items("object", 1) # This shouldn't raise an exception self._db_map.commit_session("Remove object.") objects = self._db_map.query(self._db_map.object_sq).all() self.assertEqual(len(objects), 1) @@ -904,7 +901,7 @@ def test_remove_object_class(self): self._db_map.commit_session("Add test data.") my_class = self._db_map.query(self._db_map.object_class_sq).one_or_none() self.assertIsNotNone(my_class) - self._db_map.cascade_remove_items(**{"object_class": {my_class.id}}) + self._db_map.remove_items("object_class", my_class.id) self._db_map.commit_session("Remove object class.") my_class = self._db_map.query(self._db_map.object_class_sq).one_or_none() self.assertIsNone(my_class) @@ -915,7 +912,7 @@ def test_remove_relationship_class(self): self._db_map.commit_session("Add test data.") my_class = self._db_map.query(self._db_map.relationship_class_sq).one_or_none() self.assertIsNotNone(my_class) - self._db_map.cascade_remove_items(**{"relationship_class": {my_class.id}}) + self._db_map.remove_items("relationship_class", my_class.id) self._db_map.commit_session("Remove relationship class.") my_class = self._db_map.query(self._db_map.relationship_class_sq).one_or_none() self.assertIsNone(my_class) @@ -926,7 +923,7 @@ def test_remove_object(self): self._db_map.commit_session("Add test data.") my_object = self._db_map.query(self._db_map.object_sq).one_or_none() self.assertIsNotNone(my_object) - self._db_map.cascade_remove_items(**{"object": {my_object.id}}) + self._db_map.remove_items("object", my_object.id) self._db_map.commit_session("Remove object.") my_object = self._db_map.query(self._db_map.object_sq).one_or_none() self.assertIsNone(my_object) @@ -939,7 +936,7 @@ def test_remove_relationship(self): self._db_map.commit_session("Add test data.") my_relationship = self._db_map.query(self._db_map.relationship_sq).one_or_none() self.assertIsNotNone(my_relationship) - self._db_map.cascade_remove_items(**{"relationship": {2}}) + self._db_map.remove_items("relationship", 2) self._db_map.commit_session("Remove relationship.") my_relationship = self._db_map.query(self._db_map.relationship_sq).one_or_none() self.assertIsNone(my_relationship) @@ -954,7 +951,7 @@ def test_remove_parameter_value(self): self._db_map.commit_session("Add test data.") my_value = self._db_map.query(self._db_map.object_parameter_value_sq).one_or_none() self.assertIsNotNone(my_value) - self._db_map.cascade_remove_items(**{"parameter_value": {my_value.id}}) + self._db_map.remove_items("parameter_value", my_value.id) self._db_map.commit_session("Remove parameter value.") my_parameter = self._db_map.query(self._db_map.object_parameter_value_sq).one_or_none() self.assertIsNone(my_parameter) diff --git a/tests/test_DiffDatabaseMapping.py b/tests/test_DiffDatabaseMapping.py index 3e45697a..dbaf8784 100644 --- a/tests/test_DiffDatabaseMapping.py +++ b/tests/test_DiffDatabaseMapping.py @@ -72,8 +72,9 @@ def test_shorthand_filter_query_works(self): url = URL("sqlite") url.database = os.path.join(temp_dir, "test_shorthand_filter_query_works.json") out_db = DatabaseMapping(url, create=True) - out_db.add_tools({"name": "object_activity_control", "id": 1}) - out_db.commit_session("Add tool.") + out_db.add_scenarios({"name": "scen1"}) + out_db.add_scenario_alternatives({"scenario_name": "scen1", "alternative_name": "Base", "rank": 1}) + out_db.commit_session("Add scen.") out_db.connection.close() try: db_map = DatabaseMapping(url) @@ -95,10 +96,10 @@ def test_cascade_remove_relationship(self): self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "object_class_id_list": [1, 2]}) self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 2}) - ids, _ = self._db_map.add_wide_relationships( + items, _ = self._db_map.add_wide_relationships( {"id": 3, "name": "remove_me", "class_id": 3, "object_id_list": [1, 2]} ) - self._db_map.cascade_remove_items(relationship=ids) + self._db_map.remove_items("relationship", *{x["id"] for x in items}) self.assertEqual(len(self._db_map.query(self._db_map.wide_relationship_sq).all()), 0) self._db_map.commit_session("delete") self.assertEqual(len(self._db_map.query(self._db_map.wide_relationship_sq).all()), 0) @@ -108,37 +109,34 @@ def test_cascade_remove_relationship_from_committed_session(self): self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "object_class_id_list": [1, 2]}) self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 2}) - ids, _ = self._db_map.add_wide_relationships( + items, _ = self._db_map.add_wide_relationships( {"id": 3, "name": "remove_me", "class_id": 3, "object_id_list": [1, 2]} ) self._db_map.commit_session("add") self.assertEqual(len(self._db_map.query(self._db_map.wide_relationship_sq).all()), 1) - self._db_map.cascade_remove_items(relationship=ids) - self.assertEqual(len(self._db_map.query(self._db_map.wide_relationship_sq).all()), 0) + self._db_map.remove_items("relationship", *{x["id"] for x in items}) self._db_map.commit_session("Add test data.") self.assertEqual(len(self._db_map.query(self._db_map.wide_relationship_sq).all()), 0) def test_remove_object(self): """Test adding and removing an object and committing""" self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) - ids, _ = self._db_map.add_objects( + items, _ = self._db_map.add_objects( {"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 2} ) - self._db_map.remove_items(object=ids) - self.assertEqual(len(self._db_map.query(self._db_map.wide_relationship_sq).all()), 0) + self._db_map.remove_items("object", *{x["id"] for x in items}) self._db_map.commit_session("delete") - self.assertEqual(len(self._db_map.query(self._db_map.wide_relationship_sq).all()), 0) + self.assertEqual(len(self._db_map.query(self._db_map.object_sq).all()), 0) def test_remove_object_from_committed_session(self): """Test removing an object from a committed session""" self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) - ids, _ = self._db_map.add_objects( + items, _ = self._db_map.add_objects( {"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 2} ) self._db_map.commit_session("add") self.assertEqual(len(self._db_map.query(self._db_map.object_sq).all()), 2) - self._db_map.remove_items(object=ids) - self.assertEqual(len(self._db_map.query(self._db_map.object_sq).all()), 0) + self._db_map.remove_items("object", *{x["id"] for x in items}) self._db_map.commit_session("Add test data.") self.assertEqual(len(self._db_map.query(self._db_map.object_sq).all()), 0) @@ -146,9 +144,8 @@ def test_remove_entity_group(self): """Test adding and removing an entity group and committing""" self._db_map.add_object_classes({"name": "oc1", "id": 1}) self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 1}) - ids, _ = self._db_map.add_entity_groups({"entity_id": 1, "entity_class_id": 1, "member_id": 2}) - self._db_map.remove_items(entity_group=ids) - self.assertEqual(len(self._db_map.query(self._db_map.entity_group_sq).all()), 0) + items, _ = self._db_map.add_entity_groups({"entity_id": 1, "entity_class_id": 1, "member_id": 2}) + self._db_map.remove_items("entity_group", *{x["id"] for x in items}) self._db_map.commit_session("delete") self.assertEqual(len(self._db_map.query(self._db_map.entity_group_sq).all()), 0) @@ -156,49 +153,44 @@ def test_remove_entity_group_from_committed_session(self): """Test removing an entity group from a committed session""" self._db_map.add_object_classes({"name": "oc1", "id": 1}) self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 1}) - ids, _ = self._db_map.add_entity_groups({"entity_id": 1, "entity_class_id": 1, "member_id": 2}) + items, _ = self._db_map.add_entity_groups({"entity_id": 1, "entity_class_id": 1, "member_id": 2}) self._db_map.commit_session("add") self.assertEqual(len(self._db_map.query(self._db_map.entity_group_sq).all()), 1) - self._db_map.remove_items(entity_group=ids) - self.assertEqual(len(self._db_map.query(self._db_map.entity_group_sq).all()), 0) + self._db_map.remove_items("entity_group", *{x["id"] for x in items}) self._db_map.commit_session("delete") self.assertEqual(len(self._db_map.query(self._db_map.entity_group_sq).all()), 0) def test_cascade_remove_relationship_class(self): """Test adding and removing a relationship class and committing""" self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) - ids, _ = self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "object_class_id_list": [1, 2]}) - self._db_map.cascade_remove_items(relationship_class=ids) - self.assertEqual(len(self._db_map.query(self._db_map.wide_relationship_class_sq).all()), 0) + items, _ = self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "object_class_id_list": [1, 2]}) + self._db_map.remove_items("relationship_class", *{x["id"] for x in items}) self._db_map.commit_session("delete") self.assertEqual(len(self._db_map.query(self._db_map.wide_relationship_class_sq).all()), 0) def test_cascade_remove_relationship_class_from_committed_session(self): """Test removing a relationship class from a committed session""" self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) - ids, _ = self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "object_class_id_list": [1, 2]}) + items, _ = self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "object_class_id_list": [1, 2]}) self._db_map.commit_session("add") self.assertEqual(len(self._db_map.query(self._db_map.wide_relationship_class_sq).all()), 1) - self._db_map.cascade_remove_items(relationship_class=ids) - self.assertEqual(len(self._db_map.query(self._db_map.wide_relationship_class_sq).all()), 0) - self._db_map.commit_session("Add test data.") + self._db_map.remove_items("relationship_class", *{x["id"] for x in items}) + self._db_map.commit_session("remove") self.assertEqual(len(self._db_map.query(self._db_map.wide_relationship_class_sq).all()), 0) def test_remove_object_class(self): """Test adding and removing an object class and committing""" - ids, _ = self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) - self._db_map.remove_items(object_class=ids) - self.assertEqual(len(self._db_map.query(self._db_map.object_class_sq).all()), 0) + items, _ = self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) + self._db_map.remove_items("object_class", *{x["id"] for x in items}) self._db_map.commit_session("delete") self.assertEqual(len(self._db_map.query(self._db_map.object_class_sq).all()), 0) def test_remove_object_class_from_committed_session(self): """Test removing an object class from a committed session""" - ids, _ = self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) + items, _ = self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) self._db_map.commit_session("add") self.assertEqual(len(self._db_map.query(self._db_map.object_class_sq).all()), 2) - self._db_map.remove_items(object_class=ids) - self.assertEqual(len(self._db_map.query(self._db_map.object_class_sq).all()), 0) + self._db_map.remove_items("object_class", *{x["id"] for x in items}) self._db_map.commit_session("Add test data.") self.assertEqual(len(self._db_map.query(self._db_map.object_class_sq).all()), 0) @@ -218,9 +210,9 @@ def test_remove_parameter_value(self): }, strict=True, ) + self._db_map.commit_session("add") self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 1) - self._db_map.remove_items(parameter_value=[1]) - self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 0) + self._db_map.remove_items("parameter_value", 1) self._db_map.commit_session("delete") self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 0) @@ -242,8 +234,7 @@ def test_remove_parameter_value_from_committed_session(self): ) self._db_map.commit_session("add") self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 1) - self._db_map.remove_items(parameter_value=[1]) - self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 0) + self._db_map.remove_items("parameter_value", 1) self._db_map.commit_session("delete") self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 0) @@ -263,9 +254,9 @@ def test_cascade_remove_object_removes_parameter_value_as_well(self): }, strict=True, ) + self._db_map.commit_session("add") self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 1) - self._db_map.cascade_remove_items(object={1}) - self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 0) + self._db_map.remove_items("object", 1) self._db_map.commit_session("delete") self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 0) @@ -287,8 +278,7 @@ def test_cascade_remove_object_from_committed_session_removes_parameter_value_as ) self._db_map.commit_session("add") self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 1) - self._db_map.cascade_remove_items(object={1}) - self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 0) + self._db_map.remove_items("object", 1) self._db_map.commit_session("delete") self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 0) @@ -307,7 +297,8 @@ def test_cascade_remove_metadata_removes_corresponding_entity_and_value_metadata self._db_map.commit_session("Add test data.") metadata = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata), 1) - self._db_map.cascade_remove_items(**{"metadata": {metadata[0].id}}) + self._db_map.remove_items("metadata", metadata[0].id) + self._db_map.commit_session("Remove test data.") self.assertEqual(len(self._db_map.query(self._db_map.metadata_sq).all()), 0) self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_metadata_sq).all()), 0) self.assertEqual(len(self._db_map.query(self._db_map.entity_metadata_sq).all()), 0) @@ -322,7 +313,8 @@ def test_cascade_remove_entity_metadata_removes_corresponding_metadata(self): self._db_map.commit_session("Add test data.") entity_metadata = self._db_map.query(self._db_map.entity_metadata_sq).all() self.assertEqual(len(entity_metadata), 1) - self._db_map.cascade_remove_items(**{"entity_metadata": {entity_metadata[0].id}}) + self._db_map.remove_items("entity_metadata", entity_metadata[0].id) + self._db_map.commit_session("Remove test data.") self.assertEqual(len(self._db_map.query(self._db_map.metadata_sq).all()), 0) self.assertEqual(len(self._db_map.query(self._db_map.entity_metadata_sq).all()), 0) self.assertEqual(len(self._db_map.query(self._db_map.object_sq).all()), 1) @@ -342,7 +334,8 @@ def test_cascade_remove_entity_metadata_leaves_metadata_used_by_value_intact(sel self._db_map.commit_session("Add test data.") entity_metadata = self._db_map.query(self._db_map.entity_metadata_sq).all() self.assertEqual(len(entity_metadata), 1) - self._db_map.cascade_remove_items(**{"entity_metadata": {entity_metadata[0].id}}) + self._db_map.remove_items("entity_metadata", entity_metadata[0].id) + self._db_map.commit_session("Remove test data.") self.assertEqual(len(self._db_map.query(self._db_map.metadata_sq).all()), 1) self.assertEqual(len(self._db_map.query(self._db_map.entity_metadata_sq).all()), 0) self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_metadata_sq).all()), 1) @@ -362,7 +355,8 @@ def test_cascade_remove_value_metadata_leaves_metadata_used_by_entity_intact(sel self._db_map.commit_session("Add test data.") parameter_value_metadata = self._db_map.query(self._db_map.parameter_value_metadata_sq).all() self.assertEqual(len(parameter_value_metadata), 1) - self._db_map.cascade_remove_items(**{"parameter_value_metadata": {parameter_value_metadata[0].id}}) + self._db_map.remove_items("parameter_value_metadata", parameter_value_metadata[0].id) + self._db_map.commit_session("Remove test data.") self.assertEqual(len(self._db_map.query(self._db_map.metadata_sq).all()), 1) self.assertEqual(len(self._db_map.query(self._db_map.entity_metadata_sq).all()), 1) self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_metadata_sq).all()), 0) @@ -373,7 +367,8 @@ def test_cascade_remove_object_removes_its_metadata(self): import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) import_functions.import_object_metadata(self._db_map, (("my_class", "my_object", '{"title": "My metadata."}'),)) self._db_map.commit_session("Add test data.") - self._db_map.cascade_remove_items(**{"object": {1}}) + self._db_map.remove_items("object", 1) + self._db_map.commit_session("Remove test data.") self.assertEqual(len(self._db_map.query(self._db_map.metadata_sq).all()), 0) self.assertEqual(len(self._db_map.query(self._db_map.entity_metadata_sq).all()), 0) self.assertEqual(len(self._db_map.query(self._db_map.object_sq).all()), 0) @@ -388,7 +383,8 @@ def test_cascade_remove_relationship_removes_its_metadata(self): self._db_map, (("my_class", ("my_object",), '{"title": "My metadata."}'),) ) self._db_map.commit_session("Add test data.") - self._db_map.cascade_remove_items(**{"relationship": {2}}) + self._db_map.remove_items("relationship", 2) + self._db_map.commit_session("Remove test data.") self.assertEqual(len(self._db_map.query(self._db_map.metadata_sq).all()), 0) self.assertEqual(len(self._db_map.query(self._db_map.entity_metadata_sq).all()), 0) self.assertEqual(len(self._db_map.query(self._db_map.relationship_sq).all()), 0) @@ -405,7 +401,8 @@ def test_cascade_remove_parameter_value_removes_its_metadata(self): self._db_map, (("my_class", "my_object", "my_parameter", '{"title": "My metadata."}'),) ) self._db_map.commit_session("Add test data.") - self._db_map.cascade_remove_items(**{"parameter_value": {1}}) + self._db_map.remove_items("parameter_value", 1) + self._db_map.commit_session("Remove test data.") self.assertEqual(len(self._db_map.query(self._db_map.metadata_sq).all()), 0) self.assertEqual(len(self._db_map.query(self._db_map.entity_metadata_sq).all()), 0) self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 0) @@ -420,8 +417,8 @@ def tearDown(self): def test_add_and_retrieve_many_objects(self): """Tests add many objects into db and retrieving them.""" - ids, _ = self._db_map.add_object_classes({"name": "testclass"}) - class_id = next(iter(ids)) + items, _ = self._db_map.add_object_classes({"name": "testclass"}) + class_id = next(iter(items))["id"] added = self._db_map.add_objects(*[{"name": str(i), "class_id": class_id} for i in range(1001)])[0] self.assertEqual(len(added), 1001) self._db_map.commit_session("test_commit") @@ -430,6 +427,7 @@ def test_add_and_retrieve_many_objects(self): def test_add_object_classes(self): """Test that adding object classes works.""" self._db_map.add_object_classes({"name": "fish"}, {"name": "dog"}) + self._db_map.commit_session("add") object_classes = self._db_map.query(self._db_map.object_class_sq).all() self.assertEqual(len(object_classes), 2) self.assertEqual(object_classes[0].name, "fish") @@ -443,6 +441,7 @@ def test_add_object_class_with_invalid_name(self): def test_add_object_classes_with_same_name(self): """Test that adding two object classes with the same name only adds one of them.""" self._db_map.add_object_classes({"name": "fish"}, {"name": "fish"}) + self._db_map.commit_session("add") object_classes = self._db_map.query(self._db_map.object_class_sq).all() self.assertEqual(len(object_classes), 1) self.assertEqual(object_classes[0].name, "fish") @@ -457,6 +456,7 @@ def test_add_objects(self): """Test that adding objects works.""" self._db_map.add_object_classes({"name": "fish"}) self._db_map.add_objects({"name": "nemo", "class_id": 1}, {"name": "dory", "class_id": 1}) + self._db_map.commit_session("add") objects = self._db_map.query(self._db_map.object_sq).all() self.assertEqual(len(objects), 2) self.assertEqual(objects[0].name, "nemo") @@ -474,6 +474,7 @@ def test_add_objects_with_same_name(self): """Test that adding two objects with the same name only adds one of them.""" self._db_map.add_object_classes({"name": "fish"}) self._db_map.add_objects({"name": "nemo", "class_id": 1}, {"name": "nemo", "class_id": 1}) + self._db_map.commit_session("add") objects = self._db_map.query(self._db_map.object_sq).all() self.assertEqual(len(objects), 1) self.assertEqual(objects[0].name, "nemo") @@ -498,8 +499,9 @@ def test_add_relationship_classes(self): self._db_map.add_wide_relationship_classes( {"name": "rc1", "object_class_id_list": [1, 2]}, {"name": "rc2", "object_class_id_list": [2, 1]} ) - diff_table = self._db_map.get_table("entity_class_dimension") - ent_cls_dims = self._db_map.query(diff_table).all() + self._db_map.commit_session("add") + table = self._db_map.get_table("entity_class_dimension") + ent_cls_dims = self._db_map.query(table).all() rel_clss = self._db_map.query(self._db_map.wide_relationship_class_sq).all() self.assertEqual(len(ent_cls_dims), 4) self.assertEqual(rel_clss[0].name, "rc1") @@ -519,10 +521,13 @@ def test_add_relationship_classes_with_same_name(self): """Test that adding two relationship classes with the same name only adds one of them.""" self._db_map.add_object_classes({"name": "oc1"}, {"name": "oc2"}) self._db_map.add_wide_relationship_classes( - {"name": "rc1", "object_class_id_list": [1, 2]}, {"name": "rc1", "object_class_id_list": [1, 2]} + {"name": "rc1", "object_class_id_list": [1, 2]}, + {"name": "rc1", "object_class_id_list": [1, 2]}, + strict=False, ) - diff_table = self._db_map.get_table("entity_class_dimension") - ecs_dims = self._db_map.query(diff_table).all() + self._db_map.commit_session("add") + table = self._db_map.get_table("entity_class_dimension") + ecs_dims = self._db_map.query(table).all() relationship_classes = self._db_map.query(self._db_map.wide_relationship_class_sq).all() self.assertEqual(len(ecs_dims), 2) self.assertEqual(len(relationship_classes), 1) @@ -570,6 +575,7 @@ def test_add_relationships(self): self._db_map.add_wide_relationship_classes({"name": "rc1", "object_class_id_list": [1, 2]}) self._db_map.add_objects({"name": "o1", "class_id": 1}, {"name": "o2", "class_id": 2}) self._db_map.add_wide_relationships({"name": "nemo__pluto", "class_id": 3, "object_id_list": [1, 2]}) + self._db_map.commit_session("add") ent_els = self._db_map.query(self._db_map.get_table("entity_element")).all() relationships = self._db_map.query(self._db_map.wide_relationship_sq).all() self.assertEqual(len(ent_els), 2) @@ -586,7 +592,7 @@ def test_add_relationship_with_invalid_name(self): self._db_map.add_wide_relationship_classes({"name": "rc1", "object_class_id_list": [1]}, strict=True) self._db_map.add_objects({"name": "o1", "class_id": 1}, strict=True) with self.assertRaises(SpineIntegrityError): - self._db_map.add_wide_relationships({"name": "", "class_id": 1, "object_id_list": [1]}, strict=True) + self._db_map.add_wide_relationships({"name": "", "class_id": 2, "object_id_list": [1]}, strict=True) def test_add_identical_relationships(self): """Test that adding two relationships with the same class and same objects only adds the first one.""" @@ -597,6 +603,7 @@ def test_add_identical_relationships(self): {"name": "nemo__pluto", "class_id": 3, "object_id_list": [1, 2]}, {"name": "nemo__pluto_duplicate", "class_id": 3, "object_id_list": [1, 2]}, ) + self._db_map.commit_session("add") relationships = self._db_map.query(self._db_map.wide_relationship_sq).all() self.assertEqual(len(relationships), 1) @@ -679,8 +686,9 @@ def test_add_entity_groups(self): self._db_map.add_object_classes({"name": "oc1", "id": 1}) self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 1}) self._db_map.add_entity_groups({"entity_id": 1, "entity_class_id": 1, "member_id": 2}) - diff_table = self._db_map.get_table("entity_group") - entity_groups = self._db_map.query(diff_table).all() + self._db_map.commit_session("add") + table = self._db_map.get_table("entity_group") + entity_groups = self._db_map.query(table).all() self.assertEqual(len(entity_groups), 1) self.assertEqual(entity_groups[0].entity_id, 1) self.assertEqual(entity_groups[0].entity_class_id, 1) @@ -723,8 +731,9 @@ def test_add_parameter_definitions(self): {"name": "color", "object_class_id": 1, "description": "test1"}, {"name": "relative_speed", "relationship_class_id": 3, "description": "test2"}, ) - diff_table = self._db_map.get_table("parameter_definition") - parameter_definitions = self._db_map.query(diff_table).all() + self._db_map.commit_session("add") + table = self._db_map.get_table("parameter_definition") + parameter_definitions = self._db_map.query(table).all() self.assertEqual(len(parameter_definitions), 2) self.assertEqual(parameter_definitions[0].name, "color") self.assertEqual(parameter_definitions[0].entity_class_id, 1) @@ -746,8 +755,9 @@ def test_add_parameter_definitions_with_same_name(self): self._db_map.add_parameter_definitions( {"name": "color", "object_class_id": 1}, {"name": "color", "relationship_class_id": 3} ) - diff_table = self._db_map.get_table("parameter_definition") - parameter_definitions = self._db_map.query(diff_table).all() + self._db_map.commit_session("add") + table = self._db_map.get_table("parameter_definition") + parameter_definitions = self._db_map.query(table).all() self.assertEqual(len(parameter_definitions), 2) self.assertEqual(parameter_definitions[0].name, "color") self.assertEqual(parameter_definitions[1].name, "color") @@ -790,20 +800,21 @@ def test_add_parameter_values(self): import_functions.import_relationships(self._db_map, [("fish_dog", ("nemo", "pluto"))]) import_functions.import_object_parameters(self._db_map, [("fish", "color")]) import_functions.import_relationship_parameters(self._db_map, [("fish_dog", "rel_speed")]) + self._db_map.commit_session("add") color_id = ( - self._db_map.parameter_definition_list() + self._db_map.query(self._db_map.parameter_definition_sq) .filter(self._db_map.parameter_definition_sq.c.name == "color") .first() .id ) rel_speed_id = ( - self._db_map.parameter_definition_list() + self._db_map.query(self._db_map.parameter_definition_sq) .filter(self._db_map.parameter_definition_sq.c.name == "rel_speed") .first() .id ) - nemo_row = self._db_map.object_list().filter(self._db_map.entity_sq.c.name == "nemo").first() - nemo__pluto_row = self._db_map.wide_relationship_list().filter().first() + nemo_row = self._db_map.query(self._db_map.object_sq).filter(self._db_map.object_sq.c.name == "nemo").first() + nemo__pluto_row = self._db_map.query(self._db_map.wide_relationship_sq).first() self._db_map.add_parameter_values( { "parameter_definition_id": color_id, @@ -820,8 +831,9 @@ def test_add_parameter_values(self): "alternative_id": 1, }, ) - diff_table = self._db_map.get_table("parameter_value") - parameter_values = self._db_map.query(diff_table).all() + self._db_map.commit_session("add") + table = self._db_map.get_table("parameter_value") + parameter_values = self._db_map.query(table).all() self.assertEqual(len(parameter_values), 2) self.assertEqual(parameter_values[0].parameter_definition_id, 1) self.assertEqual(parameter_values[0].entity_id, 1) @@ -853,13 +865,14 @@ def test_add_same_parameter_value_twice(self): import_functions.import_object_classes(self._db_map, ["fish"]) import_functions.import_objects(self._db_map, [("fish", "nemo")]) import_functions.import_object_parameters(self._db_map, [("fish", "color")]) + self._db_map.commit_session("add") color_id = ( - self._db_map.parameter_definition_list() + self._db_map.query(self._db_map.parameter_definition_sq) .filter(self._db_map.parameter_definition_sq.c.name == "color") .first() .id ) - nemo_row = self._db_map.object_list().filter(self._db_map.entity_sq.c.name == "nemo").first() + nemo_row = self._db_map.query(self._db_map.object_sq).filter(self._db_map.entity_sq.c.name == "nemo").first() self._db_map.add_parameter_values( { "parameter_definition_id": color_id, @@ -876,8 +889,9 @@ def test_add_same_parameter_value_twice(self): "alternative_id": 1, }, ) - diff_table = self._db_map.get_table("parameter_value") - parameter_values = self._db_map.query(diff_table).all() + self._db_map.commit_session("add") + table = self._db_map.get_table("parameter_value") + parameter_values = self._db_map.query(table).all() self.assertEqual(len(parameter_values), 1) self.assertEqual(parameter_values[0].parameter_definition_id, 1) self.assertEqual(parameter_values[0].entity_id, 1) @@ -900,66 +914,76 @@ def test_add_existing_parameter_value(self): strict=False, ) self.assertEqual( - [str(e) for e in errors], ["The value of parameter 'color' for entity 'nemo' is already specified."] + [str(e) for e in errors], + [ + "there's already a parameter_value with " + "{'parameter_definition_name': 'color', 'entity_byname': ('nemo',), 'alternative_name': 'Base'}" + ], ) def test_add_alternative(self): - ids, errors = self._db_map.add_alternatives({"name": "my_alternative"}) + items, errors = self._db_map.add_alternatives({"name": "my_alternative"}) + ids = {x["id"] for x in items} self.assertEqual(errors, []) self.assertEqual(ids, {2}) + self._db_map.commit_session("Add test data.") alternatives = self._db_map.query(self._db_map.alternative_sq).all() self.assertEqual(len(alternatives), 2) self.assertEqual( - alternatives[0]._asdict(), {"id": 1, "name": "Base", "description": "Base alternative", "commit_id": 1} + dict(alternatives[0]), {"id": 1, "name": "Base", "description": "Base alternative", "commit_id": 1} ) self.assertEqual( - alternatives[1]._asdict(), {"id": 2, "name": "my_alternative", "description": None, "commit_id": 2} + dict(alternatives[1]), {"id": 2, "name": "my_alternative", "description": None, "commit_id": None} ) def test_add_scenario(self): - ids, errors = self._db_map.add_scenarios({"name": "my_scenario"}) + items, errors = self._db_map.add_scenarios({"name": "my_scenario"}) + ids = {x["id"] for x in items} self.assertEqual(errors, []) self.assertEqual(ids, {1}) + self._db_map.commit_session("Add test data.") scenarios = self._db_map.query(self._db_map.scenario_sq).all() self.assertEqual(len(scenarios), 1) self.assertEqual( - scenarios[0]._asdict(), - {"id": 1, "name": "my_scenario", "description": None, "active": False, "commit_id": 2}, + dict(scenarios[0]), + {"id": 1, "name": "my_scenario", "description": None, "active": False, "commit_id": None}, ) def test_add_scenario_alternative(self): import_functions.import_scenarios(self._db_map, ("my_scenario",)) self._db_map.commit_session("Add test data.") - ids, errors = self._db_map.add_scenario_alternatives({"scenario_id": 1, "alternative_id": 1, "rank": 0}) + items, errors = self._db_map.add_scenario_alternatives({"scenario_id": 1, "alternative_id": 1, "rank": 0}) + ids = {x["id"] for x in items} self.assertEqual(errors, []) self.assertEqual(ids, {1}) + self._db_map.commit_session("Add test data.") scenario_alternatives = self._db_map.query(self._db_map.scenario_alternative_sq).all() self.assertEqual(len(scenario_alternatives), 1) self.assertEqual( - scenario_alternatives[0]._asdict(), - {"id": 1, "scenario_id": 1, "alternative_id": 1, "rank": 0, "commit_id": 3}, + dict(scenario_alternatives[0]), + {"id": 1, "scenario_id": 1, "alternative_id": 1, "rank": 0, "commit_id": None}, ) def test_add_metadata(self): items, errors = self._db_map.add_metadata({"name": "test name", "value": "test_add_metadata"}, strict=False) + ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(items, {1}) + self.assertEqual(ids, {1}) self._db_map.commit_session("Add metadata") metadata = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata), 1) self.assertEqual( - metadata[0]._asdict(), {"name": "test name", "id": 1, "value": "test_add_metadata", "commit_id": 2} + dict(metadata[0]), {"name": "test name", "id": 1, "value": "test_add_metadata", "commit_id": None} ) def test_add_metadata_that_exists_does_not_add_it(self): import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) self._db_map.commit_session("Add test data.") - items, errors = self._db_map.add_metadata({"name": "title", "value": "My metadata."}, strict=False) - self.assertEqual(errors, []) - self.assertEqual(items, set()) + items, _ = self._db_map.add_metadata({"name": "title", "value": "My metadata."}, strict=False) + self.assertEqual(items, []) metadata = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata), 1) - self.assertEqual(metadata[0]._asdict(), {"name": "title", "id": 1, "value": "My metadata.", "commit_id": 2}) + self.assertEqual(dict(metadata[0]), {"name": "title", "id": 1, "value": "My metadata.", "commit_id": None}) def test_add_entity_metadata_for_object(self): import_functions.import_object_classes(self._db_map, ("fish",)) @@ -967,13 +991,14 @@ def test_add_entity_metadata_for_object(self): import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) self._db_map.commit_session("Add test data.") items, errors = self._db_map.add_entity_metadata({"entity_id": 1, "metadata_id": 1}, strict=False) + ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(items, {1}) + self.assertEqual(ids, {1}) self._db_map.commit_session("Add entity metadata") entity_metadata = self._db_map.query(self._db_map.ext_entity_metadata_sq).all() self.assertEqual(len(entity_metadata), 1) self.assertEqual( - entity_metadata[0]._asdict(), + dict(entity_metadata[0]), { "entity_id": 1, "entity_name": "leviathan", @@ -981,7 +1006,7 @@ def test_add_entity_metadata_for_object(self): "metadata_value": "My metadata.", "metadata_id": 1, "id": 1, - "commit_id": 3, + "commit_id": None, }, ) @@ -993,13 +1018,14 @@ def test_add_entity_metadata_for_relationship(self): import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) self._db_map.commit_session("Add test data.") items, errors = self._db_map.add_entity_metadata({"entity_id": 2, "metadata_id": 1}, strict=False) + ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(items, {1}) + self.assertEqual(ids, {1}) self._db_map.commit_session("Add entity metadata") entity_metadata = self._db_map.query(self._db_map.ext_entity_metadata_sq).all() self.assertEqual(len(entity_metadata), 1) self.assertEqual( - entity_metadata[0]._asdict(), + dict(entity_metadata[0]), { "entity_id": 2, "entity_name": "my_relationship_class_my_object", @@ -1007,15 +1033,13 @@ def test_add_entity_metadata_for_relationship(self): "metadata_value": "My metadata.", "metadata_id": 1, "id": 1, - "commit_id": 3, + "commit_id": None, }, ) def test_add_entity_metadata_doesnt_raise_with_empty_cache(self): - items, errors = self._db_map.add_entity_metadata( - {"entity_id": 1, "metadata_id": 1}, cache=DBCache(lambda *args, **kwargs: None), strict=False - ) - self.assertEqual(items, set()) + items, errors = self._db_map.add_entity_metadata({"entity_id": 1, "metadata_id": 1}, strict=False) + self.assertEqual(items, []) self.assertEqual(len(errors), 1) def test_add_ext_entity_metadata_for_object(self): @@ -1025,13 +1049,14 @@ def test_add_ext_entity_metadata_for_object(self): items, errors = self._db_map.add_ext_entity_metadata( {"entity_id": 1, "metadata_name": "key", "metadata_value": "object metadata"}, strict=False ) + ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(items, {1}) + self.assertEqual(ids, {1}) self._db_map.commit_session("Add entity metadata") entity_metadata = self._db_map.query(self._db_map.ext_entity_metadata_sq).all() self.assertEqual(len(entity_metadata), 1) self.assertEqual( - entity_metadata[0]._asdict(), + dict(entity_metadata[0]), { "entity_id": 1, "entity_name": "leviathan", @@ -1039,7 +1064,7 @@ def test_add_ext_entity_metadata_for_object(self): "metadata_value": "object metadata", "metadata_id": 1, "id": 1, - "commit_id": 3, + "commit_id": None, }, ) @@ -1051,16 +1076,17 @@ def test_adding_ext_entity_metadata_for_object_reuses_existing_metadata_names_an items, errors = self._db_map.add_ext_entity_metadata( {"entity_id": 1, "metadata_name": "title", "metadata_value": "My metadata."}, strict=False ) + ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(items, {1}) + self.assertEqual(ids, {1}) self._db_map.commit_session("Add entity metadata") metadata = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata), 1) - self.assertEqual(metadata[0]._asdict(), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": 2}) + self.assertEqual(dict(metadata[0]), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": None}) entity_metadata = self._db_map.query(self._db_map.ext_entity_metadata_sq).all() self.assertEqual(len(entity_metadata), 1) self.assertEqual( - entity_metadata[0]._asdict(), + dict(entity_metadata[0]), { "entity_id": 1, "entity_name": "leviathan", @@ -1068,7 +1094,7 @@ def test_adding_ext_entity_metadata_for_object_reuses_existing_metadata_names_an "metadata_value": "My metadata.", "metadata_id": 1, "id": 1, - "commit_id": 3, + "commit_id": None, }, ) @@ -1082,13 +1108,14 @@ def test_add_parameter_value_metadata(self): items, errors = self._db_map.add_parameter_value_metadata( {"parameter_value_id": 1, "metadata_id": 1, "alternative_id": 1}, strict=False ) + ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(items, {1}) + self.assertEqual(ids, {1}) self._db_map.commit_session("Add value metadata") value_metadata = self._db_map.query(self._db_map.ext_parameter_value_metadata_sq).all() self.assertEqual(len(value_metadata), 1) self.assertEqual( - value_metadata[0]._asdict(), + dict(value_metadata[0]), { "alternative_name": "Base", "entity_name": "leviathan", @@ -1098,7 +1125,7 @@ def test_add_parameter_value_metadata(self): "metadata_value": "My metadata.", "metadata_id": 1, "id": 1, - "commit_id": 3, + "commit_id": None, }, ) @@ -1108,7 +1135,8 @@ def test_add_parameter_value_metadata_doesnt_raise_with_empty_cache(self): cache=DBCache(lambda *args, **kwargs: None), strict=False, ) - self.assertEqual(items, set()) + ids = {x["id"] for x in items} + self.assertEqual(ids, set()) self.assertEqual(len(errors), 1) def test_add_ext_parameter_value_metadata(self): @@ -1126,13 +1154,14 @@ def test_add_ext_parameter_value_metadata(self): }, strict=False, ) + ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(items, {1}) + self.assertEqual(ids, {1}) self._db_map.commit_session("Add value metadata") value_metadata = self._db_map.query(self._db_map.ext_parameter_value_metadata_sq).all() self.assertEqual(len(value_metadata), 1) self.assertEqual( - value_metadata[0]._asdict(), + dict(value_metadata[0]), { "alternative_name": "Base", "entity_name": "leviathan", @@ -1142,7 +1171,7 @@ def test_add_ext_parameter_value_metadata(self): "metadata_value": "parameter metadata", "metadata_id": 1, "id": 1, - "commit_id": 3, + "commit_id": None, }, ) @@ -1157,16 +1186,17 @@ def test_add_ext_parameter_value_metadata_reuses_existing_metadata(self): {"parameter_value_id": 1, "metadata_name": "title", "metadata_value": "My metadata.", "alternative_id": 1}, strict=False, ) + ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(items, {1}) + self.assertEqual(ids, {1}) self._db_map.commit_session("Add value metadata") metadata = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata), 1) - self.assertEqual(metadata[0]._asdict(), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": 2}) + self.assertEqual(dict(metadata[0]), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": None}) value_metadata = self._db_map.query(self._db_map.ext_parameter_value_metadata_sq).all() self.assertEqual(len(value_metadata), 1) self.assertEqual( - value_metadata[0]._asdict(), + dict(value_metadata[0]), { "alternative_name": "Base", "entity_name": "leviathan", @@ -1176,7 +1206,7 @@ def test_add_ext_parameter_value_metadata_reuses_existing_metadata(self): "metadata_value": "My metadata.", "metadata_id": 1, "id": 1, - "commit_id": 3, + "commit_id": None, }, ) @@ -1191,9 +1221,11 @@ def tearDown(self): def test_update_object_classes(self): """Test that updating object classes works.""" self._db_map.add_object_classes({"id": 1, "name": "fish"}, {"id": 2, "name": "dog"}) - ids, intgr_error_log = self._db_map.update_object_classes( + items, intgr_error_log = self._db_map.update_object_classes( {"id": 1, "name": "octopus"}, {"id": 2, "name": "god"} ) + ids = {x["id"] for x in items} + self._db_map.commit_session("test commit") sq = self._db_map.object_class_sq object_classes = {x.id: x.name for x in self._db_map.query(sq).filter(sq.c.id.in_(ids))} self.assertEqual(intgr_error_log, []) @@ -1204,7 +1236,9 @@ def test_update_objects(self): """Test that updating objects works.""" self._db_map.add_object_classes({"id": 1, "name": "fish"}) self._db_map.add_objects({"id": 1, "name": "nemo", "class_id": 1}, {"id": 2, "name": "dory", "class_id": 1}) - ids, intgr_error_log = self._db_map.update_objects({"id": 1, "name": "klaus"}, {"id": 2, "name": "squidward"}) + items, intgr_error_log = self._db_map.update_objects({"id": 1, "name": "klaus"}, {"id": 2, "name": "squidward"}) + ids = {x["id"] for x in items} + self._db_map.commit_session("test commit") sq = self._db_map.object_sq objects = {x.id: x.name for x in self._db_map.query(sq).filter(sq.c.id.in_(ids))} self.assertEqual(intgr_error_log, []) @@ -1215,28 +1249,28 @@ def test_update_objects_not_committed(self): """Test that updating objects works.""" self._db_map.add_object_classes({"id": 1, "name": "some_class"}) self._db_map.add_objects({"id": 1, "name": "nemo", "class_id": 1}) - ids, intgr_error_log = self._db_map.update_objects({"id": 1, "name": "klaus"}) + items, intgr_error_log = self._db_map.update_objects({"id": 1, "name": "klaus"}) + ids = {x["id"] for x in items} + self._db_map.commit_session("test commit") sq = self._db_map.object_sq objects = {x.id: x.name for x in self._db_map.query(sq).filter(sq.c.id.in_(ids))} self.assertEqual(intgr_error_log, []) self.assertEqual(objects[1], "klaus") self.assertEqual(self._db_map.query(self._db_map.object_sq).filter_by(id=1).first().name, "klaus") - self._db_map.commit_session("update") - self.assertEqual(self._db_map.query(self._db_map.object_sq).filter_by(id=1).first().name, "klaus") def test_update_committed_object(self): """Test that updating objects works.""" self._db_map.add_object_classes({"id": 1, "name": "some_class"}) self._db_map.add_objects({"id": 1, "name": "nemo", "class_id": 1}) self._db_map.commit_session("update") - ids, intgr_error_log = self._db_map.update_objects({"id": 1, "name": "klaus"}) + items, intgr_error_log = self._db_map.update_objects({"id": 1, "name": "klaus"}) + ids = {x["id"] for x in items} + self._db_map.commit_session("test commit") sq = self._db_map.object_sq objects = {x.id: x.name for x in self._db_map.query(sq).filter(sq.c.id.in_(ids))} self.assertEqual(intgr_error_log, []) self.assertEqual(objects[1], "klaus") self.assertEqual(self._db_map.query(self._db_map.object_sq).filter_by(id=1).first().name, "klaus") - self._db_map.commit_session("update") - self.assertEqual(self._db_map.query(self._db_map.object_sq).filter_by(id=1).first().name, "klaus") def test_update_relationship_classes(self): """Test that updating relationship classes works.""" @@ -1245,9 +1279,11 @@ def test_update_relationship_classes(self): {"id": 3, "name": "dog__fish", "object_class_id_list": [1, 2]}, {"id": 4, "name": "fish__dog", "object_class_id_list": [2, 1]}, ) - ids, intgr_error_log = self._db_map.update_wide_relationship_classes( + items, intgr_error_log = self._db_map.update_wide_relationship_classes( {"id": 3, "name": "god__octopus"}, {"id": 4, "name": "octopus__dog"} ) + ids = {x["id"] for x in items} + self._db_map.commit_session("test commit") sq = self._db_map.wide_relationship_class_sq rel_clss = {x.id: x.name for x in self._db_map.query(sq).filter(sq.c.id.in_(ids))} self.assertEqual(intgr_error_log, []) @@ -1266,9 +1302,11 @@ def test_update_relationships(self): self._db_map.add_wide_relationships( {"id": 4, "name": "nemo__pluto", "class_id": 3, "object_id_list": [1, 2], "object_class_id_list": [1, 2]} ) - ids, intgr_error_log = self._db_map.update_wide_relationships( + items, intgr_error_log = self._db_map.update_wide_relationships( {"id": 4, "name": "nemo__scooby", "class_id": 3, "object_id_list": [1, 3], "object_class_id_list": [1, 2]} ) + ids = {x["id"] for x in items} + self._db_map.commit_session("test commit") sq = self._db_map.wide_relationship_sq rels = { x.id: {"name": x.name, "object_id_list": x.object_id_list} diff --git a/tests/test_check_functions.py b/tests/test_check_functions.py index 47d75a9e..fb5ce399 100644 --- a/tests/test_check_functions.py +++ b/tests/test_check_functions.py @@ -16,9 +16,8 @@ from spinedb_api.db_cache import DBCache, ParameterValueItem from spinedb_api.exception import SpineIntegrityError -from spinedb_api.check_functions import replace_parameter_values_with_list_references - +@unittest.skip("obsolete, but need to adapt to current check system") class TestCheckFunctions(unittest.TestCase): def setUp(self): self.data = [ diff --git a/tests/test_export_functions.py b/tests/test_export_functions.py index 3ab2f58f..385117f2 100644 --- a/tests/test_export_functions.py +++ b/tests/test_export_functions.py @@ -21,10 +21,6 @@ export_data, export_scenarios, export_scenario_alternatives, - export_tools, - export_features, - export_tool_features, - export_tool_feature_methods, import_alternatives, import_object_classes, import_object_parameter_values, @@ -37,10 +33,6 @@ import_relationships, import_scenarios, import_scenario_alternatives, - import_tools, - import_features, - import_tool_features, - import_tool_feature_methods, ) @@ -52,40 +44,6 @@ def setUp(self): def tearDown(self): self._db_map.connection.close() - def test_export_tools(self): - import_tools(self._db_map, [("tool", "Description")]) - exported = export_tools(self._db_map) - self.assertEqual(exported, [("tool", "Description")]) - - def test_export_features(self): - import_object_classes(self._db_map, ["object_class1", "object_class2"]) - import_parameter_value_lists(self._db_map, [['value_list', 'value1'], ['value_list', 'value2']]) - import_object_parameters(self._db_map, [["object_class1", "parameter1", "value1", "value_list"]]) - import_features(self._db_map, [["object_class1", "parameter1", "Description"]]) - exported = export_features(self._db_map) - self.assertEqual(exported, [("object_class1", "parameter1", "value_list", "Description")]) - - def test_export_tool_features(self): - import_object_classes(self._db_map, ["object_class1", "object_class2"]) - import_parameter_value_lists(self._db_map, [['value_list', 'value1'], ['value_list', 'value2']]) - import_object_parameters(self._db_map, [["object_class1", "parameter1", "value1", "value_list"]]) - import_features(self._db_map, [["object_class1", "parameter1", "Description"]]) - import_tools(self._db_map, ["tool1"]) - import_tool_features(self._db_map, [["tool1", "object_class1", "parameter1"]]) - exported = export_tool_features(self._db_map) - self.assertEqual(exported, [("tool1", "object_class1", "parameter1", False)]) - - def test_export_tool_feature_methods(self): - import_object_classes(self._db_map, ["object_class1", "object_class2"]) - import_parameter_value_lists(self._db_map, [['value_list', 'value1'], ['value_list', 'value2']]) - import_object_parameters(self._db_map, [["object_class1", "parameter1", "value1", "value_list"]]) - import_features(self._db_map, [["object_class1", "parameter1", "Description"]]) - import_tools(self._db_map, ["tool1"]) - import_tool_features(self._db_map, [["tool1", "object_class1", "parameter1"]]) - import_tool_feature_methods(self._db_map, [["tool1", "object_class1", "parameter1", "value2"]]) - exported = export_tool_feature_methods(self._db_map) - self.assertEqual(exported, [("tool1", "object_class1", "parameter1", "value2")]) - def test_export_alternatives(self): import_alternatives(self._db_map, [("alternative", "Description")]) exported = export_alternatives(self._db_map) diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py index 501db071..e381c1f6 100644 --- a/tests/test_import_functions.py +++ b/tests/test_import_functions.py @@ -31,10 +31,6 @@ import_scenario_alternatives, import_scenarios, import_parameter_value_lists, - import_tools, - import_features, - import_tool_features, - import_tool_feature_methods, import_metadata, import_object_metadata, import_relationship_metadata, @@ -120,6 +116,7 @@ def test_import_object_class(self): db_map = create_diff_db_map() _, errors = import_object_classes(db_map, ["new_class"]) self.assertFalse(errors) + db_map.commit_session("test") self.assertIn("new_class", [oc.name for oc in db_map.query(db_map.object_class_sq)]) db_map.connection.close() @@ -130,6 +127,7 @@ def test_import_valid_objects(self): import_object_classes(db_map, ["object_class"]) _, errors = import_objects(db_map, [["object_class", "new_object"]]) self.assertFalse(errors) + db_map.commit_session("test") self.assertIn("new_object", [o.name for o in db_map.query(db_map.object_sq)]) db_map.connection.close() @@ -144,6 +142,7 @@ def test_import_two_objects_with_same_name(self): import_object_classes(db_map, ["object_class1", "object_class2"]) _, errors = import_objects(db_map, [["object_class1", "object"], ["object_class2", "object"]]) self.assertFalse(errors) + db_map.commit_session("test") objects = { o.class_name: o.name for o in db_map.query( @@ -158,9 +157,11 @@ def test_import_existing_object(self): db_map = create_diff_db_map() import_object_classes(db_map, ["object_class"]) import_objects(db_map, [["object_class", "object"]]) + db_map.commit_session("test") self.assertIn("object", [o.name for o in db_map.query(db_map.object_sq)]) _, errors = import_objects(db_map, [["object_class", "object"]]) self.assertFalse(errors) + db_map.commit_session("test") self.assertIn("object", [o.name for o in db_map.query(db_map.object_sq)]) db_map.connection.close() @@ -171,6 +172,7 @@ def test_import_valid_relationship_class(self): import_object_classes(db_map, ["object_class1", "object_class2"]) _, errors = import_relationship_classes(db_map, [["relationship_class", ["object_class1", "object_class2"]]]) self.assertFalse(errors) + db_map.commit_session("test") relationship_classes = { rc.name: rc.object_class_name_list for rc in db_map.query(db_map.wide_relationship_class_sq) } @@ -183,7 +185,8 @@ def test_import_relationship_class_with_invalid_object_class_name(self): import_object_classes(db_map, ["object_class"]) _, errors = import_relationship_classes(db_map, [["relationship_class", ["object_class", "nonexistent"]]]) self.assertTrue(errors) - self.assertFalse([rc for rc in db_map.query(db_map.wide_relationship_class_sq)]) + db_map.commit_session("test") + self.assertFalse(db_map.query(db_map.wide_relationship_class_sq).all()) db_map.connection.close() def test_import_relationship_class_name_twice(self): @@ -193,6 +196,7 @@ def test_import_relationship_class_name_twice(self): db_map, [["new_rc", ["object_class1", "object_class2"]], ["new_rc", ["object_class1", "object_class2"]]] ) self.assertFalse(errors) + db_map.commit_session("test") relationship_classes = { rc.name: rc.object_class_name_list for rc in db_map.query(db_map.wide_relationship_class_sq) } @@ -213,6 +217,7 @@ def test_import_relationship_class_with_one_object_class_as_None(self): import_object_classes(db_map, ["object_class1"]) _, errors = import_relationship_classes(db_map, [["new_rc", ["object_class", None]]]) self.assertTrue(errors) + db_map.commit_session("test") self.assertFalse([rc for rc in db_map.query(db_map.wide_relationship_class_sq)]) db_map.connection.close() @@ -223,6 +228,7 @@ def test_import_valid_object_class_parameter(self): import_object_classes(db_map, ["object_class"]) _, errors = import_object_parameters(db_map, [["object_class", "new_parameter"]]) self.assertFalse(errors) + db_map.commit_session("test") self.assertIn("new_parameter", [p.name for p in db_map.query(db_map.parameter_definition_sq)]) db_map.connection.close() @@ -239,6 +245,7 @@ def test_import_object_class_parameter_name_twice(self): db_map, [["object_class1", "new_parameter"], ["object_class2", "new_parameter"]] ) self.assertFalse(errors) + db_map.commit_session("test") definitions = { definition.object_class_name: definition.parameter_name for definition in db_map.query(db_map.object_parameter_definition_sq) @@ -251,8 +258,10 @@ def test_import_existing_object_class_parameter(self): db_map = create_diff_db_map() import_object_classes(db_map, ["object_class"]) import_object_parameters(db_map, [["object_class", "parameter"]]) + db_map.commit_session("test") self.assertIn("parameter", [p.name for p in db_map.query(db_map.parameter_definition_sq)]) _, errors = import_object_parameters(db_map, [["object_class", "parameter"]]) + db_map.commit_session("test") self.assertIn("parameter", [p.name for p in db_map.query(db_map.parameter_definition_sq)]) self.assertFalse(errors) db_map.connection.close() @@ -279,6 +288,7 @@ def test_import_valid_relationship_class_parameter(self): import_relationship_classes(db_map, [["relationship_class", ["object_class1", "object_class2"]]]) _, errors = import_relationship_parameters(db_map, [["relationship_class", "new_parameter"]]) self.assertFalse(errors) + db_map.commit_session("test") definitions = { d.class_name: d.name for d in db_map.query( @@ -310,6 +320,7 @@ def test_import_relationship_class_parameter_name_twice(self): db_map, [["relationship_class1", "new_parameter"], ["relationship_class2", "new_parameter"]] ) self.assertFalse(errors) + db_map.commit_session("test") definitions = { d.class_name: d.name for d in db_map.query( @@ -344,6 +355,7 @@ def test_import_relationships(self): import_relationship_classes(db_map, (("relationship_class", ("object_class",)),)) _, errors = import_relationships(db_map, (("relationship_class", ("object",)),)) self.assertFalse(errors) + db_map.commit_session("test") self.assertIn("relationship_class_object", [r.name for r in db_map.query(db_map.relationship_sq)]) db_map.connection.close() @@ -353,6 +365,7 @@ def test_import_valid_relationship(self): import_relationship_classes(db_map, [["relationship_class", ["object_class1", "object_class2"]]]) _, errors = import_relationships(db_map, [["relationship_class", ["object1", "object2"]]]) self.assertFalse(errors) + db_map.commit_session("test") self.assertIn("relationship_class_object1__object2", [r.name for r in db_map.query(db_map.relationship_sq)]) db_map.connection.close() @@ -363,6 +376,7 @@ def test_import_valid_relationship_with_object_name_in_multiple_classes(self): import_relationship_classes(db_map, [["relationship_class", ["object_class1", "object_class2"]]]) _, errors = import_relationships(db_map, [["relationship_class", ["duplicate", "object2"]]]) self.assertFalse(errors) + db_map.commit_session("test") self.assertIn("relationship_class_duplicate__object2", [r.name for r in db_map.query(db_map.relationship_sq)]) db_map.connection.close() @@ -371,6 +385,7 @@ def test_import_relationship_with_invalid_class_name(self): self.populate(db_map) _, errors = import_relationships(db_map, [["nonexistent_relationship_class", ["object1", "object2"]]]) self.assertTrue(errors) + db_map.commit_session("test") self.assertFalse([r.name for r in db_map.query(db_map.relationship_sq)]) db_map.connection.close() @@ -380,6 +395,7 @@ def test_import_relationship_with_invalid_object_name(self): import_relationship_classes(db_map, [["relationship_class", ["object_class1", "object_class2"]]]) _, errors = import_relationships(db_map, [["relationship_class", ["nonexistent_object", "object2"]]]) self.assertTrue(errors) + db_map.commit_session("test") self.assertFalse([r.name for r in db_map.query(db_map.relationship_sq)]) db_map.connection.close() @@ -388,9 +404,11 @@ def test_import_existing_relationship(self): self.populate(db_map) import_relationship_classes(db_map, [["relationship_class", ["object_class1", "object_class2"]]]) import_relationships(db_map, [["relationship_class", ["object1", "object2"]]]) + db_map.commit_session("test") self.assertIn("relationship_class_object1__object2", [r.name for r in db_map.query(db_map.relationship_sq)]) _, errors = import_relationships(db_map, [["relationship_class", ["object1", "object2"]]]) self.assertFalse(errors) + db_map.commit_session("test") self.assertIn("relationship_class_object1__object2", [r.name for r in db_map.query(db_map.relationship_sq)]) db_map.connection.close() @@ -400,6 +418,7 @@ def test_import_relationship_with_one_None_object(self): import_relationship_classes(db_map, [["relationship_class", ["object_class1", "object_class2"]]]) _, errors = import_relationships(db_map, [["relationship_class", [None, "object2"]]]) self.assertTrue(errors) + db_map.commit_session("test") self.assertFalse([r.name for r in db_map.query(db_map.relationship_sq)]) db_map.connection.close() @@ -417,9 +436,7 @@ def test_import_object_parameter_definition(self): self.assertEqual(errors, []) self.assertEqual(count, 1) self._db_map.commit_session("Add test data.") - parameter_definitions = [ - row._asdict() for row in self._db_map.query(self._db_map.object_parameter_definition_sq) - ] + parameter_definitions = [dict(row) for row in self._db_map.query(self._db_map.object_parameter_definition_sq)] self.assertEqual( parameter_definitions, [ @@ -446,9 +463,7 @@ def test_import_object_parameter_definition_with_value_list(self): self.assertEqual(errors, []) self.assertEqual(count, 1) self._db_map.commit_session("Add test data.") - parameter_definitions = [ - row._asdict() for row in self._db_map.query(self._db_map.object_parameter_definition_sq) - ] + parameter_definitions = [dict(row) for row in self._db_map.query(self._db_map.object_parameter_definition_sq)] self.assertEqual( parameter_definitions, [ @@ -475,9 +490,7 @@ def test_import_object_parameter_definition_with_default_value_from_value_list(s self.assertEqual(errors, []) self.assertEqual(count, 1) self._db_map.commit_session("Add test data.") - parameter_definitions = [ - row._asdict() for row in self._db_map.query(self._db_map.object_parameter_definition_sq) - ] + parameter_definitions = [dict(row) for row in self._db_map.query(self._db_map.object_parameter_definition_sq)] self.assertEqual( parameter_definitions, [ @@ -501,13 +514,7 @@ def test_import_object_parameter_definition_with_default_value_from_value_list_f import_object_classes(self._db_map, ["my_object_class"]) import_parameter_value_lists(self._db_map, (("my_list", 99.0),)) count, errors = import_object_parameters(self._db_map, (("my_object_class", "my_parameter", 23.0, "my_list"),)) - self.assertEqual( - [error.msg for error in errors], - [ - "Could not import parameter 'my_parameter' with class 'my_object_class': " - "Invalid default_value '23.0' - it should be one from the parameter value list: '99.0'." - ], - ) + self.assertEqual(errors, ["default value 23.0 of my_parameter is not in my_list"]) self.assertEqual(count, 0) @@ -530,6 +537,7 @@ def test_import_valid_object_parameter_value(self): self.populate(db_map) _, errors = import_object_parameter_values(db_map, [["object_class1", "object1", "parameter", 1]]) self.assertFalse(errors) + db_map.commit_session("test") values = {v.object_name: v.value for v in db_map.query(db_map.object_parameter_value_sq)} expected = {"object1": b"1"} self.assertEqual(values, expected) @@ -540,6 +548,7 @@ def test_import_valid_object_parameter_value_string(self): self.populate(db_map) _, errors = import_object_parameter_values(db_map, [["object_class1", "object1", "parameter", "value_string"]]) self.assertFalse(errors) + db_map.commit_session("test") values = {v.object_name: v.value for v in db_map.query(db_map.object_parameter_value_sq)} expected = {"object1": b'"value_string"'} self.assertEqual(values, expected) @@ -551,6 +560,7 @@ def test_import_valid_object_parameter_value_with_duplicate_object_name(self): import_objects(db_map, [["object_class1", "duplicate_object"], ["object_class2", "duplicate_object"]]) _, errors = import_object_parameter_values(db_map, [["object_class1", "duplicate_object", "parameter", 1]]) self.assertFalse(errors) + db_map.commit_session("test") values = {v.object_class_name: {v.object_name: v.value} for v in db_map.query(db_map.object_parameter_value_sq)} expected = {"object_class1": {"duplicate_object": b"1"}} self.assertEqual(values, expected) @@ -562,6 +572,7 @@ def test_import_valid_object_parameter_value_with_duplicate_parameter_name(self) import_object_parameters(db_map, [["object_class2", "parameter"]]) _, errors = import_object_parameter_values(db_map, [["object_class1", "object1", "parameter", 1]]) self.assertFalse(errors) + db_map.commit_session("test") values = {v.object_class_name: {v.object_name: v.value} for v in db_map.query(db_map.object_parameter_value_sq)} expected = {"object_class1": {"object1": b"1"}} self.assertEqual(values, expected) @@ -573,6 +584,7 @@ def test_import_object_parameter_value_with_invalid_object(self): import_object_parameters(db_map, [["object_class", "parameter"]]) _, errors = import_object_parameter_values(db_map, [["object_class", "nonexistent_object", "parameter", 1]]) self.assertTrue(errors) + db_map.commit_session("test") self.assertFalse(db_map.query(db_map.object_parameter_value_sq).all()) db_map.connection.close() @@ -582,6 +594,7 @@ def test_import_object_parameter_value_with_invalid_parameter(self): import_objects(db_map, ["object_class", "object"]) _, errors = import_object_parameter_values(db_map, [["object_class", "object", "nonexistent_parameter", 1]]) self.assertTrue(errors) + db_map.commit_session("test") self.assertFalse(db_map.query(db_map.object_parameter_value_sq).all()) db_map.connection.close() @@ -591,6 +604,7 @@ def test_import_existing_object_parameter_value_update_the_value(self): import_object_parameter_values(db_map, [["object_class1", "object1", "parameter", "initial_value"]]) _, errors = import_object_parameter_values(db_map, [["object_class1", "object1", "parameter", "new_value"]]) self.assertFalse(errors) + db_map.commit_session("test") values = {v.object_name: v.value for v in db_map.query(db_map.object_parameter_value_sq)} expected = {"object1": b'"new_value"'} self.assertEqual(values, expected) @@ -606,6 +620,7 @@ def test_import_existing_object_parameter_value_on_conflict_keep(self): db_map, [["object_class1", "object1", "parameter", new_value]], on_conflict="keep" ) self.assertFalse(errors) + db_map.commit_session("test") pv = db_map.query(db_map.object_parameter_value_sq).filter_by(object_name="object1").first() value = from_database(pv.value, pv.type) self.assertEqual(['2000-01-01T01:00:00', '2000-01-01T02:00:00'], [str(x) for x in value.indexes]) @@ -622,6 +637,7 @@ def test_import_existing_object_parameter_value_on_conflict_replace(self): db_map, [["object_class1", "object1", "parameter", new_value]], on_conflict="replace" ) self.assertFalse(errors) + db_map.commit_session("test") pv = db_map.query(db_map.object_parameter_value_sq).filter_by(object_name="object1").first() value = from_database(pv.value, pv.type) self.assertEqual(['2000-01-01T02:00:00', '2000-01-01T03:00:00'], [str(x) for x in value.indexes]) @@ -638,6 +654,7 @@ def test_import_existing_object_parameter_value_on_conflict_merge(self): db_map, [["object_class1", "object1", "parameter", new_value]], on_conflict="merge" ) self.assertFalse(errors) + db_map.commit_session("test") pv = db_map.query(db_map.object_parameter_value_sq).filter_by(object_name="object1").first() value = from_database(pv.value, pv.type) self.assertEqual( @@ -664,6 +681,7 @@ def test_import_existing_object_parameter_value_on_conflict_merge_map(self): db_map, [["object_class1", "object1", "parameter", new_value]], on_conflict="merge" ) self.assertFalse(errors) + db_map.commit_session("test") pv = db_map.query(db_map.object_parameter_value_sq).filter_by(object_name="object1").first() map_ = from_database(pv.value, pv.type) self.assertEqual(['xxx'], [str(x) for x in map_.indexes]) @@ -682,6 +700,7 @@ def test_import_duplicate_object_parameter_value(self): [["object_class1", "object1", "parameter", "first"], ["object_class1", "object1", "parameter", "second"]], ) self.assertTrue(errors) + db_map.commit_session("test") values = {v.object_name: v.value for v in db_map.query(db_map.object_parameter_value_sq)} expected = {"object1": b'"first"'} self.assertEqual(values, expected) @@ -696,13 +715,9 @@ def test_import_object_parameter_value_with_alternative(self): ) self.assertFalse(errors) self.assertEqual(count, 1) + db_map.commit_session("test") values = { - v.object_name: (v.value, v.alternative_name) - for v in db_map.query( - db_map.object_parameter_value_sq, db_map.alternative_sq.c.name.label("alternative_name") - ) - .filter(db_map.object_parameter_value_sq.c.alternative_id == db_map.alternative_sq.c.id) - .all() + v.object_name: (v.value, v.alternative_name) for v in db_map.query(db_map.object_parameter_value_sq).all() } expected = {"object1": (b"1", "alternative")} self.assertEqual(values, expected) @@ -727,6 +742,7 @@ def test_valid_object_parameter_value_from_value_list(self): count, errors = import_object_parameter_values(db_map, (("object_class", "my_object", "parameter", 5.0),)) self.assertEqual(count, 1) self.assertEqual(errors, []) + db_map.commit_session("test") values = db_map.query(db_map.object_parameter_value_sq).all() self.assertEqual(len(values), 1) value = values[0] @@ -751,6 +767,7 @@ def test_import_valid_relationship_parameter_value(self): db_map, [["relationship_class", ["object1", "object2"], "parameter", 1]] ) self.assertFalse(errors) + db_map.commit_session("test") values = {v.object_name_list: v.value for v in db_map.query(db_map.relationship_parameter_value_sq)} expected = {"object1,object2": b"1"} self.assertEqual(values, expected) @@ -765,6 +782,7 @@ def test_import_valid_relationship_parameter_value_with_duplicate_parameter_name db_map, [["relationship_class", ["object1", "object2"], "parameter", 1]] ) self.assertFalse(errors) + db_map.commit_session("test") values = {v.object_name_list: v.value for v in db_map.query(db_map.relationship_parameter_value_sq)} expected = {"object1,object2": b"1"} self.assertEqual(values, expected) @@ -779,6 +797,7 @@ def test_import_valid_relationship_parameter_value_with_duplicate_object_name(se db_map, [["relationship_class", ["duplicate_object", "duplicate_object"], "parameter", 1]] ) self.assertFalse(errors) + db_map.commit_session("test") values = {v.object_name_list: v.value for v in db_map.query(db_map.relationship_parameter_value_sq)} expected = {"duplicate_object,duplicate_object": b"1"} self.assertEqual(values, expected) @@ -791,6 +810,7 @@ def test_import_relationship_parameter_value_with_invalid_object(self): db_map, [["relationship_class", ["nonexistent_object", "object2"], "parameter", 1]] ) self.assertTrue(errors) + db_map.commit_session("test") self.assertFalse(db_map.query(db_map.relationship_parameter_value_sq).all()) db_map.connection.close() @@ -801,6 +821,7 @@ def test_import_relationship_parameter_value_with_invalid_relationship_class(sel db_map, [["nonexistent_class", ["object1", "object2"], "parameter", 1]] ) self.assertTrue(errors) + db_map.commit_session("test") self.assertFalse(db_map.query(db_map.relationship_parameter_value_sq).all()) db_map.connection.close() @@ -811,6 +832,7 @@ def test_import_relationship_parameter_value_with_invalid_parameter(self): db_map, [["relationship_class", ["object1", "object2"], "nonexistent_parameter", 1]] ) self.assertTrue(errors) + db_map.commit_session("test") self.assertFalse(db_map.query(db_map.relationship_parameter_value_sq).all()) db_map.connection.close() @@ -824,6 +846,7 @@ def test_import_existing_relationship_parameter_value(self): db_map, [["relationship_class", ["object1", "object2"], "parameter", "new_value"]] ) self.assertFalse(errors) + db_map.commit_session("test") values = {v.object_name_list: v.value for v in db_map.query(db_map.relationship_parameter_value_sq)} expected = {"object1,object2": b'"new_value"'} self.assertEqual(values, expected) @@ -840,6 +863,7 @@ def test_import_duplicate_relationship_parameter_value(self): ], ) self.assertTrue(errors) + db_map.commit_session("test") values = {v.object_name_list: v.value for v in db_map.query(db_map.relationship_parameter_value_sq)} expected = {"object1,object2": b'"first"'} self.assertEqual(values, expected) @@ -854,13 +878,10 @@ def test_import_relationship_parameter_value_with_alternative(self): ) self.assertFalse(errors) self.assertEqual(count, 1) + db_map.commit_session("test") values = { v.object_name_list: (v.value, v.alternative_name) - for v in db_map.query( - db_map.relationship_parameter_value_sq, db_map.alternative_sq.c.name.label("alternative_name") - ) - .filter(db_map.relationship_parameter_value_sq.c.alternative_id == db_map.alternative_sq.c.id) - .all() + for v in db_map.query(db_map.relationship_parameter_value_sq).all() } expected = {"object1,object2": (b"1", "alternative")} self.assertEqual(values, expected) @@ -889,6 +910,7 @@ def test_valid_relationship_parameter_value_from_value_list(self): ) self.assertEqual(count, 1) self.assertEqual(errors, []) + db_map.commit_session("test") values = db_map.query(db_map.relationship_parameter_value_sq).all() self.assertEqual(len(values), 1) value = values[0] @@ -922,6 +944,7 @@ def test_list_with_single_value(self): count, errors = import_parameter_value_lists(self._db_map, (("list_1", 23.0),)) self.assertEqual(errors, []) self.assertEqual(count, 2) + self._db_map.commit_session("test") value_lists = self._db_map.query(self._db_map.parameter_value_list_sq).all() list_values = self._db_map.query(self._db_map.list_value_sq).all() self.assertEqual(len(value_lists), 1) @@ -938,7 +961,8 @@ def test_import_twelfth_value(self): self.assertEqual(count, n_values + 1) count, errors = import_parameter_value_lists(self._db_map, (("list_1", 23.0),)) self.assertEqual(errors, []) - self.assertEqual(count, 1) + self.assertEqual(count, 2) + self._db_map.commit_session("test") value_lists = self._db_map.query(self._db_map.parameter_value_list_sq).all() self.assertEqual(len(value_lists), 1) self.assertEqual(value_lists[0].name, "list_1") @@ -956,6 +980,7 @@ def test_single_alternative(self): count, errors = import_alternatives(db_map, ["alternative"]) self.assertEqual(count, 1) self.assertFalse(errors) + db_map.commit_session("test") alternatives = [a.name for a in db_map.query(db_map.alternative_sq)] self.assertEqual(len(alternatives), 2) self.assertIn("Base", alternatives) @@ -967,6 +992,7 @@ def test_alternative_description(self): count, errors = import_alternatives(db_map, [["alternative", "description"]]) self.assertEqual(count, 1) self.assertFalse(errors) + db_map.commit_session("test") alternatives = {a.name: a.description for a in db_map.query(db_map.alternative_sq)} expected = {"Base": "Base alternative", "alternative": "description"} self.assertEqual(alternatives, expected) @@ -977,6 +1003,7 @@ def test_update_alternative_description(self): count, errors = import_alternatives(db_map, [["Base", "new description"]]) self.assertEqual(count, 1) self.assertFalse(errors) + db_map.commit_session("test") alternatives = {a.name: a.description for a in db_map.query(db_map.alternative_sq)} expected = {"Base": "new description"} self.assertEqual(alternatives, expected) @@ -989,6 +1016,7 @@ def test_single_scenario(self): count, errors = import_scenarios(db_map, ["scenario"]) self.assertEqual(count, 1) self.assertFalse(errors) + db_map.commit_session("test") scenarios = {s.name: s.description for s in db_map.query(db_map.scenario_sq)} self.assertEqual(scenarios, {"scenario": None}) db_map.connection.close() @@ -998,6 +1026,7 @@ def test_scenario_with_description(self): count, errors = import_scenarios(db_map, [["scenario", False, "description"]]) self.assertEqual(count, 1) self.assertFalse(errors) + db_map.commit_session("test") scenarios = {s.name: s.description for s in db_map.query(db_map.scenario_sq)} self.assertEqual(scenarios, {"scenario": "description"}) db_map.connection.close() @@ -1008,6 +1037,7 @@ def test_update_scenario_description(self): count, errors = import_scenarios(db_map, [["scenario", False, "new description"]]) self.assertEqual(count, 1) self.assertFalse(errors) + db_map.commit_session("test") scenarios = {s.name: s.description for s in db_map.query(db_map.scenario_sq)} self.assertEqual(scenarios, {"scenario": "new description"}) db_map.connection.close() @@ -1061,6 +1091,7 @@ def test_fails_with_nonexistent_before_alternative(self): self.assertEqual(count, 2) def scenario_alternatives(self): + self._db_map.commit_session("test") scenario_alternative_qry = ( self._db_map.query( self._db_map.scenario_sq.c.name.label("scenario_name"), @@ -1077,260 +1108,13 @@ def scenario_alternatives(self): return scenario_alternatives -class TestImportTool(unittest.TestCase): - def test_single_tool(self): - db_map = create_diff_db_map() - count, errors = import_tools(db_map, ["tool"]) - self.assertEqual(count, 1) - self.assertFalse(errors) - tools = [x.name for x in db_map.query(db_map.tool_sq)] - self.assertEqual(len(tools), 1) - self.assertIn("tool", tools) - db_map.connection.close() - - def test_tool_description(self): - db_map = create_diff_db_map() - count, errors = import_tools(db_map, [["tool", "description"]]) - self.assertEqual(count, 1) - self.assertFalse(errors) - tools = {x.name: x.description for x in db_map.query(db_map.tool_sq)} - expected = {"tool": "description"} - self.assertEqual(tools, expected) - db_map.connection.close() - - def test_update_tool_description(self): - db_map = create_diff_db_map() - count, errors = import_tools(db_map, [["tool", "description"]]) - count, errors = import_tools(db_map, [["tool", "new description"]]) - self.assertEqual(count, 1) - self.assertFalse(errors) - tools = {x.name: x.description for x in db_map.query(db_map.tool_sq)} - expected = {"tool": "new description"} - self.assertEqual(tools, expected) - db_map.connection.close() - - -class TestImportFeature(unittest.TestCase): - @staticmethod - def populate(db_map): - import_object_classes(db_map, ["object_class1", "object_class2"]) - import_parameter_value_lists( - db_map, [['value_list', 'value1'], ['value_list', 'value2'], ['value_list', 'value3']] - ) - import_object_parameters( - db_map, [["object_class1", "parameter1", "value1", "value_list"], ["object_class1", "parameter2"]] - ) - - def test_single_feature(self): - db_map = create_diff_db_map() - self.populate(db_map) - count, errors = import_features(db_map, [["object_class1", "parameter1"]]) - self.assertEqual(count, 1) - self.assertFalse(errors) - features = [ - (x.entity_class_name, x.parameter_definition_name, x.parameter_value_list_name) - for x in db_map.query(db_map.ext_feature_sq) - ] - self.assertEqual(len(features), 1) - self.assertIn(("object_class1", "parameter1", "value_list"), features) - db_map.connection.close() - - def test_feature_for_parameter_without_value_list(self): - db_map = create_diff_db_map() - self.populate(db_map) - count, errors = import_features(db_map, [["object_class1", "parameter2"]]) - self.assertEqual(count, 0) - self.assertTrue(errors) - db_map.connection.close() - - def test_feature_description(self): - db_map = create_diff_db_map() - self.populate(db_map) - count, errors = import_features(db_map, [["object_class1", "parameter1", "description"]]) - self.assertEqual(count, 1) - self.assertFalse(errors) - features = { - (x.entity_class_name, x.parameter_definition_name, x.parameter_value_list_name): x.description - for x in db_map.query(db_map.ext_feature_sq) - } - expected = {("object_class1", "parameter1", "value_list"): "description"} - self.assertEqual(features, expected) - db_map.connection.close() - - def test_update_feature_description(self): - db_map = create_diff_db_map() - self.populate(db_map) - count, errors = import_features(db_map, [["object_class1", "parameter1", "description"]]) - count, errors = import_features(db_map, [["object_class1", "parameter1", "new description"]]) - self.assertEqual(count, 1) - self.assertFalse(errors) - features = { - (x.entity_class_name, x.parameter_definition_name, x.parameter_value_list_name): x.description - for x in db_map.query(db_map.ext_feature_sq) - } - expected = {("object_class1", "parameter1", "value_list"): "new description"} - self.assertEqual(features, expected) - db_map.connection.close() - - -class TestImportToolFeature(unittest.TestCase): - @staticmethod - def populate(db_map): - import_object_classes(db_map, ["object_class1", "object_class2"]) - import_parameter_value_lists( - db_map, [['value_list', 'value1'], ['value_list', 'value2'], ['value_list', 'value3']] - ) - import_object_parameters( - db_map, [["object_class1", "parameter1", "value1", "value_list"], ["object_class1", "parameter2"]] - ) - import_features(db_map, [["object_class1", "parameter1"]]) - import_tools(db_map, ["tool1"]) - - def test_single_tool_feature(self): - db_map = create_diff_db_map() - self.populate(db_map) - count, errors = import_tool_features(db_map, [["tool1", "object_class1", "parameter1"]]) - self.assertEqual(count, 1) - self.assertFalse(errors) - tool_features = [ - (x.tool_name, x.entity_class_name, x.parameter_definition_name, x.required) - for x in db_map.query(db_map.ext_tool_feature_sq) - ] - self.assertEqual(len(tool_features), 1) - self.assertIn(("tool1", "object_class1", "parameter1", False), tool_features) - db_map.connection.close() - - def test_tool_feature_with_non_feature_parameter(self): - db_map = create_diff_db_map() - self.populate(db_map) - count, errors = import_tool_features(db_map, [["tool1", "object_class1", "parameter2"]]) - self.assertEqual(count, 0) - self.assertTrue(errors) - db_map.connection.close() - - def test_tool_feature_with_non_existing_tool(self): - db_map = create_diff_db_map() - self.populate(db_map) - count, errors = import_tool_features(db_map, [["non_existing_tool", "object_class1", "parameter1"]]) - self.assertEqual(count, 0) - self.assertTrue(errors) - db_map.connection.close() - - def test_tool_feature_required(self): - db_map = create_diff_db_map() - self.populate(db_map) - count, errors = import_tool_features(db_map, [["tool1", "object_class1", "parameter1", True]]) - self.assertEqual(count, 1) - self.assertFalse(errors) - tool_features = [ - (x.tool_name, x.entity_class_name, x.parameter_definition_name, x.required) - for x in db_map.query(db_map.ext_tool_feature_sq) - ] - self.assertEqual(len(tool_features), 1) - self.assertIn(("tool1", "object_class1", "parameter1", True), tool_features) - db_map.connection.close() - - def test_update_tool_feature_required(self): - db_map = create_diff_db_map() - self.populate(db_map) - import_tool_features(db_map, [["tool1", "object_class1", "parameter1"]]) - count, errors = import_tool_features(db_map, [["tool1", "object_class1", "parameter1", True]]) - self.assertEqual(count, 1) - self.assertFalse(errors) - tool_features = [ - (x.tool_name, x.entity_class_name, x.parameter_definition_name, x.required) - for x in db_map.query(db_map.ext_tool_feature_sq) - ] - self.assertEqual(len(tool_features), 1) - self.assertIn(("tool1", "object_class1", "parameter1", True), tool_features) - db_map.connection.close() - - -class TestImportToolFeatureMethod(unittest.TestCase): - @staticmethod - def populate(db_map): - import_object_classes(db_map, ["object_class1", "object_class2"]) - import_parameter_value_lists( - db_map, [['value_list', 'value1'], ['value_list', 'value2'], ['value_list', 'value3']] - ) - import_object_parameters( - db_map, [["object_class1", "parameter1", "value1", "value_list"], ["object_class1", "parameter2"]] - ) - import_features(db_map, [["object_class1", "parameter1"]]) - import_tools(db_map, ["tool1"]) - import_tool_features(db_map, [["tool1", "object_class1", "parameter1"]]) - - def test_import_a_couple_of_tool_feature_methods(self): - db_map = create_diff_db_map() - self.populate(db_map) - count, errors = import_tool_feature_methods( - db_map, - [["tool1", "object_class1", "parameter1", "value2"], ["tool1", "object_class1", "parameter1", "value3"]], - ) - self.assertEqual(count, 2) - self.assertFalse(errors) - tool_feature_methods = [ - (x.tool_name, x.entity_class_name, x.parameter_definition_name, from_database(x.method)) - for x in db_map.query(db_map.ext_tool_feature_method_sq) - ] - self.assertEqual(len(tool_feature_methods), 2) - self.assertIn(("tool1", "object_class1", "parameter1", "value2"), tool_feature_methods) - self.assertIn(("tool1", "object_class1", "parameter1", "value3"), tool_feature_methods) - db_map.connection.close() - - def test_tool_feature_method_with_non_feature_parameter(self): - db_map = create_diff_db_map() - self.populate(db_map) - count, errors = import_tool_feature_methods(db_map, [["tool1", "object_class1", "parameter2", "method"]]) - self.assertEqual(count, 0) - self.assertTrue(errors) - db_map.connection.close() - - def test_tool_feature_method_with_non_existing_tool(self): - db_map = create_diff_db_map() - self.populate(db_map) - count, errors = import_tool_feature_methods( - db_map, [["non_existing_tool", "object_class1", "parameter1", "value2"]] - ) - self.assertEqual(count, 0) - self.assertTrue(errors) - db_map.connection.close() - - def test_tool_feature_method_with_invalid_method(self): - db_map = create_diff_db_map() - self.populate(db_map) - count, errors = import_tool_feature_methods( - db_map, [["tool1", "object_class1", "parameter1", "invalid_method"]] - ) - self.assertEqual(count, 0) - self.assertTrue(errors) - db_map.connection.close() - - def test_tool_feature_method_with_db_server_style_method(self): - db_map = DatabaseMapping("sqlite://", create=True) - self.populate(db_map) - db_map.commit_session("Add test data.") - count, errors = import_tool_feature_methods( - db_map, [["tool1", "object_class1", "parameter1", [b'"value1"', None]]], unparse_value=_unparse_value - ) - self.assertEqual(errors, []) - self.assertEqual(count, 1) - tool_feature_methods = db_map.query(db_map.ext_tool_feature_method_sq).all() - self.assertEqual(len(tool_feature_methods), 1) - self.assertEqual(tool_feature_methods[0].entity_class_name, "object_class1") - self.assertEqual(from_database(tool_feature_methods[0].method), "value1") - self.assertEqual(tool_feature_methods[0].parameter_definition_name, "parameter1") - self.assertEqual(tool_feature_methods[0].parameter_value_list_name, "value_list") - self.assertEqual(tool_feature_methods[0].tool_name, "tool1") - db_map.connection.close() - - class TestImportMetadata(unittest.TestCase): def test_import_metadata(self): db_map = create_diff_db_map() count, errors = import_metadata(db_map, ['{"name": "John", "age": 17}', '{"name": "Charly", "age": 90}']) self.assertEqual(count, 4) self.assertFalse(errors) + db_map.commit_session("test") metadata = [(x.name, x.value) for x in db_map.query(db_map.metadata_sq)] self.assertEqual(len(metadata), 4) self.assertIn(("name", "John"), metadata) @@ -1344,6 +1128,7 @@ def test_import_metadata_with_duplicate_entry(self): count, errors = import_metadata(db_map, ['{"name": "John", "age": 17}', '{"name": "Charly", "age": 17}']) self.assertEqual(count, 3) self.assertFalse(errors) + db_map.commit_session("test") metadata = [(x.name, x.value) for x in db_map.query(db_map.metadata_sq)] self.assertEqual(len(metadata), 3) self.assertIn(("name", "John"), metadata) @@ -1354,6 +1139,7 @@ def test_import_metadata_with_duplicate_entry(self): def test_import_metadata_with_nested_dict(self): db_map = create_diff_db_map() count, errors = import_metadata(db_map, ['{"name": "John", "info": {"age": 17, "city": "LA"}}']) + db_map.commit_session("test") metadata = [(x.name, x.value) for x in db_map.query(db_map.metadata_sq)] self.assertEqual(count, 2) self.assertFalse(errors) @@ -1365,6 +1151,7 @@ def test_import_metadata_with_nested_dict(self): def test_import_metadata_with_nested_list(self): db_map = create_diff_db_map() count, errors = import_metadata(db_map, ['{"contributors": [{"name": "John"}, {"name": "Charly"}]}']) + db_map.commit_session("test") metadata = [(x.name, x.value) for x in db_map.query(db_map.metadata_sq)] self.assertEqual(count, 2) self.assertFalse(errors) @@ -1376,6 +1163,7 @@ def test_import_metadata_with_nested_list(self): def test_import_unformatted_metadata(self): db_map = create_diff_db_map() count, errors = import_metadata(db_map, ['not a JSON object']) + db_map.commit_session("test") metadata = [(x.name, x.value) for x in db_map.query(db_map.metadata_sq)] self.assertEqual(count, 1) self.assertFalse(errors) @@ -1409,6 +1197,7 @@ def test_import_object_metadata(self): ) self.assertEqual(count, 4) self.assertFalse(errors) + db_map.commit_session("test") metadata = [ (x.entity_name, x.metadata_name, x.metadata_value) for x in db_map.query(db_map.ext_entity_metadata_sq) ] @@ -1431,6 +1220,7 @@ def test_import_relationship_metadata(self): ) self.assertEqual(count, 4) self.assertFalse(errors) + db_map.commit_session("test") metadata = [(x.metadata_name, x.metadata_value) for x in db_map.query(db_map.ext_entity_metadata_sq)] self.assertEqual(len(metadata), 4) self.assertIn(('co-author', 'John'), metadata) @@ -1458,10 +1248,11 @@ def test_import_object_parameter_value_metadata(self): ) self.assertEqual(errors, []) self.assertEqual(count, 2) + self._db_map.commit_session("test") metadata = self._db_map.query(self._db_map.ext_parameter_value_metadata_sq).all() self.assertEqual(len(metadata), 2) self.assertEqual( - metadata[0]._asdict(), + dict(metadata[0]), { "alternative_name": "Base", "entity_name": "object", @@ -1471,11 +1262,11 @@ def test_import_object_parameter_value_metadata(self): "metadata_value": "John", "parameter_name": "param", "parameter_value_id": 1, - "commit_id": 2, + "commit_id": None, }, ) self.assertEqual( - metadata[1]._asdict(), + dict(metadata[1]), { "alternative_name": "Base", "entity_name": "object", @@ -1485,7 +1276,7 @@ def test_import_object_parameter_value_metadata(self): "metadata_value": "17", "parameter_name": "param", "parameter_value_id": 1, - "commit_id": 2, + "commit_id": None, }, ) @@ -1501,10 +1292,11 @@ def test_import_relationship_parameter_value_metadata(self): ) self.assertEqual(errors, []) self.assertEqual(count, 2) + self._db_map.commit_session("test") metadata = self._db_map.query(self._db_map.ext_parameter_value_metadata_sq).all() self.assertEqual(len(metadata), 2) self.assertEqual( - metadata[0]._asdict(), + dict(metadata[0]), { "alternative_name": "Base", "entity_name": "relationship_class_object", @@ -1514,11 +1306,11 @@ def test_import_relationship_parameter_value_metadata(self): "metadata_value": "John", "parameter_name": "param", "parameter_value_id": 1, - "commit_id": 2, + "commit_id": None, }, ) self.assertEqual( - metadata[1]._asdict(), + dict(metadata[1]), { "alternative_name": "Base", "entity_name": "relationship_class_object", @@ -1528,7 +1320,7 @@ def test_import_relationship_parameter_value_metadata(self): "metadata_value": "17", "parameter_name": "param", "parameter_value_id": 1, - "commit_id": 2, + "commit_id": None, }, ) diff --git a/tests/test_migration.py b/tests/test_migration.py index 11f1f69e..4f2cb0f2 100644 --- a/tests/test_migration.py +++ b/tests/test_migration.py @@ -24,7 +24,8 @@ class TestMigration(unittest.TestCase): @unittest.skip( - "default_values's server_default has been changed from 0 to NULL in the create scrip, but there's no associated upgrade script yet." + "default_values's server_default has been changed from 0 to NULL in the create scrip, " + "but there's no associated upgrade script yet." ) def test_upgrade_schema(self): """Tests that the upgrade scripts produce the same schema as the function to create @@ -93,27 +94,30 @@ def test_upgrade_content(self): engine.execute("INSERT INTO parameter_value (parameter_id, relationship_id, value) VALUES (3, 2, '-1')") # Upgrade the db and check that our stuff is still there db_map = DatabaseMapping(db_url, upgrade=True) - object_classes = {x.id: x.name for x in db_map.object_class_list()} - objects = {x.id: (object_classes[x.class_id], x.name) for x in db_map.object_list()} - rel_clss = {x.id: (x.name, x.object_class_name_list) for x in db_map.wide_relationship_class_list()} + object_classes = {x.id: x.name for x in db_map.query(db_map.object_class_sq)} + objects = {x.id: (object_classes[x.class_id], x.name) for x in db_map.query(db_map.object_sq)} + rel_clss = { + x.id: (x.name, x.object_class_name_list) for x in db_map.query(db_map.wide_relationship_class_sq) + } rels = { - x.id: (rel_clss[x.class_id][0], x.name, x.object_name_list) for x in db_map.wide_relationship_list() + x.id: (rel_clss[x.class_id][0], x.name, x.object_name_list) + for x in db_map.query(db_map.wide_relationship_sq) } obj_par_defs = { x.id: (object_classes[x.object_class_id], x.parameter_name) - for x in db_map.object_parameter_definition_list() + for x in db_map.query(db_map.object_parameter_definition_sq) } rel_par_defs = { x.id: (rel_clss[x.relationship_class_id][0], x.parameter_name) - for x in db_map.relationship_parameter_definition_list() + for x in db_map.query(db_map.relationship_parameter_definition_sq) } obj_par_vals = { (obj_par_defs[x.parameter_id][1], objects[x.object_id][1], x.value) - for x in db_map.object_parameter_value_list() + for x in db_map.query(db_map.object_parameter_value_sq) } rel_par_vals = { (rel_par_defs[x.parameter_id][1], rels[x.relationship_id][1], x.value) - for x in db_map.relationship_parameter_value_list() + for x in db_map.query(db_map.relationship_parameter_value_sq) } self.assertTrue(len(object_classes), 2) self.assertTrue(len(objects), 3) From f7bc863abdb75cd80a396b1350168647b90ac9b2 Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 16 May 2023 08:13:57 +0200 Subject: [PATCH 039/317] Introduce _TempId to generate uncommitted ids without the DB --- spinedb_api/db_cache.py | 854 ------------------------- spinedb_api/db_mapping_add_mixin.py | 203 ++---- spinedb_api/db_mapping_base.py | 45 +- spinedb_api/db_mapping_commit_mixin.py | 19 +- spinedb_api/db_mapping_remove_mixin.py | 31 +- spinedb_api/db_mapping_update_mixin.py | 149 ++--- spinedb_api/import_functions.py | 117 +--- tests/test_DatabaseMapping.py | 10 +- tests/test_DiffDatabaseMapping.py | 68 +- tests/test_import_functions.py | 4 +- 10 files changed, 183 insertions(+), 1317 deletions(-) delete mode 100644 spinedb_api/db_cache.py diff --git a/spinedb_api/db_cache.py b/spinedb_api/db_cache.py deleted file mode 100644 index 61d903f9..00000000 --- a/spinedb_api/db_cache.py +++ /dev/null @@ -1,854 +0,0 @@ -###################################################################################################################### -# Copyright (C) 2017-2022 Spine project consortium -# This file is part of Spine Database API. -# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser -# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your -# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; -# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General -# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with -# this program. If not, see . -###################################################################################################################### -""" -DB cache utility. - -""" -import uuid -from contextlib import suppress -from operator import itemgetter -from enum import Enum, unique, auto -from .parameter_value import from_database - -# TODO: Implement CacheItem.pop() to do lookup? - - -@unique -class Status(Enum): - """Cache item status.""" - - committed = auto() - to_add = auto() - to_update = auto() - to_remove = auto() - - -class DBCache(dict): - """A dictionary that maps table names to ids to items. Used to store and retrieve database contents.""" - - def __init__(self, db_map, chunk_size=None): - """ - Args: - db_map (DatabaseMapping) - """ - super().__init__() - self._db_map = db_map - self._offsets = {} - self._fetched_item_types = set() - self._chunk_size = chunk_size - - def commit(self): - to_add = {} - to_update = {} - to_remove = {} - for item_type, table_cache in self.items(): - for item in dict.values(table_cache): - if item.status == Status.to_add: - to_add.setdefault(item_type, []).append(item) - elif item.status == Status.to_update: - to_update.setdefault(item_type, []).append(item) - elif item.status == Status.to_remove: - to_remove.setdefault(item_type, set()).add(item["id"]) - item.status = Status.committed - # FIXME: When computing to_remove, we could at the same time fetch all tables where items should be removed - # in cascade. This could be nice. So we would visit the tables in order, collect removed items, - # and if we find some then we would fetch all the descendant tables and validate items in them. - # This would set the removed flag, and then we would be able to collect those items - # in subsequent iterations. - # This might solve the issue when the user removes, commits, and then undoes the removal. - # My impression is since committing the removal action would fetch all the referrers, then it would - # be possible to properly undo it. Maybe that is the case already because `cascading_ids()` - # also fetches all the descendant tablenams into cache. - # Actually, it looks like all we're missing is setting the new attribute for restored items too??!! - # Ok so when you restore and item whose removal was committed, you need to set new to True - - # Another option would be to build a list of fetched ids in a fully independent dictionary. - # Then we could compare contents of the cache with this list and easily find out which items need - # to be added, updated and removed. - # To add: Those that are valid in the cache but not in fetched id - # To update: Those that are both valid in the cache and in fetched id - # To remove: Those that are in fetched id but not valid in the cache. - # But this would require fetching the entire DB before committing or something like that... To think about it. - return to_add, to_update, to_remove - - @property - def fetched_item_types(self): - return self._fetched_item_types - - def reset_queries(self): - """Resets queries and clears caches.""" - self._offsets.clear() - self._fetched_item_types.clear() - - def advance_query(self, item_type): - """Schedules an advance of the DB query that fetches items of given type. - - Args: - item_type (str) - - Returns: - Future - """ - return self._db_map.executor.submit(self.do_advance_query, item_type) - - def _get_next_chunk(self, item_type): - try: - sq_name = { - "entity_class": "wide_entity_class_sq", - "entity": "wide_entity_sq", - "parameter_value_list": "parameter_value_list_sq", - "list_value": "list_value_sq", - "alternative": "alternative_sq", - "scenario": "scenario_sq", - "scenario_alternative": "scenario_alternative_sq", - "entity_group": "entity_group_sq", - "parameter_definition": "parameter_definition_sq", - "parameter_value": "parameter_value_sq", - "metadata": "metadata_sq", - "entity_metadata": "entity_metadata_sq", - "parameter_value_metadata": "parameter_value_metadata_sq", - "commit": "commit_sq", - }[item_type] - qry = self._db_map.query(getattr(self._db_map, sq_name)) - except KeyError: - return [] - if not self._chunk_size: - self._fetched_item_types.add(item_type) - return [dict(x) for x in qry] - offset = self._offsets.setdefault(item_type, 0) - chunk = [dict(x) for x in qry.limit(self._chunk_size).offset(offset)] - self._offsets[item_type] += len(chunk) - return chunk - - def do_advance_query(self, item_type): - """Advances the DB query that fetches items of given type and caches the results. - - Args: - item_type (str) - - Returns: - list: items fetched from the DB - """ - chunk = self._get_next_chunk(item_type) - if not chunk: - self._fetched_item_types.add(item_type) - return [] - table_cache = self.table_cache(item_type) - for item in chunk: - # FIXME: This will overwrite working changes after a refresh - table_cache.add_item(item) - return chunk - - def table_cache(self, item_type): - return self.setdefault(item_type, TableCache(self, item_type)) - - def get_item(self, item_type, id_): - table_cache = self.get(item_type, {}) - item = table_cache.get(id_) - if item is None: - return {} - return item - - def fetch_more(self, item_type): - if item_type in self._fetched_item_types: - return False - return bool(self.do_advance_query(item_type)) - - def fetch_all(self, item_type): - while self.fetch_more(item_type): - pass - - def fetch_ref(self, item_type, id_): - while self.fetch_more(item_type): - with suppress(KeyError): - return self[item_type][id_] - # It is possible that fetching was completed between deciding to call this function - # and starting the while loop above resulting in self.fetch_more() to return False immediately. - # Therefore, we should try one last time if the ref is available. - with suppress(KeyError): - return self[item_type][id_] - return None - - -class TableCache(dict): - def __init__(self, db_cache, item_type, *args, **kwargs): - """ - Args: - db_cache (DBCache): the DB cache where this table cache belongs. - item_type (str): the item type, equal to a table name - """ - super().__init__(*args, **kwargs) - self._db_cache = db_cache - self._item_type = item_type - self._id_by_unique_key_value = {} - - def unique_key_value_to_id(self, key, value, strict=False): - """Returns the id that has the given value for the given unique key, or None. - - Args: - key (tuple) - value (tuple) - - Returns: - int - """ - value = tuple(tuple(x) if isinstance(x, list) else x for x in value) - self._db_cache.fetch_all(self._item_type) - id_by_unique_value = self._id_by_unique_key_value.get(key, {}) - if strict: - return id_by_unique_value[value] - return id_by_unique_value.get(value) - - def _unique_key_value_to_item(self, key, value, strict=False): - return self.get(self.unique_key_value_to_id(key, value)) - - def values(self): - return (x for x in super().values() if x.is_valid()) - - @property - def _item_factory(self): - return { - "entity_class": EntityClassItem, - "entity": EntityItem, - "entity_group": EntityGroupItem, - "parameter_definition": ParameterDefinitionItem, - "parameter_value": ParameterValueItem, - "list_value": ListValueItem, - "alternative": AlternativeItem, - "scenario": ScenarioItem, - "scenario_alternative": ScenarioAlternativeItem, - "metadata": MetadataItem, - "entity_metadata": EntityMetadataItem, - "parameter_value_metadata": ParameterValueMetadataItem, - }.get(self._item_type, CacheItem) - - def _make_item(self, item): - """Returns a cache item. - - Args: - item (dict): the 'db item' to use as base - - Returns: - CacheItem - """ - return self._item_factory(self._db_cache, self._item_type, **item) - - def current_item(self, item, skip_keys=()): - id_ = item.get("id") - if isinstance(id_, int): - # id is an int, easy - return self.get(id_) - if isinstance(id_, dict): - # id is a dict specifying the values for one of the unique constraints - key, value = zip(*id_.items()) - return self._unique_key_value_to_item(key, value) - if id_ is None: - # No id. Try to locate the item by the value of one of the unique keys. Used by import_data. - item = self._make_item(item) - error = item.resolve_inverse_references() - if error: - return None - error = item.polish() - if error: - return None - for key, value in item.unique_values(skip_keys=skip_keys): - current_item = self._unique_key_value_to_item(key, value) - if current_item: - return current_item - - def check_item(self, item, for_update=False, skip_keys=()): - if for_update: - current_item = self.current_item(item, skip_keys=skip_keys) - if current_item is None: - return None, f"no {self._item_type} matching {item} to update" - item = {**current_item, **item} - item["id"] = current_item["id"] - else: - current_item = None - candidate_item = self._make_item(item) - error = candidate_item.resolve_inverse_references() - if error: - return None, error - error = candidate_item.polish() - if error: - return None, error - invalid_ref = candidate_item.invalid_ref() - if invalid_ref: - return None, f"invalid {invalid_ref} for {self._item_type}" - try: - for key, value in candidate_item.unique_values(skip_keys=skip_keys): - empty = {k for k, v in zip(key, value) if v == ""} - if empty: - return None, f"invalid empty keys {empty} for {self._item_type}" - unique_item = self._unique_key_value_to_item(key, value) - if unique_item not in (None, current_item) and unique_item.is_valid(): - return None, f"there's already a {self._item_type} with {dict(zip(key, value))}" - except KeyError as e: - return None, f"missing {e} for {self._item_type}" - return candidate_item, None - - def _add_unique(self, item): - for key, value in item.unique_values(): - self._id_by_unique_key_value.setdefault(key, {})[value] = item["id"] - - def _remove_unique(self, item): - for key, value in item.unique_values(): - self._id_by_unique_key_value.get(key, {}).pop(value, None) - - def add_item(self, item, new=False): - self[item["id"]] = new_item = self._make_item(item) - self._add_unique(new_item) - if new: - new_item.status = Status.to_add - return new_item - - def update_item(self, item): - current_item = self[item["id"]] - self._remove_unique(current_item) - current_item.update(item) - self._add_unique(current_item) - current_item.cascade_update() - if current_item.status != Status.to_add: - current_item.status = Status.to_update - return current_item - - def remove_item(self, id_): - current_item = self.get(id_) - if current_item is not None: - self._remove_unique(current_item) - current_item.cascade_remove() - return current_item - - def restore_item(self, id_): - current_item = self.get(id_) - if current_item is not None: - self._add_unique(current_item) - current_item.cascade_restore() - return current_item - - -class CacheItem(dict): - """A dictionary that represents an db item.""" - - _defaults = {} - _unique_keys = (("name",),) - _references = {} - _inverse_references = {} - - def __init__(self, db_cache, item_type, *args, **kwargs): - """ - Args: - db_cache (DBCache): the DB cache where this item belongs. - """ - super().__init__(*args, **kwargs) - self._db_cache = db_cache - self._item_type = item_type - self._referrers = {} - self._weak_referrers = {} - self.restore_callbacks = set() - self.update_callbacks = set() - self.remove_callbacks = set() - self._to_remove = False - self._removed = False - self._corrupted = False - self._valid = None - self.status = Status.committed - - def is_committed(self): - return self.status == Status.committed - - def polish(self): - """Polishes this item once all it's references are resolved. Returns any errors. - - Returns: - str or None - """ - for key, default_value in self._defaults.items(): - self.setdefault(key, default_value) - return "" - - def resolve_inverse_references(self): - for src_key, (id_key, (ref_type, ref_key)) in self._inverse_references.items(): - if dict.get(self, src_key): - # When updating items, the user might update the id keys while leaving the name keys intact. - # In this case we shouldn't overwrite the updated id keys from the obsolete name keys. - # FIXME: It feels that this is our fault, though, like it is us who keep the obsolete name keys around. - continue - id_value = tuple(dict.get(self, k) or self.get(k) for k in id_key) - if None in id_value: - continue - table_cache = self._db_cache.table_cache(ref_type) - try: - src_value = ( - tuple(table_cache.unique_key_value_to_id(ref_key, v, strict=True) for v in zip(*id_value)) - if all(isinstance(v, (tuple, list)) for v in id_value) - else table_cache.unique_key_value_to_id(ref_key, id_value, strict=True) - ) - self[src_key] = src_value - except KeyError as err: - # Happens at unique_key_value_to_id(..., strict=True) - return f"can't find {ref_type} with {dict(zip(ref_key, err.args[0]))}" - - def invalid_ref(self): - for key, (ref_type, _ref_key) in self._references.values(): - try: - ref_id = self[key] - except KeyError: - return key - if isinstance(ref_id, tuple): - for x in ref_id: - if not self._get_ref(ref_type, x): - return key - elif not self._get_ref(ref_type, ref_id): - return key - - def unique_values(self, skip_keys=()): - for key in self._unique_keys: - if key not in skip_keys: - yield key, tuple(self.get(k) for k in key) - - @property - def removed(self): - return self._removed - - @property - def item_type(self): - return self._item_type - - @property - def key(self): - if dict.get(self, "id") is None: - return None - return (self._item_type, self["id"]) - - def __getattr__(self, name): - """Overridden method to return the dictionary key named after the attribute, or None if it doesn't exist.""" - return self.get(name) - - def __repr__(self): - return f"{self._item_type}{self._extended()}" - - def _extended(self): - return {**self, **{key: self[key] for key in self._references}} - - def _asdict(self): - return dict(**self) - - def _get_ref(self, ref_type, ref_id, strong=True): - ref = self._db_cache.get_item(ref_type, ref_id) - if not ref: - if not strong: - return {} - ref = self._db_cache.fetch_ref(ref_type, ref_id) - if not ref: - self._corrupted = True - return {} - return self._handle_ref(ref, strong) - - def _handle_ref(self, ref, strong): - if strong: - ref.add_referrer(self) - if ref.removed: - self._to_remove = True - else: - ref.add_weak_referrer(self) - if ref.removed: - return {} - return ref - - def get(self, key, default=None): - try: - return self[key] - except KeyError: - return default - - def is_valid(self): - if self._valid is not None: - return self._valid - if self._removed or self._corrupted: - return False - self._to_remove = False - self._corrupted = False - for key in self._references: - _ = self[key] - if self._to_remove: - self.cascade_remove() - self._valid = not self._removed and not self._corrupted - return self._valid - - def add_referrer(self, referrer): - if referrer.key is None: - return - self._referrers[referrer.key] = self._weak_referrers.pop(referrer.key, referrer) - - def add_weak_referrer(self, referrer): - if referrer.key is None: - return - if referrer.key not in self._referrers: - self._weak_referrers[referrer.key] = referrer - - def cascade_restore(self): - if self.status == Status.committed: - self.status = Status.to_add - if not self._removed: - return - self._removed = False - for referrer in self._referrers.values(): - referrer.cascade_restore() - for weak_referrer in self._weak_referrers.values(): - weak_referrer.call_update_callbacks() - obsolete = set() - for callback in self.restore_callbacks: - if not callback(self): - obsolete.add(callback) - self.restore_callbacks -= obsolete - - def cascade_remove(self): - self.status = Status.to_remove - if self._removed: - return - self._removed = True - self._to_remove = False - self._valid = None - obsolete = set() - for callback in self.remove_callbacks: - if not callback(self): - obsolete.add(callback) - self.remove_callbacks -= obsolete - for referrer in self._referrers.values(): - referrer.cascade_remove() - for weak_referrer in self._weak_referrers.values(): - weak_referrer.call_update_callbacks() - - def cascade_update(self): - self.call_update_callbacks() - for weak_referrer in self._weak_referrers.values(): - weak_referrer.call_update_callbacks() - for referrer in self._referrers.values(): - referrer.cascade_update() - - def call_update_callbacks(self): - self.pop("parsed_value", None) - obsolete = set() - for callback in self.update_callbacks: - if not callback(self): - obsolete.add(callback) - self.update_callbacks -= obsolete - - def __getitem__(self, key): - ref = self._references.get(key) - if ref: - key, (ref_type, ref_key) = ref - ref_id = self[key] - if isinstance(ref_id, tuple): - return tuple(self._get_ref(ref_type, x).get(ref_key) for x in ref_id) - return self._get_ref(ref_type, ref_id).get(ref_key) - return super().__getitem__(key) - - -class EntityClassItem(CacheItem): - _defaults = {"description": None, "display_icon": None} - _references = {"dimension_name_list": ("dimension_id_list", ("entity_class", "name"))} - _inverse_references = {"dimension_id_list": (("dimension_name_list",), ("entity_class", ("name",)))} - - def __init__(self, *args, **kwargs): - dimension_id_list = kwargs.get("dimension_id_list") - if dimension_id_list is None: - dimension_id_list = () - if isinstance(dimension_id_list, str): - dimension_id_list = (int(id_) for id_ in dimension_id_list.split(",")) - kwargs["dimension_id_list"] = tuple(dimension_id_list) - super().__init__(*args, **kwargs) - - -class EntityItem(CacheItem): - _defaults = {"description": None} - _unique_keys = (("class_name", "name"), ("class_name", "byname")) - _references = { - "class_name": ("class_id", ("entity_class", "name")), - "dimension_id_list": ("class_id", ("entity_class", "dimension_id_list")), - "dimension_name_list": ("class_id", ("entity_class", "dimension_name_list")), - "element_name_list": ("element_id_list", ("entity", "name")), - } - _inverse_references = { - "class_id": (("class_name",), ("entity_class", ("name",))), - "element_id_list": (("dimension_name_list", "element_name_list"), ("entity", ("class_name", "name"))), - } - - def __init__(self, *args, **kwargs): - element_id_list = kwargs.get("element_id_list") - if element_id_list is None: - element_id_list = () - if isinstance(element_id_list, str): - element_id_list = (int(id_) for id_ in element_id_list.split(",")) - kwargs["element_id_list"] = tuple(element_id_list) - super().__init__(*args, **kwargs) - - def __getitem__(self, key): - if key == "byname": - return self["element_name_list"] or (self["name"],) - return super().__getitem__(key) - - def polish(self): - error = super().polish() - if error: - return error - if "name" in self: - return - base_name = self["class_name"] + "_" + "__".join(self["element_name_list"]) - name = base_name - table_cache = self._db_cache.table_cache(self._item_type) - while table_cache.unique_key_value_to_id(("class_name", "name"), (self["class_name"], name)) is not None: - name = base_name + "_" + uuid.uuid4().hex - self["name"] = name - - -class EntityGroupItem(CacheItem): - _unique_keys = (("group_name", "member_name"),) - _references = { - "class_name": ("entity_class_id", ("entity_class", "name")), - "group_name": ("entity_id", ("entity", "name")), - "member_name": ("member_id", ("entity", "name")), - "dimension_id_list": ("entity_class_id", ("entity_class", "dimension_id_list")), - } - _inverse_references = { - "entity_class_id": (("class_name",), ("entity_class", ("name",))), - "entity_id": (("class_name", "group_name"), ("entity", ("class_name", "name"))), - "member_id": (("class_name", "member_name"), ("entity", ("class_name", "name"))), - } - - def __getitem__(self, key): - if key == "class_id": - return self["entity_class_id"] - if key == "group_id": - return self["entity_id"] - return super().__getitem__(key) - - -class ParameterDefinitionItem(CacheItem): - _defaults = {"description": None, "default_value": None, "default_type": None, "parameter_value_list_id": None} - _unique_keys = (("entity_class_name", "name"),) - _references = { - "entity_class_name": ("entity_class_id", ("entity_class", "name")), - "dimension_id_list": ("entity_class_id", ("entity_class", "dimension_id_list")), - "dimension_name_list": ("entity_class_id", ("entity_class", "dimension_name_list")), - } - _inverse_references = { - "entity_class_id": (("entity_class_name",), ("entity_class", ("name",))), - "parameter_value_list_id": (("parameter_value_list_name",), ("parameter_value_list", ("name",))), - } - - @property - def list_value_id(self): - if dict.get(self, "default_type") == "list_value_ref": - return int(dict.__getitem__(self, "default_value")) - return None - - def __getitem__(self, key): - if key == "parameter_name": - return super().__getitem__("name") - if key == "value_list_id": - return super().__getitem__("parameter_value_list_id") - if key == "parameter_value_list_id": - return dict.get(self, key) - if key == "parameter_value_list_name": - return self._get_ref("parameter_value_list", self["parameter_value_list_id"], strong=False).get("name") - if key in ("default_value", "default_type"): - list_value_id = self.list_value_id - if list_value_id is not None: - list_value_key = {"default_value": "value", "default_type": "type"}[key] - return self._get_ref("list_value", list_value_id, strong=False).get(list_value_key) - return dict.get(self, key) - if key == "list_value_id": - return self.list_value_id - return super().__getitem__(key) - - def polish(self): - error = super().polish() - if error: - return error - default_type = self["default_type"] - default_value = self["default_value"] - list_name = self["parameter_value_list_name"] - if list_name is None: - return - if default_type == "list_value_ref": - return - parsed_value = from_database(default_value, default_type) - if parsed_value is None: - return - list_value_id = self._db_cache.table_cache("list_value").unique_key_value_to_id( - ("parameter_value_list_name", "value", "type"), (list_name, default_value, default_type) - ) - if list_value_id is None: - return f"default value {parsed_value} of {self['name']} is not in {list_name}" - self["default_value"] = str(list_value_id).encode() - self["default_type"] = "list_value_ref" - - -class ParameterValueItem(CacheItem): - _unique_keys = (("parameter_definition_name", "entity_byname", "alternative_name"),) - _references = { - "entity_class_name": ("entity_class_id", ("entity_class", "name")), - "dimension_id_list": ("entity_class_id", ("entity_class", "dimension_id_list")), - "dimension_name_list": ("entity_class_id", ("entity_class", "dimension_name_list")), - "parameter_definition_name": ("parameter_definition_id", ("parameter_definition", "name")), - "parameter_value_list_id": ("parameter_definition_id", ("parameter_definition", "parameter_value_list_id")), - "parameter_value_list_name": ("parameter_definition_id", ("parameter_definition", "parameter_value_list_name")), - "entity_name": ("entity_id", ("entity", "name")), - "entity_byname": ("entity_id", ("entity", "byname")), - "element_id_list": ("entity_id", ("entity", "element_id_list")), - "element_name_list": ("entity_id", ("entity", "element_name_list")), - "alternative_name": ("alternative_id", ("alternative", "name")), - } - _inverse_references = { - "entity_class_id": (("entity_class_name",), ("entity_class", ("name",))), - "parameter_definition_id": ( - ("entity_class_name", "parameter_definition_name"), - ("parameter_definition", ("entity_class_name", "name")), - ), - "entity_id": (("entity_class_name", "entity_byname"), ("entity", ("class_name", "byname"))), - "alternative_id": (("alternative_name",), ("alternative", ("name",))), - } - - @property - def list_value_id(self): - if dict.__getitem__(self, "type") == "list_value_ref": - return int(dict.__getitem__(self, "value")) - return None - - def __getitem__(self, key): - if key == "parameter_id": - return super().__getitem__("parameter_definition_id") - if key == "parameter_name": - return super().__getitem__("parameter_definition_name") - if key in ("value", "type"): - list_value_id = self.list_value_id - if list_value_id: - return self._get_ref("list_value", list_value_id, strong=False).get(key) - if key == "list_value_id": - return self.list_value_id - return super().__getitem__(key) - - def polish(self): - list_name = self["parameter_value_list_name"] - if list_name is None: - return - type_ = self["type"] - if type_ == "list_value_ref": - return - value = self["value"] - parsed_value = from_database(value, type_) - if parsed_value is None: - return - list_value_id = self._db_cache.table_cache("list_value").unique_key_value_to_id( - ("parameter_value_list_name", "value", "type"), (list_name, value, type_) - ) - if list_value_id is None: - return ( - f"value {parsed_value} of {self['parameter_definition_name']} for {self['entity_byname']} " - "is not in {list_name}" - ) - self["value"] = str(list_value_id).encode() - self["type"] = "list_value_ref" - - -class ListValueItem(CacheItem): - _unique_keys = (("parameter_value_list_name", "value", "type"), ("parameter_value_list_name", "index")) - _references = {"parameter_value_list_name": ("parameter_value_list_id", ("parameter_value_list", "name"))} - _inverse_references = { - "parameter_value_list_id": (("parameter_value_list_name",), ("parameter_value_list", ("name",))), - } - - -class AlternativeItem(CacheItem): - _defaults = {"description": None} - - -class ScenarioItem(CacheItem): - _defaults = {"active": False, "description": None} - - @property - def sorted_alternatives(self): - self._db_cache.fetch_all("scenario_alternative") - return sorted( - (x for x in self._db_cache.get("scenario_alternative", {}).values() if x["scenario_id"] == self["id"]), - key=itemgetter("rank"), - ) - - def __getitem__(self, key): - if key == "alternative_id_list": - return [x["alternative_id"] for x in self.sorted_alternatives] - if key == "alternative_name_list": - return [x["alternative_name"] for x in self.sorted_alternatives] - return super().__getitem__(key) - - -class ScenarioAlternativeItem(CacheItem): - _unique_keys = (("scenario_name", "alternative_name"), ("scenario_name", "rank")) - _references = { - "scenario_name": ("scenario_id", ("scenario", "name")), - "alternative_name": ("alternative_id", ("alternative", "name")), - } - _inverse_references = { - "scenario_id": (("scenario_name",), ("scenario", ("name",))), - "alternative_id": (("alternative_name",), ("alternative", ("name",))), - } - - def __getitem__(self, key): - # The 'before' is to be interpreted as, this scenario alternative goes *before* the before_alternative. - # Since ranks go from 1 to the alternative count, the first alternative will have the second as the 'before', - # the second will have the third, etc, and the last will have None. - # Note that alternatives with higher ranks overwrite the values of those with lower ranks. - if key == "before_alternative_name": - return self._get_ref("alternative", self["before_alternative_id"], strong=False).get("name") - if key == "before_alternative_id": - scenario = self._get_ref("scenario", self["scenario_id"], strong=False) - try: - return scenario["alternative_id_list"][self["rank"]] - except IndexError: - return None - return super().__getitem__(key) - - -class MetadataItem(CacheItem): - _unique_keys = (("name", "value"),) - - -class EntityMetadataItem(CacheItem): - _unique_keys = (("entity_name", "metadata_name"),) - _references = { - "entity_name": ("entity_id", ("entity", "name")), - "metadata_name": ("metadata_id", ("metadata", "name")), - "metadata_value": ("metadata_id", ("metadata", "value")), - } - _inverse_references = { - "entity_id": (("entity_class_name", "entity_byname"), ("entity", ("class_name", "byname"))), - "metadata_id": (("metadata_name", "metadata_value"), ("metadata", ("name", "value"))), - } - - -class ParameterValueMetadataItem(CacheItem): - _unique_keys = (("parameter_definition_name", "entity_byname", "alternative_name", "metadata_name"),) - _references = { - "parameter_definition_name": ("parameter_value_id", ("parameter_value", "parameter_definition_name")), - "entity_byname": ("parameter_value_id", ("parameter_value", "entity_byname")), - "alternative_name": ("parameter_value_id", ("parameter_value", "alternative_name")), - "metadata_name": ("metadata_id", ("metadata", "name")), - "metadata_value": ("metadata_id", ("metadata", "value")), - } - _inverse_references = { - "parameter_value_id": ( - ("parameter_definition_name", "entity_byname", "alternative_name"), - ("parameter_value", ("parameter_definition_name", "entity_byname", "alternative_name")), - ), - "metadata_id": (("metadata_name", "metadata_value"), ("metadata", ("name", "value"))), - } diff --git a/spinedb_api/db_mapping_add_mixin.py b/spinedb_api/db_mapping_add_mixin.py index 2f3c53a8..ee7775f6 100644 --- a/spinedb_api/db_mapping_add_mixin.py +++ b/spinedb_api/db_mapping_add_mixin.py @@ -14,9 +14,6 @@ """ # TODO: improve docstrings -from datetime import datetime -from contextlib import contextmanager -from sqlalchemy import func, Table, Column, Integer, String, null, select from sqlalchemy.exc import DBAPIError from .exception import SpineIntegrityError from .helpers import convert_legacy @@ -25,101 +22,6 @@ class DatabaseMappingAddMixin: """Provides methods to perform ``INSERT`` operations over a Spine db.""" - class _IdGenerator: - def __init__(self, next_id): - self._next_id = next_id - - @property - def next_id(self): - return self._next_id - - def __call__(self): - try: - return self._next_id - finally: - self._next_id += 1 - - def __init__(self, *args, **kwargs): - """Initialize class.""" - super().__init__(*args, **kwargs) - self._next_id = self._metadata.tables.get("next_id") - if self._next_id is None: - self._next_id = Table( - "next_id", - self._metadata, - Column("user", String(155), primary_key=True), - Column("date", String(155), primary_key=True), - Column("entity_id", Integer, server_default=null()), - Column("entity_class_id", Integer, server_default=null()), - Column("entity_group_id", Integer, server_default=null()), - Column("parameter_definition_id", Integer, server_default=null()), - Column("parameter_value_id", Integer, server_default=null()), - Column("parameter_value_list_id", Integer, server_default=null()), - Column("list_value_id", Integer, server_default=null()), - Column("alternative_id", Integer, server_default=null()), - Column("scenario_id", Integer, server_default=null()), - Column("scenario_alternative_id", Integer, server_default=null()), - Column("metadata_id", Integer, server_default=null()), - Column("parameter_value_metadata_id", Integer, server_default=null()), - Column("entity_metadata_id", Integer, server_default=null()), - ) - try: - self._next_id.create(self.connection) - except DBAPIError: - # Some other concurrent process must have beaten us to create the table - self._next_id = Table("next_id", self._metadata, autoload=True) - - @contextmanager - def generate_ids(self, tablename): - """Manages id generation for new items to be added to the db. - - Args: - tablename (str): the table to which items will be added - - Yields: - self._IdGenerator: an object that generates a new id every time it is called. - """ - fieldname = { - "entity_class": "entity_class_id", - "object_class": "entity_class_id", - "relationship_class": "entity_class_id", - "entity": "entity_id", - "object": "entity_id", - "relationship": "entity_id", - "entity_group": "entity_group_id", - "parameter_definition": "parameter_definition_id", - "parameter_value": "parameter_value_id", - "parameter_value_list": "parameter_value_list_id", - "list_value": "list_value_id", - "alternative": "alternative_id", - "scenario": "scenario_id", - "scenario_alternative": "scenario_alternative_id", - "metadata": "metadata_id", - "parameter_value_metadata": "parameter_value_metadata_id", - "entity_metadata": "entity_metadata_id", - }[tablename] - with self.engine.begin() as connection: - select_next_id = select([self._next_id]) - next_id_row = connection.execute(select_next_id).first() - if next_id_row is None: - next_id = None - stmt = self._next_id.insert() - else: - next_id = getattr(next_id_row, fieldname) - stmt = self._next_id.update() - if next_id is None: - real_tablename = self._real_tablename(tablename) - table = self._metadata.tables[real_tablename] - id_field = self._id_fields.get(real_tablename, "id") - select_max_id = select([func.max(getattr(table.c, id_field))]) - max_id = connection.execute(select_max_id).scalar() - next_id = max_id + 1 if max_id else 1 - gen = self._IdGenerator(next_id) - try: - yield gen - finally: - connection.execute(stmt, {"user": self.username, "date": datetime.utcnow(), fieldname: gen.next_id}) - def add_items(self, tablename, *items, check=True, strict=False): """Add items to cache. @@ -137,42 +39,46 @@ def add_items(self, tablename, *items, check=True, strict=False): added, errors = [], [] tablename = self._real_tablename(tablename) table_cache = self.cache.table_cache(tablename) - with self.generate_ids(tablename) as new_id: - if not check: - for item in items: - convert_legacy(tablename, item) - if "id" not in item: - item["id"] = new_id() - added.append(table_cache.add_item(item, new=True)._asdict()) - else: - for item in items: - convert_legacy(tablename, item) - checked_item, error = table_cache.check_item(item) - if error: - if strict: - raise SpineIntegrityError(error) - errors.append(error) - continue - item = checked_item._asdict() - if "id" not in item: - item["id"] = new_id() - added.append(table_cache.add_item(item, new=True)._asdict()) + if not check: + for item in items: + convert_legacy(tablename, item) + added.append(table_cache.add_item(item, new=True)._asdict()) + else: + for item in items: + convert_legacy(tablename, item) + checked_item, error = table_cache.check_item(item) + if error: + if strict: + raise SpineIntegrityError(error) + errors.append(error) + continue + item = checked_item._asdict() + added.append(table_cache.add_item(item, new=True)._asdict()) return added, errors - def _do_add_items(self, tablename, *items_to_add): + def _do_add_items(self, connection, tablename, *items_to_add): """Add items to DB without checking integrity.""" + if not items_to_add: + return try: - for tablename_, items_to_add_ in self._items_to_add_per_table(tablename, items_to_add): + table = self._metadata.tables[self._real_tablename(tablename)] + for item in items_to_add: + item = item._asdict() + temp_id = item.pop("id") if hasattr(item["id"], "resolve") else None + id_ = connection.execute(table.insert(), item).inserted_primary_key[0] + if temp_id: + temp_id.resolve(id_) + for tablename_, items_to_add_ in self._extra_items_to_add_per_table(tablename, items_to_add): if not items_to_add_: continue table = self._metadata.tables[self._real_tablename(tablename_)] - self.connection_execute(table.insert(), [dict(item) for item in items_to_add_]) + connection.execute(table.insert(), items_to_add_) except DBAPIError as e: msg = f"DBAPIError while inserting {tablename} items: {e.orig.args}" raise SpineIntegrityError(msg) from e @staticmethod - def _items_to_add_per_table(tablename, items_to_add): + def _extra_items_to_add_per_table(tablename, items_to_add): """ Yields tuples of string tablename, list of items to insert. Needed because some insert queries actually need to insert records to more than one table. @@ -184,7 +90,6 @@ def _items_to_add_per_table(tablename, items_to_add): Yields: tuple: database table name, items to add """ - yield (tablename, items_to_add) if tablename == "entity_class": ecd_items_to_add = [ {"entity_class_id": item["id"], "position": position, "dimension_id": dimension_id} @@ -259,42 +164,16 @@ def add_entity_metadata(self, *items, **kwargs): def add_parameter_value_metadata(self, *items, **kwargs): return self.add_items("parameter_value_metadata", *items, **kwargs) - def _get_or_add_metadata_ids_for_items(self, *items, check, strict): - metadata_ids = {} - for entry in self.cache.get("metadata", {}).values(): - metadata_ids.setdefault(entry.name, {})[entry.value] = entry.id - metadata_to_add = [] - items_missing_metadata_ids = {} - for item in items: - existing_values = metadata_ids.get(item["metadata_name"]) - existing_id = existing_values.get(item["metadata_value"]) if existing_values is not None else None - if existing_values is None or existing_id is None: - metadata_to_add.append({"name": item["metadata_name"], "value": item["metadata_value"]}) - items_missing_metadata_ids.setdefault(item["metadata_name"], {})[item["metadata_value"]] = item - else: - item["metadata_id"] = existing_id - added_metadata, errors = self.add_items("metadata", *metadata_to_add, check=check, strict=strict) - if errors: - return added_metadata, errors - new_metadata_ids = {} - for added in added_metadata: - new_metadata_ids.setdefault(added["name"], {})[added["value"]] = added["id"] - for metadata_name, value_to_item in items_missing_metadata_ids.items(): - for metadata_value, item in value_to_item.items(): - item["metadata_id"] = new_metadata_ids[metadata_name][metadata_value] - return added_metadata, errors - - def _add_ext_item_metadata(self, table_name, *items, check=True, strict=False): - self.fetch_all({table_name}, include_ancestors=True) - added_metadata, metadata_errors = self._get_or_add_metadata_ids_for_items(*items, check=check, strict=strict) - if metadata_errors: - return added_metadata, metadata_errors - added_item_metadata, item_errors = self.add_items(table_name, *items, check=check, strict=strict) - errors = metadata_errors + item_errors - return added_metadata + added_item_metadata, errors - - def add_ext_entity_metadata(self, *items, check=True, strict=False): - return self._add_ext_item_metadata("entity_metadata", *items, check=check, strict=strict) - - def add_ext_parameter_value_metadata(self, *items, check=True, strict=False): - return self._add_ext_item_metadata("parameter_value_metadata", *items, check=check, strict=strict) + def add_ext_entity_metadata(self, *items, **kwargs): + metadata_items = self.get_metadata_to_add_with_entity_metadata_items(*items) + self.add_items("metadata", *metadata_items, **kwargs) + return self.add_items("entity_metadata", *items, **kwargs) + + def add_ext_parameter_value_metadata(self, *items, **kwargs): + metadata_items = self.get_metadata_to_add_with_entity_metadata_items(*items) + self.add_items("metadata", *metadata_items, **kwargs) + return self.add_items("parameter_value_metadata", *items, **kwargs) + + def get_metadata_to_add_with_entity_metadata_items(self, *items): + metadata_items = ({"name": item["metadata_name"], "value": item["metadata_value"]} for item in items) + return [x for x in metadata_items if not self.cache.table_cache("metadata").current_item(x)] diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 65d871e9..841b9a13 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -16,7 +16,6 @@ import os import logging import time -from collections import Counter from types import MethodType from concurrent.futures import ThreadPoolExecutor from sqlalchemy import create_engine, MetaData, Table, Column, Integer, inspect, case, func, cast, false, and_, or_ @@ -43,7 +42,7 @@ ) from .filters.tools import pop_filter_configs from .spine_db_client import get_db_url_from_server -from .db_cache import DBCache +from .db_cache_impl import DBCache from .query import Query @@ -230,6 +229,9 @@ def __enter__(self): def __exit__(self, _exc_type, _exc_val, _exc_tb): self.close() + def get_filter_configs(self): + return self._filter_configs + def _make_executor(self): return ThreadPoolExecutor(max_workers=1) if self._asynchronous else _Executor() @@ -268,18 +270,6 @@ def _descendant_tablenames(self, tablename): yield child yield from self._descendant_tablenames(child) - def sorted_tablenames(self): - tablenames = list(self.ITEM_TYPES) - sorted_tablenames = [] - while tablenames: - tablename = tablenames.pop(0) - ancestors = self.ancestor_tablenames.get(tablename) - if ancestors is None or all(x in sorted_tablenames for x in ancestors): - sorted_tablenames.append(tablename) - else: - tablenames.append(tablename) - return sorted_tablenames - def _real_tablename(self, tablename): return { "object_class": "entity_class", @@ -384,6 +374,8 @@ def _receive_engine_close(self, dbapi_con, _connection_record): def in_(self, column, values): """Returns an expression equivalent to column.in_(values), that circumvents the 'too many sql variables' problem in sqlite.""" + # FIXME + return column.in_(values) if not values: return false() if not self.sa_url.drivername.startswith("sqlite"): @@ -396,7 +388,7 @@ def in_(self, column, values): ) self.call_in_right_thread(in_value.create, self.connection, checkfirst=True) self.connection_execute(in_value.insert(), [{"value": column.type.python_type(val)} for val in set(values)]) - return column.in_(self.query(in_value.c.value)) + return column.in_({x.value for x in self.query(in_value.c.value)}) def _get_table_to_sq_attr(self): if not self._table_to_sq_attr: @@ -2006,22 +1998,19 @@ def _object_name_list(self): [(self.wide_entity_sq.c.element_id_list != None, self.wide_relationship_sq.c.object_name_list)], else_=None ) - def _metadata_usage_counts(self): - """Counts references to metadata name, value pairs in entity_metadata and parameter_value_metadata tables. + def advance_cache_query(self, item_type, callback=None): + """Schedules an advance of the DB query that fetches items of given type. + + Args: + item_type (str) Returns: - Counter: usage counts keyed by metadata id + Future """ - cache = self.cache - usage_counts = Counter() - for entry in dict.values(cache.get("entity_metadata", {})): - usage_counts[entry.metadata_id] += 1 - for entry in dict.values(cache.get("parameter_value_metadata", {})): - usage_counts[entry.metadata_id] += 1 - return usage_counts - - def get_filter_configs(self): - return self._filter_configs + if not callback: + return self.cache.advance_query(item_type) + future = self.executor.submit(self.cache.advance_query, item_type) + future.add_done_callback(lambda future: callback(future.result())) def __del__(self): try: diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py index ac109d5e..e28aa8e6 100644 --- a/spinedb_api/db_mapping_commit_mixin.py +++ b/spinedb_api/db_mapping_commit_mixin.py @@ -29,18 +29,19 @@ def commit_session(self, comment): """ if not comment: raise SpineDBAPIError("Commit message cannot be empty.") - user = self.username - date = datetime.now(timezone.utc) - ins = self._metadata.tables["commit"].insert() - commit_id = self.connection_execute(ins, dict(user=user, date=date, comment=comment)).inserted_primary_key[0] to_add, to_update, to_remove = self.cache.commit() if not to_add and not to_update and not to_remove: raise SpineDBAPIError("Nothing to commit.") - for tablename, items in to_add.items(): - self._do_add_items(tablename, *items) - for tablename, items in to_update.items(): - self._do_update_items(tablename, *items) - self._do_remove_items(**to_remove) + user = self.username + date = datetime.now(timezone.utc) + ins = self._metadata.tables["commit"].insert() + with self.engine.begin() as connection: + commit_id = connection.execute(ins, dict(user=user, date=date, comment=comment)).inserted_primary_key[0] + for tablename, items in to_add.items(): + self._do_add_items(connection, tablename, *items) + for tablename, items in to_update.items(): + self._do_update_items(connection, tablename, *items) + self._do_remove_items(connection, **to_remove) if self._memory: self._memory_dirty = True diff --git a/spinedb_api/db_mapping_remove_mixin.py b/spinedb_api/db_mapping_remove_mixin.py index 2f6ec533..07fdc0b7 100644 --- a/spinedb_api/db_mapping_remove_mixin.py +++ b/spinedb_api/db_mapping_remove_mixin.py @@ -44,7 +44,7 @@ def remove_items(self, tablename, *ids): ids -= {1} return [table_cache.remove_item(id_) for id_ in ids] - def _do_remove_items(self, **kwargs): + def _do_remove_items(self, connection, **kwargs): """Removes items from the db. Args: @@ -62,7 +62,7 @@ def _do_remove_items(self, **kwargs): table = self._metadata.tables[tablename] delete = table.delete().where(self.in_(getattr(table.c, id_field), ids)) try: - self.connection_execute(delete) + connection.execute(delete) except DBAPIError as e: msg = f"DBAPIError while removing {tablename} items: {e.orig.args}" raise SpineDBAPIError(msg) from e @@ -234,23 +234,16 @@ def _metadata_cascading_ids(self, ids): self._merge(cascading_ids, value_metadata) return cascading_ids - def _non_referenced_metadata_ids(self, ids, metadata_table_name): - cache = self.cache - metadata_id_counts = self._metadata_usage_counts() - cascading_ids = {} - metadata = cache.get(metadata_table_name, {}) - for id_ in ids: - metadata_id_counts[metadata[id_].metadata_id] -= 1 - zero_count_metadata_ids = {id_ for id_, count in metadata_id_counts.items() if count == 0} - self._merge(cascading_ids, {"metadata": zero_count_metadata_ids}) - return cascading_ids - def _entity_metadata_cascading_ids(self, ids): - cascading_ids = {"entity_metadata": set(ids)} - cascading_ids.update(self._non_referenced_metadata_ids(ids, "entity_metadata")) - return cascading_ids + return {"entity_metadata": set(ids)} def _parameter_value_metadata_cascading_ids(self, ids): - cascading_ids = {"parameter_value_metadata": set(ids)} - cascading_ids.update(self._non_referenced_metadata_ids(ids, "parameter_value_metadata")) - return cascading_ids + return {"parameter_value_metadata": set(ids)} + + def get_metadata_ids_to_remove(self): + used_metadata_ids = set() + for x in self.cache.get("entity_metadata", {}).values(): + used_metadata_ids.add(x["metadata_id"]) + for x in self.cache.get("parameter_value_metadata", {}).values(): + used_metadata_ids.add(x["metadata_id"]) + return {x["id"] for x in self.cache.get("metadata", {}).values()} - used_metadata_ids diff --git a/spinedb_api/db_mapping_update_mixin.py b/spinedb_api/db_mapping_update_mixin.py index c4b86fbb..30a2464d 100644 --- a/spinedb_api/db_mapping_update_mixin.py +++ b/spinedb_api/db_mapping_update_mixin.py @@ -12,7 +12,6 @@ """Provides :class:`DatabaseMappingUpdateMixin`. """ -from collections import Counter from sqlalchemy.exc import DBAPIError from sqlalchemy.sql.expression import bindparam from .exception import SpineIntegrityError @@ -22,24 +21,31 @@ class DatabaseMappingUpdateMixin: """Provides methods to perform ``UPDATE`` operations over a Spine db.""" - def _do_update_items(self, tablename, *items_to_update): + def _make_update_stmt(self, tablename, keys): + table = self._metadata.tables[self._real_tablename(tablename)] + upd = table.update() + for k in self._get_primary_key(tablename): + upd = upd.where(getattr(table.c, k) == bindparam(k)) + return upd.values({key: bindparam(key) for key in table.columns.keys() & keys}) + + def _do_update_items(self, connection, tablename, *items_to_update): """Update items in DB without checking integrity.""" + if not items_to_update: + return try: - for tablename_, items_to_update_ in self._items_to_update_per_table(tablename, items_to_update): + upd = self._make_update_stmt(tablename, items_to_update[0].keys()) + connection.execute(upd, [item._asdict() for item in items_to_update]) + for tablename_, items_to_update_ in self._extra_items_to_update_per_table(tablename, items_to_update): if not items_to_update_: continue - table = self._metadata.tables[self._real_tablename(tablename_)] - upd = table.update() - for k in self._get_primary_key(tablename_): - upd = upd.where(getattr(table.c, k) == bindparam(k)) - upd = upd.values({key: bindparam(key) for key in table.columns.keys() & items_to_update_[0].keys()}) - self.connection_execute(upd, [dict(item) for item in items_to_update_]) + upd = self._make_update_stmt(tablename_, items_to_update_[0].keys()) + connection.execute(upd, items_to_update_) except DBAPIError as e: msg = f"DBAPIError while updating '{tablename}' items: {e.orig.args}" raise SpineIntegrityError(msg) from e @staticmethod - def _items_to_update_per_table(tablename, items_to_update): + def _extra_items_to_update_per_table(tablename, items_to_update): """ Yields tuples of string tablename, list of items to update. Needed because some update queries actually need to update records in more than one table. @@ -51,7 +57,6 @@ def _items_to_update_per_table(tablename, items_to_update): Yields: tuple: database table name, items to update """ - yield (tablename, items_to_update) if tablename == "entity": ee_items_to_update = [ { @@ -98,8 +103,9 @@ def update_items(self, tablename, *items, check=True, strict=False): raise SpineIntegrityError(error) errors.append(error) continue - item = checked_item._asdict() - updated.append(table_cache.update_item(item)._asdict()) + if checked_item: + item = checked_item._asdict() + updated.append(table_cache.update_item(item)._asdict()) return updated, errors def update_alternatives(self, *items, **kwargs): @@ -144,104 +150,25 @@ def update_list_values(self, *items, **kwargs): def update_metadata(self, *items, **kwargs): return self.update_items("metadata", *items, **kwargs) - def update_ext_entity_metadata(self, *items, check=True, strict=False): - updated_items, errors = self._update_ext_item_metadata("entity_metadata", *items, check=check, strict=strict) - return updated_items, errors - - def update_ext_parameter_value_metadata(self, *items, check=True, strict=False): - updated_items, errors = self._update_ext_item_metadata( - "parameter_value_metadata", *items, check=check, strict=strict - ) - return updated_items, errors - - def _update_ext_item_metadata(self, metadata_table, *items, check=True, strict=False): - self.fetch_all({"entity_metadata", "parameter_value_metadata", "metadata"}) - cache = self.cache - metadata_ids = {} - for entry in cache.get("metadata", {}).values(): - metadata_ids.setdefault(entry.name, {})[entry.value] = entry.id - item_metadata_cache = cache[metadata_table] - metadata_usage_counts = self._metadata_usage_counts() - updatable_items = [] - homeless_items = [] - for item in items: - metadata_name = item["metadata_name"] - metadata_value = item["metadata_value"] - metadata_id = metadata_ids.get(metadata_name, {}).get(metadata_value) - if metadata_id is None: - homeless_items.append(item) - continue - item["metadata_id"] = metadata_id - previous_metadata_id = item_metadata_cache[item["id"]]["metadata_id"] - metadata_usage_counts[previous_metadata_id] -= 1 - metadata_usage_counts[metadata_id] += 1 - updatable_items.append(item) - homeless_item_metadata_usage_counts = Counter() - for item in homeless_items: - homeless_item_metadata_usage_counts[item_metadata_cache[item["id"]].metadata_id] += 1 - updatable_metadata_items = [] - future_metadata_ids = {} - for metadata_id, count in homeless_item_metadata_usage_counts.items(): - if count == metadata_usage_counts[metadata_id]: - for cached_item in item_metadata_cache.values(): - if cached_item["metadata_id"] == metadata_id: - found = False - for item in homeless_items: - if item["id"] == cached_item["id"]: - metadata_name = item["metadata_name"] - metadata_value = item["metadata_value"] - updatable_metadata_items.append( - {"id": metadata_id, "name": metadata_name, "value": metadata_value} - ) - future_metadata_ids.setdefault(metadata_name, {})[metadata_value] = metadata_id - metadata_usage_counts[metadata_id] = 0 - found = True - break - if found: - break - items_needing_new_metadata = [] - for item in homeless_items: - metadata_name = item["metadata_name"] - metadata_value = item["metadata_value"] - metadata_id = future_metadata_ids.get(metadata_name, {}).get(metadata_value) - if metadata_id is None: - items_needing_new_metadata.append(item) - continue - if item_metadata_cache[item["id"]]["metadata_id"] == metadata_id: - continue - item["metadata_id"] = metadata_id - updatable_items.append(item) - all_items = [] - errors = [] - if updatable_metadata_items: - updated_metadata, errors = self.update_metadata(*updatable_metadata_items, check=False, strict=strict) - all_items += updated_metadata - if errors: - return all_items, errors - addable_metadata = [ - {"name": i["metadata_name"], "value": i["metadata_value"]} for i in items_needing_new_metadata - ] - added_metadata = [] - if addable_metadata: - added_metadata, metadata_add_errors = self.add_metadata(*addable_metadata, check=False, strict=strict) - all_items += added_metadata - errors += metadata_add_errors - if errors: - return all_items, errors - added_metadata_ids = {} - for item in added_metadata: - added_metadata_ids.setdefault(item["name"], {})[item["value"]] = item["id"] - for item in items_needing_new_metadata: - item["metadata_id"] = added_metadata_ids[item["metadata_name"]][item["metadata_value"]] - updatable_items.append(item) - if updatable_items: - # FIXME: Force-clear cache before updating item metadata to ensure that added/updated metadata is found. - updated_item_metadata, item_metadata_errors = self.update_items( - metadata_table, *updatable_items, check=check, strict=strict - ) - all_items += updated_item_metadata - errors += item_metadata_errors - return all_items, errors + def update_entity_metadata(self, *items, **kwargs): + return self.update_items("entity_metadata", *items, **kwargs) + + def update_parameter_value_metadata(self, *items, **kwargs): + return self.update_items("parameter_value_metadata", *items, **kwargs) + + def _update_ext_item_metadata(self, tablename, *items, **kwargs): + metadata_items = self.get_metadata_to_add_with_entity_metadata_items(*items) + added, errors = self.add_items("metadata", *metadata_items, **kwargs) + updated, more_errors = self.update_items(tablename, *items, **kwargs) + metadata_ids = self.get_metadata_ids_to_remove() + self.remove_items("metadata", *metadata_ids) + return added + updated, errors + more_errors + + def update_ext_entity_metadata(self, *items, **kwargs): + return self._update_ext_item_metadata("entity_metadata", *items, **kwargs) + + def update_ext_parameter_value_metadata(self, *items, **kwargs): + return self._update_ext_item_metadata("parameter_value_metadata", *items, **kwargs) def get_data_to_set_scenario_alternatives(self, *scenarios, strict=True): """Returns data to add and remove, in order to set wide scenario alternatives. diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 7b20de90..524e316e 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -130,11 +130,13 @@ def get_data_for_import( scenarios=(), scenario_alternatives=(), metadata=(), + entity_metadata=(), + parameter_value_metadata=(), object_metadata=(), relationship_metadata=(), object_parameter_value_metadata=(), relationship_parameter_value_metadata=(), - # FIXME: compat + # legacy tools=(), features=(), tool_features=(), @@ -187,7 +189,7 @@ def get_data_for_import( if object_classes: yield ("object_class", _get_object_classes_for_import(db_map, object_classes)) if relationship_classes: - yield ("relationship_class", _get_relationship_classes_for_import(db_map, relationship_classes)) + yield ("relationship_class", _get_entity_classes_for_import(db_map, relationship_classes)) if parameter_value_lists: yield ("parameter_value_list", _get_parameter_value_lists_for_import(db_map, parameter_value_lists)) yield ("list_value", _get_list_values_for_import(db_map, parameter_value_lists, unparse_value)) @@ -197,22 +199,22 @@ def get_data_for_import( _get_parameter_definitions_for_import(db_map, parameter_definitions, unparse_value), ) if object_parameters: - yield ("parameter_definition", _get_object_parameters_for_import(db_map, object_parameters, unparse_value)) + yield ("parameter_definition", _get_parameter_definitions_for_import(db_map, object_parameters, unparse_value)) if relationship_parameters: yield ( "parameter_definition", - _get_relationship_parameters_for_import(db_map, relationship_parameters, unparse_value), + _get_parameter_definitions_for_import(db_map, relationship_parameters, unparse_value), ) if entities: yield ("entity", _get_entities_for_import(db_map, entities)) if objects: - yield ("object", _get_objects_for_import(db_map, objects)) + yield ("object", _get_entities_for_import(db_map, objects)) if relationships: - yield ("relationship", _get_relationships_for_import(db_map, relationships)) + yield ("relationship", _get_entities_for_import(db_map, relationships)) if entity_groups: yield ("entity_group", _get_entity_groups_for_import(db_map, entity_groups)) if object_groups: - yield ("entity_group", _get_object_groups_for_import(db_map, object_groups)) + yield ("entity_group", _get_entity_groups_for_import(db_map, object_groups)) if parameter_values: yield ( "parameter_value", @@ -221,31 +223,28 @@ def get_data_for_import( if object_parameter_values: yield ( "parameter_value", - _get_object_parameter_values_for_import(db_map, object_parameter_values, unparse_value, on_conflict), + _get_parameter_values_for_import(db_map, object_parameter_values, unparse_value, on_conflict), ) if relationship_parameter_values: yield ( "parameter_value", - _get_relationship_parameter_values_for_import( - db_map, relationship_parameter_values, unparse_value, on_conflict - ), + _get_parameter_values_for_import(db_map, relationship_parameter_values, unparse_value, on_conflict), ) if metadata: yield ("metadata", _get_metadata_for_import(db_map, metadata)) + if entity_metadata: + yield ("metadata", _get_metadata_for_import(db_map, (metadata for _, _, metadata in entity_metadata))) + yield ("entity_metadata", _get_entity_metadata_for_import(db_map, entity_metadata)) + if parameter_value_metadata: + yield ("parameter_value_metadata", _get_parameter_value_metadata_for_import(db_map, parameter_value_metadata)) if object_metadata: - yield ("entity_metadata", _get_object_metadata_for_import(db_map, object_metadata)) + yield from get_data_for_import(db_map, entity_metadata=object_metadata) if relationship_metadata: - yield ("entity_metadata", _get_relationship_metadata_for_import(db_map, relationship_metadata)) + yield from get_data_for_import(db_map, entity_metadata=relationship_metadata) if object_parameter_value_metadata: - yield ( - "parameter_value_metadata", - _get_object_parameter_value_metadata_for_import(db_map, object_parameter_value_metadata), - ) + yield from get_data_for_import(db_map, parameter_value_metadata=object_parameter_value_metadata) if relationship_parameter_value_metadata: - yield ( - "parameter_value_metadata", - _get_relationship_parameter_value_metadata_for_import(db_map, relationship_parameter_value_metadata), - ) + yield from get_data_for_import(db_map, parameter_value_metadata=relationship_parameter_value_metadata) def import_entity_classes(db_map, data): @@ -750,27 +749,25 @@ def _get_items_for_import(db_map, item_type, data, skip_keys=()): to_add = [] to_update = [] seen = {} - with db_map.generate_ids(item_type) as new_id: - for item in data: - checked_item, add_error = table_cache.check_item(item, skip_keys=skip_keys) - if not add_error: - if not _check_seen(item_type, checked_item, seen, errors): - continue - checked_item["id"] = new_id() - to_add.append(checked_item) + for item in data: + checked_item, add_error = table_cache.check_item(item, skip_keys=skip_keys) + if not add_error: + if not _check_unique(item_type, checked_item, seen, errors): continue - checked_item, update_error = table_cache.check_item(item, for_update=True, skip_keys=skip_keys) - if not update_error: - if not _check_seen(item_type, checked_item, seen, errors): + to_add.append(checked_item) + continue + checked_item, update_error = table_cache.check_item(item, for_update=True, skip_keys=skip_keys) + if not update_error: + if checked_item: + if not _check_unique(item_type, checked_item, seen, errors): continue - # FIXME: Maybe check that item and checked_item are different before updating??? to_update.append(checked_item) - continue - errors.append(add_error) + continue + errors.append(add_error) return to_add, to_update, errors -def _check_seen(item_type, checked_item, seen, errors): +def _check_unique(item_type, checked_item, seen, errors): dupe_key = _add_to_seen(checked_item, seen) if not dupe_key: return True @@ -960,51 +957,3 @@ def _data_iterator(): yield name, (), *optionals return _get_entity_classes_for_import(db_map, _data_iterator()) - - -def _get_relationship_classes_for_import(db_map, data): - return _get_entity_classes_for_import(db_map, data) - - -def _get_objects_for_import(db_map, data): - return _get_entities_for_import(db_map, data) - - -def _get_relationships_for_import(db_map, data): - return _get_entities_for_import(db_map, data) - - -def _get_object_groups_for_import(db_map, data): - return _get_entity_groups_for_import(db_map, data) - - -def _get_object_parameters_for_import(db_map, data, unparse_value): - return _get_parameter_definitions_for_import(db_map, data, unparse_value) - - -def _get_relationship_parameters_for_import(db_map, data, unparse_value): - return _get_parameter_definitions_for_import(db_map, data, unparse_value) - - -def _get_object_parameter_values_for_import(db_map, data, unparse_value, on_conflict): - return _get_parameter_values_for_import(db_map, data, unparse_value, on_conflict) - - -def _get_relationship_parameter_values_for_import(db_map, data, unparse_value, on_conflict): - return _get_parameter_values_for_import(db_map, data, unparse_value, on_conflict) - - -def _get_object_metadata_for_import(db_map, data): - return _get_entity_metadata_for_import(db_map, data) - - -def _get_relationship_metadata_for_import(db_map, data): - return _get_entity_metadata_for_import(db_map, data) - - -def _get_object_parameter_value_metadata_for_import(db_map, data): - return _get_parameter_value_metadata_for_import(db_map, data) - - -def _get_relationship_parameter_value_metadata_for_import(db_map, data): - return _get_parameter_value_metadata_for_import(db_map, data) diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index a91de2a7..8fd0ea3f 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -812,17 +812,16 @@ def test_update_parameter_value_metadata(self): items, errors = self._db_map.update_ext_parameter_value_metadata( *[{"id": 1, "metadata_name": "key_2", "metadata_value": "new value"}] ) - ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(ids, {1}) + self.assertEqual(len(items), 2) self._db_map.commit_session("Update data") metadata_entries = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata_entries), 1) - self.assertEqual(dict(metadata_entries[0]), {"id": 1, "name": "key_2", "value": "new value", "commit_id": None}) + self.assertEqual(dict(metadata_entries[0]), {"id": 2, "name": "key_2", "value": "new value", "commit_id": None}) value_metadata_entries = self._db_map.query(self._db_map.parameter_value_metadata_sq).all() self.assertEqual(len(value_metadata_entries), 1) self.assertEqual( - dict(value_metadata_entries[0]), {"id": 1, "parameter_value_id": 1, "metadata_id": 1, "commit_id": None} + dict(value_metadata_entries[0]), {"id": 1, "parameter_value_id": 1, "metadata_id": 2, "commit_id": None} ) def test_update_parameter_value_metadata_will_not_delete_shared_entity_metadata(self): @@ -841,9 +840,8 @@ def test_update_parameter_value_metadata_will_not_delete_shared_entity_metadata( items, errors = self._db_map.update_ext_parameter_value_metadata( *[{"id": 1, "metadata_name": "key_2", "metadata_value": "new value"}] ) - ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(ids, {1, 2}) + self.assertEqual(len(items), 2) self._db_map.commit_session("Update data") metadata_entries = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata_entries), 2) diff --git a/tests/test_DiffDatabaseMapping.py b/tests/test_DiffDatabaseMapping.py index dbaf8784..5c70d9e0 100644 --- a/tests/test_DiffDatabaseMapping.py +++ b/tests/test_DiffDatabaseMapping.py @@ -22,7 +22,6 @@ from sqlalchemy.util import KeyedTuple from spinedb_api.db_mapping import DatabaseMapping from spinedb_api.exception import SpineIntegrityError -from spinedb_api.db_cache import DBCache from spinedb_api import import_functions, SpineDBAPIError @@ -153,10 +152,10 @@ def test_remove_entity_group_from_committed_session(self): """Test removing an entity group from a committed session""" self._db_map.add_object_classes({"name": "oc1", "id": 1}) self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 1}) - items, _ = self._db_map.add_entity_groups({"entity_id": 1, "entity_class_id": 1, "member_id": 2}) + self._db_map.add_entity_groups({"entity_id": 1, "entity_class_id": 1, "member_id": 2}) self._db_map.commit_session("add") self.assertEqual(len(self._db_map.query(self._db_map.entity_group_sq).all()), 1) - self._db_map.remove_items("entity_group", *{x["id"] for x in items}) + self._db_map.remove_items("entity_group", 1) self._db_map.commit_session("delete") self.assertEqual(len(self._db_map.query(self._db_map.entity_group_sq).all()), 0) @@ -454,7 +453,7 @@ def test_add_object_class_with_same_name_as_existing_one(self): def test_add_objects(self): """Test that adding objects works.""" - self._db_map.add_object_classes({"name": "fish"}) + self._db_map.add_object_classes({"name": "fish", "id": 1}) self._db_map.add_objects({"name": "nemo", "class_id": 1}, {"name": "dory", "class_id": 1}) self._db_map.commit_session("add") objects = self._db_map.query(self._db_map.object_sq).all() @@ -472,7 +471,7 @@ def test_add_object_with_invalid_name(self): def test_add_objects_with_same_name(self): """Test that adding two objects with the same name only adds one of them.""" - self._db_map.add_object_classes({"name": "fish"}) + self._db_map.add_object_classes({"name": "fish", "id": 1}) self._db_map.add_objects({"name": "nemo", "class_id": 1}, {"name": "nemo", "class_id": 1}) self._db_map.commit_session("add") objects = self._db_map.query(self._db_map.object_sq).all() @@ -495,7 +494,7 @@ def test_add_object_with_invalid_class(self): def test_add_relationship_classes(self): """Test that adding relationship classes works.""" - self._db_map.add_object_classes({"name": "oc1"}, {"name": "oc2"}) + self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) self._db_map.add_wide_relationship_classes( {"name": "rc1", "object_class_id_list": [1, 2]}, {"name": "rc2", "object_class_id_list": [2, 1]} ) @@ -519,7 +518,7 @@ def test_add_relationship_classes_with_invalid_name(self): def test_add_relationship_classes_with_same_name(self): """Test that adding two relationship classes with the same name only adds one of them.""" - self._db_map.add_object_classes({"name": "oc1"}, {"name": "oc2"}) + self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) self._db_map.add_wide_relationship_classes( {"name": "rc1", "object_class_id_list": [1, 2]}, {"name": "rc1", "object_class_id_list": [1, 2]}, @@ -571,9 +570,9 @@ def test_add_relationship_class_with_invalid_object_class(self): def test_add_relationships(self): """Test that adding relationships works.""" - self._db_map.add_object_classes({"name": "oc1"}, {"name": "oc2"}) - self._db_map.add_wide_relationship_classes({"name": "rc1", "object_class_id_list": [1, 2]}) - self._db_map.add_objects({"name": "o1", "class_id": 1}, {"name": "o2", "class_id": 2}) + self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) + self._db_map.add_wide_relationship_classes({"name": "rc1", "object_class_id_list": [1, 2], "id": 3}) + self._db_map.add_objects({"name": "o1", "class_id": 1, "id": 1}, {"name": "o2", "class_id": 2, "id": 2}) self._db_map.add_wide_relationships({"name": "nemo__pluto", "class_id": 3, "object_id_list": [1, 2]}) self._db_map.commit_session("add") ent_els = self._db_map.query(self._db_map.get_table("entity_element")).all() @@ -588,7 +587,7 @@ def test_add_relationships(self): def test_add_relationship_with_invalid_name(self): """Test that adding object classes with empty name raises error""" - self._db_map.add_object_classes({"name": "oc1"}, strict=True) + self._db_map.add_object_classes({"name": "oc1", "id": 1}, strict=True) self._db_map.add_wide_relationship_classes({"name": "rc1", "object_class_id_list": [1]}, strict=True) self._db_map.add_objects({"name": "o1", "class_id": 1}, strict=True) with self.assertRaises(SpineIntegrityError): @@ -596,9 +595,9 @@ def test_add_relationship_with_invalid_name(self): def test_add_identical_relationships(self): """Test that adding two relationships with the same class and same objects only adds the first one.""" - self._db_map.add_object_classes({"name": "oc1"}, {"name": "oc2"}) - self._db_map.add_wide_relationship_classes({"name": "rc1", "object_class_id_list": [1, 2]}) - self._db_map.add_objects({"name": "o1", "class_id": 1}, {"name": "o2", "class_id": 2}) + self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) + self._db_map.add_wide_relationship_classes({"name": "rc1", "object_class_id_list": [1, 2], "id": 3}) + self._db_map.add_objects({"name": "o1", "class_id": 1, "id": 1}, {"name": "o2", "class_id": 2, "id": 2}) self._db_map.add_wide_relationships( {"name": "nemo__pluto", "class_id": 3, "object_id_list": [1, 2]}, {"name": "nemo__pluto_duplicate", "class_id": 3, "object_id_list": [1, 2]}, @@ -903,6 +902,7 @@ def test_add_existing_parameter_value(self): import_functions.import_objects(self._db_map, [("fish", "nemo")]) import_functions.import_object_parameters(self._db_map, [("fish", "color")]) import_functions.import_object_parameter_values(self._db_map, [("fish", "nemo", "color", "orange")]) + self._db_map.commit_session("add") _, errors = self._db_map.add_parameter_values( { "parameter_definition_id": 1, @@ -923,9 +923,8 @@ def test_add_existing_parameter_value(self): def test_add_alternative(self): items, errors = self._db_map.add_alternatives({"name": "my_alternative"}) - ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(ids, {2}) + self.assertEqual(len(items), 1) self._db_map.commit_session("Add test data.") alternatives = self._db_map.query(self._db_map.alternative_sq).all() self.assertEqual(len(alternatives), 2) @@ -938,9 +937,8 @@ def test_add_alternative(self): def test_add_scenario(self): items, errors = self._db_map.add_scenarios({"name": "my_scenario"}) - ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(ids, {1}) + self.assertEqual(len(items), 1) self._db_map.commit_session("Add test data.") scenarios = self._db_map.query(self._db_map.scenario_sq).all() self.assertEqual(len(scenarios), 1) @@ -953,9 +951,8 @@ def test_add_scenario_alternative(self): import_functions.import_scenarios(self._db_map, ("my_scenario",)) self._db_map.commit_session("Add test data.") items, errors = self._db_map.add_scenario_alternatives({"scenario_id": 1, "alternative_id": 1, "rank": 0}) - ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(ids, {1}) + self.assertEqual(len(items), 1) self._db_map.commit_session("Add test data.") scenario_alternatives = self._db_map.query(self._db_map.scenario_alternative_sq).all() self.assertEqual(len(scenario_alternatives), 1) @@ -966,9 +963,8 @@ def test_add_scenario_alternative(self): def test_add_metadata(self): items, errors = self._db_map.add_metadata({"name": "test name", "value": "test_add_metadata"}, strict=False) - ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(ids, {1}) + self.assertEqual(len(items), 1) self._db_map.commit_session("Add metadata") metadata = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata), 1) @@ -991,9 +987,8 @@ def test_add_entity_metadata_for_object(self): import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) self._db_map.commit_session("Add test data.") items, errors = self._db_map.add_entity_metadata({"entity_id": 1, "metadata_id": 1}, strict=False) - ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(ids, {1}) + self.assertEqual(len(items), 1) self._db_map.commit_session("Add entity metadata") entity_metadata = self._db_map.query(self._db_map.ext_entity_metadata_sq).all() self.assertEqual(len(entity_metadata), 1) @@ -1018,9 +1013,8 @@ def test_add_entity_metadata_for_relationship(self): import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) self._db_map.commit_session("Add test data.") items, errors = self._db_map.add_entity_metadata({"entity_id": 2, "metadata_id": 1}, strict=False) - ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(ids, {1}) + self.assertEqual(len(items), 1) self._db_map.commit_session("Add entity metadata") entity_metadata = self._db_map.query(self._db_map.ext_entity_metadata_sq).all() self.assertEqual(len(entity_metadata), 1) @@ -1049,9 +1043,8 @@ def test_add_ext_entity_metadata_for_object(self): items, errors = self._db_map.add_ext_entity_metadata( {"entity_id": 1, "metadata_name": "key", "metadata_value": "object metadata"}, strict=False ) - ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(ids, {1}) + self.assertEqual(len(items), 1) self._db_map.commit_session("Add entity metadata") entity_metadata = self._db_map.query(self._db_map.ext_entity_metadata_sq).all() self.assertEqual(len(entity_metadata), 1) @@ -1076,9 +1069,8 @@ def test_adding_ext_entity_metadata_for_object_reuses_existing_metadata_names_an items, errors = self._db_map.add_ext_entity_metadata( {"entity_id": 1, "metadata_name": "title", "metadata_value": "My metadata."}, strict=False ) - ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(ids, {1}) + self.assertEqual(len(items), 1) self._db_map.commit_session("Add entity metadata") metadata = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata), 1) @@ -1108,9 +1100,8 @@ def test_add_parameter_value_metadata(self): items, errors = self._db_map.add_parameter_value_metadata( {"parameter_value_id": 1, "metadata_id": 1, "alternative_id": 1}, strict=False ) - ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(ids, {1}) + self.assertEqual(len(items), 1) self._db_map.commit_session("Add value metadata") value_metadata = self._db_map.query(self._db_map.ext_parameter_value_metadata_sq).all() self.assertEqual(len(value_metadata), 1) @@ -1131,12 +1122,9 @@ def test_add_parameter_value_metadata(self): def test_add_parameter_value_metadata_doesnt_raise_with_empty_cache(self): items, errors = self._db_map.add_parameter_value_metadata( - {"parameter_value_id": 1, "metadata_id": 1, "alternative_id": 1}, - cache=DBCache(lambda *args, **kwargs: None), - strict=False, + {"parameter_value_id": 1, "metadata_id": 1, "alternative_id": 1} ) - ids = {x["id"] for x in items} - self.assertEqual(ids, set()) + self.assertEqual(len(items), 0) self.assertEqual(len(errors), 1) def test_add_ext_parameter_value_metadata(self): @@ -1154,9 +1142,8 @@ def test_add_ext_parameter_value_metadata(self): }, strict=False, ) - ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(ids, {1}) + self.assertEqual(len(items), 1) self._db_map.commit_session("Add value metadata") value_metadata = self._db_map.query(self._db_map.ext_parameter_value_metadata_sq).all() self.assertEqual(len(value_metadata), 1) @@ -1186,9 +1173,8 @@ def test_add_ext_parameter_value_metadata_reuses_existing_metadata(self): {"parameter_value_id": 1, "metadata_name": "title", "metadata_value": "My metadata.", "alternative_id": 1}, strict=False, ) - ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(ids, {1}) + self.assertEqual(len(items), 1) self._db_map.commit_session("Add value metadata") metadata = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata), 1) diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py index e381c1f6..8f44d1dd 100644 --- a/tests/test_import_functions.py +++ b/tests/test_import_functions.py @@ -161,7 +161,6 @@ def test_import_existing_object(self): self.assertIn("object", [o.name for o in db_map.query(db_map.object_sq)]) _, errors = import_objects(db_map, [["object_class", "object"]]) self.assertFalse(errors) - db_map.commit_session("test") self.assertIn("object", [o.name for o in db_map.query(db_map.object_sq)]) db_map.connection.close() @@ -261,7 +260,6 @@ def test_import_existing_object_class_parameter(self): db_map.commit_session("test") self.assertIn("parameter", [p.name for p in db_map.query(db_map.parameter_definition_sq)]) _, errors = import_object_parameters(db_map, [["object_class", "parameter"]]) - db_map.commit_session("test") self.assertIn("parameter", [p.name for p in db_map.query(db_map.parameter_definition_sq)]) self.assertFalse(errors) db_map.connection.close() @@ -961,7 +959,7 @@ def test_import_twelfth_value(self): self.assertEqual(count, n_values + 1) count, errors = import_parameter_value_lists(self._db_map, (("list_1", 23.0),)) self.assertEqual(errors, []) - self.assertEqual(count, 2) + self.assertEqual(count, 1) self._db_map.commit_session("test") value_lists = self._db_map.query(self._db_map.parameter_value_list_sq).all() self.assertEqual(len(value_lists), 1) From 00ed8b33aa50617f119450a20ff6287bc036eab9 Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 16 May 2023 16:52:55 +0200 Subject: [PATCH 040/317] Fix various bugs found via tests --- spinedb_api/db_cache_base.py | 664 ++++++++++++++++++++++++ spinedb_api/db_cache_impl.py | 412 +++++++++++++++ spinedb_api/db_mapping_add_mixin.py | 27 +- spinedb_api/db_mapping_base.py | 113 ++-- spinedb_api/db_mapping_remove_mixin.py | 178 +------ spinedb_api/db_mapping_update_mixin.py | 6 +- spinedb_api/export_functions.py | 4 +- spinedb_api/filters/execution_filter.py | 20 +- spinedb_api/helpers.py | 27 +- spinedb_api/import_functions.py | 25 +- spinedb_api/query.py | 8 +- tests/test_DatabaseMapping.py | 27 +- tests/test_DiffDatabaseMapping.py | 23 +- 13 files changed, 1178 insertions(+), 356 deletions(-) create mode 100644 spinedb_api/db_cache_base.py create mode 100644 spinedb_api/db_cache_impl.py diff --git a/spinedb_api/db_cache_base.py b/spinedb_api/db_cache_base.py new file mode 100644 index 00000000..72a146a0 --- /dev/null +++ b/spinedb_api/db_cache_base.py @@ -0,0 +1,664 @@ +###################################################################################################################### +# Copyright (C) 2017-2022 Spine project consortium +# This file is part of Spine Database API. +# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser +# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General +# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with +# this program. If not, see . +###################################################################################################################### +""" +DB cache base. + +""" +from contextlib import suppress +from enum import Enum, unique, auto +from functools import cmp_to_key + +# TODO: Implement CacheItem.pop() to do lookup? + + +@unique +class Status(Enum): + """Cache item status.""" + + committed = auto() + to_add = auto() + to_update = auto() + to_remove = auto() + + +class DBCacheBase(dict): + """A dictionary that maps table names to ids to items. Used to store and retrieve database contents.""" + + def __init__(self, chunk_size=None): + super().__init__() + self._offsets = {} + self._fetched_item_types = set() + self._chunk_size = chunk_size + + def _item_factory(self, item_type): + raise NotImplementedError() + + def _query(self, item_type): + raise NotImplementedError() + + def make_item(self, item_type, **item): + factory = self._item_factory(item_type) + return factory(self, item_type, **item) + + def _cmp_item_type(self, a, b): + if a in self._item_factory(b).ref_types(): + # a should come before b + return -1 + if b in self._item_factory(a).ref_types(): + # a should come after b + return 1 + return 0 + + def _sorted_item_types(self): + sorted(self, key=cmp_to_key(self._cmp_item_type)) + + def commit(self): + to_add = {} + to_update = {} + to_remove = {} + for item_type in sorted(self, key=cmp_to_key(self._cmp_item_type)): + table_cache = self[item_type] + for item in dict.values(table_cache): + _ = item.is_valid() + if item.status == Status.to_add: + to_add.setdefault(item_type, []).append(item) + elif item.status == Status.to_update: + to_update.setdefault(item_type, []).append(item) + elif item.status == Status.to_remove: + to_remove.setdefault(item_type, set()).add(item["id"]) + item.status = Status.committed + if to_remove.get(item_type): + # Fetch descendants, so that they are validated in next iterations of the loop. + # This allows removal in cascade. + for x in self: + if self._cmp_item_type(item_type, x) < 0: + self.fetch_all(x) + return to_add, to_update, to_remove + + @property + def fetched_item_types(self): + return self._fetched_item_types + + def reset_queries(self): + """Resets queries and clears caches.""" + self._offsets.clear() + self._fetched_item_types.clear() + + def _get_next_chunk(self, item_type): + qry = self._query(item_type) + if not self._chunk_size: + self._fetched_item_types.add(item_type) + return [dict(x) for x in qry] + offset = self._offsets.setdefault(item_type, 0) + chunk = [dict(x) for x in qry.limit(self._chunk_size).offset(offset)] + self._offsets[item_type] += len(chunk) + return chunk + + def advance_query(self, item_type): + """Advances the DB query that fetches items of given type and caches the results. + + Args: + item_type (str) + + Returns: + list: items fetched from the DB + """ + chunk = self._get_next_chunk(item_type) + if not chunk: + self._fetched_item_types.add(item_type) + return [] + table_cache = self.table_cache(item_type) + for item in chunk: + # FIXME: This will overwrite working changes after a refresh + table_cache.add_item(item) + return chunk + + def table_cache(self, item_type): + return self.setdefault(item_type, _TableCache(self, item_type)) + + def get_item(self, item_type, id_): + table_cache = self.get(item_type, {}) + item = table_cache.get(id_) + if item is None: + return {} + return item + + def fetch_more(self, item_type): + if item_type in self._fetched_item_types: + return False + return bool(self.advance_query(item_type)) + + def fetch_all(self, item_type): + while self.fetch_more(item_type): + pass + + def fetch_ref(self, item_type, id_): + while self.fetch_more(item_type): + with suppress(KeyError): + return self[item_type][id_] + # It is possible that fetching was completed between deciding to call this function + # and starting the while loop above resulting in self.fetch_more() to return False immediately. + # Therefore, we should try one last time if the ref is available. + with suppress(KeyError): + return self[item_type][id_] + return None + + +class _TempId(int): + _next_id = {} + + def __new__(cls, item_type): + id_ = cls._next_id.setdefault(item_type, -1) + cls._next_id[item_type] -= 1 + return super().__new__(cls, id_) + + def __init__(self, item_type): + super().__init__() + self._item_type = item_type + self._value_binds = [] + self._tuple_value_binds = [] + self._key_binds = [] + self._tuple_key_binds = [] + + def add_value_bind(self, item, key): + self._value_binds.append((item, key)) + + def add_tuple_value_bind(self, item, key): + self._tuple_value_binds.append((item, key)) + + def add_key_bind(self, item): + self._key_binds.append(item) + + def add_tuple_key_bind(self, item, key): + self._tuple_key_binds.append((item, key)) + + def remove_key_bind(self, item): + self._key_binds.remove(item) + + def remove_tuple_key_bind(self, item, key): + self._tuple_key_binds.remove((item, key)) + + def resolve(self, new_id): + for item, key in self._value_binds: + item[key] = new_id + for item, key in self._tuple_value_binds: + item[key] = tuple(new_id if v is self else v for v in item[key]) + for item in self._key_binds: + if self in item: + item[new_id] = dict.pop(item, self, None) + for item, key in self._tuple_key_binds: + if key in item: + item[tuple(new_id if k is self else k for k in key)] = dict.pop(item, key, None) + + +class _TempIdDict(dict): + def __init__(self, **kwargs): + super().__init__(**kwargs) + for key, value in kwargs.items(): + self._bind(key, value) + + def __setitem__(self, key, value): + super().__setitem__(key, value) + self._bind(key, value) + + def __delitem__(self, key): + super().__delitem__(key) + self._unbind(key) + + def setdefault(self, key, default): + value = super().setdefault(key, default) + self._bind(key, value) + return value + + def update(self, other): + super().update(other) + for key, value in other.items(): + self._bind(key, value) + + def pop(self, key, default): + if key in self: + self._unbind(key) + return super().pop(key, default) + + def _bind(self, key, value): + if isinstance(value, _TempId): + value.add_value_bind(self, key) + elif isinstance(value, tuple): + for v in value: + if isinstance(v, _TempId): + v.add_tuple_value_bind(self, key) + elif isinstance(key, _TempId): + key.add_key_bind(self) + elif isinstance(key, tuple): + for k in key: + if isinstance(k, _TempId): + k.add_tuple_key_bind(self, key) + + def _unbind(self, key): + if isinstance(key, _TempId): + key.remove_key_bind(self) + elif isinstance(key, tuple): + for k in key: + if isinstance(k, _TempId): + k.remove_tuple_key_bind(self, key) + + +class _TableCache(_TempIdDict): + def __init__(self, db_cache, item_type, *args, **kwargs): + """ + Args: + db_cache (DBCache): the DB cache where this table cache belongs. + item_type (str): the item type, equal to a table name + """ + super().__init__(*args, **kwargs) + self._db_cache = db_cache + self._item_type = item_type + self._id_by_unique_key_value = {} + + def _new_id(self): + return _TempId(self._item_type) + + def unique_key_value_to_id(self, key, value, strict=False): + """Returns the id that has the given value for the given unique key, or None. + + Args: + key (tuple) + value (tuple) + + Returns: + int + """ + value = tuple(tuple(x) if isinstance(x, list) else x for x in value) + self._db_cache.fetch_all(self._item_type) + id_by_unique_value = self._id_by_unique_key_value.get(key, {}) + if strict: + return id_by_unique_value[value] + return id_by_unique_value.get(value) + + def _unique_key_value_to_item(self, key, value, strict=False): + return self.get(self.unique_key_value_to_id(key, value)) + + def values(self): + return (x for x in super().values() if x.is_valid()) + + def _make_item(self, item): + """Returns a cache item. + + Args: + item (dict): the 'db item' to use as base + + Returns: + CacheItem + """ + return self._db_cache.make_item(self._item_type, **item) + + def current_item(self, item, skip_keys=()): + """Returns a CacheItemBase that matches the given dictionary-item. + + Args: + item (dict) + + Returns: + CacheItemBase or None + """ + id_ = item.get("id") + if isinstance(id_, int): + # id is an int, easy + return self.get(id_) + if isinstance(id_, dict): + # id is a dict specifying the values for one of the unique constraints + key, value = zip(*id_.items()) + return self._unique_key_value_to_item(key, value) + if id_ is None: + # No id. Try to locate the item by the value of one of the unique keys. + # Used by import_data (and more...) + cache_item = self._make_item(item) + error = cache_item.resolve_inverse_references(item.keys()) + if error: + return None + error = cache_item.polish() + if error: + return None + for key, value in cache_item.unique_values(skip_keys=skip_keys): + current_item = self._unique_key_value_to_item(key, value) + if current_item: + return current_item + + def check_item(self, item, for_update=False, skip_keys=()): + # FIXME: The only use-case for skip_keys at the moment is that of importing scenario alternatives, + # where we only want to match by (scen_name, alt_name) and not by (scen_name, rank) + if for_update: + current_item = self.current_item(item, skip_keys=skip_keys) + if current_item is None: + return None, f"no {self._item_type} matching {item} to update" + full_item, merge_error = current_item.merge(item) + if full_item is None: + return None, merge_error + else: + current_item = None + full_item, merge_error = item, None + candidate_item = self._make_item(full_item) + error = candidate_item.resolve_inverse_references(skip_keys=item.keys()) + if error: + return None, error + error = candidate_item.polish() + if error: + return None, error + invalid_ref = candidate_item.invalid_ref() + if invalid_ref: + return None, f"invalid {invalid_ref} for {self._item_type}" + try: + for key, value in candidate_item.unique_values(skip_keys=skip_keys): + empty = {k for k, v in zip(key, value) if v == ""} + if empty: + return None, f"invalid empty keys {empty} for {self._item_type}" + unique_item = self._unique_key_value_to_item(key, value) + if unique_item not in (None, current_item) and unique_item.is_valid(): + return None, f"there's already a {self._item_type} with {dict(zip(key, value))}" + except KeyError as e: + return None, f"missing {e} for {self._item_type}" + if "id" not in candidate_item: + candidate_item["id"] = self._new_id() + return candidate_item, merge_error + + def _add_unique(self, item): + for key, value in item.unique_values(): + self._id_by_unique_key_value.setdefault(key, _TempIdDict())[value] = item["id"] + + def _remove_unique(self, item): + for key, value in item.unique_values(): + self._id_by_unique_key_value.get(key, {}).pop(value, None) + + def add_item(self, item, new=False): + if "id" not in item: + item["id"] = self._new_id() + self[item["id"]] = new_item = self._make_item(item) + self._add_unique(new_item) + if new: + new_item.status = Status.to_add + return new_item + + def update_item(self, item): + current_item = self[item["id"]] + self._remove_unique(current_item) + current_item.update(item) + self._add_unique(current_item) + current_item.cascade_update() + if current_item.status == Status.committed: + current_item.status = Status.to_update + return current_item + + def remove_item(self, id_): + current_item = self.get(id_) + if current_item is not None: + self._remove_unique(current_item) + current_item.cascade_remove() + return current_item + + def restore_item(self, id_): + current_item = self.get(id_) + if current_item is not None: + self._add_unique(current_item) + current_item.cascade_restore() + return current_item + + +class CacheItemBase(_TempIdDict): + """A dictionary that represents an db item.""" + + _defaults = {} + _unique_keys = () + _references = {} + _inverse_references = {} + + def __init__(self, db_cache, item_type, **kwargs): + """ + Args: + db_cache (DBCache): the DB cache where this item belongs. + """ + super().__init__(**kwargs) + self._db_cache = db_cache + self._item_type = item_type + self._referrers = _TempIdDict() + self._weak_referrers = _TempIdDict() + self.restore_callbacks = set() + self.update_callbacks = set() + self.remove_callbacks = set() + self._to_remove = False + self._removed = False + self._corrupted = False + self._valid = None + self.status = Status.committed + + @classmethod + def ref_types(cls): + return set(ref_type for _src_key, (ref_type, _ref_key) in cls._references.values()) + + @property + def removed(self): + return self._removed + + @property + def item_type(self): + return self._item_type + + @property + def key(self): + id_ = dict.get(self, "id") + if id_ is None: + return None + return (self._item_type, id_) + + def __repr__(self): + return f"{self._item_type}{self._extended()}" + + def __getattr__(self, name): + """Overridden method to return the dictionary key named after the attribute, or None if it doesn't exist.""" + # FIXME: We should try and get rid of this one + return self.get(name) + + def __getitem__(self, key): + ref = self._references.get(key) + if ref: + src_key, (ref_type, ref_key) = ref + ref_id = self[src_key] + if isinstance(ref_id, tuple): + return tuple(self._get_ref(ref_type, x).get(ref_key) for x in ref_id) + return self._get_ref(ref_type, ref_id).get(ref_key) + return super().__getitem__(key) + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + def update(self, other): + for src_key, (ref_type, _ref_key) in self._references.values(): + ref_id = self[src_key] + if src_key in other and other[src_key] != ref_id: + # Forget references + if isinstance(ref_id, tuple): + for x in ref_id: + self._forget_ref(ref_type, x) + else: + self._forget_ref(ref_type, ref_id) + super().update(other) + + def merge(self, other): + if all(self.get(key) == value for key, value in other.items()): + return None, "" + merged = {**self, **other} + merged["id"] = self["id"] + return merged, "" + + def polish(self): + """Polishes this item once all it's references are resolved. Returns any errors. + + Returns: + str or None + """ + for key, default_value in self._defaults.items(): + self.setdefault(key, default_value) + return "" + + def resolve_inverse_references(self, skip_keys=()): + for src_key, (id_key, (ref_type, ref_key)) in self._inverse_references.items(): + if src_key in skip_keys: + continue + id_value = tuple(dict.get(self, k) or self.get(k) for k in id_key) + if None in id_value: + continue + table_cache = self._db_cache.table_cache(ref_type) + try: + self[src_key] = ( + tuple(table_cache.unique_key_value_to_id(ref_key, v, strict=True) for v in zip(*id_value)) + if all(isinstance(v, (tuple, list)) for v in id_value) + else table_cache.unique_key_value_to_id(ref_key, id_value, strict=True) + ) + except KeyError as err: + # Happens at unique_key_value_to_id(..., strict=True) + return f"can't find {ref_type} with {dict(zip(ref_key, err.args[0]))}" + + def invalid_ref(self): + for src_key, (ref_type, _ref_key) in self._references.values(): + try: + ref_id = self[src_key] + except KeyError: + return src_key + if isinstance(ref_id, tuple): + for x in ref_id: + if not self._get_ref(ref_type, x): + return src_key + elif not self._get_ref(ref_type, ref_id): + return src_key + + def unique_values(self, skip_keys=()): + for key in self._unique_keys: + if key not in skip_keys: + yield key, tuple(self.get(k) for k in key) + + def _get_ref(self, ref_type, ref_id, strong=True): + ref = self._db_cache.get_item(ref_type, ref_id) + if not ref: + if not strong: + return {} + ref = self._db_cache.fetch_ref(ref_type, ref_id) + if not ref: + self._corrupted = True + return {} + return self._handle_ref(ref, strong) + + def _handle_ref(self, ref, strong): + if strong: + ref.add_referrer(self) + if ref.removed: + self._to_remove = True + else: + ref.add_weak_referrer(self) + if ref.removed: + return {} + return ref + + def _forget_ref(self, ref_type, ref_id): + ref = self._db_cache.get_item(ref_type, ref_id) + ref.remove_referrer(self) + + def is_valid(self): + if self._valid is not None: + return self._valid + if self._removed or self._corrupted: + return False + self._to_remove = False + self._corrupted = False + for key in self._references: + _ = self[key] + if self._to_remove: + self.cascade_remove() + self._valid = not self._removed and not self._corrupted + return self._valid + + def add_referrer(self, referrer): + if referrer.key is None: + return + self._referrers[referrer.key] = self._weak_referrers.pop(referrer.key, referrer) + + def remove_referrer(self, referrer): + if referrer.key is None: + return + self._referrers.pop(referrer.key, None) + + def add_weak_referrer(self, referrer): + if referrer.key is None: + return + if referrer.key not in self._referrers: + self._weak_referrers[referrer.key] = referrer + + def _update_weak_referrers(self): + for weak_referrer in self._weak_referrers.values(): + weak_referrer.call_update_callbacks() + + def cascade_restore(self): + if not self._removed: + return + if self.status == Status.committed: + self.status = Status.to_add + self._removed = False + for referrer in self._referrers.values(): + referrer.cascade_restore() + self._update_weak_referrers() + obsolete = set() + for callback in self.restore_callbacks: + if not callback(self): + obsolete.add(callback) + self.restore_callbacks -= obsolete + + def cascade_remove(self): + if self._removed: + return + if self.status == Status.committed: + self.status = Status.to_remove + else: + self.status = Status.committed + self._removed = True + self._to_remove = False + self._valid = None + obsolete = set() + for callback in self.remove_callbacks: + if not callback(self): + obsolete.add(callback) + self.remove_callbacks -= obsolete + for referrer in self._referrers.values(): + referrer.cascade_remove() + self._update_weak_referrers() + + def cascade_update(self): + self.call_update_callbacks() + for referrer in self._referrers.values(): + referrer.cascade_update() + self._update_weak_referrers() + + def call_update_callbacks(self): + self.pop("parsed_value", None) + obsolete = set() + for callback in self.update_callbacks: + if not callback(self): + obsolete.add(callback) + self.update_callbacks -= obsolete + + def _extended(self): + return {**self, **{key: self[key] for key in self._references}} + + def _asdict(self): + return dict(self) + + def is_committed(self): + return self.status == Status.committed diff --git a/spinedb_api/db_cache_impl.py b/spinedb_api/db_cache_impl.py new file mode 100644 index 00000000..30a2a345 --- /dev/null +++ b/spinedb_api/db_cache_impl.py @@ -0,0 +1,412 @@ +###################################################################################################################### +# Copyright (C) 2017-2022 Spine project consortium +# This file is part of Spine Database API. +# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser +# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General +# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with +# this program. If not, see . +###################################################################################################################### +""" +DB cache implementation. + +""" +import uuid +from operator import itemgetter +from .parameter_value import from_database +from .db_cache_base import DBCacheBase, CacheItemBase + + +class DBCache(DBCacheBase): + def __init__(self, db_map, chunk_size=None): + """ + Args: + db_map (DatabaseMapping) + """ + super().__init__(chunk_size=chunk_size) + self._db_map = db_map + + def _item_factory(self, item_type): + return { + "entity_class": EntityClassItem, + "entity": EntityItem, + "entity_group": EntityGroupItem, + "parameter_definition": ParameterDefinitionItem, + "parameter_value": ParameterValueItem, + "parameter_value_list": ParameterValueListItem, + "list_value": ListValueItem, + "alternative": AlternativeItem, + "scenario": ScenarioItem, + "scenario_alternative": ScenarioAlternativeItem, + "metadata": MetadataItem, + "entity_metadata": EntityMetadataItem, + "parameter_value_metadata": ParameterValueMetadataItem, + }.get(item_type, CacheItemBase) + + def _query(self, item_type): + sq_name = { + "entity_class": "wide_entity_class_sq", + "entity": "wide_entity_sq", + "parameter_value_list": "parameter_value_list_sq", + "list_value": "list_value_sq", + "alternative": "alternative_sq", + "scenario": "scenario_sq", + "scenario_alternative": "scenario_alternative_sq", + "entity_group": "entity_group_sq", + "parameter_definition": "parameter_definition_sq", + "parameter_value": "parameter_value_sq", + "metadata": "metadata_sq", + "entity_metadata": "entity_metadata_sq", + "parameter_value_metadata": "parameter_value_metadata_sq", + "commit": "commit_sq", + }[item_type] + return self._db_map.query(getattr(self._db_map, sq_name)) + + +class EntityClassItem(CacheItemBase): + _defaults = {"description": None, "display_icon": None} + _unique_keys = (("name",),) + _references = {"dimension_name_list": ("dimension_id_list", ("entity_class", "name"))} + _inverse_references = {"dimension_id_list": (("dimension_name_list",), ("entity_class", ("name",)))} + + def __init__(self, *args, **kwargs): + dimension_id_list = kwargs.get("dimension_id_list") + if dimension_id_list is None: + dimension_id_list = () + if isinstance(dimension_id_list, str): + dimension_id_list = (int(id_) for id_ in dimension_id_list.split(",")) + kwargs["dimension_id_list"] = tuple(dimension_id_list) + super().__init__(*args, **kwargs) + + def merge(self, other): + dimension_id_list = other.pop("dimension_id_list", None) + error = ( + "can't modify dimensions of an entity class" + if dimension_id_list is not None and dimension_id_list != self["dimension_id_list"] + else "" + ) + merged, super_error = super().merge(other) + return merged, " and ".join([x for x in (super_error, error) if x]) + + +class EntityItem(CacheItemBase): + _defaults = {"description": None} + _unique_keys = (("class_name", "name"), ("class_name", "byname")) + _references = { + "class_name": ("class_id", ("entity_class", "name")), + "dimension_id_list": ("class_id", ("entity_class", "dimension_id_list")), + "dimension_name_list": ("class_id", ("entity_class", "dimension_name_list")), + "element_name_list": ("element_id_list", ("entity", "name")), + } + _inverse_references = { + "class_id": (("class_name",), ("entity_class", ("name",))), + "element_id_list": (("dimension_name_list", "element_name_list"), ("entity", ("class_name", "name"))), + } + + def __init__(self, *args, **kwargs): + element_id_list = kwargs.get("element_id_list") + if element_id_list is None: + element_id_list = () + if isinstance(element_id_list, str): + element_id_list = (int(id_) for id_ in element_id_list.split(",")) + kwargs["element_id_list"] = tuple(element_id_list) + super().__init__(*args, **kwargs) + + def __getitem__(self, key): + if key == "byname": + return self["element_name_list"] or (self["name"],) + return super().__getitem__(key) + + def polish(self): + error = super().polish() + if error: + return error + if "name" in self: + return + base_name = self["class_name"] + "_" + "__".join(self["element_name_list"]) + name = base_name + table_cache = self._db_cache.table_cache(self._item_type) + while table_cache.unique_key_value_to_id(("class_name", "name"), (self["class_name"], name)) is not None: + name = base_name + "_" + uuid.uuid4().hex + self["name"] = name + + +class EntityGroupItem(CacheItemBase): + _unique_keys = (("group_name", "member_name"),) + _references = { + "class_name": ("entity_class_id", ("entity_class", "name")), + "group_name": ("entity_id", ("entity", "name")), + "member_name": ("member_id", ("entity", "name")), + "dimension_id_list": ("entity_class_id", ("entity_class", "dimension_id_list")), + } + _inverse_references = { + "entity_class_id": (("class_name",), ("entity_class", ("name",))), + "entity_id": (("class_name", "group_name"), ("entity", ("class_name", "name"))), + "member_id": (("class_name", "member_name"), ("entity", ("class_name", "name"))), + } + + def __getitem__(self, key): + if key == "class_id": + return self["entity_class_id"] + if key == "group_id": + return self["entity_id"] + return super().__getitem__(key) + + +class ParameterDefinitionItem(CacheItemBase): + _defaults = {"description": None, "default_value": None, "default_type": None, "parameter_value_list_id": None} + _unique_keys = (("entity_class_name", "name"),) + _references = { + "entity_class_name": ("entity_class_id", ("entity_class", "name")), + "dimension_id_list": ("entity_class_id", ("entity_class", "dimension_id_list")), + "dimension_name_list": ("entity_class_id", ("entity_class", "dimension_name_list")), + } + _inverse_references = { + "entity_class_id": (("entity_class_name",), ("entity_class", ("name",))), + "parameter_value_list_id": (("parameter_value_list_name",), ("parameter_value_list", ("name",))), + } + + @property + def list_value_id(self): + if dict.get(self, "default_type") == "list_value_ref": + return int(dict.__getitem__(self, "default_value")) + return None + + def __getitem__(self, key): + if key == "parameter_name": + return super().__getitem__("name") + if key == "value_list_id": + return super().__getitem__("parameter_value_list_id") + if key == "parameter_value_list_id": + return dict.get(self, key) + if key == "parameter_value_list_name": + return self._get_ref("parameter_value_list", self["parameter_value_list_id"], strong=False).get("name") + if key in ("default_value", "default_type"): + list_value_id = self.list_value_id + if list_value_id is not None: + list_value_key = {"default_value": "value", "default_type": "type"}[key] + return self._get_ref("list_value", list_value_id, strong=False).get(list_value_key) + return dict.get(self, key) + if key == "list_value_id": + return self.list_value_id + return super().__getitem__(key) + + def polish(self): + error = super().polish() + if error: + return error + default_type = self["default_type"] + default_value = self["default_value"] + list_name = self["parameter_value_list_name"] + if list_name is None: + return + if default_type == "list_value_ref": + return + parsed_value = from_database(default_value, default_type) + if parsed_value is None: + return + list_value_id = self._db_cache.table_cache("list_value").unique_key_value_to_id( + ("parameter_value_list_name", "value", "type"), (list_name, default_value, default_type) + ) + if list_value_id is None: + return f"default value {parsed_value} of {self['name']} is not in {list_name}" + self["default_value"] = list_value_id + self["default_type"] = "list_value_ref" + + def _asdict(self): + d = super()._asdict() + if d.get("default_type") == "list_value_ref": + d["default_value"] = str(d["default_value"]).encode() + return d + + def merge(self, other): + parameter_value_list_id = other.get("parameter_value_list_id") + if ( + parameter_value_list_id is not None + and parameter_value_list_id != self["parameter_value_list_id"] + and any( + x["parameter_definition_id"] == self["id"] + for x in self._db_cache.table_cache("parameter_value").values() + ) + ): + del other["parameter_value_list_id"] + error = "can't modify the parameter value list of a parameter that already has values" + else: + error = "" + merged, super_error = super().merge(other) + return merged, " and ".join([x for x in (super_error, error) if x]) + + +class ParameterValueItem(CacheItemBase): + _unique_keys = (("parameter_definition_name", "entity_byname", "alternative_name"),) + _references = { + "entity_class_name": ("entity_class_id", ("entity_class", "name")), + "dimension_id_list": ("entity_class_id", ("entity_class", "dimension_id_list")), + "dimension_name_list": ("entity_class_id", ("entity_class", "dimension_name_list")), + "parameter_definition_name": ("parameter_definition_id", ("parameter_definition", "name")), + "parameter_value_list_id": ("parameter_definition_id", ("parameter_definition", "parameter_value_list_id")), + "parameter_value_list_name": ("parameter_definition_id", ("parameter_definition", "parameter_value_list_name")), + "entity_name": ("entity_id", ("entity", "name")), + "entity_byname": ("entity_id", ("entity", "byname")), + "element_id_list": ("entity_id", ("entity", "element_id_list")), + "element_name_list": ("entity_id", ("entity", "element_name_list")), + "alternative_name": ("alternative_id", ("alternative", "name")), + } + _inverse_references = { + "entity_class_id": (("entity_class_name",), ("entity_class", ("name",))), + "parameter_definition_id": ( + ("entity_class_name", "parameter_definition_name"), + ("parameter_definition", ("entity_class_name", "name")), + ), + "entity_id": (("entity_class_name", "entity_byname"), ("entity", ("class_name", "byname"))), + "alternative_id": (("alternative_name",), ("alternative", ("name",))), + } + + @property + def list_value_id(self): + if dict.__getitem__(self, "type") == "list_value_ref": + return int(dict.__getitem__(self, "value")) + return None + + def __getitem__(self, key): + if key == "parameter_id": + return super().__getitem__("parameter_definition_id") + if key == "parameter_name": + return super().__getitem__("parameter_definition_name") + if key in ("value", "type"): + list_value_id = self.list_value_id + if list_value_id: + return self._get_ref("list_value", list_value_id, strong=False).get(key) + if key == "list_value_id": + return self.list_value_id + return super().__getitem__(key) + + def polish(self): + list_name = self["parameter_value_list_name"] + if list_name is None: + return + type_ = self["type"] + if type_ == "list_value_ref": + return + value = self["value"] + parsed_value = from_database(value, type_) + if parsed_value is None: + return + list_value_id = self._db_cache.table_cache("list_value").unique_key_value_to_id( + ("parameter_value_list_name", "value", "type"), (list_name, value, type_) + ) + if list_value_id is None: + return ( + f"value {parsed_value} of {self['parameter_definition_name']} for {self['entity_byname']} " + "is not in {list_name}" + ) + self["value"] = list_value_id + self["type"] = "list_value_ref" + + def _asdict(self): + d = super()._asdict() + if d.get("type") == "list_value_ref": + d["value"] = str(d["value"]).encode() + return d + + +class ParameterValueListItem(CacheItemBase): + _unique_keys = (("name",),) + + +class ListValueItem(CacheItemBase): + _unique_keys = (("parameter_value_list_name", "value", "type"), ("parameter_value_list_name", "index")) + _references = {"parameter_value_list_name": ("parameter_value_list_id", ("parameter_value_list", "name"))} + _inverse_references = { + "parameter_value_list_id": (("parameter_value_list_name",), ("parameter_value_list", ("name",))), + } + + +class AlternativeItem(CacheItemBase): + _defaults = {"description": None} + _unique_keys = (("name",),) + + +class ScenarioItem(CacheItemBase): + _defaults = {"active": False, "description": None} + _unique_keys = (("name",),) + + @property + def sorted_alternatives(self): + self._db_cache.fetch_all("scenario_alternative") + return sorted( + (x for x in self._db_cache.get("scenario_alternative", {}).values() if x["scenario_id"] == self["id"]), + key=itemgetter("rank"), + ) + + def __getitem__(self, key): + if key == "alternative_id_list": + return [x["alternative_id"] for x in self.sorted_alternatives] + if key == "alternative_name_list": + return [x["alternative_name"] for x in self.sorted_alternatives] + return super().__getitem__(key) + + +class ScenarioAlternativeItem(CacheItemBase): + _unique_keys = (("scenario_name", "alternative_name"), ("scenario_name", "rank")) + _references = { + "scenario_name": ("scenario_id", ("scenario", "name")), + "alternative_name": ("alternative_id", ("alternative", "name")), + } + _inverse_references = { + "scenario_id": (("scenario_name",), ("scenario", ("name",))), + "alternative_id": (("alternative_name",), ("alternative", ("name",))), + } + + def __getitem__(self, key): + # The 'before' is to be interpreted as, this scenario alternative goes *before* the before_alternative. + # Since ranks go from 1 to the alternative count, the first alternative will have the second as the 'before', + # the second will have the third, etc, and the last will have None. + # Note that alternatives with higher ranks overwrite the values of those with lower ranks. + if key == "before_alternative_name": + return self._get_ref("alternative", self["before_alternative_id"], strong=False).get("name") + if key == "before_alternative_id": + scenario = self._get_ref("scenario", self["scenario_id"], strong=False) + try: + return scenario["alternative_id_list"][self["rank"]] + except IndexError: + return None + return super().__getitem__(key) + + +class MetadataItem(CacheItemBase): + _unique_keys = (("name", "value"),) + + +class EntityMetadataItem(CacheItemBase): + _unique_keys = (("entity_name", "metadata_name", "metadata_value"),) + _references = { + "entity_name": ("entity_id", ("entity", "name")), + "metadata_name": ("metadata_id", ("metadata", "name")), + "metadata_value": ("metadata_id", ("metadata", "value")), + } + _inverse_references = { + "entity_id": (("entity_class_name", "entity_byname"), ("entity", ("class_name", "byname"))), + "metadata_id": (("metadata_name", "metadata_value"), ("metadata", ("name", "value"))), + } + + +class ParameterValueMetadataItem(CacheItemBase): + _unique_keys = ( + ("parameter_definition_name", "entity_byname", "alternative_name", "metadata_name", "metadata_value"), + ) + _references = { + "parameter_definition_name": ("parameter_value_id", ("parameter_value", "parameter_definition_name")), + "entity_byname": ("parameter_value_id", ("parameter_value", "entity_byname")), + "alternative_name": ("parameter_value_id", ("parameter_value", "alternative_name")), + "metadata_name": ("metadata_id", ("metadata", "name")), + "metadata_value": ("metadata_id", ("metadata", "value")), + } + _inverse_references = { + "parameter_value_id": ( + ("parameter_definition_name", "entity_byname", "alternative_name"), + ("parameter_value", ("parameter_definition_name", "entity_byname", "alternative_name")), + ), + "metadata_id": (("metadata_name", "metadata_value"), ("metadata", ("name", "value"))), + } diff --git a/spinedb_api/db_mapping_add_mixin.py b/spinedb_api/db_mapping_add_mixin.py index ee7775f6..734ff2cd 100644 --- a/spinedb_api/db_mapping_add_mixin.py +++ b/spinedb_api/db_mapping_add_mixin.py @@ -16,7 +16,7 @@ from sqlalchemy.exc import DBAPIError from .exception import SpineIntegrityError -from .helpers import convert_legacy +from .query import Query class DatabaseMappingAddMixin: @@ -41,11 +41,11 @@ def add_items(self, tablename, *items, check=True, strict=False): table_cache = self.cache.table_cache(tablename) if not check: for item in items: - convert_legacy(tablename, item) + self._convert_legacy(tablename, item) added.append(table_cache.add_item(item, new=True)._asdict()) else: for item in items: - convert_legacy(tablename, item) + self._convert_legacy(tablename, item) checked_item, error = table_cache.check_item(item) if error: if strict: @@ -62,12 +62,25 @@ def _do_add_items(self, connection, tablename, *items_to_add): return try: table = self._metadata.tables[self._real_tablename(tablename)] + id_items, temp_id_items = [], [] for item in items_to_add: - item = item._asdict() - temp_id = item.pop("id") if hasattr(item["id"], "resolve") else None - id_ = connection.execute(table.insert(), item).inserted_primary_key[0] - if temp_id: + if hasattr(item["id"], "resolve"): + temp_id_items.append(item) + else: + id_items.append(item) + if id_items: + connection.execute(table.insert(), [x._asdict() for x in id_items]) + if temp_id_items: + current_ids = {x["id"] for x in Query(connection.execute, table)} + next_id = max(current_ids, default=0) + 1 + available_ids = set(range(1, next_id)) - current_ids + missing_id_count = len(temp_id_items) - len(available_ids) + new_ids = set(range(next_id, next_id + missing_id_count)) + ids = sorted(available_ids | new_ids) + for id_, item in zip(ids, temp_id_items): + temp_id = item["id"] temp_id.resolve(id_) + connection.execute(table.insert(), [x._asdict() for x in temp_id_items]) for tablename_, items_to_add_ in self._extra_items_to_add_per_table(tablename, items_to_add): if not items_to_add_: continue diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 841b9a13..a913bdfc 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -177,8 +177,6 @@ def __init__( self._relationship_parameter_value_sq = None self._ext_parameter_value_metadata_sq = None self._ext_entity_metadata_sq = None - # Import alternative suff - self._import_alternative_id = None self._import_alternative_name = None self._table_to_sq_attr = {} # Table primary ids map: @@ -195,33 +193,6 @@ def __init__( "entity_alternative": ("entity_id", "alternative_id"), "entity_class_dimension": ("entity_class_id", "position"), } - self.ancestor_tablenames = { - "scenario_alternative": ("scenario", "alternative"), - "entity": ("entity_class",), - "entity_group": ("entity_class", "entity"), - "parameter_definition": ("entity_class", "parameter_value_list", "list_value"), - "parameter_value": ( - "alternative", - "entity_class", - "entity", - "parameter_definition", - "parameter_value_list", - "list_value", - ), - "entity_metadata": ("metadata", "entity_class", "entity"), - "parameter_value_metadata": ( - "metadata", - "parameter_value", - "parameter_definition", - "entity_class", - "entity", - "alternative", - ), - "list_value": ("parameter_value_list",), - } - self.descendant_tablenames = { - tablename: set(self._descendant_tablenames(tablename)) for tablename in self.ITEM_TYPES - } def __enter__(self): return self @@ -252,24 +223,6 @@ def reconnect(self): self.executor = self._make_executor() self.connection = self.executor.submit(self.engine.connect).result() - def _descendant_tablenames(self, tablename): - child_tablenames = { - "alternative": ("parameter_value", "scenario_alternative"), - "scenario": ("scenario_alternative",), - "entity_class": ("entity", "parameter_definition"), - "entity": ("parameter_value", "entity_group", "entity_metadata"), - "parameter_definition": ("parameter_value",), - "parameter_value_list": (), - "parameter_value": ("parameter_value_metadata", "entity_metadata"), - "entity_metadata": ("metadata",), - "parameter_value_metadata": ("metadata",), - } - for parent, children in child_tablenames.items(): - if tablename == parent: - for child in children: - yield child - yield from self._descendant_tablenames(child) - def _real_tablename(self, tablename): return { "object_class": "entity_class", @@ -453,7 +406,7 @@ def query(self, *args, **kwargs): db_map.object_sq.c.class_id == db_map.object_class_sq.c.id ).group_by(db_map.object_class_sq.c.name).all() """ - return Query(self, *args) + return Query(self.connection_execute, *args) def _subquery(self, tablename): """A subquery of the form: @@ -1756,31 +1709,19 @@ def _make_scenario_alternative_sq(self): """ return self._subquery("scenario_alternative") - def get_import_alternative(self): - """Returns the id of the alternative to use as default for all import operations. + def get_import_alternative_name(self): + """Returns the name of the alternative to use as default for all import operations. Returns: - int, str + str """ - if self._import_alternative_id is None: + if self._import_alternative_name is None: self._create_import_alternative() - return self._import_alternative_id, self._import_alternative_name + return self._import_alternative_name def _create_import_alternative(self): """Creates the alternative to be used as default for all import operations.""" - self.fetch_all({"alternative"}) self._import_alternative_name = "Base" - self._import_alternative_id = next( - ( - id_ - for id_, alt in self.cache.get("alternative", {}).items() - if alt.name == self._import_alternative_name - ), - None, - ) - if not self._import_alternative_id: - ids = self._add_alternatives({"name": self._import_alternative_name}) - self._import_alternative_id = next(iter(ids)) def override_create_import_alternative(self, method): """ @@ -1790,7 +1731,7 @@ def override_create_import_alternative(self, method): method (Callable) """ self._create_import_alternative = MethodType(method, self) - self._import_alternative_id = None + self._import_alternative_name = None def override_entity_class_sq_maker(self, method): """ @@ -1923,18 +1864,9 @@ def _reset_mapping(self): self.connection_execute(table.delete()) self.connection_execute("INSERT INTO alternative VALUES (1, 'Base', 'Base alternative', null)") - def fetch_all(self, tablenames, include_descendants=False, include_ancestors=False, force_tablenames=None): - if include_descendants: - tablenames |= { - descendant for tablename in tablenames for descendant in self.descendant_tablenames.get(tablename, ()) - } - if include_ancestors: - tablenames |= { - ancestor for tablename in tablenames for ancestor in self.ancestor_tablenames.get(tablename, ()) - } - if force_tablenames: - tablenames |= force_tablenames - for tablename in tablenames & set(self.ITEM_TYPES): + def fetch_all(self, tablenames=None): + tablenames = set(self.ITEM_TYPES) if tablenames is None else tablenames & set(self.ITEM_TYPES) + for tablename in tablenames: self.cache.fetch_all(tablename) def _object_class_id(self): @@ -2012,6 +1944,31 @@ def advance_cache_query(self, item_type, callback=None): future = self.executor.submit(self.cache.advance_query, item_type) future.add_done_callback(lambda future: callback(future.result())) + @staticmethod + def _convert_legacy(tablename, item): + if tablename in ("entity_class", "entity"): + object_class_id_list = tuple(item.pop("object_class_id_list", ())) + if object_class_id_list: + item["dimension_id_list"] = object_class_id_list + object_class_name_list = tuple(item.pop("object_class_name_list", ())) + if object_class_name_list: + item["dimension_name_list"] = object_class_name_list + if tablename == "entity": + object_id_list = tuple(item.pop("object_id_list", ())) + if object_id_list: + item["element_id_list"] = object_id_list + object_name_list = tuple(item.pop("object_name_list", ())) + if object_name_list: + item["element_name_list"] = object_name_list + if tablename in ("parameter_definition", "parameter_value"): + entity_class_id = item.pop("object_class_id", None) or item.pop("relationship_class_id", None) + if entity_class_id: + item["entity_class_id"] = entity_class_id + if tablename == "parameter_value": + entity_id = item.pop("object_id", None) or item.pop("relationship_id", None) + if entity_id: + item["entity_id"] = entity_id + def __del__(self): try: self.close() diff --git a/spinedb_api/db_mapping_remove_mixin.py b/spinedb_api/db_mapping_remove_mixin.py index 07fdc0b7..71d6d37c 100644 --- a/spinedb_api/db_mapping_remove_mixin.py +++ b/spinedb_api/db_mapping_remove_mixin.py @@ -50,8 +50,7 @@ def _do_remove_items(self, connection, **kwargs): Args: **kwargs: keyword is table name, argument is list of ids to remove """ - cascading_ids = self.cascading_ids(**kwargs) - for tablename, ids in cascading_ids.items(): + for tablename, ids in kwargs.items(): tablename = self._real_tablename(tablename) if tablename == "alternative": # Do not remove the Base alternative @@ -67,180 +66,7 @@ def _do_remove_items(self, connection, **kwargs): msg = f"DBAPIError while removing {tablename} items: {e.orig.args}" raise SpineDBAPIError(msg) from e - # pylint: disable=redefined-builtin - def cascading_ids(self, **kwargs): - """Returns cascading ids. - - Keyword args: - **kwargs: set of ids keyed by table name to be removed - - Returns: - cascading_ids (dict): cascading ids keyed by table name - """ - for new_tablename, old_tablenames in ( - ("entity_class", {"object_class", "relationship_class"}), - ("entity", {"object", "relationship"}), - ): - for old_tablename in old_tablenames: - ids = kwargs.pop(old_tablename, None) - if ids is not None: - # FIXME: Add deprecation warning - kwargs.setdefault(new_tablename, set()).update(ids) - self.fetch_all( - set(kwargs), - include_descendants=True, - force_tablenames={"entity_metadata", "parameter_value_metadata"} - if any(x in kwargs for x in ("entity_metadata", "parameter_value_metadata", "metadata")) - else None, - ) - ids = {} - self._merge(ids, self._entity_class_cascading_ids(kwargs.get("entity_class", set()))) - self._merge(ids, self._entity_cascading_ids(kwargs.get("entity", set()))) - self._merge(ids, self._entity_group_cascading_ids(kwargs.get("entity_group", set()))) - self._merge(ids, self._parameter_definition_cascading_ids(kwargs.get("parameter_definition", set()))) - self._merge(ids, self._parameter_value_cascading_ids(kwargs.get("parameter_value", set()))) - self._merge(ids, self._parameter_value_list_cascading_ids(kwargs.get("parameter_value_list", set()))) - self._merge(ids, self._list_value_cascading_ids(kwargs.get("list_value", set()))) - self._merge(ids, self._alternative_cascading_ids(kwargs.get("alternative", set()))) - self._merge(ids, self._scenario_cascading_ids(kwargs.get("scenario", set()))) - self._merge(ids, self._scenario_alternatives_cascading_ids(kwargs.get("scenario_alternative", set()))) - self._merge(ids, self._metadata_cascading_ids(kwargs.get("metadata", set()))) - self._merge(ids, self._entity_metadata_cascading_ids(kwargs.get("entity_metadata", set()))) - self._merge(ids, self._parameter_value_metadata_cascading_ids(kwargs.get("parameter_value_metadata", set()))) - sorted_ids = {} - while ids: - tablename = next(iter(ids)) - self._move(tablename, ids, sorted_ids) - return sorted_ids - - def _move(self, tablename, unsorted, sorted_): - for ancestor in self.ancestor_tablenames.get(tablename, ()): - self._move(ancestor, unsorted, sorted_) - to_move = unsorted.pop(tablename, None) - if to_move: - sorted_[tablename] = to_move - - @staticmethod - def _merge(left, right): - for tablename, ids in right.items(): - left.setdefault(tablename, set()).update(ids) - - def _alternative_cascading_ids(self, ids): - """Returns alternative cascading ids.""" - cache = self.cache - cascading_ids = {"alternative": set(ids)} - entity_alternatives = (x for x in dict.values(cache.get("entity_alternative", {})) if x.alternative_id in ids) - parameter_values = (x for x in dict.values(cache.get("parameter_value", {})) if x.alternative_id in ids) - scenario_alternatives = ( - x for x in dict.values(cache.get("scenario_alternative", {})) if x.alternative_id in ids - ) - self._merge(cascading_ids, self._entity_alternative_cascading_ids({x.id for x in entity_alternatives})) - self._merge(cascading_ids, self._parameter_value_cascading_ids({x.id for x in parameter_values})) - self._merge(cascading_ids, self._scenario_alternatives_cascading_ids({x.id for x in scenario_alternatives})) - return cascading_ids - - def _scenario_cascading_ids(self, ids): - cache = self.cache - cascading_ids = {"scenario": set(ids)} - scenario_alternatives = [x for x in dict.values(cache.get("scenario_alternative", {})) if x.scenario_id in ids] - self._merge(cascading_ids, self._scenario_alternatives_cascading_ids({x.id for x in scenario_alternatives})) - return cascading_ids - - def _entity_class_cascading_ids(self, ids): - """Returns entity class cascading ids.""" - if not ids: - return {} - cache = self.cache - cascading_ids = {"entity_class": set(ids), "entity_class_dimension": set(ids)} - entities = [x for x in dict.values(cache.get("entity", {})) if x.class_id in ids] - entity_classes = ( - x for x in dict.values(cache.get("entity_class", {})) if set(x.dimension_id_list).intersection(ids) - ) - paramerer_definitions = [ - x for x in dict.values(cache.get("parameter_definition", {})) if x.entity_class_id in ids - ] - self._merge(cascading_ids, self._entity_cascading_ids({x.id for x in entities})) - self._merge(cascading_ids, self._entity_class_cascading_ids({x.id for x in entity_classes})) - self._merge(cascading_ids, self._parameter_definition_cascading_ids({x.id for x in paramerer_definitions})) - return cascading_ids - - def _entity_cascading_ids(self, ids): - """Returns entity cascading ids.""" - if not ids: - return {} - cache = self.cache - cascading_ids = {"entity": set(ids), "entity_element": set(ids)} - entities = (x for x in dict.values(cache.get("entity", {})) if set(x.element_id_list).intersection(ids)) - entity_alternatives = (x for x in dict.values(cache.get("entity_alternative", {})) if x.entity_id in ids) - parameter_values = (x for x in dict.values(cache.get("parameter_value", {})) if x.entity_id in ids) - groups = (x for x in dict.values(cache.get("entity_group", {})) if {x.group_id, x.member_id}.intersection(ids)) - entity_metadata_ids = {x.id for x in dict.values(cache.get("entity_metadata", {})) if x.entity_id in ids} - self._merge(cascading_ids, self._entity_cascading_ids({x.id for x in entities})) - self._merge(cascading_ids, self._entity_alternative_cascading_ids({x.id for x in entity_alternatives})) - self._merge(cascading_ids, self._parameter_value_cascading_ids({x.id for x in parameter_values})) - self._merge(cascading_ids, self._entity_group_cascading_ids({x.id for x in groups})) - self._merge(cascading_ids, self._entity_metadata_cascading_ids(entity_metadata_ids)) - return cascading_ids - - def _entity_alternative_cascading_ids(self, ids): - return {"entity_alternative": set(ids)} - - def _entity_group_cascading_ids(self, ids): # pylint: disable=no-self-use - """Returns entity group cascading ids.""" - return {"entity_group": set(ids)} - - def _parameter_definition_cascading_ids(self, ids): - """Returns parameter definition cascading ids.""" - cache = self.cache - cascading_ids = {"parameter_definition": set(ids)} - parameter_values = [x for x in dict.values(cache.get("parameter_value", {})) if x.parameter_id in ids] - self._merge(cascading_ids, self._parameter_value_cascading_ids({x.id for x in parameter_values})) - return cascading_ids - - def _parameter_value_cascading_ids(self, ids): # pylint: disable=no-self-use - """Returns parameter value cascading ids.""" - cache = self.cache - cascading_ids = {"parameter_value": set(ids)} - value_metadata_ids = { - x.id for x in dict.values(cache.get("parameter_value_metadata", {})) if x.parameter_value_id in ids - } - self._merge(cascading_ids, self._parameter_value_metadata_cascading_ids(value_metadata_ids)) - return cascading_ids - - def _parameter_value_list_cascading_ids(self, ids): # pylint: disable=no-self-use - """Returns parameter value list cascading ids and adds them to the given dictionaries.""" - cascading_ids = {"parameter_value_list": set(ids)} - return cascading_ids - - def _list_value_cascading_ids(self, ids): # pylint: disable=no-self-use - """Returns parameter value list value cascading ids.""" - return {"list_value": set(ids)} - - def _scenario_alternatives_cascading_ids(self, ids): - return {"scenario_alternative": set(ids)} - - def _metadata_cascading_ids(self, ids): - cache = self.cache - cascading_ids = {"metadata": set(ids)} - entity_metadata = { - "entity_metadata": {x.id for x in dict.values(cache.get("entity_metadata", {})) if x.metadata_id in ids} - } - self._merge(cascading_ids, entity_metadata) - value_metadata = { - "parameter_value_metadata": { - x.id for x in dict.values(cache.get("parameter_value_metadata", {})) if x.metadata_id in ids - } - } - self._merge(cascading_ids, value_metadata) - return cascading_ids - - def _entity_metadata_cascading_ids(self, ids): - return {"entity_metadata": set(ids)} - - def _parameter_value_metadata_cascading_ids(self, ids): - return {"parameter_value_metadata": set(ids)} - - def get_metadata_ids_to_remove(self): + def _get_metadata_ids_to_remove(self): used_metadata_ids = set() for x in self.cache.get("entity_metadata", {}).values(): used_metadata_ids.add(x["metadata_id"]) diff --git a/spinedb_api/db_mapping_update_mixin.py b/spinedb_api/db_mapping_update_mixin.py index 30a2464d..afeb64d1 100644 --- a/spinedb_api/db_mapping_update_mixin.py +++ b/spinedb_api/db_mapping_update_mixin.py @@ -15,7 +15,6 @@ from sqlalchemy.exc import DBAPIError from sqlalchemy.sql.expression import bindparam from .exception import SpineIntegrityError -from .helpers import convert_legacy class DatabaseMappingUpdateMixin: @@ -92,17 +91,16 @@ def update_items(self, tablename, *items, check=True, strict=False): table_cache = self.cache.table_cache(tablename) if not check: for item in items: - convert_legacy(tablename, item) + self._convert_legacy(tablename, item) updated.append(table_cache.update_item(item)._asdict()) else: for item in items: - convert_legacy(tablename, item) + self._convert_legacy(tablename, item) checked_item, error = table_cache.check_item(item, for_update=True) if error: if strict: raise SpineIntegrityError(error) errors.append(error) - continue if checked_item: item = checked_item._asdict() updated.append(table_cache.update_item(item)._asdict()) diff --git a/spinedb_api/export_functions.py b/spinedb_api/export_functions.py index 8d79e6fb..93b13d97 100644 --- a/spinedb_api/export_functions.py +++ b/spinedb_api/export_functions.py @@ -99,7 +99,7 @@ def export_data( def _get_items(db_map, tablename, ids): if not ids: return () - db_map.fetch_all({tablename}, include_ancestors=True) + db_map.fetch_all({tablename}) _process_item = _make_item_processor(db_map, tablename) for item in _get_items_from_cache(db_map.cache, tablename, ids): yield from _process_item(item) @@ -118,7 +118,7 @@ def _get_items_from_cache(cache, tablename, ids): def _make_item_processor(db_map, tablename): if tablename == "parameter_value_list": - db_map.fetch_all({"list_value"}, include_ancestors=True) + db_map.fetch_all({"list_value"}) return _ParameterValueListProcessor(db_map.cache.get("list_value", {}).values()) return lambda item: (item,) diff --git a/spinedb_api/filters/execution_filter.py b/spinedb_api/filters/execution_filter.py index 078e53bb..8e53e1ea 100644 --- a/spinedb_api/filters/execution_filter.py +++ b/spinedb_api/filters/execution_filter.py @@ -151,23 +151,17 @@ def _create_import_alternative(db_map, state): db_map (DatabaseMappingBase): database the state applies to state (_ExecutionFilterState): a state bound to ``db_map`` """ - # FIXME execution_item = state.execution_item scenarios = state.scenarios timestamp = state.timestamp sep = "__" if scenarios else "" db_map._import_alternative_name = f"{'_'.join(scenarios)}{sep}{execution_item}@{timestamp}" - alt_ids, _ = db_map.add_alternatives({"name": db_map._import_alternative_name}, return_dups=True) - db_map._import_alternative_id = next(iter(alt_ids)) - scenarios = [{"name": scenario} for scenario in scenarios] - scen_ids, _ = db_map.add_scenarios(*scenarios, return_dups=True) - for scen_id in scen_ids: - max_rank = ( - db_map.query(func.max(db_map.scenario_alternative_sq.c.rank)) - .filter(db_map.scenario_alternative_sq.c.scenario_id == scen_id) - .scalar() - ) - rank = max_rank + 1 if max_rank else 1 + db_map.add_alternatives({"name": db_map._import_alternative_name}, _strict=False) + scenarios = [{"name": scen_name} for scen_name in scenarios] + db_map.add_scenarios(*scenarios, _strict=True) + for scen_name in scenarios: + scen = db_map.cache.table_cache("scenario").current_item({"name": scen_name}) + rank = len(scen.sorted_alternatives) + 1 # ranks are 1-based db_map.add_scenario_alternatives( - {"scenario_id": scen_id, "alternative_id": db_map._import_alternative_id, "rank": rank} + {"scenario_name": scen_name, "alternative_name": db_map._import_alternative_name, "rank": rank} ) diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index 7d6cd1fc..9a643861 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -75,7 +75,7 @@ model_meta = MetaData(naming_convention=naming_convention) -LONGTEXT_LENGTH = 2**32 - 1 +LONGTEXT_LENGTH = 2 ** 32 - 1 # NOTE: Deactivated since foreign keys are too difficult to get right in the diff tables. @@ -826,28 +826,3 @@ def remove_credentials_from_url(url): if parsed.username is None: return url return urlunparse(parsed._replace(netloc=parsed.netloc.partition("@")[-1])) - - -def convert_legacy(tablename, item): - if tablename in ("entity_class", "entity"): - object_class_id_list = tuple(item.pop("object_class_id_list", ())) - if object_class_id_list: - item["dimension_id_list"] = object_class_id_list - object_class_name_list = tuple(item.pop("object_class_name_list", ())) - if object_class_name_list: - item["dimension_name_list"] = object_class_name_list - if tablename == "entity": - object_id_list = tuple(item.pop("object_id_list", ())) - if object_id_list: - item["element_id_list"] = object_id_list - object_name_list = tuple(item.pop("object_name_list", ())) - if object_name_list: - item["element_name_list"] = object_name_list - if tablename in ("parameter_definition", "parameter_value"): - entity_class_id = item.pop("object_class_id", None) or item.pop("relationship_class_id", None) - if entity_class_id: - item["entity_class_id"] = entity_class_id - if tablename == "parameter_value": - entity_id = item.pop("object_id", None) or item.pop("relationship_id", None) - if entity_id: - item["entity_id"] = entity_id diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 524e316e..4cc5c0f1 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -18,7 +18,6 @@ from .helpers import _parse_metadata # TODO: update docstrings -# FIXME: alt_id, alternative_name = db_map.get_import_alternative() class ImportErrorLogItem: @@ -824,13 +823,23 @@ def _data_iterator(): for class_name, entity_byname, parameter_name, value, *optionals in data: if isinstance(entity_byname, str): entity_byname = (entity_byname,) + alternative_name = optionals[0] if optionals else db_map.get_import_alternative_name() value, type_ = unparse_value(value) - alternative_name = optionals[0] if optionals else "Base" - yield class_name, entity_byname, parameter_name, value, type_, alternative_name - - key = ("entity_class_name", "entity_byname", "parameter_definition_name", "value", "type", "alternative_name") - return _get_items_for_import(db_map, "parameter_value", (dict(zip(key, x)) for x in _data_iterator())) - # FIXME: value, type_ = fix_conflict((value, type_), (current_pv.value, current_pv.type), on_conflict) + item = { + "entity_class_name": class_name, + "entity_byname": entity_byname, + "parameter_definition_name": parameter_name, + "alternative_name": alternative_name, + "value": None, + "type": None, + } + pv = db_map.cache.table_cache("parameter_value").current_item(item) + if pv is not None: + value, type_ = fix_conflict((value, type_), (pv["value"], pv["type"]), on_conflict) + item.update({"value": value, "type": type_}) + yield item + + return _get_items_for_import(db_map, "parameter_value", _data_iterator()) def _get_alternatives_for_import(db_map, data): @@ -931,7 +940,7 @@ def _data_iterator(): for class_name, entity_byname, parameter_name, metadata, *optionals in data: if isinstance(entity_byname, str): entity_byname = (entity_byname,) - alternative_name = optionals[0] if optionals else "Base" + alternative_name = optionals[0] if optionals else db_map.get_import_alternative_name() for name, value in _parse_metadata(metadata): yield (class_name, entity_byname, parameter_name, name, value, alternative_name) diff --git a/spinedb_api/query.py b/spinedb_api/query.py index f80894b4..dcee88a9 100644 --- a/spinedb_api/query.py +++ b/spinedb_api/query.py @@ -17,8 +17,8 @@ class Query: - def __init__(self, db_map, *entities): - self._db_map = db_map + def __init__(self, execute, *entities): + self._execute = execute self._entities = entities self._select = select(entities) self._from = None @@ -85,7 +85,7 @@ def having(self, *args): return self def _result(self): - return self._db_map.connection_execute(self._select) + return self._execute(self._select) def all(self): return self._result().fetchall() @@ -107,7 +107,7 @@ def scalar(self): return self._result().scalar() def count(self): - return self._db_map.connection_execute(select([count()]).select_from(self._select)).scalar() + return self._execute(select([count()]).select_from(self._select)).scalar() def __iter__(self): return self._result() diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 8fd0ea3f..7052c013 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -566,9 +566,8 @@ def test_update_wide_relationship_class_does_not_update_member_class_id(self): items, errors = self._db_map.update_wide_relationship_classes( {"id": 3, "name": "renamed", "object_class_id_list": [2]} ) - updated_ids = {x["id"] for x in items} - self.assertEqual([str(err) for err in errors], ["Can't update fixed fields 'dimension_id_list'"]) - self.assertEqual(updated_ids, {3}) + self.assertEqual([str(err) for err in errors], ["can't modify dimensions of an entity class"]) + self.assertEqual(len(items), 1) self._db_map.commit_session("Update data.") classes = self._db_map.query(self._db_map.wide_relationship_class_sq).all() self.assertEqual(len(classes), 1) @@ -670,12 +669,11 @@ def test_update_parameter_definition_value_list_when_values_exist_gives_error(se items, errors = self._db_map.update_parameter_definitions( {"id": 1, "name": "my_parameter", "parameter_value_list_id": 1} ) - updated_ids = {x["id"] for x in items} self.assertEqual( list(map(str, errors)), - ["Can't change value list on parameter my_parameter because it has parameter values."], + ["can't modify the parameter value list of a parameter that already has values"], ) - self.assertEqual(updated_ids, set()) + self.assertEqual(items, []) def test_update_parameter_definitions_default_value_that_is_not_on_value_list_gives_error(self): import_functions.import_parameter_value_lists(self._db_map, (("my_list", 99.0),)) @@ -712,17 +710,16 @@ def test_update_object_metadata(self): items, errors = self._db_map.update_ext_entity_metadata( *[{"id": 1, "metadata_name": "key_2", "metadata_value": "new value"}] ) - ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(ids, {1}) + self.assertEqual(len(items), 2) self._db_map.commit_session("Update data") metadata_entries = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata_entries), 1) - self.assertEqual(dict(metadata_entries[0]), {"id": 1, "name": "key_2", "value": "new value", "commit_id": None}) + self.assertEqual(dict(metadata_entries[0]), {"id": 2, "name": "key_2", "value": "new value", "commit_id": None}) entity_metadata_entries = self._db_map.query(self._db_map.entity_metadata_sq).all() self.assertEqual(len(entity_metadata_entries), 1) self.assertEqual( - dict(entity_metadata_entries[0]), {"id": 1, "entity_id": 1, "metadata_id": 1, "commit_id": None} + dict(entity_metadata_entries[0]), {"id": 1, "entity_id": 1, "metadata_id": 2, "commit_id": None} ) def test_update_object_metadata_reuses_existing_metadata(self): @@ -745,12 +742,9 @@ def test_update_object_metadata_reuses_existing_metadata(self): self.assertEqual(ids, {1}) self._db_map.commit_session("Update data") metadata_entries = self._db_map.query(self._db_map.metadata_sq).all() - self.assertEqual(len(metadata_entries), 2) - self.assertEqual( - dict(metadata_entries[0]), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": None} - ) + self.assertEqual(len(metadata_entries), 1) self.assertEqual( - dict(metadata_entries[1]), {"id": 2, "name": "key 2", "value": "metadata value 2", "commit_id": None} + dict(metadata_entries[0]), {"id": 2, "name": "key 2", "value": "metadata value 2", "commit_id": None} ) entity_metadata_entries = self._db_map.query(self._db_map.entity_metadata_sq).all() self.assertEqual(len(entity_metadata_entries), 2) @@ -776,9 +770,8 @@ def test_update_object_metadata_keeps_metadata_still_in_use(self): items, errors = self._db_map.update_ext_entity_metadata( *[{"id": 1, "metadata_name": "new key", "metadata_value": "new value"}] ) - ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(ids, {1, 2}) + self.assertEqual(len(items), 2) self._db_map.commit_session("Update data") metadata_entries = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata_entries), 2) diff --git a/tests/test_DiffDatabaseMapping.py b/tests/test_DiffDatabaseMapping.py index 5c70d9e0..a8df6460 100644 --- a/tests/test_DiffDatabaseMapping.py +++ b/tests/test_DiffDatabaseMapping.py @@ -772,25 +772,6 @@ def test_add_parameter_with_same_name_as_existing_one(self): with self.assertRaises(SpineIntegrityError): self._db_map.add_parameter_definitions({"name": "color", "object_class_id": 1}, strict=True) - def test_add_parameter_with_invalid_class(self): - """Test that adding parameter_definitions with an invalid (object or relationship) class raises and integrity error.""" - self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) - self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "object_class_id_list": [1, 2]}) - with self.assertRaises(SpineIntegrityError): - self._db_map.add_parameter_definitions({"name": "color", "object_class_id": 3}, strict=True) - with self.assertRaises(SpineIntegrityError): - self._db_map.add_parameter_definitions({"name": "color", "relationship_class_id": 1}, strict=True) - - def test_add_parameter_for_both_object_and_relationship_class(self): - """Test that adding parameter_definitions associated to both and object and relationship class - raises and integrity error.""" - self._db_map.add_object_classes({"name": "fish", "id": 1}, {"name": "dog", "id": 2}) - self._db_map.add_wide_relationship_classes({"name": "fish__dog", "id": 10, "object_class_id_list": [1, 2]}) - with self.assertRaises(SpineIntegrityError): - self._db_map.add_parameter_definitions( - {"name": "color", "object_class_id": 1, "relationship_class_id": 10}, strict=True - ) - def test_add_parameter_values(self): """Test that adding parameter values works.""" import_functions.import_object_classes(self._db_map, ["fish", "dog"]) @@ -853,11 +834,11 @@ def test_add_parameter_value_with_invalid_object_or_relationship(self): _, errors = self._db_map.add_parameter_values( {"parameter_definition_id": 1, "object_id": 3, "value": b'"orange"', "alternative_id": 1}, strict=False ) - self.assertEqual([str(e) for e in errors], ["Incorrect entity 'fish_dog_nemo__pluto' for parameter 'color'."]) + self.assertEqual([str(e) for e in errors], ["invalid entity_class_id for parameter_value"]) _, errors = self._db_map.add_parameter_values( {"parameter_definition_id": 2, "relationship_id": 2, "value": b"125", "alternative_id": 1}, strict=False ) - self.assertEqual([str(e) for e in errors], ["Incorrect entity 'pluto' for parameter 'rel_speed'."]) + self.assertEqual([str(e) for e in errors], ["invalid entity_class_id for parameter_value"]) def test_add_same_parameter_value_twice(self): """Test that adding a parameter value twice only adds the first one.""" From 7d4d0fe6bbf1dee1f3f06c1bed17802ee6c95438 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 17 May 2023 08:51:48 +0200 Subject: [PATCH 041/317] Get rid of DBMappingBase.connection and move asynch back to toolbox --- spinedb_api/db_mapping_add_mixin.py | 2 +- spinedb_api/db_mapping_base.py | 105 ++--------- spinedb_api/db_mapping_update_mixin.py | 2 - spinedb_api/perfect_split.py | 4 +- spinedb_api/query.py | 8 +- spinedb_api/spine_db_server.py | 2 +- tests/export_mapping/test_export_mapping.py | 144 +++++++-------- tests/export_mapping/test_settings.py | 6 +- tests/filters/test_alternative_filter.py | 46 ++--- tests/filters/test_renamer.py | 50 ++--- tests/filters/test_scenario_filter.py | 172 ++++++++++-------- tests/filters/test_tool_filter.py | 2 +- tests/filters/test_tools.py | 18 +- tests/filters/test_value_transformer.py | 76 ++++---- tests/spine_io/exporters/test_csv_writer.py | 8 +- tests/spine_io/exporters/test_excel_writer.py | 12 +- tests/spine_io/exporters/test_gdx_writer.py | 24 +-- tests/spine_io/exporters/test_sql_writer.py | 14 +- tests/spine_io/exporters/test_writer.py | 2 +- tests/spine_io/test_excel_integration.py | 2 +- tests/test_DatabaseMapping.py | 16 +- tests/test_DiffDatabaseMapping.py | 26 +-- tests/test_export_functions.py | 2 +- tests/test_import_functions.py | 144 +++++++-------- tests/test_migration.py | 2 +- 25 files changed, 421 insertions(+), 468 deletions(-) diff --git a/spinedb_api/db_mapping_add_mixin.py b/spinedb_api/db_mapping_add_mixin.py index 734ff2cd..17c7c762 100644 --- a/spinedb_api/db_mapping_add_mixin.py +++ b/spinedb_api/db_mapping_add_mixin.py @@ -71,7 +71,7 @@ def _do_add_items(self, connection, tablename, *items_to_add): if id_items: connection.execute(table.insert(), [x._asdict() for x in id_items]) if temp_id_items: - current_ids = {x["id"] for x in Query(connection.execute, table)} + current_ids = {x["id"] for x in Query(connection, table)} next_id = max(current_ids, default=0) + 1 available_ids = set(range(1, next_id)) - current_ids missing_id_count = len(temp_id_items) - len(available_ids) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index a913bdfc..0414485d 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -17,12 +17,11 @@ import logging import time from types import MethodType -from concurrent.futures import ThreadPoolExecutor from sqlalchemy import create_engine, MetaData, Table, Column, Integer, inspect, case, func, cast, false, and_, or_ from sqlalchemy.sql.expression import label, Alias from sqlalchemy.engine.url import make_url, URL from sqlalchemy.orm import aliased -from sqlalchemy.exc import DatabaseError, ProgrammingError +from sqlalchemy.exc import DatabaseError from sqlalchemy.event import listen from sqlalchemy.pool import NullPool from alembic.migration import MigrationContext @@ -82,7 +81,6 @@ def __init__( apply_filters=True, memory=False, sqlite_timeout=1800, - asynchronous=False, chunk_size=None, ): """ @@ -95,7 +93,6 @@ def __init__( apply_filters (bool): Whether or not filters in the URL's query part are applied to the database map. memory (bool): Whether or not to use a sqlite memory db as replacement for this DB map. sqlite_timeout (int): How many seconds to wait before raising connection errors. - asynchronous (bool): Whether or not communication with the db should be done asynchronously. chunk_size (int, optional): How many rows to fetch from the DB at a time when populating the cache. If not specified, then all rows are fetched at once. """ @@ -114,21 +111,19 @@ def __init__( self.codename = self._make_codename(codename) self._memory = memory self._memory_dirty = False - self._asynchronous = asynchronous self._original_engine = self.create_engine( self.sa_url, upgrade=upgrade, create=create, sqlite_timeout=sqlite_timeout ) # NOTE: The NullPool is needed to receive the close event (or any events), for some reason self.engine = create_engine("sqlite://", poolclass=NullPool) if self._memory else self._original_engine listen(self.engine, 'close', self._receive_engine_close) - self.executor = self._make_executor() - self.connection = self.executor.submit(self.engine.connect).result() if self._memory: - self.executor.submit(copy_database_bind, self.connection, self._original_engine) - self._metadata = MetaData(self.connection) - _ = self.executor.submit(self._metadata.reflect).result() + copy_database_bind(self.engine, self._original_engine) + self._metadata = MetaData(self.engine) + self._metadata.reflect() self._tablenames = [t.name for t in self._metadata.sorted_tables] self.cache = DBCache(self, chunk_size=chunk_size) + self.closed = False # Subqueries that select everything from each table self._commit_sq = None self._alternative_sq = None @@ -203,25 +198,8 @@ def __exit__(self, _exc_type, _exc_val, _exc_tb): def get_filter_configs(self): return self._filter_configs - def _make_executor(self): - return ThreadPoolExecutor(max_workers=1) if self._asynchronous else _Executor() - - def call_in_right_thread(self, fn, *args, **kwargs): - # We try to call directly. If we are in the wrong thread this will raise ProgrammingError. - # Then we can execute in the executor thread. - try: - return fn(*args, **kwargs) - except ProgrammingError: - return self.executor.submit(fn, *args, **kwargs).result() - def close(self): - if not self.connection.closed: - self.executor.submit(self.connection.close) - self.executor.shutdown() - - def reconnect(self): - self.executor = self._make_executor() - self.connection = self.executor.submit(self.engine.connect).result() + self.closed = True def _real_tablename(self, tablename): return { @@ -321,7 +299,7 @@ def upgrade_to_head(rev, context): return engine def _receive_engine_close(self, dbapi_con, _connection_record): - if dbapi_con == self.connection.connection.connection and self._memory_dirty: + if self._memory_dirty: copy_database_bind(self._original_engine, self.engine) def in_(self, column, values): @@ -339,9 +317,10 @@ def in_(self, column, values): Column("value", column.type, primary_key=True), prefixes=['TEMPORARY'], ) - self.call_in_right_thread(in_value.create, self.connection, checkfirst=True) - self.connection_execute(in_value.insert(), [{"value": column.type.python_type(val)} for val in set(values)]) - return column.in_({x.value for x in self.query(in_value.c.value)}) + with self.engine.connect() as connection: + in_value.create(connection, checkfirst=True) + connection.execute(in_value.insert(), [{"value": column.type.python_type(val)} for val in set(values)]) + return column.in_(Query(connection, in_value.c.value)) def _get_table_to_sq_attr(self): if not self._table_to_sq_attr: @@ -406,7 +385,7 @@ def query(self, *args, **kwargs): db_map.object_sq.c.class_id == db_map.object_class_sq.c.id ).group_by(db_map.object_class_sq.c.name).all() """ - return Query(self.connection_execute, *args) + return Query(self.engine, *args) def _subquery(self, tablename): """A subquery of the form: @@ -1845,9 +1824,6 @@ def restore_scenario_alternative_sq_maker(self): self._make_scenario_alternative_sq = MethodType(DatabaseMappingBase._make_scenario_alternative_sq, self) self._clear_subqueries("scenario_alternative") - def connection_execute(self, *args): - return self.call_in_right_thread(self.connection.execute, *args) - def _get_primary_key(self, tablename): pk = self.composite_pks.get(tablename) if pk is None: @@ -1859,10 +1835,11 @@ def _reset_mapping(self): """Delete all records from all tables but don't drop the tables. Useful for writing tests """ - for tablename in self._tablenames: - table = self._metadata.tables[tablename] - self.connection_execute(table.delete()) - self.connection_execute("INSERT INTO alternative VALUES (1, 'Base', 'Base alternative', null)") + with self.engine.connect() as connection: + for tablename in self._tablenames: + table = self._metadata.tables[tablename] + connection.execute(table.delete()) + connection.execute("INSERT INTO alternative VALUES (1, 'Base', 'Base alternative', null)") def fetch_all(self, tablenames=None): tablenames = set(self.ITEM_TYPES) if tablenames is None else tablenames & set(self.ITEM_TYPES) @@ -1930,19 +1907,13 @@ def _object_name_list(self): [(self.wide_entity_sq.c.element_id_list != None, self.wide_relationship_sq.c.object_name_list)], else_=None ) - def advance_cache_query(self, item_type, callback=None): + def advance_cache_query(self, item_type): """Schedules an advance of the DB query that fetches items of given type. Args: item_type (str) - - Returns: - Future """ - if not callback: - return self.cache.advance_query(item_type) - future = self.executor.submit(self.cache.advance_query, item_type) - future.add_done_callback(lambda future: callback(future.result())) + return self.cache.advance_query(item_type) @staticmethod def _convert_legacy(tablename, item): @@ -1970,40 +1941,4 @@ def _convert_legacy(tablename, item): item["entity_id"] = entity_id def __del__(self): - try: - self.close() - except AttributeError: - pass - - -class _Future: - def __init__(self): - self._result = None - self._exception = None - - def set_result(self, result): - self._result = result - - def set_exception(self, exception): - self._exception = exception - - def add_done_callback(self, callback): - callback(self) - - def result(self): - if self._exception is not None: - raise self._exception - return self._result - - -class _Executor: - def submit(self, fn, *args, **kwargs): - future = _Future() - try: - future.set_result(fn(*args, **kwargs)) - except Exception as exc: - future.set_exception(exc) - return future - - def shutdown(self): - pass + self.close() diff --git a/spinedb_api/db_mapping_update_mixin.py b/spinedb_api/db_mapping_update_mixin.py index afeb64d1..3d0e0bdf 100644 --- a/spinedb_api/db_mapping_update_mixin.py +++ b/spinedb_api/db_mapping_update_mixin.py @@ -158,8 +158,6 @@ def _update_ext_item_metadata(self, tablename, *items, **kwargs): metadata_items = self.get_metadata_to_add_with_entity_metadata_items(*items) added, errors = self.add_items("metadata", *metadata_items, **kwargs) updated, more_errors = self.update_items(tablename, *items, **kwargs) - metadata_ids = self.get_metadata_ids_to_remove() - self.remove_items("metadata", *metadata_ids) return added + updated, errors + more_errors def update_ext_entity_metadata(self, *items, **kwargs): diff --git a/spinedb_api/perfect_split.py b/spinedb_api/perfect_split.py index 1db77bdf..82d9b56d 100644 --- a/spinedb_api/perfect_split.py +++ b/spinedb_api/perfect_split.py @@ -33,7 +33,7 @@ def perfect_split(input_urls, intersection_url, diff_urls): input_db_map = DatabaseMapping(input_url) input_data_sets[input_url] = export_data(input_db_map) db_names[input_url] = input_db_map.codename - input_db_map.connection.close() + input_db_map.close() intersection_data = {} input_data_set_iter = iter(input_data_sets) left_url = next(iter(input_data_set_iter)) @@ -67,7 +67,7 @@ def perfect_split(input_urls, intersection_url, diff_urls): db_name = db_names[input_url] other_db_names = ', '.join([name for url, name in db_names.items() if url != input_url]) diff_db_map.commit_session(f"Add differences between {db_name} and {other_db_names}") - diff_db_map.connection.close() + diff_db_map.close() def _make_lookup(data): diff --git a/spinedb_api/query.py b/spinedb_api/query.py index dcee88a9..a304d107 100644 --- a/spinedb_api/query.py +++ b/spinedb_api/query.py @@ -17,8 +17,8 @@ class Query: - def __init__(self, execute, *entities): - self._execute = execute + def __init__(self, bind, *entities): + self._bind = bind self._entities = entities self._select = select(entities) self._from = None @@ -85,7 +85,7 @@ def having(self, *args): return self def _result(self): - return self._execute(self._select) + return self._bind.execute(self._select) def all(self): return self._result().fetchall() @@ -107,7 +107,7 @@ def scalar(self): return self._result().scalar() def count(self): - return self._execute(select([count()]).select_from(self._select)).scalar() + return self._bind.execute(select([count()]).select_from(self._select)).scalar() def __iter__(self): return self._result() diff --git a/spinedb_api/spine_db_server.py b/spinedb_api/spine_db_server.py index 708ddb92..5b3df2fc 100644 --- a/spinedb_api/spine_db_server.py +++ b/spinedb_api/spine_db_server.py @@ -268,7 +268,7 @@ def _do_work(self): while True: input_ = self._in_queue.get() if input_ == self._CLOSE: - self._db_map.connection.close() + self._db_map.close() break request, args, kwargs = input_ handler = { diff --git a/tests/export_mapping/test_export_mapping.py b/tests/export_mapping/test_export_mapping.py index d7219d2c..fc6eefaa 100644 --- a/tests/export_mapping/test_export_mapping.py +++ b/tests/export_mapping/test_export_mapping.py @@ -72,7 +72,7 @@ def test_export_empty_table(self): db_map = DatabaseMapping("sqlite://", create=True) object_class_mapping = EntityClassMapping(0) self.assertEqual(list(rows(object_class_mapping, db_map)), []) - db_map.connection.close() + db_map.close() def test_export_single_object_class(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -80,7 +80,7 @@ def test_export_single_object_class(self): db_map.commit_session("Add test data.") object_class_mapping = EntityClassMapping(0) self.assertEqual(list(rows(object_class_mapping, db_map)), [["object_class"]]) - db_map.connection.close() + db_map.close() def test_export_objects(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -95,7 +95,7 @@ def test_export_objects(self): list(rows(object_class_mapping, db_map)), [["oc1", "o11"], ["oc1", "o12"], ["oc2", "o21"], ["oc3", "o31"], ["oc3", "o32"], ["oc3", "o33"]], ) - db_map.connection.close() + db_map.close() def test_hidden_tail(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -105,7 +105,7 @@ def test_hidden_tail(self): object_class_mapping = EntityClassMapping(0) object_class_mapping.child = EntityMapping(Position.hidden) self.assertEqual(list(rows(object_class_mapping, db_map)), [["oc1"], ["oc1"]]) - db_map.connection.close() + db_map.close() def test_pivot_without_values(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -115,7 +115,7 @@ def test_pivot_without_values(self): object_class_mapping = EntityClassMapping(-1) object_class_mapping.child = EntityMapping(Position.hidden) self.assertEqual(list(rows(object_class_mapping, db_map)), []) - db_map.connection.close() + db_map.close() def test_hidden_tail_pivoted(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -135,7 +135,7 @@ def test_hidden_tail_pivoted(self): ) expected = [[None, None, "p1", "p2"], ["oc", "o1", "Base", "Base"]] self.assertEqual(list(rows(root_mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_hidden_leaf_item_in_regular_table_valid(self): object_class_mapping = EntityClassMapping(0) @@ -156,7 +156,7 @@ def test_object_groups(self): flattened = [EntityClassMapping(0), EntityGroupMapping(1)] mapping = unflatten(flattened) self.assertEqual(list(rows(mapping, db_map)), [["oc", "g1"], ["oc", "g2"]]) - db_map.connection.close() + db_map.close() def test_object_groups_with_objects(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -167,7 +167,7 @@ def test_object_groups_with_objects(self): flattened = [EntityClassMapping(0), EntityGroupMapping(1), EntityGroupEntityMapping(2)] mapping = unflatten(flattened) self.assertEqual(list(rows(mapping, db_map)), [["oc", "g1", "o1"], ["oc", "g1", "o2"], ["oc", "g2", "o3"]]) - db_map.connection.close() + db_map.close() def test_object_groups_with_parameter_values(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -192,7 +192,7 @@ def test_object_groups_with_parameter_values(self): list(rows(mapping, db_map)), [["oc", "g1", "o1", -11.0], ["oc", "g1", "o2", -12.0], ["oc", "g2", "o3", -13.0]], ) - db_map.connection.close() + db_map.close() def test_export_parameter_definitions(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -212,7 +212,7 @@ def test_export_parameter_definitions(self): ["oc2", "p21", "o21"], ] self.assertEqual(list(rows(object_class_mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_export_single_parameter_value_when_there_are_multiple_objects(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -231,7 +231,7 @@ def test_export_single_parameter_value_when_there_are_multiple_objects(self): object_mapping.child = value_mapping object_class_mapping.child = parameter_definition_mapping self.assertEqual(list(rows(object_class_mapping, db_map)), [["oc1", "p12", "o11", -11.0]]) - db_map.connection.close() + db_map.close() def test_export_single_parameter_value_pivoted_by_object_name(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -259,7 +259,7 @@ def test_export_single_parameter_value_pivoted_by_object_name(self): object_class_mapping.child = parameter_definition_mapping expected = [[None, None, "o11", "o12"], ["oc1", "p11", -11.0, -21.0], ["oc1", "p12", -12.0, -22.0]] self.assertEqual(list(rows(object_class_mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_minimum_pivot_index_need_not_be_minus_one(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -284,7 +284,7 @@ def test_minimum_pivot_index_need_not_be_minus_one(self): ["o", "oc", "p", "B", -2.2, -6.6], ] self.assertEqual(list(rows(mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_pivot_row_order(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -328,7 +328,7 @@ def test_pivot_row_order(self): ["oc1", -11.0, -12.0, -21.0, -22.0], ] self.assertEqual(list(rows(root, db_map)), expected) - db_map.connection.close() + db_map.close() def test_export_parameter_indexes(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -365,7 +365,7 @@ def test_export_parameter_indexes(self): ["oc", "o2", "p2", "h"], ] self.assertEqual(list(rows(object_class_mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_export_nested_parameter_indexes(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -400,7 +400,7 @@ def test_export_nested_parameter_indexes(self): ["oc", "o2", "p", "D", None], ] self.assertEqual(list(rows(object_class_mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_export_nested_map_values_only(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -430,7 +430,7 @@ def test_export_nested_map_values_only(self): object_class_mapping.child = parameter_definition_mapping expected = [[23.0], [-1.1], [-2.2], [-3.3], [-4.4], [2.3]] self.assertEqual(list(rows(object_class_mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_full_pivot_table(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -466,7 +466,7 @@ def test_full_pivot_table(self): ["oc", "p", "B", "b", -4.4, -8.8], ] self.assertEqual(list(rows(object_class_mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_full_pivot_table_with_hidden_columns(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -486,7 +486,7 @@ def test_full_pivot_table_with_hidden_columns(self): ["oc", None, "p", "Base", "B", -2.2, -6.6], ] self.assertEqual(list(rows(mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_objects_as_pivot_header_for_indexed_values_with_alternatives(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -515,7 +515,7 @@ def test_objects_as_pivot_header_for_indexed_values_with_alternatives(self): ["oc", None, "p", "alt", "B", -4.4, -8.8], ] self.assertEqual(list(rows(mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_objects_and_indexes_as_pivot_header(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -535,7 +535,7 @@ def test_objects_and_indexes_as_pivot_header(self): ["oc", None, "p", "Base", -1.1, -2.2, -3.3, -4.4], ] self.assertEqual(list(rows(mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_objects_and_indexes_as_pivot_header_with_multiple_alternatives_and_parameters(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -569,7 +569,7 @@ def test_objects_and_indexes_as_pivot_header_with_multiple_alternatives_and_para ["oc", "p2", -5.5, -6.6, -7.7, -8.8, -13.3, -14.4, -15.5, -16.6], ] self.assertEqual(list(rows(mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_empty_column_while_pivoted_handled_gracefully(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -586,7 +586,7 @@ def test_empty_column_while_pivoted_handled_gracefully(self): definition.child = value_list mapping.child = definition self.assertEqual(list(rows(mapping, db_map)), []) - db_map.connection.close() + db_map.close() def test_object_classes_as_header_row_and_objects_in_columns(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -601,7 +601,7 @@ def test_object_classes_as_header_row_and_objects_in_columns(self): list(rows(object_class_mapping, db_map)), [["oc1", "oc2", "oc3"], ["o11", "o21", "o31"], ["o12", None, "o32"], [None, None, "o33"]], ) - db_map.connection.close() + db_map.close() def test_object_classes_as_table_names(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -616,7 +616,7 @@ def test_object_classes_as_table_names(self): for title, title_key in titles(object_class_mapping, db_map): tables[title] = list(rows(object_class_mapping, db_map, title_key)) self.assertEqual(tables, {"oc1": [["o11"], ["o12"]], "oc2": [["o21"]], "oc3": [["o31"], ["o32"], ["o33"]]}) - db_map.connection.close() + db_map.close() def test_object_class_and_parameter_definition_as_table_name(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -635,7 +635,7 @@ def test_object_class_and_parameter_definition_as_table_name(self): self.assertEqual( tables, {"oc1,p11": [["o11"], ["o12"]], "oc2,p21": [["o21"]], "oc2,p22": [["o21"]], "oc3": [["o31"]]} ) - db_map.connection.close() + db_map.close() def test_object_relationship_name_as_table_name(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -651,7 +651,7 @@ def test_object_relationship_name_as_table_name(self): self.assertEqual( tables, {"rc_o1__O,o1": [["rc", "oc1", "oc2", "O"]], "rc_o2__O,o2": [["rc", "oc1", "oc2", "O"]]} ) - db_map.connection.close() + db_map.close() def test_parameter_definitions_with_value_lists(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -668,7 +668,7 @@ def test_parameter_definitions_with_value_lists(self): for title, title_key in titles(class_mapping, db_map): tables[title] = list(rows(class_mapping, db_map, title_key)) self.assertEqual(tables, {None: [["oc", "p1", "vl1"]]}) - db_map.connection.close() + db_map.close() def test_parameter_definitions_and_values_and_value_lists(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -691,7 +691,7 @@ def test_parameter_definitions_and_values_and_value_lists(self): for title, title_key in titles(mapping, db_map): tables[title] = list(rows(mapping, db_map, title_key)) self.assertEqual(tables, {None: [["oc", "p1", "vl", "o", -1.0]]}) - db_map.connection.close() + db_map.close() def test_parameter_definitions_and_values_and_ignorable_value_lists(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -716,7 +716,7 @@ def test_parameter_definitions_and_values_and_ignorable_value_lists(self): for title, title_key in titles(mapping, db_map): tables[title] = list(rows(mapping, db_map, title_key)) self.assertEqual(tables, {None: [["oc", "p1", "vl", "o", -1.0], ["oc", "p2", None, "o", 5.0]]}) - db_map.connection.close() + db_map.close() def test_parameter_value_lists(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -727,7 +727,7 @@ def test_parameter_value_lists(self): for title, title_key in titles(value_list_mapping, db_map): tables[title] = list(rows(value_list_mapping, db_map, title_key)) self.assertEqual(tables, {None: [["vl1"], ["vl2"]]}) - db_map.connection.close() + db_map.close() def test_parameter_value_list_values(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -740,7 +740,7 @@ def test_parameter_value_list_values(self): for title, title_key in titles(value_list_mapping, db_map): tables[title] = list(rows(value_list_mapping, db_map, title_key)) self.assertEqual(tables, {"vl1": [[-1.0]], "vl2": [[-2.0]]}) - db_map.connection.close() + db_map.close() def test_no_item_declared_as_title_gives_full_table(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -757,7 +757,7 @@ def test_no_item_declared_as_title_gives_full_table(self): for title, title_key in titles(object_class_mapping, db_map): tables[title] = list(rows(object_class_mapping, db_map, title_key)) self.assertEqual(tables, {None: [["o11"], ["o12"], ["o21"], ["o21"]]}) - db_map.connection.close() + db_map.close() def test_missing_values_for_alternatives(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -791,7 +791,7 @@ def test_missing_values_for_alternatives(self): ["oc", "o2", "p2", "alt2", -5.5], ] self.assertEqual(list(rows(object_class_mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_export_relationship_classes(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -807,7 +807,7 @@ def test_export_relationship_classes(self): list(rows(relationship_class_mapping, db_map)), [["rc1", "oc1", ""], ["rc2", "oc3", "oc2"], ["rc3", "oc2", "oc3"]], ) - db_map.connection.close() + db_map.close() def test_export_relationships(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -828,7 +828,7 @@ def test_export_relationships(self): ['rc2', 'oc2', 'oc1', 'rc2_o21__o12', 'o21', 'o12'], ] self.assertEqual(list(rows(relationship_class_mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_relationships_with_different_dimensions(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -864,7 +864,7 @@ def test_relationships_with_different_dimensions(self): ["rc2D", "oc1", "oc2", "o12", "o22"], ] self.assertEqual(tables[None], expected) - db_map.connection.close() + db_map.close() def test_default_parameter_values(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -878,7 +878,7 @@ def test_default_parameter_values(self): object_class_mapping.child = definition_mapping table = list(rows(object_class_mapping, db_map)) self.assertEqual(table, [["oc1", "p11", 3.14], ["oc2", "p21", 14.3], ["oc2", "p22", -1.0]]) - db_map.connection.close() + db_map.close() def test_indexed_default_parameter_values(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -908,7 +908,7 @@ def test_indexed_default_parameter_values(self): ["oc2", "p22", "D", -1.0], ] self.assertEqual(table, expected) - db_map.connection.close() + db_map.close() def test_replace_parameter_indexes_by_external_data(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -938,7 +938,7 @@ def test_replace_parameter_indexes_by_external_data(self): ["oc", "o2", "p1", "d", -2.0], ] self.assertEqual(list(rows(object_class_mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_constant_mapping_as_title(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -951,7 +951,7 @@ def test_constant_mapping_as_title(self): for title, title_key in titles(constant_mapping, db_map): tables[title] = list(rows(constant_mapping, db_map, title_key)) self.assertEqual(tables, {"title_text": [["oc1"], ["oc2"], ["oc3"]]}) - db_map.connection.close() + db_map.close() def test_scenario_mapping(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -962,7 +962,7 @@ def test_scenario_mapping(self): for title, title_key in titles(scenario_mapping, db_map): tables[title] = list(rows(scenario_mapping, db_map, title_key)) self.assertEqual(tables, {None: [["s1"], ["s2"]]}) - db_map.connection.close() + db_map.close() def test_scenario_active_flag_mapping(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -975,7 +975,7 @@ def test_scenario_active_flag_mapping(self): for title, title_key in titles(scenario_mapping, db_map): tables[title] = list(rows(scenario_mapping, db_map, title_key)) self.assertEqual(tables, {None: [["s1", True], ["s2", False]]}) - db_map.connection.close() + db_map.close() def test_scenario_alternative_mapping(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -990,7 +990,7 @@ def test_scenario_alternative_mapping(self): for title, title_key in titles(scenario_mapping, db_map): tables[title] = list(rows(scenario_mapping, db_map, title_key)) self.assertEqual(tables, {None: [["s1", "a1"], ["s1", "a2"], ["s2", "a2"], ["s2", "a3"]]}) - db_map.connection.close() + db_map.close() def test_header(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -1000,21 +1000,21 @@ def test_header(self): root = unflatten([EntityClassMapping(0, header="class"), EntityMapping(1, header="entity")]) expected = [["class", "entity"], ["oc", "o1"]] self.assertEqual(list(rows(root, db_map)), expected) - db_map.connection.close() + db_map.close() def test_header_without_data_still_creates_header(self): db_map = DatabaseMapping("sqlite://", create=True) root = unflatten([EntityClassMapping(0, header="class"), EntityMapping(1, header="object")]) expected = [["class", "object"]] self.assertEqual(list(rows(root, db_map)), expected) - db_map.connection.close() + db_map.close() def test_header_in_half_pivot_table_without_data_still_creates_header(self): db_map = DatabaseMapping("sqlite://", create=True) root = unflatten([EntityClassMapping(-1, header="class"), EntityMapping(9, header="object")]) expected = [["class"]] self.assertEqual(list(rows(root, db_map)), expected) - db_map.connection.close() + db_map.close() def test_header_in_pivot_table_without_data_still_creates_header(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -1029,21 +1029,21 @@ def test_header_in_pivot_table_without_data_still_creates_header(self): ) expected = [[None, "class"], ["parameter", "alternative"]] self.assertEqual(list(rows(root, db_map)), expected) - db_map.connection.close() + db_map.close() def test_disabled_empty_data_header(self): db_map = DatabaseMapping("sqlite://", create=True) root = unflatten([EntityClassMapping(0, header="class"), EntityMapping(1, header="object")]) expected = [] self.assertEqual(list(rows(root, db_map, empty_data_header=False)), expected) - db_map.connection.close() + db_map.close() def test_disabled_empty_data_header_in_pivot_table(self): db_map = DatabaseMapping("sqlite://", create=True) root = unflatten([EntityClassMapping(-1, header="class"), EntityMapping(0)]) expected = [] self.assertEqual(list(rows(root, db_map, empty_data_header=False)), expected) - db_map.connection.close() + db_map.close() def test_header_position(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -1053,7 +1053,7 @@ def test_header_position(self): root = unflatten([EntityClassMapping(Position.header), EntityMapping(0)]) expected = [["oc"], ["o1"]] self.assertEqual(list(rows(root, db_map)), expected) - db_map.connection.close() + db_map.close() def test_header_position_with_relationships(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -1074,7 +1074,7 @@ def test_header_position_with_relationships(self): ) expected = [["", "", "oc1", "oc2"], ["rc", "rc_o11__o21", "o11", "o21"]] self.assertEqual(list(rows(root, db_map)), expected) - db_map.connection.close() + db_map.close() def test_header_position_with_relationships_but_no_data(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -1093,7 +1093,7 @@ def test_header_position_with_relationships_but_no_data(self): ) expected = [["", "", "oc1", "oc2"]] self.assertEqual(list(rows(root, db_map)), expected) - db_map.connection.close() + db_map.close() def test_header_and_pivot(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -1134,7 +1134,7 @@ def test_header_and_pivot(self): ["oc", "p2", -5.5, -6.6, -7.7, -8.8, -13.3, -14.4, -15.5, -16.6], ] self.assertEqual(list(rows(mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_pivot_without_left_hand_side_has_padding_column_for_headers(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -1175,7 +1175,7 @@ def test_pivot_without_left_hand_side_has_padding_column_for_headers(self): [None, -5.5, -6.6, -7.7, -8.8, -13.3, -14.4, -15.5, -16.6], ] self.assertEqual(list(rows(mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_count_mappings(self): object_class_mapping = EntityClassMapping(2) @@ -1266,7 +1266,7 @@ def test_setting_ignorable_flag(self): self.assertTrue(object_mapping.is_ignorable()) expected = [["oc", None]] self.assertEqual(list(rows(root_mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_unsetting_ignorable_flag(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -1281,7 +1281,7 @@ def test_unsetting_ignorable_flag(self): self.assertFalse(object_mapping.is_ignorable()) expected = [["oc", "o1"]] self.assertEqual(list(rows(root_mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_filter(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -1293,7 +1293,7 @@ def test_filter(self): root_mapping = unflatten([EntityClassMapping(0), object_mapping]) expected = [["oc", "o1"]] self.assertEqual(list(rows(root_mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_hidden_tail_filter(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -1305,7 +1305,7 @@ def test_hidden_tail_filter(self): root_mapping = unflatten([EntityClassMapping(0), object_mapping]) expected = [["oc1"]] self.assertEqual(list(rows(root_mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_index_names(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -1319,7 +1319,7 @@ def test_index_names(self): ) expected = [["", "", "", "", "index", ""], ["oc", "o", "p", "Base", "a", 5.0]] self.assertEqual(list(rows(mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_default_value_index_names_with_nested_map(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -1333,7 +1333,7 @@ def test_default_value_index_names_with_nested_map(self): ) expected = [["", "", "idx1", "idx2", ""], ["oc", "p", "A", "b", 2.3]] self.assertEqual(list(rows(mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_multiple_index_names_with_empty_database(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -1342,7 +1342,7 @@ def test_multiple_index_names_with_empty_database(self): ) expected = [9 * [""]] self.assertEqual(list(rows(mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_parameter_default_value_type(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -1356,7 +1356,7 @@ def test_parameter_default_value_type(self): ["oc2", "p22", "single_value", -1.0], ] self.assertEqual(list(rows(root_mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_map_with_more_dimensions_than_index_mappings(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -1370,7 +1370,7 @@ def test_map_with_more_dimensions_than_index_mappings(self): ) expected = [["oc", "p", "o", "A", "map"]] self.assertEqual(list(rows(mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_default_map_value_with_more_dimensions_than_index_mappings(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -1380,7 +1380,7 @@ def test_default_map_value_with_more_dimensions_than_index_mappings(self): mapping = entity_parameter_default_value_export(0, 1, Position.hidden, 3, [Position.hidden], [2]) expected = [["oc", "p", "A", "map"]] self.assertEqual(list(rows(mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_map_with_single_value_mapping(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -1394,7 +1394,7 @@ def test_map_with_single_value_mapping(self): ) expected = [["oc", "p", "o", "map"]] self.assertEqual(list(rows(mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_default_map_value_with_single_value_mapping(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -1404,7 +1404,7 @@ def test_default_map_value_with_single_value_mapping(self): mapping = entity_parameter_default_value_export(0, 1, Position.hidden, 2, None, None) expected = [["oc", "p", "map"]] self.assertEqual(list(rows(mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_table_gets_exported_even_without_parameter_values(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -1419,7 +1419,7 @@ def test_table_gets_exported_even_without_parameter_values(self): tables[title] = list(rows(mapping, db_map, title_key)) expected = {"p": [["oc", ""]]} self.assertEqual(tables, expected) - db_map.connection.close() + db_map.close() def test_relationship_class_object_classes_parameters(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -1436,7 +1436,7 @@ def test_relationship_class_object_classes_parameters(self): ) expected = [["rc", "oc", "p"]] self.assertEqual(list(rows(root_mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_relationship_class_object_classes_parameters_multiple_dimensions(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -1449,7 +1449,7 @@ def test_relationship_class_object_classes_parameters_multiple_dimensions(self): ) expected = [["rc", "oc1", "p11", "oc2"], ["rc", "oc1", "p12", "oc2"]] self.assertEqual(list(rows(root_mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_highlight_relationship_objects(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -1475,7 +1475,7 @@ def test_highlight_relationship_objects(self): ["rc", "oc1", "oc2", "rc_o12__o22", "o12", "o22"], ] self.assertEqual(list(rows(root_mapping, db_map)), expected) - db_map.connection.close() + db_map.close() def test_export_object_parameters_while_exporting_relationships(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -1499,7 +1499,7 @@ def test_export_object_parameters_while_exporting_relationships(self): ) expected = [["rc", "oc", "rc_o", "o", "p", "Base", 23.0]] self.assertEqual(list(rows(root_mapping, db_map)), expected) - db_map.connection.close() + db_map.close() if __name__ == "__main__": diff --git a/tests/export_mapping/test_settings.py b/tests/export_mapping/test_settings.py index d490af98..36be53d3 100644 --- a/tests/export_mapping/test_settings.py +++ b/tests/export_mapping/test_settings.py @@ -93,7 +93,7 @@ def test_export_with_parameter_values(self): [numpy.datetime64("2022-06-22T12:00:00"), -2.2, -4.4], ] self.assertEqual(list(rows(root_mapping, db_map)), expected) - db_map.connection.close() + db_map.close() class TestEntityClassDimensionParameterDefaultValueExport(unittest.TestCase): @@ -101,7 +101,7 @@ def setUp(self): self._db_map = DatabaseMapping("sqlite://", create=True) def tearDown(self): - self._db_map.connection.close() + self._db_map.close() def test_export_with_two_dimensions(self): import_object_classes(self._db_map, ("oc1", "oc2")) @@ -130,7 +130,7 @@ def setUp(self): self._db_map = DatabaseMapping("sqlite://", create=True) def tearDown(self): - self._db_map.connection.close() + self._db_map.close() def test_export_with_two_dimensions(self): import_object_classes(self._db_map, ("oc1", "oc2")) diff --git a/tests/filters/test_alternative_filter.py b/tests/filters/test_alternative_filter.py index c24da2c0..00160a6d 100644 --- a/tests/filters/test_alternative_filter.py +++ b/tests/filters/test_alternative_filter.py @@ -48,30 +48,30 @@ def setUpClass(cls): def setUp(self): create_new_spine_database(self._db_url) - self._out_map = DatabaseMapping(self._db_url) + self._out_db_map = DatabaseMapping(self._db_url) self._db_map = DatabaseMapping(self._db_url) def tearDown(self): - self._out_map.connection.close() - self._db_map.connection.close() + self._out_db_map.close() + self._db_map.close() def test_alternative_filter_without_scenarios_or_alternatives(self): self._build_data_without_alternatives() - self._out_map.commit_session("Add test data") + self._out_db_map.commit_session("Add test data") apply_alternative_filter_to_parameter_value_sq(self._db_map, []) parameters = self._db_map.query(self._db_map.parameter_value_sq).all() self.assertEqual(parameters, []) def test_alternative_filter_without_scenarios_or_alternatives_uncommitted_data(self): self._build_data_without_alternatives() - apply_alternative_filter_to_parameter_value_sq(self._out_map, alternatives=[]) - parameters = self._out_map.query(self._out_map.parameter_value_sq).all() + apply_alternative_filter_to_parameter_value_sq(self._out_db_map, alternatives=[]) + parameters = self._out_db_map.query(self._out_db_map.parameter_value_sq).all() self.assertEqual(parameters, []) - self._out_map.rollback_session() + self._out_db_map.rollback_session() def test_alternative_filter(self): self._build_data_with_single_alternative() - self._out_map.commit_session("Add test data") + self._out_db_map.commit_session("Add test data") apply_alternative_filter_to_parameter_value_sq(self._db_map, ["alternative"]) parameters = self._db_map.query(self._db_map.parameter_value_sq).all() self.assertEqual(len(parameters), 1) @@ -80,14 +80,14 @@ def test_alternative_filter(self): def test_alternative_filter_uncommitted_data(self): self._build_data_with_single_alternative() with self.assertRaises(SpineDBAPIError): - apply_alternative_filter_to_parameter_value_sq(self._out_map, ["alternative"]) - parameters = self._out_map.query(self._out_map.parameter_value_sq).all() + apply_alternative_filter_to_parameter_value_sq(self._out_db_map, ["alternative"]) + parameters = self._out_db_map.query(self._out_db_map.parameter_value_sq).all() self.assertEqual(len(parameters), 0) - self._out_map.rollback_session() + self._out_db_map.rollback_session() def test_alternative_filter_from_dict(self): self._build_data_with_single_alternative() - self._out_map.commit_session("Add test data") + self._out_db_map.commit_session("Add test data") config = alternative_filter_config(["alternative"]) alternative_filter_from_dict(self._db_map, config) parameters = self._db_map.query(self._db_map.parameter_value_sq).all() @@ -95,18 +95,18 @@ def test_alternative_filter_from_dict(self): self.assertEqual(parameters[0].value, b"23.0") def _build_data_without_alternatives(self): - import_object_classes(self._out_map, ["object_class"]) - import_objects(self._out_map, [("object_class", "object")]) - import_object_parameters(self._out_map, [("object_class", "parameter")]) - import_object_parameter_values(self._out_map, [("object_class", "object", "parameter", 23.0)]) + import_object_classes(self._out_db_map, ["object_class"]) + import_objects(self._out_db_map, [("object_class", "object")]) + import_object_parameters(self._out_db_map, [("object_class", "parameter")]) + import_object_parameter_values(self._out_db_map, [("object_class", "object", "parameter", 23.0)]) def _build_data_with_single_alternative(self): - import_alternatives(self._out_map, ["alternative"]) - import_object_classes(self._out_map, ["object_class"]) - import_objects(self._out_map, [("object_class", "object")]) - import_object_parameters(self._out_map, [("object_class", "parameter")]) - import_object_parameter_values(self._out_map, [("object_class", "object", "parameter", -1.0)]) - import_object_parameter_values(self._out_map, [("object_class", "object", "parameter", 23.0, "alternative")]) + import_alternatives(self._out_db_map, ["alternative"]) + import_object_classes(self._out_db_map, ["object_class"]) + import_objects(self._out_db_map, [("object_class", "object")]) + import_object_parameters(self._out_db_map, [("object_class", "parameter")]) + import_object_parameter_values(self._out_db_map, [("object_class", "object", "parameter", -1.0)]) + import_object_parameter_values(self._out_db_map, [("object_class", "object", "parameter", 23.0, "alternative")]) class TestAlternativeFilterWithMemoryDatabase(unittest.TestCase): @@ -119,7 +119,7 @@ def setUp(self): self._db_map.commit_session("Add initial data.") def tearDown(self): - self._db_map.connection.close() + self._db_map.close() def test_alternative_names_with_colons(self): self._add_value_in_alternative(23.0, "new@2023-23-23T11:12:13") diff --git a/tests/filters/test_renamer.py b/tests/filters/test_renamer.py index ddaf8a15..8f2b68e0 100644 --- a/tests/filters/test_renamer.py +++ b/tests/filters/test_renamer.py @@ -47,12 +47,12 @@ def setUpClass(cls): def setUp(self): create_new_spine_database(self._db_url) - self._out_map = DatabaseMapping(self._db_url) + self._out_db_map = DatabaseMapping(self._db_url) self._db_map = DatabaseMapping(self._db_url) def tearDown(self): - self._out_map.connection.close() - self._db_map.connection.close() + self._out_db_map.close() + self._db_map.close() def test_renaming_empty_database(self): apply_renaming_to_entity_class_sq(self._db_map, {"some_name": "another_name"}) @@ -60,8 +60,8 @@ def test_renaming_empty_database(self): self.assertEqual(classes, []) def test_renaming_singe_entity_class(self): - import_object_classes(self._out_map, ("old_name",)) - self._out_map.commit_session("Add test data") + import_object_classes(self._out_db_map, ("old_name",)) + self._out_db_map.commit_session("Add test data") apply_renaming_to_entity_class_sq(self._db_map, {"old_name": "new_name"}) classes = list(self._db_map.query(self._db_map.entity_class_sq).all()) self.assertEqual(len(classes), 1) @@ -74,24 +74,24 @@ def test_renaming_singe_entity_class(self): self.assertEqual(class_row.name, "new_name") def test_renaming_singe_relationship_class(self): - import_object_classes(self._out_map, ("object_class",)) - import_relationship_classes(self._out_map, (("old_name", ("object_class",)),)) - self._out_map.commit_session("Add test data") + import_object_classes(self._out_db_map, ("object_class",)) + import_relationship_classes(self._out_db_map, (("old_name", ("object_class",)),)) + self._out_db_map.commit_session("Add test data") apply_renaming_to_entity_class_sq(self._db_map, {"old_name": "new_name"}) classes = list(self._db_map.query(self._db_map.relationship_class_sq).all()) self.assertEqual(len(classes), 1) self.assertEqual(classes[0].name, "new_name") def test_renaming_multiple_entity_classes(self): - import_object_classes(self._out_map, ("object_class1", "object_class2")) + import_object_classes(self._out_db_map, ("object_class1", "object_class2")) import_relationship_classes( - self._out_map, + self._out_db_map, ( ("relationship_class1", ("object_class1", "object_class2")), ("relationship_class2", ("object_class2", "object_class1")), ), ) - self._out_map.commit_session("Add test data") + self._out_db_map.commit_session("Add test data") apply_renaming_to_entity_class_sq( self._db_map, {"object_class1": "new_object_class", "relationship_class1": "new_relationship_class"} ) @@ -116,8 +116,8 @@ def test_entity_class_renamer_config(self): ) def test_entity_class_renamer_from_dict(self): - import_object_classes(self._out_map, ("old_name",)) - self._out_map.commit_session("Add test data") + import_object_classes(self._out_db_map, ("old_name",)) + self._out_db_map.commit_session("Add test data") config = entity_class_renamer_config(old_name="new_name") entity_class_renamer_from_dict(self._db_map, config) classes = list(self._db_map.query(self._db_map.entity_class_sq).all()) @@ -152,12 +152,12 @@ def setUpClass(cls): def setUp(self): create_new_spine_database(self._db_url) - self._out_map = DatabaseMapping(self._db_url) + self._out_db_map = DatabaseMapping(self._db_url) self._db_map = DatabaseMapping(self._db_url) def tearDown(self): - self._out_map.connection.close() - self._db_map.connection.close() + self._out_db_map.close() + self._db_map.close() def test_renaming_empty_database(self): apply_renaming_to_parameter_definition_sq(self._db_map, {"some_name": "another_name"}) @@ -165,9 +165,9 @@ def test_renaming_empty_database(self): self.assertEqual(classes, []) def test_renaming_single_parameter(self): - import_object_classes(self._out_map, ("object_class",)) - import_object_parameters(self._out_map, (("object_class", "old_name"),)) - self._out_map.commit_session("Add test data") + import_object_classes(self._out_db_map, ("object_class",)) + import_object_parameters(self._out_db_map, (("object_class", "old_name"),)) + self._out_db_map.commit_session("Add test data") apply_renaming_to_parameter_definition_sq(self._db_map, {"object_class": {"old_name": "new_name"}}) parameters = list(self._db_map.query(self._db_map.parameter_definition_sq).all()) self.assertEqual(len(parameters), 1) @@ -192,9 +192,9 @@ def test_renaming_single_parameter(self): self.assertEqual(parameter_row.name, "new_name") def test_renaming_applies_to_correct_parameter(self): - import_object_classes(self._out_map, ("oc1", "oc2")) - import_object_parameters(self._out_map, (("oc1", "param"), ("oc2", "param"))) - self._out_map.commit_session("Add test data") + import_object_classes(self._out_db_map, ("oc1", "oc2")) + import_object_parameters(self._out_db_map, (("oc1", "param"), ("oc2", "param"))) + self._out_db_map.commit_session("Add test data") apply_renaming_to_parameter_definition_sq(self._db_map, {"oc2": {"param": "new_name"}}) parameters = list(self._db_map.query(self._db_map.entity_parameter_definition_sq).all()) self.assertEqual(len(parameters), 2) @@ -212,9 +212,9 @@ def test_parameter_renamer_config(self): ) def test_parameter_renamer_from_dict(self): - import_object_classes(self._out_map, ("object_class",)) - import_object_parameters(self._out_map, (("object_class", "old_name"),)) - self._out_map.commit_session("Add test data") + import_object_classes(self._out_db_map, ("object_class",)) + import_object_parameters(self._out_db_map, (("object_class", "old_name"),)) + self._out_db_map.commit_session("Add test data") config = parameter_renamer_config({"object_class": {"old_name": "new_name"}}) parameter_renamer_from_dict(self._db_map, config) parameters = list(self._db_map.query(self._db_map.parameter_definition_sq).all()) diff --git a/tests/filters/test_scenario_filter.py b/tests/filters/test_scenario_filter.py index 5adf8d2f..d0210787 100644 --- a/tests/filters/test_scenario_filter.py +++ b/tests/filters/test_scenario_filter.py @@ -54,25 +54,25 @@ def setUpClass(cls): def setUp(self): create_new_spine_database(self._db_url) - self._out_map = DatabaseMapping(self._db_url) + self._out_db_map = DatabaseMapping(self._db_url) self._db_map = DatabaseMapping(self._db_url) def tearDown(self): - self._out_map.connection.close() - self._db_map.connection.close() + self._out_db_map.close() + self._db_map.close() def _build_data_with_single_scenario(self): - import_alternatives(self._out_map, ["alternative"]) - import_object_classes(self._out_map, ["object_class"]) - import_objects(self._out_map, [("object_class", "object")]) - import_object_parameters(self._out_map, [("object_class", "parameter")]) - import_object_parameter_values(self._out_map, [("object_class", "object", "parameter", -1.0)]) - import_object_parameter_values(self._out_map, [("object_class", "object", "parameter", 23.0, "alternative")]) - import_scenarios(self._out_map, [("scenario", True)]) - import_scenario_alternatives(self._out_map, [("scenario", "alternative")]) + import_alternatives(self._out_db_map, ["alternative"]) + import_object_classes(self._out_db_map, ["object_class"]) + import_objects(self._out_db_map, [("object_class", "object")]) + import_object_parameters(self._out_db_map, [("object_class", "parameter")]) + import_object_parameter_values(self._out_db_map, [("object_class", "object", "parameter", -1.0)]) + import_object_parameter_values(self._out_db_map, [("object_class", "object", "parameter", 23.0, "alternative")]) + import_scenarios(self._out_db_map, [("scenario", True)]) + import_scenario_alternatives(self._out_db_map, [("scenario", "alternative")]) def test_scenario_filter(self): - _build_data_with_single_scenario(self._out_map) + _build_data_with_single_scenario(self._out_db_map) apply_scenario_filter_to_subqueries(self._db_map, "scenario") parameters = self._db_map.query(self._db_map.parameter_value_sq).all() self.assertEqual(len(parameters), 1) @@ -96,19 +96,19 @@ def test_scenario_filter(self): ) def test_scenario_filter_uncommitted_data(self): - _build_data_with_single_scenario(self._out_map, commit=False) + _build_data_with_single_scenario(self._out_db_map, commit=False) with self.assertRaises(SpineDBAPIError): - apply_scenario_filter_to_subqueries(self._out_map, "scenario") - parameters = self._out_map.query(self._out_map.parameter_value_sq).all() + apply_scenario_filter_to_subqueries(self._out_db_map, "scenario") + parameters = self._out_db_map.query(self._out_db_map.parameter_value_sq).all() self.assertEqual(len(parameters), 0) - alternatives = [dict(a) for a in self._out_map.query(self._out_map.alternative_sq)] + alternatives = [dict(a) for a in self._out_db_map.query(self._out_db_map.alternative_sq)] self.assertEqual(alternatives, [{"name": "Base", "description": "Base alternative", "id": 1, "commit_id": 1}]) - scenarios = self._out_map.query(self._out_map.wide_scenario_sq).all() + scenarios = self._out_db_map.query(self._out_db_map.wide_scenario_sq).all() self.assertEqual(len(scenarios), 0) - self._out_map.rollback_session() + self._out_db_map.rollback_session() def test_scenario_filter_works_for_object_parameter_value_sq(self): - _build_data_with_single_scenario(self._out_map) + _build_data_with_single_scenario(self._out_db_map) apply_scenario_filter_to_subqueries(self._db_map, "scenario") parameters = self._db_map.query(self._db_map.object_parameter_value_sq).all() self.assertEqual(len(parameters), 1) @@ -132,17 +132,17 @@ def test_scenario_filter_works_for_object_parameter_value_sq(self): ) def test_scenario_filter_works_for_relationship_parameter_value_sq(self): - _build_data_with_single_scenario(self._out_map, commit=False) - import_relationship_classes(self._out_map, [("relationship_class", ["object_class"])]) - import_relationship_parameters(self._out_map, [("relationship_class", "relationship_parameter")]) - import_relationships(self._out_map, [("relationship_class", ["object"])]) + _build_data_with_single_scenario(self._out_db_map, commit=False) + import_relationship_classes(self._out_db_map, [("relationship_class", ["object_class"])]) + import_relationship_parameters(self._out_db_map, [("relationship_class", "relationship_parameter")]) + import_relationships(self._out_db_map, [("relationship_class", ["object"])]) import_relationship_parameter_values( - self._out_map, [("relationship_class", ["object"], "relationship_parameter", -1)] + self._out_db_map, [("relationship_class", ["object"], "relationship_parameter", -1)] ) import_relationship_parameter_values( - self._out_map, [("relationship_class", ["object"], "relationship_parameter", 23.0, "alternative")] + self._out_db_map, [("relationship_class", ["object"], "relationship_parameter", 23.0, "alternative")] ) - self._out_map.commit_session("Add test data") + self._out_db_map.commit_session("Add test data") apply_scenario_filter_to_subqueries(self._db_map, "scenario") parameters = self._db_map.query(self._db_map.relationship_parameter_value_sq).all() self.assertEqual(len(parameters), 1) @@ -166,26 +166,32 @@ def test_scenario_filter_works_for_relationship_parameter_value_sq(self): ) def test_scenario_filter_selects_highest_ranked_alternative(self): - import_alternatives(self._out_map, ["alternative3"]) - import_alternatives(self._out_map, ["alternative1"]) - import_alternatives(self._out_map, ["alternative2"]) - import_object_classes(self._out_map, ["object_class"]) - import_objects(self._out_map, [("object_class", "object")]) - import_object_parameters(self._out_map, [("object_class", "parameter")]) - import_object_parameter_values(self._out_map, [("object_class", "object", "parameter", -1.0)]) - import_object_parameter_values(self._out_map, [("object_class", "object", "parameter", 10.0, "alternative1")]) - import_object_parameter_values(self._out_map, [("object_class", "object", "parameter", 2000.0, "alternative2")]) - import_object_parameter_values(self._out_map, [("object_class", "object", "parameter", 300.0, "alternative3")]) - import_scenarios(self._out_map, [("scenario", True)]) + import_alternatives(self._out_db_map, ["alternative3"]) + import_alternatives(self._out_db_map, ["alternative1"]) + import_alternatives(self._out_db_map, ["alternative2"]) + import_object_classes(self._out_db_map, ["object_class"]) + import_objects(self._out_db_map, [("object_class", "object")]) + import_object_parameters(self._out_db_map, [("object_class", "parameter")]) + import_object_parameter_values(self._out_db_map, [("object_class", "object", "parameter", -1.0)]) + import_object_parameter_values( + self._out_db_map, [("object_class", "object", "parameter", 10.0, "alternative1")] + ) + import_object_parameter_values( + self._out_db_map, [("object_class", "object", "parameter", 2000.0, "alternative2")] + ) + import_object_parameter_values( + self._out_db_map, [("object_class", "object", "parameter", 300.0, "alternative3")] + ) + import_scenarios(self._out_db_map, [("scenario", True)]) import_scenario_alternatives( - self._out_map, + self._out_db_map, [ ("scenario", "alternative2"), ("scenario", "alternative3", "alternative2"), ("scenario", "alternative1", "alternative3"), ], ) - self._out_map.commit_session("Add test data") + self._out_db_map.commit_session("Add test data") apply_scenario_filter_to_subqueries(self._db_map, "scenario") parameters = self._db_map.query(self._db_map.parameter_value_sq).all() self.assertEqual(len(parameters), 1) @@ -216,21 +222,27 @@ def test_scenario_filter_selects_highest_ranked_alternative(self): ) def test_scenario_filter_selects_highest_ranked_alternative_of_active_scenario(self): - import_alternatives(self._out_map, ["alternative3"]) - import_alternatives(self._out_map, ["alternative1"]) - import_alternatives(self._out_map, ["alternative2"]) - import_alternatives(self._out_map, ["non_active_alternative"]) - import_object_classes(self._out_map, ["object_class"]) - import_objects(self._out_map, [("object_class", "object")]) - import_object_parameters(self._out_map, [("object_class", "parameter")]) - import_object_parameter_values(self._out_map, [("object_class", "object", "parameter", -1.0)]) - import_object_parameter_values(self._out_map, [("object_class", "object", "parameter", 10.0, "alternative1")]) - import_object_parameter_values(self._out_map, [("object_class", "object", "parameter", 2000.0, "alternative2")]) - import_object_parameter_values(self._out_map, [("object_class", "object", "parameter", 300.0, "alternative3")]) - import_scenarios(self._out_map, [("scenario", True)]) - import_scenarios(self._out_map, [("non_active_scenario", False)]) + import_alternatives(self._out_db_map, ["alternative3"]) + import_alternatives(self._out_db_map, ["alternative1"]) + import_alternatives(self._out_db_map, ["alternative2"]) + import_alternatives(self._out_db_map, ["non_active_alternative"]) + import_object_classes(self._out_db_map, ["object_class"]) + import_objects(self._out_db_map, [("object_class", "object")]) + import_object_parameters(self._out_db_map, [("object_class", "parameter")]) + import_object_parameter_values(self._out_db_map, [("object_class", "object", "parameter", -1.0)]) + import_object_parameter_values( + self._out_db_map, [("object_class", "object", "parameter", 10.0, "alternative1")] + ) + import_object_parameter_values( + self._out_db_map, [("object_class", "object", "parameter", 2000.0, "alternative2")] + ) + import_object_parameter_values( + self._out_db_map, [("object_class", "object", "parameter", 300.0, "alternative3")] + ) + import_scenarios(self._out_db_map, [("scenario", True)]) + import_scenarios(self._out_db_map, [("non_active_scenario", False)]) import_scenario_alternatives( - self._out_map, + self._out_db_map, [ ("scenario", "alternative2"), ("scenario", "alternative3", "alternative2"), @@ -238,7 +250,7 @@ def test_scenario_filter_selects_highest_ranked_alternative_of_active_scenario(s ], ) import_scenario_alternatives( - self._out_map, + self._out_db_map, [ ("non_active_scenario", "non_active_alternative"), ("scenario", "alternative2", "non_active_alternative"), @@ -246,7 +258,7 @@ def test_scenario_filter_selects_highest_ranked_alternative_of_active_scenario(s ("scenario", "alternative1", "alternative3"), ], ) - self._out_map.commit_session("Add test data") + self._out_db_map.commit_session("Add test data") apply_scenario_filter_to_subqueries(self._db_map, "scenario") parameters = self._db_map.query(self._db_map.parameter_value_sq).all() self.assertEqual(len(parameters), 1) @@ -277,23 +289,31 @@ def test_scenario_filter_selects_highest_ranked_alternative_of_active_scenario(s ) def test_scenario_filter_for_multiple_objects_and_parameters(self): - import_alternatives(self._out_map, ["alternative"]) - import_object_classes(self._out_map, ["object_class"]) - import_objects(self._out_map, [("object_class", "object1")]) - import_objects(self._out_map, [("object_class", "object2")]) - import_object_parameters(self._out_map, [("object_class", "parameter1")]) - import_object_parameters(self._out_map, [("object_class", "parameter2")]) - import_object_parameter_values(self._out_map, [("object_class", "object1", "parameter1", -1.0)]) - import_object_parameter_values(self._out_map, [("object_class", "object1", "parameter1", 10.0, "alternative")]) - import_object_parameter_values(self._out_map, [("object_class", "object1", "parameter2", -1.0)]) - import_object_parameter_values(self._out_map, [("object_class", "object1", "parameter2", 11.0, "alternative")]) - import_object_parameter_values(self._out_map, [("object_class", "object2", "parameter1", -2.0)]) - import_object_parameter_values(self._out_map, [("object_class", "object2", "parameter1", 20.0, "alternative")]) - import_object_parameter_values(self._out_map, [("object_class", "object2", "parameter2", -2.0)]) - import_object_parameter_values(self._out_map, [("object_class", "object2", "parameter2", 22.0, "alternative")]) - import_scenarios(self._out_map, [("scenario", True)]) - import_scenario_alternatives(self._out_map, [("scenario", "alternative")]) - self._out_map.commit_session("Add test data") + import_alternatives(self._out_db_map, ["alternative"]) + import_object_classes(self._out_db_map, ["object_class"]) + import_objects(self._out_db_map, [("object_class", "object1")]) + import_objects(self._out_db_map, [("object_class", "object2")]) + import_object_parameters(self._out_db_map, [("object_class", "parameter1")]) + import_object_parameters(self._out_db_map, [("object_class", "parameter2")]) + import_object_parameter_values(self._out_db_map, [("object_class", "object1", "parameter1", -1.0)]) + import_object_parameter_values( + self._out_db_map, [("object_class", "object1", "parameter1", 10.0, "alternative")] + ) + import_object_parameter_values(self._out_db_map, [("object_class", "object1", "parameter2", -1.0)]) + import_object_parameter_values( + self._out_db_map, [("object_class", "object1", "parameter2", 11.0, "alternative")] + ) + import_object_parameter_values(self._out_db_map, [("object_class", "object2", "parameter1", -2.0)]) + import_object_parameter_values( + self._out_db_map, [("object_class", "object2", "parameter1", 20.0, "alternative")] + ) + import_object_parameter_values(self._out_db_map, [("object_class", "object2", "parameter2", -2.0)]) + import_object_parameter_values( + self._out_db_map, [("object_class", "object2", "parameter2", 22.0, "alternative")] + ) + import_scenarios(self._out_db_map, [("scenario", True)]) + import_scenario_alternatives(self._out_db_map, [("scenario", "alternative")]) + self._out_db_map.commit_session("Add test data") apply_scenario_filter_to_subqueries(self._db_map, "scenario") parameters = self._db_map.query(self._db_map.parameter_value_sq).all() self.assertEqual(len(parameters), 4) @@ -331,10 +351,10 @@ def test_scenario_filter_for_multiple_objects_and_parameters(self): ) def test_filters_scenarios_and_alternatives(self): - import_scenarios(self._out_map, ("scenario1", "scenario2")) - import_alternatives(self._out_map, ("alternative1", "alternative2", "alternative3")) + import_scenarios(self._out_db_map, ("scenario1", "scenario2")) + import_alternatives(self._out_db_map, ("alternative1", "alternative2", "alternative3")) import_scenario_alternatives( - self._out_map, + self._out_db_map, ( ("scenario1", "alternative2"), ("scenario1", "alternative1", "alternative2"), @@ -342,7 +362,7 @@ def test_filters_scenarios_and_alternatives(self): ("scenario2", "alternative2", "alternative3"), ), ) - self._out_map.commit_session("Add test data.") + self._out_db_map.commit_session("Add test data.") apply_scenario_filter_to_subqueries(self._db_map, "scenario2") alternatives = [dict(a) for a in self._db_map.query(self._db_map.alternative_sq)] self.assertEqual( diff --git a/tests/filters/test_tool_filter.py b/tests/filters/test_tool_filter.py index 0c0ca922..a85455da 100644 --- a/tests/filters/test_tool_filter.py +++ b/tests/filters/test_tool_filter.py @@ -48,7 +48,7 @@ def setUp(self): self._db_map = DatabaseMapping(self._db_url) def tearDown(self): - self._db_map.connection.close() + self._db_map.close() def _build_data_with_tools(self): import_object_classes(self._db_map, ["object_class"]) diff --git a/tests/filters/test_tools.py b/tests/filters/test_tools.py index 0c63544a..41ecad30 100644 --- a/tests/filters/test_tools.py +++ b/tests/filters/test_tools.py @@ -94,7 +94,7 @@ def setUpClass(cls): db_map = DatabaseMapping(cls._db_url, create=True) import_object_classes(db_map, ("object_class",)) db_map.commit_session("Add test data.") - db_map.connection.close() + db_map.close() def test_empty_stack(self): db_map = DatabaseMapping(self._db_url) @@ -103,7 +103,7 @@ def test_empty_stack(self): object_classes = export_object_classes(db_map) self.assertEqual(object_classes, [("object_class", None, None)]) finally: - db_map.connection.close() + db_map.close() def test_single_renaming_filter(self): db_map = DatabaseMapping(self._db_url) @@ -113,7 +113,7 @@ def test_single_renaming_filter(self): object_classes = export_object_classes(db_map) self.assertEqual(object_classes, [("renamed_once", None, None)]) finally: - db_map.connection.close() + db_map.close() def test_two_renaming_filters(self): db_map = DatabaseMapping(self._db_url) @@ -126,7 +126,7 @@ def test_two_renaming_filters(self): object_classes = export_object_classes(db_map) self.assertEqual(object_classes, [("renamed_twice", None, None)]) finally: - db_map.connection.close() + db_map.close() class TestFilteredDatabaseMap(unittest.TestCase): @@ -141,7 +141,7 @@ def setUpClass(cls): db_map = DatabaseMapping(cls._db_url, create=True) import_object_classes(db_map, ("object_class",)) db_map.commit_session("Add test data.") - db_map.connection.close() + db_map.close() def test_without_filters(self): db_map = DatabaseMapping(self._db_url, self._engine) @@ -149,7 +149,7 @@ def test_without_filters(self): object_classes = export_object_classes(db_map) self.assertEqual(object_classes, [("object_class", None, None)]) finally: - db_map.connection.close() + db_map.close() def test_single_renaming_filter(self): path = os.path.join(self._dir.name, "config.json") @@ -161,7 +161,7 @@ def test_single_renaming_filter(self): object_classes = export_object_classes(db_map) self.assertEqual(object_classes, [("renamed_once", None, None)]) finally: - db_map.connection.close() + db_map.close() def test_two_renaming_filters(self): path1 = os.path.join(self._dir.name, "config1.json") @@ -177,7 +177,7 @@ def test_two_renaming_filters(self): object_classes = export_object_classes(db_map) self.assertEqual(object_classes, [("renamed_twice", None, None)]) finally: - db_map.connection.close() + db_map.close() def test_config_embedded_to_url(self): config = entity_class_renamer_config(object_class="renamed_once") @@ -187,7 +187,7 @@ def test_config_embedded_to_url(self): object_classes = export_object_classes(db_map) self.assertEqual(object_classes, [("renamed_once", None, None)]) finally: - db_map.connection.close() + db_map.close() class TestAppendFilterConfig(unittest.TestCase): diff --git a/tests/filters/test_value_transformer.py b/tests/filters/test_value_transformer.py index 18886a98..3660e85a 100644 --- a/tests/filters/test_value_transformer.py +++ b/tests/filters/test_value_transformer.py @@ -106,17 +106,17 @@ def tearDownClass(cls): def setUp(self): create_new_spine_database(self._db_url) - self._out_map = DatabaseMapping(self._db_url) + self._out_db_map = DatabaseMapping(self._db_url) def tearDown(self): - self._out_map.connection.close() + self._out_db_map.close() def test_negate_manipulator(self): - import_object_classes(self._out_map, ("class",)) - import_object_parameters(self._out_map, (("class", "parameter"),)) - import_objects(self._out_map, (("class", "object"),)) - import_object_parameter_values(self._out_map, (("class", "object", "parameter", -2.3),)) - self._out_map.commit_session("Add test data.") + import_object_classes(self._out_db_map, ("class",)) + import_object_parameters(self._out_db_map, (("class", "parameter"),)) + import_objects(self._out_db_map, (("class", "object"),)) + import_object_parameter_values(self._out_db_map, (("class", "object", "parameter", -2.3),)) + self._out_db_map.commit_session("Add test data.") instructions = {"class": {"parameter": [{"operation": "negate"}]}} config = value_transformer_config(instructions) url = append_filter_config(str(self._db_url), config) @@ -125,15 +125,15 @@ def test_negate_manipulator(self): values = [from_database(row.value, row.type) for row in db_map.query(db_map.parameter_value_sq)] self.assertEqual(values, [2.3]) finally: - db_map.connection.close() + db_map.close() def test_negate_manipulator_with_nested_map(self): - import_object_classes(self._out_map, ("class",)) - import_object_parameters(self._out_map, (("class", "parameter"),)) - import_objects(self._out_map, (("class", "object"),)) + import_object_classes(self._out_db_map, ("class",)) + import_object_parameters(self._out_db_map, (("class", "parameter"),)) + import_objects(self._out_db_map, (("class", "object"),)) value = Map(["A"], [Map(["1"], [2.3])]) - import_object_parameter_values(self._out_map, (("class", "object", "parameter", value),)) - self._out_map.commit_session("Add test data.") + import_object_parameter_values(self._out_db_map, (("class", "object", "parameter", value),)) + self._out_db_map.commit_session("Add test data.") instructions = {"class": {"parameter": [{"operation": "negate"}]}} config = value_transformer_config(instructions) url = append_filter_config(str(self._db_url), config) @@ -143,14 +143,14 @@ def test_negate_manipulator_with_nested_map(self): expected = Map(["A"], [Map(["1"], [-2.3])]) self.assertEqual(values, [expected]) finally: - db_map.connection.close() + db_map.close() def test_multiply_manipulator(self): - import_object_classes(self._out_map, ("class",)) - import_object_parameters(self._out_map, (("class", "parameter"),)) - import_objects(self._out_map, (("class", "object"),)) - import_object_parameter_values(self._out_map, (("class", "object", "parameter", -2.3),)) - self._out_map.commit_session("Add test data.") + import_object_classes(self._out_db_map, ("class",)) + import_object_parameters(self._out_db_map, (("class", "parameter"),)) + import_objects(self._out_db_map, (("class", "object"),)) + import_object_parameter_values(self._out_db_map, (("class", "object", "parameter", -2.3),)) + self._out_db_map.commit_session("Add test data.") instructions = {"class": {"parameter": [{"operation": "multiply", "rhs": 10.0}]}} config = value_transformer_config(instructions) url = append_filter_config(str(self._db_url), config) @@ -159,14 +159,14 @@ def test_multiply_manipulator(self): values = [from_database(row.value, row.type) for row in db_map.query(db_map.parameter_value_sq)] self.assertEqual(values, [-23.0]) finally: - db_map.connection.close() + db_map.close() def test_invert_manipulator(self): - import_object_classes(self._out_map, ("class",)) - import_object_parameters(self._out_map, (("class", "parameter"),)) - import_objects(self._out_map, (("class", "object"),)) - import_object_parameter_values(self._out_map, (("class", "object", "parameter", -2.3),)) - self._out_map.commit_session("Add test data.") + import_object_classes(self._out_db_map, ("class",)) + import_object_parameters(self._out_db_map, (("class", "parameter"),)) + import_objects(self._out_db_map, (("class", "object"),)) + import_object_parameter_values(self._out_db_map, (("class", "object", "parameter", -2.3),)) + self._out_db_map.commit_session("Add test data.") instructions = {"class": {"parameter": [{"operation": "invert"}]}} config = value_transformer_config(instructions) url = append_filter_config(str(self._db_url), config) @@ -175,14 +175,14 @@ def test_invert_manipulator(self): values = [from_database(row.value, row.type) for row in db_map.query(db_map.parameter_value_sq)] self.assertEqual(values, [-1.0 / 2.3]) finally: - db_map.connection.close() + db_map.close() def test_multiple_instructions(self): - import_object_classes(self._out_map, ("class",)) - import_object_parameters(self._out_map, (("class", "parameter"),)) - import_objects(self._out_map, (("class", "object"),)) - import_object_parameter_values(self._out_map, (("class", "object", "parameter", -2.3),)) - self._out_map.commit_session("Add test data.") + import_object_classes(self._out_db_map, ("class",)) + import_object_parameters(self._out_db_map, (("class", "parameter"),)) + import_objects(self._out_db_map, (("class", "object"),)) + import_object_parameter_values(self._out_db_map, (("class", "object", "parameter", -2.3),)) + self._out_db_map.commit_session("Add test data.") instructions = {"class": {"parameter": [{"operation": "invert"}, {"operation": "negate"}]}} config = value_transformer_config(instructions) url = append_filter_config(str(self._db_url), config) @@ -191,15 +191,15 @@ def test_multiple_instructions(self): values = [from_database(row.value, row.type) for row in db_map.query(db_map.parameter_value_sq)] self.assertEqual(values, [1.0 / 2.3]) finally: - db_map.connection.close() + db_map.close() def test_index_generator_on_time_series(self): - import_object_classes(self._out_map, ("class",)) - import_object_parameters(self._out_map, (("class", "parameter"),)) - import_objects(self._out_map, (("class", "object"),)) + import_object_classes(self._out_db_map, ("class",)) + import_object_parameters(self._out_db_map, (("class", "parameter"),)) + import_objects(self._out_db_map, (("class", "object"),)) value = TimeSeriesFixedResolution("2021-06-07T08:00", "1D", [-5.0, -2.3], False, False) - import_object_parameter_values(self._out_map, (("class", "object", "parameter", value),)) - self._out_map.commit_session("Add test data.") + import_object_parameter_values(self._out_db_map, (("class", "object", "parameter", value),)) + self._out_db_map.commit_session("Add test data.") instructions = {"class": {"parameter": [{"operation": "generate_index", "expression": "float(i)"}]}} config = value_transformer_config(instructions) url = append_filter_config(str(self._db_url), config) @@ -209,7 +209,7 @@ def test_index_generator_on_time_series(self): expected = Map([1.0, 2.0], [-5.0, -2.3]) self.assertEqual(values, [expected]) finally: - db_map.connection.close() + db_map.close() if __name__ == "__main__": diff --git a/tests/spine_io/exporters/test_csv_writer.py b/tests/spine_io/exporters/test_csv_writer.py index a75f592b..9ddf6dfc 100644 --- a/tests/spine_io/exporters/test_csv_writer.py +++ b/tests/spine_io/exporters/test_csv_writer.py @@ -38,7 +38,7 @@ def test_write_empty_database(self): self.assertTrue(out_path.exists()) with open(out_path) as out_file: self.assertEqual(out_file.readlines(), []) - db_map.connection.close() + db_map.close() def test_write_single_object_class_and_object(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -52,7 +52,7 @@ def test_write_single_object_class_and_object(self): self.assertTrue(out_path.exists()) with open(out_path) as out_file: self.assertEqual(out_file.readlines(), ["oc,o1\n"]) - db_map.connection.close() + db_map.close() def test_tables_are_written_to_separate_files(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -76,7 +76,7 @@ def test_tables_are_written_to_separate_files(self): self.assertEqual(out_file.readlines(), expected) self.assertEqual(len(out_files), 2) self.assertEqual(set(out_files), {"oc1.csv", "oc2.csv"}) - db_map.connection.close() + db_map.close() def test_append_to_table(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -91,7 +91,7 @@ def test_append_to_table(self): self.assertTrue(out_path.exists()) with open(out_path) as out_file: self.assertEqual(out_file.readlines(), ["oc,o1\n", "oc,o1\n"]) - db_map.connection.close() + db_map.close() if __name__ == "__main__": diff --git a/tests/spine_io/exporters/test_excel_writer.py b/tests/spine_io/exporters/test_excel_writer.py index 470401ef..65f8cf02 100644 --- a/tests/spine_io/exporters/test_excel_writer.py +++ b/tests/spine_io/exporters/test_excel_writer.py @@ -41,7 +41,7 @@ def test_write_empty_database(self): sheet = workbook["Sheet1"] self.assertEqual(sheet.calculate_dimension(), "A1:A1") workbook.close() - db_map.connection.close() + db_map.close() def test_write_single_object_class_and_object(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -57,7 +57,7 @@ def test_write_single_object_class_and_object(self): expected = [["oc", "o1"]] self.check_sheet(workbook, "Sheet1", expected) workbook.close() - db_map.connection.close() + db_map.close() def test_write_to_existing_sheet(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -73,7 +73,7 @@ def test_write_to_existing_sheet(self): expected = [["o1"], ["o2"]] self.check_sheet(workbook, "Sheet1", expected) workbook.close() - db_map.connection.close() + db_map.close() def test_write_to_named_sheets(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -91,7 +91,7 @@ def test_write_to_named_sheets(self): expected = [[None, "o21"]] self.check_sheet(workbook, "oc2", expected) workbook.close() - db_map.connection.close() + db_map.close() def test_append_to_anonymous_table(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -108,7 +108,7 @@ def test_append_to_anonymous_table(self): expected = [["oc", "o1"], ["oc", "o1"]] self.check_sheet(workbook, "Sheet1", expected) workbook.close() - db_map.connection.close() + db_map.close() def test_append_to_named_table(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -125,7 +125,7 @@ def test_append_to_named_table(self): expected = [["o1"], ["o1"]] self.check_sheet(workbook, "oc", expected) workbook.close() - db_map.connection.close() + db_map.close() def check_sheet(self, workbook, sheet_name, expected): """ diff --git a/tests/spine_io/exporters/test_gdx_writer.py b/tests/spine_io/exporters/test_gdx_writer.py index 50d57d21..48407fd4 100644 --- a/tests/spine_io/exporters/test_gdx_writer.py +++ b/tests/spine_io/exporters/test_gdx_writer.py @@ -51,7 +51,7 @@ def test_write_empty_database(self): write(db_map, writer, root_mapping) with GdxFile(str(file_path), "r", self._gams_dir) as gdx_file: self.assertEqual(len(gdx_file), 0) - db_map.connection.close() + db_map.close() @unittest.skipIf(_gams_dir is None, "No working GAMS installation found.") def test_write_single_object_class_and_object(self): @@ -70,7 +70,7 @@ def test_write_single_object_class_and_object(self): gams_set = gdx_file["oc"] self.assertIsNone(gams_set.domain) self.assertEqual(gams_set.elements, ["o1"]) - db_map.connection.close() + db_map.close() @unittest.skipIf(_gams_dir is None, "No working GAMS installation found.") def test_write_2D_relationship(self): @@ -90,7 +90,7 @@ def test_write_2D_relationship(self): gams_set = gdx_file["rel"] self.assertEqual(gams_set.domain, ["oc1", "oc2"]) self.assertEqual(gams_set.elements, [("o1", "o2")]) - db_map.connection.close() + db_map.close() @unittest.skipIf(_gams_dir is None, "No working GAMS installation found.") def test_write_parameters(self): @@ -114,7 +114,7 @@ def test_write_parameters(self): gams_parameter = gdx_file["oc"] self.assertEqual(len(gams_parameter), 1) self.assertEqual(gams_parameter["o1"], 2.3) - db_map.connection.close() + db_map.close() @unittest.skipIf(_gams_dir is None, "No working GAMS installation found.") def test_non_numerical_parameter_value_raises_writer_expection(self): @@ -133,7 +133,7 @@ def test_non_numerical_parameter_value_raises_writer_expection(self): file_path = Path(temp_dir, "test_write_parameters.gdx") writer = GdxWriter(str(file_path), self._gams_dir) self.assertRaises(WriterException, write, db_map, writer, root_mapping) - db_map.connection.close() + db_map.close() @unittest.skipIf(_gams_dir is None, "No working GAMS installation found.") def test_empty_parameter(self): @@ -158,7 +158,7 @@ def test_empty_parameter(self): gams_parameter = gdx_file["oc"] self.assertIsInstance(gams_parameter, GAMSParameter) self.assertEqual(len(gams_parameter), 0) - db_map.connection.close() + db_map.close() @unittest.skipIf(_gams_dir is None, "No working GAMS installation found.") def test_write_scalars(self): @@ -177,7 +177,7 @@ def test_write_scalars(self): self.assertEqual(len(gdx_file), 1) gams_scalar = gdx_file["oc"] self.assertEqual(float(gams_scalar), 2.3) - db_map.connection.close() + db_map.close() @unittest.skipIf(_gams_dir is None, "No working GAMS installation found.") def test_two_tables(self): @@ -199,7 +199,7 @@ def test_two_tables(self): gams_set = gdx_file["oc2"] self.assertIsNone(gams_set.domain) self.assertEqual(gams_set.elements, ["p"]) - db_map.connection.close() + db_map.close() @unittest.skipIf(_gams_dir is None, "No working GAMS installation found.") def test_append_to_table(self): @@ -226,7 +226,7 @@ def test_append_to_table(self): gams_set = gdx_file["set_X"] self.assertIsNone(gams_set.domain) self.assertEqual(gams_set.elements, ["o", "p"]) - db_map.connection.close() + db_map.close() @unittest.skipIf(_gams_dir is None, "No working GAMS installation found.") def test_parameter_value_non_convertible_to_float_raises_WriterException(self): @@ -247,7 +247,7 @@ def test_parameter_value_non_convertible_to_float_raises_WriterException(self): file_path = Path(temp_dir, "test_two_tables.gdx") writer = GdxWriter(str(file_path), self._gams_dir) self.assertRaises(WriterException, write, db_map, writer, root_mapping) - db_map.connection.close() + db_map.close() @unittest.skipIf(_gams_dir is None, "No working GAMS installation found.") def test_non_string_set_element_raises_WriterException(self): @@ -268,7 +268,7 @@ def test_non_string_set_element_raises_WriterException(self): file_path = Path(temp_dir, "test_two_tables.gdx") writer = GdxWriter(str(file_path), self._gams_dir) self.assertRaises(WriterException, write, db_map, writer, root_mapping) - db_map.connection.close() + db_map.close() @unittest.skipIf(_gams_dir is None, "No working GAMS installation found.") def test_special_value_conversions(self): @@ -305,7 +305,7 @@ def test_special_value_conversions(self): self.assertEqual(gams_parameter[("o1", "infinity")], math.inf) self.assertEqual(gams_parameter[("o1", "negative_infinity")], -math.inf) self.assertTrue(math.isnan(gams_parameter[("o1", "nan")])) - db_map.connection.close() + db_map.close() if __name__ == '__main__': diff --git a/tests/spine_io/exporters/test_sql_writer.py b/tests/spine_io/exporters/test_sql_writer.py index fa17f0cb..702adeac 100644 --- a/tests/spine_io/exporters/test_sql_writer.py +++ b/tests/spine_io/exporters/test_sql_writer.py @@ -53,7 +53,7 @@ def test_write_empty_database(self): out_path = Path(self._temp_dir.name, "out.sqlite") writer = SqlWriter(str(out_path), overwrite_existing=True) write(db_map, writer, settings) - db_map.connection.close() + db_map.close() self.assertTrue(out_path.exists()) def test_write_header_only(self): @@ -70,7 +70,7 @@ def test_write_header_only(self): out_path = Path(self._temp_dir.name, "out.sqlite") writer = SqlWriter(str(out_path), overwrite_existing=True) write(db_map, writer, root_mapping) - db_map.connection.close() + db_map.close() self.assertTrue(out_path.exists()) engine = create_engine("sqlite:///" + str(out_path)) connection = engine.connect() @@ -102,7 +102,7 @@ def test_write_single_object_class_and_object(self): out_path = Path(self._temp_dir.name, "out.sqlite") writer = SqlWriter(str(out_path), overwrite_existing=True) write(db_map, writer, root_mapping) - db_map.connection.close() + db_map.close() self.assertTrue(out_path.exists()) engine = create_engine("sqlite:///" + str(out_path)) connection = engine.connect() @@ -141,7 +141,7 @@ def test_write_datetime_value(self): out_path = Path(self._temp_dir.name, "out.sqlite") writer = SqlWriter(str(out_path), overwrite_existing=True) write(db_map, writer, root_mapping) - db_map.connection.close() + db_map.close() self.assertTrue(out_path.exists()) engine = create_engine("sqlite:///" + str(out_path)) connection = engine.connect() @@ -181,7 +181,7 @@ def test_write_duration_value(self): out_path = Path(self._temp_dir.name, "out.sqlite") writer = SqlWriter(str(out_path), overwrite_existing=True) write(db_map, writer, root_mapping) - db_map.connection.close() + db_map.close() self.assertTrue(out_path.exists()) engine = create_engine("sqlite:///" + str(out_path)) connection = engine.connect() @@ -216,7 +216,7 @@ def test_append_to_table(self): writer = SqlWriter(str(out_path), overwrite_existing=True) write(db_map, writer, root_mapping1) write(db_map, writer, root_mapping2) - db_map.connection.close() + db_map.close() self.assertTrue(out_path.exists()) engine = create_engine("sqlite:///" + str(out_path)) connection = engine.connect() @@ -256,7 +256,7 @@ def test_appending_to_table_in_existing_database(self): root_mapping.child.header = "objects" writer = SqlWriter(str(out_path), overwrite_existing=False) write(db_map, writer, root_mapping) - db_map.connection.close() + db_map.close() self.assertTrue(out_path.exists()) engine = create_engine("sqlite:///" + str(out_path)) connection = engine.connect() diff --git a/tests/spine_io/exporters/test_writer.py b/tests/spine_io/exporters/test_writer.py index 5e642b7b..9a9283c0 100644 --- a/tests/spine_io/exporters/test_writer.py +++ b/tests/spine_io/exporters/test_writer.py @@ -44,7 +44,7 @@ def setUp(self): self._db_map = DatabaseMapping("sqlite://", create=True) def tearDown(self): - self._db_map.connection.close() + self._db_map.close() def test_max_rows(self): import_object_classes(self._db_map, ("class1", "class2")) diff --git a/tests/spine_io/test_excel_integration.py b/tests/spine_io/test_excel_integration.py index 1031f7c4..e1ed5b4b 100644 --- a/tests/spine_io/test_excel_integration.py +++ b/tests/spine_io/test_excel_integration.py @@ -114,7 +114,7 @@ def _check_parameter_value(self, val): path = str(PurePath(directory, _TEMP_EXCEL_FILENAME)) export_spine_database_to_xlsx(db_map, path) output_data, errors = get_mapped_data_from_xlsx(path) - db_map.connection.close() + db_map.close() self.assertEqual([], errors) input_param_vals = input_data.pop("parameter_values") output_param_vals = output_data.pop("parameter_values") diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 7052c013..034fe965 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -30,7 +30,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - cls._db_map.connection.close() + cls._db_map.close() def test_construction_with_filters(self): db_url = IN_MEMORY_DB_URL + "?spinedbfilter=fltr1&spinedbfilter=fltr2" @@ -39,7 +39,7 @@ def test_construction_with_filters(self): "spinedb_api.db_mapping.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] ) as mock_load: db_map = DatabaseMapping(db_url, create=True) - db_map.connection.close() + db_map.close() mock_load.assert_called_once_with(["fltr1", "fltr2"]) mock_apply.assert_called_once_with(db_map, [{"fltr1": "config1", "fltr2": "config2"}]) @@ -51,7 +51,7 @@ def test_construction_with_sqlalchemy_url_and_filters(self): "spinedb_api.db_mapping.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] ) as mock_load: db_map = DatabaseMapping(sa_url, create=True) - db_map.connection.close() + db_map.close() mock_load.assert_called_once_with(["fltr1", "fltr2"]) mock_apply.assert_called_once_with(db_map, [{"fltr1": "config1", "fltr2": "config2"}]) @@ -287,7 +287,7 @@ def setUp(self): self._db_map = DatabaseMapping(IN_MEMORY_DB_URL, create=True) def tearDown(self): - self._db_map.connection.close() + self._db_map.close() def create_object_classes(self): obj_classes = ['class1', 'class2'] @@ -544,7 +544,7 @@ def setUp(self): self._db_map = DatabaseMapping(IN_MEMORY_DB_URL, create=True) def tearDown(self): - self._db_map.connection.close() + self._db_map.close() def test_update_wide_relationship_class(self): _ = import_functions.import_object_classes(self._db_map, ("object_class_1",)) @@ -873,7 +873,7 @@ def setUp(self): self._db_map = DatabaseMapping(IN_MEMORY_DB_URL, create=True) def tearDown(self): - self._db_map.connection.close() + self._db_map.close() def test_remove_works_when_entity_groups_are_present(self): import_functions.import_object_classes(self._db_map, ("my_class",)) @@ -953,14 +953,14 @@ def setUp(self): self._db_map = DatabaseMapping(IN_MEMORY_DB_URL, create=True) def tearDown(self): - self._db_map.connection.close() + self._db_map.close() def test_commit_message(self): """Tests that commit comment ends up in the database.""" self._db_map.add_object_classes({"name": "testclass"}) self._db_map.commit_session("test commit") self.assertEqual(self._db_map.query(self._db_map.commit_sq).all()[-1].comment, "test commit") - self._db_map.connection.close() + self._db_map.close() def test_commit_session_raise_with_empty_comment(self): import_functions.import_object_classes(self._db_map, ("my_class",)) diff --git a/tests/test_DiffDatabaseMapping.py b/tests/test_DiffDatabaseMapping.py index a8df6460..a779d019 100644 --- a/tests/test_DiffDatabaseMapping.py +++ b/tests/test_DiffDatabaseMapping.py @@ -50,7 +50,7 @@ def test_construction_with_filters(self): "spinedb_api.db_mapping.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] ) as mock_load: db_map = DatabaseMapping(db_url, create=True) - db_map.connection.close() + db_map.close() mock_load.assert_called_once_with(["fltr1", "fltr2"]) mock_apply.assert_called_once_with(db_map, [{"fltr1": "config1", "fltr2": "config2"}]) @@ -62,7 +62,7 @@ def test_construction_with_sqlalchemy_url_and_filters(self): "spinedb_api.db_mapping.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] ) as mock_load: db_map = DatabaseMapping(sa_url, create=True) - db_map.connection.close() + db_map.close() mock_load.assert_called_once_with(["fltr1", "fltr2"]) mock_apply.assert_called_once_with(db_map, [{"fltr1": "config1", "fltr2": "config2"}]) @@ -70,17 +70,17 @@ def test_shorthand_filter_query_works(self): with TemporaryDirectory() as temp_dir: url = URL("sqlite") url.database = os.path.join(temp_dir, "test_shorthand_filter_query_works.json") - out_db = DatabaseMapping(url, create=True) - out_db.add_scenarios({"name": "scen1"}) - out_db.add_scenario_alternatives({"scenario_name": "scen1", "alternative_name": "Base", "rank": 1}) - out_db.commit_session("Add scen.") - out_db.connection.close() + out_db_map = DatabaseMapping(url, create=True) + out_db_map.add_scenarios({"name": "scen1"}) + out_db_map.add_scenario_alternatives({"scenario_name": "scen1", "alternative_name": "Base", "rank": 1}) + out_db_map.commit_session("Add scen.") + out_db_map.close() try: db_map = DatabaseMapping(url) except: self.fail("DatabaseMapping.__init__() should not raise.") else: - db_map.connection.close() + db_map.close() class TestDatabaseMappingRemove(unittest.TestCase): @@ -88,7 +88,7 @@ def setUp(self): self._db_map = create_diff_db_map() def tearDown(self): - self._db_map.connection.close() + self._db_map.close() def test_cascade_remove_relationship(self): """Test adding and removing a relationship and committing""" @@ -412,7 +412,7 @@ def setUp(self): self._db_map = create_diff_db_map() def tearDown(self): - self._db_map.connection.close() + self._db_map.close() def test_add_and_retrieve_many_objects(self): """Tests add many objects into db and retrieving them.""" @@ -1183,7 +1183,7 @@ def setUp(self): self._db_map = create_diff_db_map() def tearDown(self): - self._db_map.connection.close() + self._db_map.close() def test_update_object_classes(self): """Test that updating object classes works.""" @@ -1289,14 +1289,14 @@ def setUp(self): self._db_map = create_diff_db_map() def tearDown(self): - self._db_map.connection.close() + self._db_map.close() def test_commit_message(self): """Tests that commit comment ends up in the database.""" self._db_map.add_object_classes({"name": "testclass"}) self._db_map.commit_session("test commit") self.assertEqual(self._db_map.query(self._db_map.commit_sq).all()[-1].comment, "test commit") - self._db_map.connection.close() + self._db_map.close() def test_commit_session_raise_with_empty_comment(self): import_functions.import_object_classes(self._db_map, ("my_class",)) diff --git a/tests/test_export_functions.py b/tests/test_export_functions.py index 385117f2..8179d17c 100644 --- a/tests/test_export_functions.py +++ b/tests/test_export_functions.py @@ -42,7 +42,7 @@ def setUp(self): self._db_map = DatabaseMapping(db_url, username="UnitTest", create=True) def tearDown(self): - self._db_map.connection.close() + self._db_map.close() def test_export_alternatives(self): import_alternatives(self._db_map, [("alternative", "Description")]) diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py index 8f44d1dd..86d87261 100644 --- a/tests/test_import_functions.py +++ b/tests/test_import_functions.py @@ -106,7 +106,7 @@ def test_import_data_integration(self): scenarios=scenarios, scenario_alternatives=scenario_alternatives, ) - db_map.connection.close() + db_map.close() self.assertEqual(num_imports, 13) self.assertFalse(errors) @@ -118,7 +118,7 @@ def test_import_object_class(self): self.assertFalse(errors) db_map.commit_session("test") self.assertIn("new_class", [oc.name for oc in db_map.query(db_map.object_class_sq)]) - db_map.connection.close() + db_map.close() class TestImportObject(unittest.TestCase): @@ -129,13 +129,13 @@ def test_import_valid_objects(self): self.assertFalse(errors) db_map.commit_session("test") self.assertIn("new_object", [o.name for o in db_map.query(db_map.object_sq)]) - db_map.connection.close() + db_map.close() def test_import_object_with_invalid_object_class_name(self): db_map = create_diff_db_map() _, errors = import_objects(db_map, [["nonexistent_class", "new_object"]]) self.assertTrue(errors) - db_map.connection.close() + db_map.close() def test_import_two_objects_with_same_name(self): db_map = create_diff_db_map() @@ -151,7 +151,7 @@ def test_import_two_objects_with_same_name(self): } expected = {"object_class1": "object", "object_class2": "object"} self.assertEqual(objects, expected) - db_map.connection.close() + db_map.close() def test_import_existing_object(self): db_map = create_diff_db_map() @@ -162,7 +162,7 @@ def test_import_existing_object(self): _, errors = import_objects(db_map, [["object_class", "object"]]) self.assertFalse(errors) self.assertIn("object", [o.name for o in db_map.query(db_map.object_sq)]) - db_map.connection.close() + db_map.close() class TestImportRelationshipClass(unittest.TestCase): @@ -177,7 +177,7 @@ def test_import_valid_relationship_class(self): } expected = {"relationship_class": "object_class1,object_class2"} self.assertEqual(relationship_classes, expected) - db_map.connection.close() + db_map.close() def test_import_relationship_class_with_invalid_object_class_name(self): db_map = create_diff_db_map() @@ -186,7 +186,7 @@ def test_import_relationship_class_with_invalid_object_class_name(self): self.assertTrue(errors) db_map.commit_session("test") self.assertFalse(db_map.query(db_map.wide_relationship_class_sq).all()) - db_map.connection.close() + db_map.close() def test_import_relationship_class_name_twice(self): db_map = create_diff_db_map() @@ -201,7 +201,7 @@ def test_import_relationship_class_name_twice(self): } expected = {"new_rc": "object_class1,object_class2"} self.assertEqual(relationship_classes, expected) - db_map.connection.close() + db_map.close() def test_import_existing_relationship_class(self): db_map = create_diff_db_map() @@ -209,7 +209,7 @@ def test_import_existing_relationship_class(self): import_relationship_classes(db_map, [["rc", ["object_class1", "object_class2"]]]) _, errors = import_relationship_classes(db_map, [["rc", ["object_class1", "object_class2"]]]) self.assertFalse(errors) - db_map.connection.close() + db_map.close() def test_import_relationship_class_with_one_object_class_as_None(self): db_map = create_diff_db_map() @@ -218,7 +218,7 @@ def test_import_relationship_class_with_one_object_class_as_None(self): self.assertTrue(errors) db_map.commit_session("test") self.assertFalse([rc for rc in db_map.query(db_map.wide_relationship_class_sq)]) - db_map.connection.close() + db_map.close() class TestImportObjectClassParameter(unittest.TestCase): @@ -229,13 +229,13 @@ def test_import_valid_object_class_parameter(self): self.assertFalse(errors) db_map.commit_session("test") self.assertIn("new_parameter", [p.name for p in db_map.query(db_map.parameter_definition_sq)]) - db_map.connection.close() + db_map.close() def test_import_parameter_with_invalid_object_class_name(self): db_map = create_diff_db_map() _, errors = import_object_parameters(db_map, [["nonexistent_object_class", "new_parameter"]]) self.assertTrue(errors) - db_map.connection.close() + db_map.close() def test_import_object_class_parameter_name_twice(self): db_map = create_diff_db_map() @@ -251,7 +251,7 @@ def test_import_object_class_parameter_name_twice(self): } expected = {"object_class1": "new_parameter", "object_class2": "new_parameter"} self.assertEqual(definitions, expected) - db_map.connection.close() + db_map.close() def test_import_existing_object_class_parameter(self): db_map = create_diff_db_map() @@ -262,7 +262,7 @@ def test_import_existing_object_class_parameter(self): _, errors = import_object_parameters(db_map, [["object_class", "parameter"]]) self.assertIn("parameter", [p.name for p in db_map.query(db_map.parameter_definition_sq)]) self.assertFalse(errors) - db_map.connection.close() + db_map.close() def test_import_object_class_parameter_with_null_default_value_and_db_server_unparsing(self): db_map = DatabaseMapping("sqlite://", create=True) @@ -276,7 +276,7 @@ def test_import_object_class_parameter_with_null_default_value_and_db_server_unp self.assertEqual(len(parameters), 1) self.assertIsNone(parameters[0].default_value) self.assertIsNone(parameters[0].default_type) - db_map.connection.close() + db_map.close() class TestImportRelationshipClassParameter(unittest.TestCase): @@ -296,13 +296,13 @@ def test_import_valid_relationship_class_parameter(self): } expected = {"relationship_class": "new_parameter"} self.assertEqual(definitions, expected) - db_map.connection.close() + db_map.close() def test_import_parameter_with_invalid_relationship_class_name(self): db_map = create_diff_db_map() _, errors = import_relationship_parameters(db_map, [["nonexistent_relationship_class", "new_parameter"]]) self.assertTrue(errors) - db_map.connection.close() + db_map.close() def test_import_relationship_class_parameter_name_twice(self): db_map = create_diff_db_map() @@ -328,7 +328,7 @@ def test_import_relationship_class_parameter_name_twice(self): } expected = {"relationship_class1": "new_parameter", "relationship_class2": "new_parameter"} self.assertEqual(definitions, expected) - db_map.connection.close() + db_map.close() def test_import_existing_relationship_class_parameter(self): db_map = create_diff_db_map() @@ -337,7 +337,7 @@ def test_import_existing_relationship_class_parameter(self): import_relationship_parameters(db_map, [["relationship_class", "new_parameter"]]) _, errors = import_relationship_parameters(db_map, [["relationship_class", "new_parameter"]]) self.assertFalse(errors) - db_map.connection.close() + db_map.close() class TestImportRelationship(unittest.TestCase): @@ -355,7 +355,7 @@ def test_import_relationships(self): self.assertFalse(errors) db_map.commit_session("test") self.assertIn("relationship_class_object", [r.name for r in db_map.query(db_map.relationship_sq)]) - db_map.connection.close() + db_map.close() def test_import_valid_relationship(self): db_map = create_diff_db_map() @@ -365,7 +365,7 @@ def test_import_valid_relationship(self): self.assertFalse(errors) db_map.commit_session("test") self.assertIn("relationship_class_object1__object2", [r.name for r in db_map.query(db_map.relationship_sq)]) - db_map.connection.close() + db_map.close() def test_import_valid_relationship_with_object_name_in_multiple_classes(self): db_map = create_diff_db_map() @@ -376,7 +376,7 @@ def test_import_valid_relationship_with_object_name_in_multiple_classes(self): self.assertFalse(errors) db_map.commit_session("test") self.assertIn("relationship_class_duplicate__object2", [r.name for r in db_map.query(db_map.relationship_sq)]) - db_map.connection.close() + db_map.close() def test_import_relationship_with_invalid_class_name(self): db_map = create_diff_db_map() @@ -385,7 +385,7 @@ def test_import_relationship_with_invalid_class_name(self): self.assertTrue(errors) db_map.commit_session("test") self.assertFalse([r.name for r in db_map.query(db_map.relationship_sq)]) - db_map.connection.close() + db_map.close() def test_import_relationship_with_invalid_object_name(self): db_map = create_diff_db_map() @@ -395,7 +395,7 @@ def test_import_relationship_with_invalid_object_name(self): self.assertTrue(errors) db_map.commit_session("test") self.assertFalse([r.name for r in db_map.query(db_map.relationship_sq)]) - db_map.connection.close() + db_map.close() def test_import_existing_relationship(self): db_map = create_diff_db_map() @@ -408,7 +408,7 @@ def test_import_existing_relationship(self): self.assertFalse(errors) db_map.commit_session("test") self.assertIn("relationship_class_object1__object2", [r.name for r in db_map.query(db_map.relationship_sq)]) - db_map.connection.close() + db_map.close() def test_import_relationship_with_one_None_object(self): db_map = create_diff_db_map() @@ -418,7 +418,7 @@ def test_import_relationship_with_one_None_object(self): self.assertTrue(errors) db_map.commit_session("test") self.assertFalse([r.name for r in db_map.query(db_map.relationship_sq)]) - db_map.connection.close() + db_map.close() class TestImportParameterDefinition(unittest.TestCase): @@ -426,7 +426,7 @@ def setUp(self): self._db_map = DatabaseMapping("sqlite://", create=True) def tearDown(self): - self._db_map.connection.close() + self._db_map.close() def test_import_object_parameter_definition(self): import_object_classes(self._db_map, ["my_object_class"]) @@ -539,7 +539,7 @@ def test_import_valid_object_parameter_value(self): values = {v.object_name: v.value for v in db_map.query(db_map.object_parameter_value_sq)} expected = {"object1": b"1"} self.assertEqual(values, expected) - db_map.connection.close() + db_map.close() def test_import_valid_object_parameter_value_string(self): db_map = create_diff_db_map() @@ -550,7 +550,7 @@ def test_import_valid_object_parameter_value_string(self): values = {v.object_name: v.value for v in db_map.query(db_map.object_parameter_value_sq)} expected = {"object1": b'"value_string"'} self.assertEqual(values, expected) - db_map.connection.close() + db_map.close() def test_import_valid_object_parameter_value_with_duplicate_object_name(self): db_map = create_diff_db_map() @@ -562,7 +562,7 @@ def test_import_valid_object_parameter_value_with_duplicate_object_name(self): values = {v.object_class_name: {v.object_name: v.value} for v in db_map.query(db_map.object_parameter_value_sq)} expected = {"object_class1": {"duplicate_object": b"1"}} self.assertEqual(values, expected) - db_map.connection.close() + db_map.close() def test_import_valid_object_parameter_value_with_duplicate_parameter_name(self): db_map = create_diff_db_map() @@ -574,7 +574,7 @@ def test_import_valid_object_parameter_value_with_duplicate_parameter_name(self) values = {v.object_class_name: {v.object_name: v.value} for v in db_map.query(db_map.object_parameter_value_sq)} expected = {"object_class1": {"object1": b"1"}} self.assertEqual(values, expected) - db_map.connection.close() + db_map.close() def test_import_object_parameter_value_with_invalid_object(self): db_map = create_diff_db_map() @@ -584,7 +584,7 @@ def test_import_object_parameter_value_with_invalid_object(self): self.assertTrue(errors) db_map.commit_session("test") self.assertFalse(db_map.query(db_map.object_parameter_value_sq).all()) - db_map.connection.close() + db_map.close() def test_import_object_parameter_value_with_invalid_parameter(self): db_map = create_diff_db_map() @@ -594,7 +594,7 @@ def test_import_object_parameter_value_with_invalid_parameter(self): self.assertTrue(errors) db_map.commit_session("test") self.assertFalse(db_map.query(db_map.object_parameter_value_sq).all()) - db_map.connection.close() + db_map.close() def test_import_existing_object_parameter_value_update_the_value(self): db_map = create_diff_db_map() @@ -606,7 +606,7 @@ def test_import_existing_object_parameter_value_update_the_value(self): values = {v.object_name: v.value for v in db_map.query(db_map.object_parameter_value_sq)} expected = {"object1": b'"new_value"'} self.assertEqual(values, expected) - db_map.connection.close() + db_map.close() def test_import_existing_object_parameter_value_on_conflict_keep(self): db_map = create_diff_db_map() @@ -623,7 +623,7 @@ def test_import_existing_object_parameter_value_on_conflict_keep(self): value = from_database(pv.value, pv.type) self.assertEqual(['2000-01-01T01:00:00', '2000-01-01T02:00:00'], [str(x) for x in value.indexes]) self.assertEqual([1.0, 2.0], list(value.values)) - db_map.connection.close() + db_map.close() def test_import_existing_object_parameter_value_on_conflict_replace(self): db_map = create_diff_db_map() @@ -640,7 +640,7 @@ def test_import_existing_object_parameter_value_on_conflict_replace(self): value = from_database(pv.value, pv.type) self.assertEqual(['2000-01-01T02:00:00', '2000-01-01T03:00:00'], [str(x) for x in value.indexes]) self.assertEqual([3.0, 4.0], list(value.values)) - db_map.connection.close() + db_map.close() def test_import_existing_object_parameter_value_on_conflict_merge(self): db_map = create_diff_db_map() @@ -659,7 +659,7 @@ def test_import_existing_object_parameter_value_on_conflict_merge(self): ['2000-01-01T01:00:00', '2000-01-01T02:00:00', '2000-01-01T03:00:00'], [str(x) for x in value.indexes] ) self.assertEqual([1.0, 3.0, 4.0], list(value.values)) - db_map.connection.close() + db_map.close() def test_import_existing_object_parameter_value_on_conflict_merge_map(self): db_map = create_diff_db_map() @@ -688,7 +688,7 @@ def test_import_existing_object_parameter_value_on_conflict_merge_map(self): ['2000-01-01T01:00:00', '2000-01-01T02:00:00', '2000-01-01T03:00:00'], [str(x) for x in ts.indexes] ) self.assertEqual([1.0, 3.0, 4.0], list(ts.values)) - db_map.connection.close() + db_map.close() def test_import_duplicate_object_parameter_value(self): db_map = create_diff_db_map() @@ -702,7 +702,7 @@ def test_import_duplicate_object_parameter_value(self): values = {v.object_name: v.value for v in db_map.query(db_map.object_parameter_value_sq)} expected = {"object1": b'"first"'} self.assertEqual(values, expected) - db_map.connection.close() + db_map.close() def test_import_object_parameter_value_with_alternative(self): db_map = create_diff_db_map() @@ -719,7 +719,7 @@ def test_import_object_parameter_value_with_alternative(self): } expected = {"object1": (b"1", "alternative")} self.assertEqual(values, expected) - db_map.connection.close() + db_map.close() def test_import_object_parameter_value_fails_with_nonexistent_alternative(self): db_map = create_diff_db_map() @@ -729,7 +729,7 @@ def test_import_object_parameter_value_fails_with_nonexistent_alternative(self): ) self.assertTrue(errors) self.assertEqual(count, 0) - db_map.connection.close() + db_map.close() def test_valid_object_parameter_value_from_value_list(self): db_map = create_diff_db_map() @@ -745,7 +745,7 @@ def test_valid_object_parameter_value_from_value_list(self): self.assertEqual(len(values), 1) value = values[0] self.assertEqual(from_database(value.value), 5.0) - db_map.connection.close() + db_map.close() def test_non_existent_object_parameter_value_from_value_list_fails_gracefully(self): db_map = create_diff_db_map() @@ -756,7 +756,7 @@ def test_non_existent_object_parameter_value_from_value_list_fails_gracefully(se count, errors = import_object_parameter_values(db_map, (("object_class", "my_object", "parameter", 2.3),)) self.assertEqual(count, 0) self.assertEqual(len(errors), 1) - db_map.connection.close() + db_map.close() def test_import_valid_relationship_parameter_value(self): db_map = create_diff_db_map() @@ -769,7 +769,7 @@ def test_import_valid_relationship_parameter_value(self): values = {v.object_name_list: v.value for v in db_map.query(db_map.relationship_parameter_value_sq)} expected = {"object1,object2": b"1"} self.assertEqual(values, expected) - db_map.connection.close() + db_map.close() def test_import_valid_relationship_parameter_value_with_duplicate_parameter_name(self): db_map = create_diff_db_map() @@ -784,7 +784,7 @@ def test_import_valid_relationship_parameter_value_with_duplicate_parameter_name values = {v.object_name_list: v.value for v in db_map.query(db_map.relationship_parameter_value_sq)} expected = {"object1,object2": b"1"} self.assertEqual(values, expected) - db_map.connection.close() + db_map.close() def test_import_valid_relationship_parameter_value_with_duplicate_object_name(self): db_map = create_diff_db_map() @@ -799,7 +799,7 @@ def test_import_valid_relationship_parameter_value_with_duplicate_object_name(se values = {v.object_name_list: v.value for v in db_map.query(db_map.relationship_parameter_value_sq)} expected = {"duplicate_object,duplicate_object": b"1"} self.assertEqual(values, expected) - db_map.connection.close() + db_map.close() def test_import_relationship_parameter_value_with_invalid_object(self): db_map = create_diff_db_map() @@ -810,7 +810,7 @@ def test_import_relationship_parameter_value_with_invalid_object(self): self.assertTrue(errors) db_map.commit_session("test") self.assertFalse(db_map.query(db_map.relationship_parameter_value_sq).all()) - db_map.connection.close() + db_map.close() def test_import_relationship_parameter_value_with_invalid_relationship_class(self): db_map = create_diff_db_map() @@ -821,7 +821,7 @@ def test_import_relationship_parameter_value_with_invalid_relationship_class(sel self.assertTrue(errors) db_map.commit_session("test") self.assertFalse(db_map.query(db_map.relationship_parameter_value_sq).all()) - db_map.connection.close() + db_map.close() def test_import_relationship_parameter_value_with_invalid_parameter(self): db_map = create_diff_db_map() @@ -832,7 +832,7 @@ def test_import_relationship_parameter_value_with_invalid_parameter(self): self.assertTrue(errors) db_map.commit_session("test") self.assertFalse(db_map.query(db_map.relationship_parameter_value_sq).all()) - db_map.connection.close() + db_map.close() def test_import_existing_relationship_parameter_value(self): db_map = create_diff_db_map() @@ -848,7 +848,7 @@ def test_import_existing_relationship_parameter_value(self): values = {v.object_name_list: v.value for v in db_map.query(db_map.relationship_parameter_value_sq)} expected = {"object1,object2": b'"new_value"'} self.assertEqual(values, expected) - db_map.connection.close() + db_map.close() def test_import_duplicate_relationship_parameter_value(self): db_map = create_diff_db_map() @@ -865,7 +865,7 @@ def test_import_duplicate_relationship_parameter_value(self): values = {v.object_name_list: v.value for v in db_map.query(db_map.relationship_parameter_value_sq)} expected = {"object1,object2": b'"first"'} self.assertEqual(values, expected) - db_map.connection.close() + db_map.close() def test_import_relationship_parameter_value_with_alternative(self): db_map = create_diff_db_map() @@ -883,7 +883,7 @@ def test_import_relationship_parameter_value_with_alternative(self): } expected = {"object1,object2": (b"1", "alternative")} self.assertEqual(values, expected) - db_map.connection.close() + db_map.close() def test_import_relationship_parameter_value_fails_with_nonexistent_alternative(self): db_map = create_diff_db_map() @@ -893,7 +893,7 @@ def test_import_relationship_parameter_value_fails_with_nonexistent_alternative( ) self.assertTrue(errors) self.assertEqual(count, 0) - db_map.connection.close() + db_map.close() def test_valid_relationship_parameter_value_from_value_list(self): db_map = create_diff_db_map() @@ -913,7 +913,7 @@ def test_valid_relationship_parameter_value_from_value_list(self): self.assertEqual(len(values), 1) value = values[0] self.assertEqual(from_database(value.value), 5.0) - db_map.connection.close() + db_map.close() def test_non_existent_relationship_parameter_value_from_value_list_fails_gracefully(self): db_map = create_diff_db_map() @@ -928,7 +928,7 @@ def test_non_existent_relationship_parameter_value_from_value_list_fails_gracefu ) self.assertEqual(count, 0) self.assertEqual(len(errors), 1) - db_map.connection.close() + db_map.close() class TestImportParameterValueList(unittest.TestCase): @@ -936,7 +936,7 @@ def setUp(self): self._db_map = DatabaseMapping("sqlite://", create=True) def tearDown(self): - self._db_map.connection.close() + self._db_map.close() def test_list_with_single_value(self): count, errors = import_parameter_value_lists(self._db_map, (("list_1", 23.0),)) @@ -983,7 +983,7 @@ def test_single_alternative(self): self.assertEqual(len(alternatives), 2) self.assertIn("Base", alternatives) self.assertIn("alternative", alternatives) - db_map.connection.close() + db_map.close() def test_alternative_description(self): db_map = create_diff_db_map() @@ -994,7 +994,7 @@ def test_alternative_description(self): alternatives = {a.name: a.description for a in db_map.query(db_map.alternative_sq)} expected = {"Base": "Base alternative", "alternative": "description"} self.assertEqual(alternatives, expected) - db_map.connection.close() + db_map.close() def test_update_alternative_description(self): db_map = create_diff_db_map() @@ -1005,7 +1005,7 @@ def test_update_alternative_description(self): alternatives = {a.name: a.description for a in db_map.query(db_map.alternative_sq)} expected = {"Base": "new description"} self.assertEqual(alternatives, expected) - db_map.connection.close() + db_map.close() class TestImportScenario(unittest.TestCase): @@ -1017,7 +1017,7 @@ def test_single_scenario(self): db_map.commit_session("test") scenarios = {s.name: s.description for s in db_map.query(db_map.scenario_sq)} self.assertEqual(scenarios, {"scenario": None}) - db_map.connection.close() + db_map.close() def test_scenario_with_description(self): db_map = create_diff_db_map() @@ -1027,7 +1027,7 @@ def test_scenario_with_description(self): db_map.commit_session("test") scenarios = {s.name: s.description for s in db_map.query(db_map.scenario_sq)} self.assertEqual(scenarios, {"scenario": "description"}) - db_map.connection.close() + db_map.close() def test_update_scenario_description(self): db_map = create_diff_db_map() @@ -1038,7 +1038,7 @@ def test_update_scenario_description(self): db_map.commit_session("test") scenarios = {s.name: s.description for s in db_map.query(db_map.scenario_sq)} self.assertEqual(scenarios, {"scenario": "new description"}) - db_map.connection.close() + db_map.close() class TestImportScenarioAlternative(unittest.TestCase): @@ -1046,7 +1046,7 @@ def setUp(self): self._db_map = create_diff_db_map() def tearDown(self): - self._db_map.connection.close() + self._db_map.close() def test_single_scenario_alternative_import(self): count, errors = import_scenario_alternatives(self._db_map, [["scenario", "alternative"]]) @@ -1119,7 +1119,7 @@ def test_import_metadata(self): self.assertIn(("name", "Charly"), metadata) self.assertIn(("age", "17"), metadata) self.assertIn(("age", "90"), metadata) - db_map.connection.close() + db_map.close() def test_import_metadata_with_duplicate_entry(self): db_map = create_diff_db_map() @@ -1132,7 +1132,7 @@ def test_import_metadata_with_duplicate_entry(self): self.assertIn(("name", "John"), metadata) self.assertIn(("name", "Charly"), metadata) self.assertIn(("age", "17"), metadata) - db_map.connection.close() + db_map.close() def test_import_metadata_with_nested_dict(self): db_map = create_diff_db_map() @@ -1144,7 +1144,7 @@ def test_import_metadata_with_nested_dict(self): self.assertEqual(len(metadata), 2) self.assertIn(("name", "John"), metadata) self.assertIn(("info", "{'age': 17, 'city': 'LA'}"), metadata) - db_map.connection.close() + db_map.close() def test_import_metadata_with_nested_list(self): db_map = create_diff_db_map() @@ -1156,7 +1156,7 @@ def test_import_metadata_with_nested_list(self): self.assertEqual(len(metadata), 2) self.assertIn(('contributors', "{'name': 'John'}"), metadata) self.assertIn(('contributors', "{'name': 'Charly'}"), metadata) - db_map.connection.close() + db_map.close() def test_import_unformatted_metadata(self): db_map = create_diff_db_map() @@ -1167,7 +1167,7 @@ def test_import_unformatted_metadata(self): self.assertFalse(errors) self.assertEqual(len(metadata), 1) self.assertIn(("unnamed", "not a JSON object"), metadata) - db_map.connection.close() + db_map.close() class TestImportEntityMetadata(unittest.TestCase): @@ -1204,7 +1204,7 @@ def test_import_object_metadata(self): self.assertIn(('object1', 'age', '90'), metadata) self.assertIn(('object1', 'co-author', 'Charly'), metadata) self.assertIn(('object1', 'age', '17'), metadata) - db_map.connection.close() + db_map.close() def test_import_relationship_metadata(self): db_map = create_diff_db_map() @@ -1225,7 +1225,7 @@ def test_import_relationship_metadata(self): self.assertIn(('age', '90'), metadata) self.assertIn(('co-author', 'Charly'), metadata) self.assertIn(('age', '17'), metadata) - db_map.connection.close() + db_map.close() class TestImportParameterValueMetadata(unittest.TestCase): @@ -1234,7 +1234,7 @@ def setUp(self): import_metadata(self._db_map, ['{"co-author": "John", "age": 17}']) def tearDown(self): - self._db_map.connection.close() + self._db_map.close() def test_import_object_parameter_value_metadata(self): import_object_classes(self._db_map, ["object_class"]) diff --git a/tests/test_migration.py b/tests/test_migration.py index 4f2cb0f2..c9cf76d1 100644 --- a/tests/test_migration.py +++ b/tests/test_migration.py @@ -142,4 +142,4 @@ def test_upgrade_content(self): self.assertTrue(('breed', 'pluto', b'"labrador"') in obj_par_vals) self.assertTrue(('relative_speed', 'pluto__nemo', b'100') in rel_par_vals) self.assertTrue(('relative_speed', 'scooby__nemo', b'-1') in rel_par_vals) - db_map.connection.close() + db_map.close() From 7525f0a090b0ed60cddc925563c138cdb722fbb9 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 17 May 2023 15:55:42 +0200 Subject: [PATCH 042/317] Add commit_id to committed items --- spinedb_api/db_cache_base.py | 160 ++++++---------------- spinedb_api/db_mapping_add_mixin.py | 3 +- spinedb_api/db_mapping_base.py | 28 +--- spinedb_api/db_mapping_commit_mixin.py | 18 +-- spinedb_api/db_mapping_remove_mixin.py | 65 +++++---- spinedb_api/filters/alternative_filter.py | 8 +- spinedb_api/helpers.py | 7 + spinedb_api/purge.py | 10 +- spinedb_api/query.py | 4 +- spinedb_api/temp_id.py | 113 +++++++++++++++ tests/filters/test_scenario_filter.py | 38 ++--- tests/test_DatabaseMapping.py | 52 +++---- tests/test_DiffDatabaseMapping.py | 28 ++-- tests/test_import_functions.py | 8 +- 14 files changed, 269 insertions(+), 273 deletions(-) create mode 100644 spinedb_api/temp_id.py diff --git a/spinedb_api/db_cache_base.py b/spinedb_api/db_cache_base.py index 72a146a0..89c179c5 100644 --- a/spinedb_api/db_cache_base.py +++ b/spinedb_api/db_cache_base.py @@ -15,6 +15,7 @@ from contextlib import suppress from enum import Enum, unique, auto from functools import cmp_to_key +from .temp_id import TempIdDict, TempId # TODO: Implement CacheItem.pop() to do lookup? @@ -60,28 +61,30 @@ def _cmp_item_type(self, a, b): def _sorted_item_types(self): sorted(self, key=cmp_to_key(self._cmp_item_type)) - def commit(self): - to_add = {} - to_update = {} - to_remove = {} + def dirty_items(self): + dirty_items = [] for item_type in sorted(self, key=cmp_to_key(self._cmp_item_type)): table_cache = self[item_type] + to_add = [] + to_update = [] + to_remove = [] for item in dict.values(table_cache): _ = item.is_valid() if item.status == Status.to_add: - to_add.setdefault(item_type, []).append(item) + to_add.append(item) elif item.status == Status.to_update: - to_update.setdefault(item_type, []).append(item) + to_update.append(item) elif item.status == Status.to_remove: - to_remove.setdefault(item_type, set()).add(item["id"]) - item.status = Status.committed - if to_remove.get(item_type): + to_remove.append(item) + if to_remove: # Fetch descendants, so that they are validated in next iterations of the loop. - # This allows removal in cascade. + # This ensures cascade removal. for x in self: if self._cmp_item_type(item_type, x) < 0: self.fetch_all(x) - return to_add, to_update, to_remove + if to_add or to_update or to_remove: + dirty_items.append((item_type, (to_add, to_update, to_remove))) + return dirty_items @property def fetched_item_types(self): @@ -140,6 +143,13 @@ def fetch_all(self, item_type): while self.fetch_more(item_type): pass + def fetch_value(self, item_type, return_fn): + while self.fetch_more(item_type): + return_value = return_fn() + if return_value: + return return_value + return return_fn() + def fetch_ref(self, item_type, id_): while self.fetch_more(item_type): with suppress(KeyError): @@ -152,106 +162,7 @@ def fetch_ref(self, item_type, id_): return None -class _TempId(int): - _next_id = {} - - def __new__(cls, item_type): - id_ = cls._next_id.setdefault(item_type, -1) - cls._next_id[item_type] -= 1 - return super().__new__(cls, id_) - - def __init__(self, item_type): - super().__init__() - self._item_type = item_type - self._value_binds = [] - self._tuple_value_binds = [] - self._key_binds = [] - self._tuple_key_binds = [] - - def add_value_bind(self, item, key): - self._value_binds.append((item, key)) - - def add_tuple_value_bind(self, item, key): - self._tuple_value_binds.append((item, key)) - - def add_key_bind(self, item): - self._key_binds.append(item) - - def add_tuple_key_bind(self, item, key): - self._tuple_key_binds.append((item, key)) - - def remove_key_bind(self, item): - self._key_binds.remove(item) - - def remove_tuple_key_bind(self, item, key): - self._tuple_key_binds.remove((item, key)) - - def resolve(self, new_id): - for item, key in self._value_binds: - item[key] = new_id - for item, key in self._tuple_value_binds: - item[key] = tuple(new_id if v is self else v for v in item[key]) - for item in self._key_binds: - if self in item: - item[new_id] = dict.pop(item, self, None) - for item, key in self._tuple_key_binds: - if key in item: - item[tuple(new_id if k is self else k for k in key)] = dict.pop(item, key, None) - - -class _TempIdDict(dict): - def __init__(self, **kwargs): - super().__init__(**kwargs) - for key, value in kwargs.items(): - self._bind(key, value) - - def __setitem__(self, key, value): - super().__setitem__(key, value) - self._bind(key, value) - - def __delitem__(self, key): - super().__delitem__(key) - self._unbind(key) - - def setdefault(self, key, default): - value = super().setdefault(key, default) - self._bind(key, value) - return value - - def update(self, other): - super().update(other) - for key, value in other.items(): - self._bind(key, value) - - def pop(self, key, default): - if key in self: - self._unbind(key) - return super().pop(key, default) - - def _bind(self, key, value): - if isinstance(value, _TempId): - value.add_value_bind(self, key) - elif isinstance(value, tuple): - for v in value: - if isinstance(v, _TempId): - v.add_tuple_value_bind(self, key) - elif isinstance(key, _TempId): - key.add_key_bind(self) - elif isinstance(key, tuple): - for k in key: - if isinstance(k, _TempId): - k.add_tuple_key_bind(self, key) - - def _unbind(self, key): - if isinstance(key, _TempId): - key.remove_key_bind(self) - elif isinstance(key, tuple): - for k in key: - if isinstance(k, _TempId): - k.remove_tuple_key_bind(self, key) - - -class _TableCache(_TempIdDict): +class _TableCache(TempIdDict): def __init__(self, db_cache, item_type, *args, **kwargs): """ Args: @@ -264,7 +175,7 @@ def __init__(self, db_cache, item_type, *args, **kwargs): self._id_by_unique_key_value = {} def _new_id(self): - return _TempId(self._item_type) + return TempId(self._item_type) def unique_key_value_to_id(self, key, value, strict=False): """Returns the id that has the given value for the given unique key, or None. @@ -276,9 +187,12 @@ def unique_key_value_to_id(self, key, value, strict=False): Returns: int """ - value = tuple(tuple(x) if isinstance(x, list) else x for x in value) - self._db_cache.fetch_all(self._item_type) id_by_unique_value = self._id_by_unique_key_value.get(key, {}) + if not id_by_unique_value: + id_by_unique_value = self._db_cache.fetch_value( + self._item_type, lambda: self._id_by_unique_key_value.get(key, {}) + ) + value = tuple(tuple(x) if isinstance(x, list) else x for x in value) if strict: return id_by_unique_value[value] return id_by_unique_value.get(value) @@ -312,7 +226,7 @@ def current_item(self, item, skip_keys=()): id_ = item.get("id") if isinstance(id_, int): # id is an int, easy - return self.get(id_) + return self.get(id_) or self._db_cache.fetch_ref(self._item_type, id_) if isinstance(id_, dict): # id is a dict specifying the values for one of the unique constraints key, value = zip(*id_.items()) @@ -371,7 +285,7 @@ def check_item(self, item, for_update=False, skip_keys=()): def _add_unique(self, item): for key, value in item.unique_values(): - self._id_by_unique_key_value.setdefault(key, _TempIdDict())[value] = item["id"] + self._id_by_unique_key_value.setdefault(key, TempIdDict())[value] = item["id"] def _remove_unique(self, item): for key, value in item.unique_values(): @@ -387,7 +301,7 @@ def add_item(self, item, new=False): return new_item def update_item(self, item): - current_item = self[item["id"]] + current_item = self.current_item(item) self._remove_unique(current_item) current_item.update(item) self._add_unique(current_item) @@ -397,7 +311,7 @@ def update_item(self, item): return current_item def remove_item(self, id_): - current_item = self.get(id_) + current_item = self.current_item({"id": id_}) if current_item is not None: self._remove_unique(current_item) current_item.cascade_remove() @@ -411,7 +325,7 @@ def restore_item(self, id_): return current_item -class CacheItemBase(_TempIdDict): +class CacheItemBase(TempIdDict): """A dictionary that represents an db item.""" _defaults = {} @@ -427,8 +341,8 @@ def __init__(self, db_cache, item_type, **kwargs): super().__init__(**kwargs) self._db_cache = db_cache self._item_type = item_type - self._referrers = _TempIdDict() - self._weak_referrers = _TempIdDict() + self._referrers = TempIdDict() + self._weak_referrers = TempIdDict() self.restore_callbacks = set() self.update_callbacks = set() self.remove_callbacks = set() @@ -662,3 +576,7 @@ def _asdict(self): def is_committed(self): return self.status == Status.committed + + def commit(self, commit_id): + self.status = Status.committed + self["commit_id"] = commit_id diff --git a/spinedb_api/db_mapping_add_mixin.py b/spinedb_api/db_mapping_add_mixin.py index 17c7c762..f17b46c9 100644 --- a/spinedb_api/db_mapping_add_mixin.py +++ b/spinedb_api/db_mapping_add_mixin.py @@ -16,7 +16,6 @@ from sqlalchemy.exc import DBAPIError from .exception import SpineIntegrityError -from .query import Query class DatabaseMappingAddMixin: @@ -71,7 +70,7 @@ def _do_add_items(self, connection, tablename, *items_to_add): if id_items: connection.execute(table.insert(), [x._asdict() for x in id_items]) if temp_id_items: - current_ids = {x["id"] for x in Query(connection, table)} + current_ids = {x["id"] for x in connection.execute(table.select())} next_id = max(current_ids, default=0) + 1 available_ids = set(range(1, next_id)) - current_ids missing_id_count = len(temp_id_items) - len(available_ids) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 0414485d..2bf0139b 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -11,14 +11,13 @@ """Provides :class:`.DatabaseMappingBase`.""" # TODO: Finish docstrings -import uuid import hashlib import os import logging import time from types import MethodType -from sqlalchemy import create_engine, MetaData, Table, Column, Integer, inspect, case, func, cast, false, and_, or_ -from sqlalchemy.sql.expression import label, Alias +from sqlalchemy import create_engine, MetaData, Table, Integer, inspect, case, func, cast, and_, or_ +from sqlalchemy.sql.expression import Alias, label from sqlalchemy.engine.url import make_url, URL from sqlalchemy.orm import aliased from sqlalchemy.exc import DatabaseError @@ -302,26 +301,6 @@ def _receive_engine_close(self, dbapi_con, _connection_record): if self._memory_dirty: copy_database_bind(self._original_engine, self.engine) - def in_(self, column, values): - """Returns an expression equivalent to column.in_(values), that circumvents the - 'too many sql variables' problem in sqlite.""" - # FIXME - return column.in_(values) - if not values: - return false() - if not self.sa_url.drivername.startswith("sqlite"): - return column.in_(values) - in_value = Table( - "in_value_" + str(uuid.uuid4()), - MetaData(), - Column("value", column.type, primary_key=True), - prefixes=['TEMPORARY'], - ) - with self.engine.connect() as connection: - in_value.create(connection, checkfirst=True) - connection.execute(in_value.insert(), [{"value": column.type.python_type(val)} for val in set(values)]) - return column.in_(Query(connection, in_value.c.value)) - def _get_table_to_sq_attr(self): if not self._table_to_sq_attr: self._table_to_sq_attr = self._make_table_to_sq_attr() @@ -361,8 +340,7 @@ def _clear_subqueries(self, *tablenames): setattr(self, attr_name, None) def query(self, *args, **kwargs): - """Return a sqlalchemy :class:`~sqlalchemy.orm.query.Query` object applied - to this :class:`.DatabaseMappingBase`. + """Return a sqlalchemy :class:`~Query` object bound to this :class:`.DatabaseMappingBase`. To perform custom ``SELECT`` statements, call this method with one or more of the class documented :class:`~sqlalchemy.sql.expression.Alias` properties. For example, to select the object class with diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py index e28aa8e6..7d251afb 100644 --- a/spinedb_api/db_mapping_commit_mixin.py +++ b/spinedb_api/db_mapping_commit_mixin.py @@ -29,24 +29,24 @@ def commit_session(self, comment): """ if not comment: raise SpineDBAPIError("Commit message cannot be empty.") - to_add, to_update, to_remove = self.cache.commit() - if not to_add and not to_update and not to_remove: + dirty_items = self.cache.dirty_items() + if not dirty_items: raise SpineDBAPIError("Nothing to commit.") user = self.username date = datetime.now(timezone.utc) ins = self._metadata.tables["commit"].insert() with self.engine.begin() as connection: commit_id = connection.execute(ins, dict(user=user, date=date, comment=comment)).inserted_primary_key[0] - for tablename, items in to_add.items(): - self._do_add_items(connection, tablename, *items) - for tablename, items in to_update.items(): - self._do_update_items(connection, tablename, *items) - self._do_remove_items(connection, **to_remove) + for tablename, (to_add, to_update, to_remove) in dirty_items: + for item in to_add + to_update + to_remove: + item.commit(commit_id) + self._do_add_items(connection, tablename, *to_add) + self._do_update_items(connection, tablename, *to_update) + self._do_remove_items(connection, tablename, *{x["id"] for x in to_remove}) if self._memory: self._memory_dirty = True def rollback_session(self): - to_add, to_update, to_remove = self.cache.commit() - if not to_add and not to_update and not to_remove: + if not self.cache.dirty_items(): raise SpineDBAPIError("Nothing to rollback.") self.cache.reset_queries() diff --git a/spinedb_api/db_mapping_remove_mixin.py b/spinedb_api/db_mapping_remove_mixin.py index 71d6d37c..ced5ed6b 100644 --- a/spinedb_api/db_mapping_remove_mixin.py +++ b/spinedb_api/db_mapping_remove_mixin.py @@ -13,8 +13,10 @@ """ +from sqlalchemy import and_, or_ from sqlalchemy.exc import DBAPIError from .exception import SpineDBAPIError +from .helpers import Asterisk, group_consecutive # TODO: improve docstrings @@ -22,49 +24,52 @@ class DatabaseMappingRemoveMixin: """Provides methods to perform ``REMOVE`` operations over a Spine db.""" - def restore_items(self, tablename, *ids): - if not ids: - return [] - tablename = self._real_tablename(tablename) - table_cache = self.cache.get(tablename) - if not table_cache: - return [] - return [table_cache.restore_item(id_) for id_ in ids] - def remove_items(self, tablename, *ids): if not ids: return [] tablename = self._real_tablename(tablename) - table_cache = self.cache.get(tablename) - if not table_cache: - return [] + table_cache = self.cache.table_cache(tablename) + if Asterisk in ids: + ids = table_cache ids = set(ids) if tablename == "alternative": # Do not remove the Base alternative - ids -= {1} + ids.discard(1) return [table_cache.remove_item(id_) for id_ in ids] - def _do_remove_items(self, connection, **kwargs): + def restore_items(self, tablename, *ids): + if not ids: + return [] + tablename = self._real_tablename(tablename) + table_cache = self.cache.table_cache(tablename) + return [table_cache.restore_item(id_) for id_ in ids] + + def purge_items(self, tablename): + return self.remove_items(tablename, Asterisk) + + def _do_remove_items(self, connection, tablename, *ids): """Removes items from the db. Args: - **kwargs: keyword is table name, argument is list of ids to remove + *ids: ids to remove """ - for tablename, ids in kwargs.items(): - tablename = self._real_tablename(tablename) - if tablename == "alternative": - # Do not remove the Base alternative - ids -= {1} - if not ids: - continue - id_field = self._id_fields.get(tablename, "id") - table = self._metadata.tables[tablename] - delete = table.delete().where(self.in_(getattr(table.c, id_field), ids)) - try: - connection.execute(delete) - except DBAPIError as e: - msg = f"DBAPIError while removing {tablename} items: {e.orig.args}" - raise SpineDBAPIError(msg) from e + tablename = self._real_tablename(tablename) + ids = set(ids) + if tablename == "alternative": + # Do not remove the Base alternative + ids.discard(1) + if not ids: + return + table = self._metadata.tables[tablename] + id_field = self._id_fields.get(tablename, "id") + id_column = getattr(table.c, id_field) + cond = or_(*(and_(id_column >= first, id_column <= last) for first, last in group_consecutive(ids))) + delete = table.delete().where(cond) + try: + connection.execute(delete) + except DBAPIError as e: + msg = f"DBAPIError while removing {tablename} items: {e.orig.args}" + raise SpineDBAPIError(msg) from e def _get_metadata_ids_to_remove(self): used_metadata_ids = set() diff --git a/spinedb_api/filters/alternative_filter.py b/spinedb_api/filters/alternative_filter.py index 67a01907..8299f7cf 100644 --- a/spinedb_api/filters/alternative_filter.py +++ b/spinedb_api/filters/alternative_filter.py @@ -138,7 +138,7 @@ def _alternative_ids(db_map, alternatives): alternative_names = [name for name in alternatives if isinstance(name, str)] ids_from_db = ( db_map.query(db_map.alternative_sq.c.id, db_map.alternative_sq.c.name) - .filter(db_map.in_(db_map.alternative_sq.c.name, alternative_names)) + .filter(db_map.alternative_sq.c.name.in_(alternative_names)) .all() ) names_in_db = [i.name for i in ids_from_db] @@ -148,9 +148,7 @@ def _alternative_ids(db_map, alternatives): ids = [i.id for i in ids_from_db] alternative_ids = [id_ for id_ in alternatives if isinstance(id_, int)] ids_from_db = ( - db_map.query(db_map.alternative_sq.c.id) - .filter(db_map.in_(db_map.alternative_sq.c.id, alternative_ids)) - .all() + db_map.query(db_map.alternative_sq.c.id).filter(db_map.alternative_sq.c.id.in_(alternative_ids)).all() ) ids_in_db = [i.id for i in ids_from_db] if len(alternative_ids) != len(ids_from_db): @@ -174,4 +172,4 @@ def _make_alternative_filtered_parameter_value_sq(db_map, state): Alias: a subquery for parameter value filtered by selected alternatives """ subquery = state.original_parameter_value_sq - return db_map.query(subquery).filter(db_map.in_(subquery.c.alternative_id, state.alternatives)).subquery() + return db_map.query(subquery).filter(subquery.c.alternative_id.in_(state.alternatives)).subquery() diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index 9a643861..0e1daeaa 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -18,6 +18,7 @@ import json import warnings from operator import itemgetter +from itertools import groupby from urllib.parse import urlparse, urlunparse from sqlalchemy import ( Boolean, @@ -826,3 +827,9 @@ def remove_credentials_from_url(url): if parsed.username is None: return url return urlunparse(parsed._replace(netloc=parsed.netloc.partition("@")[-1])) + + +def group_consecutive(list_of_numbers): + for _k, g in groupby(enumerate(sorted(list_of_numbers)), lambda x: x[0] - x[1]): + group = list(map(itemgetter(1), g)) + yield group[0], group[-1] diff --git a/spinedb_api/purge.py b/spinedb_api/purge.py index e4eac936..a3f9a491 100644 --- a/spinedb_api/purge.py +++ b/spinedb_api/purge.py @@ -61,17 +61,13 @@ def purge(db_map, purge_settings, logger=None): if purge_settings is None: # Bring all the pain purge_settings = {item_type: True for item_type in DatabaseMapping.ITEM_TYPES} - removable_db_map_data = { - item_type: _ids_for_item_type(db_map, item_type) for item_type, checked in purge_settings.items() if checked - } - removable_db_map_data = {item_type: ids for item_type, ids in removable_db_map_data.items() if ids} + removable_db_map_data = {item_type for item_type, checked in purge_settings.items() if checked} if removable_db_map_data: try: if logger: logger.msg.emit("Purging database...") - for item_type, ids in removable_db_map_data.items(): - db_map.remove_items(item_type, **ids) - # FIXME: What do do here? How does one affect the DB directly, bypassing cache? + for item_type in removable_db_map_data: + db_map.purge_items(item_type) db_map.commit_session("Purge database") if logger: logger.msg.emit("Database purged") diff --git a/spinedb_api/query.py b/spinedb_api/query.py index a304d107..08a299c1 100644 --- a/spinedb_api/query.py +++ b/spinedb_api/query.py @@ -35,8 +35,8 @@ def add_columns(self, *columns): self._select = select(self._entities) return self - def filter(self, *args): - self._select = self._select.where(*args) + def filter(self, clause): + self._select = self._select.where(clause) return self def filter_by(self, **kwargs): diff --git a/spinedb_api/temp_id.py b/spinedb_api/temp_id.py new file mode 100644 index 00000000..4fbd9cd7 --- /dev/null +++ b/spinedb_api/temp_id.py @@ -0,0 +1,113 @@ +###################################################################################################################### +# Copyright (C) 2017-2022 Spine project consortium +# This file is part of Spine Database API. +# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser +# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General +# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with +# this program. If not, see . +###################################################################################################################### +""" +Temp id stuff. + +""" + + +class TempId(int): + _next_id = {} + + def __new__(cls, item_type): + id_ = cls._next_id.setdefault(item_type, -1) + cls._next_id[item_type] -= 1 + return super().__new__(cls, id_) + + def __init__(self, item_type): + super().__init__() + self._item_type = item_type + self._value_binds = [] + self._tuple_value_binds = [] + self._key_binds = [] + self._tuple_key_binds = [] + + def add_value_bind(self, item, key): + self._value_binds.append((item, key)) + + def add_tuple_value_bind(self, item, key): + self._tuple_value_binds.append((item, key)) + + def add_key_bind(self, item): + self._key_binds.append(item) + + def add_tuple_key_bind(self, item, key): + self._tuple_key_binds.append((item, key)) + + def remove_key_bind(self, item): + self._key_binds.remove(item) + + def remove_tuple_key_bind(self, item, key): + self._tuple_key_binds.remove((item, key)) + + def resolve(self, new_id): + for item, key in self._value_binds: + item[key] = new_id + for item, key in self._tuple_value_binds: + item[key] = tuple(new_id if v is self else v for v in item[key]) + for item in self._key_binds: + if self in item: + item[new_id] = dict.pop(item, self, None) + for item, key in self._tuple_key_binds: + if key in item: + item[tuple(new_id if k is self else k for k in key)] = dict.pop(item, key, None) + + +class TempIdDict(dict): + def __init__(self, **kwargs): + super().__init__(**kwargs) + for key, value in kwargs.items(): + self._bind(key, value) + + def __setitem__(self, key, value): + super().__setitem__(key, value) + self._bind(key, value) + + def __delitem__(self, key): + super().__delitem__(key) + self._unbind(key) + + def setdefault(self, key, default): + value = super().setdefault(key, default) + self._bind(key, value) + return value + + def update(self, other): + super().update(other) + for key, value in other.items(): + self._bind(key, value) + + def pop(self, key, default): + if key in self: + self._unbind(key) + return super().pop(key, default) + + def _bind(self, key, value): + if isinstance(value, TempId): + value.add_value_bind(self, key) + elif isinstance(value, tuple): + for v in value: + if isinstance(v, TempId): + v.add_tuple_value_bind(self, key) + elif isinstance(key, TempId): + key.add_key_bind(self) + elif isinstance(key, tuple): + for k in key: + if isinstance(k, TempId): + k.add_tuple_key_bind(self, key) + + def _unbind(self, key): + if isinstance(key, TempId): + key.remove_key_bind(self) + elif isinstance(key, tuple): + for k in key: + if isinstance(k, TempId): + k.remove_tuple_key_bind(self, key) diff --git a/tests/filters/test_scenario_filter.py b/tests/filters/test_scenario_filter.py index d0210787..97735df5 100644 --- a/tests/filters/test_scenario_filter.py +++ b/tests/filters/test_scenario_filter.py @@ -78,7 +78,7 @@ def test_scenario_filter(self): self.assertEqual(len(parameters), 1) self.assertEqual(parameters[0].value, b"23.0") alternatives = [dict(a) for a in self._db_map.query(self._db_map.alternative_sq)] - self.assertEqual(alternatives, [{"name": "alternative", "description": None, "id": 2, "commit_id": None}]) + self.assertEqual(alternatives, [{"name": "alternative", "description": None, "id": 2, "commit_id": 2}]) scenarios = [dict(s) for s in self._db_map.query(self._db_map.wide_scenario_sq).all()] self.assertEqual( scenarios, @@ -90,7 +90,7 @@ def test_scenario_filter(self): "alternative_name_list": "alternative", "alternative_id_list": "2", "id": 1, - "commit_id": None, + "commit_id": 2, } ], ) @@ -114,7 +114,7 @@ def test_scenario_filter_works_for_object_parameter_value_sq(self): self.assertEqual(len(parameters), 1) self.assertEqual(parameters[0].value, b"23.0") alternatives = [dict(a) for a in self._db_map.query(self._db_map.alternative_sq)] - self.assertEqual(alternatives, [{"name": "alternative", "description": None, "id": 2, "commit_id": None}]) + self.assertEqual(alternatives, [{"name": "alternative", "description": None, "id": 2, "commit_id": 2}]) scenarios = [dict(s) for s in self._db_map.query(self._db_map.wide_scenario_sq).all()] self.assertEqual( scenarios, @@ -126,7 +126,7 @@ def test_scenario_filter_works_for_object_parameter_value_sq(self): "alternative_name_list": "alternative", "alternative_id_list": "2", "id": 1, - "commit_id": None, + "commit_id": 2, } ], ) @@ -148,7 +148,7 @@ def test_scenario_filter_works_for_relationship_parameter_value_sq(self): self.assertEqual(len(parameters), 1) self.assertEqual(parameters[0].value, b"23.0") alternatives = [dict(a) for a in self._db_map.query(self._db_map.alternative_sq)] - self.assertEqual(alternatives, [{"name": "alternative", "description": None, "id": 2, "commit_id": None}]) + self.assertEqual(alternatives, [{"name": "alternative", "description": None, "id": 2, "commit_id": 2}]) scenarios = [dict(s) for s in self._db_map.query(self._db_map.wide_scenario_sq).all()] self.assertEqual( scenarios, @@ -160,7 +160,7 @@ def test_scenario_filter_works_for_relationship_parameter_value_sq(self): "alternative_name_list": "alternative", "alternative_id_list": "2", "id": 1, - "commit_id": None, + "commit_id": 2, } ], ) @@ -200,9 +200,9 @@ def test_scenario_filter_selects_highest_ranked_alternative(self): self.assertEqual( alternatives, [ - {"name": "alternative3", "description": None, "id": 2, "commit_id": None}, - {"name": "alternative1", "description": None, "id": 3, "commit_id": None}, - {"name": "alternative2", "description": None, "id": 4, "commit_id": None}, + {"name": "alternative3", "description": None, "id": 2, "commit_id": 2}, + {"name": "alternative1", "description": None, "id": 3, "commit_id": 2}, + {"name": "alternative2", "description": None, "id": 4, "commit_id": 2}, ], ) scenarios = [dict(s) for s in self._db_map.query(self._db_map.wide_scenario_sq).all()] @@ -216,7 +216,7 @@ def test_scenario_filter_selects_highest_ranked_alternative(self): "alternative_name_list": "alternative1,alternative3,alternative2", "alternative_id_list": "3,2,4", "id": 1, - "commit_id": None, + "commit_id": 2, } ], ) @@ -267,9 +267,9 @@ def test_scenario_filter_selects_highest_ranked_alternative_of_active_scenario(s self.assertEqual( alternatives, [ - {"name": "alternative3", "description": None, "id": 2, "commit_id": None}, - {"name": "alternative1", "description": None, "id": 3, "commit_id": None}, - {"name": "alternative2", "description": None, "id": 4, "commit_id": None}, + {"name": "alternative3", "description": None, "id": 2, "commit_id": 2}, + {"name": "alternative1", "description": None, "id": 3, "commit_id": 2}, + {"name": "alternative2", "description": None, "id": 4, "commit_id": 2}, ], ) scenarios = [dict(s) for s in self._db_map.query(self._db_map.wide_scenario_sq).all()] @@ -283,7 +283,7 @@ def test_scenario_filter_selects_highest_ranked_alternative_of_active_scenario(s "alternative_name_list": "alternative1,alternative3,alternative2", "alternative_id_list": "3,2,4", "id": 1, - "commit_id": None, + "commit_id": 2, } ], ) @@ -333,7 +333,7 @@ def test_scenario_filter_for_multiple_objects_and_parameters(self): }, ) alternatives = [dict(a) for a in self._db_map.query(self._db_map.alternative_sq)] - self.assertEqual(alternatives, [{"name": "alternative", "description": None, "id": 2, "commit_id": None}]) + self.assertEqual(alternatives, [{"name": "alternative", "description": None, "id": 2, "commit_id": 2}]) scenarios = [dict(s) for s in self._db_map.query(self._db_map.wide_scenario_sq).all()] self.assertEqual( scenarios, @@ -345,7 +345,7 @@ def test_scenario_filter_for_multiple_objects_and_parameters(self): "alternative_name_list": "alternative", "alternative_id_list": "2", "id": 1, - "commit_id": None, + "commit_id": 2, } ], ) @@ -368,8 +368,8 @@ def test_filters_scenarios_and_alternatives(self): self.assertEqual( alternatives, [ - {"name": "alternative2", "description": None, "id": 3, "commit_id": None}, - {"name": "alternative3", "description": None, "id": 4, "commit_id": None}, + {"name": "alternative2", "description": None, "id": 3, "commit_id": 2}, + {"name": "alternative3", "description": None, "id": 4, "commit_id": 2}, ], ) scenarios = [dict(s) for s in self._db_map.query(self._db_map.wide_scenario_sq).all()] @@ -383,7 +383,7 @@ def test_filters_scenarios_and_alternatives(self): "alternative_name_list": "alternative2,alternative3", "alternative_id_list": "3,4", "id": 2, - "commit_id": None, + "commit_id": 2, } ], ) diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 034fe965..5bf328fe 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -643,7 +643,7 @@ def test_update_parameter_definition_value_list(self): self.assertEqual( dict(pdefs[0]), { - "commit_id": None, + "commit_id": 3, "default_type": None, "default_value": None, "description": None, @@ -715,12 +715,10 @@ def test_update_object_metadata(self): self._db_map.commit_session("Update data") metadata_entries = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata_entries), 1) - self.assertEqual(dict(metadata_entries[0]), {"id": 2, "name": "key_2", "value": "new value", "commit_id": None}) + self.assertEqual(dict(metadata_entries[0]), {"id": 2, "name": "key_2", "value": "new value", "commit_id": 2}) entity_metadata_entries = self._db_map.query(self._db_map.entity_metadata_sq).all() self.assertEqual(len(entity_metadata_entries), 1) - self.assertEqual( - dict(entity_metadata_entries[0]), {"id": 1, "entity_id": 1, "metadata_id": 2, "commit_id": None} - ) + self.assertEqual(dict(entity_metadata_entries[0]), {"id": 1, "entity_id": 1, "metadata_id": 2, "commit_id": 2}) def test_update_object_metadata_reuses_existing_metadata(self): import_functions.import_object_classes(self._db_map, ("my_class",)) @@ -744,16 +742,12 @@ def test_update_object_metadata_reuses_existing_metadata(self): metadata_entries = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata_entries), 1) self.assertEqual( - dict(metadata_entries[0]), {"id": 2, "name": "key 2", "value": "metadata value 2", "commit_id": None} + dict(metadata_entries[0]), {"id": 2, "name": "key 2", "value": "metadata value 2", "commit_id": 2} ) entity_metadata_entries = self._db_map.query(self._db_map.entity_metadata_sq).all() self.assertEqual(len(entity_metadata_entries), 2) - self.assertEqual( - dict(entity_metadata_entries[0]), {"id": 1, "entity_id": 1, "metadata_id": 2, "commit_id": None} - ) - self.assertEqual( - dict(entity_metadata_entries[1]), {"id": 2, "entity_id": 2, "metadata_id": 2, "commit_id": None} - ) + self.assertEqual(dict(entity_metadata_entries[0]), {"id": 1, "entity_id": 1, "metadata_id": 2, "commit_id": 2}) + self.assertEqual(dict(entity_metadata_entries[1]), {"id": 2, "entity_id": 2, "metadata_id": 2, "commit_id": 2}) def test_update_object_metadata_keeps_metadata_still_in_use(self): import_functions.import_object_classes(self._db_map, ("my_class",)) @@ -775,20 +769,12 @@ def test_update_object_metadata_keeps_metadata_still_in_use(self): self._db_map.commit_session("Update data") metadata_entries = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata_entries), 2) - self.assertEqual( - dict(metadata_entries[0]), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": None} - ) - self.assertEqual( - dict(metadata_entries[1]), {"id": 2, "name": "new key", "value": "new value", "commit_id": None} - ) + self.assertEqual(dict(metadata_entries[0]), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": 2}) + self.assertEqual(dict(metadata_entries[1]), {"id": 2, "name": "new key", "value": "new value", "commit_id": 3}) entity_metadata_entries = self._db_map.query(self._db_map.entity_metadata_sq).all() self.assertEqual(len(entity_metadata_entries), 2) - self.assertEqual( - dict(entity_metadata_entries[0]), {"id": 1, "entity_id": 1, "metadata_id": 2, "commit_id": None} - ) - self.assertEqual( - dict(entity_metadata_entries[1]), {"id": 2, "entity_id": 2, "metadata_id": 1, "commit_id": None} - ) + self.assertEqual(dict(entity_metadata_entries[0]), {"id": 1, "entity_id": 1, "metadata_id": 2, "commit_id": 3}) + self.assertEqual(dict(entity_metadata_entries[1]), {"id": 2, "entity_id": 2, "metadata_id": 1, "commit_id": 2}) def test_update_parameter_value_metadata(self): import_functions.import_object_classes(self._db_map, ("my_class",)) @@ -810,11 +796,11 @@ def test_update_parameter_value_metadata(self): self._db_map.commit_session("Update data") metadata_entries = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata_entries), 1) - self.assertEqual(dict(metadata_entries[0]), {"id": 2, "name": "key_2", "value": "new value", "commit_id": None}) + self.assertEqual(dict(metadata_entries[0]), {"id": 2, "name": "key_2", "value": "new value", "commit_id": 2}) value_metadata_entries = self._db_map.query(self._db_map.parameter_value_metadata_sq).all() self.assertEqual(len(value_metadata_entries), 1) self.assertEqual( - dict(value_metadata_entries[0]), {"id": 1, "parameter_value_id": 1, "metadata_id": 2, "commit_id": None} + dict(value_metadata_entries[0]), {"id": 1, "parameter_value_id": 1, "metadata_id": 2, "commit_id": 2} ) def test_update_parameter_value_metadata_will_not_delete_shared_entity_metadata(self): @@ -838,20 +824,16 @@ def test_update_parameter_value_metadata_will_not_delete_shared_entity_metadata( self._db_map.commit_session("Update data") metadata_entries = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata_entries), 2) - self.assertEqual( - dict(metadata_entries[0]), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": None} - ) - self.assertEqual(dict(metadata_entries[1]), {"id": 2, "name": "key_2", "value": "new value", "commit_id": None}) + self.assertEqual(dict(metadata_entries[0]), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": 2}) + self.assertEqual(dict(metadata_entries[1]), {"id": 2, "name": "key_2", "value": "new value", "commit_id": 3}) value_metadata_entries = self._db_map.query(self._db_map.parameter_value_metadata_sq).all() self.assertEqual(len(value_metadata_entries), 1) self.assertEqual( - dict(value_metadata_entries[0]), {"id": 1, "parameter_value_id": 1, "metadata_id": 2, "commit_id": None} + dict(value_metadata_entries[0]), {"id": 1, "parameter_value_id": 1, "metadata_id": 2, "commit_id": 3} ) entity_metadata_entries = self._db_map.query(self._db_map.entity_metadata_sq).all() self.assertEqual(len(entity_metadata_entries), 1) - self.assertEqual( - dict(entity_metadata_entries[0]), {"id": 1, "entity_id": 1, "metadata_id": 1, "commit_id": None} - ) + self.assertEqual(dict(entity_metadata_entries[0]), {"id": 1, "entity_id": 1, "metadata_id": 1, "commit_id": 2}) def test_update_metadata(self): import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) @@ -864,7 +846,7 @@ def test_update_metadata(self): metadata_records = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata_records), 1) self.assertEqual( - dict(metadata_records[0]), {"id": 1, "name": "author", "value": "Prof. T. Est", "commit_id": None} + dict(metadata_records[0]), {"id": 1, "name": "author", "value": "Prof. T. Est", "commit_id": 3} ) diff --git a/tests/test_DiffDatabaseMapping.py b/tests/test_DiffDatabaseMapping.py index a779d019..444815df 100644 --- a/tests/test_DiffDatabaseMapping.py +++ b/tests/test_DiffDatabaseMapping.py @@ -913,7 +913,7 @@ def test_add_alternative(self): dict(alternatives[0]), {"id": 1, "name": "Base", "description": "Base alternative", "commit_id": 1} ) self.assertEqual( - dict(alternatives[1]), {"id": 2, "name": "my_alternative", "description": None, "commit_id": None} + dict(alternatives[1]), {"id": 2, "name": "my_alternative", "description": None, "commit_id": 2} ) def test_add_scenario(self): @@ -925,7 +925,7 @@ def test_add_scenario(self): self.assertEqual(len(scenarios), 1) self.assertEqual( dict(scenarios[0]), - {"id": 1, "name": "my_scenario", "description": None, "active": False, "commit_id": None}, + {"id": 1, "name": "my_scenario", "description": None, "active": False, "commit_id": 2}, ) def test_add_scenario_alternative(self): @@ -939,7 +939,7 @@ def test_add_scenario_alternative(self): self.assertEqual(len(scenario_alternatives), 1) self.assertEqual( dict(scenario_alternatives[0]), - {"id": 1, "scenario_id": 1, "alternative_id": 1, "rank": 0, "commit_id": None}, + {"id": 1, "scenario_id": 1, "alternative_id": 1, "rank": 0, "commit_id": 3}, ) def test_add_metadata(self): @@ -950,7 +950,7 @@ def test_add_metadata(self): metadata = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata), 1) self.assertEqual( - dict(metadata[0]), {"name": "test name", "id": 1, "value": "test_add_metadata", "commit_id": None} + dict(metadata[0]), {"name": "test name", "id": 1, "value": "test_add_metadata", "commit_id": 2} ) def test_add_metadata_that_exists_does_not_add_it(self): @@ -960,7 +960,7 @@ def test_add_metadata_that_exists_does_not_add_it(self): self.assertEqual(items, []) metadata = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata), 1) - self.assertEqual(dict(metadata[0]), {"name": "title", "id": 1, "value": "My metadata.", "commit_id": None}) + self.assertEqual(dict(metadata[0]), {"name": "title", "id": 1, "value": "My metadata.", "commit_id": 2}) def test_add_entity_metadata_for_object(self): import_functions.import_object_classes(self._db_map, ("fish",)) @@ -982,7 +982,7 @@ def test_add_entity_metadata_for_object(self): "metadata_value": "My metadata.", "metadata_id": 1, "id": 1, - "commit_id": None, + "commit_id": 3, }, ) @@ -1008,7 +1008,7 @@ def test_add_entity_metadata_for_relationship(self): "metadata_value": "My metadata.", "metadata_id": 1, "id": 1, - "commit_id": None, + "commit_id": 3, }, ) @@ -1038,7 +1038,7 @@ def test_add_ext_entity_metadata_for_object(self): "metadata_value": "object metadata", "metadata_id": 1, "id": 1, - "commit_id": None, + "commit_id": 3, }, ) @@ -1055,7 +1055,7 @@ def test_adding_ext_entity_metadata_for_object_reuses_existing_metadata_names_an self._db_map.commit_session("Add entity metadata") metadata = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata), 1) - self.assertEqual(dict(metadata[0]), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": None}) + self.assertEqual(dict(metadata[0]), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": 2}) entity_metadata = self._db_map.query(self._db_map.ext_entity_metadata_sq).all() self.assertEqual(len(entity_metadata), 1) self.assertEqual( @@ -1067,7 +1067,7 @@ def test_adding_ext_entity_metadata_for_object_reuses_existing_metadata_names_an "metadata_value": "My metadata.", "metadata_id": 1, "id": 1, - "commit_id": None, + "commit_id": 3, }, ) @@ -1097,7 +1097,7 @@ def test_add_parameter_value_metadata(self): "metadata_value": "My metadata.", "metadata_id": 1, "id": 1, - "commit_id": None, + "commit_id": 3, }, ) @@ -1139,7 +1139,7 @@ def test_add_ext_parameter_value_metadata(self): "metadata_value": "parameter metadata", "metadata_id": 1, "id": 1, - "commit_id": None, + "commit_id": 3, }, ) @@ -1159,7 +1159,7 @@ def test_add_ext_parameter_value_metadata_reuses_existing_metadata(self): self._db_map.commit_session("Add value metadata") metadata = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata), 1) - self.assertEqual(dict(metadata[0]), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": None}) + self.assertEqual(dict(metadata[0]), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": 2}) value_metadata = self._db_map.query(self._db_map.ext_parameter_value_metadata_sq).all() self.assertEqual(len(value_metadata), 1) self.assertEqual( @@ -1173,7 +1173,7 @@ def test_add_ext_parameter_value_metadata_reuses_existing_metadata(self): "metadata_value": "My metadata.", "metadata_id": 1, "id": 1, - "commit_id": None, + "commit_id": 3, }, ) diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py index 86d87261..9309b565 100644 --- a/tests/test_import_functions.py +++ b/tests/test_import_functions.py @@ -1260,7 +1260,7 @@ def test_import_object_parameter_value_metadata(self): "metadata_value": "John", "parameter_name": "param", "parameter_value_id": 1, - "commit_id": None, + "commit_id": 2, }, ) self.assertEqual( @@ -1274,7 +1274,7 @@ def test_import_object_parameter_value_metadata(self): "metadata_value": "17", "parameter_name": "param", "parameter_value_id": 1, - "commit_id": None, + "commit_id": 2, }, ) @@ -1304,7 +1304,7 @@ def test_import_relationship_parameter_value_metadata(self): "metadata_value": "John", "parameter_name": "param", "parameter_value_id": 1, - "commit_id": None, + "commit_id": 2, }, ) self.assertEqual( @@ -1318,7 +1318,7 @@ def test_import_relationship_parameter_value_metadata(self): "metadata_value": "17", "parameter_name": "param", "parameter_value_id": 1, - "commit_id": None, + "commit_id": 2, }, ) From 3b0e1108a1b7438cdb0957f2009d606ff5e0e6c8 Mon Sep 17 00:00:00 2001 From: Manuel Date: Mon, 22 May 2023 09:59:14 +0200 Subject: [PATCH 043/317] Fix tests --- spinedb_api/db_cache_impl.py | 8 +- spinedb_api/db_mapping_remove_mixin.py | 5 +- tests/test_DatabaseMapping.py | 13 ++-- tests/test_DiffDatabaseMapping.py | 11 ++- tests/test_check_functions.py | 103 ++++++++++++++----------- 5 files changed, 82 insertions(+), 58 deletions(-) diff --git a/spinedb_api/db_cache_impl.py b/spinedb_api/db_cache_impl.py index 30a2a345..d44ed40d 100644 --- a/spinedb_api/db_cache_impl.py +++ b/spinedb_api/db_cache_impl.py @@ -221,10 +221,10 @@ def _asdict(self): return d def merge(self, other): - parameter_value_list_id = other.get("parameter_value_list_id") + other_parameter_value_list_id = other.get("parameter_value_list_id") if ( - parameter_value_list_id is not None - and parameter_value_list_id != self["parameter_value_list_id"] + other_parameter_value_list_id is not None + and other_parameter_value_list_id != self["parameter_value_list_id"] and any( x["parameter_definition_id"] == self["id"] for x in self._db_cache.table_cache("parameter_value").values() @@ -299,7 +299,7 @@ def polish(self): if list_value_id is None: return ( f"value {parsed_value} of {self['parameter_definition_name']} for {self['entity_byname']} " - "is not in {list_name}" + f"is not in {list_name}" ) self["value"] = list_value_id self["type"] = "list_value_ref" diff --git a/spinedb_api/db_mapping_remove_mixin.py b/spinedb_api/db_mapping_remove_mixin.py index ced5ed6b..bd1efb3a 100644 --- a/spinedb_api/db_mapping_remove_mixin.py +++ b/spinedb_api/db_mapping_remove_mixin.py @@ -71,10 +71,11 @@ def _do_remove_items(self, connection, tablename, *ids): msg = f"DBAPIError while removing {tablename} items: {e.orig.args}" raise SpineDBAPIError(msg) from e - def _get_metadata_ids_to_remove(self): + def remove_unused_metadata(self): used_metadata_ids = set() for x in self.cache.get("entity_metadata", {}).values(): used_metadata_ids.add(x["metadata_id"]) for x in self.cache.get("parameter_value_metadata", {}).values(): used_metadata_ids.add(x["metadata_id"]) - return {x["id"] for x in self.cache.get("metadata", {}).values()} - used_metadata_ids + unused_metadata_ids = {x["id"] for x in self.cache.get("metadata", {}).values()} - used_metadata_ids + self.remove_items("metadata", *unused_metadata_ids) diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 5bf328fe..d8f015a6 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -712,13 +712,14 @@ def test_update_object_metadata(self): ) self.assertEqual(errors, []) self.assertEqual(len(items), 2) + self._db_map.remove_unused_metadata() self._db_map.commit_session("Update data") metadata_entries = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata_entries), 1) - self.assertEqual(dict(metadata_entries[0]), {"id": 2, "name": "key_2", "value": "new value", "commit_id": 2}) + self.assertEqual(dict(metadata_entries[0]), {"id": 2, "name": "key_2", "value": "new value", "commit_id": 3}) entity_metadata_entries = self._db_map.query(self._db_map.entity_metadata_sq).all() self.assertEqual(len(entity_metadata_entries), 1) - self.assertEqual(dict(entity_metadata_entries[0]), {"id": 1, "entity_id": 1, "metadata_id": 2, "commit_id": 2}) + self.assertEqual(dict(entity_metadata_entries[0]), {"id": 1, "entity_id": 1, "metadata_id": 2, "commit_id": 3}) def test_update_object_metadata_reuses_existing_metadata(self): import_functions.import_object_classes(self._db_map, ("my_class",)) @@ -738,6 +739,7 @@ def test_update_object_metadata_reuses_existing_metadata(self): ids = {x["id"] for x in items} self.assertEqual(errors, []) self.assertEqual(ids, {1}) + self._db_map.remove_unused_metadata() self._db_map.commit_session("Update data") metadata_entries = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata_entries), 1) @@ -746,7 +748,7 @@ def test_update_object_metadata_reuses_existing_metadata(self): ) entity_metadata_entries = self._db_map.query(self._db_map.entity_metadata_sq).all() self.assertEqual(len(entity_metadata_entries), 2) - self.assertEqual(dict(entity_metadata_entries[0]), {"id": 1, "entity_id": 1, "metadata_id": 2, "commit_id": 2}) + self.assertEqual(dict(entity_metadata_entries[0]), {"id": 1, "entity_id": 1, "metadata_id": 2, "commit_id": 3}) self.assertEqual(dict(entity_metadata_entries[1]), {"id": 2, "entity_id": 2, "metadata_id": 2, "commit_id": 2}) def test_update_object_metadata_keeps_metadata_still_in_use(self): @@ -793,14 +795,15 @@ def test_update_parameter_value_metadata(self): ) self.assertEqual(errors, []) self.assertEqual(len(items), 2) + self._db_map.remove_unused_metadata() self._db_map.commit_session("Update data") metadata_entries = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata_entries), 1) - self.assertEqual(dict(metadata_entries[0]), {"id": 2, "name": "key_2", "value": "new value", "commit_id": 2}) + self.assertEqual(dict(metadata_entries[0]), {"id": 2, "name": "key_2", "value": "new value", "commit_id": 3}) value_metadata_entries = self._db_map.query(self._db_map.parameter_value_metadata_sq).all() self.assertEqual(len(value_metadata_entries), 1) self.assertEqual( - dict(value_metadata_entries[0]), {"id": 1, "parameter_value_id": 1, "metadata_id": 2, "commit_id": 2} + dict(value_metadata_entries[0]), {"id": 1, "parameter_value_id": 1, "metadata_id": 2, "commit_id": 3} ) def test_update_parameter_value_metadata_will_not_delete_shared_entity_metadata(self): diff --git a/tests/test_DiffDatabaseMapping.py b/tests/test_DiffDatabaseMapping.py index 444815df..cd7af14e 100644 --- a/tests/test_DiffDatabaseMapping.py +++ b/tests/test_DiffDatabaseMapping.py @@ -180,8 +180,11 @@ def test_cascade_remove_relationship_class_from_committed_session(self): def test_remove_object_class(self): """Test adding and removing an object class and committing""" items, _ = self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) - self._db_map.remove_items("object_class", *{x["id"] for x in items}) - self._db_map.commit_session("delete") + self.assertEqual(len(items), 2) + self._db_map.remove_items("object_class", 1, 2) + with self.assertRaises(SpineDBAPIError): + # Nothing to commit + self._db_map.commit_session("delete") self.assertEqual(len(self._db_map.query(self._db_map.object_class_sq).all()), 0) def test_remove_object_class_from_committed_session(self): @@ -313,6 +316,7 @@ def test_cascade_remove_entity_metadata_removes_corresponding_metadata(self): entity_metadata = self._db_map.query(self._db_map.entity_metadata_sq).all() self.assertEqual(len(entity_metadata), 1) self._db_map.remove_items("entity_metadata", entity_metadata[0].id) + self._db_map.remove_unused_metadata() self._db_map.commit_session("Remove test data.") self.assertEqual(len(self._db_map.query(self._db_map.metadata_sq).all()), 0) self.assertEqual(len(self._db_map.query(self._db_map.entity_metadata_sq).all()), 0) @@ -367,6 +371,7 @@ def test_cascade_remove_object_removes_its_metadata(self): import_functions.import_object_metadata(self._db_map, (("my_class", "my_object", '{"title": "My metadata."}'),)) self._db_map.commit_session("Add test data.") self._db_map.remove_items("object", 1) + self._db_map.remove_unused_metadata() self._db_map.commit_session("Remove test data.") self.assertEqual(len(self._db_map.query(self._db_map.metadata_sq).all()), 0) self.assertEqual(len(self._db_map.query(self._db_map.entity_metadata_sq).all()), 0) @@ -383,6 +388,7 @@ def test_cascade_remove_relationship_removes_its_metadata(self): ) self._db_map.commit_session("Add test data.") self._db_map.remove_items("relationship", 2) + self._db_map.remove_unused_metadata() self._db_map.commit_session("Remove test data.") self.assertEqual(len(self._db_map.query(self._db_map.metadata_sq).all()), 0) self.assertEqual(len(self._db_map.query(self._db_map.entity_metadata_sq).all()), 0) @@ -401,6 +407,7 @@ def test_cascade_remove_parameter_value_removes_its_metadata(self): ) self._db_map.commit_session("Add test data.") self._db_map.remove_items("parameter_value", 1) + self._db_map.remove_unused_metadata() self._db_map.commit_session("Remove test data.") self.assertEqual(len(self._db_map.query(self._db_map.metadata_sq).all()), 0) self.assertEqual(len(self._db_map.query(self._db_map.entity_metadata_sq).all()), 0) diff --git a/tests/test_check_functions.py b/tests/test_check_functions.py index fb5ce399..ead4e546 100644 --- a/tests/test_check_functions.py +++ b/tests/test_check_functions.py @@ -13,11 +13,17 @@ from numbers import Number import unittest -from spinedb_api.db_cache import DBCache, ParameterValueItem +from spinedb_api import DatabaseMapping +from spinedb_api.parameter_value import to_database from spinedb_api.exception import SpineIntegrityError -@unittest.skip("obsolete, but need to adapt to current check system") +def _val_dict(val): + keys = ("value", "type") + values = to_database(val) + return dict(zip(keys, values)) + + class TestCheckFunctions(unittest.TestCase): def setUp(self): self.data = [ @@ -25,58 +31,65 @@ def setUp(self): (int, (b'32', b'3.14'), (b'42', b'-2')), (str, (b'"FOO"', b'"bar"'), (b'"foo"', b'"Bar"', b'"BAZ"')), ] - self.parameter_definitions = { - 1: {'name': 'par1', 'entity_class_id': 1, 'parameter_value_list_id': 1}, - 2: {'name': 'par2', 'entity_class_id': 1, 'parameter_value_list_id': 2}, - 3: {'name': 'par2', 'entity_class_id': 1, 'parameter_value_list_id': 3}, - } self.value_type = {bool: 1, int: 2, str: 3} - self.parameter_value_lists = {1: (1, 2), 2: (3, 4), 3: (5, 6, 7)} - self.list_values = {1: True, 2: False, 3: 42, 4: -2, 5: 'foo', 6: 'Bar', 7: 'BAZ'} + self.db_map = DatabaseMapping("sqlite://", create=True) + self.db_map.add_items("entity_class", {"id": 1, 'name': 'cat'}) + self.db_map.add_items( + "entity", + {"id": 1, 'name': 'Tom', "class_id": 1}, + {"id": 2, 'name': 'Felix', "class_id": 1}, + {"id": 3, 'name': 'Jansson', "class_id": 1}, + ) + self.db_map.add_items( + "parameter_value_list", {"id": 1, 'name': 'list1'}, {"id": 2, 'name': 'list2'}, {"id": 3, 'name': 'list3'} + ) + self.db_map.add_items( + "list_value", + {"id": 1, **_val_dict(True), "index": 0, "parameter_value_list_id": 1}, + {"id": 2, **_val_dict(False), "index": 1, "parameter_value_list_id": 1}, + {"id": 3, **_val_dict(42), "index": 0, "parameter_value_list_id": 2}, + {"id": 4, **_val_dict(-2), "index": 1, "parameter_value_list_id": 2}, + {"id": 5, **_val_dict("foo"), "index": 0, "parameter_value_list_id": 3}, + {"id": 6, **_val_dict("Bar"), "index": 1, "parameter_value_list_id": 3}, + {"id": 7, **_val_dict("BAZ"), "index": 2, "parameter_value_list_id": 3}, + ) + self.db_map.add_items( + "parameter_definition", + {"id": 1, 'name': 'par1', 'entity_class_id': 1, 'parameter_value_list_id': 1}, + {"id": 2, 'name': 'par2', 'entity_class_id': 1, 'parameter_value_list_id': 2}, + {"id": 3, 'name': 'par3', 'entity_class_id': 1, 'parameter_value_list_id': 3}, + ) - def get_item(self, _type: type, val: bytes): - _id = self.value_type[_type] # setup: parameter definition/value list ids are equal - kwargs = { + @staticmethod + def get_item(id_: int, val: bytes, entity_id: int): + return { 'id': 1, - 'parameter_definition_id': _id, + 'parameter_definition_id': id_, 'entity_class_id': 1, - 'entity_id': 1, - 'object_class_id': 1, - 'object_id': 1, + 'entity_id': entity_id, 'value': val, - 'commit_id': 3, + 'type': None, 'alternative_id': 1, - 'object_class_name': 'test_objcls', - 'alternative_name': 'Base', - 'object_name': 'obj1', } - return ParameterValueItem(DBCache(lambda *_, **__: None), item_type="value", **kwargs) def test_replace_parameter_or_default_values_with_list_references(self): # regression test for spine-tools/Spine-Toolbox#1878 - for _type, _fail, _pass in self.data: - for data in _fail: - with self.subTest(_type=_type, data=data): - expect_in = json.loads(data.decode('utf8')) + for type_, fail, pass_ in self.data: + id_ = self.value_type[type_] # setup: parameter definition/value list ids are equal + for k, value in enumerate(fail): + with self.subTest(type=type_, value=value): + expect_in = json.loads(value.decode('utf8')) if isinstance(expect_in, Number): expect_in = float(expect_in) - ref = [self.list_values[i] for i in self.parameter_value_lists[self.value_type[_type]]] - expect_ref = ", ".join(f"{json.dumps(i)!r}" for i in ref) - self.assertRaisesRegex( - SpineIntegrityError, - fr"{expect_in!r}.+{expect_ref}", - replace_parameter_values_with_list_references, - self.get_item(_type, data), - self.parameter_definitions, - self.parameter_value_lists, - self.list_values, - ) - - for data in _pass: - with self.subTest(_type=_type, data=data): - replace_parameter_values_with_list_references( - self.get_item(_type, data), - self.parameter_definitions, - self.parameter_value_lists, - self.list_values, - ) + item = self.get_item(id_, value, 1) + _, errors = self.db_map.add_items("parameter_value", item) + self.assertEqual(len(errors), 1) + parsed_value = json.loads(value.decode('utf8')) + if isinstance(parsed_value, Number): + parsed_value = float(parsed_value) + self.assertEqual(errors[0], f"value {parsed_value} of par{id_} for ('Tom',) is not in list{id_}") + for k, value in enumerate(pass_): + with self.subTest(type=type_, value=value): + item = self.get_item(id_, value, k + 1) + _, errors = self.db_map.add_items("parameter_value", item) + self.assertEqual(errors, []) From 3172ee1eb0daf0fc26f6f3fa4282d946baad2a41 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Mon, 22 May 2023 11:06:04 +0300 Subject: [PATCH 044/317] Black formatting Re #215 --- spinedb_api/export_mapping/export_mapping.py | 6 ++---- spinedb_api/query.py | 1 - tests/export_mapping/test_export_mapping.py | 7 ++++++- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/spinedb_api/export_mapping/export_mapping.py b/spinedb_api/export_mapping/export_mapping.py index 925c00e4..4b4783f2 100644 --- a/spinedb_api/export_mapping/export_mapping.py +++ b/spinedb_api/export_mapping/export_mapping.py @@ -644,9 +644,7 @@ def add_query_columns(self, db_map, query): db_map.wide_entity_class_sq.c.dimension_name_list.label("dimension_name_list"), ) if self.highlight_position is not None: - query = query.add_columns( - db_map.entity_class_dimension_sq.c.dimension_id.label("highlighted_dimension_id") - ) + query = query.add_columns(db_map.entity_class_dimension_sq.c.dimension_id.label("highlighted_dimension_id")) return query def filter_query(self, db_map, query): @@ -901,7 +899,7 @@ def add_query_columns(self, db_map, query): def filter_query(self, db_map, query): if self.query_parents("highlight_position") is not None: - return query.outerjoin( + return query.outerjoin( db_map.parameter_definition_sq, db_map.parameter_definition_sq.c.entity_class_id == db_map.entity_class_dimension_sq.c.dimension_id, ) diff --git a/spinedb_api/query.py b/spinedb_api/query.py index 04e5536a..df75daff 100644 --- a/spinedb_api/query.py +++ b/spinedb_api/query.py @@ -30,7 +30,6 @@ def column_descriptions(self): def column_names(self): yield from (c.name for c in self._select.columns) - def subquery(self, name=None): return self._select.alias(name) diff --git a/tests/export_mapping/test_export_mapping.py b/tests/export_mapping/test_export_mapping.py index 9623cbe0..ccafe03f 100644 --- a/tests/export_mapping/test_export_mapping.py +++ b/tests/export_mapping/test_export_mapping.py @@ -1448,7 +1448,12 @@ def test_relationship_class_object_classes_parameters_multiple_dimensions(self): import_relationship_classes(db_map, (("rc", ("oc1", "oc2")),)) db_map.commit_session("Add test data") root_mapping = unflatten( - [EntityClassMapping(0, highlight_position=0), DimensionMapping(1), DimensionMapping(3), ParameterDefinitionMapping(2)] + [ + EntityClassMapping(0, highlight_position=0), + DimensionMapping(1), + DimensionMapping(3), + ParameterDefinitionMapping(2), + ] ) expected = [["rc", "oc1", "p11", "oc2"], ["rc", "oc1", "p12", "oc2"]] self.assertEqual(list(rows(root_mapping, db_map)), expected) From 8e95ce8d5018212603cd3c3685595874767d84bd Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 23 May 2023 12:54:33 +0200 Subject: [PATCH 045/317] Implement rollback --- spinedb_api/db_cache_base.py | 74 +++++++++++++++++++++----- spinedb_api/db_cache_impl.py | 55 +++++++++++++++++-- spinedb_api/db_mapping_commit_mixin.py | 7 ++- 3 files changed, 115 insertions(+), 21 deletions(-) diff --git a/spinedb_api/db_cache_base.py b/spinedb_api/db_cache_base.py index 89c179c5..d9c876c8 100644 --- a/spinedb_api/db_cache_base.py +++ b/spinedb_api/db_cache_base.py @@ -79,6 +79,8 @@ def dirty_items(self): if to_remove: # Fetch descendants, so that they are validated in next iterations of the loop. # This ensures cascade removal. + # FIXME: We should also fetch the current item type because of multi-dimensional entities and + # classes which also depend on no-dimensional ones for x in self: if self._cmp_item_type(item_type, x) < 0: self.fetch_all(x) @@ -86,6 +88,32 @@ def dirty_items(self): dirty_items.append((item_type, (to_add, to_update, to_remove))) return dirty_items + def rollback(self): + dirty_items = self.dirty_items() + if not dirty_items: + return False + to_add_by_type = [] + to_update_by_type = [] + to_remove_by_type = [] + for item_type, (to_add, to_update, to_remove) in reversed(dirty_items): + to_add_by_type.append((item_type, to_add)) + to_update_by_type.append((item_type, to_update)) + to_remove_by_type.append((item_type, to_remove)) + for item_type, to_remove in to_remove_by_type: + table_cache = self.table_cache(item_type) + for item in to_remove: + table_cache.restore_item(item["id"]) + for item_type, to_update in to_update_by_type: + table_cache = self.table_cache(item_type) + for item in to_update: + table_cache.update_item(item.backup) + for item_type, to_add in to_add_by_type: + table_cache = self.table_cache(item_type) + for item in to_add: + if table_cache.remove_item(item["id"]) is not None: + del item["id"] + return True + @property def fetched_item_types(self): return self._fetched_item_types @@ -306,8 +334,6 @@ def update_item(self, item): current_item.update(item) self._add_unique(current_item) current_item.cascade_update() - if current_item.status == Status.committed: - current_item.status = Status.to_update return current_item def remove_item(self, id_): @@ -350,12 +376,25 @@ def __init__(self, db_cache, item_type, **kwargs): self._removed = False self._corrupted = False self._valid = None - self.status = Status.committed + self._status = Status.committed + self._backup = None @classmethod def ref_types(cls): return set(ref_type for _src_key, (ref_type, _ref_key) in cls._references.values()) + @property + def status(self): + return self._status + + @status.setter + def status(self, status): + self._status = status + + @property + def backup(self): + return self._backup + @property def removed(self): return self._removed @@ -396,6 +435,9 @@ def get(self, key, default=None): return default def update(self, other): + if self._status == Status.committed: + self._status = Status.to_update + self._backup = self._asdict() for src_key, (ref_type, _ref_key) in self._references.values(): ref_id = self[src_key] if src_key in other and other[src_key] != ref_id: @@ -406,11 +448,13 @@ def update(self, other): else: self._forget_ref(ref_type, ref_id) super().update(other) + if self._asdict() == self._backup: + self._status = Status.committed def merge(self, other): if all(self.get(key) == value for key, value in other.items()): return None, "" - merged = {**self, **other} + merged = {**self._extended(), **other} merged["id"] = self["id"] return merged, "" @@ -428,7 +472,7 @@ def resolve_inverse_references(self, skip_keys=()): for src_key, (id_key, (ref_type, ref_key)) in self._inverse_references.items(): if src_key in skip_keys: continue - id_value = tuple(dict.get(self, k) or self.get(k) for k in id_key) + id_value = tuple(dict.pop(self, k, None) or self.get(k) for k in id_key) if None in id_value: continue table_cache = self._db_cache.table_cache(ref_type) @@ -523,8 +567,10 @@ def _update_weak_referrers(self): def cascade_restore(self): if not self._removed: return - if self.status == Status.committed: - self.status = Status.to_add + if self._status == Status.committed: + self._status = Status.to_add + else: + self._status = Status.committed self._removed = False for referrer in self._referrers.values(): referrer.cascade_restore() @@ -538,10 +584,10 @@ def cascade_restore(self): def cascade_remove(self): if self._removed: return - if self.status == Status.committed: - self.status = Status.to_remove + if self._status == Status.committed: + self._status = Status.to_remove else: - self.status = Status.committed + self._status = Status.committed self._removed = True self._to_remove = False self._valid = None @@ -561,7 +607,6 @@ def cascade_update(self): self._update_weak_referrers() def call_update_callbacks(self): - self.pop("parsed_value", None) obsolete = set() for callback in self.update_callbacks: if not callback(self): @@ -575,8 +620,9 @@ def _asdict(self): return dict(self) def is_committed(self): - return self.status == Status.committed + return self._status == Status.committed def commit(self, commit_id): - self.status = Status.committed - self["commit_id"] = commit_id + self._status = Status.committed + if commit_id: + self["commit_id"] = commit_id diff --git a/spinedb_api/db_cache_impl.py b/spinedb_api/db_cache_impl.py index d44ed40d..805f18c3 100644 --- a/spinedb_api/db_cache_impl.py +++ b/spinedb_api/db_cache_impl.py @@ -14,7 +14,7 @@ """ import uuid from operator import itemgetter -from .parameter_value import from_database +from .parameter_value import from_database, ParameterValueFormatError from .db_cache_base import DBCacheBase, CacheItemBase @@ -65,7 +65,7 @@ def _query(self, item_type): class EntityClassItem(CacheItemBase): - _defaults = {"description": None, "display_icon": None} + _defaults = {"description": None, "display_icon": None, "display_order": 99, "hidden": False} _unique_keys = (("name",),) _references = {"dimension_name_list": ("dimension_id_list", ("entity_class", "name"))} _inverse_references = {"dimension_id_list": (("dimension_name_list",), ("entity_class", ("name",)))} @@ -89,6 +89,9 @@ def merge(self, other): merged, super_error = super().merge(other) return merged, " and ".join([x for x in (super_error, error) if x]) + def commit(self, commit_id): + super().commit(None) + class EntityItem(CacheItemBase): _defaults = {"description": None} @@ -154,7 +157,31 @@ def __getitem__(self, key): return super().__getitem__(key) -class ParameterDefinitionItem(CacheItemBase): +class ParsedValueBase(CacheItemBase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._parsed_value = None + + @property + def parsed_value(self): + if self._parsed_value is None: + self._parsed_value = self._make_parsed_value() + return self._parsed_value + + def _make_parsed_value(self): + raise NotImplementedError() + + def update(self, other): + self._parsed_value = None + super().update(other) + + def __getitem__(self, key): + if key == "parsed_value": + return self.parsed_value + return super().__getitem__(key) + + +class ParameterDefinitionItem(ParsedValueBase): _defaults = {"description": None, "default_value": None, "default_type": None, "parameter_value_list_id": None} _unique_keys = (("entity_class_name", "name"),) _references = { @@ -173,6 +200,12 @@ def list_value_id(self): return int(dict.__getitem__(self, "default_value")) return None + def _make_parsed_value(self): + try: + return from_database(self["default_value"], self["default_type"]) + except ParameterValueFormatError as error: + return error + def __getitem__(self, key): if key == "parameter_name": return super().__getitem__("name") @@ -238,7 +271,7 @@ def merge(self, other): return merged, " and ".join([x for x in (super_error, error) if x]) -class ParameterValueItem(CacheItemBase): +class ParameterValueItem(ParsedValueBase): _unique_keys = (("parameter_definition_name", "entity_byname", "alternative_name"),) _references = { "entity_class_name": ("entity_class_id", ("entity_class", "name")), @@ -269,6 +302,12 @@ def list_value_id(self): return int(dict.__getitem__(self, "value")) return None + def _make_parsed_value(self): + try: + return from_database(self["value"], self["type"]) + except ParameterValueFormatError as error: + return error + def __getitem__(self, key): if key == "parameter_id": return super().__getitem__("parameter_definition_id") @@ -315,13 +354,19 @@ class ParameterValueListItem(CacheItemBase): _unique_keys = (("name",),) -class ListValueItem(CacheItemBase): +class ListValueItem(ParsedValueBase): _unique_keys = (("parameter_value_list_name", "value", "type"), ("parameter_value_list_name", "index")) _references = {"parameter_value_list_name": ("parameter_value_list_id", ("parameter_value_list", "name"))} _inverse_references = { "parameter_value_list_id": (("parameter_value_list_name",), ("parameter_value_list", ("name",))), } + def _make_parsed_value(self): + try: + return from_database(self["value"], self["type"]) + except ParameterValueFormatError as error: + return error + class AlternativeItem(CacheItemBase): _defaults = {"description": None} diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py index 7d251afb..349be3fc 100644 --- a/spinedb_api/db_mapping_commit_mixin.py +++ b/spinedb_api/db_mapping_commit_mixin.py @@ -47,6 +47,9 @@ def commit_session(self, comment): self._memory_dirty = True def rollback_session(self): - if not self.cache.dirty_items(): + if not self.cache.rollback(): raise SpineDBAPIError("Nothing to rollback.") - self.cache.reset_queries() + + def refresh_session(self): + # TODO + pass From 4c6a1eaf6ad10bef145032165100610410cadad4 Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 23 May 2023 12:55:11 +0200 Subject: [PATCH 046/317] Import zero dim entities before the others, so it works --- spinedb_api/import_functions.py | 53 ++++++++++++++++++++------------- tests/test_check_functions.py | 4 --- 2 files changed, 33 insertions(+), 24 deletions(-) diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 4cc5c0f1..912812d5 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -184,11 +184,12 @@ def get_data_for_import( yield ("alternative", _get_alternatives_for_import(db_map, alternatives)) yield ("scenario_alternative", _get_scenario_alternatives_for_import(db_map, scenario_alternatives)) if entity_classes: - yield ("entity_class", _get_entity_classes_for_import(db_map, entity_classes)) + yield ("entity_class", _get_entity_classes_for_import(db_map, entity_classes, zero_dim=True)) + yield ("entity_class", _get_entity_classes_for_import(db_map, entity_classes, zero_dim=False)) if object_classes: yield ("object_class", _get_object_classes_for_import(db_map, object_classes)) if relationship_classes: - yield ("relationship_class", _get_entity_classes_for_import(db_map, relationship_classes)) + yield ("relationship_class", _get_entity_classes_for_import(db_map, relationship_classes, zero_dim=False)) if parameter_value_lists: yield ("parameter_value_list", _get_parameter_value_lists_for_import(db_map, parameter_value_lists)) yield ("list_value", _get_list_values_for_import(db_map, parameter_value_lists, unparse_value)) @@ -205,11 +206,12 @@ def get_data_for_import( _get_parameter_definitions_for_import(db_map, relationship_parameters, unparse_value), ) if entities: - yield ("entity", _get_entities_for_import(db_map, entities)) + yield ("entity", _get_entities_for_import(db_map, entities, zero_dim=True)) + yield ("entity", _get_entities_for_import(db_map, entities, zero_dim=False)) if objects: - yield ("object", _get_entities_for_import(db_map, objects)) + yield ("object", _get_entities_for_import(db_map, objects, zero_dim=True)) if relationships: - yield ("relationship", _get_entities_for_import(db_map, relationships)) + yield ("relationship", _get_entities_for_import(db_map, relationships, zero_dim=False)) if entity_groups: yield ("entity_group", _get_entity_groups_for_import(db_map, entity_groups)) if object_groups: @@ -742,20 +744,20 @@ def import_relationship_parameter_value_metadata(db_map, data): return import_data(db_map, relationship_parameter_value_metadata=data) -def _get_items_for_import(db_map, item_type, data, skip_keys=()): +def _get_items_for_import(db_map, item_type, data, check_skip_keys=()): table_cache = db_map.cache.table_cache(item_type) errors = [] to_add = [] to_update = [] seen = {} for item in data: - checked_item, add_error = table_cache.check_item(item, skip_keys=skip_keys) + checked_item, add_error = table_cache.check_item(item, skip_keys=check_skip_keys) if not add_error: if not _check_unique(item_type, checked_item, seen, errors): continue to_add.append(checked_item) continue - checked_item, update_error = table_cache.check_item(item, for_update=True, skip_keys=skip_keys) + checked_item, update_error = table_cache.check_item(item, for_update=True, skip_keys=check_skip_keys) if not update_error: if checked_item: if not _check_unique(item_type, checked_item, seen, errors): @@ -782,17 +784,28 @@ def _add_to_seen(checked_item, seen): seen.setdefault(key, set()).add(value) -def _get_entity_classes_for_import(db_map, data): +def _get_entity_classes_for_import(db_map, data, zero_dim): + def _data_iterator(): + for x in data: + if isinstance(x, str): + x = x, () + name, *optionals = x + dim_name_list = optionals.pop(0) if optionals else () + if (dim_name_list and zero_dim) or (not dim_name_list and not zero_dim): + continue + yield name, dim_name_list, *optionals + key = ("name", "dimension_name_list", "description", "display_icon") - return _get_items_for_import( - db_map, "entity_class", ({"name": x} if isinstance(x, str) else dict(zip(key, x)) for x in data) - ) + return _get_items_for_import(db_map, "entity_class", (dict(zip(key, x)) for x in _data_iterator())) -def _get_entities_for_import(db_map, data): +def _get_entities_for_import(db_map, data, zero_dim): def _data_iterator(): for class_name, name_or_element_name_list, *optionals in data: - byname_key = "name" if isinstance(name_or_element_name_list, str) else "element_name_list" + is_zero_dim = isinstance(name_or_element_name_list, str) + if (is_zero_dim and not zero_dim) or (not is_zero_dim and zero_dim): + continue + byname_key = "name" if is_zero_dim else "element_name_list" key = ("class_name", byname_key, "description") yield dict(zip(key, (class_name, name_or_element_name_list, *optionals))) @@ -882,7 +895,7 @@ def _data_iterator(): yield {"scenario_name": scen_name, "alternative_name": alt_name, "rank": k + 1} to_add, to_update, more_errors = _get_items_for_import( - db_map, "scenario_alternative", _data_iterator(), skip_keys=(("scenario_name", "rank"),) + db_map, "scenario_alternative", _data_iterator(), check_skip_keys=(("scenario_name", "rank"),) ) return to_add, to_update, errors + more_errors @@ -960,9 +973,9 @@ def _get_object_classes_for_import(db_map, data): def _data_iterator(): for x in data: if isinstance(x, str): - yield x - continue - name, *optionals = x - yield name, (), *optionals + yield x, () + else: + name, *optionals = x + yield name, (), *optionals - return _get_entity_classes_for_import(db_map, _data_iterator()) + return _get_entity_classes_for_import(db_map, _data_iterator(), zero_dim=True) diff --git a/tests/test_check_functions.py b/tests/test_check_functions.py index ead4e546..ea4cdbeb 100644 --- a/tests/test_check_functions.py +++ b/tests/test_check_functions.py @@ -15,7 +15,6 @@ from spinedb_api import DatabaseMapping from spinedb_api.parameter_value import to_database -from spinedb_api.exception import SpineIntegrityError def _val_dict(val): @@ -78,9 +77,6 @@ def test_replace_parameter_or_default_values_with_list_references(self): id_ = self.value_type[type_] # setup: parameter definition/value list ids are equal for k, value in enumerate(fail): with self.subTest(type=type_, value=value): - expect_in = json.loads(value.decode('utf8')) - if isinstance(expect_in, Number): - expect_in = float(expect_in) item = self.get_item(id_, value, 1) _, errors = self.db_map.add_items("parameter_value", item) self.assertEqual(len(errors), 1) From c29b4b68b5a7a53512fa1003af14e90ee3bec0ad Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 23 May 2023 17:11:00 +0200 Subject: [PATCH 047/317] Consolidate DatabaseMapping tests --- tests/test_DatabaseMapping.py | 1264 ++++++++++++++++++++++++++- tests/test_DiffDatabaseMapping.py | 1317 ----------------------------- 2 files changed, 1253 insertions(+), 1328 deletions(-) delete mode 100644 tests/test_DiffDatabaseMapping.py diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index d8f015a6..ba096fe0 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -13,14 +13,77 @@ Unit tests for DatabaseMapping class. """ +import os.path +from tempfile import TemporaryDirectory import unittest +from unittest import mock from unittest.mock import patch -from sqlalchemy.engine.url import URL -from spinedb_api import DatabaseMapping, to_database, import_functions, from_database, SpineDBAPIError +from sqlalchemy.engine.url import make_url, URL +from sqlalchemy.util import KeyedTuple +from spinedb_api import ( + DatabaseMapping, + import_functions, + from_database, + to_database, + SpineDBAPIError, + SpineIntegrityError, +) + + +def create_query_wrapper(db_map): + def query_wrapper(*args, orig_query=db_map.query, **kwargs): + arg = args[0] + if isinstance(arg, mock.Mock): + return arg.value + return orig_query(*args, **kwargs) + + return query_wrapper + IN_MEMORY_DB_URL = "sqlite://" +class TestDatabaseMappingConstruction(unittest.TestCase): + def test_construction_with_filters(self): + db_url = IN_MEMORY_DB_URL + "?spinedbfilter=fltr1&spinedbfilter=fltr2" + with mock.patch("spinedb_api.db_mapping.apply_filter_stack") as mock_apply: + with mock.patch( + "spinedb_api.db_mapping.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] + ) as mock_load: + db_map = DatabaseMapping(db_url, create=True) + db_map.close() + mock_load.assert_called_once_with(["fltr1", "fltr2"]) + mock_apply.assert_called_once_with(db_map, [{"fltr1": "config1", "fltr2": "config2"}]) + + def test_construction_with_sqlalchemy_url_and_filters(self): + db_url = IN_MEMORY_DB_URL + "/?spinedbfilter=fltr1&spinedbfilter=fltr2" + sa_url = make_url(db_url) + with mock.patch("spinedb_api.db_mapping.apply_filter_stack") as mock_apply: + with mock.patch( + "spinedb_api.db_mapping.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] + ) as mock_load: + db_map = DatabaseMapping(sa_url, create=True) + db_map.close() + mock_load.assert_called_once_with(["fltr1", "fltr2"]) + mock_apply.assert_called_once_with(db_map, [{"fltr1": "config1", "fltr2": "config2"}]) + + def test_shorthand_filter_query_works(self): + with TemporaryDirectory() as temp_dir: + url = URL("sqlite") + url.database = os.path.join(temp_dir, "test_shorthand_filter_query_works.json") + out_db_map = DatabaseMapping(url, create=True) + out_db_map.add_scenarios({"name": "scen1"}) + out_db_map.add_scenario_alternatives({"scenario_name": "scen1", "alternative_name": "Base", "rank": 1}) + out_db_map.commit_session("Add scen.") + out_db_map.close() + try: + db_map = DatabaseMapping(url) + except: + self.fail("DatabaseMapping.__init__() should not raise.") + else: + db_map.close() + + class TestDatabaseMappingBase(unittest.TestCase): _db_map = None @@ -539,14 +602,844 @@ def test_wide_parameter_value_list_sq(self): self.assertEqual(value_lists[1].name, "list2") -class TestDatabaseMappingUpdateMixin(unittest.TestCase): +class TestDatabaseMappingAdd(unittest.TestCase): + def setUp(self): + self._db_map = DatabaseMapping(IN_MEMORY_DB_URL, create=True) + + def tearDown(self): + self._db_map.close() + + def test_add_and_retrieve_many_objects(self): + """Tests add many objects into db and retrieving them.""" + items, _ = self._db_map.add_object_classes({"name": "testclass"}) + class_id = next(iter(items))["id"] + added = self._db_map.add_objects(*[{"name": str(i), "class_id": class_id} for i in range(1001)])[0] + self.assertEqual(len(added), 1001) + self._db_map.commit_session("test_commit") + self.assertEqual(self._db_map.query(self._db_map.entity_sq).count(), 1001) + + def test_add_object_classes(self): + """Test that adding object classes works.""" + self._db_map.add_object_classes({"name": "fish"}, {"name": "dog"}) + self._db_map.commit_session("add") + object_classes = self._db_map.query(self._db_map.object_class_sq).all() + self.assertEqual(len(object_classes), 2) + self.assertEqual(object_classes[0].name, "fish") + self.assertEqual(object_classes[1].name, "dog") + + def test_add_object_class_with_invalid_name(self): + """Test that adding object classes with empty name raises error""" + with self.assertRaises(SpineIntegrityError): + self._db_map.add_object_classes({"name": ""}, strict=True) + + def test_add_object_classes_with_same_name(self): + """Test that adding two object classes with the same name only adds one of them.""" + self._db_map.add_object_classes({"name": "fish"}, {"name": "fish"}) + self._db_map.commit_session("add") + object_classes = self._db_map.query(self._db_map.object_class_sq).all() + self.assertEqual(len(object_classes), 1) + self.assertEqual(object_classes[0].name, "fish") + + def test_add_object_class_with_same_name_as_existing_one(self): + """Test that adding an object class with an already taken name raises an integrity error.""" + self._db_map.add_object_classes({"name": "fish"}, {"name": "fish"}) + with self.assertRaises(SpineIntegrityError): + self._db_map.add_object_classes({"name": "fish"}, strict=True) + + def test_add_objects(self): + """Test that adding objects works.""" + self._db_map.add_object_classes({"name": "fish", "id": 1}) + self._db_map.add_objects({"name": "nemo", "class_id": 1}, {"name": "dory", "class_id": 1}) + self._db_map.commit_session("add") + objects = self._db_map.query(self._db_map.object_sq).all() + self.assertEqual(len(objects), 2) + self.assertEqual(objects[0].name, "nemo") + self.assertEqual(objects[0].class_id, 1) + self.assertEqual(objects[1].name, "dory") + self.assertEqual(objects[1].class_id, 1) + + def test_add_object_with_invalid_name(self): + """Test that adding object classes with empty name raises error""" + self._db_map.add_object_classes({"name": "fish"}) + with self.assertRaises(SpineIntegrityError): + self._db_map.add_objects({"name": "", "class_id": 1}, strict=True) + + def test_add_objects_with_same_name(self): + """Test that adding two objects with the same name only adds one of them.""" + self._db_map.add_object_classes({"name": "fish", "id": 1}) + self._db_map.add_objects({"name": "nemo", "class_id": 1}, {"name": "nemo", "class_id": 1}) + self._db_map.commit_session("add") + objects = self._db_map.query(self._db_map.object_sq).all() + self.assertEqual(len(objects), 1) + self.assertEqual(objects[0].name, "nemo") + self.assertEqual(objects[0].class_id, 1) + + def test_add_object_with_same_name_as_existing_one(self): + """Test that adding an object with an already taken name raises an integrity error.""" + self._db_map.add_object_classes({"name": "fish"}) + self._db_map.add_objects({"name": "nemo", "class_id": 1}) + with self.assertRaises(SpineIntegrityError): + self._db_map.add_objects({"name": "nemo", "class_id": 1}, strict=True) + + def test_add_object_with_invalid_class(self): + """Test that adding an object with a non existing class raises an integrity error.""" + self._db_map.add_object_classes({"name": "fish"}) + with self.assertRaises(SpineIntegrityError): + self._db_map.add_objects({"name": "pluto", "class_id": 2}, strict=True) + + def test_add_relationship_classes(self): + """Test that adding relationship classes works.""" + self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) + self._db_map.add_wide_relationship_classes( + {"name": "rc1", "object_class_id_list": [1, 2]}, {"name": "rc2", "object_class_id_list": [2, 1]} + ) + self._db_map.commit_session("add") + table = self._db_map.get_table("entity_class_dimension") + ent_cls_dims = self._db_map.query(table).all() + rel_clss = self._db_map.query(self._db_map.wide_relationship_class_sq).all() + self.assertEqual(len(ent_cls_dims), 4) + self.assertEqual(rel_clss[0].name, "rc1") + self.assertEqual(ent_cls_dims[0].dimension_id, 1) + self.assertEqual(ent_cls_dims[1].dimension_id, 2) + self.assertEqual(rel_clss[1].name, "rc2") + self.assertEqual(ent_cls_dims[2].dimension_id, 2) + self.assertEqual(ent_cls_dims[3].dimension_id, 1) + + def test_add_relationship_classes_with_invalid_name(self): + """Test that adding object classes with empty name raises error""" + self._db_map.add_object_classes({"name": "fish"}) + with self.assertRaises(SpineIntegrityError): + self._db_map.add_wide_relationship_classes({"name": "", "object_class_id_list": [1]}, strict=True) + + def test_add_relationship_classes_with_same_name(self): + """Test that adding two relationship classes with the same name only adds one of them.""" + self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) + self._db_map.add_wide_relationship_classes( + {"name": "rc1", "object_class_id_list": [1, 2]}, + {"name": "rc1", "object_class_id_list": [1, 2]}, + strict=False, + ) + self._db_map.commit_session("add") + table = self._db_map.get_table("entity_class_dimension") + ecs_dims = self._db_map.query(table).all() + relationship_classes = self._db_map.query(self._db_map.wide_relationship_class_sq).all() + self.assertEqual(len(ecs_dims), 2) + self.assertEqual(len(relationship_classes), 1) + self.assertEqual(relationship_classes[0].name, "rc1") + self.assertEqual(ecs_dims[0].dimension_id, 1) + self.assertEqual(ecs_dims[1].dimension_id, 2) + + def test_add_relationship_class_with_same_name_as_existing_one(self): + """Test that adding a relationship class with an already taken name raises an integrity error.""" + query_wrapper = create_query_wrapper(self._db_map) + with mock.patch.object(DatabaseMapping, "query") as mock_query, mock.patch.object( + DatabaseMapping, "object_class_sq" + ) as mock_object_class_sq, mock.patch.object( + DatabaseMapping, "wide_relationship_class_sq" + ) as mock_wide_rel_cls_sq: + mock_query.side_effect = query_wrapper + mock_object_class_sq.return_value = [ + KeyedTuple([1, "fish"], labels=["id", "name"]), + KeyedTuple([2, "dog"], labels=["id", "name"]), + ] + mock_wide_rel_cls_sq.return_value = [ + KeyedTuple([1, "1,2", "fish__dog"], labels=["id", "object_class_id_list", "name"]) + ] + with self.assertRaises(SpineIntegrityError): + self._db_map.add_wide_relationship_classes( + {"name": "fish__dog", "object_class_id_list": [1, 2]}, strict=True + ) + + def test_add_relationship_class_with_invalid_object_class(self): + """Test that adding a relationship class with a non existing object class raises an integrity error.""" + query_wrapper = create_query_wrapper(self._db_map) + with mock.patch.object(DatabaseMapping, "query") as mock_query, mock.patch.object( + DatabaseMapping, "object_class_sq" + ) as mock_object_class_sq, mock.patch.object(DatabaseMapping, "wide_relationship_class_sq"): + mock_query.side_effect = query_wrapper + mock_object_class_sq.return_value = [KeyedTuple([1, "fish"], labels=["id", "name"])] + with self.assertRaises(SpineIntegrityError): + self._db_map.add_wide_relationship_classes( + {"name": "fish__dog", "object_class_id_list": [1, 2]}, strict=True + ) + + def test_add_relationships(self): + """Test that adding relationships works.""" + self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) + self._db_map.add_wide_relationship_classes({"name": "rc1", "object_class_id_list": [1, 2], "id": 3}) + self._db_map.add_objects({"name": "o1", "class_id": 1, "id": 1}, {"name": "o2", "class_id": 2, "id": 2}) + self._db_map.add_wide_relationships({"name": "nemo__pluto", "class_id": 3, "object_id_list": [1, 2]}) + self._db_map.commit_session("add") + ent_els = self._db_map.query(self._db_map.get_table("entity_element")).all() + relationships = self._db_map.query(self._db_map.wide_relationship_sq).all() + self.assertEqual(len(ent_els), 2) + self.assertEqual(len(relationships), 1) + self.assertEqual(relationships[0].name, "nemo__pluto") + self.assertEqual(ent_els[0].entity_class_id, 3) + self.assertEqual(ent_els[0].element_id, 1) + self.assertEqual(ent_els[1].entity_class_id, 3) + self.assertEqual(ent_els[1].element_id, 2) + + def test_add_relationship_with_invalid_name(self): + """Test that adding object classes with empty name raises error""" + self._db_map.add_object_classes({"name": "oc1", "id": 1}, strict=True) + self._db_map.add_wide_relationship_classes({"name": "rc1", "object_class_id_list": [1]}, strict=True) + self._db_map.add_objects({"name": "o1", "class_id": 1}, strict=True) + with self.assertRaises(SpineIntegrityError): + self._db_map.add_wide_relationships({"name": "", "class_id": 2, "object_id_list": [1]}, strict=True) + + def test_add_identical_relationships(self): + """Test that adding two relationships with the same class and same objects only adds the first one.""" + self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) + self._db_map.add_wide_relationship_classes({"name": "rc1", "object_class_id_list": [1, 2], "id": 3}) + self._db_map.add_objects({"name": "o1", "class_id": 1, "id": 1}, {"name": "o2", "class_id": 2, "id": 2}) + self._db_map.add_wide_relationships( + {"name": "nemo__pluto", "class_id": 3, "object_id_list": [1, 2]}, + {"name": "nemo__pluto_duplicate", "class_id": 3, "object_id_list": [1, 2]}, + ) + self._db_map.commit_session("add") + relationships = self._db_map.query(self._db_map.wide_relationship_sq).all() + self.assertEqual(len(relationships), 1) + + def test_add_relationship_identical_to_existing_one(self): + """Test that adding a relationship with the same class and same objects as an existing one + raises an integrity error. + """ + query_wrapper = create_query_wrapper(self._db_map) + with mock.patch.object(DatabaseMapping, "query") as mock_query, mock.patch.object( + DatabaseMapping, "object_sq" + ) as mock_object_sq, mock.patch.object( + DatabaseMapping, "wide_relationship_class_sq" + ) as mock_wide_rel_cls_sq, mock.patch.object( + DatabaseMapping, "wide_relationship_sq" + ) as mock_wide_rel_sq: + mock_query.side_effect = query_wrapper + mock_object_sq.return_value = [ + KeyedTuple([1, 10, "nemo"], labels=["id", "class_id", "name"]), + KeyedTuple([2, 20, "pluto"], labels=["id", "class_id", "name"]), + ] + mock_wide_rel_cls_sq.return_value = [ + KeyedTuple([1, "10,20", "fish__dog"], labels=["id", "object_class_id_list", "name"]) + ] + mock_wide_rel_sq.return_value = [ + KeyedTuple([1, 1, "1,2", "nemo__pluto"], labels=["id", "class_id", "object_id_list", "name"]) + ] + with self.assertRaises(SpineIntegrityError): + self._db_map.add_wide_relationships( + {"name": "nemoy__plutoy", "class_id": 1, "object_id_list": [1, 2]}, strict=True + ) + + def test_add_relationship_with_invalid_class(self): + """Test that adding a relationship with an invalid class raises an integrity error.""" + query_wrapper = create_query_wrapper(self._db_map) + with mock.patch.object(DatabaseMapping, "query") as mock_query, mock.patch.object( + DatabaseMapping, "object_sq" + ) as mock_object_sq, mock.patch.object( + DatabaseMapping, "wide_relationship_class_sq" + ) as mock_wide_rel_cls_sq, mock.patch.object( + DatabaseMapping, "wide_relationship_sq" + ): + mock_query.side_effect = query_wrapper + mock_object_sq.return_value = [ + KeyedTuple([1, 10, "nemo"], labels=["id", "class_id", "name"]), + KeyedTuple([2, 20, "pluto"], labels=["id", "class_id", "name"]), + ] + mock_wide_rel_cls_sq.return_value = [ + KeyedTuple([1, "10,20", "fish__dog"], labels=["id", "object_class_id_list", "name"]) + ] + with self.assertRaises(SpineIntegrityError): + self._db_map.add_wide_relationships( + {"name": "nemo__pluto", "class_id": 2, "object_id_list": [1, 2]}, strict=True + ) + + def test_add_relationship_with_invalid_object(self): + """Test that adding a relationship with an invalid object raises an integrity error.""" + query_wrapper = create_query_wrapper(self._db_map) + with mock.patch.object(DatabaseMapping, "query") as mock_query, mock.patch.object( + DatabaseMapping, "object_sq" + ) as mock_object_sq, mock.patch.object( + DatabaseMapping, "wide_relationship_class_sq" + ) as mock_wide_rel_cls_sq, mock.patch.object( + DatabaseMapping, "wide_relationship_sq" + ): + mock_query.side_effect = query_wrapper + mock_object_sq.return_value = [ + KeyedTuple([1, 10, "nemo"], labels=["id", "class_id", "name"]), + KeyedTuple([2, 20, "pluto"], labels=["id", "class_id", "name"]), + ] + mock_wide_rel_cls_sq.return_value = [ + KeyedTuple([1, "10,20", "fish__dog"], labels=["id", "object_class_id_list", "name"]) + ] + with self.assertRaises(SpineIntegrityError): + self._db_map.add_wide_relationships( + {"name": "nemo__pluto", "class_id": 1, "object_id_list": [1, 3]}, strict=True + ) + + def test_add_entity_groups(self): + """Test that adding group entities works.""" + self._db_map.add_object_classes({"name": "oc1", "id": 1}) + self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 1}) + self._db_map.add_entity_groups({"entity_id": 1, "entity_class_id": 1, "member_id": 2}) + self._db_map.commit_session("add") + table = self._db_map.get_table("entity_group") + entity_groups = self._db_map.query(table).all() + self.assertEqual(len(entity_groups), 1) + self.assertEqual(entity_groups[0].entity_id, 1) + self.assertEqual(entity_groups[0].entity_class_id, 1) + self.assertEqual(entity_groups[0].member_id, 2) + + def test_add_entity_groups_with_invalid_class(self): + """Test that adding group entities with an invalid class fails.""" + self._db_map.add_object_classes({"name": "oc1", "id": 1}) + self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 1}) + with self.assertRaises(SpineIntegrityError): + self._db_map.add_entity_groups({"entity_id": 1, "entity_class_id": 2, "member_id": 2}, strict=True) + + def test_add_entity_groups_with_invalid_entity(self): + """Test that adding group entities with an invalid entity fails.""" + self._db_map.add_object_classes({"name": "oc1", "id": 1}) + self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 1}) + with self.assertRaises(SpineIntegrityError): + self._db_map.add_entity_groups({"entity_id": 3, "entity_class_id": 2, "member_id": 2}, strict=True) + + def test_add_entity_groups_with_invalid_member(self): + """Test that adding group entities with an invalid member fails.""" + self._db_map.add_object_classes({"name": "oc1", "id": 1}) + self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 1}) + with self.assertRaises(SpineIntegrityError): + self._db_map.add_entity_groups({"entity_id": 1, "entity_class_id": 2, "member_id": 3}, strict=True) + + def test_add_repeated_entity_groups(self): + """Test that adding repeated group entities fails.""" + self._db_map.add_object_classes({"name": "oc1", "id": 1}) + self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 1}) + self._db_map.add_entity_groups({"entity_id": 1, "entity_class_id": 2, "member_id": 2}) + with self.assertRaises(SpineIntegrityError): + self._db_map.add_entity_groups({"entity_id": 1, "entity_class_id": 2, "member_id": 2}, strict=True) + + def test_add_parameter_definitions(self): + """Test that adding parameter definitions works.""" + self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) + self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "object_class_id_list": [1, 2]}) + self._db_map.add_parameter_definitions( + {"name": "color", "object_class_id": 1, "description": "test1"}, + {"name": "relative_speed", "relationship_class_id": 3, "description": "test2"}, + ) + self._db_map.commit_session("add") + table = self._db_map.get_table("parameter_definition") + parameter_definitions = self._db_map.query(table).all() + self.assertEqual(len(parameter_definitions), 2) + self.assertEqual(parameter_definitions[0].name, "color") + self.assertEqual(parameter_definitions[0].entity_class_id, 1) + self.assertEqual(parameter_definitions[0].description, "test1") + self.assertEqual(parameter_definitions[1].name, "relative_speed") + self.assertEqual(parameter_definitions[1].entity_class_id, 3) + self.assertEqual(parameter_definitions[1].description, "test2") + + def test_add_parameter_with_invalid_name(self): + """Test that adding object classes with empty name raises error""" + self._db_map.add_object_classes({"name": "oc1"}, strict=True) + with self.assertRaises(SpineIntegrityError): + self._db_map.add_parameter_definitions({"name": "", "object_class_id": 1}, strict=True) + + def test_add_parameter_definitions_with_same_name(self): + """Test that adding two parameter_definitions with the same name adds both of them.""" + self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) + self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "object_class_id_list": [1, 2]}) + self._db_map.add_parameter_definitions( + {"name": "color", "object_class_id": 1}, {"name": "color", "relationship_class_id": 3} + ) + self._db_map.commit_session("add") + table = self._db_map.get_table("parameter_definition") + parameter_definitions = self._db_map.query(table).all() + self.assertEqual(len(parameter_definitions), 2) + self.assertEqual(parameter_definitions[0].name, "color") + self.assertEqual(parameter_definitions[1].name, "color") + self.assertEqual(parameter_definitions[0].entity_class_id, 1) + + def test_add_parameter_with_same_name_as_existing_one(self): + """Test that adding parameter_definitions with an already taken name raises and integrity error.""" + self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) + self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "object_class_id_list": [1, 2]}) + self._db_map.add_parameter_definitions( + {"name": "color", "object_class_id": 1}, {"name": "color", "relationship_class_id": 3} + ) + with self.assertRaises(SpineIntegrityError): + self._db_map.add_parameter_definitions({"name": "color", "object_class_id": 1}, strict=True) + + def test_add_parameter_values(self): + """Test that adding parameter values works.""" + import_functions.import_object_classes(self._db_map, ["fish", "dog"]) + import_functions.import_relationship_classes(self._db_map, [("fish_dog", ["fish", "dog"])]) + import_functions.import_objects(self._db_map, [("fish", "nemo"), ("dog", "pluto")]) + import_functions.import_relationships(self._db_map, [("fish_dog", ("nemo", "pluto"))]) + import_functions.import_object_parameters(self._db_map, [("fish", "color")]) + import_functions.import_relationship_parameters(self._db_map, [("fish_dog", "rel_speed")]) + self._db_map.commit_session("add") + color_id = ( + self._db_map.query(self._db_map.parameter_definition_sq) + .filter(self._db_map.parameter_definition_sq.c.name == "color") + .first() + .id + ) + rel_speed_id = ( + self._db_map.query(self._db_map.parameter_definition_sq) + .filter(self._db_map.parameter_definition_sq.c.name == "rel_speed") + .first() + .id + ) + nemo_row = self._db_map.query(self._db_map.object_sq).filter(self._db_map.object_sq.c.name == "nemo").first() + nemo__pluto_row = self._db_map.query(self._db_map.wide_relationship_sq).first() + self._db_map.add_parameter_values( + { + "parameter_definition_id": color_id, + "entity_id": nemo_row.id, + "entity_class_id": nemo_row.class_id, + "value": b'"orange"', + "alternative_id": 1, + }, + { + "parameter_definition_id": rel_speed_id, + "entity_id": nemo__pluto_row.id, + "entity_class_id": nemo__pluto_row.class_id, + "value": b"125", + "alternative_id": 1, + }, + ) + self._db_map.commit_session("add") + table = self._db_map.get_table("parameter_value") + parameter_values = self._db_map.query(table).all() + self.assertEqual(len(parameter_values), 2) + self.assertEqual(parameter_values[0].parameter_definition_id, 1) + self.assertEqual(parameter_values[0].entity_id, 1) + self.assertEqual(parameter_values[0].value, b'"orange"') + self.assertEqual(parameter_values[1].parameter_definition_id, 2) + self.assertEqual(parameter_values[1].entity_id, 3) + self.assertEqual(parameter_values[1].value, b"125") + + def test_add_parameter_value_with_invalid_object_or_relationship(self): + """Test that adding a parameter value with an invalid object or relationship raises an + integrity error.""" + import_functions.import_object_classes(self._db_map, ["fish", "dog"]) + import_functions.import_relationship_classes(self._db_map, [("fish_dog", ["fish", "dog"])]) + import_functions.import_objects(self._db_map, [("fish", "nemo"), ("dog", "pluto")]) + import_functions.import_relationships(self._db_map, [("fish_dog", ("nemo", "pluto"))]) + import_functions.import_object_parameters(self._db_map, [("fish", "color")]) + import_functions.import_relationship_parameters(self._db_map, [("fish_dog", "rel_speed")]) + _, errors = self._db_map.add_parameter_values( + {"parameter_definition_id": 1, "object_id": 3, "value": b'"orange"', "alternative_id": 1}, strict=False + ) + self.assertEqual([str(e) for e in errors], ["invalid entity_class_id for parameter_value"]) + _, errors = self._db_map.add_parameter_values( + {"parameter_definition_id": 2, "relationship_id": 2, "value": b"125", "alternative_id": 1}, strict=False + ) + self.assertEqual([str(e) for e in errors], ["invalid entity_class_id for parameter_value"]) + + def test_add_same_parameter_value_twice(self): + """Test that adding a parameter value twice only adds the first one.""" + import_functions.import_object_classes(self._db_map, ["fish"]) + import_functions.import_objects(self._db_map, [("fish", "nemo")]) + import_functions.import_object_parameters(self._db_map, [("fish", "color")]) + self._db_map.commit_session("add") + color_id = ( + self._db_map.query(self._db_map.parameter_definition_sq) + .filter(self._db_map.parameter_definition_sq.c.name == "color") + .first() + .id + ) + nemo_row = self._db_map.query(self._db_map.object_sq).filter(self._db_map.entity_sq.c.name == "nemo").first() + self._db_map.add_parameter_values( + { + "parameter_definition_id": color_id, + "entity_id": nemo_row.id, + "entity_class_id": nemo_row.class_id, + "value": b'"orange"', + "alternative_id": 1, + }, + { + "parameter_definition_id": color_id, + "entity_id": nemo_row.id, + "entity_class_id": nemo_row.class_id, + "value": b'"blue"', + "alternative_id": 1, + }, + ) + self._db_map.commit_session("add") + table = self._db_map.get_table("parameter_value") + parameter_values = self._db_map.query(table).all() + self.assertEqual(len(parameter_values), 1) + self.assertEqual(parameter_values[0].parameter_definition_id, 1) + self.assertEqual(parameter_values[0].entity_id, 1) + self.assertEqual(parameter_values[0].value, b'"orange"') + + def test_add_existing_parameter_value(self): + """Test that adding an existing parameter value raises an integrity error.""" + import_functions.import_object_classes(self._db_map, ["fish"]) + import_functions.import_objects(self._db_map, [("fish", "nemo")]) + import_functions.import_object_parameters(self._db_map, [("fish", "color")]) + import_functions.import_object_parameter_values(self._db_map, [("fish", "nemo", "color", "orange")]) + self._db_map.commit_session("add") + _, errors = self._db_map.add_parameter_values( + { + "parameter_definition_id": 1, + "entity_class_id": 1, + "entity_id": 1, + "value": b'"blue"', + "alternative_id": 1, + }, + strict=False, + ) + self.assertEqual( + [str(e) for e in errors], + [ + "there's already a parameter_value with " + "{'parameter_definition_name': 'color', 'entity_byname': ('nemo',), 'alternative_name': 'Base'}" + ], + ) + + def test_add_alternative(self): + items, errors = self._db_map.add_alternatives({"name": "my_alternative"}) + self.assertEqual(errors, []) + self.assertEqual(len(items), 1) + self._db_map.commit_session("Add test data.") + alternatives = self._db_map.query(self._db_map.alternative_sq).all() + self.assertEqual(len(alternatives), 2) + self.assertEqual( + dict(alternatives[0]), {"id": 1, "name": "Base", "description": "Base alternative", "commit_id": 1} + ) + self.assertEqual( + dict(alternatives[1]), {"id": 2, "name": "my_alternative", "description": None, "commit_id": 2} + ) + + def test_add_scenario(self): + items, errors = self._db_map.add_scenarios({"name": "my_scenario"}) + self.assertEqual(errors, []) + self.assertEqual(len(items), 1) + self._db_map.commit_session("Add test data.") + scenarios = self._db_map.query(self._db_map.scenario_sq).all() + self.assertEqual(len(scenarios), 1) + self.assertEqual( + dict(scenarios[0]), + {"id": 1, "name": "my_scenario", "description": None, "active": False, "commit_id": 2}, + ) + + def test_add_scenario_alternative(self): + import_functions.import_scenarios(self._db_map, ("my_scenario",)) + self._db_map.commit_session("Add test data.") + items, errors = self._db_map.add_scenario_alternatives({"scenario_id": 1, "alternative_id": 1, "rank": 0}) + self.assertEqual(errors, []) + self.assertEqual(len(items), 1) + self._db_map.commit_session("Add test data.") + scenario_alternatives = self._db_map.query(self._db_map.scenario_alternative_sq).all() + self.assertEqual(len(scenario_alternatives), 1) + self.assertEqual( + dict(scenario_alternatives[0]), + {"id": 1, "scenario_id": 1, "alternative_id": 1, "rank": 0, "commit_id": 3}, + ) + + def test_add_metadata(self): + items, errors = self._db_map.add_metadata({"name": "test name", "value": "test_add_metadata"}, strict=False) + self.assertEqual(errors, []) + self.assertEqual(len(items), 1) + self._db_map.commit_session("Add metadata") + metadata = self._db_map.query(self._db_map.metadata_sq).all() + self.assertEqual(len(metadata), 1) + self.assertEqual( + dict(metadata[0]), {"name": "test name", "id": 1, "value": "test_add_metadata", "commit_id": 2} + ) + + def test_add_metadata_that_exists_does_not_add_it(self): + import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) + self._db_map.commit_session("Add test data.") + items, _ = self._db_map.add_metadata({"name": "title", "value": "My metadata."}, strict=False) + self.assertEqual(items, []) + metadata = self._db_map.query(self._db_map.metadata_sq).all() + self.assertEqual(len(metadata), 1) + self.assertEqual(dict(metadata[0]), {"name": "title", "id": 1, "value": "My metadata.", "commit_id": 2}) + + def test_add_entity_metadata_for_object(self): + import_functions.import_object_classes(self._db_map, ("fish",)) + import_functions.import_objects(self._db_map, (("fish", "leviathan"),)) + import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) + self._db_map.commit_session("Add test data.") + items, errors = self._db_map.add_entity_metadata({"entity_id": 1, "metadata_id": 1}, strict=False) + self.assertEqual(errors, []) + self.assertEqual(len(items), 1) + self._db_map.commit_session("Add entity metadata") + entity_metadata = self._db_map.query(self._db_map.ext_entity_metadata_sq).all() + self.assertEqual(len(entity_metadata), 1) + self.assertEqual( + dict(entity_metadata[0]), + { + "entity_id": 1, + "entity_name": "leviathan", + "metadata_name": "title", + "metadata_value": "My metadata.", + "metadata_id": 1, + "id": 1, + "commit_id": 3, + }, + ) + + def test_add_entity_metadata_for_relationship(self): + import_functions.import_object_classes(self._db_map, ("my_object_class",)) + import_functions.import_objects(self._db_map, (("my_object_class", "my_object"),)) + import_functions.import_relationship_classes(self._db_map, (("my_relationship_class", ("my_object_class",)),)) + import_functions.import_relationships(self._db_map, (("my_relationship_class", ("my_object",)),)) + import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) + self._db_map.commit_session("Add test data.") + items, errors = self._db_map.add_entity_metadata({"entity_id": 2, "metadata_id": 1}, strict=False) + self.assertEqual(errors, []) + self.assertEqual(len(items), 1) + self._db_map.commit_session("Add entity metadata") + entity_metadata = self._db_map.query(self._db_map.ext_entity_metadata_sq).all() + self.assertEqual(len(entity_metadata), 1) + self.assertEqual( + dict(entity_metadata[0]), + { + "entity_id": 2, + "entity_name": "my_relationship_class_my_object", + "metadata_name": "title", + "metadata_value": "My metadata.", + "metadata_id": 1, + "id": 1, + "commit_id": 3, + }, + ) + + def test_add_entity_metadata_doesnt_raise_with_empty_cache(self): + items, errors = self._db_map.add_entity_metadata({"entity_id": 1, "metadata_id": 1}, strict=False) + self.assertEqual(items, []) + self.assertEqual(len(errors), 1) + + def test_add_ext_entity_metadata_for_object(self): + import_functions.import_object_classes(self._db_map, ("fish",)) + import_functions.import_objects(self._db_map, (("fish", "leviathan"),)) + self._db_map.commit_session("Add test data.") + items, errors = self._db_map.add_ext_entity_metadata( + {"entity_id": 1, "metadata_name": "key", "metadata_value": "object metadata"}, strict=False + ) + self.assertEqual(errors, []) + self.assertEqual(len(items), 1) + self._db_map.commit_session("Add entity metadata") + entity_metadata = self._db_map.query(self._db_map.ext_entity_metadata_sq).all() + self.assertEqual(len(entity_metadata), 1) + self.assertEqual( + dict(entity_metadata[0]), + { + "entity_id": 1, + "entity_name": "leviathan", + "metadata_name": "key", + "metadata_value": "object metadata", + "metadata_id": 1, + "id": 1, + "commit_id": 3, + }, + ) + + def test_adding_ext_entity_metadata_for_object_reuses_existing_metadata_names_and_values(self): + import_functions.import_object_classes(self._db_map, ("fish",)) + import_functions.import_objects(self._db_map, (("fish", "leviathan"),)) + import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) + self._db_map.commit_session("Add test data.") + items, errors = self._db_map.add_ext_entity_metadata( + {"entity_id": 1, "metadata_name": "title", "metadata_value": "My metadata."}, strict=False + ) + self.assertEqual(errors, []) + self.assertEqual(len(items), 1) + self._db_map.commit_session("Add entity metadata") + metadata = self._db_map.query(self._db_map.metadata_sq).all() + self.assertEqual(len(metadata), 1) + self.assertEqual(dict(metadata[0]), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": 2}) + entity_metadata = self._db_map.query(self._db_map.ext_entity_metadata_sq).all() + self.assertEqual(len(entity_metadata), 1) + self.assertEqual( + dict(entity_metadata[0]), + { + "entity_id": 1, + "entity_name": "leviathan", + "metadata_name": "title", + "metadata_value": "My metadata.", + "metadata_id": 1, + "id": 1, + "commit_id": 3, + }, + ) + + def test_add_parameter_value_metadata(self): + import_functions.import_object_classes(self._db_map, ("fish",)) + import_functions.import_objects(self._db_map, (("fish", "leviathan"),)) + import_functions.import_object_parameters(self._db_map, (("fish", "paranormality"),)) + import_functions.import_object_parameter_values(self._db_map, (("fish", "leviathan", "paranormality", 3.9),)) + import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) + self._db_map.commit_session("Add test data.") + items, errors = self._db_map.add_parameter_value_metadata( + {"parameter_value_id": 1, "metadata_id": 1, "alternative_id": 1}, strict=False + ) + self.assertEqual(errors, []) + self.assertEqual(len(items), 1) + self._db_map.commit_session("Add value metadata") + value_metadata = self._db_map.query(self._db_map.ext_parameter_value_metadata_sq).all() + self.assertEqual(len(value_metadata), 1) + self.assertEqual( + dict(value_metadata[0]), + { + "alternative_name": "Base", + "entity_name": "leviathan", + "parameter_value_id": 1, + "parameter_name": "paranormality", + "metadata_name": "title", + "metadata_value": "My metadata.", + "metadata_id": 1, + "id": 1, + "commit_id": 3, + }, + ) + + def test_add_parameter_value_metadata_doesnt_raise_with_empty_cache(self): + items, errors = self._db_map.add_parameter_value_metadata( + {"parameter_value_id": 1, "metadata_id": 1, "alternative_id": 1} + ) + self.assertEqual(len(items), 0) + self.assertEqual(len(errors), 1) + + def test_add_ext_parameter_value_metadata(self): + import_functions.import_object_classes(self._db_map, ("fish",)) + import_functions.import_objects(self._db_map, (("fish", "leviathan"),)) + import_functions.import_object_parameters(self._db_map, (("fish", "paranormality"),)) + import_functions.import_object_parameter_values(self._db_map, (("fish", "leviathan", "paranormality", 3.9),)) + self._db_map.commit_session("Add test data.") + items, errors = self._db_map.add_ext_parameter_value_metadata( + { + "parameter_value_id": 1, + "metadata_name": "key", + "metadata_value": "parameter metadata", + "alternative_id": 1, + }, + strict=False, + ) + self.assertEqual(errors, []) + self.assertEqual(len(items), 1) + self._db_map.commit_session("Add value metadata") + value_metadata = self._db_map.query(self._db_map.ext_parameter_value_metadata_sq).all() + self.assertEqual(len(value_metadata), 1) + self.assertEqual( + dict(value_metadata[0]), + { + "alternative_name": "Base", + "entity_name": "leviathan", + "parameter_value_id": 1, + "parameter_name": "paranormality", + "metadata_name": "key", + "metadata_value": "parameter metadata", + "metadata_id": 1, + "id": 1, + "commit_id": 3, + }, + ) + + def test_add_ext_parameter_value_metadata_reuses_existing_metadata(self): + import_functions.import_object_classes(self._db_map, ("fish",)) + import_functions.import_objects(self._db_map, (("fish", "leviathan"),)) + import_functions.import_object_parameters(self._db_map, (("fish", "paranormality"),)) + import_functions.import_object_parameter_values(self._db_map, (("fish", "leviathan", "paranormality", 3.9),)) + import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) + self._db_map.commit_session("Add test data.") + items, errors = self._db_map.add_ext_parameter_value_metadata( + {"parameter_value_id": 1, "metadata_name": "title", "metadata_value": "My metadata.", "alternative_id": 1}, + strict=False, + ) + self.assertEqual(errors, []) + self.assertEqual(len(items), 1) + self._db_map.commit_session("Add value metadata") + metadata = self._db_map.query(self._db_map.metadata_sq).all() + self.assertEqual(len(metadata), 1) + self.assertEqual(dict(metadata[0]), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": 2}) + value_metadata = self._db_map.query(self._db_map.ext_parameter_value_metadata_sq).all() + self.assertEqual(len(value_metadata), 1) + self.assertEqual( + dict(value_metadata[0]), + { + "alternative_name": "Base", + "entity_name": "leviathan", + "parameter_value_id": 1, + "parameter_name": "paranormality", + "metadata_name": "title", + "metadata_value": "My metadata.", + "metadata_id": 1, + "id": 1, + "commit_id": 3, + }, + ) + + +class TestDatabaseMappingUpdate(unittest.TestCase): def setUp(self): self._db_map = DatabaseMapping(IN_MEMORY_DB_URL, create=True) def tearDown(self): self._db_map.close() - def test_update_wide_relationship_class(self): + def test_update_object_classes(self): + """Test that updating object classes works.""" + self._db_map.add_object_classes({"id": 1, "name": "fish"}, {"id": 2, "name": "dog"}) + items, intgr_error_log = self._db_map.update_object_classes( + {"id": 1, "name": "octopus"}, {"id": 2, "name": "god"} + ) + ids = {x["id"] for x in items} + self._db_map.commit_session("test commit") + sq = self._db_map.object_class_sq + object_classes = {x.id: x.name for x in self._db_map.query(sq).filter(sq.c.id.in_(ids))} + self.assertEqual(intgr_error_log, []) + self.assertEqual(object_classes[1], "octopus") + self.assertEqual(object_classes[2], "god") + + def test_update_objects(self): + """Test that updating objects works.""" + self._db_map.add_object_classes({"id": 1, "name": "fish"}) + self._db_map.add_objects({"id": 1, "name": "nemo", "class_id": 1}, {"id": 2, "name": "dory", "class_id": 1}) + items, intgr_error_log = self._db_map.update_objects({"id": 1, "name": "klaus"}, {"id": 2, "name": "squidward"}) + ids = {x["id"] for x in items} + self._db_map.commit_session("test commit") + sq = self._db_map.object_sq + objects = {x.id: x.name for x in self._db_map.query(sq).filter(sq.c.id.in_(ids))} + self.assertEqual(intgr_error_log, []) + self.assertEqual(objects[1], "klaus") + self.assertEqual(objects[2], "squidward") + + def test_update_committed_object(self): + """Test that updating objects works.""" + self._db_map.add_object_classes({"id": 1, "name": "some_class"}) + self._db_map.add_objects({"id": 1, "name": "nemo", "class_id": 1}) + self._db_map.commit_session("update") + items, intgr_error_log = self._db_map.update_objects({"id": 1, "name": "klaus"}) + ids = {x["id"] for x in items} + self._db_map.commit_session("test commit") + sq = self._db_map.object_sq + objects = {x.id: x.name for x in self._db_map.query(sq).filter(sq.c.id.in_(ids))} + self.assertEqual(intgr_error_log, []) + self.assertEqual(objects[1], "klaus") + self.assertEqual(self._db_map.query(self._db_map.object_sq).filter_by(id=1).first().name, "klaus") + + def test_update_relationship_classes(self): + """Test that updating relationship classes works.""" + self._db_map.add_object_classes({"name": "dog", "id": 1}, {"name": "fish", "id": 2}) + self._db_map.add_wide_relationship_classes( + {"id": 3, "name": "dog__fish", "object_class_id_list": [1, 2]}, + {"id": 4, "name": "fish__dog", "object_class_id_list": [2, 1]}, + ) + items, intgr_error_log = self._db_map.update_wide_relationship_classes( + {"id": 3, "name": "god__octopus"}, {"id": 4, "name": "octopus__dog"} + ) + ids = {x["id"] for x in items} + self._db_map.commit_session("test commit") + sq = self._db_map.wide_relationship_class_sq + rel_clss = {x.id: x.name for x in self._db_map.query(sq).filter(sq.c.id.in_(ids))} + self.assertEqual(intgr_error_log, []) + self.assertEqual(rel_clss[3], "god__octopus") + self.assertEqual(rel_clss[4], "octopus__dog") + + def test_update_committed_relationship_class(self): _ = import_functions.import_object_classes(self._db_map, ("object_class_1",)) _ = import_functions.import_relationship_classes(self._db_map, (("my_class", ("object_class_1",)),)) self._db_map.commit_session("Add test data") @@ -559,7 +1452,7 @@ def test_update_wide_relationship_class(self): self.assertEqual(len(classes), 1) self.assertEqual(classes[0].name, "renamed") - def test_update_wide_relationship_class_does_not_update_member_class_id(self): + def test_update_relationship_class_does_not_update_member_class_id(self): import_functions.import_object_classes(self._db_map, ("object_class_1", "object_class_2")) import_functions.import_relationship_classes(self._db_map, (("my_class", ("object_class_1",)),)) self._db_map.commit_session("Add test data") @@ -574,7 +1467,33 @@ def test_update_wide_relationship_class_does_not_update_member_class_id(self): self.assertEqual(classes[0].name, "renamed") self.assertEqual(classes[0].object_class_name_list, "object_class_1") - def test_update_wide_relationship(self): + def test_update_relationships(self): + """Test that updating relationships works.""" + self._db_map.add_object_classes({"name": "fish", "id": 1}, {"name": "dog", "id": 2}) + self._db_map.add_wide_relationship_classes({"name": "fish__dog", "id": 3, "object_class_id_list": [1, 2]}) + self._db_map.add_objects( + {"name": "nemo", "id": 1, "class_id": 1}, + {"name": "pluto", "id": 2, "class_id": 2}, + {"name": "scooby", "id": 3, "class_id": 2}, + ) + self._db_map.add_wide_relationships( + {"id": 4, "name": "nemo__pluto", "class_id": 3, "object_id_list": [1, 2], "object_class_id_list": [1, 2]} + ) + items, intgr_error_log = self._db_map.update_wide_relationships( + {"id": 4, "name": "nemo__scooby", "class_id": 3, "object_id_list": [1, 3], "object_class_id_list": [1, 2]} + ) + ids = {x["id"] for x in items} + self._db_map.commit_session("test commit") + sq = self._db_map.wide_relationship_sq + rels = { + x.id: {"name": x.name, "object_id_list": x.object_id_list} + for x in self._db_map.query(sq).filter(sq.c.id.in_(ids)) + } + self.assertEqual(intgr_error_log, []) + self.assertEqual(rels[4]["name"], "nemo__scooby") + self.assertEqual(rels[4]["object_id_list"], "1,3") + + def test_update_committed_relationship(self): import_functions.import_object_classes(self._db_map, ("object_class_1", "object_class_2")) import_functions.import_objects( self._db_map, @@ -860,6 +1779,329 @@ def setUp(self): def tearDown(self): self._db_map.close() + def test_remove_object_class(self): + """Test adding and removing an object class and committing""" + items, _ = self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) + self.assertEqual(len(items), 2) + self._db_map.remove_items("object_class", 1, 2) + with self.assertRaises(SpineDBAPIError): + # Nothing to commit + self._db_map.commit_session("delete") + self.assertEqual(len(self._db_map.query(self._db_map.object_class_sq).all()), 0) + + def test_remove_object_class_from_committed_session(self): + """Test removing an object class from a committed session""" + items, _ = self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) + self._db_map.commit_session("add") + self.assertEqual(len(self._db_map.query(self._db_map.object_class_sq).all()), 2) + self._db_map.remove_items("object_class", *{x["id"] for x in items}) + self._db_map.commit_session("Add test data.") + self.assertEqual(len(self._db_map.query(self._db_map.object_class_sq).all()), 0) + + def test_remove_object(self): + """Test adding and removing an object and committing""" + self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) + items, _ = self._db_map.add_objects( + {"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 2} + ) + self._db_map.remove_items("object", *{x["id"] for x in items}) + self._db_map.commit_session("delete") + self.assertEqual(len(self._db_map.query(self._db_map.object_sq).all()), 0) + + def test_remove_object_from_committed_session(self): + """Test removing an object from a committed session""" + self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) + items, _ = self._db_map.add_objects( + {"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 2} + ) + self._db_map.commit_session("add") + self.assertEqual(len(self._db_map.query(self._db_map.object_sq).all()), 2) + self._db_map.remove_items("object", *{x["id"] for x in items}) + self._db_map.commit_session("Add test data.") + self.assertEqual(len(self._db_map.query(self._db_map.object_sq).all()), 0) + + def test_remove_entity_group(self): + """Test adding and removing an entity group and committing""" + self._db_map.add_object_classes({"name": "oc1", "id": 1}) + self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 1}) + items, _ = self._db_map.add_entity_groups({"entity_id": 1, "entity_class_id": 1, "member_id": 2}) + self._db_map.remove_items("entity_group", *{x["id"] for x in items}) + self._db_map.commit_session("delete") + self.assertEqual(len(self._db_map.query(self._db_map.entity_group_sq).all()), 0) + + def test_remove_entity_group_from_committed_session(self): + """Test removing an entity group from a committed session""" + self._db_map.add_object_classes({"name": "oc1", "id": 1}) + self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 1}) + self._db_map.add_entity_groups({"entity_id": 1, "entity_class_id": 1, "member_id": 2}) + self._db_map.commit_session("add") + self.assertEqual(len(self._db_map.query(self._db_map.entity_group_sq).all()), 1) + self._db_map.remove_items("entity_group", 1) + self._db_map.commit_session("delete") + self.assertEqual(len(self._db_map.query(self._db_map.entity_group_sq).all()), 0) + + def test_cascade_remove_relationship_class(self): + """Test adding and removing a relationship class and committing""" + self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) + items, _ = self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "object_class_id_list": [1, 2]}) + self._db_map.remove_items("relationship_class", *{x["id"] for x in items}) + self._db_map.commit_session("delete") + self.assertEqual(len(self._db_map.query(self._db_map.wide_relationship_class_sq).all()), 0) + + def test_cascade_remove_relationship_class_from_committed_session(self): + """Test removing a relationship class from a committed session""" + self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) + items, _ = self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "object_class_id_list": [1, 2]}) + self._db_map.commit_session("add") + self.assertEqual(len(self._db_map.query(self._db_map.wide_relationship_class_sq).all()), 1) + self._db_map.remove_items("relationship_class", *{x["id"] for x in items}) + self._db_map.commit_session("remove") + self.assertEqual(len(self._db_map.query(self._db_map.wide_relationship_class_sq).all()), 0) + + def test_cascade_remove_relationship(self): + """Test adding and removing a relationship and committing""" + self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) + self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "object_class_id_list": [1, 2]}) + self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 2}) + items, _ = self._db_map.add_wide_relationships( + {"id": 3, "name": "remove_me", "class_id": 3, "object_id_list": [1, 2]} + ) + self._db_map.remove_items("relationship", *{x["id"] for x in items}) + self.assertEqual(len(self._db_map.query(self._db_map.wide_relationship_sq).all()), 0) + self._db_map.commit_session("delete") + self.assertEqual(len(self._db_map.query(self._db_map.wide_relationship_sq).all()), 0) + + def test_cascade_remove_relationship_from_committed_session(self): + """Test removing a relationship from a committed session""" + self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) + self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "object_class_id_list": [1, 2]}) + self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 2}) + items, _ = self._db_map.add_wide_relationships( + {"id": 3, "name": "remove_me", "class_id": 3, "object_id_list": [1, 2]} + ) + self._db_map.commit_session("add") + self.assertEqual(len(self._db_map.query(self._db_map.wide_relationship_sq).all()), 1) + self._db_map.remove_items("relationship", *{x["id"] for x in items}) + self._db_map.commit_session("Add test data.") + self.assertEqual(len(self._db_map.query(self._db_map.wide_relationship_sq).all()), 0) + + def test_remove_parameter_value(self): + """Test adding and removing a parameter value and committing""" + self._db_map.add_object_classes({"name": "oc1", "id": 1}, strict=True) + self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, strict=True) + self._db_map.add_parameter_definitions({"name": "param", "id": 1, "object_class_id": 1}, strict=True) + self._db_map.add_parameter_values( + { + "value": b"0", + "id": 1, + "parameter_definition_id": 1, + "object_id": 1, + "object_class_id": 1, + "alternative_id": 1, + }, + strict=True, + ) + self._db_map.commit_session("add") + self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 1) + self._db_map.remove_items("parameter_value", 1) + self._db_map.commit_session("delete") + self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 0) + + def test_remove_parameter_value_from_committed_session(self): + """Test adding and committing a parameter value and then removing it""" + self._db_map.add_object_classes({"name": "oc1", "id": 1}, strict=True) + self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, strict=True) + self._db_map.add_parameter_definitions({"name": "param", "id": 1, "object_class_id": 1}, strict=True) + self._db_map.add_parameter_values( + { + "value": b"0", + "id": 1, + "parameter_definition_id": 1, + "object_id": 1, + "object_class_id": 1, + "alternative_id": 1, + }, + strict=True, + ) + self._db_map.commit_session("add") + self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 1) + self._db_map.remove_items("parameter_value", 1) + self._db_map.commit_session("delete") + self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 0) + + def test_cascade_remove_object_removes_parameter_value_as_well(self): + """Test adding and removing a parameter value and committing""" + self._db_map.add_object_classes({"name": "oc1", "id": 1}, strict=True) + self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, strict=True) + self._db_map.add_parameter_definitions({"name": "param", "id": 1, "object_class_id": 1}, strict=True) + self._db_map.add_parameter_values( + { + "value": b"0", + "id": 1, + "parameter_definition_id": 1, + "object_id": 1, + "object_class_id": 1, + "alternative_id": 1, + }, + strict=True, + ) + self._db_map.commit_session("add") + self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 1) + self._db_map.remove_items("object", 1) + self._db_map.commit_session("delete") + self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 0) + + def test_cascade_remove_object_from_committed_session_removes_parameter_value_as_well(self): + """Test adding and committing a paramater value and then removing it""" + self._db_map.add_object_classes({"name": "oc1", "id": 1}, strict=True) + self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, strict=True) + self._db_map.add_parameter_definitions({"name": "param", "id": 1, "object_class_id": 1}, strict=True) + self._db_map.add_parameter_values( + { + "value": b"0", + "id": 1, + "parameter_definition_id": 1, + "object_id": 1, + "object_class_id": 1, + "alternative_id": 1, + }, + strict=True, + ) + self._db_map.commit_session("add") + self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 1) + self._db_map.remove_items("object", 1) + self._db_map.commit_session("delete") + self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 0) + + def test_cascade_remove_metadata_removes_corresponding_entity_and_value_metadata(self): + import_functions.import_object_classes(self._db_map, ("my_class",)) + import_functions.import_objects(self._db_map, (("my_class", "my_object"),)) + import_functions.import_object_parameters(self._db_map, (("my_class", "my_parameter"),)) + import_functions.import_object_parameter_values( + self._db_map, (("my_class", "my_object", "my_parameter", 99.0),) + ) + import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) + import_functions.import_object_metadata(self._db_map, (("my_class", "my_object", '{"title": "My metadata."}'),)) + import_functions.import_object_parameter_value_metadata( + self._db_map, (("my_class", "my_object", "my_parameter", '{"title": "My metadata."}'),) + ) + self._db_map.commit_session("Add test data.") + metadata = self._db_map.query(self._db_map.metadata_sq).all() + self.assertEqual(len(metadata), 1) + self._db_map.remove_items("metadata", metadata[0].id) + self._db_map.commit_session("Remove test data.") + self.assertEqual(len(self._db_map.query(self._db_map.metadata_sq).all()), 0) + self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_metadata_sq).all()), 0) + self.assertEqual(len(self._db_map.query(self._db_map.entity_metadata_sq).all()), 0) + self.assertEqual(len(self._db_map.query(self._db_map.object_sq).all()), 1) + self.assertEqual(len(self._db_map.query(self._db_map.object_parameter_definition_sq).all()), 1) + + def test_cascade_remove_entity_metadata_removes_corresponding_metadata(self): + import_functions.import_object_classes(self._db_map, ("my_class",)) + import_functions.import_objects(self._db_map, (("my_class", "my_object"),)) + import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) + import_functions.import_object_metadata(self._db_map, (("my_class", "my_object", '{"title": "My metadata."}'),)) + self._db_map.commit_session("Add test data.") + entity_metadata = self._db_map.query(self._db_map.entity_metadata_sq).all() + self.assertEqual(len(entity_metadata), 1) + self._db_map.remove_items("entity_metadata", entity_metadata[0].id) + self._db_map.remove_unused_metadata() + self._db_map.commit_session("Remove test data.") + self.assertEqual(len(self._db_map.query(self._db_map.metadata_sq).all()), 0) + self.assertEqual(len(self._db_map.query(self._db_map.entity_metadata_sq).all()), 0) + self.assertEqual(len(self._db_map.query(self._db_map.object_sq).all()), 1) + + def test_cascade_remove_entity_metadata_leaves_metadata_used_by_value_intact(self): + import_functions.import_object_classes(self._db_map, ("my_class",)) + import_functions.import_objects(self._db_map, (("my_class", "my_object"),)) + import_functions.import_object_parameters(self._db_map, (("my_class", "my_parameter"),)) + import_functions.import_object_parameter_values( + self._db_map, (("my_class", "my_object", "my_parameter", 99.0),) + ) + import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) + import_functions.import_object_metadata(self._db_map, (("my_class", "my_object", '{"title": "My metadata."}'),)) + import_functions.import_object_parameter_value_metadata( + self._db_map, (("my_class", "my_object", "my_parameter", '{"title": "My metadata."}'),) + ) + self._db_map.commit_session("Add test data.") + entity_metadata = self._db_map.query(self._db_map.entity_metadata_sq).all() + self.assertEqual(len(entity_metadata), 1) + self._db_map.remove_items("entity_metadata", entity_metadata[0].id) + self._db_map.commit_session("Remove test data.") + self.assertEqual(len(self._db_map.query(self._db_map.metadata_sq).all()), 1) + self.assertEqual(len(self._db_map.query(self._db_map.entity_metadata_sq).all()), 0) + self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_metadata_sq).all()), 1) + + def test_cascade_remove_value_metadata_leaves_metadata_used_by_entity_intact(self): + import_functions.import_object_classes(self._db_map, ("my_class",)) + import_functions.import_objects(self._db_map, (("my_class", "my_object"),)) + import_functions.import_object_parameters(self._db_map, (("my_class", "my_parameter"),)) + import_functions.import_object_parameter_values( + self._db_map, (("my_class", "my_object", "my_parameter", 99.0),) + ) + import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) + import_functions.import_object_metadata(self._db_map, (("my_class", "my_object", '{"title": "My metadata."}'),)) + import_functions.import_object_parameter_value_metadata( + self._db_map, (("my_class", "my_object", "my_parameter", '{"title": "My metadata."}'),) + ) + self._db_map.commit_session("Add test data.") + parameter_value_metadata = self._db_map.query(self._db_map.parameter_value_metadata_sq).all() + self.assertEqual(len(parameter_value_metadata), 1) + self._db_map.remove_items("parameter_value_metadata", parameter_value_metadata[0].id) + self._db_map.commit_session("Remove test data.") + self.assertEqual(len(self._db_map.query(self._db_map.metadata_sq).all()), 1) + self.assertEqual(len(self._db_map.query(self._db_map.entity_metadata_sq).all()), 1) + self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_metadata_sq).all()), 0) + + def test_cascade_remove_object_removes_its_metadata(self): + import_functions.import_object_classes(self._db_map, ("my_class",)) + import_functions.import_objects(self._db_map, (("my_class", "my_object"),)) + import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) + import_functions.import_object_metadata(self._db_map, (("my_class", "my_object", '{"title": "My metadata."}'),)) + self._db_map.commit_session("Add test data.") + self._db_map.remove_items("object", 1) + self._db_map.remove_unused_metadata() + self._db_map.commit_session("Remove test data.") + self.assertEqual(len(self._db_map.query(self._db_map.metadata_sq).all()), 0) + self.assertEqual(len(self._db_map.query(self._db_map.entity_metadata_sq).all()), 0) + self.assertEqual(len(self._db_map.query(self._db_map.object_sq).all()), 0) + + def test_cascade_remove_relationship_removes_its_metadata(self): + import_functions.import_object_classes(self._db_map, ("my_object_class",)) + import_functions.import_objects(self._db_map, (("my_object_class", "my_object"),)) + import_functions.import_relationship_classes(self._db_map, (("my_class", ("my_object_class",)),)) + import_functions.import_relationships(self._db_map, (("my_class", ("my_object",)),)) + import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) + import_functions.import_relationship_metadata( + self._db_map, (("my_class", ("my_object",), '{"title": "My metadata."}'),) + ) + self._db_map.commit_session("Add test data.") + self._db_map.remove_items("relationship", 2) + self._db_map.remove_unused_metadata() + self._db_map.commit_session("Remove test data.") + self.assertEqual(len(self._db_map.query(self._db_map.metadata_sq).all()), 0) + self.assertEqual(len(self._db_map.query(self._db_map.entity_metadata_sq).all()), 0) + self.assertEqual(len(self._db_map.query(self._db_map.relationship_sq).all()), 0) + + def test_cascade_remove_parameter_value_removes_its_metadata(self): + import_functions.import_object_classes(self._db_map, ("my_class",)) + import_functions.import_objects(self._db_map, (("my_class", "my_object"),)) + import_functions.import_object_parameters(self._db_map, (("my_class", "my_parameter"),)) + import_functions.import_object_parameter_values( + self._db_map, (("my_class", "my_object", "my_parameter", 99.0),) + ) + import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) + import_functions.import_object_parameter_value_metadata( + self._db_map, (("my_class", "my_object", "my_parameter", '{"title": "My metadata."}'),) + ) + self._db_map.commit_session("Add test data.") + self._db_map.remove_items("parameter_value", 1) + self._db_map.remove_unused_metadata() + self._db_map.commit_session("Remove test data.") + self.assertEqual(len(self._db_map.query(self._db_map.metadata_sq).all()), 0) + self.assertEqual(len(self._db_map.query(self._db_map.entity_metadata_sq).all()), 0) + self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 0) + def test_remove_works_when_entity_groups_are_present(self): import_functions.import_object_classes(self._db_map, ("my_class",)) import_functions.import_objects(self._db_map, (("my_class", "my_object"),)) @@ -872,7 +2114,7 @@ def test_remove_works_when_entity_groups_are_present(self): self.assertEqual(len(objects), 1) self.assertEqual(objects[0].name, "my_group") - def test_remove_object_class(self): + def test_remove_object_class2(self): import_functions.import_object_classes(self._db_map, ("my_class",)) self._db_map.commit_session("Add test data.") my_class = self._db_map.query(self._db_map.object_class_sq).one_or_none() @@ -882,7 +2124,7 @@ def test_remove_object_class(self): my_class = self._db_map.query(self._db_map.object_class_sq).one_or_none() self.assertIsNone(my_class) - def test_remove_relationship_class(self): + def test_remove_relationship_class2(self): import_functions.import_object_classes(self._db_map, ("my_class",)) import_functions.import_relationship_classes(self._db_map, (("my_relationship_class", ("my_class",)),)) self._db_map.commit_session("Add test data.") @@ -893,7 +2135,7 @@ def test_remove_relationship_class(self): my_class = self._db_map.query(self._db_map.relationship_class_sq).one_or_none() self.assertIsNone(my_class) - def test_remove_object(self): + def test_remove_object2(self): import_functions.import_object_classes(self._db_map, ("my_class",)) import_functions.import_objects(self._db_map, (("my_class", "my_object"),)) self._db_map.commit_session("Add test data.") @@ -904,7 +2146,7 @@ def test_remove_object(self): my_object = self._db_map.query(self._db_map.object_sq).one_or_none() self.assertIsNone(my_object) - def test_remove_relationship(self): + def test_remove_relationship2(self): import_functions.import_object_classes(self._db_map, ("my_class",)) import_functions.import_objects(self._db_map, (("my_class", "my_object"),)) import_functions.import_relationship_classes(self._db_map, (("my_relationship_class", ("my_class",)),)) @@ -917,7 +2159,7 @@ def test_remove_relationship(self): my_relationship = self._db_map.query(self._db_map.relationship_sq).one_or_none() self.assertIsNone(my_relationship) - def test_remove_parameter_value(self): + def test_remove_parameter_value2(self): import_functions.import_object_classes(self._db_map, ("my_class",)) import_functions.import_objects(self._db_map, (("my_class", "my_object"),)) import_functions.import_object_parameters(self._db_map, (("my_class", "my_parameter"),)) diff --git a/tests/test_DiffDatabaseMapping.py b/tests/test_DiffDatabaseMapping.py deleted file mode 100644 index cd7af14e..00000000 --- a/tests/test_DiffDatabaseMapping.py +++ /dev/null @@ -1,1317 +0,0 @@ -###################################################################################################################### -# Copyright (C) 2017-2022 Spine project consortium -# This file is part of Spine Database API. -# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser -# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your -# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; -# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General -# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with -# this program. If not, see . -###################################################################################################################### - -""" -Unit tests for DatabaseMapping class. - -""" - -import os.path -from tempfile import TemporaryDirectory -import unittest -from unittest import mock -from sqlalchemy.engine.url import make_url, URL -from sqlalchemy.util import KeyedTuple -from spinedb_api.db_mapping import DatabaseMapping -from spinedb_api.exception import SpineIntegrityError -from spinedb_api import import_functions, SpineDBAPIError - - -def create_query_wrapper(db_map): - def query_wrapper(*args, orig_query=db_map.query, **kwargs): - arg = args[0] - if isinstance(arg, mock.Mock): - return arg.value - return orig_query(*args, **kwargs) - - return query_wrapper - - -IN_MEMORY_DB_URL = "sqlite://" - - -def create_diff_db_map(): - return DatabaseMapping(IN_MEMORY_DB_URL, username="UnitTest", create=True) - - -class TestDatabaseMappingConstruction(unittest.TestCase): - def test_construction_with_filters(self): - db_url = IN_MEMORY_DB_URL + "?spinedbfilter=fltr1&spinedbfilter=fltr2" - with mock.patch("spinedb_api.db_mapping.apply_filter_stack") as mock_apply: - with mock.patch( - "spinedb_api.db_mapping.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] - ) as mock_load: - db_map = DatabaseMapping(db_url, create=True) - db_map.close() - mock_load.assert_called_once_with(["fltr1", "fltr2"]) - mock_apply.assert_called_once_with(db_map, [{"fltr1": "config1", "fltr2": "config2"}]) - - def test_construction_with_sqlalchemy_url_and_filters(self): - db_url = IN_MEMORY_DB_URL + "/?spinedbfilter=fltr1&spinedbfilter=fltr2" - sa_url = make_url(db_url) - with mock.patch("spinedb_api.db_mapping.apply_filter_stack") as mock_apply: - with mock.patch( - "spinedb_api.db_mapping.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] - ) as mock_load: - db_map = DatabaseMapping(sa_url, create=True) - db_map.close() - mock_load.assert_called_once_with(["fltr1", "fltr2"]) - mock_apply.assert_called_once_with(db_map, [{"fltr1": "config1", "fltr2": "config2"}]) - - def test_shorthand_filter_query_works(self): - with TemporaryDirectory() as temp_dir: - url = URL("sqlite") - url.database = os.path.join(temp_dir, "test_shorthand_filter_query_works.json") - out_db_map = DatabaseMapping(url, create=True) - out_db_map.add_scenarios({"name": "scen1"}) - out_db_map.add_scenario_alternatives({"scenario_name": "scen1", "alternative_name": "Base", "rank": 1}) - out_db_map.commit_session("Add scen.") - out_db_map.close() - try: - db_map = DatabaseMapping(url) - except: - self.fail("DatabaseMapping.__init__() should not raise.") - else: - db_map.close() - - -class TestDatabaseMappingRemove(unittest.TestCase): - def setUp(self): - self._db_map = create_diff_db_map() - - def tearDown(self): - self._db_map.close() - - def test_cascade_remove_relationship(self): - """Test adding and removing a relationship and committing""" - self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) - self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "object_class_id_list": [1, 2]}) - self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 2}) - items, _ = self._db_map.add_wide_relationships( - {"id": 3, "name": "remove_me", "class_id": 3, "object_id_list": [1, 2]} - ) - self._db_map.remove_items("relationship", *{x["id"] for x in items}) - self.assertEqual(len(self._db_map.query(self._db_map.wide_relationship_sq).all()), 0) - self._db_map.commit_session("delete") - self.assertEqual(len(self._db_map.query(self._db_map.wide_relationship_sq).all()), 0) - - def test_cascade_remove_relationship_from_committed_session(self): - """Test removing a relationship from a committed session""" - self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) - self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "object_class_id_list": [1, 2]}) - self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 2}) - items, _ = self._db_map.add_wide_relationships( - {"id": 3, "name": "remove_me", "class_id": 3, "object_id_list": [1, 2]} - ) - self._db_map.commit_session("add") - self.assertEqual(len(self._db_map.query(self._db_map.wide_relationship_sq).all()), 1) - self._db_map.remove_items("relationship", *{x["id"] for x in items}) - self._db_map.commit_session("Add test data.") - self.assertEqual(len(self._db_map.query(self._db_map.wide_relationship_sq).all()), 0) - - def test_remove_object(self): - """Test adding and removing an object and committing""" - self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) - items, _ = self._db_map.add_objects( - {"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 2} - ) - self._db_map.remove_items("object", *{x["id"] for x in items}) - self._db_map.commit_session("delete") - self.assertEqual(len(self._db_map.query(self._db_map.object_sq).all()), 0) - - def test_remove_object_from_committed_session(self): - """Test removing an object from a committed session""" - self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) - items, _ = self._db_map.add_objects( - {"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 2} - ) - self._db_map.commit_session("add") - self.assertEqual(len(self._db_map.query(self._db_map.object_sq).all()), 2) - self._db_map.remove_items("object", *{x["id"] for x in items}) - self._db_map.commit_session("Add test data.") - self.assertEqual(len(self._db_map.query(self._db_map.object_sq).all()), 0) - - def test_remove_entity_group(self): - """Test adding and removing an entity group and committing""" - self._db_map.add_object_classes({"name": "oc1", "id": 1}) - self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 1}) - items, _ = self._db_map.add_entity_groups({"entity_id": 1, "entity_class_id": 1, "member_id": 2}) - self._db_map.remove_items("entity_group", *{x["id"] for x in items}) - self._db_map.commit_session("delete") - self.assertEqual(len(self._db_map.query(self._db_map.entity_group_sq).all()), 0) - - def test_remove_entity_group_from_committed_session(self): - """Test removing an entity group from a committed session""" - self._db_map.add_object_classes({"name": "oc1", "id": 1}) - self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 1}) - self._db_map.add_entity_groups({"entity_id": 1, "entity_class_id": 1, "member_id": 2}) - self._db_map.commit_session("add") - self.assertEqual(len(self._db_map.query(self._db_map.entity_group_sq).all()), 1) - self._db_map.remove_items("entity_group", 1) - self._db_map.commit_session("delete") - self.assertEqual(len(self._db_map.query(self._db_map.entity_group_sq).all()), 0) - - def test_cascade_remove_relationship_class(self): - """Test adding and removing a relationship class and committing""" - self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) - items, _ = self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "object_class_id_list": [1, 2]}) - self._db_map.remove_items("relationship_class", *{x["id"] for x in items}) - self._db_map.commit_session("delete") - self.assertEqual(len(self._db_map.query(self._db_map.wide_relationship_class_sq).all()), 0) - - def test_cascade_remove_relationship_class_from_committed_session(self): - """Test removing a relationship class from a committed session""" - self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) - items, _ = self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "object_class_id_list": [1, 2]}) - self._db_map.commit_session("add") - self.assertEqual(len(self._db_map.query(self._db_map.wide_relationship_class_sq).all()), 1) - self._db_map.remove_items("relationship_class", *{x["id"] for x in items}) - self._db_map.commit_session("remove") - self.assertEqual(len(self._db_map.query(self._db_map.wide_relationship_class_sq).all()), 0) - - def test_remove_object_class(self): - """Test adding and removing an object class and committing""" - items, _ = self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) - self.assertEqual(len(items), 2) - self._db_map.remove_items("object_class", 1, 2) - with self.assertRaises(SpineDBAPIError): - # Nothing to commit - self._db_map.commit_session("delete") - self.assertEqual(len(self._db_map.query(self._db_map.object_class_sq).all()), 0) - - def test_remove_object_class_from_committed_session(self): - """Test removing an object class from a committed session""" - items, _ = self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) - self._db_map.commit_session("add") - self.assertEqual(len(self._db_map.query(self._db_map.object_class_sq).all()), 2) - self._db_map.remove_items("object_class", *{x["id"] for x in items}) - self._db_map.commit_session("Add test data.") - self.assertEqual(len(self._db_map.query(self._db_map.object_class_sq).all()), 0) - - def test_remove_parameter_value(self): - """Test adding and removing a parameter value and committing""" - self._db_map.add_object_classes({"name": "oc1", "id": 1}, strict=True) - self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, strict=True) - self._db_map.add_parameter_definitions({"name": "param", "id": 1, "object_class_id": 1}, strict=True) - self._db_map.add_parameter_values( - { - "value": b"0", - "id": 1, - "parameter_definition_id": 1, - "object_id": 1, - "object_class_id": 1, - "alternative_id": 1, - }, - strict=True, - ) - self._db_map.commit_session("add") - self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 1) - self._db_map.remove_items("parameter_value", 1) - self._db_map.commit_session("delete") - self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 0) - - def test_remove_parameter_value_from_committed_session(self): - """Test adding and committing a parameter value and then removing it""" - self._db_map.add_object_classes({"name": "oc1", "id": 1}, strict=True) - self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, strict=True) - self._db_map.add_parameter_definitions({"name": "param", "id": 1, "object_class_id": 1}, strict=True) - self._db_map.add_parameter_values( - { - "value": b"0", - "id": 1, - "parameter_definition_id": 1, - "object_id": 1, - "object_class_id": 1, - "alternative_id": 1, - }, - strict=True, - ) - self._db_map.commit_session("add") - self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 1) - self._db_map.remove_items("parameter_value", 1) - self._db_map.commit_session("delete") - self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 0) - - def test_cascade_remove_object_removes_parameter_value_as_well(self): - """Test adding and removing a parameter value and committing""" - self._db_map.add_object_classes({"name": "oc1", "id": 1}, strict=True) - self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, strict=True) - self._db_map.add_parameter_definitions({"name": "param", "id": 1, "object_class_id": 1}, strict=True) - self._db_map.add_parameter_values( - { - "value": b"0", - "id": 1, - "parameter_definition_id": 1, - "object_id": 1, - "object_class_id": 1, - "alternative_id": 1, - }, - strict=True, - ) - self._db_map.commit_session("add") - self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 1) - self._db_map.remove_items("object", 1) - self._db_map.commit_session("delete") - self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 0) - - def test_cascade_remove_object_from_committed_session_removes_parameter_value_as_well(self): - """Test adding and committing a paramater value and then removing it""" - self._db_map.add_object_classes({"name": "oc1", "id": 1}, strict=True) - self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, strict=True) - self._db_map.add_parameter_definitions({"name": "param", "id": 1, "object_class_id": 1}, strict=True) - self._db_map.add_parameter_values( - { - "value": b"0", - "id": 1, - "parameter_definition_id": 1, - "object_id": 1, - "object_class_id": 1, - "alternative_id": 1, - }, - strict=True, - ) - self._db_map.commit_session("add") - self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 1) - self._db_map.remove_items("object", 1) - self._db_map.commit_session("delete") - self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 0) - - def test_cascade_remove_metadata_removes_corresponding_entity_and_value_metadata(self): - import_functions.import_object_classes(self._db_map, ("my_class",)) - import_functions.import_objects(self._db_map, (("my_class", "my_object"),)) - import_functions.import_object_parameters(self._db_map, (("my_class", "my_parameter"),)) - import_functions.import_object_parameter_values( - self._db_map, (("my_class", "my_object", "my_parameter", 99.0),) - ) - import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) - import_functions.import_object_metadata(self._db_map, (("my_class", "my_object", '{"title": "My metadata."}'),)) - import_functions.import_object_parameter_value_metadata( - self._db_map, (("my_class", "my_object", "my_parameter", '{"title": "My metadata."}'),) - ) - self._db_map.commit_session("Add test data.") - metadata = self._db_map.query(self._db_map.metadata_sq).all() - self.assertEqual(len(metadata), 1) - self._db_map.remove_items("metadata", metadata[0].id) - self._db_map.commit_session("Remove test data.") - self.assertEqual(len(self._db_map.query(self._db_map.metadata_sq).all()), 0) - self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_metadata_sq).all()), 0) - self.assertEqual(len(self._db_map.query(self._db_map.entity_metadata_sq).all()), 0) - self.assertEqual(len(self._db_map.query(self._db_map.object_sq).all()), 1) - self.assertEqual(len(self._db_map.query(self._db_map.object_parameter_definition_sq).all()), 1) - - def test_cascade_remove_entity_metadata_removes_corresponding_metadata(self): - import_functions.import_object_classes(self._db_map, ("my_class",)) - import_functions.import_objects(self._db_map, (("my_class", "my_object"),)) - import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) - import_functions.import_object_metadata(self._db_map, (("my_class", "my_object", '{"title": "My metadata."}'),)) - self._db_map.commit_session("Add test data.") - entity_metadata = self._db_map.query(self._db_map.entity_metadata_sq).all() - self.assertEqual(len(entity_metadata), 1) - self._db_map.remove_items("entity_metadata", entity_metadata[0].id) - self._db_map.remove_unused_metadata() - self._db_map.commit_session("Remove test data.") - self.assertEqual(len(self._db_map.query(self._db_map.metadata_sq).all()), 0) - self.assertEqual(len(self._db_map.query(self._db_map.entity_metadata_sq).all()), 0) - self.assertEqual(len(self._db_map.query(self._db_map.object_sq).all()), 1) - - def test_cascade_remove_entity_metadata_leaves_metadata_used_by_value_intact(self): - import_functions.import_object_classes(self._db_map, ("my_class",)) - import_functions.import_objects(self._db_map, (("my_class", "my_object"),)) - import_functions.import_object_parameters(self._db_map, (("my_class", "my_parameter"),)) - import_functions.import_object_parameter_values( - self._db_map, (("my_class", "my_object", "my_parameter", 99.0),) - ) - import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) - import_functions.import_object_metadata(self._db_map, (("my_class", "my_object", '{"title": "My metadata."}'),)) - import_functions.import_object_parameter_value_metadata( - self._db_map, (("my_class", "my_object", "my_parameter", '{"title": "My metadata."}'),) - ) - self._db_map.commit_session("Add test data.") - entity_metadata = self._db_map.query(self._db_map.entity_metadata_sq).all() - self.assertEqual(len(entity_metadata), 1) - self._db_map.remove_items("entity_metadata", entity_metadata[0].id) - self._db_map.commit_session("Remove test data.") - self.assertEqual(len(self._db_map.query(self._db_map.metadata_sq).all()), 1) - self.assertEqual(len(self._db_map.query(self._db_map.entity_metadata_sq).all()), 0) - self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_metadata_sq).all()), 1) - - def test_cascade_remove_value_metadata_leaves_metadata_used_by_entity_intact(self): - import_functions.import_object_classes(self._db_map, ("my_class",)) - import_functions.import_objects(self._db_map, (("my_class", "my_object"),)) - import_functions.import_object_parameters(self._db_map, (("my_class", "my_parameter"),)) - import_functions.import_object_parameter_values( - self._db_map, (("my_class", "my_object", "my_parameter", 99.0),) - ) - import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) - import_functions.import_object_metadata(self._db_map, (("my_class", "my_object", '{"title": "My metadata."}'),)) - import_functions.import_object_parameter_value_metadata( - self._db_map, (("my_class", "my_object", "my_parameter", '{"title": "My metadata."}'),) - ) - self._db_map.commit_session("Add test data.") - parameter_value_metadata = self._db_map.query(self._db_map.parameter_value_metadata_sq).all() - self.assertEqual(len(parameter_value_metadata), 1) - self._db_map.remove_items("parameter_value_metadata", parameter_value_metadata[0].id) - self._db_map.commit_session("Remove test data.") - self.assertEqual(len(self._db_map.query(self._db_map.metadata_sq).all()), 1) - self.assertEqual(len(self._db_map.query(self._db_map.entity_metadata_sq).all()), 1) - self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_metadata_sq).all()), 0) - - def test_cascade_remove_object_removes_its_metadata(self): - import_functions.import_object_classes(self._db_map, ("my_class",)) - import_functions.import_objects(self._db_map, (("my_class", "my_object"),)) - import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) - import_functions.import_object_metadata(self._db_map, (("my_class", "my_object", '{"title": "My metadata."}'),)) - self._db_map.commit_session("Add test data.") - self._db_map.remove_items("object", 1) - self._db_map.remove_unused_metadata() - self._db_map.commit_session("Remove test data.") - self.assertEqual(len(self._db_map.query(self._db_map.metadata_sq).all()), 0) - self.assertEqual(len(self._db_map.query(self._db_map.entity_metadata_sq).all()), 0) - self.assertEqual(len(self._db_map.query(self._db_map.object_sq).all()), 0) - - def test_cascade_remove_relationship_removes_its_metadata(self): - import_functions.import_object_classes(self._db_map, ("my_object_class",)) - import_functions.import_objects(self._db_map, (("my_object_class", "my_object"),)) - import_functions.import_relationship_classes(self._db_map, (("my_class", ("my_object_class",)),)) - import_functions.import_relationships(self._db_map, (("my_class", ("my_object",)),)) - import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) - import_functions.import_relationship_metadata( - self._db_map, (("my_class", ("my_object",), '{"title": "My metadata."}'),) - ) - self._db_map.commit_session("Add test data.") - self._db_map.remove_items("relationship", 2) - self._db_map.remove_unused_metadata() - self._db_map.commit_session("Remove test data.") - self.assertEqual(len(self._db_map.query(self._db_map.metadata_sq).all()), 0) - self.assertEqual(len(self._db_map.query(self._db_map.entity_metadata_sq).all()), 0) - self.assertEqual(len(self._db_map.query(self._db_map.relationship_sq).all()), 0) - - def test_cascade_remove_parameter_value_removes_its_metadata(self): - import_functions.import_object_classes(self._db_map, ("my_class",)) - import_functions.import_objects(self._db_map, (("my_class", "my_object"),)) - import_functions.import_object_parameters(self._db_map, (("my_class", "my_parameter"),)) - import_functions.import_object_parameter_values( - self._db_map, (("my_class", "my_object", "my_parameter", 99.0),) - ) - import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) - import_functions.import_object_parameter_value_metadata( - self._db_map, (("my_class", "my_object", "my_parameter", '{"title": "My metadata."}'),) - ) - self._db_map.commit_session("Add test data.") - self._db_map.remove_items("parameter_value", 1) - self._db_map.remove_unused_metadata() - self._db_map.commit_session("Remove test data.") - self.assertEqual(len(self._db_map.query(self._db_map.metadata_sq).all()), 0) - self.assertEqual(len(self._db_map.query(self._db_map.entity_metadata_sq).all()), 0) - self.assertEqual(len(self._db_map.query(self._db_map.parameter_value_sq).all()), 0) - - -class TestDatabaseMappingAdd(unittest.TestCase): - def setUp(self): - self._db_map = create_diff_db_map() - - def tearDown(self): - self._db_map.close() - - def test_add_and_retrieve_many_objects(self): - """Tests add many objects into db and retrieving them.""" - items, _ = self._db_map.add_object_classes({"name": "testclass"}) - class_id = next(iter(items))["id"] - added = self._db_map.add_objects(*[{"name": str(i), "class_id": class_id} for i in range(1001)])[0] - self.assertEqual(len(added), 1001) - self._db_map.commit_session("test_commit") - self.assertEqual(self._db_map.query(self._db_map.entity_sq).count(), 1001) - - def test_add_object_classes(self): - """Test that adding object classes works.""" - self._db_map.add_object_classes({"name": "fish"}, {"name": "dog"}) - self._db_map.commit_session("add") - object_classes = self._db_map.query(self._db_map.object_class_sq).all() - self.assertEqual(len(object_classes), 2) - self.assertEqual(object_classes[0].name, "fish") - self.assertEqual(object_classes[1].name, "dog") - - def test_add_object_class_with_invalid_name(self): - """Test that adding object classes with empty name raises error""" - with self.assertRaises(SpineIntegrityError): - self._db_map.add_object_classes({"name": ""}, strict=True) - - def test_add_object_classes_with_same_name(self): - """Test that adding two object classes with the same name only adds one of them.""" - self._db_map.add_object_classes({"name": "fish"}, {"name": "fish"}) - self._db_map.commit_session("add") - object_classes = self._db_map.query(self._db_map.object_class_sq).all() - self.assertEqual(len(object_classes), 1) - self.assertEqual(object_classes[0].name, "fish") - - def test_add_object_class_with_same_name_as_existing_one(self): - """Test that adding an object class with an already taken name raises an integrity error.""" - self._db_map.add_object_classes({"name": "fish"}, {"name": "fish"}) - with self.assertRaises(SpineIntegrityError): - self._db_map.add_object_classes({"name": "fish"}, strict=True) - - def test_add_objects(self): - """Test that adding objects works.""" - self._db_map.add_object_classes({"name": "fish", "id": 1}) - self._db_map.add_objects({"name": "nemo", "class_id": 1}, {"name": "dory", "class_id": 1}) - self._db_map.commit_session("add") - objects = self._db_map.query(self._db_map.object_sq).all() - self.assertEqual(len(objects), 2) - self.assertEqual(objects[0].name, "nemo") - self.assertEqual(objects[0].class_id, 1) - self.assertEqual(objects[1].name, "dory") - self.assertEqual(objects[1].class_id, 1) - - def test_add_object_with_invalid_name(self): - """Test that adding object classes with empty name raises error""" - self._db_map.add_object_classes({"name": "fish"}) - with self.assertRaises(SpineIntegrityError): - self._db_map.add_objects({"name": "", "class_id": 1}, strict=True) - - def test_add_objects_with_same_name(self): - """Test that adding two objects with the same name only adds one of them.""" - self._db_map.add_object_classes({"name": "fish", "id": 1}) - self._db_map.add_objects({"name": "nemo", "class_id": 1}, {"name": "nemo", "class_id": 1}) - self._db_map.commit_session("add") - objects = self._db_map.query(self._db_map.object_sq).all() - self.assertEqual(len(objects), 1) - self.assertEqual(objects[0].name, "nemo") - self.assertEqual(objects[0].class_id, 1) - - def test_add_object_with_same_name_as_existing_one(self): - """Test that adding an object with an already taken name raises an integrity error.""" - self._db_map.add_object_classes({"name": "fish"}) - self._db_map.add_objects({"name": "nemo", "class_id": 1}) - with self.assertRaises(SpineIntegrityError): - self._db_map.add_objects({"name": "nemo", "class_id": 1}, strict=True) - - def test_add_object_with_invalid_class(self): - """Test that adding an object with a non existing class raises an integrity error.""" - self._db_map.add_object_classes({"name": "fish"}) - with self.assertRaises(SpineIntegrityError): - self._db_map.add_objects({"name": "pluto", "class_id": 2}, strict=True) - - def test_add_relationship_classes(self): - """Test that adding relationship classes works.""" - self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) - self._db_map.add_wide_relationship_classes( - {"name": "rc1", "object_class_id_list": [1, 2]}, {"name": "rc2", "object_class_id_list": [2, 1]} - ) - self._db_map.commit_session("add") - table = self._db_map.get_table("entity_class_dimension") - ent_cls_dims = self._db_map.query(table).all() - rel_clss = self._db_map.query(self._db_map.wide_relationship_class_sq).all() - self.assertEqual(len(ent_cls_dims), 4) - self.assertEqual(rel_clss[0].name, "rc1") - self.assertEqual(ent_cls_dims[0].dimension_id, 1) - self.assertEqual(ent_cls_dims[1].dimension_id, 2) - self.assertEqual(rel_clss[1].name, "rc2") - self.assertEqual(ent_cls_dims[2].dimension_id, 2) - self.assertEqual(ent_cls_dims[3].dimension_id, 1) - - def test_add_relationship_classes_with_invalid_name(self): - """Test that adding object classes with empty name raises error""" - self._db_map.add_object_classes({"name": "fish"}) - with self.assertRaises(SpineIntegrityError): - self._db_map.add_wide_relationship_classes({"name": "", "object_class_id_list": [1]}, strict=True) - - def test_add_relationship_classes_with_same_name(self): - """Test that adding two relationship classes with the same name only adds one of them.""" - self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) - self._db_map.add_wide_relationship_classes( - {"name": "rc1", "object_class_id_list": [1, 2]}, - {"name": "rc1", "object_class_id_list": [1, 2]}, - strict=False, - ) - self._db_map.commit_session("add") - table = self._db_map.get_table("entity_class_dimension") - ecs_dims = self._db_map.query(table).all() - relationship_classes = self._db_map.query(self._db_map.wide_relationship_class_sq).all() - self.assertEqual(len(ecs_dims), 2) - self.assertEqual(len(relationship_classes), 1) - self.assertEqual(relationship_classes[0].name, "rc1") - self.assertEqual(ecs_dims[0].dimension_id, 1) - self.assertEqual(ecs_dims[1].dimension_id, 2) - - def test_add_relationship_class_with_same_name_as_existing_one(self): - """Test that adding a relationship class with an already taken name raises an integrity error.""" - query_wrapper = create_query_wrapper(self._db_map) - with mock.patch.object(DatabaseMapping, "query") as mock_query, mock.patch.object( - DatabaseMapping, "object_class_sq" - ) as mock_object_class_sq, mock.patch.object( - DatabaseMapping, "wide_relationship_class_sq" - ) as mock_wide_rel_cls_sq: - mock_query.side_effect = query_wrapper - mock_object_class_sq.return_value = [ - KeyedTuple([1, "fish"], labels=["id", "name"]), - KeyedTuple([2, "dog"], labels=["id", "name"]), - ] - mock_wide_rel_cls_sq.return_value = [ - KeyedTuple([1, "1,2", "fish__dog"], labels=["id", "object_class_id_list", "name"]) - ] - with self.assertRaises(SpineIntegrityError): - self._db_map.add_wide_relationship_classes( - {"name": "fish__dog", "object_class_id_list": [1, 2]}, strict=True - ) - - def test_add_relationship_class_with_invalid_object_class(self): - """Test that adding a relationship class with a non existing object class raises an integrity error.""" - query_wrapper = create_query_wrapper(self._db_map) - with mock.patch.object(DatabaseMapping, "query") as mock_query, mock.patch.object( - DatabaseMapping, "object_class_sq" - ) as mock_object_class_sq, mock.patch.object(DatabaseMapping, "wide_relationship_class_sq"): - mock_query.side_effect = query_wrapper - mock_object_class_sq.return_value = [KeyedTuple([1, "fish"], labels=["id", "name"])] - with self.assertRaises(SpineIntegrityError): - self._db_map.add_wide_relationship_classes( - {"name": "fish__dog", "object_class_id_list": [1, 2]}, strict=True - ) - - def test_add_relationships(self): - """Test that adding relationships works.""" - self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) - self._db_map.add_wide_relationship_classes({"name": "rc1", "object_class_id_list": [1, 2], "id": 3}) - self._db_map.add_objects({"name": "o1", "class_id": 1, "id": 1}, {"name": "o2", "class_id": 2, "id": 2}) - self._db_map.add_wide_relationships({"name": "nemo__pluto", "class_id": 3, "object_id_list": [1, 2]}) - self._db_map.commit_session("add") - ent_els = self._db_map.query(self._db_map.get_table("entity_element")).all() - relationships = self._db_map.query(self._db_map.wide_relationship_sq).all() - self.assertEqual(len(ent_els), 2) - self.assertEqual(len(relationships), 1) - self.assertEqual(relationships[0].name, "nemo__pluto") - self.assertEqual(ent_els[0].entity_class_id, 3) - self.assertEqual(ent_els[0].element_id, 1) - self.assertEqual(ent_els[1].entity_class_id, 3) - self.assertEqual(ent_els[1].element_id, 2) - - def test_add_relationship_with_invalid_name(self): - """Test that adding object classes with empty name raises error""" - self._db_map.add_object_classes({"name": "oc1", "id": 1}, strict=True) - self._db_map.add_wide_relationship_classes({"name": "rc1", "object_class_id_list": [1]}, strict=True) - self._db_map.add_objects({"name": "o1", "class_id": 1}, strict=True) - with self.assertRaises(SpineIntegrityError): - self._db_map.add_wide_relationships({"name": "", "class_id": 2, "object_id_list": [1]}, strict=True) - - def test_add_identical_relationships(self): - """Test that adding two relationships with the same class and same objects only adds the first one.""" - self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) - self._db_map.add_wide_relationship_classes({"name": "rc1", "object_class_id_list": [1, 2], "id": 3}) - self._db_map.add_objects({"name": "o1", "class_id": 1, "id": 1}, {"name": "o2", "class_id": 2, "id": 2}) - self._db_map.add_wide_relationships( - {"name": "nemo__pluto", "class_id": 3, "object_id_list": [1, 2]}, - {"name": "nemo__pluto_duplicate", "class_id": 3, "object_id_list": [1, 2]}, - ) - self._db_map.commit_session("add") - relationships = self._db_map.query(self._db_map.wide_relationship_sq).all() - self.assertEqual(len(relationships), 1) - - def test_add_relationship_identical_to_existing_one(self): - """Test that adding a relationship with the same class and same objects as an existing one - raises an integrity error. - """ - query_wrapper = create_query_wrapper(self._db_map) - with mock.patch.object(DatabaseMapping, "query") as mock_query, mock.patch.object( - DatabaseMapping, "object_sq" - ) as mock_object_sq, mock.patch.object( - DatabaseMapping, "wide_relationship_class_sq" - ) as mock_wide_rel_cls_sq, mock.patch.object( - DatabaseMapping, "wide_relationship_sq" - ) as mock_wide_rel_sq: - mock_query.side_effect = query_wrapper - mock_object_sq.return_value = [ - KeyedTuple([1, 10, "nemo"], labels=["id", "class_id", "name"]), - KeyedTuple([2, 20, "pluto"], labels=["id", "class_id", "name"]), - ] - mock_wide_rel_cls_sq.return_value = [ - KeyedTuple([1, "10,20", "fish__dog"], labels=["id", "object_class_id_list", "name"]) - ] - mock_wide_rel_sq.return_value = [ - KeyedTuple([1, 1, "1,2", "nemo__pluto"], labels=["id", "class_id", "object_id_list", "name"]) - ] - with self.assertRaises(SpineIntegrityError): - self._db_map.add_wide_relationships( - {"name": "nemoy__plutoy", "class_id": 1, "object_id_list": [1, 2]}, strict=True - ) - - def test_add_relationship_with_invalid_class(self): - """Test that adding a relationship with an invalid class raises an integrity error.""" - query_wrapper = create_query_wrapper(self._db_map) - with mock.patch.object(DatabaseMapping, "query") as mock_query, mock.patch.object( - DatabaseMapping, "object_sq" - ) as mock_object_sq, mock.patch.object( - DatabaseMapping, "wide_relationship_class_sq" - ) as mock_wide_rel_cls_sq, mock.patch.object( - DatabaseMapping, "wide_relationship_sq" - ): - mock_query.side_effect = query_wrapper - mock_object_sq.return_value = [ - KeyedTuple([1, 10, "nemo"], labels=["id", "class_id", "name"]), - KeyedTuple([2, 20, "pluto"], labels=["id", "class_id", "name"]), - ] - mock_wide_rel_cls_sq.return_value = [ - KeyedTuple([1, "10,20", "fish__dog"], labels=["id", "object_class_id_list", "name"]) - ] - with self.assertRaises(SpineIntegrityError): - self._db_map.add_wide_relationships( - {"name": "nemo__pluto", "class_id": 2, "object_id_list": [1, 2]}, strict=True - ) - - def test_add_relationship_with_invalid_object(self): - """Test that adding a relationship with an invalid object raises an integrity error.""" - query_wrapper = create_query_wrapper(self._db_map) - with mock.patch.object(DatabaseMapping, "query") as mock_query, mock.patch.object( - DatabaseMapping, "object_sq" - ) as mock_object_sq, mock.patch.object( - DatabaseMapping, "wide_relationship_class_sq" - ) as mock_wide_rel_cls_sq, mock.patch.object( - DatabaseMapping, "wide_relationship_sq" - ): - mock_query.side_effect = query_wrapper - mock_object_sq.return_value = [ - KeyedTuple([1, 10, "nemo"], labels=["id", "class_id", "name"]), - KeyedTuple([2, 20, "pluto"], labels=["id", "class_id", "name"]), - ] - mock_wide_rel_cls_sq.return_value = [ - KeyedTuple([1, "10,20", "fish__dog"], labels=["id", "object_class_id_list", "name"]) - ] - with self.assertRaises(SpineIntegrityError): - self._db_map.add_wide_relationships( - {"name": "nemo__pluto", "class_id": 1, "object_id_list": [1, 3]}, strict=True - ) - - def test_add_entity_groups(self): - """Test that adding group entities works.""" - self._db_map.add_object_classes({"name": "oc1", "id": 1}) - self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 1}) - self._db_map.add_entity_groups({"entity_id": 1, "entity_class_id": 1, "member_id": 2}) - self._db_map.commit_session("add") - table = self._db_map.get_table("entity_group") - entity_groups = self._db_map.query(table).all() - self.assertEqual(len(entity_groups), 1) - self.assertEqual(entity_groups[0].entity_id, 1) - self.assertEqual(entity_groups[0].entity_class_id, 1) - self.assertEqual(entity_groups[0].member_id, 2) - - def test_add_entity_groups_with_invalid_class(self): - """Test that adding group entities with an invalid class fails.""" - self._db_map.add_object_classes({"name": "oc1", "id": 1}) - self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 1}) - with self.assertRaises(SpineIntegrityError): - self._db_map.add_entity_groups({"entity_id": 1, "entity_class_id": 2, "member_id": 2}, strict=True) - - def test_add_entity_groups_with_invalid_entity(self): - """Test that adding group entities with an invalid entity fails.""" - self._db_map.add_object_classes({"name": "oc1", "id": 1}) - self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 1}) - with self.assertRaises(SpineIntegrityError): - self._db_map.add_entity_groups({"entity_id": 3, "entity_class_id": 2, "member_id": 2}, strict=True) - - def test_add_entity_groups_with_invalid_member(self): - """Test that adding group entities with an invalid member fails.""" - self._db_map.add_object_classes({"name": "oc1", "id": 1}) - self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 1}) - with self.assertRaises(SpineIntegrityError): - self._db_map.add_entity_groups({"entity_id": 1, "entity_class_id": 2, "member_id": 3}, strict=True) - - def test_add_repeated_entity_groups(self): - """Test that adding repeated group entities fails.""" - self._db_map.add_object_classes({"name": "oc1", "id": 1}) - self._db_map.add_objects({"name": "o1", "id": 1, "class_id": 1}, {"name": "o2", "id": 2, "class_id": 1}) - self._db_map.add_entity_groups({"entity_id": 1, "entity_class_id": 2, "member_id": 2}) - with self.assertRaises(SpineIntegrityError): - self._db_map.add_entity_groups({"entity_id": 1, "entity_class_id": 2, "member_id": 2}, strict=True) - - def test_add_parameter_definitions(self): - """Test that adding parameter definitions works.""" - self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) - self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "object_class_id_list": [1, 2]}) - self._db_map.add_parameter_definitions( - {"name": "color", "object_class_id": 1, "description": "test1"}, - {"name": "relative_speed", "relationship_class_id": 3, "description": "test2"}, - ) - self._db_map.commit_session("add") - table = self._db_map.get_table("parameter_definition") - parameter_definitions = self._db_map.query(table).all() - self.assertEqual(len(parameter_definitions), 2) - self.assertEqual(parameter_definitions[0].name, "color") - self.assertEqual(parameter_definitions[0].entity_class_id, 1) - self.assertEqual(parameter_definitions[0].description, "test1") - self.assertEqual(parameter_definitions[1].name, "relative_speed") - self.assertEqual(parameter_definitions[1].entity_class_id, 3) - self.assertEqual(parameter_definitions[1].description, "test2") - - def test_add_parameter_with_invalid_name(self): - """Test that adding object classes with empty name raises error""" - self._db_map.add_object_classes({"name": "oc1"}, strict=True) - with self.assertRaises(SpineIntegrityError): - self._db_map.add_parameter_definitions({"name": "", "object_class_id": 1}, strict=True) - - def test_add_parameter_definitions_with_same_name(self): - """Test that adding two parameter_definitions with the same name adds both of them.""" - self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) - self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "object_class_id_list": [1, 2]}) - self._db_map.add_parameter_definitions( - {"name": "color", "object_class_id": 1}, {"name": "color", "relationship_class_id": 3} - ) - self._db_map.commit_session("add") - table = self._db_map.get_table("parameter_definition") - parameter_definitions = self._db_map.query(table).all() - self.assertEqual(len(parameter_definitions), 2) - self.assertEqual(parameter_definitions[0].name, "color") - self.assertEqual(parameter_definitions[1].name, "color") - self.assertEqual(parameter_definitions[0].entity_class_id, 1) - - def test_add_parameter_with_same_name_as_existing_one(self): - """Test that adding parameter_definitions with an already taken name raises and integrity error.""" - self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) - self._db_map.add_wide_relationship_classes({"name": "rc1", "id": 3, "object_class_id_list": [1, 2]}) - self._db_map.add_parameter_definitions( - {"name": "color", "object_class_id": 1}, {"name": "color", "relationship_class_id": 3} - ) - with self.assertRaises(SpineIntegrityError): - self._db_map.add_parameter_definitions({"name": "color", "object_class_id": 1}, strict=True) - - def test_add_parameter_values(self): - """Test that adding parameter values works.""" - import_functions.import_object_classes(self._db_map, ["fish", "dog"]) - import_functions.import_relationship_classes(self._db_map, [("fish_dog", ["fish", "dog"])]) - import_functions.import_objects(self._db_map, [("fish", "nemo"), ("dog", "pluto")]) - import_functions.import_relationships(self._db_map, [("fish_dog", ("nemo", "pluto"))]) - import_functions.import_object_parameters(self._db_map, [("fish", "color")]) - import_functions.import_relationship_parameters(self._db_map, [("fish_dog", "rel_speed")]) - self._db_map.commit_session("add") - color_id = ( - self._db_map.query(self._db_map.parameter_definition_sq) - .filter(self._db_map.parameter_definition_sq.c.name == "color") - .first() - .id - ) - rel_speed_id = ( - self._db_map.query(self._db_map.parameter_definition_sq) - .filter(self._db_map.parameter_definition_sq.c.name == "rel_speed") - .first() - .id - ) - nemo_row = self._db_map.query(self._db_map.object_sq).filter(self._db_map.object_sq.c.name == "nemo").first() - nemo__pluto_row = self._db_map.query(self._db_map.wide_relationship_sq).first() - self._db_map.add_parameter_values( - { - "parameter_definition_id": color_id, - "entity_id": nemo_row.id, - "entity_class_id": nemo_row.class_id, - "value": b'"orange"', - "alternative_id": 1, - }, - { - "parameter_definition_id": rel_speed_id, - "entity_id": nemo__pluto_row.id, - "entity_class_id": nemo__pluto_row.class_id, - "value": b"125", - "alternative_id": 1, - }, - ) - self._db_map.commit_session("add") - table = self._db_map.get_table("parameter_value") - parameter_values = self._db_map.query(table).all() - self.assertEqual(len(parameter_values), 2) - self.assertEqual(parameter_values[0].parameter_definition_id, 1) - self.assertEqual(parameter_values[0].entity_id, 1) - self.assertEqual(parameter_values[0].value, b'"orange"') - self.assertEqual(parameter_values[1].parameter_definition_id, 2) - self.assertEqual(parameter_values[1].entity_id, 3) - self.assertEqual(parameter_values[1].value, b"125") - - def test_add_parameter_value_with_invalid_object_or_relationship(self): - """Test that adding a parameter value with an invalid object or relationship raises an - integrity error.""" - import_functions.import_object_classes(self._db_map, ["fish", "dog"]) - import_functions.import_relationship_classes(self._db_map, [("fish_dog", ["fish", "dog"])]) - import_functions.import_objects(self._db_map, [("fish", "nemo"), ("dog", "pluto")]) - import_functions.import_relationships(self._db_map, [("fish_dog", ("nemo", "pluto"))]) - import_functions.import_object_parameters(self._db_map, [("fish", "color")]) - import_functions.import_relationship_parameters(self._db_map, [("fish_dog", "rel_speed")]) - _, errors = self._db_map.add_parameter_values( - {"parameter_definition_id": 1, "object_id": 3, "value": b'"orange"', "alternative_id": 1}, strict=False - ) - self.assertEqual([str(e) for e in errors], ["invalid entity_class_id for parameter_value"]) - _, errors = self._db_map.add_parameter_values( - {"parameter_definition_id": 2, "relationship_id": 2, "value": b"125", "alternative_id": 1}, strict=False - ) - self.assertEqual([str(e) for e in errors], ["invalid entity_class_id for parameter_value"]) - - def test_add_same_parameter_value_twice(self): - """Test that adding a parameter value twice only adds the first one.""" - import_functions.import_object_classes(self._db_map, ["fish"]) - import_functions.import_objects(self._db_map, [("fish", "nemo")]) - import_functions.import_object_parameters(self._db_map, [("fish", "color")]) - self._db_map.commit_session("add") - color_id = ( - self._db_map.query(self._db_map.parameter_definition_sq) - .filter(self._db_map.parameter_definition_sq.c.name == "color") - .first() - .id - ) - nemo_row = self._db_map.query(self._db_map.object_sq).filter(self._db_map.entity_sq.c.name == "nemo").first() - self._db_map.add_parameter_values( - { - "parameter_definition_id": color_id, - "entity_id": nemo_row.id, - "entity_class_id": nemo_row.class_id, - "value": b'"orange"', - "alternative_id": 1, - }, - { - "parameter_definition_id": color_id, - "entity_id": nemo_row.id, - "entity_class_id": nemo_row.class_id, - "value": b'"blue"', - "alternative_id": 1, - }, - ) - self._db_map.commit_session("add") - table = self._db_map.get_table("parameter_value") - parameter_values = self._db_map.query(table).all() - self.assertEqual(len(parameter_values), 1) - self.assertEqual(parameter_values[0].parameter_definition_id, 1) - self.assertEqual(parameter_values[0].entity_id, 1) - self.assertEqual(parameter_values[0].value, b'"orange"') - - def test_add_existing_parameter_value(self): - """Test that adding an existing parameter value raises an integrity error.""" - import_functions.import_object_classes(self._db_map, ["fish"]) - import_functions.import_objects(self._db_map, [("fish", "nemo")]) - import_functions.import_object_parameters(self._db_map, [("fish", "color")]) - import_functions.import_object_parameter_values(self._db_map, [("fish", "nemo", "color", "orange")]) - self._db_map.commit_session("add") - _, errors = self._db_map.add_parameter_values( - { - "parameter_definition_id": 1, - "entity_class_id": 1, - "entity_id": 1, - "value": b'"blue"', - "alternative_id": 1, - }, - strict=False, - ) - self.assertEqual( - [str(e) for e in errors], - [ - "there's already a parameter_value with " - "{'parameter_definition_name': 'color', 'entity_byname': ('nemo',), 'alternative_name': 'Base'}" - ], - ) - - def test_add_alternative(self): - items, errors = self._db_map.add_alternatives({"name": "my_alternative"}) - self.assertEqual(errors, []) - self.assertEqual(len(items), 1) - self._db_map.commit_session("Add test data.") - alternatives = self._db_map.query(self._db_map.alternative_sq).all() - self.assertEqual(len(alternatives), 2) - self.assertEqual( - dict(alternatives[0]), {"id": 1, "name": "Base", "description": "Base alternative", "commit_id": 1} - ) - self.assertEqual( - dict(alternatives[1]), {"id": 2, "name": "my_alternative", "description": None, "commit_id": 2} - ) - - def test_add_scenario(self): - items, errors = self._db_map.add_scenarios({"name": "my_scenario"}) - self.assertEqual(errors, []) - self.assertEqual(len(items), 1) - self._db_map.commit_session("Add test data.") - scenarios = self._db_map.query(self._db_map.scenario_sq).all() - self.assertEqual(len(scenarios), 1) - self.assertEqual( - dict(scenarios[0]), - {"id": 1, "name": "my_scenario", "description": None, "active": False, "commit_id": 2}, - ) - - def test_add_scenario_alternative(self): - import_functions.import_scenarios(self._db_map, ("my_scenario",)) - self._db_map.commit_session("Add test data.") - items, errors = self._db_map.add_scenario_alternatives({"scenario_id": 1, "alternative_id": 1, "rank": 0}) - self.assertEqual(errors, []) - self.assertEqual(len(items), 1) - self._db_map.commit_session("Add test data.") - scenario_alternatives = self._db_map.query(self._db_map.scenario_alternative_sq).all() - self.assertEqual(len(scenario_alternatives), 1) - self.assertEqual( - dict(scenario_alternatives[0]), - {"id": 1, "scenario_id": 1, "alternative_id": 1, "rank": 0, "commit_id": 3}, - ) - - def test_add_metadata(self): - items, errors = self._db_map.add_metadata({"name": "test name", "value": "test_add_metadata"}, strict=False) - self.assertEqual(errors, []) - self.assertEqual(len(items), 1) - self._db_map.commit_session("Add metadata") - metadata = self._db_map.query(self._db_map.metadata_sq).all() - self.assertEqual(len(metadata), 1) - self.assertEqual( - dict(metadata[0]), {"name": "test name", "id": 1, "value": "test_add_metadata", "commit_id": 2} - ) - - def test_add_metadata_that_exists_does_not_add_it(self): - import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) - self._db_map.commit_session("Add test data.") - items, _ = self._db_map.add_metadata({"name": "title", "value": "My metadata."}, strict=False) - self.assertEqual(items, []) - metadata = self._db_map.query(self._db_map.metadata_sq).all() - self.assertEqual(len(metadata), 1) - self.assertEqual(dict(metadata[0]), {"name": "title", "id": 1, "value": "My metadata.", "commit_id": 2}) - - def test_add_entity_metadata_for_object(self): - import_functions.import_object_classes(self._db_map, ("fish",)) - import_functions.import_objects(self._db_map, (("fish", "leviathan"),)) - import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) - self._db_map.commit_session("Add test data.") - items, errors = self._db_map.add_entity_metadata({"entity_id": 1, "metadata_id": 1}, strict=False) - self.assertEqual(errors, []) - self.assertEqual(len(items), 1) - self._db_map.commit_session("Add entity metadata") - entity_metadata = self._db_map.query(self._db_map.ext_entity_metadata_sq).all() - self.assertEqual(len(entity_metadata), 1) - self.assertEqual( - dict(entity_metadata[0]), - { - "entity_id": 1, - "entity_name": "leviathan", - "metadata_name": "title", - "metadata_value": "My metadata.", - "metadata_id": 1, - "id": 1, - "commit_id": 3, - }, - ) - - def test_add_entity_metadata_for_relationship(self): - import_functions.import_object_classes(self._db_map, ("my_object_class",)) - import_functions.import_objects(self._db_map, (("my_object_class", "my_object"),)) - import_functions.import_relationship_classes(self._db_map, (("my_relationship_class", ("my_object_class",)),)) - import_functions.import_relationships(self._db_map, (("my_relationship_class", ("my_object",)),)) - import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) - self._db_map.commit_session("Add test data.") - items, errors = self._db_map.add_entity_metadata({"entity_id": 2, "metadata_id": 1}, strict=False) - self.assertEqual(errors, []) - self.assertEqual(len(items), 1) - self._db_map.commit_session("Add entity metadata") - entity_metadata = self._db_map.query(self._db_map.ext_entity_metadata_sq).all() - self.assertEqual(len(entity_metadata), 1) - self.assertEqual( - dict(entity_metadata[0]), - { - "entity_id": 2, - "entity_name": "my_relationship_class_my_object", - "metadata_name": "title", - "metadata_value": "My metadata.", - "metadata_id": 1, - "id": 1, - "commit_id": 3, - }, - ) - - def test_add_entity_metadata_doesnt_raise_with_empty_cache(self): - items, errors = self._db_map.add_entity_metadata({"entity_id": 1, "metadata_id": 1}, strict=False) - self.assertEqual(items, []) - self.assertEqual(len(errors), 1) - - def test_add_ext_entity_metadata_for_object(self): - import_functions.import_object_classes(self._db_map, ("fish",)) - import_functions.import_objects(self._db_map, (("fish", "leviathan"),)) - self._db_map.commit_session("Add test data.") - items, errors = self._db_map.add_ext_entity_metadata( - {"entity_id": 1, "metadata_name": "key", "metadata_value": "object metadata"}, strict=False - ) - self.assertEqual(errors, []) - self.assertEqual(len(items), 1) - self._db_map.commit_session("Add entity metadata") - entity_metadata = self._db_map.query(self._db_map.ext_entity_metadata_sq).all() - self.assertEqual(len(entity_metadata), 1) - self.assertEqual( - dict(entity_metadata[0]), - { - "entity_id": 1, - "entity_name": "leviathan", - "metadata_name": "key", - "metadata_value": "object metadata", - "metadata_id": 1, - "id": 1, - "commit_id": 3, - }, - ) - - def test_adding_ext_entity_metadata_for_object_reuses_existing_metadata_names_and_values(self): - import_functions.import_object_classes(self._db_map, ("fish",)) - import_functions.import_objects(self._db_map, (("fish", "leviathan"),)) - import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) - self._db_map.commit_session("Add test data.") - items, errors = self._db_map.add_ext_entity_metadata( - {"entity_id": 1, "metadata_name": "title", "metadata_value": "My metadata."}, strict=False - ) - self.assertEqual(errors, []) - self.assertEqual(len(items), 1) - self._db_map.commit_session("Add entity metadata") - metadata = self._db_map.query(self._db_map.metadata_sq).all() - self.assertEqual(len(metadata), 1) - self.assertEqual(dict(metadata[0]), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": 2}) - entity_metadata = self._db_map.query(self._db_map.ext_entity_metadata_sq).all() - self.assertEqual(len(entity_metadata), 1) - self.assertEqual( - dict(entity_metadata[0]), - { - "entity_id": 1, - "entity_name": "leviathan", - "metadata_name": "title", - "metadata_value": "My metadata.", - "metadata_id": 1, - "id": 1, - "commit_id": 3, - }, - ) - - def test_add_parameter_value_metadata(self): - import_functions.import_object_classes(self._db_map, ("fish",)) - import_functions.import_objects(self._db_map, (("fish", "leviathan"),)) - import_functions.import_object_parameters(self._db_map, (("fish", "paranormality"),)) - import_functions.import_object_parameter_values(self._db_map, (("fish", "leviathan", "paranormality", 3.9),)) - import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) - self._db_map.commit_session("Add test data.") - items, errors = self._db_map.add_parameter_value_metadata( - {"parameter_value_id": 1, "metadata_id": 1, "alternative_id": 1}, strict=False - ) - self.assertEqual(errors, []) - self.assertEqual(len(items), 1) - self._db_map.commit_session("Add value metadata") - value_metadata = self._db_map.query(self._db_map.ext_parameter_value_metadata_sq).all() - self.assertEqual(len(value_metadata), 1) - self.assertEqual( - dict(value_metadata[0]), - { - "alternative_name": "Base", - "entity_name": "leviathan", - "parameter_value_id": 1, - "parameter_name": "paranormality", - "metadata_name": "title", - "metadata_value": "My metadata.", - "metadata_id": 1, - "id": 1, - "commit_id": 3, - }, - ) - - def test_add_parameter_value_metadata_doesnt_raise_with_empty_cache(self): - items, errors = self._db_map.add_parameter_value_metadata( - {"parameter_value_id": 1, "metadata_id": 1, "alternative_id": 1} - ) - self.assertEqual(len(items), 0) - self.assertEqual(len(errors), 1) - - def test_add_ext_parameter_value_metadata(self): - import_functions.import_object_classes(self._db_map, ("fish",)) - import_functions.import_objects(self._db_map, (("fish", "leviathan"),)) - import_functions.import_object_parameters(self._db_map, (("fish", "paranormality"),)) - import_functions.import_object_parameter_values(self._db_map, (("fish", "leviathan", "paranormality", 3.9),)) - self._db_map.commit_session("Add test data.") - items, errors = self._db_map.add_ext_parameter_value_metadata( - { - "parameter_value_id": 1, - "metadata_name": "key", - "metadata_value": "parameter metadata", - "alternative_id": 1, - }, - strict=False, - ) - self.assertEqual(errors, []) - self.assertEqual(len(items), 1) - self._db_map.commit_session("Add value metadata") - value_metadata = self._db_map.query(self._db_map.ext_parameter_value_metadata_sq).all() - self.assertEqual(len(value_metadata), 1) - self.assertEqual( - dict(value_metadata[0]), - { - "alternative_name": "Base", - "entity_name": "leviathan", - "parameter_value_id": 1, - "parameter_name": "paranormality", - "metadata_name": "key", - "metadata_value": "parameter metadata", - "metadata_id": 1, - "id": 1, - "commit_id": 3, - }, - ) - - def test_add_ext_parameter_value_metadata_reuses_existing_metadata(self): - import_functions.import_object_classes(self._db_map, ("fish",)) - import_functions.import_objects(self._db_map, (("fish", "leviathan"),)) - import_functions.import_object_parameters(self._db_map, (("fish", "paranormality"),)) - import_functions.import_object_parameter_values(self._db_map, (("fish", "leviathan", "paranormality", 3.9),)) - import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) - self._db_map.commit_session("Add test data.") - items, errors = self._db_map.add_ext_parameter_value_metadata( - {"parameter_value_id": 1, "metadata_name": "title", "metadata_value": "My metadata.", "alternative_id": 1}, - strict=False, - ) - self.assertEqual(errors, []) - self.assertEqual(len(items), 1) - self._db_map.commit_session("Add value metadata") - metadata = self._db_map.query(self._db_map.metadata_sq).all() - self.assertEqual(len(metadata), 1) - self.assertEqual(dict(metadata[0]), {"id": 1, "name": "title", "value": "My metadata.", "commit_id": 2}) - value_metadata = self._db_map.query(self._db_map.ext_parameter_value_metadata_sq).all() - self.assertEqual(len(value_metadata), 1) - self.assertEqual( - dict(value_metadata[0]), - { - "alternative_name": "Base", - "entity_name": "leviathan", - "parameter_value_id": 1, - "parameter_name": "paranormality", - "metadata_name": "title", - "metadata_value": "My metadata.", - "metadata_id": 1, - "id": 1, - "commit_id": 3, - }, - ) - - -class TestDatabaseMappingUpdate(unittest.TestCase): - def setUp(self): - self._db_map = create_diff_db_map() - - def tearDown(self): - self._db_map.close() - - def test_update_object_classes(self): - """Test that updating object classes works.""" - self._db_map.add_object_classes({"id": 1, "name": "fish"}, {"id": 2, "name": "dog"}) - items, intgr_error_log = self._db_map.update_object_classes( - {"id": 1, "name": "octopus"}, {"id": 2, "name": "god"} - ) - ids = {x["id"] for x in items} - self._db_map.commit_session("test commit") - sq = self._db_map.object_class_sq - object_classes = {x.id: x.name for x in self._db_map.query(sq).filter(sq.c.id.in_(ids))} - self.assertEqual(intgr_error_log, []) - self.assertEqual(object_classes[1], "octopus") - self.assertEqual(object_classes[2], "god") - - def test_update_objects(self): - """Test that updating objects works.""" - self._db_map.add_object_classes({"id": 1, "name": "fish"}) - self._db_map.add_objects({"id": 1, "name": "nemo", "class_id": 1}, {"id": 2, "name": "dory", "class_id": 1}) - items, intgr_error_log = self._db_map.update_objects({"id": 1, "name": "klaus"}, {"id": 2, "name": "squidward"}) - ids = {x["id"] for x in items} - self._db_map.commit_session("test commit") - sq = self._db_map.object_sq - objects = {x.id: x.name for x in self._db_map.query(sq).filter(sq.c.id.in_(ids))} - self.assertEqual(intgr_error_log, []) - self.assertEqual(objects[1], "klaus") - self.assertEqual(objects[2], "squidward") - - def test_update_objects_not_committed(self): - """Test that updating objects works.""" - self._db_map.add_object_classes({"id": 1, "name": "some_class"}) - self._db_map.add_objects({"id": 1, "name": "nemo", "class_id": 1}) - items, intgr_error_log = self._db_map.update_objects({"id": 1, "name": "klaus"}) - ids = {x["id"] for x in items} - self._db_map.commit_session("test commit") - sq = self._db_map.object_sq - objects = {x.id: x.name for x in self._db_map.query(sq).filter(sq.c.id.in_(ids))} - self.assertEqual(intgr_error_log, []) - self.assertEqual(objects[1], "klaus") - self.assertEqual(self._db_map.query(self._db_map.object_sq).filter_by(id=1).first().name, "klaus") - - def test_update_committed_object(self): - """Test that updating objects works.""" - self._db_map.add_object_classes({"id": 1, "name": "some_class"}) - self._db_map.add_objects({"id": 1, "name": "nemo", "class_id": 1}) - self._db_map.commit_session("update") - items, intgr_error_log = self._db_map.update_objects({"id": 1, "name": "klaus"}) - ids = {x["id"] for x in items} - self._db_map.commit_session("test commit") - sq = self._db_map.object_sq - objects = {x.id: x.name for x in self._db_map.query(sq).filter(sq.c.id.in_(ids))} - self.assertEqual(intgr_error_log, []) - self.assertEqual(objects[1], "klaus") - self.assertEqual(self._db_map.query(self._db_map.object_sq).filter_by(id=1).first().name, "klaus") - - def test_update_relationship_classes(self): - """Test that updating relationship classes works.""" - self._db_map.add_object_classes({"name": "dog", "id": 1}, {"name": "fish", "id": 2}) - self._db_map.add_wide_relationship_classes( - {"id": 3, "name": "dog__fish", "object_class_id_list": [1, 2]}, - {"id": 4, "name": "fish__dog", "object_class_id_list": [2, 1]}, - ) - items, intgr_error_log = self._db_map.update_wide_relationship_classes( - {"id": 3, "name": "god__octopus"}, {"id": 4, "name": "octopus__dog"} - ) - ids = {x["id"] for x in items} - self._db_map.commit_session("test commit") - sq = self._db_map.wide_relationship_class_sq - rel_clss = {x.id: x.name for x in self._db_map.query(sq).filter(sq.c.id.in_(ids))} - self.assertEqual(intgr_error_log, []) - self.assertEqual(rel_clss[3], "god__octopus") - self.assertEqual(rel_clss[4], "octopus__dog") - - def test_update_relationships(self): - """Test that updating relationships works.""" - self._db_map.add_object_classes({"name": "fish", "id": 1}, {"name": "dog", "id": 2}) - self._db_map.add_wide_relationship_classes({"name": "fish__dog", "id": 3, "object_class_id_list": [1, 2]}) - self._db_map.add_objects( - {"name": "nemo", "id": 1, "class_id": 1}, - {"name": "pluto", "id": 2, "class_id": 2}, - {"name": "scooby", "id": 3, "class_id": 2}, - ) - self._db_map.add_wide_relationships( - {"id": 4, "name": "nemo__pluto", "class_id": 3, "object_id_list": [1, 2], "object_class_id_list": [1, 2]} - ) - items, intgr_error_log = self._db_map.update_wide_relationships( - {"id": 4, "name": "nemo__scooby", "class_id": 3, "object_id_list": [1, 3], "object_class_id_list": [1, 2]} - ) - ids = {x["id"] for x in items} - self._db_map.commit_session("test commit") - sq = self._db_map.wide_relationship_sq - rels = { - x.id: {"name": x.name, "object_id_list": x.object_id_list} - for x in self._db_map.query(sq).filter(sq.c.id.in_(ids)) - } - self.assertEqual(intgr_error_log, []) - self.assertEqual(rels[4]["name"], "nemo__scooby") - self.assertEqual(rels[4]["object_id_list"], "1,3") - - -class TestDatabaseMappingCommit(unittest.TestCase): - def setUp(self): - self._db_map = create_diff_db_map() - - def tearDown(self): - self._db_map.close() - - def test_commit_message(self): - """Tests that commit comment ends up in the database.""" - self._db_map.add_object_classes({"name": "testclass"}) - self._db_map.commit_session("test commit") - self.assertEqual(self._db_map.query(self._db_map.commit_sq).all()[-1].comment, "test commit") - self._db_map.close() - - def test_commit_session_raise_with_empty_comment(self): - import_functions.import_object_classes(self._db_map, ("my_class",)) - self.assertRaisesRegex(SpineDBAPIError, "Commit message cannot be empty.", self._db_map.commit_session, "") - - def test_commit_session_raise_when_nothing_to_commit(self): - self.assertRaisesRegex(SpineDBAPIError, "Nothing to commit.", self._db_map.commit_session, "No changes.") - - -if __name__ == "__main__": - unittest.main() From 68f3ffec74e900555c60e5ef6520b8f57b649706 Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 23 May 2023 17:24:14 +0200 Subject: [PATCH 048/317] Implement refresh_session, fix query and add rollback_session tests --- spinedb_api/db_cache_base.py | 59 +++++++++++++++++++++++--- spinedb_api/db_mapping_commit_mixin.py | 3 +- spinedb_api/query.py | 8 +++- tests/test_DatabaseMapping.py | 37 ++++++++++++++++ 4 files changed, 96 insertions(+), 11 deletions(-) diff --git a/spinedb_api/db_cache_base.py b/spinedb_api/db_cache_base.py index d9c876c8..e0b4c0a1 100644 --- a/spinedb_api/db_cache_base.py +++ b/spinedb_api/db_cache_base.py @@ -35,10 +35,16 @@ class DBCacheBase(dict): def __init__(self, chunk_size=None): super().__init__() + self._updated_items = {} + self._removed_items = {} self._offsets = {} self._fetched_item_types = set() self._chunk_size = chunk_size + @property + def fetched_item_types(self): + return self._fetched_item_types + def _item_factory(self, item_type): raise NotImplementedError() @@ -62,6 +68,12 @@ def _sorted_item_types(self): sorted(self, key=cmp_to_key(self._cmp_item_type)) def dirty_items(self): + """Returns a list of tuples of the form (item_type, (to_add, to_update, to_remove)) corresponding to + items that have been modified but not yet committed. + + Returns: + list + """ dirty_items = [] for item_type in sorted(self, key=cmp_to_key(self._cmp_item_type)): table_cache = self[item_type] @@ -89,6 +101,13 @@ def dirty_items(self): return dirty_items def rollback(self): + """Discards uncommitted changes. + + Namely, removes all the added items, resets all the updated items, and restores all the removed items. + + Returns: + bool: False if there is no uncommitted items, True if successful. + """ dirty_items = self.dirty_items() if not dirty_items: return False @@ -114,12 +133,24 @@ def rollback(self): del item["id"] return True - @property - def fetched_item_types(self): - return self._fetched_item_types - - def reset_queries(self): - """Resets queries and clears caches.""" + def refresh(self): + """Stores dirty items in internal dictionaries and clears the cache, so the DB can be fetched again. + Conflicts between new contents of the DB and dirty items are solved in favor of the latter + (See ``advance_query`` where we resolve those conflicts as consuming the queries). + """ + dirty_items = self.dirty_items() # Get dirty items before clearing + self.clear() + self._updated_items.clear() + self._removed_items.clear() + for item_type, (to_add, to_update, to_remove) in dirty_items: + # Add new items directly + table_cache = self.table_cache(item_type) + for item in to_add: + table_cache.add_item(item, new=True) + # Store updated and removed so we can take the proper action + # when we see their equivalents comming from the DB + self._updated_items[item_type] = {x["id"]: x for x in to_update} + self._removed_items[item_type] = {x["id"]: x for x in to_remove} self._offsets.clear() self._fetched_item_types.clear() @@ -147,8 +178,17 @@ def advance_query(self, item_type): self._fetched_item_types.add(item_type) return [] table_cache = self.table_cache(item_type) + updated_items = self._updated_items.get(item_type, {}) + removed_items = self._removed_items.get(item_type, {}) for item in chunk: - # FIXME: This will overwrite working changes after a refresh + updated_item = updated_items.get(item["id"]) + if updated_item: + table_cache.persist_item(updated_item) + continue + removed_item = removed_items.get(item["id"]) + if removed_item: + table_cache.persist_item(removed_item, removed=True) + continue table_cache.add_item(item) return chunk @@ -319,6 +359,11 @@ def _remove_unique(self, item): for key, value in item.unique_values(): self._id_by_unique_key_value.get(key, {}).pop(value, None) + def persist_item(self, item, removed=False): + self[item["id"]] = item + if not removed: + self._add_unique(item) + def add_item(self, item, new=False): if "id" not in item: item["id"] = self._new_id() diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py index 349be3fc..aa0700ea 100644 --- a/spinedb_api/db_mapping_commit_mixin.py +++ b/spinedb_api/db_mapping_commit_mixin.py @@ -51,5 +51,4 @@ def rollback_session(self): raise SpineDBAPIError("Nothing to rollback.") def refresh_session(self): - # TODO - pass + self.cache.refresh() diff --git a/spinedb_api/query.py b/spinedb_api/query.py index df75daff..707724a5 100644 --- a/spinedb_api/query.py +++ b/spinedb_api/query.py @@ -13,6 +13,7 @@ from sqlalchemy import select, and_ from sqlalchemy.sql.functions import count +from sqlalchemy.exc import OperationalError from .exception import SpineDBAPIError @@ -88,7 +89,10 @@ def having(self, *args): return self def _result(self): - return self._bind.execute(self._select) + try: + return self._bind.execute(self._select) + except OperationalError: + return None def all(self): return self._result().fetchall() @@ -113,7 +117,7 @@ def count(self): return self._bind.execute(select([count()]).select_from(self._select)).scalar() def __iter__(self): - return self._result() + return self._result() or iter([]) def _get_leaves(parent): diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index ba096fe0..d80909d6 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -2196,6 +2196,43 @@ def test_commit_session_raise_with_empty_comment(self): def test_commit_session_raise_when_nothing_to_commit(self): self.assertRaisesRegex(SpineDBAPIError, "Nothing to commit.", self._db_map.commit_session, "No changes.") + def test_rollback_addition(self): + import_functions.import_object_classes(self._db_map, ("my_class",)) + self._db_map.commit_session("test commit") + import_functions.import_object_classes(self._db_map, ("second_class",)) + self.assertEqual(len(list(self._db_map.cache.table_cache("entity_class").values())), 2) + self._db_map.rollback_session() + self.assertEqual(len(list(self._db_map.cache.table_cache("entity_class").values())), 1) + with self.assertRaises(SpineDBAPIError): + # Nothing to commit + self._db_map.commit_session("test commit") + + def test_rollback_removal(self): + import_functions.import_object_classes(self._db_map, ("my_class",)) + self._db_map.commit_session("test commit") + self._db_map.remove_items("entity_class", 1) + self.assertEqual(len(list(self._db_map.cache.table_cache("entity_class").values())), 0) + self._db_map.rollback_session() + self.assertEqual(len(list(self._db_map.cache.table_cache("entity_class").values())), 1) + with self.assertRaises(SpineDBAPIError): + # Nothing to commit + self._db_map.commit_session("test commit") + + def test_rollback_update(self): + import_functions.import_object_classes(self._db_map, ("my_class",)) + self._db_map.commit_session("test commit") + self._db_map.update_items("entity_class", {"id": {"name": "my_class"}, "name": "new_name"}) + entity_classes = list(self._db_map.cache.table_cache("entity_class").values()) + self.assertEqual(len(entity_classes), 1) + self.assertEqual(entity_classes[0]["name"], "new_name") + self._db_map.rollback_session() + entity_classes = list(self._db_map.cache.table_cache("entity_class").values()) + self.assertEqual(len(entity_classes), 1) + self.assertEqual(entity_classes[0]["name"], "my_class") + with self.assertRaises(SpineDBAPIError): + # Nothing to commit + self._db_map.commit_session("test commit") + if __name__ == "__main__": unittest.main() From 4da0da268245357fcda75093838a38076ef5e84d Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 24 May 2023 09:28:42 +0200 Subject: [PATCH 049/317] Fix minor things and add tests for refresh_session --- spinedb_api/db_cache_base.py | 44 ++++++++++++--- spinedb_api/db_cache_impl.py | 7 ++- spinedb_api/db_mapping_add_mixin.py | 4 +- spinedb_api/db_mapping_base.py | 12 ++-- spinedb_api/db_mapping_commit_mixin.py | 3 + spinedb_api/db_mapping_update_mixin.py | 4 +- spinedb_api/helpers.py | 4 +- spinedb_api/query.py | 6 +- tests/test_DatabaseMapping.py | 55 +++++++++++++++---- ...k_functions.py => test_check_integrity.py} | 4 +- 10 files changed, 104 insertions(+), 39 deletions(-) rename tests/{test_check_functions.py => test_check_integrity.py} (97%) diff --git a/spinedb_api/db_cache_base.py b/spinedb_api/db_cache_base.py index e0b4c0a1..3d056c05 100644 --- a/spinedb_api/db_cache_base.py +++ b/spinedb_api/db_cache_base.py @@ -45,7 +45,8 @@ def __init__(self, chunk_size=None): def fetched_item_types(self): return self._fetched_item_types - def _item_factory(self, item_type): + @staticmethod + def _item_factory(item_type): raise NotImplementedError() def _query(self, item_type): @@ -100,6 +101,11 @@ def dirty_items(self): dirty_items.append((item_type, (to_add, to_update, to_remove))) return dirty_items + def commit(self): + """Clears the internal storage of dirty items created by ``refresh``.""" + self._updated_items.clear() + self._removed_items.clear() + def rollback(self): """Discards uncommitted changes. @@ -140,8 +146,6 @@ def refresh(self): """ dirty_items = self.dirty_items() # Get dirty items before clearing self.clear() - self._updated_items.clear() - self._removed_items.clear() for item_type, (to_add, to_update, to_remove) in dirty_items: # Add new items directly table_cache = self.table_cache(item_type) @@ -149,13 +153,15 @@ def refresh(self): table_cache.add_item(item, new=True) # Store updated and removed so we can take the proper action # when we see their equivalents comming from the DB - self._updated_items[item_type] = {x["id"]: x for x in to_update} - self._removed_items[item_type] = {x["id"]: x for x in to_remove} + self._updated_items.setdefault(item_type, {}).update({x["id"]: x for x in to_update}) + self._removed_items.setdefault(item_type, {}).update({x["id"]: x for x in to_remove}) self._offsets.clear() self._fetched_item_types.clear() def _get_next_chunk(self, item_type): qry = self._query(item_type) + if not qry: + return [] if not self._chunk_size: self._fetched_item_types.add(item_type) return [dict(x) for x in qry] @@ -165,7 +171,8 @@ def _get_next_chunk(self, item_type): return chunk def advance_query(self, item_type): - """Advances the DB query that fetches items of given type and caches the results. + """Advances the DB query that fetches items of given type + and adds the results to the corresponding table cache. Args: item_type (str) @@ -181,11 +188,11 @@ def advance_query(self, item_type): updated_items = self._updated_items.get(item_type, {}) removed_items = self._removed_items.get(item_type, {}) for item in chunk: - updated_item = updated_items.get(item["id"]) + updated_item = updated_items.pop(item["id"], None) if updated_item: table_cache.persist_item(updated_item) continue - removed_item = removed_items.get(item["id"]) + removed_item = removed_items.pop(item["id"], None) if removed_item: table_cache.persist_item(removed_item, removed=True) continue @@ -400,9 +407,30 @@ class CacheItemBase(TempIdDict): """A dictionary that represents an db item.""" _defaults = {} + """A dictionary mapping fields to their default values""" _unique_keys = () + """A tuple where each element is itself a tuple of fields indicating a unique key""" _references = {} + """A dictionary mapping fields that are not in the original dictionary, + to a recipe for finding the field they reference in another item. + + The recipe is a tuple of the form (original_field, (ref_item_type, ref_field)), + to be interpreted as follows: + 1. take the value from the original_field of this item, which should be an id, + 2. locate the item of type ref_item_type that has that id, + 3. return the value from the ref_field of that item. + """ _inverse_references = {} + """Another dictionary mapping fields that are not in the original dictionary, + to a recipe for finding the field they reference in another item. + Used only for creating new items, when the user provides names and we want to find the ids. + + The recipe is a tuple of the form (src_unique_key, (ref_item_type, ref_unique_key)), + to be interpreted as follows: + 1. take the values from the src_unique_key of this item, to form a tuple, + 2. locate the item of type ref_item_type where the ref_unique_key is exactly that tuple of values, + 3. return the id of that item. + """ def __init__(self, db_cache, item_type, **kwargs): """ diff --git a/spinedb_api/db_cache_impl.py b/spinedb_api/db_cache_impl.py index 805f18c3..2d204b1e 100644 --- a/spinedb_api/db_cache_impl.py +++ b/spinedb_api/db_cache_impl.py @@ -27,7 +27,8 @@ def __init__(self, db_map, chunk_size=None): super().__init__(chunk_size=chunk_size) self._db_map = db_map - def _item_factory(self, item_type): + @staticmethod + def _item_factory(item_type): return { "entity_class": EntityClassItem, "entity": EntityItem, @@ -45,6 +46,8 @@ def _item_factory(self, item_type): }.get(item_type, CacheItemBase) def _query(self, item_type): + if self._db_map.closed: + return None sq_name = { "entity_class": "wide_entity_class_sq", "entity": "wide_entity_sq", @@ -89,7 +92,7 @@ def merge(self, other): merged, super_error = super().merge(other) return merged, " and ".join([x for x in (super_error, error) if x]) - def commit(self, commit_id): + def commit(self, _commit_id): super().commit(None) diff --git a/spinedb_api/db_mapping_add_mixin.py b/spinedb_api/db_mapping_add_mixin.py index f17b46c9..d12bb19c 100644 --- a/spinedb_api/db_mapping_add_mixin.py +++ b/spinedb_api/db_mapping_add_mixin.py @@ -15,7 +15,7 @@ # TODO: improve docstrings from sqlalchemy.exc import DBAPIError -from .exception import SpineIntegrityError +from .exception import SpineIntegrityError, SpineDBAPIError class DatabaseMappingAddMixin: @@ -87,7 +87,7 @@ def _do_add_items(self, connection, tablename, *items_to_add): connection.execute(table.insert(), items_to_add_) except DBAPIError as e: msg = f"DBAPIError while inserting {tablename} items: {e.orig.args}" - raise SpineIntegrityError(msg) from e + raise SpineDBAPIError(msg) from e @staticmethod def _extra_items_to_add_per_table(tablename, items_to_add): diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 2bf0139b..e37eb012 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -308,6 +308,11 @@ def _get_table_to_sq_attr(self): def _make_table_to_sq_attr(self): """Returns a dict mapping table names to subquery attribute names, involving that table.""" + + def _func(x, tables): + if isinstance(x, Table): + tables.add(x.name) # pylint: disable=cell-var-from-loop + # This 'loads' our subquery attributes for attr in dir(self): getattr(self, attr) @@ -316,12 +321,7 @@ def _make_table_to_sq_attr(self): if not isinstance(val, Alias): continue tables = set() - - def _func(x): - if isinstance(x, Table): - tables.add(x.name) # pylint: disable=cell-var-from-loop - - forward_sweep(val, _func) + forward_sweep(val, _func, tables) # Now `tables` contains all tables related to `val` for table in tables: table_to_sq_attr.setdefault(table, set()).add(attr) diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py index aa0700ea..1097a623 100644 --- a/spinedb_api/db_mapping_commit_mixin.py +++ b/spinedb_api/db_mapping_commit_mixin.py @@ -43,12 +43,15 @@ def commit_session(self, comment): self._do_add_items(connection, tablename, *to_add) self._do_update_items(connection, tablename, *to_update) self._do_remove_items(connection, tablename, *{x["id"] for x in to_remove}) + self.cache.commit() if self._memory: self._memory_dirty = True def rollback_session(self): if not self.cache.rollback(): raise SpineDBAPIError("Nothing to rollback.") + if self._memory: + self._memory_dirty = False def refresh_session(self): self.cache.refresh() diff --git a/spinedb_api/db_mapping_update_mixin.py b/spinedb_api/db_mapping_update_mixin.py index 3d0e0bdf..7960af9a 100644 --- a/spinedb_api/db_mapping_update_mixin.py +++ b/spinedb_api/db_mapping_update_mixin.py @@ -14,7 +14,7 @@ """ from sqlalchemy.exc import DBAPIError from sqlalchemy.sql.expression import bindparam -from .exception import SpineIntegrityError +from .exception import SpineIntegrityError, SpineDBAPIError class DatabaseMappingUpdateMixin: @@ -41,7 +41,7 @@ def _do_update_items(self, connection, tablename, *items_to_update): connection.execute(upd, items_to_update_) except DBAPIError as e: msg = f"DBAPIError while updating '{tablename}' items: {e.orig.args}" - raise SpineIntegrityError(msg) from e + raise SpineDBAPIError(msg) from e @staticmethod def _extra_items_to_update_per_table(tablename, items_to_update): diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index 0e1daeaa..f83d3224 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -744,14 +744,14 @@ def _create_first_spine_database(db_url): return engine -def forward_sweep(root, fn): +def forward_sweep(root, fn, *args): """Recursively visit, using `get_children()`, the given sqlalchemy object. Apply `fn` on every visited node.""" current = root parent = {} children = {current: iter(current.get_children(column_collections=False))} while True: - fn(current) + fn(current, *args) # Try and visit next children next_ = next(children[current], None) if next_ is not None: diff --git a/spinedb_api/query.py b/spinedb_api/query.py index 707724a5..ac375c37 100644 --- a/spinedb_api/query.py +++ b/spinedb_api/query.py @@ -13,7 +13,6 @@ from sqlalchemy import select, and_ from sqlalchemy.sql.functions import count -from sqlalchemy.exc import OperationalError from .exception import SpineDBAPIError @@ -89,10 +88,7 @@ def having(self, *args): return self def _result(self): - try: - return self._bind.execute(self._select) - except OperationalError: - return None + return self._bind.execute(self._select) def all(self): return self._result().fetchall() diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index d80909d6..43cdd93c 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -2200,9 +2200,11 @@ def test_rollback_addition(self): import_functions.import_object_classes(self._db_map, ("my_class",)) self._db_map.commit_session("test commit") import_functions.import_object_classes(self._db_map, ("second_class",)) - self.assertEqual(len(list(self._db_map.cache.table_cache("entity_class").values())), 2) + entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").values()} + self.assertEqual(entity_class_names, {"my_class", "second_class"}) self._db_map.rollback_session() - self.assertEqual(len(list(self._db_map.cache.table_cache("entity_class").values())), 1) + entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").values()} + self.assertEqual(entity_class_names, {"my_class"}) with self.assertRaises(SpineDBAPIError): # Nothing to commit self._db_map.commit_session("test commit") @@ -2211,9 +2213,11 @@ def test_rollback_removal(self): import_functions.import_object_classes(self._db_map, ("my_class",)) self._db_map.commit_session("test commit") self._db_map.remove_items("entity_class", 1) - self.assertEqual(len(list(self._db_map.cache.table_cache("entity_class").values())), 0) + entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").values()} + self.assertEqual(entity_class_names, set()) self._db_map.rollback_session() - self.assertEqual(len(list(self._db_map.cache.table_cache("entity_class").values())), 1) + entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").values()} + self.assertEqual(entity_class_names, {"my_class"}) with self.assertRaises(SpineDBAPIError): # Nothing to commit self._db_map.commit_session("test commit") @@ -2222,17 +2226,48 @@ def test_rollback_update(self): import_functions.import_object_classes(self._db_map, ("my_class",)) self._db_map.commit_session("test commit") self._db_map.update_items("entity_class", {"id": {"name": "my_class"}, "name": "new_name"}) - entity_classes = list(self._db_map.cache.table_cache("entity_class").values()) - self.assertEqual(len(entity_classes), 1) - self.assertEqual(entity_classes[0]["name"], "new_name") + entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").values()} + self.assertEqual(entity_class_names, {"new_name"}) self._db_map.rollback_session() - entity_classes = list(self._db_map.cache.table_cache("entity_class").values()) - self.assertEqual(len(entity_classes), 1) - self.assertEqual(entity_classes[0]["name"], "my_class") + entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").values()} + self.assertEqual(entity_class_names, {"my_class"}) with self.assertRaises(SpineDBAPIError): # Nothing to commit self._db_map.commit_session("test commit") + def test_refresh_addition(self): + import_functions.import_object_classes(self._db_map, ("my_class",)) + self._db_map.commit_session("test commit") + import_functions.import_object_classes(self._db_map, ("second_class",)) + entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").values()} + self.assertEqual(entity_class_names, {"my_class", "second_class"}) + self._db_map.refresh_session() + self._db_map.fetch_all() + entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").values()} + self.assertEqual(entity_class_names, {"my_class", "second_class"}) + + def test_refresh_removal(self): + import_functions.import_object_classes(self._db_map, ("my_class",)) + self._db_map.commit_session("test commit") + self._db_map.remove_items("entity_class", 1) + entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").values()} + self.assertEqual(entity_class_names, set()) + self._db_map.refresh_session() + self._db_map.fetch_all() + entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").values()} + self.assertEqual(entity_class_names, set()) + + def test_refresh_update(self): + import_functions.import_object_classes(self._db_map, ("my_class",)) + self._db_map.commit_session("test commit") + self._db_map.update_items("entity_class", {"id": {"name": "my_class"}, "name": "new_name"}) + entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").values()} + self.assertEqual(entity_class_names, {"new_name"}) + self._db_map.refresh_session() + self._db_map.fetch_all() + entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").values()} + self.assertEqual(entity_class_names, {"new_name"}) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_check_functions.py b/tests/test_check_integrity.py similarity index 97% rename from tests/test_check_functions.py rename to tests/test_check_integrity.py index ea4cdbeb..9d54caae 100644 --- a/tests/test_check_functions.py +++ b/tests/test_check_integrity.py @@ -23,7 +23,7 @@ def _val_dict(val): return dict(zip(keys, values)) -class TestCheckFunctions(unittest.TestCase): +class TestCheckIntegrity(unittest.TestCase): def setUp(self): self.data = [ (bool, (b'"TRUE"', b'"FALSE"', b'"T"', b'"True"', b'"False"'), (b'true', b'false')), @@ -71,7 +71,7 @@ def get_item(id_: int, val: bytes, entity_id: int): 'alternative_id': 1, } - def test_replace_parameter_or_default_values_with_list_references(self): + def test_parameter_values_and_default_values_with_list_references(self): # regression test for spine-tools/Spine-Toolbox#1878 for type_, fail, pass_ in self.data: id_ = self.value_type[type_] # setup: parameter definition/value list ids are equal From 1ff78a4e07014eaa1f1dd54bae6f6df27a286a46 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 24 May 2023 15:18:18 +0200 Subject: [PATCH 050/317] Fix refresh and implement dirty_ids --- spinedb_api/db_cache_base.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/spinedb_api/db_cache_base.py b/spinedb_api/db_cache_base.py index 3d056c05..2b7bed10 100644 --- a/spinedb_api/db_cache_base.py +++ b/spinedb_api/db_cache_base.py @@ -68,6 +68,11 @@ def _cmp_item_type(self, a, b): def _sorted_item_types(self): sorted(self, key=cmp_to_key(self._cmp_item_type)) + def dirty_ids(self, item_type): + return { + item["id"] for item in self.get(item_type, {}).values() if item.status in (Status.to_add, Status.to_update) + } + def dirty_items(self): """Returns a list of tuples of the form (item_type, (to_add, to_update, to_remove)) corresponding to items that have been modified but not yet committed. @@ -146,6 +151,10 @@ def refresh(self): """ dirty_items = self.dirty_items() # Get dirty items before clearing self.clear() + # Clear _offsets and _fetched_item_types before adding dirty items below, + # so those items are able to properly fetch their references from the DB + self._offsets.clear() + self._fetched_item_types.clear() for item_type, (to_add, to_update, to_remove) in dirty_items: # Add new items directly table_cache = self.table_cache(item_type) @@ -155,8 +164,6 @@ def refresh(self): # when we see their equivalents comming from the DB self._updated_items.setdefault(item_type, {}).update({x["id"]: x for x in to_update}) self._removed_items.setdefault(item_type, {}).update({x["id"]: x for x in to_remove}) - self._offsets.clear() - self._fetched_item_types.clear() def _get_next_chunk(self, item_type): qry = self._query(item_type) From c483aca2a7334303196d4a2c4a3848f7853a389f Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 24 May 2023 15:49:03 +0200 Subject: [PATCH 051/317] Fix management of unique keys --- spinedb_api/db_cache_base.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/spinedb_api/db_cache_base.py b/spinedb_api/db_cache_base.py index 2b7bed10..73e5127b 100644 --- a/spinedb_api/db_cache_base.py +++ b/spinedb_api/db_cache_base.py @@ -279,7 +279,7 @@ def unique_key_value_to_id(self, key, value, strict=False): return id_by_unique_value[value] return id_by_unique_value.get(value) - def _unique_key_value_to_item(self, key, value, strict=False): + def _unique_key_value_to_item(self, key, value): return self.get(self.unique_key_value_to_id(key, value)) def values(self): @@ -371,7 +371,9 @@ def _add_unique(self, item): def _remove_unique(self, item): for key, value in item.unique_values(): - self._id_by_unique_key_value.get(key, {}).pop(value, None) + id_by_value = self._id_by_unique_key_value.get(key, {}) + if id_by_value.get(value) == item["id"]: + del id_by_value[value] def persist_item(self, item, removed=False): self[item["id"]] = item From 69d15ce7cf0f088ef72d47e8a55546fd154f75e5 Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 25 May 2023 09:18:48 +0200 Subject: [PATCH 052/317] Document CacheItemBase --- spinedb_api/db_cache_base.py | 293 ++++++++++++++++++++++++++--------- 1 file changed, 217 insertions(+), 76 deletions(-) diff --git a/spinedb_api/db_cache_base.py b/spinedb_api/db_cache_base.py index 73e5127b..05cbb543 100644 --- a/spinedb_api/db_cache_base.py +++ b/spinedb_api/db_cache_base.py @@ -348,9 +348,9 @@ def check_item(self, item, for_update=False, skip_keys=()): error = candidate_item.polish() if error: return None, error - invalid_ref = candidate_item.invalid_ref() - if invalid_ref: - return None, f"invalid {invalid_ref} for {self._item_type}" + first_invalid_key = candidate_item.first_invalid_key() + if first_invalid_key: + return None, f"invalid {first_invalid_key} for {self._item_type}" try: for key, value in candidate_item.unique_values(skip_keys=skip_keys): empty = {k for k, v in zip(key, value) if v == ""} @@ -416,11 +416,11 @@ class CacheItemBase(TempIdDict): """A dictionary that represents an db item.""" _defaults = {} - """A dictionary mapping fields to their default values""" + """A dictionary mapping keys to their default values""" _unique_keys = () - """A tuple where each element is itself a tuple of fields indicating a unique key""" + """A tuple where each element is itself a tuple of keys that are unique""" _references = {} - """A dictionary mapping fields that are not in the original dictionary, + """A dictionary mapping keys that are not in the original dictionary, to a recipe for finding the field they reference in another item. The recipe is a tuple of the form (original_field, (ref_item_type, ref_field)), @@ -430,7 +430,7 @@ class CacheItemBase(TempIdDict): 3. return the value from the ref_field of that item. """ _inverse_references = {} - """Another dictionary mapping fields that are not in the original dictionary, + """Another dictionary mapping keys that are not in the original dictionary, to a recipe for finding the field they reference in another item. Used only for creating new items, when the user provides names and we want to find the ids. @@ -463,94 +463,147 @@ def __init__(self, db_cache, item_type, **kwargs): @classmethod def ref_types(cls): + """Returns a set of item types that this class refers. + + Returns: + set(str) + """ return set(ref_type for _src_key, (ref_type, _ref_key) in cls._references.values()) @property def status(self): + """Returns the status of this item. + + Returns: + Status + """ return self._status @status.setter def status(self, status): + """Sets the status of this item. + + Args: + status (Status) + """ self._status = status @property def backup(self): + """Returns the committed version of this item. + + Returns: + dict or None + """ return self._backup @property def removed(self): + """Returns whether or not this item has been removed. + + Returns: + bool + """ return self._removed @property def item_type(self): + """Returns this item's type + + Returns: + str + """ return self._item_type @property def key(self): + """Returns a tuple (item_type, id) for convenience, or None if this item doesn't yet have an id. + TODO: When does the latter happen? + + Returns: + tuple(str,int) or None + """ id_ = dict.get(self, "id") if id_ is None: return None return (self._item_type, id_) - def __repr__(self): - return f"{self._item_type}{self._extended()}" - - def __getattr__(self, name): - """Overridden method to return the dictionary key named after the attribute, or None if it doesn't exist.""" - # FIXME: We should try and get rid of this one - return self.get(name) + def _extended(self): + """Returns a dict from this item's original fields plus all the references resolved statically. - def __getitem__(self, key): - ref = self._references.get(key) - if ref: - src_key, (ref_type, ref_key) = ref - ref_id = self[src_key] - if isinstance(ref_id, tuple): - return tuple(self._get_ref(ref_type, x).get(ref_key) for x in ref_id) - return self._get_ref(ref_type, ref_id).get(ref_key) - return super().__getitem__(key) + Returns: + dict + """ + return {**self, **{key: self[key] for key in self._references}} - def get(self, key, default=None): - try: - return self[key] - except KeyError: - return default + def _asdict(self): + """Returns a dict from this item's original fields. - def update(self, other): - if self._status == Status.committed: - self._status = Status.to_update - self._backup = self._asdict() - for src_key, (ref_type, _ref_key) in self._references.values(): - ref_id = self[src_key] - if src_key in other and other[src_key] != ref_id: - # Forget references - if isinstance(ref_id, tuple): - for x in ref_id: - self._forget_ref(ref_type, x) - else: - self._forget_ref(ref_type, ref_id) - super().update(other) - if self._asdict() == self._backup: - self._status = Status.committed + Returns: + dict + """ + return dict(self) def merge(self, other): + """Merges this item with another and returns the merged item together with any errors. + Used for updating items. + + Args: + other (dict): the item to merge into this. + + Returns: + dict: merged item. + str: error description if any. + """ if all(self.get(key) == value for key, value in other.items()): return None, "" merged = {**self._extended(), **other} merged["id"] = self["id"] return merged, "" - def polish(self): - """Polishes this item once all it's references are resolved. Returns any errors. + def first_invalid_key(self): + """Goes through the ``_references`` class attribute and returns the key of the first one + that cannot be resolved. Returns: - str or None + str or None: unresolved reference's key if any. """ - for key, default_value in self._defaults.items(): - self.setdefault(key, default_value) - return "" + for src_key, (ref_type, _ref_key) in self._references.values(): + try: + ref_id = self[src_key] + except KeyError: + return src_key + if isinstance(ref_id, tuple): + for x in ref_id: + if not self._get_ref(ref_type, x): + return src_key + elif not self._get_ref(ref_type, ref_id): + return src_key + + def unique_values(self, skip_keys=()): + """Yields tuples of unique keys and their values. + + Args: + skip_keys: Don't yield these keys + + Yields: + tuple(tuple,tuple): the first element is the unique key, the second is the values. + """ + for key in self._unique_keys: + if key not in skip_keys: + yield key, tuple(self.get(k) for k in key) def resolve_inverse_references(self, skip_keys=()): + """Goes through the ``_inverse_references`` class attribute and updates this item + by resolving those references. + Returns any error. + + Args: + skip_keys (tuple): don't resolve references for these keys. + + Returns: + str or None: error description if any. + """ for src_key, (id_key, (ref_type, ref_key)) in self._inverse_references.items(): if src_key in skip_keys: continue @@ -568,25 +621,32 @@ def resolve_inverse_references(self, skip_keys=()): # Happens at unique_key_value_to_id(..., strict=True) return f"can't find {ref_type} with {dict(zip(ref_key, err.args[0]))}" - def invalid_ref(self): - for src_key, (ref_type, _ref_key) in self._references.values(): - try: - ref_id = self[src_key] - except KeyError: - return src_key - if isinstance(ref_id, tuple): - for x in ref_id: - if not self._get_ref(ref_type, x): - return src_key - elif not self._get_ref(ref_type, ref_id): - return src_key + def polish(self): + """Polishes this item once all it's references have been resolved. Returns any error. - def unique_values(self, skip_keys=()): - for key in self._unique_keys: - if key not in skip_keys: - yield key, tuple(self.get(k) for k in key) + The base implementation sets defaults but subclasses can do more work if needed. + + Returns: + str or None: error description if any. + """ + for key, default_value in self._defaults.items(): + self.setdefault(key, default_value) + return "" def _get_ref(self, ref_type, ref_id, strong=True): + """Collects a reference from the cache. + Adds this item to the reference's list of referrers if strong is True; + or weak referrers if strong is False. + If the reference is not found, sets some flags. + + Args: + ref_type (str): The references's type + ref_id (int): The references's id + strong (bool): True if the reference corresponds to a foreign key, False otherwise + + Returns: + CacheItemBase or dict + """ ref = self._db_cache.get_item(ref_type, ref_id) if not ref: if not strong: @@ -595,9 +655,7 @@ def _get_ref(self, ref_type, ref_id, strong=True): if not ref: self._corrupted = True return {} - return self._handle_ref(ref, strong) - - def _handle_ref(self, ref, strong): + # Here we have a ref if strong: ref.add_referrer(self) if ref.removed: @@ -608,11 +666,23 @@ def _handle_ref(self, ref, strong): return {} return ref - def _forget_ref(self, ref_type, ref_id): + def _invalidate_ref(self, ref_type, ref_id): + """Invalidates a reference previously collected from the cache. + + Args: + ref_type (str): The references's type + ref_id (int): The references's id + """ ref = self._db_cache.get_item(ref_type, ref_id) ref.remove_referrer(self) def is_valid(self): + """Checks if this item has all its references. + Removes the item from the cache if not valid by calling ``cascade_remove``. + + Returns: + bool + """ if self._valid is not None: return self._valid if self._removed or self._corrupted: @@ -627,16 +697,33 @@ def is_valid(self): return self._valid def add_referrer(self, referrer): + """Adds a strong referrer to this item. Strong referrers are removed, updated and restored + in cascade with this item. + + Args: + referrer (CacheItemBase) + """ if referrer.key is None: return self._referrers[referrer.key] = self._weak_referrers.pop(referrer.key, referrer) def remove_referrer(self, referrer): + """Removes a strong referrer. + + Args: + referrer (CacheItemBase) + """ if referrer.key is None: return self._referrers.pop(referrer.key, None) def add_weak_referrer(self, referrer): + """Adds a weak referrer to this item. + Weak referrers' update callbacks are called whenever this item changes. + + Args: + referrer (CacheItemBase) + """ if referrer.key is None: return if referrer.key not in self._referrers: @@ -647,6 +734,9 @@ def _update_weak_referrers(self): weak_referrer.call_update_callbacks() def cascade_restore(self): + """Restores this item (if removed) and all its referrers in cascade. + Also, updates items' status and calls their restore callbacks. + """ if not self._removed: return if self._status == Status.committed: @@ -664,6 +754,9 @@ def cascade_restore(self): self.restore_callbacks -= obsolete def cascade_remove(self): + """Removes this item and all its referrers in cascade. + Also, updates items' status and calls their remove callbacks. + """ if self._removed: return if self._status == Status.committed: @@ -683,6 +776,9 @@ def cascade_remove(self): self._update_weak_referrers() def cascade_update(self): + """Updates this item and all its referrers in cascade. + Also, calls items' update callbacks. + """ self.call_update_callbacks() for referrer in self._referrers.values(): referrer.cascade_update() @@ -695,16 +791,61 @@ def call_update_callbacks(self): obsolete.add(callback) self.update_callbacks -= obsolete - def _extended(self): - return {**self, **{key: self[key] for key in self._references}} - - def _asdict(self): - return dict(self) - def is_committed(self): + """Returns whether or not this item is committed to the DB. + + Returns: + bool + """ return self._status == Status.committed def commit(self, commit_id): + """Sets this item as committed with the given commit id.""" self._status = Status.committed if commit_id: self["commit_id"] = commit_id + + def __repr__(self): + """Overridden to return a more verbose representation.""" + return f"{self._item_type}{self._extended()}" + + def __getattr__(self, name): + """Overridden to return the dictionary key named after the attribute, or None if it doesn't exist.""" + # FIXME: We should try and get rid of this one + return self.get(name) + + def __getitem__(self, key): + """Overridden to return references.""" + ref = self._references.get(key) + if ref: + src_key, (ref_type, ref_key) = ref + ref_id = self[src_key] + if isinstance(ref_id, tuple): + return tuple(self._get_ref(ref_type, x).get(ref_key) for x in ref_id) + return self._get_ref(ref_type, ref_id).get(ref_key) + return super().__getitem__(key) + + def get(self, key, default=None): + """Overridden to return references.""" + try: + return self[key] + except KeyError: + return default + + def update(self, other): + """Overridden to update the item status and also to invalidate references that become obsolete.""" + if self._status == Status.committed: + self._status = Status.to_update + self._backup = self._asdict() + for src_key, (ref_type, _ref_key) in self._references.values(): + ref_id = self[src_key] + if src_key in other and other[src_key] != ref_id: + # Invalidate references + if isinstance(ref_id, tuple): + for x in ref_id: + self._invalidate_ref(ref_type, x) + else: + self._invalidate_ref(ref_type, ref_id) + super().update(other) + if self._asdict() == self._backup: + self._status = Status.committed From 2a2c7a648f2c0231f6010d21e7e48231180a3672 Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 25 May 2023 09:21:23 +0200 Subject: [PATCH 053/317] Fix db server query entry point --- spinedb_api/spine_db_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spinedb_api/spine_db_server.py b/spinedb_api/spine_db_server.py index 5b3df2fc..c6e92698 100644 --- a/spinedb_api/spine_db_server.py +++ b/spinedb_api/spine_db_server.py @@ -294,7 +294,7 @@ def _do_query(self, *args): sq = getattr(self._db_map, sq_name, None) if sq is None: continue - result[sq_name] = [x._asdict() for x in self._db_map.query(sq)] + result[sq_name] = [dict(x) for x in self._db_map.query(sq)] return dict(result=result) def _do_filtered_query(self, **kwargs): From ad5ec655b382f51b00446e1c0360e506cb854b97 Mon Sep 17 00:00:00 2001 From: Manuel Date: Mon, 12 Jun 2023 15:55:49 +0200 Subject: [PATCH 054/317] Fix calling callbacks while others are added from another thread --- spinedb_api/db_cache_base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/spinedb_api/db_cache_base.py b/spinedb_api/db_cache_base.py index 05cbb543..b79294f2 100644 --- a/spinedb_api/db_cache_base.py +++ b/spinedb_api/db_cache_base.py @@ -748,7 +748,7 @@ def cascade_restore(self): referrer.cascade_restore() self._update_weak_referrers() obsolete = set() - for callback in self.restore_callbacks: + for callback in list(self.restore_callbacks): if not callback(self): obsolete.add(callback) self.restore_callbacks -= obsolete @@ -767,7 +767,7 @@ def cascade_remove(self): self._to_remove = False self._valid = None obsolete = set() - for callback in self.remove_callbacks: + for callback in list(self.remove_callbacks): if not callback(self): obsolete.add(callback) self.remove_callbacks -= obsolete @@ -786,7 +786,7 @@ def cascade_update(self): def call_update_callbacks(self): obsolete = set() - for callback in self.update_callbacks: + for callback in list(self.update_callbacks): if not callback(self): obsolete.add(callback) self.update_callbacks -= obsolete From 74009793596cb8442498dac61d523a0da918ba72 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 14 Jun 2023 09:15:52 +0200 Subject: [PATCH 055/317] Fix some issues with temp id replacement --- spinedb_api/db_cache_base.py | 8 ++-- spinedb_api/db_cache_impl.py | 6 +-- spinedb_api/db_mapping_add_mixin.py | 2 +- spinedb_api/db_mapping_commit_mixin.py | 5 ++- spinedb_api/db_mapping_update_mixin.py | 20 ++++++--- spinedb_api/filters/execution_filter.py | 4 +- spinedb_api/import_functions.py | 6 +-- spinedb_api/temp_id.py | 56 +++++++++++++++---------- 8 files changed, 65 insertions(+), 42 deletions(-) diff --git a/spinedb_api/db_cache_base.py b/spinedb_api/db_cache_base.py index b79294f2..52470d31 100644 --- a/spinedb_api/db_cache_base.py +++ b/spinedb_api/db_cache_base.py @@ -296,7 +296,7 @@ def _make_item(self, item): """ return self._db_cache.make_item(self._item_type, **item) - def current_item(self, item, skip_keys=()): + def find_item(self, item, skip_keys=()): """Returns a CacheItemBase that matches the given dictionary-item. Args: @@ -332,7 +332,7 @@ def check_item(self, item, for_update=False, skip_keys=()): # FIXME: The only use-case for skip_keys at the moment is that of importing scenario alternatives, # where we only want to match by (scen_name, alt_name) and not by (scen_name, rank) if for_update: - current_item = self.current_item(item, skip_keys=skip_keys) + current_item = self.find_item(item, skip_keys=skip_keys) if current_item is None: return None, f"no {self._item_type} matching {item} to update" full_item, merge_error = current_item.merge(item) @@ -390,7 +390,7 @@ def add_item(self, item, new=False): return new_item def update_item(self, item): - current_item = self.current_item(item) + current_item = self.find_item(item) self._remove_unique(current_item) current_item.update(item) self._add_unique(current_item) @@ -398,7 +398,7 @@ def update_item(self, item): return current_item def remove_item(self, id_): - current_item = self.current_item({"id": id_}) + current_item = self.find_item({"id": id_}) if current_item is not None: self._remove_unique(current_item) current_item.cascade_remove() diff --git a/spinedb_api/db_cache_impl.py b/spinedb_api/db_cache_impl.py index 2d204b1e..f5c0a6b7 100644 --- a/spinedb_api/db_cache_impl.py +++ b/spinedb_api/db_cache_impl.py @@ -381,7 +381,7 @@ class ScenarioItem(CacheItemBase): _unique_keys = (("name",),) @property - def sorted_alternatives(self): + def sorted_scenario_alternatives(self): self._db_cache.fetch_all("scenario_alternative") return sorted( (x for x in self._db_cache.get("scenario_alternative", {}).values() if x["scenario_id"] == self["id"]), @@ -390,9 +390,9 @@ def sorted_alternatives(self): def __getitem__(self, key): if key == "alternative_id_list": - return [x["alternative_id"] for x in self.sorted_alternatives] + return [x["alternative_id"] for x in self.sorted_scenario_alternatives] if key == "alternative_name_list": - return [x["alternative_name"] for x in self.sorted_alternatives] + return [x["alternative_name"] for x in self.sorted_scenario_alternatives] return super().__getitem__(key) diff --git a/spinedb_api/db_mapping_add_mixin.py b/spinedb_api/db_mapping_add_mixin.py index d12bb19c..3ae54159 100644 --- a/spinedb_api/db_mapping_add_mixin.py +++ b/spinedb_api/db_mapping_add_mixin.py @@ -188,4 +188,4 @@ def add_ext_parameter_value_metadata(self, *items, **kwargs): def get_metadata_to_add_with_entity_metadata_items(self, *items): metadata_items = ({"name": item["metadata_name"], "value": item["metadata_value"]} for item in items) - return [x for x in metadata_items if not self.cache.table_cache("metadata").current_item(x)] + return [x for x in metadata_items if not self.cache.table_cache("metadata").find_item(x)] diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py index 1097a623..6611a8a2 100644 --- a/spinedb_api/db_mapping_commit_mixin.py +++ b/spinedb_api/db_mapping_commit_mixin.py @@ -40,9 +40,10 @@ def commit_session(self, comment): for tablename, (to_add, to_update, to_remove) in dirty_items: for item in to_add + to_update + to_remove: item.commit(commit_id) - self._do_add_items(connection, tablename, *to_add) - self._do_update_items(connection, tablename, *to_update) + # Remove before add, to help with keeping integrity constraints self._do_remove_items(connection, tablename, *{x["id"] for x in to_remove}) + self._do_update_items(connection, tablename, *to_update) + self._do_add_items(connection, tablename, *to_add) self.cache.commit() if self._memory: self._memory_dirty = True diff --git a/spinedb_api/db_mapping_update_mixin.py b/spinedb_api/db_mapping_update_mixin.py index 7960af9a..5335b862 100644 --- a/spinedb_api/db_mapping_update_mixin.py +++ b/spinedb_api/db_mapping_update_mixin.py @@ -181,10 +181,10 @@ def get_data_to_set_scenario_alternatives(self, *scenarios, strict=True): set: integer scenario_alternative ids to remove """ scen_alts_to_add = [] - scen_alt_ids_to_remove = set() + scen_alt_ids_to_remove = {} errors = [] for scen in scenarios: - current_scen = self.cache.table_cache("scenario").current_item(scen) + current_scen = self.cache.table_cache("scenario").find_item(scen) if current_scen is None: error = f"no scenario matching {scen} to set alternatives for" if strict: @@ -199,6 +199,16 @@ def get_data_to_set_scenario_alternatives(self, *scenarios, strict=True): scen_alts_to_add.append(item_to_add) for alternative_id in current_scen["alternative_id_list"]: scen_alt = {"scenario_id": current_scen["id"], "alternative_id": alternative_id} - current_scen_alt = self.cache.table_cache("scenario_alternative").current_item(scen_alt) - scen_alt_ids_to_remove.add(current_scen_alt["id"]) - return scen_alts_to_add, scen_alt_ids_to_remove, errors + current_scen_alt = self.cache.table_cache("scenario_alternative").find_item(scen_alt) + scen_alt_ids_to_remove[current_scen_alt["id"]] = current_scen_alt + # Remove items that are both to add and to remove + for id_, to_rm in list(scen_alt_ids_to_remove.items()): + i = next((i for i, to_add in enumerate(scen_alts_to_add) if _is_equal(to_add, to_rm)), None) + if i is not None: + del scen_alts_to_add[i] + del scen_alt_ids_to_remove[id_] + return scen_alts_to_add, set(scen_alt_ids_to_remove), errors + + +def _is_equal(to_add, to_rm): + return all(to_rm[k] == v for k, v in to_add.items()) diff --git a/spinedb_api/filters/execution_filter.py b/spinedb_api/filters/execution_filter.py index 8e53e1ea..d5d37fa2 100644 --- a/spinedb_api/filters/execution_filter.py +++ b/spinedb_api/filters/execution_filter.py @@ -160,8 +160,8 @@ def _create_import_alternative(db_map, state): scenarios = [{"name": scen_name} for scen_name in scenarios] db_map.add_scenarios(*scenarios, _strict=True) for scen_name in scenarios: - scen = db_map.cache.table_cache("scenario").current_item({"name": scen_name}) - rank = len(scen.sorted_alternatives) + 1 # ranks are 1-based + scen = db_map.cache.table_cache("scenario").find_item({"name": scen_name}) + rank = len(scen.sorted_scenario_alternatives) + 1 # ranks are 1-based db_map.add_scenario_alternatives( {"scenario_name": scen_name, "alternative_name": db_map._import_alternative_name, "rank": rank} ) diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 912812d5..a1bddbbb 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -846,7 +846,7 @@ def _data_iterator(): "value": None, "type": None, } - pv = db_map.cache.table_cache("parameter_value").current_item(item) + pv = db_map.cache.table_cache("parameter_value").find_item(item) if pv is not None: value, type_ = fix_conflict((value, type_), (pv["value"], pv["type"]), on_conflict) item.update({"value": value, "type": type_}) @@ -872,7 +872,7 @@ def _get_scenarios_for_import(db_map, data): def _get_scenario_alternatives_for_import(db_map, data): alt_name_list_by_scen_name, errors = {}, [] for scen_name, alt_name, *optionals in data: - scen = db_map.cache.table_cache("scenario").current_item({"name": scen_name}) + scen = db_map.cache.table_cache("scenario").find_item({"name": scen_name}) if scen is None: errors.append(f"no scenario with name {scen_name} to set alternatives for") continue @@ -911,7 +911,7 @@ def _data_iterator(): value, type_ = unparse_value(value) index = index_by_list_name.get(list_name) if index is None: - current_list = db_map.cache.table_cache("parameter_value_list").current_item({"name": list_name}) + current_list = db_map.cache.table_cache("parameter_value_list").find_item({"name": list_name}) index = max( ( x["index"] diff --git a/spinedb_api/temp_id.py b/spinedb_api/temp_id.py index 4fbd9cd7..44fa5a14 100644 --- a/spinedb_api/temp_id.py +++ b/spinedb_api/temp_id.py @@ -30,43 +30,55 @@ def __init__(self, item_type): self._key_binds = [] self._tuple_key_binds = [] - def add_value_bind(self, item, key): - self._value_binds.append((item, key)) + def add_value_bind(self, collection, key): + self._value_binds.append((collection, key)) - def add_tuple_value_bind(self, item, key): - self._tuple_value_binds.append((item, key)) + def add_tuple_value_bind(self, collection, key): + self._tuple_value_binds.append((collection, key)) - def add_key_bind(self, item): - self._key_binds.append(item) + def add_key_bind(self, collection): + self._key_binds.append(collection) - def add_tuple_key_bind(self, item, key): - self._tuple_key_binds.append((item, key)) + def add_tuple_key_bind(self, collection, key): + self._tuple_key_binds.append((collection, key)) - def remove_key_bind(self, item): - self._key_binds.remove(item) + def remove_key_bind(self, collection): + self._key_binds.remove(collection) - def remove_tuple_key_bind(self, item, key): - self._tuple_key_binds.remove((item, key)) + def remove_tuple_key_bind(self, collection, key): + self._tuple_key_binds.remove((collection, key)) def resolve(self, new_id): - for item, key in self._value_binds: - item[key] = new_id - for item, key in self._tuple_value_binds: - item[key] = tuple(new_id if v is self else v for v in item[key]) - for item in self._key_binds: - if self in item: - item[new_id] = dict.pop(item, self, None) - for item, key in self._tuple_key_binds: - if key in item: - item[tuple(new_id if k is self else k for k in key)] = dict.pop(item, key, None) + for collection, key in self._value_binds: + collection[key] = new_id + for collection, key in self._tuple_value_binds: + collection[key] = tuple(new_id if v is self else v for v in collection[key]) + for collection in self._key_binds: + if self in collection: + collection.key_map[self] = new_id + collection[new_id] = dict.pop(collection, self, None) + for collection, key in self._tuple_key_binds: + if key in collection: + new_key = tuple(new_id if k is self else k for k in key) + collection[new_key] = dict.pop(collection, key, None) + collection.key_map[key] = new_key class TempIdDict(dict): def __init__(self, **kwargs): super().__init__(**kwargs) + self.key_map = {} for key, value in kwargs.items(): self._bind(key, value) + def __getitem__(self, key): + key = self.key_map.get(key, key) + return super().__getitem__(key) + + def get(self, key, default=None): + key = self.key_map.get(key, key) + return super().get(key, default) + def __setitem__(self, key, value): super().__setitem__(key, value) self._bind(key, value) From 176cac4791a5ced9688209a7e0c47cadbc6d52d6 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 14 Jun 2023 13:36:37 +0200 Subject: [PATCH 056/317] Return layout coordinates --- spinedb_api/graph_layout_generator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/spinedb_api/graph_layout_generator.py b/spinedb_api/graph_layout_generator.py index 8943b242..2e1d5569 100644 --- a/spinedb_api/graph_layout_generator.py +++ b/spinedb_api/graph_layout_generator.py @@ -111,10 +111,10 @@ def compute_layout(self): if self.vertex_count <= 1: x, y = np.array([0.0]), np.array([0.0]) self._layout_available(x, y) - return + return x, y matrix = self.shortest_path_matrix() if matrix is None: - return + return [], [] mask = np.ones((self.vertex_count, self.vertex_count)) == 1 - np.tril( np.ones((self.vertex_count, self.vertex_count)) ) # Upper triangular except diagonal @@ -159,3 +159,4 @@ def compute_layout(self): layout[heavy_ind, :] = heavy_pos x, y = layout[:, 0], layout[:, 1] self._layout_available(x, y) + return x, y From 13c8238c9386023478079c68be6abad7606e0ee8 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 14 Jun 2023 15:08:28 +0200 Subject: [PATCH 057/317] Fix CacheItem status change --- spinedb_api/db_cache_base.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/spinedb_api/db_cache_base.py b/spinedb_api/db_cache_base.py index 52470d31..0b507d67 100644 --- a/spinedb_api/db_cache_base.py +++ b/spinedb_api/db_cache_base.py @@ -28,6 +28,7 @@ class Status(Enum): to_add = auto() to_update = auto() to_remove = auto() + added_and_removed = auto() class DBCacheBase(dict): @@ -459,6 +460,7 @@ def __init__(self, db_cache, item_type, **kwargs): self._corrupted = False self._valid = None self._status = Status.committed + self._status_when_removed = None self._backup = None @classmethod @@ -739,10 +741,10 @@ def cascade_restore(self): """ if not self._removed: return - if self._status == Status.committed: - self._status = Status.to_add + if self.status in (Status.added_and_removed, Status.to_remove): + self._status = self._status_when_removed else: - self._status = Status.committed + raise RuntimeError("invalid status for item being restored") self._removed = False for referrer in self._referrers.values(): referrer.cascade_restore() @@ -759,10 +761,13 @@ def cascade_remove(self): """ if self._removed: return - if self._status == Status.committed: + self._status_when_removed = self._status + if self._status == Status.to_add: + self._status = Status.added_and_removed + elif self._status in (Status.committed, Status.to_update): self._status = Status.to_remove else: - self._status = Status.committed + raise RuntimeError("invalid status for item being removed") self._removed = True self._to_remove = False self._valid = None @@ -837,6 +842,8 @@ def update(self, other): if self._status == Status.committed: self._status = Status.to_update self._backup = self._asdict() + elif self._status in (Status.to_remove, Status.added_and_removed): + raise RuntimeError("invalid status of item being updated") for src_key, (ref_type, _ref_key) in self._references.values(): ref_id = self[src_key] if src_key in other and other[src_key] != ref_id: From a33d52953972febbb40657b0f64701f3a5cc69d2 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 14 Jun 2023 16:56:32 +0200 Subject: [PATCH 058/317] Fix issues with list_value_id --- spinedb_api/db_cache_impl.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/spinedb_api/db_cache_impl.py b/spinedb_api/db_cache_impl.py index f5c0a6b7..7c81d2a9 100644 --- a/spinedb_api/db_cache_impl.py +++ b/spinedb_api/db_cache_impl.py @@ -14,7 +14,7 @@ """ import uuid from operator import itemgetter -from .parameter_value import from_database, ParameterValueFormatError +from .parameter_value import to_database, from_database, ParameterValueFormatError from .db_cache_base import DBCacheBase, CacheItemBase @@ -247,15 +247,9 @@ def polish(self): ) if list_value_id is None: return f"default value {parsed_value} of {self['name']} is not in {list_name}" - self["default_value"] = list_value_id + self["default_value"] = to_database(list_value_id)[0] self["default_type"] = "list_value_ref" - def _asdict(self): - d = super()._asdict() - if d.get("default_type") == "list_value_ref": - d["default_value"] = str(d["default_value"]).encode() - return d - def merge(self, other): other_parameter_value_list_id = other.get("parameter_value_list_id") if ( @@ -343,15 +337,9 @@ def polish(self): f"value {parsed_value} of {self['parameter_definition_name']} for {self['entity_byname']} " f"is not in {list_name}" ) - self["value"] = list_value_id + self["value"] = to_database(list_value_id)[0] self["type"] = "list_value_ref" - def _asdict(self): - d = super()._asdict() - if d.get("type") == "list_value_ref": - d["value"] = str(d["value"]).encode() - return d - class ParameterValueListItem(CacheItemBase): _unique_keys = (("name",),) From 9fb98cf75f5b1bfa8e681fefc499796e79f54dfd Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 15 Jun 2023 10:52:17 +0200 Subject: [PATCH 059/317] Make cascade restore sensitive to the removal source --- spinedb_api/db_cache_base.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/spinedb_api/db_cache_base.py b/spinedb_api/db_cache_base.py index 0b507d67..5a586f62 100644 --- a/spinedb_api/db_cache_base.py +++ b/spinedb_api/db_cache_base.py @@ -460,6 +460,7 @@ def __init__(self, db_cache, item_type, **kwargs): self._corrupted = False self._valid = None self._status = Status.committed + self._removal_source = None self._status_when_removed = None self._backup = None @@ -735,19 +736,21 @@ def _update_weak_referrers(self): for weak_referrer in self._weak_referrers.values(): weak_referrer.call_update_callbacks() - def cascade_restore(self): + def cascade_restore(self, source=None): """Restores this item (if removed) and all its referrers in cascade. Also, updates items' status and calls their restore callbacks. """ if not self._removed: return + if source is not self._removal_source: + return if self.status in (Status.added_and_removed, Status.to_remove): self._status = self._status_when_removed else: raise RuntimeError("invalid status for item being restored") self._removed = False for referrer in self._referrers.values(): - referrer.cascade_restore() + referrer.cascade_restore(source=self) self._update_weak_referrers() obsolete = set() for callback in list(self.restore_callbacks): @@ -755,7 +758,7 @@ def cascade_restore(self): obsolete.add(callback) self.restore_callbacks -= obsolete - def cascade_remove(self): + def cascade_remove(self, source=None): """Removes this item and all its referrers in cascade. Also, updates items' status and calls their remove callbacks. """ @@ -768,6 +771,7 @@ def cascade_remove(self): self._status = Status.to_remove else: raise RuntimeError("invalid status for item being removed") + self._removal_source = source self._removed = True self._to_remove = False self._valid = None @@ -777,7 +781,7 @@ def cascade_remove(self): obsolete.add(callback) self.remove_callbacks -= obsolete for referrer in self._referrers.values(): - referrer.cascade_remove() + referrer.cascade_remove(source=self) self._update_weak_referrers() def cascade_update(self): From 61de5ff84582db93d6c9ce8283f43eb6a460c176 Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 20 Jun 2023 09:15:28 +0200 Subject: [PATCH 060/317] Fix ParameterValue unique key --- spinedb_api/db_cache_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spinedb_api/db_cache_impl.py b/spinedb_api/db_cache_impl.py index 7c81d2a9..6aafcb11 100644 --- a/spinedb_api/db_cache_impl.py +++ b/spinedb_api/db_cache_impl.py @@ -269,7 +269,7 @@ def merge(self, other): class ParameterValueItem(ParsedValueBase): - _unique_keys = (("parameter_definition_name", "entity_byname", "alternative_name"),) + _unique_keys = (("entity_class_name", "parameter_definition_name", "entity_byname", "alternative_name"),) _references = { "entity_class_name": ("entity_class_id", ("entity_class", "name")), "dimension_id_list": ("entity_class_id", ("entity_class", "dimension_id_list")), From 36acf4f2023493ad459e2f9e43be9b40acff1f0b Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 13 Jul 2023 09:34:50 +0200 Subject: [PATCH 061/317] Simplify TempId resolving and parameter queries --- spinedb_api/db_cache_base.py | 15 ++-- spinedb_api/db_cache_impl.py | 20 ++++- spinedb_api/db_mapping_add_mixin.py | 12 +-- spinedb_api/db_mapping_base.py | 32 +++---- spinedb_api/db_mapping_update_mixin.py | 2 +- spinedb_api/filters/renamer.py | 2 - spinedb_api/temp_id.py | 111 +++++++++++++++---------- tests/filters/test_renamer.py | 4 - tests/filters/test_scenario_filter.py | 2 +- tests/test_DatabaseMapping.py | 29 +++---- 10 files changed, 126 insertions(+), 103 deletions(-) diff --git a/spinedb_api/db_cache_base.py b/spinedb_api/db_cache_base.py index 5a586f62..72a4cf03 100644 --- a/spinedb_api/db_cache_base.py +++ b/spinedb_api/db_cache_base.py @@ -382,13 +382,16 @@ def persist_item(self, item, removed=False): self._add_unique(item) def add_item(self, item, new=False): + if not isinstance(item, CacheItemBase): + item = self._make_item(item) + item.polish() if "id" not in item: item["id"] = self._new_id() - self[item["id"]] = new_item = self._make_item(item) - self._add_unique(new_item) + self[item["id"]] = item + self._add_unique(item) if new: - new_item.status = Status.to_add - return new_item + item.status = Status.to_add + return item def update_item(self, item): current_item = self.find_item(item) @@ -537,7 +540,9 @@ def _extended(self): Returns: dict """ - return {**self, **{key: self[key] for key in self._references}} + d = self._asdict() + d.update({key: self[key] for key in self._references}) + return d def _asdict(self): """Returns a dict from this item's original fields. diff --git a/spinedb_api/db_cache_impl.py b/spinedb_api/db_cache_impl.py index 6aafcb11..497f4dab 100644 --- a/spinedb_api/db_cache_impl.py +++ b/spinedb_api/db_cache_impl.py @@ -16,6 +16,7 @@ from operator import itemgetter from .parameter_value import to_database, from_database, ParameterValueFormatError from .db_cache_base import DBCacheBase, CacheItemBase +from .temp_id import TempId class DBCache(DBCacheBase): @@ -249,6 +250,12 @@ def polish(self): return f"default value {parsed_value} of {self['name']} is not in {list_name}" self["default_value"] = to_database(list_value_id)[0] self["default_type"] = "list_value_ref" + if isinstance(list_value_id, TempId): + + def callback(new_id): + self["default_value"] = to_database(new_id)[0] + + list_value_id.add_resolve_callback(callback) def merge(self, other): other_parameter_value_list_id = other.get("parameter_value_list_id") @@ -339,6 +346,12 @@ def polish(self): ) self["value"] = to_database(list_value_id)[0] self["type"] = "list_value_ref" + if isinstance(list_value_id, TempId): + + def callback(new_id): + self["value"] = to_database(new_id)[0] + + list_value_id.add_resolve_callback(callback) class ParameterValueListItem(CacheItemBase): @@ -441,8 +454,11 @@ class ParameterValueMetadataItem(CacheItemBase): } _inverse_references = { "parameter_value_id": ( - ("parameter_definition_name", "entity_byname", "alternative_name"), - ("parameter_value", ("parameter_definition_name", "entity_byname", "alternative_name")), + ("entity_class_name", "parameter_definition_name", "entity_byname", "alternative_name"), + ( + "parameter_value", + ("entity_class_name", "parameter_definition_name", "entity_byname", "alternative_name"), + ), ), "metadata_id": (("metadata_name", "metadata_value"), ("metadata", ("name", "value"))), } diff --git a/spinedb_api/db_mapping_add_mixin.py b/spinedb_api/db_mapping_add_mixin.py index 3ae54159..ef0e21a7 100644 --- a/spinedb_api/db_mapping_add_mixin.py +++ b/spinedb_api/db_mapping_add_mixin.py @@ -16,6 +16,7 @@ from sqlalchemy.exc import DBAPIError from .exception import SpineIntegrityError, SpineDBAPIError +from .temp_id import TempId class DatabaseMappingAddMixin: @@ -51,8 +52,7 @@ def add_items(self, tablename, *items, check=True, strict=False): raise SpineIntegrityError(error) errors.append(error) continue - item = checked_item._asdict() - added.append(table_cache.add_item(item, new=True)._asdict()) + added.append(table_cache.add_item(checked_item, new=True)._asdict()) return added, errors def _do_add_items(self, connection, tablename, *items_to_add): @@ -63,7 +63,7 @@ def _do_add_items(self, connection, tablename, *items_to_add): table = self._metadata.tables[self._real_tablename(tablename)] id_items, temp_id_items = [], [] for item in items_to_add: - if hasattr(item["id"], "resolve"): + if isinstance(item["id"], TempId): temp_id_items.append(item) else: id_items.append(item) @@ -177,15 +177,15 @@ def add_parameter_value_metadata(self, *items, **kwargs): return self.add_items("parameter_value_metadata", *items, **kwargs) def add_ext_entity_metadata(self, *items, **kwargs): - metadata_items = self.get_metadata_to_add_with_entity_metadata_items(*items) + metadata_items = self.get_metadata_to_add_with_item_metadata_items(*items) self.add_items("metadata", *metadata_items, **kwargs) return self.add_items("entity_metadata", *items, **kwargs) def add_ext_parameter_value_metadata(self, *items, **kwargs): - metadata_items = self.get_metadata_to_add_with_entity_metadata_items(*items) + metadata_items = self.get_metadata_to_add_with_item_metadata_items(*items) self.add_items("metadata", *metadata_items, **kwargs) return self.add_items("parameter_value_metadata", *items, **kwargs) - def get_metadata_to_add_with_entity_metadata_items(self, *items): + def get_metadata_to_add_with_item_metadata_items(self, *items): metadata_items = ({"name": item["metadata_name"], "value": item["metadata_value"]} for item in items) return [x for x in metadata_items if not self.cache.table_cache("metadata").find_item(x)] diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index e37eb012..82167301 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -1201,9 +1201,9 @@ def entity_parameter_definition_sq(self): self.query( self.parameter_definition_sq.c.id.label("id"), self.parameter_definition_sq.c.entity_class_id, - self.parameter_definition_sq.c.object_class_id, - self.parameter_definition_sq.c.relationship_class_id, self.wide_entity_class_sq.c.name.label("entity_class_name"), + label("object_class_id", self._object_class_id()), + label("relationship_class_id", self._relationship_class_id()), label("object_class_name", self._object_class_name()), label("relationship_class_name", self._relationship_class_name()), label("object_class_id_list", self._object_class_id_list()), @@ -1282,7 +1282,7 @@ def object_parameter_definition_sq(self): self.parameter_definition_sq.c.default_type, self.parameter_definition_sq.c.description, ) - .filter(self.object_class_sq.c.id == self.parameter_definition_sq.c.object_class_id) + .filter(self.object_class_sq.c.id == self.parameter_definition_sq.c.entity_class_id) .outerjoin( self.parameter_value_list_sq, self.parameter_value_list_sq.c.id == self.parameter_definition_sq.c.parameter_value_list_id, @@ -1363,7 +1363,7 @@ def relationship_parameter_definition_sq(self): self.parameter_definition_sq.c.default_type, self.parameter_definition_sq.c.description, ) - .filter(self.parameter_definition_sq.c.relationship_class_id == self.wide_relationship_class_sq.c.id) + .filter(self.parameter_definition_sq.c.entity_class_id == self.wide_relationship_class_sq.c.id) .outerjoin( self.parameter_value_list_sq, self.parameter_value_list_sq.c.id == self.parameter_definition_sq.c.parameter_value_list_id, @@ -1383,17 +1383,17 @@ def entity_parameter_value_sq(self): self.query( self.parameter_value_sq.c.id.label("id"), self.parameter_definition_sq.c.entity_class_id, - self.parameter_definition_sq.c.object_class_id, - self.parameter_definition_sq.c.relationship_class_id, self.wide_entity_class_sq.c.name.label("entity_class_name"), + label("object_class_id", self._object_class_id()), + label("relationship_class_id", self._relationship_class_id()), label("object_class_name", self._object_class_name()), label("relationship_class_name", self._relationship_class_name()), label("object_class_id_list", self._object_class_id_list()), label("object_class_name_list", self._object_class_name_list()), self.parameter_value_sq.c.entity_id, self.wide_entity_sq.c.name.label("entity_name"), - self.parameter_value_sq.c.object_id, - self.parameter_value_sq.c.relationship_id, + label("object_id", self._object_id()), + label("relationship_id", self._relationship_id()), label("object_name", self._object_name()), label("object_id_list", self._object_id_list()), label("object_name_list", self._object_name_list()), @@ -1424,7 +1424,7 @@ def entity_parameter_value_sq(self): # object_id_list might be None when objects have been filtered out .filter( or_( - self.parameter_value_sq.c.relationship_id.is_(None), + self.wide_relationship_sq.c.id.is_(None), self.wide_relationship_sq.c.object_id_list.isnot(None), ) ) @@ -1457,8 +1457,8 @@ def object_parameter_value_sq(self): self.parameter_value_sq.c.type, ) .filter(self.parameter_definition_sq.c.id == self.parameter_value_sq.c.parameter_definition_id) - .filter(self.parameter_value_sq.c.object_id == self.object_sq.c.id) - .filter(self.parameter_definition_sq.c.object_class_id == self.object_class_sq.c.id) + .filter(self.parameter_value_sq.c.entity_id == self.object_sq.c.id) + .filter(self.parameter_definition_sq.c.entity_class_id == self.object_class_sq.c.id) .filter(self.parameter_value_sq.c.alternative_id == self.alternative_sq.c.id) .subquery() ) @@ -1492,8 +1492,8 @@ def relationship_parameter_value_sq(self): self.parameter_value_sq.c.type, ) .filter(self.parameter_definition_sq.c.id == self.parameter_value_sq.c.parameter_definition_id) - .filter(self.parameter_value_sq.c.relationship_id == self.wide_relationship_sq.c.id) - .filter(self.parameter_definition_sq.c.relationship_class_id == self.wide_relationship_class_sq.c.id) + .filter(self.parameter_value_sq.c.entity_id == self.wide_relationship_sq.c.id) + .filter(self.parameter_definition_sq.c.entity_class_id == self.wide_relationship_class_sq.c.id) .filter(self.parameter_value_sq.c.alternative_id == self.alternative_sq.c.id) .subquery() ) @@ -1593,8 +1593,6 @@ def _make_parameter_definition_sq(self): par_def_sq.c.name.label("name"), par_def_sq.c.description.label("description"), par_def_sq.c.entity_class_id, - label("object_class_id", self._object_class_id()), - label("relationship_class_id", self._relationship_class_id()), label("default_value", default_value), label("default_type", default_type), label("list_value_id", list_value_id), @@ -1623,10 +1621,6 @@ def _make_parameter_value_sq(self): par_val_sq.c.parameter_definition_id, par_val_sq.c.entity_class_id, par_val_sq.c.entity_id, - label("object_class_id", self._object_class_id()), - label("relationship_class_id", self._relationship_class_id()), - label("object_id", self._object_id()), - label("relationship_id", self._relationship_id()), label("value", value), label("type", type_), label("list_value_id", list_value_id), diff --git a/spinedb_api/db_mapping_update_mixin.py b/spinedb_api/db_mapping_update_mixin.py index 5335b862..0bfcf678 100644 --- a/spinedb_api/db_mapping_update_mixin.py +++ b/spinedb_api/db_mapping_update_mixin.py @@ -155,7 +155,7 @@ def update_parameter_value_metadata(self, *items, **kwargs): return self.update_items("parameter_value_metadata", *items, **kwargs) def _update_ext_item_metadata(self, tablename, *items, **kwargs): - metadata_items = self.get_metadata_to_add_with_entity_metadata_items(*items) + metadata_items = self.get_metadata_to_add_with_item_metadata_items(*items) added, errors = self.add_items("metadata", *metadata_items, **kwargs) updated, more_errors = self.update_items(tablename, *items, **kwargs) return added + updated, errors + more_errors diff --git a/spinedb_api/filters/renamer.py b/spinedb_api/filters/renamer.py index b31ab9b9..f729e8df 100644 --- a/spinedb_api/filters/renamer.py +++ b/spinedb_api/filters/renamer.py @@ -272,8 +272,6 @@ def _make_renaming_parameter_definition_sq(db_map, state): new_parameter_name.label("name"), subquery.c.description, subquery.c.entity_class_id, - subquery.c.object_class_id, - subquery.c.relationship_class_id, subquery.c.default_value, subquery.c.default_type, subquery.c.list_value_id, diff --git a/spinedb_api/temp_id.py b/spinedb_api/temp_id.py index 44fa5a14..0ebf5663 100644 --- a/spinedb_api/temp_id.py +++ b/spinedb_api/temp_id.py @@ -25,49 +25,30 @@ def __new__(cls, item_type): def __init__(self, item_type): super().__init__() self._item_type = item_type - self._value_binds = [] - self._tuple_value_binds = [] - self._key_binds = [] - self._tuple_key_binds = [] + self._resolve_callbacks = [] - def add_value_bind(self, collection, key): - self._value_binds.append((collection, key)) + def __repr__(self): + return f"TempId({self._item_type}, {super().__repr__()})" - def add_tuple_value_bind(self, collection, key): - self._tuple_value_binds.append((collection, key)) + def add_resolve_callback(self, callback): + self._resolve_callbacks.append(callback) - def add_key_bind(self, collection): - self._key_binds.append(collection) - - def add_tuple_key_bind(self, collection, key): - self._tuple_key_binds.append((collection, key)) - - def remove_key_bind(self, collection): - self._key_binds.remove(collection) - - def remove_tuple_key_bind(self, collection, key): - self._tuple_key_binds.remove((collection, key)) + def remove_resolve_callback(self, callback): + try: + self._resolve_callbacks.remove(callback) + except ValueError: + pass def resolve(self, new_id): - for collection, key in self._value_binds: - collection[key] = new_id - for collection, key in self._tuple_value_binds: - collection[key] = tuple(new_id if v is self else v for v in collection[key]) - for collection in self._key_binds: - if self in collection: - collection.key_map[self] = new_id - collection[new_id] = dict.pop(collection, self, None) - for collection, key in self._tuple_key_binds: - if key in collection: - new_key = tuple(new_id if k is self else k for k in key) - collection[new_key] = dict.pop(collection, key, None) - collection.key_map[key] = new_key + while self._resolve_callbacks: + self._resolve_callbacks.pop(0)(new_id) class TempIdDict(dict): def __init__(self, **kwargs): super().__init__(**kwargs) self.key_map = {} + self._unbind_callbacks_by_key = {} for key, value in kwargs.items(): self._bind(key, value) @@ -102,24 +83,66 @@ def pop(self, key, default): self._unbind(key) return super().pop(key, default) + def _make_value_resolve_callback(self, key): + def callback(new_id): + self[key] = new_id + + return callback + + def _make_value_component_resolve_callback(self, key, value, i): + """Returns a callback to call when the given key is resolved. + + Args: + key (TempId) + """ + + def callback(new_id, i=i): + new_value = list(value) + new_value[i] = new_id + new_value = tuple(new_value) + self[key] = new_value + + return callback + + def _make_key_resolve_callback(self, key): + def callback(new_id): + if key in self: + self.key_map[key] = new_id + self[new_id] = self.pop(key, None) + + return callback + + def _make_key_component_resolve_callback(self, key, i): + def callback(new_id, i=i): + if key in self: + new_key = list(key) + new_key[i] = new_id + new_key = tuple(new_key) + self.key_map[key] = new_key + self[new_key] = self.pop(key, None) + + return callback + def _bind(self, key, value): if isinstance(value, TempId): - value.add_value_bind(self, key) + value.add_resolve_callback(self._make_value_resolve_callback(key)) elif isinstance(value, tuple): - for v in value: + for (i, v) in enumerate(value): if isinstance(v, TempId): - v.add_tuple_value_bind(self, key) + v.add_resolve_callback(self._make_value_component_resolve_callback(key, value, i)) elif isinstance(key, TempId): - key.add_key_bind(self) + callback = self._make_key_resolve_callback(key) + key.add_resolve_callback(callback) + self._unbind_callbacks_by_key.setdefault(key, []).append(lambda: key.remove_resolve_callback(callback)) elif isinstance(key, tuple): - for k in key: + for i, k in enumerate(key): if isinstance(k, TempId): - k.add_tuple_key_bind(self, key) + callback = self._make_key_component_resolve_callback(key, i) + k.add_resolve_callback(callback) + self._unbind_callbacks_by_key.setdefault(key, []).append( + lambda k=k, callback=callback: k.remove_resolve_callback(callback) + ) def _unbind(self, key): - if isinstance(key, TempId): - key.remove_key_bind(self) - elif isinstance(key, tuple): - for k in key: - if isinstance(k, TempId): - k.remove_tuple_key_bind(self, key) + for callback in self._unbind_callbacks_by_key.pop(key, ()): + callback() diff --git a/tests/filters/test_renamer.py b/tests/filters/test_renamer.py index 8f2b68e0..d92dc634 100644 --- a/tests/filters/test_renamer.py +++ b/tests/filters/test_renamer.py @@ -178,8 +178,6 @@ def test_renaming_single_parameter(self): "name", "description", "entity_class_id", - "object_class_id", - "relationship_class_id", "default_value", "default_type", "list_value_id", @@ -226,8 +224,6 @@ def test_parameter_renamer_from_dict(self): "name", "description", "entity_class_id", - "object_class_id", - "relationship_class_id", "default_value", "default_type", "list_value_id", diff --git a/tests/filters/test_scenario_filter.py b/tests/filters/test_scenario_filter.py index 97735df5..13120381 100644 --- a/tests/filters/test_scenario_filter.py +++ b/tests/filters/test_scenario_filter.py @@ -323,7 +323,7 @@ def test_scenario_filter_for_multiple_objects_and_parameters(self): datamined_values = dict() for parameter in parameters: self.assertEqual(alternative_names[parameter.alternative_id], "alternative") - parameter_values = datamined_values.setdefault(object_names[parameter.object_id], dict()) + parameter_values = datamined_values.setdefault(object_names[parameter.entity_id], dict()) parameter_values[parameter_names[parameter.parameter_definition_id]] = parameter.value self.assertEqual( datamined_values, diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 43cdd93c..1272d5d6 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -160,8 +160,6 @@ def test_parameter_definition_sq(self): "name", "description", "entity_class_id", - "object_class_id", - "relationship_class_id", "default_value", "default_type", "list_value_id", @@ -178,10 +176,6 @@ def test_parameter_value_sq(self): "parameter_definition_id", "entity_class_id", "entity_id", - "object_class_id", - "relationship_class_id", - "object_id", - "relationship_id", "value", "type", "list_value_id", @@ -498,8 +492,7 @@ def test_parameter_definition_sq_for_object_class(self): definition_rows = self._db_map.query(self._db_map.parameter_definition_sq).all() self.assertEqual(len(definition_rows), 1) self.assertEqual(definition_rows[0].name, "par1") - self.assertIsNotNone(definition_rows[0].object_class_id) - self.assertIsNone(definition_rows[0].relationship_class_id) + self.assertIsNotNone(definition_rows[0].entity_class_id) def test_parameter_definition_sq_for_relationship_class(self): self.create_object_classes() @@ -509,8 +502,7 @@ def test_parameter_definition_sq_for_relationship_class(self): definition_rows = self._db_map.query(self._db_map.parameter_definition_sq).all() self.assertEqual(len(definition_rows), 1) self.assertEqual(definition_rows[0].name, "par1") - self.assertIsNone(definition_rows[0].object_class_id) - self.assertIsNotNone(definition_rows[0].relationship_class_id) + self.assertIsNotNone(definition_rows[0].entity_class_id) def test_entity_parameter_definition_sq_for_object_class(self): self.create_object_classes() @@ -1093,7 +1085,8 @@ def test_add_existing_parameter_value(self): [str(e) for e in errors], [ "there's already a parameter_value with " - "{'parameter_definition_name': 'color', 'entity_byname': ('nemo',), 'alternative_name': 'Base'}" + "{'entity_class_name': 'fish', 'parameter_definition_name': 'color', " + "'entity_byname': ('nemo',), 'alternative_name': 'Base'}" ], ) @@ -1570,9 +1563,7 @@ def test_update_parameter_definition_value_list(self): "id": 1, "list_value_id": None, "name": "my_parameter", - "object_class_id": 1, "parameter_value_list_id": 1, - "relationship_class_id": None, }, ) @@ -1627,7 +1618,7 @@ def test_update_object_metadata(self): import_functions.import_object_metadata(self._db_map, (("my_class", "my_object", '{"title": "My metadata."}'),)) self._db_map.commit_session("Add test data") items, errors = self._db_map.update_ext_entity_metadata( - *[{"id": 1, "metadata_name": "key_2", "metadata_value": "new value"}] + {"id": 1, "metadata_name": "key_2", "metadata_value": "new value"} ) self.assertEqual(errors, []) self.assertEqual(len(items), 2) @@ -1635,10 +1626,10 @@ def test_update_object_metadata(self): self._db_map.commit_session("Update data") metadata_entries = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata_entries), 1) - self.assertEqual(dict(metadata_entries[0]), {"id": 2, "name": "key_2", "value": "new value", "commit_id": 3}) + self.assertEqual(dict(metadata_entries[0]), {"id": 1, "name": "key_2", "value": "new value", "commit_id": 3}) entity_metadata_entries = self._db_map.query(self._db_map.entity_metadata_sq).all() self.assertEqual(len(entity_metadata_entries), 1) - self.assertEqual(dict(entity_metadata_entries[0]), {"id": 1, "entity_id": 1, "metadata_id": 2, "commit_id": 3}) + self.assertEqual(dict(entity_metadata_entries[0]), {"id": 1, "entity_id": 1, "metadata_id": 1, "commit_id": 3}) def test_update_object_metadata_reuses_existing_metadata(self): import_functions.import_object_classes(self._db_map, ("my_class",)) @@ -1710,7 +1701,7 @@ def test_update_parameter_value_metadata(self): ) self._db_map.commit_session("Add test data") items, errors = self._db_map.update_ext_parameter_value_metadata( - *[{"id": 1, "metadata_name": "key_2", "metadata_value": "new value"}] + {"id": 1, "metadata_name": "key_2", "metadata_value": "new value"} ) self.assertEqual(errors, []) self.assertEqual(len(items), 2) @@ -1718,11 +1709,11 @@ def test_update_parameter_value_metadata(self): self._db_map.commit_session("Update data") metadata_entries = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata_entries), 1) - self.assertEqual(dict(metadata_entries[0]), {"id": 2, "name": "key_2", "value": "new value", "commit_id": 3}) + self.assertEqual(dict(metadata_entries[0]), {"id": 1, "name": "key_2", "value": "new value", "commit_id": 3}) value_metadata_entries = self._db_map.query(self._db_map.parameter_value_metadata_sq).all() self.assertEqual(len(value_metadata_entries), 1) self.assertEqual( - dict(value_metadata_entries[0]), {"id": 1, "parameter_value_id": 1, "metadata_id": 2, "commit_id": 3} + dict(value_metadata_entries[0]), {"id": 1, "parameter_value_id": 1, "metadata_id": 1, "commit_id": 3} ) def test_update_parameter_value_metadata_will_not_delete_shared_entity_metadata(self): From 33ebc1e6642ec03e31c7705e2758ecf17df68bed Mon Sep 17 00:00:00 2001 From: Henrik Koski Date: Wed, 30 Aug 2023 12:36:45 +0300 Subject: [PATCH 062/317] Fix locked db editing and committing Reimplement the fix on 0.8-dev branch. Re spine-tools/Spine-Toolbox#2201 --- spinedb_api/db_mapping_commit_mixin.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py index ecd9b12b..c991a1ec 100644 --- a/spinedb_api/db_mapping_commit_mixin.py +++ b/spinedb_api/db_mapping_commit_mixin.py @@ -37,7 +37,10 @@ def commit_session(self, comment): date = datetime.now(timezone.utc) ins = self._metadata.tables["commit"].insert() with self.engine.begin() as connection: - commit_id = connection.execute(ins, dict(user=user, date=date, comment=comment)).inserted_primary_key[0] + try: + commit_id = connection.execute(ins, dict(user=user, date=date, comment=comment)).inserted_primary_key[0] + except sqlalchemy.exc.DBAPIError as e: + raise SpineDBAPIError(f"Fail to commit: {e.orig.args}") for tablename, (to_add, to_update, to_remove) in dirty_items: for item in to_add + to_update + to_remove: item.commit(commit_id) From 72fd575850f2aecb927792e430602b25f5a1ff8d Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Wed, 30 Aug 2023 13:35:54 +0300 Subject: [PATCH 063/317] Remove unneeded Travis CI config file --- .travis.yml | 22 ---------------------- 1 file changed, 22 deletions(-) delete mode 100644 .travis.yml diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 1544467f..00000000 --- a/.travis.yml +++ /dev/null @@ -1,22 +0,0 @@ -dist: xenial # required for Python >= 3.7 -language: python -python: - - "3.7" -notifications: - email: false -install: - - npm install gh-pages - - pip install -U pip - - pip install sphinx sphinx_rtd_theme -script: - - pip install . -after_success: - - openssl aes-256-cbc -K $encrypted_151fbad3b0ea_key -iv $encrypted_151fbad3b0ea_iv -in deploy-key.enc -out deploy-key -d - - chmod 600 deploy-key - - eval `ssh-agent -s` - - ssh-add deploy-key - - cd docs - - make html - - cd .. - - touch docs/build/html/.nojekyll - - ./node_modules/.bin/gh-pages -t -d docs/build/html -b gh-pages -r git@github.com:${TRAVIS_REPO_SLUG}.git From 3d0989906344305bf4cc6a153675244315d88144 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 30 Aug 2023 13:59:58 +0200 Subject: [PATCH 064/317] Handle obsolete filters in DB server --- spinedb_api/spine_db_server.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/spinedb_api/spine_db_server.py b/spinedb_api/spine_db_server.py index c6e92698..ceba23fa 100644 --- a/spinedb_api/spine_db_server.py +++ b/spinedb_api/spine_db_server.py @@ -460,9 +460,11 @@ def call_method(self, method_name, *args, **kwargs): return _db_manager.call_method(self.server_address, method_name, *args, **kwargs) def apply_filters(self, filters): + obsolete = ("tool",) configs = [ {"scenario": scenario_filter_config, "alternatives": alternative_filter_config}[key](value) for key, value in filters.items() + if key not in obsolete ] return _db_manager.apply_filters(self.server_address, configs) From ed2fc6d76d0c6d9aa9a2a401f485b68fa862fa4e Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 31 Aug 2023 10:17:54 +0200 Subject: [PATCH 065/317] Update query offsets with a lock Obviously if two threads want to run the same query we need to synchronize offset updating --- spinedb_api/db_cache_base.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/spinedb_api/db_cache_base.py b/spinedb_api/db_cache_base.py index 72a4cf03..7b6fd7b3 100644 --- a/spinedb_api/db_cache_base.py +++ b/spinedb_api/db_cache_base.py @@ -12,6 +12,7 @@ DB cache base. """ +import threading from contextlib import suppress from enum import Enum, unique, auto from functools import cmp_to_key @@ -39,6 +40,7 @@ def __init__(self, chunk_size=None): self._updated_items = {} self._removed_items = {} self._offsets = {} + self._offset_lock = threading.Lock() self._fetched_item_types = set() self._chunk_size = chunk_size @@ -99,7 +101,7 @@ def dirty_items(self): # Fetch descendants, so that they are validated in next iterations of the loop. # This ensures cascade removal. # FIXME: We should also fetch the current item type because of multi-dimensional entities and - # classes which also depend on no-dimensional ones + # classes which also depend on zero-dimensional ones for x in self: if self._cmp_item_type(item_type, x) < 0: self.fetch_all(x) @@ -173,9 +175,10 @@ def _get_next_chunk(self, item_type): if not self._chunk_size: self._fetched_item_types.add(item_type) return [dict(x) for x in qry] - offset = self._offsets.setdefault(item_type, 0) - chunk = [dict(x) for x in qry.limit(self._chunk_size).offset(offset)] - self._offsets[item_type] += len(chunk) + with self._offset_lock: + offset = self._offsets.setdefault(item_type, 0) + chunk = [dict(x) for x in qry.limit(self._chunk_size).offset(offset)] + self._offsets[item_type] += len(chunk) return chunk def advance_query(self, item_type): From ea6416a1c1d15ba935df213880d122857e7d5ba3 Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 1 Sep 2023 09:07:26 +0200 Subject: [PATCH 066/317] Remove unnecessary joins in parameter value and definition queries --- spinedb_api/db_mapping_base.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 82167301..6e3e5a49 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -1599,7 +1599,6 @@ def _make_parameter_definition_sq(self): par_def_sq.c.commit_id.label("commit_id"), par_def_sq.c.parameter_value_list_id.label("parameter_value_list_id"), ) - .join(self.wide_entity_class_sq, self.wide_entity_class_sq.c.id == par_def_sq.c.entity_class_id) .outerjoin(self.list_value_sq, self.list_value_sq.c.id == list_value_id) .subquery("clean_parameter_definition_sq") ) @@ -1627,8 +1626,6 @@ def _make_parameter_value_sq(self): par_val_sq.c.commit_id.label("commit_id"), par_val_sq.c.alternative_id, ) - .join(self.wide_entity_sq, self.wide_entity_sq.c.id == par_val_sq.c.entity_id) - .join(self.wide_entity_class_sq, self.wide_entity_class_sq.c.id == par_val_sq.c.entity_class_id) .outerjoin(self.list_value_sq, self.list_value_sq.c.id == list_value_id) .subquery("clean_parameter_value_sq") ) From b7e07fbf36321e44e47f72c9c69a08f325098fa3 Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 1 Sep 2023 09:25:43 +0200 Subject: [PATCH 067/317] Simplify messaging/progress reporting in graph layout generator --- spinedb_api/graph_layout_generator.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/spinedb_api/graph_layout_generator.py b/spinedb_api/graph_layout_generator.py index 2e1d5569..dae7737b 100644 --- a/spinedb_api/graph_layout_generator.py +++ b/spinedb_api/graph_layout_generator.py @@ -36,7 +36,6 @@ def __init__( preview_available=lambda x, y: None, layout_available=lambda x, y: None, layout_progressed=lambda iter: None, - message_available=lambda msg: None, ): super().__init__() if vertex_count == 0: @@ -55,7 +54,6 @@ def __init__( self._preview_available = preview_available self._layout_available = layout_available self._layout_progressed = layout_progressed - self._message_available = message_available def shortest_path_matrix(self): """Returns the shortest-path matrix.""" @@ -73,17 +71,13 @@ def shortest_path_matrix(self): pass start = 0 slices = [] - iteration = 0 - self._message_available("Step 1 of 2: Computing shortest-path matrix...") while start < self.vertex_count: if self._is_stopped(): return None - self._layout_progressed(iteration) stop = min(self.vertex_count, start + math.ceil(self.vertex_count / 10)) slice_ = dijkstra(dist, directed=False, indices=range(start, stop)) slices.append(slice_) start = stop - iteration += 1 matrix = np.vstack(slices) # Remove infinites and zeros matrix[matrix == np.inf] = self.spread * self.vertex_count ** (0.5) @@ -113,6 +107,7 @@ def compute_layout(self): self._layout_available(x, y) return x, y matrix = self.shortest_path_matrix() + self._layout_progressed(1) if matrix is None: return [], [] mask = np.ones((self.vertex_count, self.vertex_count)) == 1 - np.tril( @@ -134,13 +129,13 @@ def compute_layout(self): minstep = 1 / np.max(weights[mask]) lambda_ = np.log(minstep / maxstep) / (self.max_iters - 1) # exponential decay of allowed adjustment sets = self.sets() # construct sets of bus pairs - self._message_available("Step 2 of 2: Generating layout...") + self._layout_progressed(2) for iteration in range(self.max_iters): if self._is_stopped(): break x, y = layout[:, 0], layout[:, 1] self._preview_available(x, y) - self._layout_progressed(iteration) + self._layout_progressed(3 + iteration) # FIXME step = maxstep * np.exp(lambda_ * iteration) # how big adjustments are allowed? rand_order = np.random.permutation( From 7473e466d8a51e91d461cb0dc337e341e78721aa Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Mon, 4 Sep 2023 12:54:53 +0300 Subject: [PATCH 068/317] Add one() to Query one() is implemented in SQLAlchemy's Query and very handy in unit tests. --- spinedb_api/query.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/spinedb_api/query.py b/spinedb_api/query.py index ac375c37..5ffcf47c 100644 --- a/spinedb_api/query.py +++ b/spinedb_api/query.py @@ -96,6 +96,16 @@ def all(self): def first(self): return self._result().first() + def one(self): + result = self._result() + first = result.fetchone() + if first is None: + return SpineDBAPIError("no results found for one()") + second = result.fetchone() + if second is not None: + raise SpineDBAPIError("multiple results found for one()") + return first + def one_or_none(self): result = self._result() first = result.fetchone() From a5dd68af4ee84ece3e67bbbd8d5e1be39865ea36 Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 5 Sep 2023 15:49:07 +0200 Subject: [PATCH 069/317] Simplify TempId resolution Re Spine-Toolbox/issues#2295 TempId are no longer replaced when resolved. Instead, we keep a mapping from TempId to db-id that we use whenever we interact with the db, and only then. The TempId stays in place on the cache-side for as long as the cache lives. This removes the need for TempIdDict. --- spinedb_api/db_cache_base.py | 44 ++++++--- spinedb_api/db_mapping_add_mixin.py | 12 +-- spinedb_api/db_mapping_remove_mixin.py | 3 +- spinedb_api/db_mapping_update_mixin.py | 5 +- spinedb_api/temp_id.py | 130 ++++--------------------- 5 files changed, 59 insertions(+), 135 deletions(-) diff --git a/spinedb_api/db_cache_base.py b/spinedb_api/db_cache_base.py index 7b6fd7b3..8ec0f11a 100644 --- a/spinedb_api/db_cache_base.py +++ b/spinedb_api/db_cache_base.py @@ -13,10 +13,9 @@ """ import threading -from contextlib import suppress from enum import Enum, unique, auto from functools import cmp_to_key -from .temp_id import TempIdDict, TempId +from .temp_id import TempId, resolve # TODO: Implement CacheItem.pop() to do lookup? @@ -165,8 +164,8 @@ def refresh(self): table_cache.add_item(item, new=True) # Store updated and removed so we can take the proper action # when we see their equivalents comming from the DB - self._updated_items.setdefault(item_type, {}).update({x["id"]: x for x in to_update}) - self._removed_items.setdefault(item_type, {}).update({x["id"]: x for x in to_remove}) + self._updated_items.setdefault(item_type, {}).update({resolve(x["id"]): x for x in to_update}) + self._removed_items.setdefault(item_type, {}).update({resolve(x["id"]): x for x in to_remove}) def _get_next_chunk(self, item_type): qry = self._query(item_type) @@ -238,17 +237,18 @@ def fetch_value(self, item_type, return_fn): def fetch_ref(self, item_type, id_): while self.fetch_more(item_type): - with suppress(KeyError): - return self[item_type][id_] + ref = self.get_item(item_type, id_) + if ref: + return ref # It is possible that fetching was completed between deciding to call this function # and starting the while loop above resulting in self.fetch_more() to return False immediately. # Therefore, we should try one last time if the ref is available. - with suppress(KeyError): - return self[item_type][id_] - return None + ref = self.get_item(item_type, id_) + if ref: + return ref -class _TableCache(TempIdDict): +class _TableCache(dict): def __init__(self, db_cache, item_type, *args, **kwargs): """ Args: @@ -259,9 +259,20 @@ def __init__(self, db_cache, item_type, *args, **kwargs): self._db_cache = db_cache self._item_type = item_type self._id_by_unique_key_value = {} + self._temp_id_by_db_id = {} + + def get(self, id_, default=None): + id_ = self._temp_id_by_db_id.get(id_, id_) + return super().get(id_, default) def _new_id(self): - return TempId(self._item_type) + temp_id = TempId(self._item_type) + + def _callback(db_id): + self._temp_id_by_db_id[db_id] = temp_id + + temp_id.add_resolve_callback(_callback) + return temp_id def unique_key_value_to_id(self, key, value, strict=False): """Returns the id that has the given value for the given unique key, or None. @@ -371,7 +382,7 @@ def check_item(self, item, for_update=False, skip_keys=()): def _add_unique(self, item): for key, value in item.unique_values(): - self._id_by_unique_key_value.setdefault(key, TempIdDict())[value] = item["id"] + self._id_by_unique_key_value.setdefault(key, {})[value] = item["id"] def _remove_unique(self, item): for key, value in item.unique_values(): @@ -419,7 +430,7 @@ def restore_item(self, id_): return current_item -class CacheItemBase(TempIdDict): +class CacheItemBase(dict): """A dictionary that represents an db item.""" _defaults = {} @@ -456,8 +467,8 @@ def __init__(self, db_cache, item_type, **kwargs): super().__init__(**kwargs) self._db_cache = db_cache self._item_type = item_type - self._referrers = TempIdDict() - self._weak_referrers = TempIdDict() + self._referrers = {} + self._weak_referrers = {} self.restore_callbacks = set() self.update_callbacks = set() self.remove_callbacks = set() @@ -569,7 +580,8 @@ def merge(self, other): if all(self.get(key) == value for key, value in other.items()): return None, "" merged = {**self._extended(), **other} - merged["id"] = self["id"] + if not isinstance(merged["id"], int): + merged["id"] = self["id"] return merged, "" def first_invalid_key(self): diff --git a/spinedb_api/db_mapping_add_mixin.py b/spinedb_api/db_mapping_add_mixin.py index ef0e21a7..6fbe6a66 100644 --- a/spinedb_api/db_mapping_add_mixin.py +++ b/spinedb_api/db_mapping_add_mixin.py @@ -16,7 +16,7 @@ from sqlalchemy.exc import DBAPIError from .exception import SpineIntegrityError, SpineDBAPIError -from .temp_id import TempId +from .temp_id import TempId, resolve class DatabaseMappingAddMixin: @@ -68,23 +68,23 @@ def _do_add_items(self, connection, tablename, *items_to_add): else: id_items.append(item) if id_items: - connection.execute(table.insert(), [x._asdict() for x in id_items]) + connection.execute(table.insert(), [resolve(x._asdict()) for x in id_items]) if temp_id_items: current_ids = {x["id"] for x in connection.execute(table.select())} next_id = max(current_ids, default=0) + 1 available_ids = set(range(1, next_id)) - current_ids - missing_id_count = len(temp_id_items) - len(available_ids) - new_ids = set(range(next_id, next_id + missing_id_count)) + required_id_count = len(temp_id_items) - len(available_ids) + new_ids = set(range(next_id, next_id + required_id_count)) ids = sorted(available_ids | new_ids) for id_, item in zip(ids, temp_id_items): temp_id = item["id"] temp_id.resolve(id_) - connection.execute(table.insert(), [x._asdict() for x in temp_id_items]) + connection.execute(table.insert(), [resolve(x._asdict()) for x in temp_id_items]) for tablename_, items_to_add_ in self._extra_items_to_add_per_table(tablename, items_to_add): if not items_to_add_: continue table = self._metadata.tables[self._real_tablename(tablename_)] - connection.execute(table.insert(), items_to_add_) + connection.execute(table.insert(), [resolve(x) for x in items_to_add_]) except DBAPIError as e: msg = f"DBAPIError while inserting {tablename} items: {e.orig.args}" raise SpineDBAPIError(msg) from e diff --git a/spinedb_api/db_mapping_remove_mixin.py b/spinedb_api/db_mapping_remove_mixin.py index bd1efb3a..8033bdb4 100644 --- a/spinedb_api/db_mapping_remove_mixin.py +++ b/spinedb_api/db_mapping_remove_mixin.py @@ -17,6 +17,7 @@ from sqlalchemy.exc import DBAPIError from .exception import SpineDBAPIError from .helpers import Asterisk, group_consecutive +from .temp_id import resolve # TODO: improve docstrings @@ -54,7 +55,7 @@ def _do_remove_items(self, connection, tablename, *ids): *ids: ids to remove """ tablename = self._real_tablename(tablename) - ids = set(ids) + ids = {resolve(id_) for id_ in ids} if tablename == "alternative": # Do not remove the Base alternative ids.discard(1) diff --git a/spinedb_api/db_mapping_update_mixin.py b/spinedb_api/db_mapping_update_mixin.py index 0bfcf678..f6834e31 100644 --- a/spinedb_api/db_mapping_update_mixin.py +++ b/spinedb_api/db_mapping_update_mixin.py @@ -15,6 +15,7 @@ from sqlalchemy.exc import DBAPIError from sqlalchemy.sql.expression import bindparam from .exception import SpineIntegrityError, SpineDBAPIError +from .temp_id import resolve class DatabaseMappingUpdateMixin: @@ -33,12 +34,12 @@ def _do_update_items(self, connection, tablename, *items_to_update): return try: upd = self._make_update_stmt(tablename, items_to_update[0].keys()) - connection.execute(upd, [item._asdict() for item in items_to_update]) + connection.execute(upd, [resolve(item._asdict()) for item in items_to_update]) for tablename_, items_to_update_ in self._extra_items_to_update_per_table(tablename, items_to_update): if not items_to_update_: continue upd = self._make_update_stmt(tablename_, items_to_update_[0].keys()) - connection.execute(upd, items_to_update_) + connection.execute(upd, [resolve(x) for x in items_to_update_]) except DBAPIError as e: msg = f"DBAPIError while updating '{tablename}' items: {e.orig.args}" raise SpineDBAPIError(msg) from e diff --git a/spinedb_api/temp_id.py b/spinedb_api/temp_id.py index 0ebf5663..84eb3464 100644 --- a/spinedb_api/temp_id.py +++ b/spinedb_api/temp_id.py @@ -26,6 +26,17 @@ def __init__(self, item_type): super().__init__() self._item_type = item_type self._resolve_callbacks = [] + self._db_id = None + + def __eq__(self, other): + return super().__eq__(other) or (self._db_id is not None and other == self._db_id) + + def __hash__(self): + return int.__hash__(self) + + @property + def db_id(self): + return self._db_id def __repr__(self): return f"TempId({self._item_type}, {super().__repr__()})" @@ -33,116 +44,15 @@ def __repr__(self): def add_resolve_callback(self, callback): self._resolve_callbacks.append(callback) - def remove_resolve_callback(self, callback): - try: - self._resolve_callbacks.remove(callback) - except ValueError: - pass - - def resolve(self, new_id): + def resolve(self, db_id): + self._db_id = db_id while self._resolve_callbacks: - self._resolve_callbacks.pop(0)(new_id) - - -class TempIdDict(dict): - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.key_map = {} - self._unbind_callbacks_by_key = {} - for key, value in kwargs.items(): - self._bind(key, value) - - def __getitem__(self, key): - key = self.key_map.get(key, key) - return super().__getitem__(key) - - def get(self, key, default=None): - key = self.key_map.get(key, key) - return super().get(key, default) - - def __setitem__(self, key, value): - super().__setitem__(key, value) - self._bind(key, value) - - def __delitem__(self, key): - super().__delitem__(key) - self._unbind(key) - - def setdefault(self, key, default): - value = super().setdefault(key, default) - self._bind(key, value) - return value - - def update(self, other): - super().update(other) - for key, value in other.items(): - self._bind(key, value) - - def pop(self, key, default): - if key in self: - self._unbind(key) - return super().pop(key, default) - - def _make_value_resolve_callback(self, key): - def callback(new_id): - self[key] = new_id - - return callback - - def _make_value_component_resolve_callback(self, key, value, i): - """Returns a callback to call when the given key is resolved. - - Args: - key (TempId) - """ - - def callback(new_id, i=i): - new_value = list(value) - new_value[i] = new_id - new_value = tuple(new_value) - self[key] = new_value - - return callback - - def _make_key_resolve_callback(self, key): - def callback(new_id): - if key in self: - self.key_map[key] = new_id - self[new_id] = self.pop(key, None) - - return callback - - def _make_key_component_resolve_callback(self, key, i): - def callback(new_id, i=i): - if key in self: - new_key = list(key) - new_key[i] = new_id - new_key = tuple(new_key) - self.key_map[key] = new_key - self[new_key] = self.pop(key, None) - - return callback + self._resolve_callbacks.pop(0)(db_id) - def _bind(self, key, value): - if isinstance(value, TempId): - value.add_resolve_callback(self._make_value_resolve_callback(key)) - elif isinstance(value, tuple): - for (i, v) in enumerate(value): - if isinstance(v, TempId): - v.add_resolve_callback(self._make_value_component_resolve_callback(key, value, i)) - elif isinstance(key, TempId): - callback = self._make_key_resolve_callback(key) - key.add_resolve_callback(callback) - self._unbind_callbacks_by_key.setdefault(key, []).append(lambda: key.remove_resolve_callback(callback)) - elif isinstance(key, tuple): - for i, k in enumerate(key): - if isinstance(k, TempId): - callback = self._make_key_component_resolve_callback(key, i) - k.add_resolve_callback(callback) - self._unbind_callbacks_by_key.setdefault(key, []).append( - lambda k=k, callback=callback: k.remove_resolve_callback(callback) - ) - def _unbind(self, key): - for callback in self._unbind_callbacks_by_key.pop(key, ()): - callback() +def resolve(value): + if isinstance(value, dict): + return {k: resolve(v) for k, v in value.items()} + if isinstance(value, TempId): + return value.db_id + return value From 4a141bd5c5afc253f681b72b16aaf874d6cbc38a Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 6 Sep 2023 08:40:12 +0200 Subject: [PATCH 070/317] Simplify refresh Spine-Toolbox#2300 Instead of clearing the cache and persisting the items, we just keep the cache intact and then check, as items are fetched from the DB, whether those items are already in the cache and in that case skip them. These guarantees that items never lose their TempID. Maybe we should rename it to CacheId. --- spinedb_api/db_cache_base.py | 49 ++++++-------------------- spinedb_api/db_mapping_commit_mixin.py | 1 - 2 files changed, 10 insertions(+), 40 deletions(-) diff --git a/spinedb_api/db_cache_base.py b/spinedb_api/db_cache_base.py index 8ec0f11a..e9f4165d 100644 --- a/spinedb_api/db_cache_base.py +++ b/spinedb_api/db_cache_base.py @@ -36,8 +36,6 @@ class DBCacheBase(dict): def __init__(self, chunk_size=None): super().__init__() - self._updated_items = {} - self._removed_items = {} self._offsets = {} self._offset_lock = threading.Lock() self._fetched_item_types = set() @@ -108,11 +106,6 @@ def dirty_items(self): dirty_items.append((item_type, (to_add, to_update, to_remove))) return dirty_items - def commit(self): - """Clears the internal storage of dirty items created by ``refresh``.""" - self._updated_items.clear() - self._removed_items.clear() - def rollback(self): """Discards uncommitted changes. @@ -147,25 +140,9 @@ def rollback(self): return True def refresh(self): - """Stores dirty items in internal dictionaries and clears the cache, so the DB can be fetched again. - Conflicts between new contents of the DB and dirty items are solved in favor of the latter - (See ``advance_query`` where we resolve those conflicts as consuming the queries). - """ - dirty_items = self.dirty_items() # Get dirty items before clearing - self.clear() - # Clear _offsets and _fetched_item_types before adding dirty items below, - # so those items are able to properly fetch their references from the DB + """Clears fetch progress, so the DB is queried again.""" self._offsets.clear() self._fetched_item_types.clear() - for item_type, (to_add, to_update, to_remove) in dirty_items: - # Add new items directly - table_cache = self.table_cache(item_type) - for item in to_add: - table_cache.add_item(item, new=True) - # Store updated and removed so we can take the proper action - # when we see their equivalents comming from the DB - self._updated_items.setdefault(item_type, {}).update({resolve(x["id"]): x for x in to_update}) - self._removed_items.setdefault(item_type, {}).update({resolve(x["id"]): x for x in to_remove}) def _get_next_chunk(self, item_type): qry = self._query(item_type) @@ -195,17 +172,7 @@ def advance_query(self, item_type): self._fetched_item_types.add(item_type) return [] table_cache = self.table_cache(item_type) - updated_items = self._updated_items.get(item_type, {}) - removed_items = self._removed_items.get(item_type, {}) for item in chunk: - updated_item = updated_items.pop(item["id"], None) - if updated_item: - table_cache.persist_item(updated_item) - continue - removed_item = removed_items.pop(item["id"], None) - if removed_item: - table_cache.persist_item(removed_item, removed=True) - continue table_cache.add_item(item) return chunk @@ -390,15 +357,19 @@ def _remove_unique(self, item): if id_by_value.get(value) == item["id"]: del id_by_value[value] - def persist_item(self, item, removed=False): - self[item["id"]] = item - if not removed: - self._add_unique(item) - def add_item(self, item, new=False): if not isinstance(item, CacheItemBase): item = self._make_item(item) item.polish() + if not new: + # Item comes from the DB + id_ = item["id"] + if id_ in self or id_ in self._temp_id_by_db_id: + # The item is already in the cache + return + if any(value in self._id_by_unique_key_value.get(key, {}) for key, value in item.unique_values()): + # An item with the same unique key is already in the cache + return if "id" not in item: item["id"] = self._new_id() self[item["id"]] = item diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py index c991a1ec..f0bcfed1 100644 --- a/spinedb_api/db_mapping_commit_mixin.py +++ b/spinedb_api/db_mapping_commit_mixin.py @@ -48,7 +48,6 @@ def commit_session(self, comment): self._do_remove_items(connection, tablename, *{x["id"] for x in to_remove}) self._do_update_items(connection, tablename, *to_update) self._do_add_items(connection, tablename, *to_add) - self.cache.commit() if self._memory: self._memory_dirty = True From 2736e13c5b4556293946dae39f32b0227d3c0cd0 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Wed, 6 Sep 2023 09:47:34 +0300 Subject: [PATCH 071/317] Fix unit tests Re spine-tools/Spine-Toolbox#2294 --- tests/filters/test_alternative_filter.py | 5 +++-- tests/import_mapping/test_generator.py | 8 ++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/filters/test_alternative_filter.py b/tests/filters/test_alternative_filter.py index 00160a6d..a677bbe2 100644 --- a/tests/filters/test_alternative_filter.py +++ b/tests/filters/test_alternative_filter.py @@ -21,6 +21,7 @@ apply_alternative_filter_to_parameter_value_sq, create_new_spine_database, DatabaseMapping, + from_database, import_alternatives, import_object_classes, import_object_parameter_values, @@ -136,8 +137,8 @@ def test_multiple_alternatives(self): alternative_filter_from_dict(self._db_map, config) parameters = self._db_map.query(self._db_map.parameter_value_sq).all() self.assertEqual(len(parameters), 2) - self.assertEqual(parameters[0].value, b"101.1") - self.assertEqual(parameters[1].value, b"23.0") + values = {from_database(p.value) for p in parameters} + self.assertEqual(values, {23.0, 101.1}) def _add_value_in_alternative(self, value, alternative): import_alternatives(self._db_map, [alternative]) diff --git a/tests/import_mapping/test_generator.py b/tests/import_mapping/test_generator.py index 3f5305d3..df0a8dac 100644 --- a/tests/import_mapping/test_generator.py +++ b/tests/import_mapping/test_generator.py @@ -429,15 +429,15 @@ def test_header_position_is_ignored_in_last_mapping_if_other_mappings_are_in_hea mapped_data, { "alternatives": {"Base"}, - "object_classes": {"Data"}, - "object_parameter_values": [ + "entity_classes": [("Data",)], + "parameter_values": [ ["Data", "d1", "parameter1", 1.1, "Base"], ["Data", "d1", "parameter2", -2.3, "Base"], ["Data", "d2", "parameter1", -1.1, "Base"], ["Data", "d2", "parameter2", 2.3, "Base"], ], - "object_parameters": [("Data", "parameter1"), ("Data", "parameter2")], - "objects": {("Data", "d1"), ("Data", "d2")}, + "parameter_definitions": [("Data", "parameter1"), ("Data", "parameter2")], + "entities": [("Data", "d1"), ("Data", "d2")], }, ) From 856f5e70f6d93e811fe0c5c0dd8214461a0b4dc9 Mon Sep 17 00:00:00 2001 From: Henrik Koski Date: Wed, 6 Sep 2023 11:40:02 +0300 Subject: [PATCH 072/317] Fix db requiring committing twice The need for double committing arose from the fact that if an entity or entity class was re-added after it had already been deleted from the db, some of its data was still present in the db. That caused some unique constraint errors. - Previously when an entity was deleted, only the entity table was considered. Now the data corresponding to the entity is also deleted from the entity_element -table. - When deleting an entity class, the data form the entity_class_dimension -table corresponding to the classes id is also deleted. Re spine-tools#Spine-Toolbox#2296 --- spinedb_api/db_mapping_remove_mixin.py | 31 ++++++++++++++++---------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/spinedb_api/db_mapping_remove_mixin.py b/spinedb_api/db_mapping_remove_mixin.py index 8033bdb4..f7999f99 100644 --- a/spinedb_api/db_mapping_remove_mixin.py +++ b/spinedb_api/db_mapping_remove_mixin.py @@ -54,23 +54,30 @@ def _do_remove_items(self, connection, tablename, *ids): Args: *ids: ids to remove """ - tablename = self._real_tablename(tablename) + tablenames = [self._real_tablename(tablename)] ids = {resolve(id_) for id_ in ids} - if tablename == "alternative": + if tablenames[0] == "alternative": # Do not remove the Base alternative ids.discard(1) if not ids: return - table = self._metadata.tables[tablename] - id_field = self._id_fields.get(tablename, "id") - id_column = getattr(table.c, id_field) - cond = or_(*(and_(id_column >= first, id_column <= last) for first, last in group_consecutive(ids))) - delete = table.delete().where(cond) - try: - connection.execute(delete) - except DBAPIError as e: - msg = f"DBAPIError while removing {tablename} items: {e.orig.args}" - raise SpineDBAPIError(msg) from e + if tablenames[0] == "entity_class": + # Also remove the items corresponding to the id in entity_class_dimension + tablenames.append("entity_class_dimension") + elif tablenames[0] == "entity": + # Also remove the items corresponding to the id in entity_element + tablenames.append("entity_element") + for tablename in tablenames: + table = self._metadata.tables[tablename] + id_field = self._id_fields.get(tablename, "id") + id_column = getattr(table.c, id_field) + cond = or_(*(and_(id_column >= first, id_column <= last) for first, last in group_consecutive(ids))) + delete = table.delete().where(cond) + try: + connection.execute(delete) + except DBAPIError as e: + msg = f"DBAPIError while removing {tablename} items: {e.orig.args}" + raise SpineDBAPIError(msg) from e def remove_unused_metadata(self): used_metadata_ids = set() From 7fa4b1e3549cf77387f96e7dfe735e733e681cff Mon Sep 17 00:00:00 2001 From: Henrik Koski Date: Wed, 6 Sep 2023 11:40:02 +0300 Subject: [PATCH 073/317] Fix db requiring committing twice The need for double committing arose from the fact that if an entity or entity class was re-added after it had already been deleted from the db, some of its data was still present in the db. That caused some unique constraint errors. - Previously when an entity was deleted, only the entity table was considered. Now the data corresponding to the entity is also deleted from the entity_element -table. - When deleting an entity class, the data form the entity_class_dimension -table corresponding to the classes id is also deleted. Re spine-tools/Spine-Toolbox#2296 --- spinedb_api/db_mapping_remove_mixin.py | 31 ++++++++++++++++---------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/spinedb_api/db_mapping_remove_mixin.py b/spinedb_api/db_mapping_remove_mixin.py index 8033bdb4..f7999f99 100644 --- a/spinedb_api/db_mapping_remove_mixin.py +++ b/spinedb_api/db_mapping_remove_mixin.py @@ -54,23 +54,30 @@ def _do_remove_items(self, connection, tablename, *ids): Args: *ids: ids to remove """ - tablename = self._real_tablename(tablename) + tablenames = [self._real_tablename(tablename)] ids = {resolve(id_) for id_ in ids} - if tablename == "alternative": + if tablenames[0] == "alternative": # Do not remove the Base alternative ids.discard(1) if not ids: return - table = self._metadata.tables[tablename] - id_field = self._id_fields.get(tablename, "id") - id_column = getattr(table.c, id_field) - cond = or_(*(and_(id_column >= first, id_column <= last) for first, last in group_consecutive(ids))) - delete = table.delete().where(cond) - try: - connection.execute(delete) - except DBAPIError as e: - msg = f"DBAPIError while removing {tablename} items: {e.orig.args}" - raise SpineDBAPIError(msg) from e + if tablenames[0] == "entity_class": + # Also remove the items corresponding to the id in entity_class_dimension + tablenames.append("entity_class_dimension") + elif tablenames[0] == "entity": + # Also remove the items corresponding to the id in entity_element + tablenames.append("entity_element") + for tablename in tablenames: + table = self._metadata.tables[tablename] + id_field = self._id_fields.get(tablename, "id") + id_column = getattr(table.c, id_field) + cond = or_(*(and_(id_column >= first, id_column <= last) for first, last in group_consecutive(ids))) + delete = table.delete().where(cond) + try: + connection.execute(delete) + except DBAPIError as e: + msg = f"DBAPIError while removing {tablename} items: {e.orig.args}" + raise SpineDBAPIError(msg) from e def remove_unused_metadata(self): used_metadata_ids = set() From aabf9d11b7f34735e73fb494f1966677374e4410 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Wed, 6 Sep 2023 12:11:00 +0300 Subject: [PATCH 074/317] Fix purging and removing items by Asterisk We need to fetch all items of given type before removing items from cache if Asterisk is used as id. Re #261 --- spinedb_api/db_mapping_remove_mixin.py | 9 ++++++ spinedb_api/purge.py | 10 +++++++ tests/test_purge.py | 41 ++++++++++++++++++++++++++ 3 files changed, 60 insertions(+) create mode 100644 tests/test_purge.py diff --git a/spinedb_api/db_mapping_remove_mixin.py b/spinedb_api/db_mapping_remove_mixin.py index 8033bdb4..e932b14e 100644 --- a/spinedb_api/db_mapping_remove_mixin.py +++ b/spinedb_api/db_mapping_remove_mixin.py @@ -31,6 +31,7 @@ def remove_items(self, tablename, *ids): tablename = self._real_tablename(tablename) table_cache = self.cache.table_cache(tablename) if Asterisk in ids: + self.cache.fetch_all(tablename) ids = table_cache ids = set(ids) if tablename == "alternative": @@ -46,6 +47,14 @@ def restore_items(self, tablename, *ids): return [table_cache.restore_item(id_) for id_ in ids] def purge_items(self, tablename): + """Removes all items from given table. + + Args: + tablename (str): name of table + + Returns: + bool: True if operation was successful, False otherwise + """ return self.remove_items(tablename, Asterisk) def _do_remove_items(self, connection, tablename, *ids): diff --git a/spinedb_api/purge.py b/spinedb_api/purge.py index a3f9a491..99790306 100644 --- a/spinedb_api/purge.py +++ b/spinedb_api/purge.py @@ -35,6 +35,16 @@ def _ids_for_item_type(db_map, item_type): def purge_url(url, purge_settings, logger=None): + """Removes all given types of items from database. + + Args: + url (str): database URL + purge_settings (dict): mapping from item type to boolean + logger (LoggerInterface, optional): logger + + Returns: + bool: True if operation was successful, False otherwise + """ try: db_map = DatabaseMapping(url) except (SpineDBAPIError, SpineDBVersionError) as err: diff --git a/tests/test_purge.py b/tests/test_purge.py new file mode 100644 index 00000000..1aacdd8e --- /dev/null +++ b/tests/test_purge.py @@ -0,0 +1,41 @@ +###################################################################################################################### +# Copyright (C) 2017-2022 Spine project consortium +# This file is part of Spine Database API. +# Spine Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General +# Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) +# any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General +# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with +# this program. If not, see . +###################################################################################################################### +import pathlib +import tempfile +import unittest + +from spinedb_api import DatabaseMapping +from spinedb_api.purge import purge_url + + +class TestPurgeUrl(unittest.TestCase): + def setUp(self): + self._temp_dir = tempfile.TemporaryDirectory() + path = pathlib.Path(self._temp_dir.name, "database.sqlite") + self._url = "sqlite:///" + str(path) + + def tearDown(self): + self._temp_dir.cleanup() + + def test_purge_entity_classes(self): + with DatabaseMapping(self._url, create=True) as db_map: + db_map.add_entity_classes({"name": "Soup"}) + db_map.commit_session("Add test data") + purge_url(self._url, {"alternative": False, "entity_class": True}) + with DatabaseMapping(self._url) as db_map: + entities = db_map.query(db_map.entity_class_sq).all() + self.assertEqual(entities, []) + alternatives = db_map.query(db_map.alternative_sq).all() + self.assertEqual(len(alternatives), 1) + + +if __name__ == '__main__': + unittest.main() From 7e0a1f3b669becd99d7ad121c0236104013ec97c Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 6 Sep 2023 11:33:49 +0200 Subject: [PATCH 075/317] Remove unused import and unnecessary condition --- spinedb_api/db_cache_base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/spinedb_api/db_cache_base.py b/spinedb_api/db_cache_base.py index e9f4165d..abaa0ab0 100644 --- a/spinedb_api/db_cache_base.py +++ b/spinedb_api/db_cache_base.py @@ -15,7 +15,7 @@ import threading from enum import Enum, unique, auto from functools import cmp_to_key -from .temp_id import TempId, resolve +from .temp_id import TempId # TODO: Implement CacheItem.pop() to do lookup? @@ -370,12 +370,12 @@ def add_item(self, item, new=False): if any(value in self._id_by_unique_key_value.get(key, {}) for key, value in item.unique_values()): # An item with the same unique key is already in the cache return + else: + item.status = Status.to_add if "id" not in item: item["id"] = self._new_id() self[item["id"]] = item self._add_unique(item) - if new: - item.status = Status.to_add return item def update_item(self, item): From c48477b1424f131bc75408e74b8eff1b5ad1a838 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 6 Sep 2023 12:15:15 +0200 Subject: [PATCH 076/317] Go through all tables (not only fetched ones) to collect dirty items Re #264 --- spinedb_api/db_cache_base.py | 22 +++++++++++++-------- spinedb_api/db_cache_impl.py | 38 +++++++++++++++++++++--------------- 2 files changed, 36 insertions(+), 24 deletions(-) diff --git a/spinedb_api/db_cache_base.py b/spinedb_api/db_cache_base.py index abaa0ab0..cdbe59ce 100644 --- a/spinedb_api/db_cache_base.py +++ b/spinedb_api/db_cache_base.py @@ -45,6 +45,10 @@ def __init__(self, chunk_size=None): def fetched_item_types(self): return self._fetched_item_types + @property + def _item_types(self): + raise NotImplementedError() + @staticmethod def _item_factory(item_type): raise NotImplementedError() @@ -65,9 +69,6 @@ def _cmp_item_type(self, a, b): return 1 return 0 - def _sorted_item_types(self): - sorted(self, key=cmp_to_key(self._cmp_item_type)) - def dirty_ids(self, item_type): return { item["id"] for item in self.get(item_type, {}).values() if item.status in (Status.to_add, Status.to_update) @@ -81,8 +82,10 @@ def dirty_items(self): list """ dirty_items = [] - for item_type in sorted(self, key=cmp_to_key(self._cmp_item_type)): - table_cache = self[item_type] + for item_type in sorted(self._item_types, key=cmp_to_key(self._cmp_item_type)): + table_cache = self.get(item_type) + if table_cache is None: + continue to_add = [] to_update = [] to_remove = [] @@ -99,9 +102,12 @@ def dirty_items(self): # This ensures cascade removal. # FIXME: We should also fetch the current item type because of multi-dimensional entities and # classes which also depend on zero-dimensional ones - for x in self: - if self._cmp_item_type(item_type, x) < 0: - self.fetch_all(x) + for other_item_type in self._item_types: + if ( + other_item_type not in self.fetched_item_types + and self._cmp_item_type(item_type, other_item_type) < 0 + ): + self.fetch_all(other_item_type) if to_add or to_update or to_remove: dirty_items.append((item_type, (to_add, to_update, to_remove))) return dirty_items diff --git a/spinedb_api/db_cache_impl.py b/spinedb_api/db_cache_impl.py index 497f4dab..dc1daa16 100644 --- a/spinedb_api/db_cache_impl.py +++ b/spinedb_api/db_cache_impl.py @@ -20,6 +20,23 @@ class DBCache(DBCacheBase): + _sq_name_by_item_type = { + "entity_class": "wide_entity_class_sq", + "entity": "wide_entity_sq", + "parameter_value_list": "parameter_value_list_sq", + "list_value": "list_value_sq", + "alternative": "alternative_sq", + "scenario": "scenario_sq", + "scenario_alternative": "scenario_alternative_sq", + "entity_group": "entity_group_sq", + "parameter_definition": "parameter_definition_sq", + "parameter_value": "parameter_value_sq", + "metadata": "metadata_sq", + "entity_metadata": "entity_metadata_sq", + "parameter_value_metadata": "parameter_value_metadata_sq", + "commit": "commit_sq", + } + def __init__(self, db_map, chunk_size=None): """ Args: @@ -28,6 +45,10 @@ def __init__(self, db_map, chunk_size=None): super().__init__(chunk_size=chunk_size) self._db_map = db_map + @property + def _item_types(self): + return list(self._sq_name_by_item_type) + @staticmethod def _item_factory(item_type): return { @@ -49,22 +70,7 @@ def _item_factory(item_type): def _query(self, item_type): if self._db_map.closed: return None - sq_name = { - "entity_class": "wide_entity_class_sq", - "entity": "wide_entity_sq", - "parameter_value_list": "parameter_value_list_sq", - "list_value": "list_value_sq", - "alternative": "alternative_sq", - "scenario": "scenario_sq", - "scenario_alternative": "scenario_alternative_sq", - "entity_group": "entity_group_sq", - "parameter_definition": "parameter_definition_sq", - "parameter_value": "parameter_value_sq", - "metadata": "metadata_sq", - "entity_metadata": "entity_metadata_sq", - "parameter_value_metadata": "parameter_value_metadata_sq", - "commit": "commit_sq", - }[item_type] + sq_name = self._sq_name_by_item_type[item_type] return self._db_map.query(getattr(self._db_map, sq_name)) From 7e5db073c6460b5b72d51e6cdcb857f7ff661c79 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 6 Sep 2023 12:17:18 +0200 Subject: [PATCH 077/317] Add docstrings --- spinedb_api/db_cache_base.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/spinedb_api/db_cache_base.py b/spinedb_api/db_cache_base.py index cdbe59ce..7653cc9a 100644 --- a/spinedb_api/db_cache_base.py +++ b/spinedb_api/db_cache_base.py @@ -47,13 +47,34 @@ def fetched_item_types(self): @property def _item_types(self): + """Returns a list of supported item type strings. + + Returns: + list + """ raise NotImplementedError() @staticmethod def _item_factory(item_type): + """Returns a subclass of CacheItemBase to build items of given type. + + Args: + item_type (str) + + Returns: + CacheItemBase + """ raise NotImplementedError() def _query(self, item_type): + """Returns a Query object to fecth items of given type. + + Args: + item_type (str) + + Returns: + Query + """ raise NotImplementedError() def make_item(self, item_type, **item): From cdfb966cff02e6c524b27b01cf255055f8585814 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Wed, 6 Sep 2023 15:04:01 +0300 Subject: [PATCH 078/317] Fix execution filter Import alternative creation was broken thanks to changes in DatabaseMapping. Re #265 --- spinedb_api/db_cache_base.py | 1 + spinedb_api/db_mapping_base.py | 2 +- spinedb_api/filters/execution_filter.py | 16 +++++++----- tests/filters/test_execution_filter.py | 34 +++++++++++++++++++++++++ tests/test_DatabaseMapping.py | 4 +++ 5 files changed, 49 insertions(+), 8 deletions(-) create mode 100644 tests/filters/test_execution_filter.py diff --git a/spinedb_api/db_cache_base.py b/spinedb_api/db_cache_base.py index abaa0ab0..2b7a14ed 100644 --- a/spinedb_api/db_cache_base.py +++ b/spinedb_api/db_cache_base.py @@ -247,6 +247,7 @@ def unique_key_value_to_id(self, key, value, strict=False): Args: key (tuple) value (tuple) + strict (bool): if True, raise a KeyError if id is not found Returns: int diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 6e3e5a49..ca6673de 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -1661,7 +1661,7 @@ def get_import_alternative_name(self): """Returns the name of the alternative to use as default for all import operations. Returns: - str + str: import alternative name """ if self._import_alternative_name is None: self._create_import_alternative() diff --git a/spinedb_api/filters/execution_filter.py b/spinedb_api/filters/execution_filter.py index d5d37fa2..e76fc6c3 100644 --- a/spinedb_api/filters/execution_filter.py +++ b/spinedb_api/filters/execution_filter.py @@ -122,15 +122,18 @@ def __init__(self, db_map, execution): self.original_create_import_alternative = db_map._create_import_alternative self.execution_item, self.scenarios, self.timestamp = self._parse_execution_descriptor(execution) - def _parse_execution_descriptor(self, execution): - """Raises ``SpineDBAPIError`` if descriptor not good. + @staticmethod + def _parse_execution_descriptor(execution): + """Parses data from execution descriptor. Args: execution (dict): execution descriptor Returns: - str: the execution item - list: scenarios + tuple: execution item name, list of scenario names, timestamp string + + Raises: + SpineDBAPIError: raised when execution descriptor is invalid """ try: execution_item = execution["execution_item"] @@ -156,9 +159,8 @@ def _create_import_alternative(db_map, state): timestamp = state.timestamp sep = "__" if scenarios else "" db_map._import_alternative_name = f"{'_'.join(scenarios)}{sep}{execution_item}@{timestamp}" - db_map.add_alternatives({"name": db_map._import_alternative_name}, _strict=False) - scenarios = [{"name": scen_name} for scen_name in scenarios] - db_map.add_scenarios(*scenarios, _strict=True) + db_map.add_alternatives({"name": db_map._import_alternative_name}) + db_map.add_scenarios(*({"name": scen_name} for scen_name in scenarios)) for scen_name in scenarios: scen = db_map.cache.table_cache("scenario").find_item({"name": scen_name}) rank = len(scen.sorted_scenario_alternatives) + 1 # ranks are 1-based diff --git a/tests/filters/test_execution_filter.py b/tests/filters/test_execution_filter.py new file mode 100644 index 00000000..bc44c7a7 --- /dev/null +++ b/tests/filters/test_execution_filter.py @@ -0,0 +1,34 @@ +###################################################################################################################### +# Copyright (C) 2017-2022 Spine project consortium +# This file is part of Spine Database API. +# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser +# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General +# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with +# this program. If not, see . +###################################################################################################################### +import unittest + +from spinedb_api import apply_execution_filter, DatabaseMapping + + +class TestExecutionFilter(unittest.TestCase): + def test_import_alternative_after_applying_execution_filter(self): + execution = { + "execution_item": "Importing importer", + "scenarios": ["low_on_steam", "wasting_my_time"], + "timestamp": "2023-09-06T01:23:45", + } + with DatabaseMapping("sqlite:///", create=True) as db_map: + apply_execution_filter(db_map, execution) + alternative_name = db_map.get_import_alternative_name() + self.assertEqual(alternative_name, "low_on_steam_wasting_my_time__Importing importer@2023-09-06T01:23:45") + alternatives = {item["name"] for item in db_map.cache["alternative"].values()} + self.assertIn(alternative_name, alternatives) + scenarios = {item["name"] for item in db_map.cache["scenario"].values()} + self.assertEqual(scenarios, {"low_on_steam", "wasting_my_time"}) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 1272d5d6..dec8c41c 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -338,6 +338,10 @@ def test_wide_parameter_value_list_sq(self): for column_name in columns: self.assertTrue(hasattr(self._db_map.wide_parameter_value_list_sq.c, column_name)) + def test_get_import_alternative_returns_base_alternative_by_default(self): + alternative_name = self._db_map.get_import_alternative_name() + self.assertEqual(alternative_name, "Base") + class TestDatabaseMappingBaseQueries(unittest.TestCase): def setUp(self): From c447319611976d493c238ac6a98797682f1598db Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Wed, 6 Sep 2023 16:11:40 +0300 Subject: [PATCH 079/317] Make Query.filter() accept any number of criteria This makes Query.filter() consistent with SQLAlchemy and fixes broken backwards compatibility. Re #267 --- spinedb_api/query.py | 5 +++-- tests/test_DatabaseMapping.py | 22 ++++++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/spinedb_api/query.py b/spinedb_api/query.py index 5ffcf47c..3df938e6 100644 --- a/spinedb_api/query.py +++ b/spinedb_api/query.py @@ -38,8 +38,9 @@ def add_columns(self, *columns): self._select = select(self._entities) return self - def filter(self, clause): - self._select = self._select.where(clause) + def filter(self, *clauses): + for clause in clauses: + self._select = self._select.where(clause) return self def filter_by(self, **kwargs): diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index dec8c41c..6b0929c7 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -597,6 +597,28 @@ def test_wide_parameter_value_list_sq(self): self.assertEqual(value_lists[0].name, "list1") self.assertEqual(value_lists[1].name, "list2") + def test_filter_query_accepts_multiple_criteria(self): + classes, errors = self._db_map.add_entity_classes({"name": "Real"}, {"name": "Fake"}) + self.assertEqual(errors, []) + self.assertEqual(len(classes), 2) + self.assertEqual(classes[0]["name"], "Real") + self.assertEqual(classes[1]["name"], "Fake") + real_class_id = classes[0]["id"] + fake_class_id = classes[1]["id"] + _, errors = self._db_map.add_entities( + {"name": "entity 1", "class_id": real_class_id}, + {"name": "entity_2", "class_id": real_class_id}, + {"name": "entity_1", "class_id": fake_class_id}, + ) + self.assertEqual(errors, []) + self._db_map.commit_session("Add test data") + sq = self._db_map.wide_entity_class_sq + real_class_id = self._db_map.query(sq).filter(sq.c.name == "Real").one().id + sq = self._db_map.wide_entity_sq + entity = self._db_map.query(sq).filter(sq.c.name == "entity 1", sq.c.class_id == 1).one() + self.assertEqual(entity.name, "entity 1") + self.assertEqual(entity.class_id, real_class_id) + class TestDatabaseMappingAdd(unittest.TestCase): def setUp(self): From 15c12ad2a0d5020fc2c39536c4961abc584cd87c Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 6 Sep 2023 16:25:39 +0200 Subject: [PATCH 080/317] Add test for cascade removal of unfetched items Re #264 --- tests/test_DatabaseMapping.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 6b0929c7..19cc8a6e 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -2285,6 +2285,17 @@ def test_refresh_update(self): entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").values()} self.assertEqual(entity_class_names, {"new_name"}) + def test_cascade_remove_unfetched(self): + import_functions.import_object_classes(self._db_map, ("my_class",)) + import_functions.import_objects(self._db_map, (("my_class", "my_object"),)) + self._db_map.commit_session("test commit") + self._db_map.refresh_session() + self._db_map.cache.clear() + self._db_map.remove_items("entity_class", 1) + self._db_map.commit_session("test commit") + ents = self._db_map.query(self._db_map.entity_sq).all() + self.assertEqual(ents, []) + if __name__ == "__main__": unittest.main() From 7dc3c59f9721885653f7f2952dd9103900712a3d Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 7 Sep 2023 14:58:30 +0200 Subject: [PATCH 081/317] Graph layout generator: Optimize case where all positions are heavy --- spinedb_api/graph_layout_generator.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/spinedb_api/graph_layout_generator.py b/spinedb_api/graph_layout_generator.py index dae7737b..074df3e9 100644 --- a/spinedb_api/graph_layout_generator.py +++ b/spinedb_api/graph_layout_generator.py @@ -102,6 +102,10 @@ def sets(self): def compute_layout(self): """Computes and returns x and y coordinates for each vertex in the graph, using VSGD-MS.""" + if len(self.heavy_positions) == self.vertex_count: + x, y = zip(*[(pos["x"], pos["y"]) for pos in self.heavy_positions.values()]) + self._layout_available(x, y) + return x, y if self.vertex_count <= 1: x, y = np.array([0.0]), np.array([0.0]) self._layout_available(x, y) From 834c6489695545b56f0d5c0672ef8c4bb5357e20 Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 7 Sep 2023 17:03:35 +0200 Subject: [PATCH 082/317] Introduce EntityAlternativeItem Re #269 --- spinedb_api/db_cache_impl.py | 22 ++++++++++++++++++++++ spinedb_api/db_mapping_base.py | 1 - 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/spinedb_api/db_cache_impl.py b/spinedb_api/db_cache_impl.py index dc1daa16..0a451382 100644 --- a/spinedb_api/db_cache_impl.py +++ b/spinedb_api/db_cache_impl.py @@ -23,6 +23,7 @@ class DBCache(DBCacheBase): _sq_name_by_item_type = { "entity_class": "wide_entity_class_sq", "entity": "wide_entity_sq", + "entity_alternative": "entity_alternative_sq", "parameter_value_list": "parameter_value_list_sq", "list_value": "list_value_sq", "alternative": "alternative_sq", @@ -54,6 +55,7 @@ def _item_factory(item_type): return { "entity_class": EntityClassItem, "entity": EntityItem, + "entity_alternative": EntityAlternativeItem, "entity_group": EntityGroupItem, "parameter_definition": ParameterDefinitionItem, "parameter_value": ParameterValueItem, @@ -167,6 +169,26 @@ def __getitem__(self, key): return super().__getitem__(key) +class EntityAlternativeItem(CacheItemBase): + _defaults = {"active": True} + _unique_keys = (("entity_class_name", "entity_byname", "alternative_name"),) + _references = { + "entity_class_id": ("entity_id", ("entity", "class_id")), + "entity_class_name": ("entity_class_id", ("entity_class", "name")), + "dimension_id_list": ("entity_class_id", ("entity_class", "dimension_id_list")), + "dimension_name_list": ("entity_class_id", ("entity_class", "dimension_name_list")), + "entity_name": ("entity_id", ("entity", "name")), + "entity_byname": ("entity_id", ("entity", "byname")), + "element_id_list": ("entity_id", ("entity", "element_id_list")), + "element_name_list": ("entity_id", ("entity", "element_name_list")), + "alternative_name": ("alternative_id", ("alternative", "name")), + } + _inverse_references = { + "entity_id": (("entity_class_name", "entity_byname"), ("entity", ("class_name", "byname"))), + "alternative_id": (("alternative_name",), ("alternative", ("name",))), + } + + class ParsedValueBase(CacheItemBase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index ca6673de..c9b0f10c 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -184,7 +184,6 @@ def __init__( } self.composite_pks = { "entity_element": ("entity_id", "position"), - "entity_alternative": ("entity_id", "alternative_id"), "entity_class_dimension": ("entity_class_id", "position"), } From 43619430be717df206ae9f1ec3f834238ad4835a Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 7 Sep 2023 17:04:04 +0200 Subject: [PATCH 083/317] Fix commit to proceed in the right order of tables --- spinedb_api/db_cache_base.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/spinedb_api/db_cache_base.py b/spinedb_api/db_cache_base.py index 68539690..98e54c7f 100644 --- a/spinedb_api/db_cache_base.py +++ b/spinedb_api/db_cache_base.py @@ -40,6 +40,14 @@ def __init__(self, chunk_size=None): self._offset_lock = threading.Lock() self._fetched_item_types = set() self._chunk_size = chunk_size + item_types = self._item_types + self._sorted_item_types = [] + while item_types: + item_type = item_types.pop(0) + if self._item_factory(item_type).ref_types() & set(item_types): + item_types.append(item_type) + else: + self._sorted_item_types.append(item_type) @property def fetched_item_types(self): @@ -81,15 +89,6 @@ def make_item(self, item_type, **item): factory = self._item_factory(item_type) return factory(self, item_type, **item) - def _cmp_item_type(self, a, b): - if a in self._item_factory(b).ref_types(): - # a should come before b - return -1 - if b in self._item_factory(a).ref_types(): - # a should come after b - return 1 - return 0 - def dirty_ids(self, item_type): return { item["id"] for item in self.get(item_type, {}).values() if item.status in (Status.to_add, Status.to_update) @@ -103,7 +102,7 @@ def dirty_items(self): list """ dirty_items = [] - for item_type in sorted(self._item_types, key=cmp_to_key(self._cmp_item_type)): + for item_type in self._sorted_item_types: table_cache = self.get(item_type) if table_cache is None: continue @@ -126,7 +125,7 @@ def dirty_items(self): for other_item_type in self._item_types: if ( other_item_type not in self.fetched_item_types - and self._cmp_item_type(item_type, other_item_type) < 0 + and item_type in self._item_factory(other_item_type).ref_types() ): self.fetch_all(other_item_type) if to_add or to_update or to_remove: From 57df64746d24864615a84bf71515b2a0a2ab93ed Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 8 Sep 2023 10:10:46 +0200 Subject: [PATCH 084/317] Filter entities with the scenario filter, first attempt Re #269 --- spinedb_api/db_mapping_base.py | 5 +++ spinedb_api/filters/scenario_filter.py | 42 ++++++++++++++++++-------- 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index c9b0f10c..e1d42347 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -505,6 +505,7 @@ def wide_entity_class_sq(self): ecd_sq.c.hidden, group_concat(ecd_sq.c.dimension_id, ecd_sq.c.position).label("dimension_id_list"), group_concat(ecd_sq.c.dimension_name, ecd_sq.c.position).label("dimension_name_list"), + func.count(ecd_sq.c.dimension_id).label("dimension_count"), ) .group_by( ecd_sq.c.id, @@ -562,6 +563,9 @@ def wide_entity_sq(self): group_concat(ext_entity_sq.c.element_id, ext_entity_sq.c.position).label("element_id_list"), group_concat(ext_entity_sq.c.element_name, ext_entity_sq.c.position).label("element_name_list"), ) + # element count might be lower than dimension count when element-entities have been filtered out + .filter(self.wide_entity_class_sq.c.id == ext_entity_sq.c.class_id) + .having(self.wide_entity_class_sq.c.dimension_count == func.count(ext_entity_sq.c.element_id)) .group_by( ext_entity_sq.c.id, ext_entity_sq.c.class_id, @@ -1625,6 +1629,7 @@ def _make_parameter_value_sq(self): par_val_sq.c.commit_id.label("commit_id"), par_val_sq.c.alternative_id, ) + .filter(par_val_sq.c.entity_id == self.entity_sq.c.id) .outerjoin(self.list_value_sq, self.list_value_sq.c.id == list_value_id) .subquery("clean_parameter_value_sq") ) diff --git a/spinedb_api/filters/scenario_filter.py b/spinedb_api/filters/scenario_filter.py index f3cd90a8..d33310ed 100644 --- a/spinedb_api/filters/scenario_filter.py +++ b/spinedb_api/filters/scenario_filter.py @@ -15,7 +15,7 @@ """ from functools import partial -from sqlalchemy import desc, func +from sqlalchemy import desc, func, or_ from ..exception import SpineDBAPIError SCENARIO_FILTER_TYPE = "scenario_filter" @@ -32,8 +32,8 @@ def apply_scenario_filter_to_subqueries(db_map, scenario): """ state = _ScenarioFilterState(db_map, scenario) # FIXME - # make_entity_sq = partial(_make_scenario_filtered_entity_sq, state=state) - # db_map.override_entity_sq_maker(make_entity_sq) + make_entity_sq = partial(_make_scenario_filtered_entity_sq, state=state) + db_map.override_entity_sq_maker(make_entity_sq) make_parameter_value_sq = partial(_make_scenario_filtered_parameter_value_sq, state=state) db_map.override_parameter_value_sq_maker(make_parameter_value_sq) make_alternative_sq = partial(_make_scenario_filtered_alternative_sq, state=state) @@ -196,7 +196,7 @@ def _make_scenario_filtered_entity_sq(db_map, state): Returns: Alias: a subquery for entity filtered by selected scenario """ - wide_entity_sq = ( + ext_entity_sq = ( db_map.query( state.original_entity_sq, func.row_number() @@ -204,15 +204,33 @@ def _make_scenario_filtered_entity_sq(db_map, state): partition_by=[state.original_entity_sq.c.id], order_by=desc(db_map.scenario_alternative_sq.c.rank), ) - .label("max_rank_row_number"), - db_map.entity_alternative_sq.c.active.label("active"), + .label("desc_rank_row_number"), + db_map.entity_alternative_sq.c.active, + db_map.scenario_alternative_sq.c.scenario_id, + ) + .outerjoin( + db_map.entity_alternative_sq, state.original_entity_sq.c.id == db_map.entity_alternative_sq.c.entity_id + ) + .outerjoin( + db_map.scenario_alternative_sq, + db_map.entity_alternative_sq.c.alternative_id == db_map.scenario_alternative_sq.c.alternative_id, + ) + .filter( + or_( + db_map.scenario_alternative_sq.c.scenario_id == None, + db_map.scenario_alternative_sq.c.scenario_id == state.scenario_id, + ) ) - .filter(state.original_entity_sq.c.id == db_map.entity_alternative_sq.c.entity_id) - .filter(db_map.entity_alternative_sq.c.alternative_id == db_map.scenario_alternative_sq.c.alternative_id) - .filter(db_map.scenario_alternative_sq.c.scenario_id == state.scenario_id) ).subquery() # TODO: Maybe we want to filter multi-dimensional entities involving filtered entities right here too? - return db_map.query(wide_entity_sq).filter_by(max_rank_row_number=1, active=True).subquery() + return ( + db_map.query(ext_entity_sq) + .filter( + ext_entity_sq.c.desc_rank_row_number == 1, + or_(ext_entity_sq.c.active == True, ext_entity_sq.c.active == None), + ) + .subquery() + ) def _make_scenario_filtered_parameter_value_sq(db_map, state): @@ -239,12 +257,12 @@ def _make_scenario_filtered_parameter_value_sq(db_map, state): ], order_by=desc(db_map.scenario_alternative_sq.c.rank), ) # the one with the highest rank will have row_number equal to 1, so it will 'win' in the filter below - .label("max_rank_row_number"), + .label("desc_rank_row_number"), ) .filter(state.original_parameter_value_sq.c.alternative_id == db_map.scenario_alternative_sq.c.alternative_id) .filter(db_map.scenario_alternative_sq.c.scenario_id == state.scenario_id) ).subquery() - return db_map.query(ext_parameter_value_sq).filter(ext_parameter_value_sq.c.max_rank_row_number == 1).subquery() + return db_map.query(ext_parameter_value_sq).filter(ext_parameter_value_sq.c.desc_rank_row_number == 1).subquery() def _make_scenario_filtered_alternative_sq(db_map, state): From 5c1c1c21d8740001e8e3ae3b4359cb7ac37a5587 Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 8 Sep 2023 19:28:15 +0200 Subject: [PATCH 085/317] Fix entity_sq --- spinedb_api/db_mapping_base.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index e1d42347..c59322d3 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -564,16 +564,15 @@ def wide_entity_sq(self): group_concat(ext_entity_sq.c.element_name, ext_entity_sq.c.position).label("element_name_list"), ) # element count might be lower than dimension count when element-entities have been filtered out - .filter(self.wide_entity_class_sq.c.id == ext_entity_sq.c.class_id) - .having(self.wide_entity_class_sq.c.dimension_count == func.count(ext_entity_sq.c.element_id)) + # .filter(self.wide_entity_class_sq.c.id == ext_entity_sq.c.class_id) + # .having(self.wide_entity_class_sq.c.dimension_count == func.count(ext_entity_sq.c.element_id)) .group_by( ext_entity_sq.c.id, ext_entity_sq.c.class_id, ext_entity_sq.c.name, ext_entity_sq.c.description, ext_entity_sq.c.commit_id, - ) - .subquery("wide_entity_sq") + ).subquery("wide_entity_sq") ) return self._wide_entity_sq From 3f0be54ce7361b3c7b0cf585cdcd2ee8e65290f1 Mon Sep 17 00:00:00 2001 From: Manuel Date: Mon, 11 Sep 2023 08:49:35 +0200 Subject: [PATCH 086/317] Memoize ts fixed res indexes and introduce IndexedValue.get_nearest --- spinedb_api/parameter_value.py | 42 ++++++++++++++++++++++------------ 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/spinedb_api/parameter_value.py b/spinedb_api/parameter_value.py index 9c4194f0..d6e3741b 100644 --- a/spinedb_api/parameter_value.py +++ b/spinedb_api/parameter_value.py @@ -846,6 +846,10 @@ def values(self, values): """Sets the values.""" self._values = values + def get_nearest(self, index): + pos = np.searchsorted(self.indexes, index) + return self.values[pos] + def get_value(self, index): """Returns the value at the given index.""" pos = self.indexes.position_lookup.get(index) @@ -1147,6 +1151,8 @@ class TimeSeriesFixedResolution(TimeSeries): other than having getters for their values. """ + _memoized_indexes = {} + def __init__(self, start, resolution, values, ignore_year, repeat, index_name=""): """ Args: @@ -1176,24 +1182,32 @@ def __eq__(self, other): and self.index_name == other.index_name ) + def _get_memoized_indexes(self): + key = (self.start, tuple(self.resolution), len(self)) + memoized_indexes = self._memoized_indexes.get(key) + if memoized_indexes is not None: + return memoized_indexes + step_index = 0 + step_cycle_index = 0 + full_cycle_duration = sum(self._resolution, relativedelta()) + stamps = np.empty(len(self), dtype=_NUMPY_DATETIME_DTYPE) + stamps[0] = self._start + for stamp_index in range(1, len(self._values)): + if step_index >= len(self._resolution): + step_index = 0 + step_cycle_index += 1 + current_cycle_duration = sum(self._resolution[: step_index + 1], relativedelta()) + duration_from_start = step_cycle_index * full_cycle_duration + current_cycle_duration + stamps[stamp_index] = self._start + duration_from_start + step_index += 1 + memoized_indexes = self._memoized_indexes[key] = np.array(stamps, dtype=_NUMPY_DATETIME_DTYPE) + return memoized_indexes + @property def indexes(self): """Returns the time stamps as a numpy.ndarray of numpy.datetime64 objects.""" if self._indexes is None: - step_index = 0 - step_cycle_index = 0 - full_cycle_duration = sum(self._resolution, relativedelta()) - stamps = np.empty(len(self), dtype=_NUMPY_DATETIME_DTYPE) - stamps[0] = self._start - for stamp_index in range(1, len(self._values)): - if step_index >= len(self._resolution): - step_index = 0 - step_cycle_index += 1 - current_cycle_duration = sum(self._resolution[: step_index + 1], relativedelta()) - duration_from_start = step_cycle_index * full_cycle_duration + current_cycle_duration - stamps[stamp_index] = self._start + duration_from_start - step_index += 1 - self.indexes = np.array(stamps, dtype=_NUMPY_DATETIME_DTYPE) + self.indexes = self._get_memoized_indexes() return IndexedValue.indexes.fget(self) @indexes.setter From b702c024af227eb9dc08f631c3878db03dccd0ad Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Wed, 13 Sep 2023 16:12:45 +0300 Subject: [PATCH 087/317] Make DataPackageConnector resilient to reading errors An error while iterating rows in datapackage table could cause the import process to traceback. We now catch and handle reading errors a bit better. Re spine-tools/Spine-Toolbox#2284 --- spinedb_api/exception.py | 4 ++ .../spine_io/importers/datapackage_reader.py | 13 ++++-- spinedb_api/spine_io/importers/reader.py | 4 +- .../importers/test_datapackage_reader.py | 21 ++++++++++ tests/spine_io/importers/test_reader.py | 42 +++++++++++++++++++ 5 files changed, 80 insertions(+), 4 deletions(-) create mode 100644 tests/spine_io/importers/test_reader.py diff --git a/spinedb_api/exception.py b/spinedb_api/exception.py index cd899da0..ea9f8980 100644 --- a/spinedb_api/exception.py +++ b/spinedb_api/exception.py @@ -103,3 +103,7 @@ def __init__(self, msg, rank=None, key=None): super().__init__(msg) self.rank = rank self.key = key + + +class ConnectorError(SpineDBAPIError): + """Failure in import connector.""" diff --git a/spinedb_api/spine_io/importers/datapackage_reader.py b/spinedb_api/spine_io/importers/datapackage_reader.py index bcf7e609..dadb4349 100644 --- a/spinedb_api/spine_io/importers/datapackage_reader.py +++ b/spinedb_api/spine_io/importers/datapackage_reader.py @@ -16,8 +16,10 @@ import threading from itertools import chain +import tabulator.exceptions from datapackage import Package from .reader import SourceConnection +from ...exception import ConnectorError class DataPackageConnector(SourceConnection): @@ -96,16 +98,21 @@ def get_data_iterator(self, table, options, max_rows=-1): if not self._datapackage: return iter([]), [] + def iterator(r): + try: + yield from (item for row, item in enumerate(r.iter(cast=False)) if row != max_rows) + except tabulator.exceptions.TabulatorException as error: + raise ConnectorError(str(error)) from error + has_header = options.get("has_header", True) for resource in self._datapackage.resources: with self._resource_name_lock: if resource.name is None: resource.infer() if table == resource.name: - iterator = (item for row, item in enumerate(resource.iter(cast=False)) if row != max_rows) if has_header: header = resource.schema.field_names - return iterator, header - return chain([resource.headers], iterator), None + return iterator(resource), header + return chain([resource.headers], iterator(resource)), None # table not found return iter([]), [] diff --git a/spinedb_api/spine_io/importers/reader.py b/spinedb_api/spine_io/importers/reader.py index 9fe7a3e5..92eda12f 100644 --- a/spinedb_api/spine_io/importers/reader.py +++ b/spinedb_api/spine_io/importers/reader.py @@ -15,6 +15,8 @@ """ from itertools import islice + +from spinedb_api.exception import ConnectorError from spinedb_api.import_mapping.generator import get_mapped_data, identity from spinedb_api.import_mapping.import_mapping_compat import parse_named_mapping_spec from spinedb_api import DateTime, Duration, ParameterValueFormatError @@ -149,7 +151,7 @@ def get_mapped_data( row_convert_fns, unparse_value, ) - except ParameterValueFormatError as error: + except (ConnectorError, ParameterValueFormatError) as error: errors.append(str(error)) continue for key, value in data.items(): diff --git a/tests/spine_io/importers/test_datapackage_reader.py b/tests/spine_io/importers/test_datapackage_reader.py index 7e7b03db..75906b29 100644 --- a/tests/spine_io/importers/test_datapackage_reader.py +++ b/tests/spine_io/importers/test_datapackage_reader.py @@ -20,6 +20,8 @@ import pickle from tempfile import TemporaryDirectory from datapackage import Package + +from spinedb_api.exception import ConnectorError from spinedb_api.spine_io.importers.datapackage_reader import DataPackageConnector @@ -56,6 +58,25 @@ def test_header_off_does_not_append_numbers_to_duplicate_cells(self): self.assertIsNone(header) self.assertEqual(list(data_iterator), data) + def test_wrong_datapackage_encoding_raises_connector_error(self): + broken_text = b"Slagn\xe4s" + # Fool the datapackage sniffing algorithm by hiding the broken line behind a large number of UTF-8 lines. + data = 1000 * [b"normal_text\n"] + [broken_text] + with TemporaryDirectory() as temp_dir: + csv_file_path = Path(temp_dir, "test_data.csv") + with open(csv_file_path, "wb") as csv_file: + for row in data: + csv_file.write(row) + package = Package(base_path=temp_dir) + package.add_resource({"path": str(csv_file_path.relative_to(temp_dir))}) + package_path = Path(temp_dir, "datapackage.json") + package.save(package_path) + reader = DataPackageConnector(None) + reader.connect_to_source(str(package_path)) + data_iterator, header = reader.get_data_iterator("test_data", {"has_header": False}) + self.assertIsNone(header) + self.assertRaises(ConnectorError, list, data_iterator) + @contextmanager def test_datapackage(rows): diff --git a/tests/spine_io/importers/test_reader.py b/tests/spine_io/importers/test_reader.py new file mode 100644 index 00000000..71e90535 --- /dev/null +++ b/tests/spine_io/importers/test_reader.py @@ -0,0 +1,42 @@ +###################################################################################################################### +# Copyright (C) 2017-2022 Spine project consortium +# This file is part of Spine Database API. +# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser +# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General +# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with +# this program. If not, see . +###################################################################################################################### +import unittest + +from spinedb_api.exception import ConnectorError +from spinedb_api.spine_io.importers.reader import SourceConnection + + +class TestSourceConnection(unittest.TestCase): + def test_get_mapped_data_can_handle_connector_error_in_data_iterator(self): + def failing_iterator(): + if True: + raise ConnectorError("error in iterator") + yield from [] + + reader = SourceConnection(None) + reader.get_data_iterator = lambda *args: (failing_iterator(), []) + table_mappings = {"table 1": []} + table_options = {} + table_column_convert_specs = {} + table_default_column_convert_fns = {} + table_row_convert_specs = {} + mapped_data, errors = reader.get_mapped_data( + table_mappings, + table_options, + table_column_convert_specs, + table_default_column_convert_fns, + table_row_convert_specs, + ) + self.assertEqual(errors, ["error in iterator"]) + + +if __name__ == '__main__': + unittest.main() From e2f1699eca5e0aef4b38e5a774da3d4d592ec9d2 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Thu, 14 Sep 2023 10:40:34 +0300 Subject: [PATCH 088/317] Invalidate cache item id instead of deleting it We shouldn't delete cache item's "id" key when rolling back changes because clients expect the id to be present. This was an issue in Toolbox where the Database editor's views would handle removed items expecting to have the "id" key available after it was deleted in spinedb_api. We don't touch the id anymore but use a new 'is_id_valid' flag to signal that the id has gone out of scope. Re spine-tools/Spine-Toolbox#2291 --- spinedb_api/db_cache_base.py | 26 ++++++++---- tests/test_db_cache_base.py | 76 ++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 7 deletions(-) create mode 100644 tests/test_db_cache_base.py diff --git a/spinedb_api/db_cache_base.py b/spinedb_api/db_cache_base.py index 98e54c7f..cd5e577c 100644 --- a/spinedb_api/db_cache_base.py +++ b/spinedb_api/db_cache_base.py @@ -8,10 +8,7 @@ # Public License for more details. You should have received a copy of the GNU Lesser General Public License along with # this program. If not, see . ###################################################################################################################### -""" -DB cache base. - -""" +"""DB cache base.""" import threading from enum import Enum, unique, auto from functools import cmp_to_key @@ -162,7 +159,7 @@ def rollback(self): table_cache = self.table_cache(item_type) for item in to_add: if table_cache.remove_item(item["id"]) is not None: - del item["id"] + item.invalidate_id() return True def refresh(self): @@ -399,7 +396,7 @@ def add_item(self, item, new=False): return else: item.status = Status.to_add - if "id" not in item: + if "id" not in item or not item.is_id_valid: item["id"] = self._new_id() self[item["id"]] = item self._add_unique(item) @@ -429,7 +426,7 @@ def restore_item(self, id_): class CacheItemBase(dict): - """A dictionary that represents an db item.""" + """A dictionary that represents a db item.""" _defaults = {} """A dictionary mapping keys to their default values""" @@ -470,6 +467,7 @@ def __init__(self, db_cache, item_type, **kwargs): self.restore_callbacks = set() self.update_callbacks = set() self.remove_callbacks = set() + self._is_id_valid = True self._to_remove = False self._removed = False self._corrupted = False @@ -546,6 +544,14 @@ def key(self): return None return (self._item_type, id_) + @property + def is_id_valid(self): + return self._is_id_valid + + def invalidate_id(self): + """Sets id as invalid.""" + self._is_id_valid = False + def _extended(self): """Returns a dict from this item's original fields plus all the references resolved statically. @@ -852,6 +858,12 @@ def __getitem__(self, key): return self._get_ref(ref_type, ref_id).get(ref_key) return super().__getitem__(key) + def __setitem__(self, key, value): + """Sets id valid if key is 'id'.""" + if key == "id": + self._is_id_valid = True + super().__setitem__(key, value) + def get(self, key, default=None): """Overridden to return references.""" try: diff --git a/tests/test_db_cache_base.py b/tests/test_db_cache_base.py new file mode 100644 index 00000000..012fe647 --- /dev/null +++ b/tests/test_db_cache_base.py @@ -0,0 +1,76 @@ +###################################################################################################################### +# Copyright (C) 2017-2022 Spine project consortium +# This file is part of Spine Database API. +# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser +# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General +# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with +# this program. If not, see . +###################################################################################################################### +import unittest + +from spinedb_api.db_cache_base import CacheItemBase, DBCacheBase + + +class TestCache(DBCacheBase): + @property + def _item_types(self): + return ["cutlery"] + + @staticmethod + def _item_factory(item_type): + if item_type == "cutlery": + return CacheItemBase + raise RuntimeError(f"unknown item_type '{item_type}'") + + +class TestDBCacheBase(unittest.TestCase): + def test_rolling_back_new_item_invalidates_its_id(self): + cache = TestCache() + table_cache = cache.table_cache("cutlery") + item = table_cache.add_item({}, new=True) + self.assertTrue(item.is_id_valid) + self.assertIn("id", item) + id_ = item["id"] + cache.rollback() + self.assertFalse(item.is_id_valid) + self.assertEqual(item["id"], id_) + + +class TestTableCache(unittest.TestCase): + def test_readding_item_with_invalid_id_creates_new_id(self): + cache = TestCache() + table_cache = cache.table_cache("cutlery") + item = table_cache.add_item({}, new=True) + id_ = item["id"] + cache.rollback() + self.assertFalse(item.is_id_valid) + table_cache.add_item(item, new=True) + self.assertTrue(item.is_id_valid) + self.assertNotEqual(item["id"], id_) + + +class TestCacheItemBase(unittest.TestCase): + def test_id_is_valid_initially(self): + cache = TestCache() + item = CacheItemBase(cache, "cutlery") + self.assertTrue(item.is_id_valid) + + def test_id_can_be_invalidated(self): + cache = TestCache() + item = CacheItemBase(cache, "cutlery") + item.invalidate_id() + self.assertFalse(item.is_id_valid) + + def test_setting_new_id_validates_it(self): + cache = TestCache() + item = CacheItemBase(cache, "cutlery") + item.invalidate_id() + self.assertFalse(item.is_id_valid) + item["id"] = 23 + self.assertTrue(item.is_id_valid) + + +if __name__ == '__main__': + unittest.main() From 8b6d8b55b762de3bb8e39cfd34a8160498bc8ac8 Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 21 Sep 2023 09:14:39 +0200 Subject: [PATCH 089/317] Let add_items and update_items return CacheItem instances Rather than just dictionaries. --- spinedb_api/db_mapping_add_mixin.py | 4 ++-- spinedb_api/db_mapping_update_mixin.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/spinedb_api/db_mapping_add_mixin.py b/spinedb_api/db_mapping_add_mixin.py index 6fbe6a66..5416f2bc 100644 --- a/spinedb_api/db_mapping_add_mixin.py +++ b/spinedb_api/db_mapping_add_mixin.py @@ -42,7 +42,7 @@ def add_items(self, tablename, *items, check=True, strict=False): if not check: for item in items: self._convert_legacy(tablename, item) - added.append(table_cache.add_item(item, new=True)._asdict()) + added.append(table_cache.add_item(item, new=True)) else: for item in items: self._convert_legacy(tablename, item) @@ -52,7 +52,7 @@ def add_items(self, tablename, *items, check=True, strict=False): raise SpineIntegrityError(error) errors.append(error) continue - added.append(table_cache.add_item(checked_item, new=True)._asdict()) + added.append(table_cache.add_item(checked_item, new=True)) return added, errors def _do_add_items(self, connection, tablename, *items_to_add): diff --git a/spinedb_api/db_mapping_update_mixin.py b/spinedb_api/db_mapping_update_mixin.py index f6834e31..78619b76 100644 --- a/spinedb_api/db_mapping_update_mixin.py +++ b/spinedb_api/db_mapping_update_mixin.py @@ -93,7 +93,7 @@ def update_items(self, tablename, *items, check=True, strict=False): if not check: for item in items: self._convert_legacy(tablename, item) - updated.append(table_cache.update_item(item)._asdict()) + updated.append(table_cache.update_item(item)) else: for item in items: self._convert_legacy(tablename, item) @@ -104,7 +104,7 @@ def update_items(self, tablename, *items, check=True, strict=False): errors.append(error) if checked_item: item = checked_item._asdict() - updated.append(table_cache.update_item(item)._asdict()) + updated.append(table_cache.update_item(item)) return updated, errors def update_alternatives(self, *items, **kwargs): From 460a1bb478e8d2ba176532939a3c29644d956c3d Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 21 Sep 2023 13:52:50 +0200 Subject: [PATCH 090/317] Fix export_data to only export entity stuff and be a little lazier --- spinedb_api/__init__.py | 19 ++--- spinedb_api/db_cache.py | 0 spinedb_api/export_functions.py | 144 +++----------------------------- tests/filters/test_tools.py | 31 ++++--- tests/test_export_functions.py | 38 +++++---- 5 files changed, 56 insertions(+), 176 deletions(-) delete mode 100644 spinedb_api/db_cache.py diff --git a/spinedb_api/__init__.py b/spinedb_api/__init__.py index 601b58f6..1c1315bf 100644 --- a/spinedb_api/__init__.py +++ b/spinedb_api/__init__.py @@ -58,19 +58,16 @@ get_data_for_import, ) from .export_functions import ( - export_alternatives, export_data, - export_object_classes, - export_object_groups, - export_object_parameters, - export_object_parameter_values, - export_objects, - export_relationship_classes, - export_relationship_parameter_values, - export_relationship_parameters, - export_relationships, - export_scenario_alternatives, + export_entity_classes, + export_entity_groups, + export_entities, + export_parameter_value_lists, + export_parameter_definitions, + export_parameter_values, export_scenarios, + export_alternatives, + export_scenario_alternatives, ) from .import_mapping.import_mapping_compat import import_mapping_from_dict from .import_mapping.generator import get_mapped_data diff --git a/spinedb_api/db_cache.py b/spinedb_api/db_cache.py deleted file mode 100644 index e69de29b..00000000 diff --git a/spinedb_api/export_functions.py b/spinedb_api/export_functions.py index 93b13d97..4eeefd63 100644 --- a/spinedb_api/export_functions.py +++ b/spinedb_api/export_functions.py @@ -24,19 +24,10 @@ def export_data( db_map, entity_class_ids=Asterisk, entity_ids=Asterisk, - parameter_definition_ids=Asterisk, - parameter_value_ids=Asterisk, entity_group_ids=Asterisk, - object_class_ids=Asterisk, - relationship_class_ids=Asterisk, parameter_value_list_ids=Asterisk, - object_parameter_ids=Asterisk, - relationship_parameter_ids=Asterisk, - object_ids=Asterisk, - object_group_ids=Asterisk, - relationship_ids=Asterisk, - object_parameter_value_ids=Asterisk, - relationship_parameter_value_ids=Asterisk, + parameter_definition_ids=Asterisk, + parameter_value_ids=Asterisk, alternative_ids=Asterisk, scenario_ids=Asterisk, scenario_alternative_ids=Asterisk, @@ -47,15 +38,12 @@ def export_data( Args: db_map (DiffDatabaseMapping): The db to pull stuff from. - object_class_ids (Iterable, optional): A collection of ids to pick from the database table - relationship_class_ids (Iterable, optional): A collection of ids to pick from the database table + entity_class_ids (Iterable, optional): A collection of ids to pick from the database table + entity_ids (Iterable, optional): A collection of ids to pick from the database table + entity_group_ids (Iterable, optional): A collection of ids to pick from the database table parameter_value_list_ids (Iterable, optional): A collection of ids to pick from the database table - object_parameter_ids (Iterable, optional): A collection of ids to pick from the database table - relationship_parameter_ids (Iterable, optional): A collection of ids to pick from the database table - object_ids (Iterable, optional): A collection of ids to pick from the database table - relationship_ids (Iterable, optional): A collection of ids to pick from the database table - object_parameter_value_ids (Iterable, optional): A collection of ids to pick from the database table - relationship_parameter_value_ids (Iterable, optional): A collection of ids to pick from the database table + parameter_definition_ids (Iterable, optional): A collection of ids to pick from the database table + parameter_value_ids (Iterable, optional): A collection of ids to pick from the database table alternative_ids (Iterable, optional): A collection of ids to pick from the database table scenario_ids (Iterable, optional): A collection of ids to pick from the database table scenario_alternative_ids (Iterable, optional): A collection of ids to pick from the database table @@ -67,28 +55,13 @@ def export_data( "entity_classes": export_entity_classes(db_map, entity_class_ids), "entities": export_entities(db_map, entity_ids), "entity_groups": export_entity_groups(db_map, entity_group_ids), - "parameter_definitions": export_parameter_definitions( - db_map, parameter_definition_ids, parse_value=parse_value - ), - "parameter_values": export_parameter_values(db_map, parameter_value_ids, parse_value=parse_value), - "object_classes": export_object_classes(db_map, object_class_ids), - "relationship_classes": export_relationship_classes(db_map, relationship_class_ids), "parameter_value_lists": export_parameter_value_lists( db_map, parameter_value_list_ids, parse_value=parse_value ), - "object_parameters": export_object_parameters(db_map, object_parameter_ids, parse_value=parse_value), - "relationship_parameters": export_relationship_parameters( - db_map, relationship_parameter_ids, parse_value=parse_value - ), - "objects": export_objects(db_map, object_ids), - "relationships": export_relationships(db_map, relationship_ids), - "object_groups": export_object_groups(db_map, object_group_ids), - "object_parameter_values": export_object_parameter_values( - db_map, object_parameter_value_ids, parse_value=parse_value - ), - "relationship_parameter_values": export_relationship_parameter_values( - db_map, relationship_parameter_value_ids, parse_value=parse_value + "parameter_definitions": export_parameter_definitions( + db_map, parameter_definition_ids, parse_value=parse_value ), + "parameter_values": export_parameter_values(db_map, parameter_value_ids, parse_value=parse_value), "alternatives": export_alternatives(db_map, alternative_ids), "scenarios": export_scenarios(db_map, scenario_ids), "scenario_alternatives": export_scenario_alternatives(db_map, scenario_alternative_ids), @@ -99,19 +72,18 @@ def export_data( def _get_items(db_map, tablename, ids): if not ids: return () - db_map.fetch_all({tablename}) _process_item = _make_item_processor(db_map, tablename) for item in _get_items_from_cache(db_map.cache, tablename, ids): yield from _process_item(item) def _get_items_from_cache(cache, tablename, ids): - items = cache.get(tablename, {}) if ids is Asterisk: - yield from items.values() + cache.fetch_all(tablename) + yield from cache.get(tablename, {}).values() return for id_ in ids: - item = items[id_] + item = cache.get_item(tablename, id_) or cache.fetch_ref(tablename, id_) if item.is_valid(): yield item @@ -191,96 +163,6 @@ def export_parameter_values(db_map, ids=Asterisk, parse_value=from_database): ) -def export_object_classes(db_map, ids=Asterisk): - return sorted( - (x.name, x.description, x.display_icon) - for x in _get_items(db_map, "entity_class", ids) - if not x.dimension_id_list - ) - - -def export_relationship_classes(db_map, ids=Asterisk): - return sorted( - (x.name, x.dimension_name_list, x.description, x.display_icon) - for x in _get_items(db_map, "entity_class", ids) - if x.dimension_id_list - ) - - -def export_objects(db_map, ids=Asterisk): - return sorted( - (x.class_name, x.name, x.description) for x in _get_items(db_map, "entity", ids) if not x.element_id_list - ) - - -def export_relationships(db_map, ids=Asterisk): - return sorted((x.class_name, x.element_name_list) for x in _get_items(db_map, "entity", ids) if x.element_id_list) - - -def export_object_groups(db_map, ids=Asterisk): - return sorted( - (x.class_name, x.group_name, x.member_name) - for x in _get_items(db_map, "entity_group", ids) - if not x.dimension_id_list - ) - - -def export_object_parameters(db_map, ids=Asterisk, parse_value=from_database): - return sorted( - ( - x.entity_class_name, - x.parameter_name, - parse_value(x.default_value, x.default_type), - x.value_list_name, - x.description, - ) - for x in _get_items(db_map, "parameter_definition", ids) - if not x.dimension_id_list - ) - - -def export_relationship_parameters(db_map, ids=Asterisk, parse_value=from_database): - return sorted( - ( - x.entity_class_name, - x.parameter_name, - parse_value(x.default_value, x.default_type), - x.value_list_name, - x.description, - ) - for x in _get_items(db_map, "parameter_definition", ids) - if x.dimension_id_list - ) - - -def export_object_parameter_values(db_map, ids=Asterisk, parse_value=from_database): - return sorted( - ( - (x.entity_class_name, x.entity_name, x.parameter_name, parse_value(x.value, x.type), x.alternative_name) - for x in _get_items(db_map, "parameter_value", ids) - if not x.element_id_list - ), - key=lambda x: x[:3] + (x[-1],), - ) - - -def export_relationship_parameter_values(db_map, ids=Asterisk, parse_value=from_database): - return sorted( - ( - ( - x.entity_class_name, - x.element_name_list, - x.parameter_name, - parse_value(x.value, x.type), - x.alternative_name, - ) - for x in _get_items(db_map, "parameter_value", ids) - if x.element_id_list - ), - key=lambda x: x[:3] + (x[-1],), - ) - - def export_alternatives(db_map, ids=Asterisk): """ Exports alternatives from database. diff --git a/tests/filters/test_tools.py b/tests/filters/test_tools.py index 41ecad30..f9a65abf 100644 --- a/tests/filters/test_tools.py +++ b/tests/filters/test_tools.py @@ -20,8 +20,7 @@ append_filter_config, clear_filter_configs, DatabaseMapping, - DatabaseMapping, - export_object_classes, + export_entity_classes, import_object_classes, pop_filter_configs, ) @@ -100,8 +99,8 @@ def test_empty_stack(self): db_map = DatabaseMapping(self._db_url) try: apply_filter_stack(db_map, []) - object_classes = export_object_classes(db_map) - self.assertEqual(object_classes, [("object_class", None, None)]) + object_classes = export_entity_classes(db_map) + self.assertEqual(object_classes, [("object_class", (), None, None)]) finally: db_map.close() @@ -110,8 +109,8 @@ def test_single_renaming_filter(self): try: stack = [entity_class_renamer_config(object_class="renamed_once")] apply_filter_stack(db_map, stack) - object_classes = export_object_classes(db_map) - self.assertEqual(object_classes, [("renamed_once", None, None)]) + object_classes = export_entity_classes(db_map) + self.assertEqual(object_classes, [("renamed_once", (), None, None)]) finally: db_map.close() @@ -123,8 +122,8 @@ def test_two_renaming_filters(self): entity_class_renamer_config(renamed_once="renamed_twice"), ] apply_filter_stack(db_map, stack) - object_classes = export_object_classes(db_map) - self.assertEqual(object_classes, [("renamed_twice", None, None)]) + object_classes = export_entity_classes(db_map) + self.assertEqual(object_classes, [("renamed_twice", (), None, None)]) finally: db_map.close() @@ -146,8 +145,8 @@ def setUpClass(cls): def test_without_filters(self): db_map = DatabaseMapping(self._db_url, self._engine) try: - object_classes = export_object_classes(db_map) - self.assertEqual(object_classes, [("object_class", None, None)]) + object_classes = export_entity_classes(db_map) + self.assertEqual(object_classes, [("object_class", (), None, None)]) finally: db_map.close() @@ -158,8 +157,8 @@ def test_single_renaming_filter(self): url = append_filter_config(str(self._db_url), path) db_map = DatabaseMapping(url, self._engine) try: - object_classes = export_object_classes(db_map) - self.assertEqual(object_classes, [("renamed_once", None, None)]) + object_classes = export_entity_classes(db_map) + self.assertEqual(object_classes, [("renamed_once", (), None, None)]) finally: db_map.close() @@ -174,8 +173,8 @@ def test_two_renaming_filters(self): url = append_filter_config(url, path2) db_map = DatabaseMapping(url, self._engine) try: - object_classes = export_object_classes(db_map) - self.assertEqual(object_classes, [("renamed_twice", None, None)]) + object_classes = export_entity_classes(db_map) + self.assertEqual(object_classes, [("renamed_twice", (), None, None)]) finally: db_map.close() @@ -184,8 +183,8 @@ def test_config_embedded_to_url(self): url = append_filter_config(str(self._db_url), config) db_map = DatabaseMapping(url, self._engine) try: - object_classes = export_object_classes(db_map) - self.assertEqual(object_classes, [("renamed_once", None, None)]) + object_classes = export_entity_classes(db_map) + self.assertEqual(object_classes, [("renamed_once", (), None, None)]) finally: db_map.close() diff --git a/tests/test_export_functions.py b/tests/test_export_functions.py index 8179d17c..489d549a 100644 --- a/tests/test_export_functions.py +++ b/tests/test_export_functions.py @@ -88,29 +88,31 @@ def test_export_data(self): import_scenarios(self._db_map, ["scenario"]) import_scenario_alternatives(self._db_map, [("scenario", "alternative")]) exported = export_data(self._db_map) - self.assertEqual(len(exported), 16) - self.assertIn("object_classes", exported) - self.assertEqual(exported["object_classes"], [("object_class", None, None)]) - self.assertIn("object_parameters", exported) - self.assertEqual(exported["object_parameters"], [("object_class", "object_parameter", None, None, None)]) - self.assertIn("objects", exported) - self.assertEqual(exported["objects"], [("object_class", "object", None)]) - self.assertIn("object_parameter_values", exported) + self.assertEqual(len(exported), 8) + self.assertIn("entity_classes", exported) self.assertEqual( - exported["object_parameter_values"], [("object_class", "object", "object_parameter", 2.3, "Base")] + exported["entity_classes"], + [("object_class", (), None, None), ("relationship_class", ("object_class",), None, None)], ) - self.assertIn("relationship_classes", exported) - self.assertEqual(exported["relationship_classes"], [("relationship_class", ("object_class",), None, None)]) - self.assertIn("relationship_parameters", exported) + self.assertIn("parameter_definitions", exported) self.assertEqual( - exported["relationship_parameters"], [("relationship_class", "relationship_parameter", None, None, None)] + exported["parameter_definitions"], + [ + ("object_class", "object_parameter", None, None, None), + ("relationship_class", "relationship_parameter", None, None, None), + ], ) - self.assertIn("relationships", exported) - self.assertEqual(exported["relationships"], [("relationship_class", ("object",))]) - self.assertIn("relationship_parameter_values", exported) + self.assertIn("entities", exported) self.assertEqual( - exported["relationship_parameter_values"], - [("relationship_class", ("object",), "relationship_parameter", 3.14, "Base")], + exported["entities"], [("object_class", "object", None), ("relationship_class", ("object",), None)] + ) + self.assertIn("parameter_values", exported) + self.assertEqual( + exported["parameter_values"], + [ + ("object_class", "object", "object_parameter", 2.3, "Base"), + ("relationship_class", ("object",), "relationship_parameter", 3.14, "Base"), + ], ) self.assertIn("parameter_value_lists", exported) self.assertEqual(exported["parameter_value_lists"], [("value_list", "5.5"), ("value_list", "6.4")]) From 4df2b048bfcbc1149f2db19c6a1be9ced589d2bd Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 22 Sep 2023 13:29:26 +0200 Subject: [PATCH 091/317] Fix exporting value list name of parameter definitions --- spinedb_api/export_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spinedb_api/export_functions.py b/spinedb_api/export_functions.py index 4eeefd63..a2ccf395 100644 --- a/spinedb_api/export_functions.py +++ b/spinedb_api/export_functions.py @@ -140,7 +140,7 @@ def export_parameter_definitions(db_map, ids=Asterisk, parse_value=from_database x.entity_class_name, x.parameter_name, parse_value(x.default_value, x.default_type), - x.value_list_name, + x.parameter_value_list_name, x.description, ) for x in _get_items(db_map, "parameter_definition", ids) From 66f6682b295635c96358092a6e265db238c27c71 Mon Sep 17 00:00:00 2001 From: Manuel Date: Sat, 23 Sep 2023 17:16:06 +0200 Subject: [PATCH 092/317] Add script to create the entity_alternative table --- ...a82ed59_create_entity_alternative_table.py | 51 +++++++++ spinedb_api/compatibility.py | 101 ++++++++++++++++++ spinedb_api/db_mapping_commit_mixin.py | 4 +- spinedb_api/helpers.py | 2 +- 4 files changed, 156 insertions(+), 2 deletions(-) create mode 100644 spinedb_api/alembic/versions/ce9faa82ed59_create_entity_alternative_table.py create mode 100644 spinedb_api/compatibility.py diff --git a/spinedb_api/alembic/versions/ce9faa82ed59_create_entity_alternative_table.py b/spinedb_api/alembic/versions/ce9faa82ed59_create_entity_alternative_table.py new file mode 100644 index 00000000..8671ab4e --- /dev/null +++ b/spinedb_api/alembic/versions/ce9faa82ed59_create_entity_alternative_table.py @@ -0,0 +1,51 @@ +"""create entity_alternative table + +Revision ID: ce9faa82ed59 +Revises: 6b7c994c1c61 +Create Date: 2023-09-21 14:35:28.867803 + +""" +from alembic import op +import sqlalchemy as sa +from spinedb_api.compatibility import convert_tool_feature_method_to_entity_alternative + + +# revision identifiers, used by Alembic. +revision = 'ce9faa82ed59' +down_revision = '6b7c994c1c61' +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + 'entity_alternative', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('entity_id', sa.Integer(), nullable=False), + sa.Column('alternative_id', sa.Integer(), nullable=False), + sa.Column('active', sa.Boolean(name='active'), server_default=sa.text('1'), nullable=False), + sa.Column('commit_id', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ['alternative_id'], + ['alternative.id'], + name=op.f('fk_entity_alternative_alternative_id_alternative'), + onupdate='CASCADE', + ondelete='CASCADE', + ), + sa.ForeignKeyConstraint(['commit_id'], ['commit.id'], name=op.f('fk_entity_alternative_commit_id_commit')), + sa.ForeignKeyConstraint( + ['entity_id'], + ['entity.id'], + name=op.f('fk_entity_alternative_entity_id_entity'), + onupdate='CASCADE', + ondelete='CASCADE', + ), + sa.PrimaryKeyConstraint('id', name=op.f('pk_entity_alternative')), + sa.UniqueConstraint('entity_id', 'alternative_id', name=op.f('uq_entity_alternative_entity_idalternative_id')), + ) + op.drop_table('next_id') + convert_tool_feature_method_to_entity_alternative(op.get_bind()) + + +def downgrade(): + pass diff --git a/spinedb_api/compatibility.py b/spinedb_api/compatibility.py new file mode 100644 index 00000000..5c3d5055 --- /dev/null +++ b/spinedb_api/compatibility.py @@ -0,0 +1,101 @@ +###################################################################################################################### +# Copyright (C) 2017-2022 Spine project consortium +# This file is part of Spine Database API. +# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser +# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General +# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with +# this program. If not, see . +###################################################################################################################### + +"""Provides dirty hacks needed to maintain compatibility in cases where migration alone doesn't do it.""" + +import sqlalchemy as sa + + +def convert_tool_feature_method_to_entity_alternative(conn, db_map=None): + """Transforms parameter_value rows into entity_alternative rows, whenever the former are used in a tool filter + to control entity activity. + + Args: + conn (Connection) + """ + meta = sa.MetaData(conn) + meta.reflect() + ea_table = meta.tables["entity_alternative"] + lv_table = meta.tables["list_value"] + pv_table = meta.tables["parameter_value"] + try: + # Compute list-value id by parameter definition id for all features and methods + tfm_table = meta.tables["tool_feature_method"] + tf_table = meta.tables["tool_feature"] + f_table = meta.tables["feature"] + lv_id_by_pdef_id = { + x["parameter_definition_id"]: x["id"] + for x in conn.execute( + sa.select([lv_table.c.id, f_table.c.parameter_definition_id]) + .where(tfm_table.c.parameter_value_list_id == lv_table.c.parameter_value_list_id) + .where(tfm_table.c.method_index == lv_table.c.index) + .where(tf_table.c.id == tfm_table.c.tool_feature_id) + .where(f_table.c.id == tf_table.c.feature_id) + ) + } + except KeyError: + # It's a new DB without tool/feature/method + # we take 'is_active' as feature and JSON "yes" and true as methods + pd_table = meta.tables["parameter_definition"] + lv_id_by_pdef_id = { + x["parameter_definition_id"]: x["id"] + for x in conn.execute( + sa.select([lv_table.c.id, lv_table.c.value, pd_table.c.id.label("parameter_definition_id")]) + .where(lv_table.c.parameter_value_list_id == pd_table.c.parameter_value_list_id) + .where(pd_table.c.name == "is_active") + .where(lv_table.c.value.in_((b'"yes"', b"true"))) + ) + } + # Collect 'is_active' parameter values + list_value_id = sa.case( + [(pv_table.c.type == "list_value_ref", sa.cast(pv_table.c.value, sa.Integer()))], else_=None + ) + is_active_pvals = [ + {c: x[c] for c in ("id", "entity_id", "alternative_id", "parameter_definition_id", "list_value_id")} + for x in conn.execute( + sa.select([pv_table, list_value_id.label("list_value_id")]).where( + pv_table.c.parameter_definition_id.in_(lv_id_by_pdef_id) + ) + ) + ] + # Compute new entity_alternative items from 'is_active' parameter values, + # where 'active' is True if the value of 'is_active' is the one from the tool_feature_method specification + current_ea_keys = {(x["entity_id"], x["alternative_id"]) for x in conn.execute(sa.select([ea_table]))} + new_ea_items = { + (x["entity_id"], x["alternative_id"]): { + "entity_id": x["entity_id"], + "alternative_id": x["alternative_id"], + "active": x["list_value_id"] == lv_id_by_pdef_id[x["parameter_definition_id"]], + } + for x in is_active_pvals + } + # Add or update entity_alternative records + ea_items_to_add = [new_ea_items[key] for key in set(new_ea_items) - current_ea_keys] + ea_items_to_update = [new_ea_items[key] for key in set(new_ea_items) & current_ea_keys] + if ea_items_to_add: + conn.execute(ea_table.insert(), ea_items_to_add) + if ea_items_to_update: + conn.execute( + ea_table.update() + .where(ea_table.c.entity_id == sa.bindparam("b_entity_id")) + .where(ea_table.c.alternative_id == sa.bindparam("b_alternative_id")) + .values(active=sa.bindparam("b_active")), + [{"b_" + k: v for k, v in x.items()} for x in ea_items_to_update], + ) + # Delete pvals 499 at a time to avoid too many sql variables + is_active_pval_ids = [x["id"] for x in is_active_pvals] + size = 499 + for i in range(0, len(is_active_pval_ids), size): + ids = is_active_pval_ids[i : i + size] + conn.execute(pv_table.delete().where(pv_table.c.id.in_(ids))) + if db_map is not None: + db_map.remove_items("parameter_value", *is_active_pval_ids) + # TODO: add and update ea items diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py index f0bcfed1..00781d3d 100644 --- a/spinedb_api/db_mapping_commit_mixin.py +++ b/spinedb_api/db_mapping_commit_mixin.py @@ -17,6 +17,7 @@ from datetime import datetime, timezone import sqlalchemy.exc from .exception import SpineDBAPIError +from .compatibility import convert_tool_feature_method_to_entity_alternative class DatabaseMappingCommitMixin: @@ -40,7 +41,7 @@ def commit_session(self, comment): try: commit_id = connection.execute(ins, dict(user=user, date=date, comment=comment)).inserted_primary_key[0] except sqlalchemy.exc.DBAPIError as e: - raise SpineDBAPIError(f"Fail to commit: {e.orig.args}") + raise SpineDBAPIError(f"Fail to commit: {e.orig.args}") from e for tablename, (to_add, to_update, to_remove) in dirty_items: for item in to_add + to_update + to_remove: item.commit(commit_id) @@ -48,6 +49,7 @@ def commit_session(self, comment): self._do_remove_items(connection, tablename, *{x["id"] for x in to_remove}) self._do_update_items(connection, tablename, *to_update) self._do_add_items(connection, tablename, *to_add) + convert_tool_feature_method_to_entity_alternative(connection, self) if self._memory: self._memory_dirty = True diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index f83d3224..f5bbbfc2 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -585,7 +585,7 @@ def create_new_spine_database(db_url): meta.create_all(engine) engine.execute("INSERT INTO `commit` VALUES (1, 'Create the database', CURRENT_TIMESTAMP, 'spinedb_api')") engine.execute("INSERT INTO alternative VALUES (1, 'Base', 'Base alternative', 1)") - engine.execute("INSERT INTO alembic_version VALUES ('6b7c994c1c61')") + engine.execute("INSERT INTO alembic_version VALUES ('ce9faa82ed59')") except DatabaseError as e: raise SpineDBAPIError(f"Unable to create Spine database: {e}") from None return engine From 91ef8d8e40c01ba5f6a24e4de634a58693d3418f Mon Sep 17 00:00:00 2001 From: Manuel Date: Sun, 24 Sep 2023 10:19:28 +0200 Subject: [PATCH 093/317] Fix order of statements in cascade remove and restore from cache --- spinedb_api/db_cache_base.py | 15 ++++++++------- spinedb_api/import_functions.py | 24 +++++++++++++----------- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/spinedb_api/db_cache_base.py b/spinedb_api/db_cache_base.py index cd5e577c..2e2fc873 100644 --- a/spinedb_api/db_cache_base.py +++ b/spinedb_api/db_cache_base.py @@ -11,7 +11,6 @@ """DB cache base.""" import threading from enum import Enum, unique, auto -from functools import cmp_to_key from .temp_id import TempId # TODO: Implement CacheItem.pop() to do lookup? @@ -773,14 +772,15 @@ def cascade_restore(self, source=None): else: raise RuntimeError("invalid status for item being restored") self._removed = False - for referrer in self._referrers.values(): - referrer.cascade_restore(source=self) - self._update_weak_referrers() + # First restore this, then referrers obsolete = set() for callback in list(self.restore_callbacks): if not callback(self): obsolete.add(callback) self.restore_callbacks -= obsolete + for referrer in self._referrers.values(): + referrer.cascade_restore(source=self) + self._update_weak_referrers() def cascade_remove(self, source=None): """Removes this item and all its referrers in cascade. @@ -799,14 +799,15 @@ def cascade_remove(self, source=None): self._removed = True self._to_remove = False self._valid = None + # First remove referrers, then this + for referrer in self._referrers.values(): + referrer.cascade_remove(source=self) + self._update_weak_referrers() obsolete = set() for callback in list(self.remove_callbacks): if not callback(self): obsolete.add(callback) self.remove_callbacks -= obsolete - for referrer in self._referrers.values(): - referrer.cascade_remove(source=self) - self._update_weak_referrers() def cascade_update(self): """Updates this item and all its referrers in cascade. diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index a1bddbbb..d3badb8e 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -112,12 +112,20 @@ def get_data_for_import( on_conflict="merge", entity_classes=(), entities=(), + entity_groups=(), + entity_alternatives=(), # TODO parameter_definitions=(), parameter_values=(), - entity_groups=(), + parameter_value_lists=(), + alternatives=(), + scenarios=(), + scenario_alternatives=(), + metadata=(), + entity_metadata=(), + parameter_value_metadata=(), + # legacy object_classes=(), relationship_classes=(), - parameter_value_lists=(), object_parameters=(), relationship_parameters=(), objects=(), @@ -125,17 +133,11 @@ def get_data_for_import( object_groups=(), object_parameter_values=(), relationship_parameter_values=(), - alternatives=(), - scenarios=(), - scenario_alternatives=(), - metadata=(), - entity_metadata=(), - parameter_value_metadata=(), object_metadata=(), relationship_metadata=(), object_parameter_value_metadata=(), relationship_parameter_value_metadata=(), - # legacy + # removed tools=(), features=(), tool_features=(), @@ -167,8 +169,8 @@ def get_data_for_import( list of lists with relationship class name, list of object names, parameter name, parameter value - Returns: - dict(str, list) + Yields: + tuple(str, list) """ # NOTE: The order is important, because of references. E.g., we want to import alternatives before parameter_values if alternatives: From 3b4387208396f98207b8ad84e842677bc8deff31 Mon Sep 17 00:00:00 2001 From: Henrik Koski Date: Mon, 25 Sep 2023 11:47:04 +0300 Subject: [PATCH 094/317] Fix imports in tests --- tests/spine_io/test_excel_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/spine_io/test_excel_integration.py b/tests/spine_io/test_excel_integration.py index e1ed5b4b..8b500b07 100644 --- a/tests/spine_io/test_excel_integration.py +++ b/tests/spine_io/test_excel_integration.py @@ -21,7 +21,7 @@ from spinedb_api import DatabaseMapping, import_data, from_database from spinedb_api.spine_io.exporters.excel import export_spine_database_to_xlsx from spinedb_api.spine_io.importers.excel_reader import get_mapped_data_from_xlsx -from ..test_import_functions import assert_import_equivalent +from tests.test_import_functions import assert_import_equivalent _TEMP_EXCEL_FILENAME = "excel.xlsx" From 6e16d0730aa604a3006990d37e8f68931d2f57a0 Mon Sep 17 00:00:00 2001 From: Henrik Koski <98282892+PiispaH@users.noreply.github.com> Date: Mon, 25 Sep 2023 12:08:39 +0300 Subject: [PATCH 095/317] Fix commit and redo causes Traceback (#275) --- spinedb_api/db_cache_base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/spinedb_api/db_cache_base.py b/spinedb_api/db_cache_base.py index 2e2fc873..f6cdd520 100644 --- a/spinedb_api/db_cache_base.py +++ b/spinedb_api/db_cache_base.py @@ -769,6 +769,8 @@ def cascade_restore(self, source=None): return if self.status in (Status.added_and_removed, Status.to_remove): self._status = self._status_when_removed + elif self.status == Status.committed: + self._status = Status.to_add else: raise RuntimeError("invalid status for item being restored") self._removed = False From 7a4536dcaa3c72442707732e2dbe30c786886fb5 Mon Sep 17 00:00:00 2001 From: Manuel Date: Mon, 25 Sep 2023 21:57:52 +0200 Subject: [PATCH 096/317] Introduce refit_data and call it on commit ...so clients can 'refresh their views'. --- spinedb_api/compatibility.py | 58 ++++++++++++++++++-------- spinedb_api/db_cache_base.py | 2 +- spinedb_api/db_mapping_commit_mixin.py | 8 ++-- 3 files changed, 45 insertions(+), 23 deletions(-) diff --git a/spinedb_api/compatibility.py b/spinedb_api/compatibility.py index 5c3d5055..f2b64372 100644 --- a/spinedb_api/compatibility.py +++ b/spinedb_api/compatibility.py @@ -9,17 +9,22 @@ # this program. If not, see . ###################################################################################################################### -"""Provides dirty hacks needed to maintain compatibility in cases where migration alone doesn't do it.""" +"""Dirty hacks needed to maintain compatibility in cases where migration alone doesn't do it.""" import sqlalchemy as sa -def convert_tool_feature_method_to_entity_alternative(conn, db_map=None): +def convert_tool_feature_method_to_entity_alternative(conn): """Transforms parameter_value rows into entity_alternative rows, whenever the former are used in a tool filter to control entity activity. Args: conn (Connection) + + Returns: + list: entity_alternative items to add + list: entity_alternative items to update + list: parameter_value ids to remove """ meta = sa.MetaData(conn) meta.reflect() @@ -68,7 +73,7 @@ def convert_tool_feature_method_to_entity_alternative(conn, db_map=None): ] # Compute new entity_alternative items from 'is_active' parameter values, # where 'active' is True if the value of 'is_active' is the one from the tool_feature_method specification - current_ea_keys = {(x["entity_id"], x["alternative_id"]) for x in conn.execute(sa.select([ea_table]))} + current_ea_ids = {(x["entity_id"], x["alternative_id"]): x["id"] for x in conn.execute(sa.select([ea_table]))} new_ea_items = { (x["entity_id"], x["alternative_id"]): { "entity_id": x["entity_id"], @@ -78,24 +83,41 @@ def convert_tool_feature_method_to_entity_alternative(conn, db_map=None): for x in is_active_pvals } # Add or update entity_alternative records - ea_items_to_add = [new_ea_items[key] for key in set(new_ea_items) - current_ea_keys] - ea_items_to_update = [new_ea_items[key] for key in set(new_ea_items) & current_ea_keys] + ea_items_to_add = [new_ea_items[key] for key in set(new_ea_items) - set(current_ea_ids)] + ea_items_to_update = [ + {"id": current_ea_ids[key], "active": new_ea_items[key]["active"]} + for key in set(new_ea_items) & set(current_ea_ids) + ] + pval_ids_to_remove = [x["id"] for x in is_active_pvals] if ea_items_to_add: conn.execute(ea_table.insert(), ea_items_to_add) if ea_items_to_update: - conn.execute( - ea_table.update() - .where(ea_table.c.entity_id == sa.bindparam("b_entity_id")) - .where(ea_table.c.alternative_id == sa.bindparam("b_alternative_id")) - .values(active=sa.bindparam("b_active")), - [{"b_" + k: v for k, v in x.items()} for x in ea_items_to_update], - ) + conn.execute(ea_table.update(), ea_items_to_update) # Delete pvals 499 at a time to avoid too many sql variables - is_active_pval_ids = [x["id"] for x in is_active_pvals] size = 499 - for i in range(0, len(is_active_pval_ids), size): - ids = is_active_pval_ids[i : i + size] + for i in range(0, len(pval_ids_to_remove), size): + ids = pval_ids_to_remove[i : i + size] conn.execute(pv_table.delete().where(pv_table.c.id.in_(ids))) - if db_map is not None: - db_map.remove_items("parameter_value", *is_active_pval_ids) - # TODO: add and update ea items + return ea_items_to_add, ea_items_to_update, set(pval_ids_to_remove) + + +def refit_data(connection): + """Refits any data having an old format and returns changes made. + + Args: + connection (Connection) + + Returns: + list: list of strings indicating the changes + list: list of tuples (tablename, (items_added, items_updated, ids_removed)) + """ + ea_items_added, ea_items_updated, pval_ids_removed = convert_tool_feature_method_to_entity_alternative(connection) + info = [] + refits = [] + if ea_items_added or ea_items_updated: + refits.append(("entity_alternative", (ea_items_added, ea_items_updated, ()))) + if pval_ids_removed: + refits.append(("parameter_value", ((), (), pval_ids_removed))) + if ea_items_added or ea_items_updated or pval_ids_removed: + info.append("Convert entity activity control using tool/feature/method into entity_alternative") + return info, refits diff --git a/spinedb_api/db_cache_base.py b/spinedb_api/db_cache_base.py index f6cdd520..8571046f 100644 --- a/spinedb_api/db_cache_base.py +++ b/spinedb_api/db_cache_base.py @@ -417,7 +417,7 @@ def remove_item(self, id_): return current_item def restore_item(self, id_): - current_item = self.get(id_) + current_item = self.find_item({"id": id_}) if current_item is not None: self._add_unique(current_item) current_item.cascade_restore() diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py index 00781d3d..d1acd8de 100644 --- a/spinedb_api/db_mapping_commit_mixin.py +++ b/spinedb_api/db_mapping_commit_mixin.py @@ -17,7 +17,7 @@ from datetime import datetime, timezone import sqlalchemy.exc from .exception import SpineDBAPIError -from .compatibility import convert_tool_feature_method_to_entity_alternative +from .compatibility import refit_data class DatabaseMappingCommitMixin: @@ -49,9 +49,9 @@ def commit_session(self, comment): self._do_remove_items(connection, tablename, *{x["id"] for x in to_remove}) self._do_update_items(connection, tablename, *to_update) self._do_add_items(connection, tablename, *to_add) - convert_tool_feature_method_to_entity_alternative(connection, self) - if self._memory: - self._memory_dirty = True + if self._memory: + self._memory_dirty = True + return refit_data(connection) def rollback_session(self): if not self.cache.rollback(): From 133b9c553eaf31176dbd005d6ba26e1d779c9e95 Mon Sep 17 00:00:00 2001 From: Manuel Date: Mon, 25 Sep 2023 21:59:23 +0200 Subject: [PATCH 097/317] Fix TempId.__hash__ Weirdly, hash(-1) == hash(-2) in my Python?? --- spinedb_api/temp_id.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spinedb_api/temp_id.py b/spinedb_api/temp_id.py index 84eb3464..3a57ce6f 100644 --- a/spinedb_api/temp_id.py +++ b/spinedb_api/temp_id.py @@ -32,7 +32,7 @@ def __eq__(self, other): return super().__eq__(other) or (self._db_id is not None and other == self._db_id) def __hash__(self): - return int.__hash__(self) + return -int(self) @property def db_id(self): From a1b0303ffd827ccfc59618cf08ffc63d29bf7824 Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 26 Sep 2023 07:46:44 +0200 Subject: [PATCH 098/317] Minor renaming --- spinedb_api/compatibility.py | 12 ++++++------ spinedb_api/db_mapping_commit_mixin.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/spinedb_api/compatibility.py b/spinedb_api/compatibility.py index f2b64372..4fb2e1c0 100644 --- a/spinedb_api/compatibility.py +++ b/spinedb_api/compatibility.py @@ -101,23 +101,23 @@ def convert_tool_feature_method_to_entity_alternative(conn): return ea_items_to_add, ea_items_to_update, set(pval_ids_to_remove) -def refit_data(connection): +def compatibility_transformations(connection): """Refits any data having an old format and returns changes made. Args: connection (Connection) Returns: - list: list of strings indicating the changes list: list of tuples (tablename, (items_added, items_updated, ids_removed)) + list: list of strings indicating the changes """ ea_items_added, ea_items_updated, pval_ids_removed = convert_tool_feature_method_to_entity_alternative(connection) + transformations = [] info = [] - refits = [] if ea_items_added or ea_items_updated: - refits.append(("entity_alternative", (ea_items_added, ea_items_updated, ()))) + transformations.append(("entity_alternative", (ea_items_added, ea_items_updated, ()))) if pval_ids_removed: - refits.append(("parameter_value", ((), (), pval_ids_removed))) + transformations.append(("parameter_value", ((), (), pval_ids_removed))) if ea_items_added or ea_items_updated or pval_ids_removed: info.append("Convert entity activity control using tool/feature/method into entity_alternative") - return info, refits + return transformations, info diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py index d1acd8de..7658641e 100644 --- a/spinedb_api/db_mapping_commit_mixin.py +++ b/spinedb_api/db_mapping_commit_mixin.py @@ -17,7 +17,7 @@ from datetime import datetime, timezone import sqlalchemy.exc from .exception import SpineDBAPIError -from .compatibility import refit_data +from .compatibility import compatibility_transformations class DatabaseMappingCommitMixin: @@ -51,7 +51,7 @@ def commit_session(self, comment): self._do_add_items(connection, tablename, *to_add) if self._memory: self._memory_dirty = True - return refit_data(connection) + return compatibility_transformations(connection) def rollback_session(self): if not self.cache.rollback(): From 81322739ffb7113fb6e7ae09f81a0e5fb0ce82ae Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 26 Sep 2023 07:53:46 +0200 Subject: [PATCH 099/317] Fix tests --- .../versions/ce9faa82ed59_create_entity_alternative_table.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/spinedb_api/alembic/versions/ce9faa82ed59_create_entity_alternative_table.py b/spinedb_api/alembic/versions/ce9faa82ed59_create_entity_alternative_table.py index 8671ab4e..cb011771 100644 --- a/spinedb_api/alembic/versions/ce9faa82ed59_create_entity_alternative_table.py +++ b/spinedb_api/alembic/versions/ce9faa82ed59_create_entity_alternative_table.py @@ -43,7 +43,10 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name=op.f('pk_entity_alternative')), sa.UniqueConstraint('entity_id', 'alternative_id', name=op.f('uq_entity_alternative_entity_idalternative_id')), ) - op.drop_table('next_id') + try: + op.drop_table('next_id') + except sa.exc.OperationalError: + pass convert_tool_feature_method_to_entity_alternative(op.get_bind()) From 49530291394be8f88c9540582e8234252954b016 Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 28 Sep 2023 10:32:32 +0200 Subject: [PATCH 100/317] Fix scenario filter including entity_alternative --- spinedb_api/__init__.py | 1 + spinedb_api/db_mapping_base.py | 27 +++++++- spinedb_api/filters/scenario_filter.py | 95 +++++++++++++++++++++----- spinedb_api/import_functions.py | 79 +++++++++++++++------ spinedb_api/spine_db_server.py | 1 + tests/filters/test_scenario_filter.py | 46 +++++++++++++ 6 files changed, 211 insertions(+), 38 deletions(-) diff --git a/spinedb_api/__init__.py b/spinedb_api/__init__.py index 1c1315bf..5e823c50 100644 --- a/spinedb_api/__init__.py +++ b/spinedb_api/__init__.py @@ -37,6 +37,7 @@ import_data, import_entity_classes, import_entities, + import_entity_alternatives, import_parameter_definitions, import_parameter_values, import_object_classes, diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index c59322d3..1545878b 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -437,7 +437,7 @@ def entity_sq(self): @property def entity_element_sq(self): if self._entity_element_sq is None: - self._entity_element_sq = self._subquery("entity_element") + self._entity_element_sq = self._make_entity_element_sq() return self._entity_element_sq @property @@ -1570,6 +1570,15 @@ def _make_entity_sq(self): """ return self._subquery("entity") + def _make_entity_element_sq(self): + """ + Creates a subquery for entity-elements. + + Returns: + Alias: an entity_element subquery + """ + return self._subquery("entity_element") + def _make_parameter_definition_sq(self): """ Creates a subquery for parameter definitions. @@ -1706,6 +1715,17 @@ def override_entity_sq_maker(self, method): self._make_entity_sq = MethodType(method, self) self._clear_subqueries("entity") + def override_entity_element_sq_maker(self, method): + """ + Overrides the function that creates the ``entity_element_sq`` property. + + Args: + method (Callable): a function that accepts a :class:`DatabaseMappingBase` as its argument and + returns entity_element subquery as an :class:`Alias` object + """ + self._make_entity_element_sq = MethodType(method, self) + self._clear_subqueries("entity_element") + def override_parameter_definition_sq_maker(self, method): """ Overrides the function that creates the ``parameter_definition_sq`` property. @@ -1771,6 +1791,11 @@ def restore_entity_sq_maker(self): self._make_entity_sq = MethodType(DatabaseMappingBase._make_entity_sq, self) self._clear_subqueries("entity") + def restore_entity_element_sq_maker(self): + """Restores the original function that creates the ``entity_element_sq`` property.""" + self._make_entity_element_sq = MethodType(DatabaseMappingBase._make_entity_element_sq, self) + self._clear_subqueries("entity_element") + def restore_parameter_definition_sq_maker(self): """Restores the original function that creates the ``parameter_definition_sq`` property.""" self._make_parameter_definition_sq = MethodType(DatabaseMappingBase._make_parameter_definition_sq, self) diff --git a/spinedb_api/filters/scenario_filter.py b/spinedb_api/filters/scenario_filter.py index d33310ed..858bcb40 100644 --- a/spinedb_api/filters/scenario_filter.py +++ b/spinedb_api/filters/scenario_filter.py @@ -31,7 +31,8 @@ def apply_scenario_filter_to_subqueries(db_map, scenario): scenario (str or int): scenario name or id """ state = _ScenarioFilterState(db_map, scenario) - # FIXME + make_entity_element_sq = partial(_make_scenario_filtered_entity_element_sq, state=state) + db_map.override_entity_element_sq_maker(make_entity_element_sq) make_entity_sq = partial(_make_scenario_filtered_entity_sq, state=state) db_map.override_entity_sq_maker(make_entity_sq) make_parameter_value_sq = partial(_make_scenario_filtered_parameter_value_sq, state=state) @@ -131,6 +132,7 @@ def __init__(self, db_map, scenario): scenario (str or int): scenario name or ids """ self.original_entity_sq = db_map.entity_sq + self.original_entity_element_sq = db_map.entity_element_sq self.original_parameter_value_sq = db_map.parameter_value_sq self.original_scenario_sq = db_map.scenario_sq self.original_scenario_alternative_sq = db_map.scenario_alternative_sq @@ -184,19 +186,8 @@ def _scenario_alternative_ids(self, db_map): return scenario_alternative_ids, alternative_ids -def _make_scenario_filtered_entity_sq(db_map, state): - """Returns a scenario filtering subquery similar to :func:`DatabaseMappingBase.entity_sq`. - - This function can be used as replacement for entity subquery maker in :class:`DatabaseMappingBase`. - - Args: - db_map (DatabaseMappingBase): a database map - state (_ScenarioFilterState): a state bound to ``db_map`` - - Returns: - Alias: a subquery for entity filtered by selected scenario - """ - ext_entity_sq = ( +def _ext_entity_sq(db_map, state): + return ( db_map.query( state.original_entity_sq, func.row_number() @@ -222,13 +213,85 @@ def _make_scenario_filtered_entity_sq(db_map, state): ) ) ).subquery() - # TODO: Maybe we want to filter multi-dimensional entities involving filtered entities right here too? + + +def _make_scenario_filtered_entity_element_sq(db_map, state): + """Returns a scenario filtering subquery similar to :func:`DatabaseMappingBase.entity_element_sq`. + + This function can be used as replacement for entity_element subquery maker in :class:`DatabaseMappingBase`. + + Args: + db_map (DatabaseMappingBase): a database map + state (_ScenarioFilterState): a state bound to ``db_map`` + + Returns: + Alias: a subquery for entity_element filtered by selected scenario + """ + ext_entity_sq = _ext_entity_sq(db_map, state) + entity_sq = ext_entity_sq.alias() + element_sq = ext_entity_sq.alias() return ( - db_map.query(ext_entity_sq) + db_map.query(state.original_entity_element_sq) + .filter(state.original_entity_element_sq.c.entity_id == entity_sq.c.id) + .filter(state.original_entity_element_sq.c.element_id == element_sq.c.id) + .filter( + entity_sq.c.desc_rank_row_number == 1, + or_(entity_sq.c.active == True, entity_sq.c.active == None), + ) + .filter( + element_sq.c.desc_rank_row_number == 1, + or_(element_sq.c.active == True, element_sq.c.active == None), + ) + .subquery() + ) + + +def _make_scenario_filtered_entity_sq(db_map, state): + """Returns a scenario filtering subquery similar to :func:`DatabaseMappingBase.entity_sq`. + + This function can be used as replacement for entity subquery maker in :class:`DatabaseMappingBase`. + + Args: + db_map (DatabaseMappingBase): a database map + state (_ScenarioFilterState): a state bound to ``db_map`` + + Returns: + Alias: a subquery for entity filtered by selected scenario + """ + ext_entity_sq = _ext_entity_sq(db_map, state) + ext_entity_class_dimension_count_sq = ( + db_map.query( + db_map.entity_class_dimension_sq.c.entity_class_id, + func.count(db_map.entity_class_dimension_sq.c.dimension_id).label("dimension_count"), + ) + .group_by(db_map.entity_class_dimension_sq.c.entity_class_id) + .subquery() + ) + return ( + db_map.query( + ext_entity_sq.c.id, + ext_entity_sq.c.class_id, + ext_entity_sq.c.name, + ext_entity_sq.c.description, + ext_entity_sq.c.commit_id, + ) .filter( ext_entity_sq.c.desc_rank_row_number == 1, or_(ext_entity_sq.c.active == True, ext_entity_sq.c.active == None), ) + .outerjoin( + ext_entity_class_dimension_count_sq, + ext_entity_class_dimension_count_sq.c.entity_class_id == ext_entity_sq.c.class_id, + ) + .outerjoin(db_map.entity_element_sq, ext_entity_sq.c.id == db_map.entity_element_sq.c.entity_id) + .group_by(ext_entity_sq.c.id) + .having( + or_( + ext_entity_class_dimension_count_sq.c.dimension_count == None, + ext_entity_class_dimension_count_sq.c.dimension_count + == func.count(db_map.entity_element_sq.c.element_id), + ) + ) .subquery() ) diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index d3badb8e..8b3bec9b 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -188,10 +188,13 @@ def get_data_for_import( if entity_classes: yield ("entity_class", _get_entity_classes_for_import(db_map, entity_classes, zero_dim=True)) yield ("entity_class", _get_entity_classes_for_import(db_map, entity_classes, zero_dim=False)) - if object_classes: - yield ("object_class", _get_object_classes_for_import(db_map, object_classes)) - if relationship_classes: - yield ("relationship_class", _get_entity_classes_for_import(db_map, relationship_classes, zero_dim=False)) + if entities: + yield ("entity", _get_entities_for_import(db_map, entities, zero_dim=True)) + yield ("entity", _get_entities_for_import(db_map, entities, zero_dim=False)) + if entity_alternatives: + yield ("entity_alternative", _get_entity_alternatives_for_import(db_map, entity_alternatives)) + if entity_groups: + yield ("entity_group", _get_entity_groups_for_import(db_map, entity_groups)) if parameter_value_lists: yield ("parameter_value_list", _get_parameter_value_lists_for_import(db_map, parameter_value_lists)) yield ("list_value", _get_list_values_for_import(db_map, parameter_value_lists, unparse_value)) @@ -200,6 +203,23 @@ def get_data_for_import( "parameter_definition", _get_parameter_definitions_for_import(db_map, parameter_definitions, unparse_value), ) + if parameter_values: + yield ( + "parameter_value", + _get_parameter_values_for_import(db_map, parameter_values, unparse_value, on_conflict), + ) + if metadata: + yield ("metadata", _get_metadata_for_import(db_map, metadata)) + if entity_metadata: + yield ("metadata", _get_metadata_for_import(db_map, (metadata for _, _, metadata in entity_metadata))) + yield ("entity_metadata", _get_entity_metadata_for_import(db_map, entity_metadata)) + if parameter_value_metadata: + yield ("parameter_value_metadata", _get_parameter_value_metadata_for_import(db_map, parameter_value_metadata)) + # Legacy + if object_classes: + yield ("object_class", _get_object_classes_for_import(db_map, object_classes)) + if relationship_classes: + yield ("relationship_class", _get_entity_classes_for_import(db_map, relationship_classes, zero_dim=False)) if object_parameters: yield ("parameter_definition", _get_parameter_definitions_for_import(db_map, object_parameters, unparse_value)) if relationship_parameters: @@ -207,22 +227,12 @@ def get_data_for_import( "parameter_definition", _get_parameter_definitions_for_import(db_map, relationship_parameters, unparse_value), ) - if entities: - yield ("entity", _get_entities_for_import(db_map, entities, zero_dim=True)) - yield ("entity", _get_entities_for_import(db_map, entities, zero_dim=False)) if objects: yield ("object", _get_entities_for_import(db_map, objects, zero_dim=True)) if relationships: yield ("relationship", _get_entities_for_import(db_map, relationships, zero_dim=False)) - if entity_groups: - yield ("entity_group", _get_entity_groups_for_import(db_map, entity_groups)) if object_groups: yield ("entity_group", _get_entity_groups_for_import(db_map, object_groups)) - if parameter_values: - yield ( - "parameter_value", - _get_parameter_values_for_import(db_map, parameter_values, unparse_value, on_conflict), - ) if object_parameter_values: yield ( "parameter_value", @@ -233,13 +243,6 @@ def get_data_for_import( "parameter_value", _get_parameter_values_for_import(db_map, relationship_parameter_values, unparse_value, on_conflict), ) - if metadata: - yield ("metadata", _get_metadata_for_import(db_map, metadata)) - if entity_metadata: - yield ("metadata", _get_metadata_for_import(db_map, (metadata for _, _, metadata in entity_metadata))) - yield ("entity_metadata", _get_entity_metadata_for_import(db_map, entity_metadata)) - if parameter_value_metadata: - yield ("parameter_value_metadata", _get_parameter_value_metadata_for_import(db_map, parameter_value_metadata)) if object_metadata: yield from get_data_for_import(db_map, entity_metadata=object_metadata) if relationship_metadata: @@ -296,6 +299,29 @@ def import_entities(db_map, data): return import_data(db_map, entities=data) +def import_entity_alternatives(db_map, data): + """Imports entity alternatives. + + Example:: + + data = [ + ('class_name1', 'entity_name1', 'alternative_name3', True), + ('class_name2', 'entity_name2', 'alternative_name4', False), + ('class_name3', ('entity_name1', 'entity_name2'), 'alternative_name5', False) + ] + import_entity_alternatives(db_map, data) + + Args: + db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into + data (List[List/Tuple]): list/set/iterable of lists/tuples with entity class name, + entity name or list/tuple of element names, alternative name, active boolean value + + Returns: + (Int, List) Number of successful inserted entities, list of errors + """ + return import_data(db_map, entity_alternatives=data) + + def import_entity_groups(db_map, data): """Imports list of entity groups by name with associated class name into given database mapping: Ignores duplicate and existing (group, member) tuples. @@ -814,6 +840,17 @@ def _data_iterator(): return _get_items_for_import(db_map, "entity", _data_iterator()) +def _get_entity_alternatives_for_import(db_map, data): + def _data_iterator(): + for class_name, entity_name_or_element_name_list, alternative, active in data: + is_zero_dim = isinstance(entity_name_or_element_name_list, str) + entity_byname = (entity_name_or_element_name_list,) if is_zero_dim else entity_name_or_element_name_list + key = ("entity_class_name", "entity_byname", "alternative_name", "active") + yield dict(zip(key, (class_name, entity_byname, alternative, active))) + + return _get_items_for_import(db_map, "entity_alternative", _data_iterator()) + + def _get_entity_groups_for_import(db_map, data): key = ("class_name", "group_name", "member_name") return _get_items_for_import(db_map, "entity_group", (dict(zip(key, x)) for x in data)) diff --git a/spinedb_api/spine_db_server.py b/spinedb_api/spine_db_server.py index ceba23fa..5915c8b4 100644 --- a/spinedb_api/spine_db_server.py +++ b/spinedb_api/spine_db_server.py @@ -328,6 +328,7 @@ def _do_call_method(self, method_name, *args, **kwargs): def _do_clear_filters(self): self._db_map.restore_entity_sq_maker() + self._db_map.restore_entity_element_sq_maker() self._db_map.restore_entity_class_sq_maker() self._db_map.restore_parameter_definition_sq_maker() self._db_map.restore_parameter_value_sq_maker() diff --git a/tests/filters/test_scenario_filter.py b/tests/filters/test_scenario_filter.py index 13120381..c2f2f51f 100644 --- a/tests/filters/test_scenario_filter.py +++ b/tests/filters/test_scenario_filter.py @@ -22,6 +22,9 @@ create_new_spine_database, DatabaseMapping, import_alternatives, + import_entity_classes, + import_entities, + import_entity_alternatives, import_object_classes, import_object_parameter_values, import_object_parameters, @@ -107,6 +110,49 @@ def test_scenario_filter_uncommitted_data(self): self.assertEqual(len(scenarios), 0) self._out_db_map.rollback_session() + def test_scenario_filter_works_for_entity_sq(self): + import_alternatives(self._out_db_map, ["alternative1", "alternative2"]) + import_entity_classes( + self._out_db_map, [("class1", ()), ("class2", ()), ("class1__class2", ("class1", "class2"))] + ) + import_entities( + self._out_db_map, + [ + ("class1", "obj1"), + ("class2", "obj2"), + ("class2", "obj22"), + ("class1__class2", ("obj1", "obj2")), + ("class1__class2", ("obj1", "obj22")), + ], + ) + import_entity_alternatives( + self._out_db_map, + [ + ("class2", "obj2", "alternative1", True), + ("class2", "obj2", "alternative2", False), + ("class2", "obj22", "alternative1", False), + ("class2", "obj22", "alternative2", True), + ], + ) + import_scenarios(self._out_db_map, [("scenario1", True)]) + import_scenario_alternatives( + self._out_db_map, [("scenario1", "alternative2"), ("scenario1", "alternative1", "alternative2")] + ) + self._out_db_map.commit_session("Add test data") + entities = self._db_map.query(self._db_map.entity_sq).all() + self.assertEqual(len(entities), 5) + apply_scenario_filter_to_subqueries(self._db_map, "scenario1") + # After this, obj2 should be excluded because it is inactive in the highest-ranked alternative2 + # The multidimensional entity 'class1__class2, (obj1, obj2)' should also be excluded because involves obj2 + entities = self._db_map.query(self._db_map.wide_entity_sq).all() + self.assertEqual(len(entities), 3) + entity_names = { + name + for x in entities + for name in (x["element_name_list"].split(",") if x["element_name_list"] else (x["name"],)) + } + self.assertFalse("obj2" in entity_names) + def test_scenario_filter_works_for_object_parameter_value_sq(self): _build_data_with_single_scenario(self._out_db_map) apply_scenario_filter_to_subqueries(self._db_map, "scenario") From 99ea33e2398852b534feca3fb20e0e84805e87a5 Mon Sep 17 00:00:00 2001 From: Manuel Date: Mon, 2 Oct 2023 08:37:00 +0200 Subject: [PATCH 101/317] Improve documentation and work towards the public API --- docs/source/conf.py | 15 +- spinedb_api/__init__.py | 9 +- spinedb_api/compatibility.py | 2 +- spinedb_api/db_cache_base.py | 94 ++-- spinedb_api/db_cache_impl.py | 35 +- spinedb_api/db_mapping.py | 153 +++++- spinedb_api/db_mapping_add_mixin.py | 19 +- spinedb_api/db_mapping_base.py | 619 +++++++++--------------- spinedb_api/db_mapping_commit_mixin.py | 5 - spinedb_api/db_mapping_remove_mixin.py | 29 +- spinedb_api/db_mapping_update_mixin.py | 29 +- spinedb_api/exception.py | 34 +- spinedb_api/export_functions.py | 83 ++-- spinedb_api/filters/execution_filter.py | 4 +- spinedb_api/graph_layout_generator.py | 33 +- spinedb_api/helpers.py | 60 ++- spinedb_api/import_functions.py | 610 ++++++----------------- spinedb_api/import_mapping/__init__.py | 4 + spinedb_api/mapping.py | 4 +- spinedb_api/parameter_value.py | 10 +- spinedb_api/perfect_split.py | 10 +- spinedb_api/purge.py | 10 +- spinedb_api/query.py | 9 +- spinedb_api/server_client_helpers.py | 22 +- spinedb_api/spine_db_client.py | 48 +- spinedb_api/spine_db_server.py | 90 +++- spinedb_api/temp_id.py | 4 - tests/test_DatabaseMapping.py | 59 ++- tests/test_db_cache_base.py | 4 +- 29 files changed, 919 insertions(+), 1188 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 28d0824a..643000a6 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -51,7 +51,7 @@ 'sphinx.ext.napoleon', 'sphinx.ext.intersphinx', 'recommonmark', - 'autoapi.extension' + 'autoapi.extension', ] # Add any paths that contain templates here, relative to this directory. @@ -85,6 +85,7 @@ pygments_style = 'sphinx' # Settings for Sphinx AutoAPI +autoapi_options = ['members', 'inherited-members'] autoapi_python_class_content = "both" autoapi_add_toctree_entry = True autoapi_root = "autoapi" @@ -92,7 +93,17 @@ autoapi_ignore = [ '*/spinedb_api/alembic/*', ] # ignored modules -autoapi_keep_files=True +autoapi_keep_files = True + + +def _skip_member(app, what, name, obj, skip, options): + if what == "class" and any(x in name for x in ("SpineDBServer", "group_concat")): + skip = True + return skip + + +def setup(sphinx): + sphinx.connect("autoapi-skip-member", _skip_member) # -- Options for HTML output ------------------------------------------------- diff --git a/spinedb_api/__init__.py b/spinedb_api/__init__.py index 5e823c50..679c68e9 100644 --- a/spinedb_api/__init__.py +++ b/spinedb_api/__init__.py @@ -9,14 +9,15 @@ # this program. If not, see . ###################################################################################################################### +""" +A package to interact with Spine DBs. +""" + from .db_mapping import DatabaseMapping from .exception import ( SpineDBAPIError, SpineIntegrityError, SpineDBVersionError, - SpineTableNotFoundError, - RecordNotFoundError, - ParameterValueError, ParameterValueFormatError, InvalidMapping, ) @@ -26,8 +27,6 @@ SUPPORTED_DIALECTS, create_new_spine_database, copy_database, - is_unlocked, - is_head, is_empty, forward_sweep, Asterisk, diff --git a/spinedb_api/compatibility.py b/spinedb_api/compatibility.py index 4fb2e1c0..2e748ce6 100644 --- a/spinedb_api/compatibility.py +++ b/spinedb_api/compatibility.py @@ -9,7 +9,7 @@ # this program. If not, see . ###################################################################################################################### -"""Dirty hacks needed to maintain compatibility in cases where migration alone doesn't do it.""" +# Dirty hacks needed to maintain compatibility in cases where migration alone doesn't do it import sqlalchemy as sa diff --git a/spinedb_api/db_cache_base.py b/spinedb_api/db_cache_base.py index 8571046f..9143bc44 100644 --- a/spinedb_api/db_cache_base.py +++ b/spinedb_api/db_cache_base.py @@ -8,13 +8,15 @@ # Public License for more details. You should have received a copy of the GNU Lesser General Public License along with # this program. If not, see . ###################################################################################################################### -"""DB cache base.""" + import threading from enum import Enum, unique, auto from .temp_id import TempId # TODO: Implement CacheItem.pop() to do lookup? +_LIMIT = 10000 + @unique class Status(Enum): @@ -30,27 +32,31 @@ class Status(Enum): class DBCacheBase(dict): """A dictionary that maps table names to ids to items. Used to store and retrieve database contents.""" - def __init__(self, chunk_size=None): + def __init__(self): super().__init__() self._offsets = {} self._offset_lock = threading.Lock() self._fetched_item_types = set() - self._chunk_size = chunk_size - item_types = self._item_types + item_types = self.item_types self._sorted_item_types = [] while item_types: item_type = item_types.pop(0) - if self._item_factory(item_type).ref_types() & set(item_types): + if self.item_factory(item_type).ref_types() & set(item_types): item_types.append(item_type) else: self._sorted_item_types.append(item_type) @property def fetched_item_types(self): + """Returns a set with the item types that are already fetched. + + Returns: + set + """ return self._fetched_item_types @property - def _item_types(self): + def item_types(self): """Returns a list of supported item type strings. Returns: @@ -59,7 +65,7 @@ def _item_types(self): raise NotImplementedError() @staticmethod - def _item_factory(item_type): + def item_factory(item_type): """Returns a subclass of CacheItemBase to build items of given type. Args: @@ -70,7 +76,7 @@ def _item_factory(item_type): """ raise NotImplementedError() - def _query(self, item_type): + def query(self, item_type): """Returns a Query object to fecth items of given type. Args: @@ -82,12 +88,14 @@ def _query(self, item_type): raise NotImplementedError() def make_item(self, item_type, **item): - factory = self._item_factory(item_type) + factory = self.item_factory(item_type) return factory(self, item_type, **item) def dirty_ids(self, item_type): return { - item["id"] for item in self.get(item_type, {}).values() if item.status in (Status.to_add, Status.to_update) + item["id"] + for item in self.table_cache(item_type).valid_values() + if item.status in (Status.to_add, Status.to_update) } def dirty_items(self): @@ -105,7 +113,7 @@ def dirty_items(self): to_add = [] to_update = [] to_remove = [] - for item in dict.values(table_cache): + for item in table_cache.values(): _ = item.is_valid() if item.status == Status.to_add: to_add.append(item) @@ -118,10 +126,10 @@ def dirty_items(self): # This ensures cascade removal. # FIXME: We should also fetch the current item type because of multi-dimensional entities and # classes which also depend on zero-dimensional ones - for other_item_type in self._item_types: + for other_item_type in self.item_types: if ( other_item_type not in self.fetched_item_types - and item_type in self._item_factory(other_item_type).ref_types() + and item_type in self.item_factory(other_item_type).ref_types() ): self.fetch_all(other_item_type) if to_add or to_update or to_remove: @@ -166,20 +174,20 @@ def refresh(self): self._offsets.clear() self._fetched_item_types.clear() - def _get_next_chunk(self, item_type): - qry = self._query(item_type) + def _get_next_chunk(self, item_type, limit): + qry = self.query(item_type) if not qry: return [] - if not self._chunk_size: + if not limit: self._fetched_item_types.add(item_type) return [dict(x) for x in qry] with self._offset_lock: offset = self._offsets.setdefault(item_type, 0) - chunk = [dict(x) for x in qry.limit(self._chunk_size).offset(offset)] + chunk = [dict(x) for x in qry.limit(limit).offset(offset)] self._offsets[item_type] += len(chunk) return chunk - def advance_query(self, item_type): + def _advance_query(self, item_type, limit): """Advances the DB query that fetches items of given type and adds the results to the corresponding table cache. @@ -189,7 +197,7 @@ def advance_query(self, item_type): Returns: list: items fetched from the DB """ - chunk = self._get_next_chunk(item_type) + chunk = self._get_next_chunk(item_type, limit) if not chunk: self._fetched_item_types.add(item_type) return [] @@ -208,10 +216,10 @@ def get_item(self, item_type, id_): return {} return item - def fetch_more(self, item_type): + def fetch_more(self, item_type, limit=_LIMIT): if item_type in self._fetched_item_types: - return False - return bool(self.advance_query(item_type)) + return [] + return self._advance_query(item_type, limit) def fetch_all(self, item_type): while self.fetch_more(item_type): @@ -264,7 +272,8 @@ def _callback(db_id): return temp_id def unique_key_value_to_id(self, key, value, strict=False): - """Returns the id that has the given value for the given unique key, or None. + """Returns the id that has the given value for the given unique key, or None if not found. + Fetches until being sure. Args: key (tuple) @@ -287,8 +296,8 @@ def unique_key_value_to_id(self, key, value, strict=False): def _unique_key_value_to_item(self, key, value): return self.get(self.unique_key_value_to_id(key, value)) - def values(self): - return (x for x in super().values() if x.is_valid()) + def valid_values(self): + return (x for x in self.values() if x.is_valid()) def _make_item(self, item): """Returns a cache item. @@ -311,27 +320,22 @@ def find_item(self, item, skip_keys=()): CacheItemBase or None """ id_ = item.get("id") - if isinstance(id_, int): - # id is an int, easy + if id_ is not None: + # id is given, easy return self.get(id_) or self._db_cache.fetch_ref(self._item_type, id_) - if isinstance(id_, dict): - # id is a dict specifying the values for one of the unique constraints - key, value = zip(*id_.items()) - return self._unique_key_value_to_item(key, value) - if id_ is None: - # No id. Try to locate the item by the value of one of the unique keys. - # Used by import_data (and more...) - cache_item = self._make_item(item) - error = cache_item.resolve_inverse_references(item.keys()) - if error: - return None - error = cache_item.polish() - if error: - return None - for key, value in cache_item.unique_values(skip_keys=skip_keys): - current_item = self._unique_key_value_to_item(key, value) - if current_item: - return current_item + # No id. Try to locate the item by the value of one of the unique keys. + # Used by import_data (and more...) + cache_item = self._make_item(item) + error = cache_item.resolve_inverse_references(item.keys()) + if error: + return None + error = cache_item.polish() + if error: + return None + for key, value in cache_item.unique_values(skip_keys=skip_keys): + current_item = self._unique_key_value_to_item(key, value) + if current_item: + return current_item def check_item(self, item, for_update=False, skip_keys=()): # FIXME: The only use-case for skip_keys at the moment is that of importing scenario alternatives, diff --git a/spinedb_api/db_cache_impl.py b/spinedb_api/db_cache_impl.py index 0a451382..8250554f 100644 --- a/spinedb_api/db_cache_impl.py +++ b/spinedb_api/db_cache_impl.py @@ -8,10 +8,9 @@ # Public License for more details. You should have received a copy of the GNU Lesser General Public License along with # this program. If not, see . ###################################################################################################################### -""" -DB cache implementation. -""" +# The Spine implementation for DBCacheBase + import uuid from operator import itemgetter from .parameter_value import to_database, from_database, ParameterValueFormatError @@ -38,20 +37,20 @@ class DBCache(DBCacheBase): "commit": "commit_sq", } - def __init__(self, db_map, chunk_size=None): + def __init__(self, db_map): """ Args: db_map (DatabaseMapping) """ - super().__init__(chunk_size=chunk_size) + super().__init__() self._db_map = db_map @property - def _item_types(self): + def item_types(self): return list(self._sq_name_by_item_type) @staticmethod - def _item_factory(item_type): + def item_factory(item_type): return { "entity_class": EntityClassItem, "entity": EntityItem, @@ -69,7 +68,7 @@ def _item_factory(item_type): "parameter_value_metadata": ParameterValueMetadataItem, }.get(item_type, CacheItemBase) - def _query(self, item_type): + def query(self, item_type): if self._db_map.closed: return None sq_name = self._sq_name_by_item_type[item_type] @@ -292,7 +291,7 @@ def merge(self, other): and other_parameter_value_list_id != self["parameter_value_list_id"] and any( x["parameter_definition_id"] == self["id"] - for x in self._db_cache.table_cache("parameter_value").values() + for x in self._db_cache.table_cache("parameter_value").valid_values() ) ): del other["parameter_value_list_id"] @@ -409,19 +408,21 @@ class ScenarioItem(CacheItemBase): _defaults = {"active": False, "description": None} _unique_keys = (("name",),) - @property - def sorted_scenario_alternatives(self): - self._db_cache.fetch_all("scenario_alternative") - return sorted( - (x for x in self._db_cache.get("scenario_alternative", {}).values() if x["scenario_id"] == self["id"]), - key=itemgetter("rank"), - ) - def __getitem__(self, key): if key == "alternative_id_list": return [x["alternative_id"] for x in self.sorted_scenario_alternatives] if key == "alternative_name_list": return [x["alternative_name"] for x in self.sorted_scenario_alternatives] + if key == "sorted_scenario_alternatives": + self._db_cache.fetch_all("scenario_alternative") + return sorted( + ( + x + for x in self._db_cache.table_cache("scenario_alternative").valid_values() + if x["scenario_id"] == self["id"] + ), + key=itemgetter("rank"), + ) return super().__getitem__(key) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index bb45dcff..1dfe84d2 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -10,16 +10,15 @@ ###################################################################################################################### """ -Provides :class:`.DatabaseMapping`. - +This module defines the :class:`.DatabaseMapping` class. """ +import sqlalchemy.exc from .db_mapping_base import DatabaseMappingBase from .db_mapping_add_mixin import DatabaseMappingAddMixin from .db_mapping_update_mixin import DatabaseMappingUpdateMixin from .db_mapping_remove_mixin import DatabaseMappingRemoveMixin from .db_mapping_commit_mixin import DatabaseMappingCommitMixin -from .filters.tools import apply_filter_stack, load_filters class DatabaseMapping( @@ -29,15 +28,149 @@ class DatabaseMapping( DatabaseMappingCommitMixin, DatabaseMappingBase, ): - """A basic read-write database mapping. + """Enables communication with a Spine DB. + + An in-memory clone (ORM) of the DB is incrementally formed as data is requested/modified. + + Data is typically retrieved using :meth:`get_item` or :meth:`get_items`. + If the requested data is already in the in-memory clone, it is returned from there; + otherwise it is fetched from the DB, stored in the clone, and then returned. + In other words, the data is fetched from the DB exactly once. + + Data is added via :meth:`add_item` or :meth:`add_items`; + updated via :meth:`update_item` or :meth:`update_items`; + removed via :meth:`remove_item` or :meth:`remove_items`; + and restored via :meth:`restore_item` or :meth:`restore_items`. + All the above methods modify the in-memory clone (not the DB itself). + These methods also fetch data from the DB into the in-memory clone to perform the necessary integrity checks + (unique constraints, foreign key constraints) as needed. + + Modifications to the in-memory clone are committed (written) to the DB via :meth:`commit_session`, + or rolled back (discarded) via :meth:`rollback_session`. - :param str db_url: A database URL in RFC-1738 format pointing to the database to be mapped. - :param str username: A user name. If ``None``, it gets replaced by the string ``"anon"``. - :param bool upgrade: Whether or not the db at the given URL should be upgraded to the most recent version. + The in-memory clone is reset via :meth:`refresh_session`. + + You can also control the fetching process via :meth:`fetch_more` and/or :meth:`fetch_all`. + These methods are especially useful to be called asynchronously. + + Data can also be retreived using :meth:`query` in combination with one of the multiple subquery properties + documented below. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if self._filter_configs is not None: - stack = load_filters(self._filter_configs) - apply_filter_stack(self, stack) + for item_type in self.ITEM_TYPES: + setattr(self, "get_" + item_type, self._make_getter(item_type)) + + def _make_getter(self, item_type): + def _get_item(self, **kwargs): + return self.get_item(item_type, **kwargs) + + return _get_item + + def get_item(self, tablename, **kwargs): + tablename = self._real_tablename(tablename) + cache_item = self.cache.table_cache(tablename).find_item(kwargs) + if not cache_item: + return None + return PublicItem(self, cache_item) + + def get_items(self, tablename, fetch=True, valid_only=True): + tablename = self._real_tablename(tablename) + if fetch and tablename not in self.cache.fetched_item_types: + self.fetch_all(tablename) + if valid_only: + return [PublicItem(self, x) for x in self.cache.table_cache(tablename).valid_values()] + return [PublicItem(self, x) for x in self.cache.table_cache(tablename).values()] + + def can_fetch_more(self, tablename): + return tablename not in self.cache.fetched_item_types + + def fetch_more(self, tablename, limit): + """Fetches items from the DB into memory, incrementally. + + Args: + tablename (str): The table to fetch. + limit (int): The maximum number of items to fetch. Successive calls to this function + will start from the point where the last one left. + In other words, each item is fetched from the DB exactly once. + + Returns: + list(PublicItem): The items fetched. + """ + tablename = self._real_tablename(tablename) + return self.cache.fetch_more(tablename, limit=limit) + + def fetch_all(self, *tablenames): + """Fetches items from the DB into memory. Unlike :meth:`fetch_more`, this method fetches entire tables. + + Args: + *tablenames (str): The tables to fetch. If none given, then the entire DB is fecthed. + """ + tablenames = set(self.ITEM_TYPES) if not tablenames else set(tablenames) & set(self.ITEM_TYPES) + for tablename in tablenames: + tablename = self._real_tablename(tablename) + self.cache.fetch_all(tablename) + + def add_item(self, tablename, **kwargs): + tablename = self._real_tablename(tablename) + table_cache = self.cache.table_cache(tablename) + self._convert_legacy(tablename, kwargs) + checked_item, error = table_cache.check_item(kwargs) + if error: + return None, error + return table_cache.add_item(checked_item, new=True), None + + +class PublicItem: + def __init__(self, db_map, cache_item): + self._db_map = db_map + self._cache_item = cache_item + + @property + def item_type(self): + return self._cache_item.item_type + + def __getitem__(self, key): + return self._cache_item[key] + + def __eq__(self, other): + if isinstance(other, dict): + return self._cache_item == other + return super().__eq__(other) + + def __repr__(self): + return repr(self._cache_item) + + def __str__(self): + return str(self._cache_item) + + def get(self, key, default=None): + return self._cache_item.get(key, default) + + def is_valid(self): + return self._cache_item.is_valid() + + def is_committed(self): + return self._cache_item.is_committed() + + def _asdict(self): + return self._cache_item._asdict() + + def update(self, **kwargs): + self._db_map.update_item(self.item_type, id=self["id"], **kwargs) + + def remove(self): + return self._db_map.remove_item(self.item_type, self["id"]) + + def restore(self): + return self._db_map.restore_item(self.item_type, self["id"]) + + def add_update_callback(self, callback): + self._cache_item.update_callbacks.add(callback) + + def add_remove_callback(self, callback): + self._cache_item.remove_callbacks.add(callback) + + def add_restore_callback(self, callback): + self._cache_item.restore_callbacks.add(callback) diff --git a/spinedb_api/db_mapping_add_mixin.py b/spinedb_api/db_mapping_add_mixin.py index 5416f2bc..81158b4f 100644 --- a/spinedb_api/db_mapping_add_mixin.py +++ b/spinedb_api/db_mapping_add_mixin.py @@ -9,9 +9,6 @@ # this program. If not, see . ###################################################################################################################### -"""Provides :class:`.DatabaseMappingAddMixin`. - -""" # TODO: improve docstrings from sqlalchemy.exc import DBAPIError @@ -37,24 +34,26 @@ def add_items(self, tablename, *items, check=True, strict=False): list(str): found violations """ added, errors = [], [] - tablename = self._real_tablename(tablename) - table_cache = self.cache.table_cache(tablename) if not check: for item in items: - self._convert_legacy(tablename, item) - added.append(table_cache.add_item(item, new=True)) + added.append(self._add_item_unsafe(tablename, item)) else: for item in items: - self._convert_legacy(tablename, item) - checked_item, error = table_cache.check_item(item) + item, error = self.add_item(tablename, **item) if error: if strict: raise SpineIntegrityError(error) errors.append(error) continue - added.append(table_cache.add_item(checked_item, new=True)) + added.append(item) return added, errors + def _add_item_unsafe(self, tablename, item): + tablename = self._real_tablename(tablename) + table_cache = self.cache.table_cache(tablename) + self._convert_legacy(tablename, item) + return table_cache.add_item(item, new=True) + def _do_add_items(self, connection, tablename, *items_to_add): """Add items to DB without checking integrity.""" if not items_to_add: diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 1545878b..9a12f724 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -9,8 +9,6 @@ # this program. If not, see . ###################################################################################################################### -"""Provides :class:`.DatabaseMappingBase`.""" -# TODO: Finish docstrings import hashlib import os import logging @@ -38,19 +36,18 @@ model_meta, copy_database_bind, ) -from .filters.tools import pop_filter_configs +from .filters.tools import pop_filter_configs, apply_filter_stack, load_filters from .spine_db_client import get_db_url_from_server from .db_cache_impl import DBCache from .query import Query - logging.getLogger("alembic").setLevel(logging.CRITICAL) class DatabaseMappingBase: """Base class for all database mappings. - It provides the :meth:`query` method for custom db querying. + Provides the :meth:`query` method for performing custom ``SELECT`` queries. """ _session_kwargs = {} @@ -80,20 +77,18 @@ def __init__( apply_filters=True, memory=False, sqlite_timeout=1800, - chunk_size=None, ): """ Args: db_url (str or URL): A URL in RFC-1738 format pointing to the database to be mapped, or to a DB server. - username (str, optional): A user name. If ``None``, it gets replaced by the string ``"anon"``. - upgrade (bool): Whether or not the db at the given URL should be upgraded to the most recent version. - codename (str, optional): A name that uniquely identifies the class instance within a client application. - create (bool): Whether or not to create a Spine db at the given URL if it's not already. - apply_filters (bool): Whether or not filters in the URL's query part are applied to the database map. - memory (bool): Whether or not to use a sqlite memory db as replacement for this DB map. - sqlite_timeout (int): How many seconds to wait before raising connection errors. - chunk_size (int, optional): How many rows to fetch from the DB at a time when populating the cache. - If not specified, then all rows are fetched at once. + username (str, optional): A user name. If not given, it gets replaced by the string ``"anon"``. + upgrade (bool, optional): Whether the db at the given URL should be upgraded to the most recent + version. + codename (str, optional): A name to associate with the DB mapping. + create (bool, optional): Whether to create a Spine db at the given URL if it's not one already. + apply_filters (bool, optional): Whether to apply filters in the URL's query part. + memory (bool, optional): Whether or not to use a sqlite memory db as replacement for this DB map. + sqlite_timeout (int, optional): How many seconds to wait before raising connection errors. """ # FIXME: We should also check the server memory property and use it here db_url = get_db_url_from_server(db_url) @@ -110,7 +105,7 @@ def __init__( self.codename = self._make_codename(codename) self._memory = memory self._memory_dirty = False - self._original_engine = self.create_engine( + self._original_engine = self._create_engine( self.sa_url, upgrade=upgrade, create=create, sqlite_timeout=sqlite_timeout ) # NOTE: The NullPool is needed to receive the close event (or any events), for some reason @@ -121,7 +116,7 @@ def __init__( self._metadata = MetaData(self.engine) self._metadata.reflect() self._tablenames = [t.name for t in self._metadata.sorted_tables] - self.cache = DBCache(self, chunk_size=chunk_size) + self.cache = DBCache(self) self.closed = False # Subqueries that select everything from each table self._commit_sq = None @@ -186,6 +181,9 @@ def __init__( "entity_element": ("entity_id", "position"), "entity_class_dimension": ("entity_class_id", "position"), } + if self._filter_configs is not None: + stack = load_filters(self._filter_configs) + apply_filter_stack(self, stack) def __enter__(self): return self @@ -194,12 +192,19 @@ def __exit__(self, _exc_type, _exc_val, _exc_tb): self.close() def get_filter_configs(self): + """Returns filters applicable to this DB mapping. + + Returns: + list(dict) + """ return self._filter_configs def close(self): + """Closes this DB mapping.""" self.closed = True - def _real_tablename(self, tablename): + @staticmethod + def _real_tablename(tablename): return { "object_class": "entity_class", "relationship_class": "entity_class", @@ -208,6 +213,7 @@ def _real_tablename(self, tablename): }.get(tablename, tablename) def get_table(self, tablename): + # For tests return self._metadata.tables[tablename] def _make_codename(self, codename): @@ -222,8 +228,8 @@ def _make_codename(self, codename): return hashing.hexdigest() @staticmethod - def create_engine(sa_url, upgrade=False, create=False, sqlite_timeout=1800): - """Create engine. + def _create_engine(sa_url, upgrade=False, create=False, sqlite_timeout=1800): + """Creates engine. Args sa_url (URL) @@ -339,28 +345,28 @@ def _clear_subqueries(self, *tablenames): setattr(self, attr_name, None) def query(self, *args, **kwargs): - """Return a sqlalchemy :class:`~Query` object bound to this :class:`.DatabaseMappingBase`. + """Returns a :class:`~spinedb_api.query.Query` object bound to this :class:`.DatabaseMappingBase`. To perform custom ``SELECT`` statements, call this method with one or more of the class documented - :class:`~sqlalchemy.sql.expression.Alias` properties. For example, to select the object class with + :class:`~sqlalchemy.sql.expression.Alias` properties. For example, to select the entity class with ``id`` equal to 1:: from spinedb_api import DatabaseMapping url = 'sqlite:///spine.db' ... db_map = DatabaseMapping(url) - db_map.query(db_map.object_class_sq).filter_by(id=1).one_or_none() + db_map.query(db_map.entity_class_sq).filter_by(id=1).one_or_none() - To perform more complex queries, just use this method in combination with the SQLAlchemy API. - For example, to select all object class names and the names of their objects concatenated in a string:: + To perform more complex queries, just use the :class:`~spinedb_api.query.Query` interface. + For example, to select all entity class names and the names of their entities concatenated in a string:: from sqlalchemy import func db_map.query( - db_map.object_class_sq.c.name, func.group_concat(db_map.object_sq.c.name) + db_map.entity_class_sq.c.name, func.group_concat(db_map.entity_sq.c.name) ).filter( - db_map.object_sq.c.class_id == db_map.object_class_sq.c.id - ).group_by(db_map.object_class_sq.c.name).all() + db_map.entity_sq.c.class_id == db_map.entity_class_sq.c.id + ).group_by(db_map.entity_class_sq.c.name).all() """ return Query(self.engine, *args) @@ -380,31 +386,13 @@ def _subquery(self, tablename): table = self._metadata.tables[tablename] return self.query(table).subquery(tablename + "_sq") - @property - def alternative_sq(self): - if self._alternative_sq is None: - self._alternative_sq = self._make_alternative_sq() - return self._alternative_sq - - @property - def scenario_sq(self): - if self._scenario_sq is None: - self._scenario_sq = self._make_scenario_sq() - return self._scenario_sq - - @property - def scenario_alternative_sq(self): - if self._scenario_alternative_sq is None: - self._scenario_alternative_sq = self._make_scenario_alternative_sq() - return self._scenario_alternative_sq - @property def entity_class_sq(self): """A subquery of the form: .. code-block:: sql - SELECT * FROM class + SELECT * FROM entity_class Returns: sqlalchemy.sql.expression.Alias @@ -415,36 +403,18 @@ def entity_class_sq(self): @property def entity_class_dimension_sq(self): - if self._entity_class_dimension_sq is None: - self._entity_class_dimension_sq = self._subquery("entity_class_dimension") - return self._entity_class_dimension_sq - - @property - def entity_sq(self): """A subquery of the form: .. code-block:: sql - SELECT * FROM entity + SELECT * FROM entity_class_dimension Returns: sqlalchemy.sql.expression.Alias """ - if self._entity_sq is None: - self._entity_sq = self._make_entity_sq() - return self._entity_sq - - @property - def entity_element_sq(self): - if self._entity_element_sq is None: - self._entity_element_sq = self._make_entity_element_sq() - return self._entity_element_sq - - @property - def entity_alternative_sq(self): - if self._entity_alternative_sq is None: - self._entity_alternative_sq = self._subquery("entity_alternative") - return self._entity_alternative_sq + if self._entity_class_dimension_sq is None: + self._entity_class_dimension_sq = self._subquery("entity_class_dimension") + return self._entity_class_dimension_sq @property def wide_entity_class_sq(self): @@ -519,6 +489,36 @@ def wide_entity_class_sq(self): ) return self._wide_entity_class_sq + @property + def entity_sq(self): + """A subquery of the form: + + .. code-block:: sql + + SELECT * FROM entity + + Returns: + sqlalchemy.sql.expression.Alias + """ + if self._entity_sq is None: + self._entity_sq = self._make_entity_sq() + return self._entity_sq + + @property + def entity_element_sq(self): + """A subquery of the form: + + .. code-block:: sql + + SELECT * FROM entity_element + + Returns: + sqlalchemy.sql.expression.Alias + """ + if self._entity_element_sq is None: + self._entity_element_sq = self._make_entity_element_sq() + return self._entity_element_sq + @property def wide_entity_sq(self): """A subquery of the form: @@ -563,139 +563,121 @@ def wide_entity_sq(self): group_concat(ext_entity_sq.c.element_id, ext_entity_sq.c.position).label("element_id_list"), group_concat(ext_entity_sq.c.element_name, ext_entity_sq.c.position).label("element_name_list"), ) - # element count might be lower than dimension count when element-entities have been filtered out - # .filter(self.wide_entity_class_sq.c.id == ext_entity_sq.c.class_id) - # .having(self.wide_entity_class_sq.c.dimension_count == func.count(ext_entity_sq.c.element_id)) .group_by( ext_entity_sq.c.id, ext_entity_sq.c.class_id, ext_entity_sq.c.name, ext_entity_sq.c.description, ext_entity_sq.c.commit_id, - ).subquery("wide_entity_sq") + ) + .subquery("wide_entity_sq") ) return self._wide_entity_sq @property - def object_class_sq(self): + def entity_group_sq(self): """A subquery of the form: .. code-block:: sql - SELECT * FROM object_class + SELECT * FROM entity_group Returns: sqlalchemy.sql.expression.Alias """ - if self._object_class_sq is None: - self._object_class_sq = ( - self.query( - self.wide_entity_class_sq.c.id.label("id"), - self.wide_entity_class_sq.c.name.label("name"), - self.wide_entity_class_sq.c.description.label("description"), - self.wide_entity_class_sq.c.display_order.label("display_order"), - self.wide_entity_class_sq.c.display_icon.label("display_icon"), - self.wide_entity_class_sq.c.hidden.label("hidden"), - ) - .filter(self.wide_entity_class_sq.c.dimension_id_list == None) - .subquery("object_class_sq") - ) - return self._object_class_sq + if self._entity_group_sq is None: + self._entity_group_sq = self._subquery("entity_group") + return self._entity_group_sq @property - def object_sq(self): + def alternative_sq(self): """A subquery of the form: .. code-block:: sql - SELECT * FROM object + SELECT * FROM alternative Returns: sqlalchemy.sql.expression.Alias """ - if self._object_sq is None: - self._object_sq = ( - self.query( - self.wide_entity_sq.c.id.label("id"), - self.wide_entity_sq.c.class_id.label("class_id"), - self.wide_entity_sq.c.name.label("name"), - self.wide_entity_sq.c.description.label("description"), - self.wide_entity_sq.c.commit_id.label("commit_id"), - ) - .filter(self.wide_entity_sq.c.element_id_list == None) - .subquery("object_sq") - ) - return self._object_sq + if self._alternative_sq is None: + self._alternative_sq = self._make_alternative_sq() + return self._alternative_sq @property - def relationship_class_sq(self): + def scenario_sq(self): """A subquery of the form: .. code-block:: sql - SELECT * FROM relationship_class + SELECT * FROM scenario Returns: sqlalchemy.sql.expression.Alias """ - if self._relationship_class_sq is None: - ent_cls_dim_sq = self._subquery("entity_class_dimension") - self._relationship_class_sq = ( - self.query( - ent_cls_dim_sq.c.entity_class_id.label("id"), - ent_cls_dim_sq.c.position.label("dimension"), # NOTE: nothing to do with the `dimension` concept - ent_cls_dim_sq.c.dimension_id.label("object_class_id"), - self.wide_entity_class_sq.c.name.label("name"), - self.wide_entity_class_sq.c.description.label("description"), - self.wide_entity_class_sq.c.display_icon.label("display_icon"), - self.wide_entity_class_sq.c.hidden.label("hidden"), - ) - .filter(self.wide_entity_class_sq.c.id == ent_cls_dim_sq.c.entity_class_id) - .subquery("relationship_class_sq") - ) - return self._relationship_class_sq + if self._scenario_sq is None: + self._scenario_sq = self._make_scenario_sq() + return self._scenario_sq @property - def relationship_sq(self): + def scenario_alternative_sq(self): """A subquery of the form: .. code-block:: sql - SELECT * FROM relationship + SELECT * FROM scenario_alternative Returns: sqlalchemy.sql.expression.Alias """ - if self._relationship_sq is None: - ent_el_sq = self._subquery("entity_element") - self._relationship_sq = ( - self.query( - ent_el_sq.c.entity_id.label("id"), - ent_el_sq.c.position.label("dimension"), # NOTE: nothing to do with the `dimension` concept - ent_el_sq.c.element_id.label("object_id"), - ent_el_sq.c.entity_class_id.label("class_id"), - self.wide_entity_sq.c.name.label("name"), - self.wide_entity_sq.c.commit_id.label("commit_id"), - ) - .filter(self.wide_entity_sq.c.id == ent_el_sq.c.entity_id) - .subquery("relationship_sq") - ) - return self._relationship_sq + if self._scenario_alternative_sq is None: + self._scenario_alternative_sq = self._make_scenario_alternative_sq() + return self._scenario_alternative_sq @property - def entity_group_sq(self): + def entity_alternative_sq(self): """A subquery of the form: .. code-block:: sql - SELECT * FROM entity_group + SELECT * FROM entity_alternative Returns: sqlalchemy.sql.expression.Alias """ - if self._entity_group_sq is None: - self._entity_group_sq = self._subquery("entity_group") - return self._entity_group_sq + if self._entity_alternative_sq is None: + self._entity_alternative_sq = self._subquery("entity_alternative") + return self._entity_alternative_sq + + @property + def parameter_value_list_sq(self): + """A subquery of the form: + + .. code-block:: sql + + SELECT * FROM parameter_value_list + + Returns: + sqlalchemy.sql.expression.Alias + """ + if self._parameter_value_list_sq is None: + self._parameter_value_list_sq = self._subquery("parameter_value_list") + return self._parameter_value_list_sq + + @property + def list_value_sq(self): + """A subquery of the form: + + .. code-block:: sql + + SELECT * FROM list_value + + Returns: + sqlalchemy.sql.expression.Alias + """ + if self._list_value_sq is None: + self._list_value_sq = self._subquery("list_value") + return self._list_value_sq @property def parameter_definition_sq(self): @@ -729,51 +711,136 @@ def parameter_value_sq(self): return self._parameter_value_sq @property - def parameter_value_list_sq(self): + def metadata_sq(self): """A subquery of the form: .. code-block:: sql - SELECT * FROM parameter_value_list + SELECT * FROM list_value Returns: sqlalchemy.sql.expression.Alias """ - if self._parameter_value_list_sq is None: - self._parameter_value_list_sq = self._subquery("parameter_value_list") - return self._parameter_value_list_sq - - @property - def list_value_sq(self): - if self._list_value_sq is None: - self._list_value_sq = self._subquery("list_value") - return self._list_value_sq - - @property - def metadata_sq(self): if self._metadata_sq is None: self._metadata_sq = self._subquery("metadata") return self._metadata_sq @property def parameter_value_metadata_sq(self): + """A subquery of the form: + + .. code-block:: sql + + SELECT * FROM parameter_value_metadata + + Returns: + sqlalchemy.sql.expression.Alias + """ if self._parameter_value_metadata_sq is None: self._parameter_value_metadata_sq = self._subquery("parameter_value_metadata") return self._parameter_value_metadata_sq @property def entity_metadata_sq(self): + """A subquery of the form: + + .. code-block:: sql + + SELECT * FROM entity_metadata + + Returns: + sqlalchemy.sql.expression.Alias + """ if self._entity_metadata_sq is None: self._entity_metadata_sq = self._subquery("entity_metadata") return self._entity_metadata_sq @property def commit_sq(self): + """A subquery of the form: + + .. code-block:: sql + + SELECT * FROM commit + + Returns: + sqlalchemy.sql.expression.Alias + """ if self._commit_sq is None: commit_sq = self._subquery("commit") self._commit_sq = self.query(commit_sq).filter(commit_sq.c.comment != "").subquery() return self._commit_sq + @property + def object_class_sq(self): + if self._object_class_sq is None: + self._object_class_sq = ( + self.query( + self.wide_entity_class_sq.c.id.label("id"), + self.wide_entity_class_sq.c.name.label("name"), + self.wide_entity_class_sq.c.description.label("description"), + self.wide_entity_class_sq.c.display_order.label("display_order"), + self.wide_entity_class_sq.c.display_icon.label("display_icon"), + self.wide_entity_class_sq.c.hidden.label("hidden"), + ) + .filter(self.wide_entity_class_sq.c.dimension_id_list == None) + .subquery("object_class_sq") + ) + return self._object_class_sq + + @property + def object_sq(self): + if self._object_sq is None: + self._object_sq = ( + self.query( + self.wide_entity_sq.c.id.label("id"), + self.wide_entity_sq.c.class_id.label("class_id"), + self.wide_entity_sq.c.name.label("name"), + self.wide_entity_sq.c.description.label("description"), + self.wide_entity_sq.c.commit_id.label("commit_id"), + ) + .filter(self.wide_entity_sq.c.element_id_list == None) + .subquery("object_sq") + ) + return self._object_sq + + @property + def relationship_class_sq(self): + if self._relationship_class_sq is None: + ent_cls_dim_sq = self._subquery("entity_class_dimension") + self._relationship_class_sq = ( + self.query( + ent_cls_dim_sq.c.entity_class_id.label("id"), + ent_cls_dim_sq.c.position.label("dimension"), # NOTE: nothing to do with the `dimension` concept + ent_cls_dim_sq.c.dimension_id.label("object_class_id"), + self.wide_entity_class_sq.c.name.label("name"), + self.wide_entity_class_sq.c.description.label("description"), + self.wide_entity_class_sq.c.display_icon.label("display_icon"), + self.wide_entity_class_sq.c.hidden.label("hidden"), + ) + .filter(self.wide_entity_class_sq.c.id == ent_cls_dim_sq.c.entity_class_id) + .subquery("relationship_class_sq") + ) + return self._relationship_class_sq + + @property + def relationship_sq(self): + if self._relationship_sq is None: + ent_el_sq = self._subquery("entity_element") + self._relationship_sq = ( + self.query( + ent_el_sq.c.entity_id.label("id"), + ent_el_sq.c.position.label("dimension"), # NOTE: nothing to do with the `dimension` concept + ent_el_sq.c.element_id.label("object_id"), + ent_el_sq.c.entity_class_id.label("class_id"), + self.wide_entity_sq.c.name.label("name"), + self.wide_entity_sq.c.commit_id.label("commit_id"), + ) + .filter(self.wide_entity_sq.c.id == ent_el_sq.c.entity_id) + .subquery("relationship_sq") + ) + return self._relationship_sq + @property def ext_parameter_value_list_sq(self): if self._ext_parameter_value_list_sq is None: @@ -938,22 +1005,6 @@ def ext_linked_scenario_alternative_sq(self): @property def ext_object_sq(self): - """A subquery of the form: - - .. code-block:: sql - - SELECT - o.id, - o.class_id, - oc.name AS class_name, - o.name, - o.description, - FROM object AS o, object_class AS oc - WHERE o.class_id = oc.id - - Returns: - sqlalchemy.sql.expression.Alias - """ if self._ext_object_sq is None: self._ext_object_sq = ( self.query( @@ -974,22 +1025,6 @@ def ext_object_sq(self): @property def ext_relationship_class_sq(self): - """A subquery of the form: - - .. code-block:: sql - - SELECT - rc.id, - rc.name, - oc.id AS object_class_id, - oc.name AS object_class_name - FROM relationship_class AS rc, object_class AS oc - WHERE rc.object_class_id = oc.id - ORDER BY rc.id, rc.dimension - - Returns: - sqlalchemy.sql.expression.Alias - """ if self._ext_relationship_class_sq is None: self._ext_relationship_class_sq = ( self.query( @@ -1009,30 +1044,6 @@ def ext_relationship_class_sq(self): @property def wide_relationship_class_sq(self): - """A subquery of the form: - - .. code-block:: sql - - SELECT - id, - name, - GROUP_CONCAT(object_class_id) AS object_class_id_list, - GROUP_CONCAT(object_class_name) AS object_class_name_list - FROM ( - SELECT - rc.id, - rc.name, - oc.id AS object_class_id, - oc.name AS object_class_name - FROM relationship_class AS rc, object_class AS oc - WHERE rc.object_class_id = oc.id - ORDER BY rc.id, rc.dimension - ) - GROUP BY id, name - - Returns: - sqlalchemy.sql.expression.Alias - """ if self._wide_relationship_class_sq is None: self._wide_relationship_class_sq = ( self.query( @@ -1059,24 +1070,6 @@ def wide_relationship_class_sq(self): @property def ext_relationship_sq(self): - """A subquery of the form: - - .. code-block:: sql - - SELECT - r.id, - r.class_id, - r.name, - o.id AS object_id, - o.name AS object_name, - o.class_id AS object_class_id, - FROM relationship as r, object AS o - WHERE r.object_id = o.id - ORDER BY r.id, r.dimension - - Returns: - sqlalchemy.sql.expression.Alias - """ if self._ext_relationship_sq is None: self._ext_relationship_sq = ( self.query( @@ -1100,33 +1093,6 @@ def ext_relationship_sq(self): @property def wide_relationship_sq(self): - """A subquery of the form: - - .. code-block:: sql - - SELECT - id, - class_id, - class_name, - name, - GROUP_CONCAT(object_id) AS object_id_list, - GROUP_CONCAT(object_name) AS object_name_list - FROM ( - SELECT - r.id, - r.class_id, - r.name, - o.id AS object_id, - o.name AS object_name - FROM relationship as r, object AS o - WHERE r.object_id = o.id - ORDER BY r.id, r.dimension - ) - GROUP BY id, class_id, name - - Returns: - sqlalchemy.sql.expression.Alias - """ if self._wide_relationship_sq is None: self._wide_relationship_sq = ( self.query( @@ -1165,11 +1131,6 @@ def wide_relationship_sq(self): @property def ext_entity_group_sq(self): - """A subquery of the form: - - Returns: - sqlalchemy.sql.expression.Alias - """ if self._ext_entity_group_sq is None: group_entity = aliased(self.entity_sq) member_entity = aliased(self.entity_sq) @@ -1194,10 +1155,6 @@ def ext_entity_group_sq(self): @property def entity_parameter_definition_sq(self): - """ - Returns: - sqlalchemy.sql.expression.Alias - """ if self._entity_parameter_definition_sq is None: self._entity_parameter_definition_sq = ( self.query( @@ -1237,38 +1194,6 @@ def entity_parameter_definition_sq(self): @property def object_parameter_definition_sq(self): - """A subquery of the form: - - .. code-block:: sql - - SELECT - pd.id, - oc.id AS object_class_id, - oc.name AS object_class_name, - pd.name AS parameter_name, - wpvl.id AS value_list_id, - wpvl.name AS value_list_name, - pd.default_value - FROM parameter_definition AS pd, object_class AS oc - ON wpdt.parameter_definition_id = pd.id - LEFT JOIN ( - SELECT - id, - name, - GROUP_CONCAT(value) AS value_list - FROM ( - SELECT id, name, value - FROM parameter_value_list - ORDER BY id, value_index - ) - GROUP BY id, name - ) AS wpvl - ON wpvl.id = pd.parameter_value_list_id - WHERE pd.object_class_id = oc.id - - Returns: - sqlalchemy.sql.expression.Alias - """ if self._object_parameter_definition_sq is None: self._object_parameter_definition_sq = ( self.query( @@ -1295,59 +1220,6 @@ def object_parameter_definition_sq(self): @property def relationship_parameter_definition_sq(self): - """A subquery of the form: - - .. code-block:: sql - - SELECT - pd.id, - wrc.id AS relationship_class_id, - wrc.name AS relationship_class_name, - wrc.object_class_id_list, - wrc.object_class_name_list, - pd.name AS parameter_name, - wpvl.id AS value_list_id, - wpvl.name AS value_list_name, - pd.default_value - FROM - parameter_definition AS pd, - ( - SELECT - id, - name, - GROUP_CONCAT(object_class_id) AS object_class_id_list, - GROUP_CONCAT(object_class_name) AS object_class_name_list - FROM ( - SELECT - rc.id, - rc.name, - oc.id AS object_class_id, - oc.name AS object_class_name - FROM relationship_class AS rc, object_class AS oc - WHERE rc.object_class_id = oc.id - ORDER BY rc.id, rc.dimension - ) - GROUP BY id, name - ) AS wrc - ON wpdt.parameter_definition_id = pd.id - LEFT JOIN ( - SELECT - id, - name, - GROUP_CONCAT(value) AS value_list - FROM ( - SELECT id, name, value - FROM parameter_value_list - ORDER BY id, value_index - ) - GROUP BY id, name - ) AS wpvl - ON wpvl.id = pd.parameter_value_list_id - WHERE pd.relationship_class_id = wrc.id - - Returns: - sqlalchemy.sql.expression.Alias - """ if self._relationship_parameter_definition_sq is None: self._relationship_parameter_definition_sq = ( self.query( @@ -1376,10 +1248,6 @@ def relationship_parameter_definition_sq(self): @property def entity_parameter_value_sq(self): - """ - Returns: - sqlalchemy.sql.expression.Alias - """ if self._entity_parameter_value_sq is None: self._entity_parameter_value_sq = ( self.query( @@ -1436,11 +1304,6 @@ def entity_parameter_value_sq(self): @property def object_parameter_value_sq(self): - """A subquery of the form: - - Returns: - sqlalchemy.sql.expression.Alias - """ if self._object_parameter_value_sq is None: self._object_parameter_value_sq = ( self.query( @@ -1468,11 +1331,6 @@ def object_parameter_value_sq(self): @property def relationship_parameter_value_sq(self): - """A subquery of the form: - - Returns: - sqlalchemy.sql.expression.Alias - """ if self._relationship_parameter_value_sq is None: self._relationship_parameter_value_sq = ( self.query( @@ -1503,10 +1361,6 @@ def relationship_parameter_value_sq(self): @property def ext_parameter_value_metadata_sq(self): - """ - Returns: - sqlalchemy.sql.expression.Alias - """ if self._ext_parameter_value_metadata_sq is None: self._ext_parameter_value_metadata_sq = ( self.query( @@ -1531,10 +1385,6 @@ def ext_parameter_value_metadata_sq(self): @property def ext_entity_metadata_sq(self): - """ - Returns: - sqlalchemy.sql.expression.Alias - """ if self._ext_entity_metadata_sq is None: self._ext_entity_metadata_sq = ( self.query( @@ -1684,12 +1534,6 @@ def _create_import_alternative(self): self._import_alternative_name = "Base" def override_create_import_alternative(self, method): - """ - Overrides the ``_create_import_alternative`` function. - - Args: - method (Callable) - """ self._create_import_alternative = MethodType(method, self) self._import_alternative_name = None @@ -1838,11 +1682,6 @@ def _reset_mapping(self): connection.execute(table.delete()) connection.execute("INSERT INTO alternative VALUES (1, 'Base', 'Base alternative', null)") - def fetch_all(self, tablenames=None): - tablenames = set(self.ITEM_TYPES) if tablenames is None else tablenames & set(self.ITEM_TYPES) - for tablename in tablenames: - self.cache.fetch_all(tablename) - def _object_class_id(self): return case( [(self.wide_entity_class_sq.c.dimension_id_list == None, self.wide_entity_class_sq.c.id)], else_=None @@ -1904,14 +1743,6 @@ def _object_name_list(self): [(self.wide_entity_sq.c.element_id_list != None, self.wide_relationship_sq.c.object_name_list)], else_=None ) - def advance_cache_query(self, item_type): - """Schedules an advance of the DB query that fetches items of given type. - - Args: - item_type (str) - """ - return self.cache.advance_query(item_type) - @staticmethod def _convert_legacy(tablename, item): if tablename in ("entity_class", "entity"): diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py index 7658641e..0b1f16fd 100644 --- a/spinedb_api/db_mapping_commit_mixin.py +++ b/spinedb_api/db_mapping_commit_mixin.py @@ -9,11 +9,6 @@ # this program. If not, see . ###################################################################################################################### -""" -Provides :class:`.QuickDatabaseMappingBase`. - -""" - from datetime import datetime, timezone import sqlalchemy.exc from .exception import SpineDBAPIError diff --git a/spinedb_api/db_mapping_remove_mixin.py b/spinedb_api/db_mapping_remove_mixin.py index 67882bfd..4b703e27 100644 --- a/spinedb_api/db_mapping_remove_mixin.py +++ b/spinedb_api/db_mapping_remove_mixin.py @@ -9,9 +9,6 @@ # this program. If not, see . ###################################################################################################################### -"""Provides :class:`.DiffDatabaseMappingRemoveMixin`. - -""" from sqlalchemy import and_, or_ from sqlalchemy.exc import DBAPIError @@ -26,6 +23,16 @@ class DatabaseMappingRemoveMixin: """Provides methods to perform ``REMOVE`` operations over a Spine db.""" def remove_items(self, tablename, *ids): + """Removes items from the DB. + + Args: + tablename (str): Target database table name + *ids (int): Ids of items to be removed. + + Returns: + set: ids or items successfully updated + list(SpineIntegrityError): found violations + """ if not ids: return [] tablename = self._real_tablename(tablename) @@ -46,6 +53,16 @@ def restore_items(self, tablename, *ids): table_cache = self.cache.table_cache(tablename) return [table_cache.restore_item(id_) for id_ in ids] + def remove_item(self, tablename, id_): + tablename = self._real_tablename(tablename) + table_cache = self.cache.table_cache(tablename) + return table_cache.remove_item(id_) + + def restore_item(self, tablename, id_): + tablename = self._real_tablename(tablename) + table_cache = self.cache.table_cache(tablename) + return table_cache.restore_item(id_) + def purge_items(self, tablename): """Removes all items from given table. @@ -90,9 +107,9 @@ def _do_remove_items(self, connection, tablename, *ids): def remove_unused_metadata(self): used_metadata_ids = set() - for x in self.cache.get("entity_metadata", {}).values(): + for x in self.cache.table_cache("entity_metadata").valid_values(): used_metadata_ids.add(x["metadata_id"]) - for x in self.cache.get("parameter_value_metadata", {}).values(): + for x in self.cache.table_cache("parameter_value_metadata").valid_values(): used_metadata_ids.add(x["metadata_id"]) - unused_metadata_ids = {x["id"] for x in self.cache.get("metadata", {}).values()} - used_metadata_ids + unused_metadata_ids = {x["id"] for x in self.cache.table_cache("metadata").valid_values()} - used_metadata_ids self.remove_items("metadata", *unused_metadata_ids) diff --git a/spinedb_api/db_mapping_update_mixin.py b/spinedb_api/db_mapping_update_mixin.py index 78619b76..1b4f03c1 100644 --- a/spinedb_api/db_mapping_update_mixin.py +++ b/spinedb_api/db_mapping_update_mixin.py @@ -9,9 +9,6 @@ # this program. If not, see . ###################################################################################################################### -"""Provides :class:`DatabaseMappingUpdateMixin`. - -""" from sqlalchemy.exc import DBAPIError from sqlalchemy.sql.expression import bindparam from .exception import SpineIntegrityError, SpineDBAPIError @@ -88,25 +85,33 @@ def update_items(self, tablename, *items, check=True, strict=False): list(SpineIntegrityError): found violations """ updated, errors = [], [] - tablename = self._real_tablename(tablename) - table_cache = self.cache.table_cache(tablename) if not check: for item in items: - self._convert_legacy(tablename, item) - updated.append(table_cache.update_item(item)) + updated.append(self._update_item_unsafe(tablename, item)) else: for item in items: - self._convert_legacy(tablename, item) - checked_item, error = table_cache.check_item(item, for_update=True) + item, error = self.update_item(tablename, **item) if error: if strict: raise SpineIntegrityError(error) errors.append(error) - if checked_item: - item = checked_item._asdict() - updated.append(table_cache.update_item(item)) + if item: + updated.append(item) return updated, errors + def _update_item_unsafe(self, tablename, item): + tablename = self._real_tablename(tablename) + table_cache = self.cache.table_cache(tablename) + self._convert_legacy(tablename, item) + return table_cache.update_item(item) + + def update_item(self, tablename, **kwargs): + tablename = self._real_tablename(tablename) + table_cache = self.cache.table_cache(tablename) + self._convert_legacy(tablename, kwargs) + checked_item, error = table_cache.check_item(kwargs, for_update=True) + return table_cache.update_item(checked_item._asdict()) if checked_item else None, error + def update_alternatives(self, *items, **kwargs): return self.update_items("alternative", *items, **kwargs) diff --git a/spinedb_api/exception.py b/spinedb_api/exception.py index ea9f8980..c2554dab 100644 --- a/spinedb_api/exception.py +++ b/spinedb_api/exception.py @@ -10,8 +10,7 @@ ###################################################################################################################### """ -Classes to handle exceptions while using the Spine database API. - +Spine DB API exceptions. """ @@ -50,33 +49,6 @@ def __init__(self, url=None, current=None, expected=None, upgrade_available=True self.upgrade_available = upgrade_available -class SpineTableNotFoundError(SpineDBAPIError): - """Can't find one of the tables.""" - - def __init__(self, table, url=None): - super().__init__(msg="Table(s) '{}' couldn't be mapped from the database at '{}'.".format(table, url)) - self.table = table - - -class RecordNotFoundError(SpineDBAPIError): - """Can't find one record in one of the tables.""" - - def __init__(self, table, name=None, id=None): - super().__init__(msg="Unable to find item in table '{}'.".format(table)) - self.table = table - self.name = name - self.id = id - - -class ParameterValueError(SpineDBAPIError): - """The value given for a parameter does not fit the datatype.""" - - def __init__(self, value, data_type): - super().__init__(msg="The value {} does not fit the datatype '{}'.".format(value, data_type)) - self.value = value - self.data_type = data_type - - class ParameterValueFormatError(SpineDBAPIError): """ Failure in encoding/decoding a parameter value. @@ -91,7 +63,7 @@ def __init__(self, msg): class InvalidMapping(SpineDBAPIError): """ - Failure in import/export mapping + Failure in import/export mapping. """ def __init__(self, msg): @@ -106,4 +78,4 @@ def __init__(self, msg, rank=None, key=None): class ConnectorError(SpineDBAPIError): - """Failure in import connector.""" + """Failure in import/export connector.""" diff --git a/spinedb_api/export_functions.py b/spinedb_api/export_functions.py index a2ccf395..f53cf525 100644 --- a/spinedb_api/export_functions.py +++ b/spinedb_api/export_functions.py @@ -10,7 +10,7 @@ ###################################################################################################################### """ -Functions for exporting data from a Spine database using entity names as references. +Functions for exporting data from a Spine database in a standard format. """ from operator import itemgetter @@ -31,22 +31,26 @@ def export_data( alternative_ids=Asterisk, scenario_ids=Asterisk, scenario_alternative_ids=Asterisk, + entity_alternative_ids=Asterisk, parse_value=from_database, ): """ - Exports data from given database into a dictionary that can be splatted into keyword arguments for ``import_data``. + Exports data from a Spine DB into a standard dictionary format. + The result can be splatted into keyword arguments for :func:`spinedb_api.import_functions.import_data` + to transfer data from one DB to another. Args: - db_map (DiffDatabaseMapping): The db to pull stuff from. - entity_class_ids (Iterable, optional): A collection of ids to pick from the database table - entity_ids (Iterable, optional): A collection of ids to pick from the database table - entity_group_ids (Iterable, optional): A collection of ids to pick from the database table - parameter_value_list_ids (Iterable, optional): A collection of ids to pick from the database table - parameter_definition_ids (Iterable, optional): A collection of ids to pick from the database table - parameter_value_ids (Iterable, optional): A collection of ids to pick from the database table - alternative_ids (Iterable, optional): A collection of ids to pick from the database table - scenario_ids (Iterable, optional): A collection of ids to pick from the database table - scenario_alternative_ids (Iterable, optional): A collection of ids to pick from the database table + db_map (DatabaseMapping): The db to pull data from. + entity_class_ids (Iterable, optional): If given, only exports classes with these ids + entity_ids (Iterable, optional): If given, only exports entities with these ids + entity_group_ids (Iterable, optional): If given, only exports groups with these ids + parameter_value_list_ids (Iterable, optional): If given, only exports lists with these ids + parameter_definition_ids (Iterable, optional): If given, only exports parameter definitions with these ids + parameter_value_ids (Iterable, optional): If given, only exports parameter values with these ids + alternative_ids (Iterable, optional): If given, only exports alternatives with these ids + scenario_ids (Iterable, optional): If given, only exports scenarios with these ids + scenario_alternative_ids (Iterable, optional): If given, only exports scenario alternatives with these ids + entity_alternative_ids (Iterable, optional): If given, only exports entity alternatives with these ids Returns: dict: exported data @@ -54,6 +58,7 @@ def export_data( data = { "entity_classes": export_entity_classes(db_map, entity_class_ids), "entities": export_entities(db_map, entity_ids), + "entity_alternatives": export_entity_alternatives(db_map, entity_alternative_ids), "entity_groups": export_entity_groups(db_map, entity_group_ids), "parameter_value_lists": export_parameter_value_lists( db_map, parameter_value_list_ids, parse_value=parse_value @@ -72,7 +77,7 @@ def export_data( def _get_items(db_map, tablename, ids): if not ids: return () - _process_item = _make_item_processor(db_map, tablename) + _process_item = _make_item_processor(db_map.cache, tablename) for item in _get_items_from_cache(db_map.cache, tablename, ids): yield from _process_item(item) @@ -80,7 +85,7 @@ def _get_items(db_map, tablename, ids): def _get_items_from_cache(cache, tablename, ids): if ids is Asterisk: cache.fetch_all(tablename) - yield from cache.get(tablename, {}).values() + yield from cache.table_cache(tablename).valid_values() return for id_ in ids: item = cache.get_item(tablename, id_) or cache.fetch_ref(tablename, id_) @@ -88,10 +93,10 @@ def _get_items_from_cache(cache, tablename, ids): yield item -def _make_item_processor(db_map, tablename): +def _make_item_processor(cache, tablename): if tablename == "parameter_value_list": - db_map.fetch_all({"list_value"}) - return _ParameterValueListProcessor(db_map.cache.get("list_value", {}).values()) + cache.fetch_all("list_value") + return _ParameterValueListProcessor(cache.table_cache("list_value").valid_values()) return lambda item: (item,) @@ -134,6 +139,13 @@ def export_entity_groups(db_map, ids=Asterisk): return sorted((x.class_name, x.group_name, x.member_name) for x in _get_items(db_map, "entity_group", ids)) +def export_entity_alternatives(db_map, ids=Asterisk): + return sorted( + (x.entity_class_name, x.entity_byname, x.alternative_name, x.active) + for x in _get_items(db_map, "entity_alternative", ids) + ) + + def export_parameter_definitions(db_map, ids=Asterisk, parse_value=from_database): return sorted( ( @@ -164,51 +176,14 @@ def export_parameter_values(db_map, ids=Asterisk, parse_value=from_database): def export_alternatives(db_map, ids=Asterisk): - """ - Exports alternatives from database. - - The format is what :func:`import_alternatives` accepts as its input. - - Args: - db_map (spinedb_api.DatabaseMapping or spinedb_api.DiffDatabaseMapping): a database map - ids (Iterable, optional): ids of the alternatives to export - - Returns: - Iterable: tuples of two elements: name of alternative and description - """ return sorted((x.name, x.description) for x in _get_items(db_map, "alternative", ids)) def export_scenarios(db_map, ids=Asterisk): - """ - Exports scenarios from database. - - The format is what :func:`import_scenarios` accepts as its input. - - Args: - db_map (spinedb_api.DatabaseMapping or spinedb_api.DiffDatabaseMapping): a database map - ids (Iterable, optional): ids of the scenarios to export - - Returns: - Iterable: tuples of two elements: name of scenario and description - """ return sorted((x.name, x.active, x.description) for x in _get_items(db_map, "scenario", ids)) def export_scenario_alternatives(db_map, ids=Asterisk): - """ - Exports scenario alternatives from database. - - The format is what :func:`import_scenario_alternatives` accepts as its input. - - Args: - db_map (spinedb_api.DatabaseMapping or spinedb_api.DiffDatabaseMapping): a database map - ids (Iterable, optional): ids of the scenario alternatives to export - - Returns: - Iterable: tuples of three elements: name of scenario, tuple containing one alternative name, - and name of next alternative - """ return sorted( ( (x.scenario_name, x.alternative_name, x.before_alternative_name) diff --git a/spinedb_api/filters/execution_filter.py b/spinedb_api/filters/execution_filter.py index e76fc6c3..10a6f7e7 100644 --- a/spinedb_api/filters/execution_filter.py +++ b/spinedb_api/filters/execution_filter.py @@ -162,8 +162,8 @@ def _create_import_alternative(db_map, state): db_map.add_alternatives({"name": db_map._import_alternative_name}) db_map.add_scenarios(*({"name": scen_name} for scen_name in scenarios)) for scen_name in scenarios: - scen = db_map.cache.table_cache("scenario").find_item({"name": scen_name}) - rank = len(scen.sorted_scenario_alternatives) + 1 # ranks are 1-based + scen = db_map.get_item("scenario", name=scen_name) + rank = len(scen["sorted_scenario_alternatives"]) + 1 # ranks are 1-based db_map.add_scenario_alternatives( {"scenario_name": scen_name, "alternative_name": db_map._import_alternative_name, "rank": rank} ) diff --git a/spinedb_api/graph_layout_generator.py b/spinedb_api/graph_layout_generator.py index 074df3e9..e1e3be49 100644 --- a/spinedb_api/graph_layout_generator.py +++ b/spinedb_api/graph_layout_generator.py @@ -10,8 +10,7 @@ ###################################################################################################################### """ -Contains the GraphLayoutGenerator class. - +This module defines the :class:`.GraphLayoutGenerator` class. """ import math @@ -21,7 +20,7 @@ class GraphLayoutGenerator: - """Computes the layout for the Entity Graph View.""" + """A class to build an optimised layout for an undirected graph.""" def __init__( self, @@ -37,6 +36,26 @@ def __init__( layout_available=lambda x, y: None, layout_progressed=lambda iter: None, ): + """ + Args: + vertex_count (int): The number of vertices in the graph. Graph vertices will have indices 0, 1, 2, ... + src_inds (tuple,optional): The indices of the source vertices of each edge. + dst_inds (tuple,optional): The indices of the destination vertices of each edge. + spread (int,optional): the ideal edge length. + heavy_positions (dict,optional): a dictionary mapping vertex indices to another dictionary + with keys "x" and "y" specifying the position it should have in the generated layout. + max_iters (int,optional): the maximum numbers of iterations of the layout generation algorithm. + weight_exp (int,optional): The exponential decay rate of attraction between vertices. The higher this + number, the lesser the attraction between distant vertices. + is_stopped (function,optional): A function to call without arguments, that returns a boolean indicating + whether the layout generation process needs to be stopped. + preview_available (function,optional): A function to call after every iteration with two lists, x and y, + representing the current layout. + layout_available (function,optional): A function to call after the last iteration with two lists, x and y, + representing the final layout. + layout_progressed (function,optional): A function to call after each iteration with the current iteration + number. + """ super().__init__() if vertex_count == 0: vertex_count = 1 @@ -56,7 +75,6 @@ def __init__( self._layout_progressed = layout_progressed def shortest_path_matrix(self): - """Returns the shortest-path matrix.""" if not self.src_inds: # Graph with no edges, just vertices. Introduce fake pair of edges to help 'spreadness'. self.src_inds = [self.vertex_count, self.vertex_count] @@ -85,7 +103,6 @@ def shortest_path_matrix(self): return matrix def sets(self): - """Returns sets of vertex pairs indices.""" sets = [] for n in range(1, self.vertex_count): pairs = np.zeros((self.vertex_count - n, 2), int) # pairs on diagonal n @@ -101,7 +118,11 @@ def sets(self): return sets def compute_layout(self): - """Computes and returns x and y coordinates for each vertex in the graph, using VSGD-MS.""" + """Computes the layout using VSGD-MS and returns x and y coordinates for each vertex in the graph. + + Returns: + tuple(list,list): x and y coordinates + """ if len(self.heavy_positions) == self.vertex_count: x, y = zip(*[(pos["x"], pos["y"]) for pos in self.heavy_positions.values()]) self._layout_available(x, y) diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index f5bbbfc2..ecb464a2 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -10,7 +10,7 @@ ###################################################################################################################### """ -General helper functions and classes. +General helper functions. """ @@ -94,13 +94,13 @@ def set_sqlite_pragma(dbapi_connection, connection_record): @compiles(TINYINT, "sqlite") def compile_TINYINT_mysql_sqlite(element, compiler, **kw): - """Handles mysql TINYINT datatype as INTEGER in sqlite.""" + # Handles mysql TINYINT datatype as INTEGER in sqlite. return compiler.visit_INTEGER(element, **kw) @compiles(DOUBLE, "sqlite") def compile_DOUBLE_mysql_sqlite(element, compiler, **kw): - """Handles mysql DOUBLE datatype as REAL in sqlite.""" + # Handles mysql DOUBLE datatype as REAL in sqlite. return compiler.visit_REAL(element, **kw) @@ -154,8 +154,8 @@ def _parse_metadata(metadata): yield (key, str(value)) -def is_head(db_url, upgrade=False): - """Check whether or not db_url is head. +def _is_head(db_url, upgrade=False): + """Check whether or not db_url is at the head revision. Args: db_url (str): database url @@ -166,12 +166,6 @@ def is_head(db_url, upgrade=False): def is_head_engine(engine, upgrade=False): - """Check whether or not engine is head. - - Args: - engine (Engine): database engine - upgrade (Bool): if True, upgrade db to head - """ config = Config() config.set_main_option("script_location", "spinedb_api:alembic") script = ScriptDirectory.from_config(config) @@ -198,8 +192,17 @@ def fn(rev, context): def copy_database(dest_url, source_url, overwrite=True, upgrade=False, only_tables=(), skip_tables=()): - """Copy the database from source_url into dest_url.""" - if not is_head(source_url, upgrade=upgrade): + """Copy the database from one url to another. + + Args: + dest_url (str): The destination url. + source_url (str): The source url. + overwrite (bool,optional): whether to overwrite the destination. + upgrade (bool,optional): whether to upgrade the source to the latest Spine schema revision. + only_tables (tuple,optional): If given, only these tables are copied. + skip_tables (tuple,optional): If given, these tables are skipped. + """ + if not _is_head(source_url, upgrade=upgrade): raise SpineDBVersionError(url=source_url) source_engine = create_engine(source_url) dest_engine = create_engine(dest_url) @@ -249,7 +252,7 @@ def copy_database_bind(dest_bind, source_bind, overwrite=True, upgrade=False, on def custom_generate_relationship(base, direction, return_fn, attrname, local_cls, referred_cls, **kw): - """Make all relationships view only to avoid warnings.""" + # Make all relationships view only to avoid warnings. kw["viewonly"] = True kw["cascade"] = "" kw["passive_deletes"] = False @@ -257,21 +260,7 @@ def custom_generate_relationship(base, direction, return_fn, attrname, local_cls return generate_relationship(base, direction, return_fn, attrname, local_cls, referred_cls, **kw) -def is_unlocked(db_url, timeout=0): - """Return True if the SQLite db_url is unlocked, after waiting at most timeout seconds. - Otherwise return False.""" - if not db_url.startswith("sqlite"): - return False - try: - engine = create_engine(db_url, connect_args={"timeout": timeout}) - engine.execute("BEGIN IMMEDIATE") - return True - except OperationalError: - return False - - def compare_schemas(left_engine, right_engine): - """Whether or not the left and right engine have the same schema.""" left_insp = inspect(left_engine) right_insp = inspect(right_engine) left_dict = schema_dict(left_insp) @@ -570,7 +559,14 @@ def create_spine_metadata(): def create_new_spine_database(db_url): - """Create a new Spine database at the given url.""" + """Create a new Spine database at the given url. + + Args: + db_url (str): The url. + + Returns: + Engine + """ try: engine = create_engine(db_url) except DatabaseError as e: @@ -745,8 +741,8 @@ def _create_first_spine_database(db_url): def forward_sweep(root, fn, *args): - """Recursively visit, using `get_children()`, the given sqlalchemy object. - Apply `fn` on every visited node.""" + # Recursively visit, using `get_children()`, the given sqlalchemy object. + # Apply `fn` on every visited node.""" current = root parent = {} children = {current: iter(current.get_children(column_collections=False))} @@ -786,7 +782,7 @@ def __repr__(self): def fix_name_ambiguity(input_list, offset=0, prefix=""): - """Modify repeated entries in name list by appending an increasing integer.""" + # Modify repeated entries in name list by appending an increasing integer. result = [] ocurrences = {} for item in input_list: diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 8b3bec9b..30280a21 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -10,91 +10,85 @@ ###################################################################################################################### """ -Functions for importing data into a Spine database using entity names as references. - +Functions for importing data into a Spine database in a standard format. """ from .parameter_value import to_database, fix_conflict from .helpers import _parse_metadata -# TODO: update docstrings - - -class ImportErrorLogItem: - """Class to hold log data for import errors""" - - def __init__(self, msg="", db_type="", imported_from="", other=""): - self.msg = msg - self.db_type = db_type - self.imported_from = imported_from - self.other = other - - def __repr__(self): - return self.msg - def import_data(db_map, unparse_value=to_database, on_conflict="merge", **kwargs): - """Imports data into a Spine database using name references (rather than id references). + """Imports data into a Spine database using a standard format. Example:: - object_c = ['example_class', 'other_class'] - obj_parameters = [['example_class', 'example_parameter']] - relationship_c = [['example_rel_class', ['example_class', 'other_class']]] - rel_parameters = [['example_rel_class', 'rel_parameter']] - objects = [['example_class', 'example_object'], - ['other_class', 'other_object']] - object_p_values = [['example_object_class', 'example_object', 'example_parameter', 3.14]] - relationships = [['example_rel_class', ['example_object', 'other_object']]] - rel_p_values = [['example_rel_class', ['example_object', 'other_object'], 'rel_parameter', 2.718]] - object_groups = [['object_class_name', 'object_group_name', ['member_name', 'another_member_name']]] - alternatives = [['example_alternative', 'An example']] - scenarios = [['example_scenario', 'An example']] - scenario_alternatives = [('scenario', 'alternative1'), ('scenario', 'alternative0', 'alternative1')] - - import_data(db_map, - object_classes=object_c, - relationship_classes=relationship_c, - object_parameters=obj_parameters, - relationship_parameters=rel_parameters, - objects=objects, - relationships=relationships, - object_groups=object_groups, - object_parameter_values=object_p_values, - relationship_parameter_values=rel_p_values, - alternatives=alternatives, - scenarios=scenarios, - scenario_alternatives=scenario_alternatives) + entity_classes = [ + ('example_class', ()), ('other_class', ()), ('multi_d_class', ('example_class', 'other_class')) + ] + alternatives = [('example_alternative', 'An example')] + scenarios = [('example_scenario', 'An example')] + scenario_alternatives = [ + ('example_scenario', 'example_alternative'), ('example_scenario', 'Base', 'example_alternative') + ] + parameter_value_lists = [("example_list", "value1"), ("example_list", "value2")] + parameter_definitions = [('example_class', 'example_parameter'), ('multi_d_class', 'other_parameter')] + entities = [ + ('example_class', 'example_entity'), + ('example_class', 'example_group'), + ('example_class', 'example_member'), + ('other_class', 'other_entity'), + ('multi_d_class', ('example_entity', 'other_entity')), + ] + entity_groups = [ + ('example_class', 'example_group', 'example_member'), + ('example_class', 'example_group', 'example_entity'), + ] + parameter_values = [ + ('example_object_class', 'example_entity', 'example_parameter', 3.14), + ('multi_d_class', ('example_entity', 'other_entity'), 'rel_parameter', 2.718), + ] + entity_alternatives = [ + ('example_class', 'example_entity', "example_alternative", True), + ('example_class', 'example_entity', "example_alternative", False), + ] + import_data( + db_map, + entity_classes=entity_classes, + alternatives=alternatives, + scenarios=scenarios, + scenario_alternatives=scenario_alternatives, + parameter_value_lists=parameter_value_lists, + parameter_definitions=parameter_definitions, + entities=entities, + entity_groups=entity_groups, + parameter_values=parameter_values, + entity_alternatives=entity_alternatives, + ) Args: db_map (spinedb_api.DiffDatabaseMapping): database mapping - on_conflict (str): Conflict resolution strategy for ``parameter_value.fix_conflict`` - object_classes (List[str]): List of object class names - relationship_classes (List[List[str, List(str)]): - List of lists with relationship class names and list of object class names - object_parameters (List[List[str, str]]): - list of lists with object class name and parameter name - relationship_parameters (List[List[str, str]]): - list of lists with relationship class name and parameter name - objects (List[List[str, str]]): - list of lists with object class name and object name - relationships: (List[List[str,List(String)]]): - list of lists with relationship class name and list of object names - object_groups (List[List/Tuple]): list/set/iterable of lists/tuples with object class name, group name, - and member name - object_parameter_values (List[List[str, str, str|numeric]]): - list of lists with object name, parameter name, parameter value - relationship_parameter_values (List[List[str, List(str), str, str|numeric]]): - list of lists with relationship class name, list of object names, parameter name, parameter value - alternatives (Iterable): alternative names or lists of two elements: alternative name and description - scenarios (Iterable): scenario names or lists of two elements: scenario name and description - scenario_alternatives (Iterable): lists of two elements: scenario name and a list of names of alternatives + on_conflict (str): Conflict resolution strategy for :func:`parameter_value.fix_conflict` + entity_classes (list(tuple(str,tuple,str,int)): tuples of + (name, dimension name tuple, description, display icon integer) + parameter_definitions (list(tuple(str,str,str,str)): + tuples of (class name, parameter name, default value, parameter value list name, description) + entities: (list(tuple(str,str or tuple(str)): tuples of (class name, entity name or element name list) + entity_alternatives: (list(tuple(str,str or tuple(str),str,bool): tuples of + (class name, entity name or element name list, alternative name, activity) + entity_groups (list(tuple(str,str,str))): tuples of (class name, group entity name, member entity name) + parameter_values (list(tuple(str,str or tuple(str),str,str|numeric,str]): + tuples of (class name, entity name or element name list, parameter name, value, alternative name) + alternatives (list(str,str)): tuples of (name, description) + scenarios (list(str,str)): tuples of (name, description) + scenario_alternatives (list(str,str,str)): tuples of + (scenario name, alternative name, preceeding alternative name) + parameter_value_lists (list(str,str|numeric)): tuples of (list name, value) Returns: - tuple: number of inserted/changed entities and list of ImportErrorLogItem with - any import errors + int: number of items imported + list: errors """ - error_log = [] + all_errors = [] num_imports = 0 for tablename, (to_add, to_update, errors) in get_data_for_import( db_map, unparse_value=unparse_value, on_conflict=on_conflict, **kwargs @@ -102,8 +96,8 @@ def import_data(db_map, unparse_value=to_database, on_conflict="merge", **kwargs updated, _ = db_map.update_items(tablename, *to_update, check=False) added, _ = db_map.add_items(tablename, *to_add, check=False) num_imports += len(added) + len(updated) - error_log.extend(errors) - return num_imports, error_log + all_errors.extend(errors) + return num_imports, all_errors def get_data_for_import( @@ -143,34 +137,30 @@ def get_data_for_import( tool_features=(), tool_feature_methods=(), ): - """Returns an iterator of data for import, that the user can call instead of `import_data` - if they want to add and update the data by themselves. - Especially intended to be used with the toolbox undo/redo functionality. + """Yields data to import into a Spine DB. Args: db_map (spinedb_api.DiffDatabaseMapping): database mapping - on_conflict (str): Conflict resolution strategy for ``parameter_value.fix_conflict`` - object_classes (List[str]): List of object class names - relationship_classes (List[List[str, List(str)]): - List of lists with relationship class names and list of object class names - object_parameters (List[List[str, str]]): - list of lists with object class name and parameter name - relationship_parameters (List[List[str, str]]): - list of lists with relationship class name and parameter name - objects (List[List[str, str]]): - list of lists with object class name and object name - relationships: (List[List[str,List(String)]]): - list of lists with relationship class name and list of object names - object_groups (List[List/Tuple]): list/set/iterable of lists/tuples with object class name, group name, - and member name - object_parameter_values (List[List[str, str, str|numeric]]): - list of lists with object name, parameter name, parameter value - relationship_parameter_values (List[List[str, List(str), str, str|numeric]]): - list of lists with relationship class name, list of object names, parameter name, - parameter value + on_conflict (str): Conflict resolution strategy for :func:`~spinedb_api.parameter_value.fix_conflict` + entity_classes (list(tuple(str,tuple,str,int)): tuples of + (name, dimension name tuple, description, display icon integer) + parameter_definitions (list(tuple(str,str,str,str)): + tuples of (class name, parameter name, default value, parameter value list name) + entities: (list(tuple(str,str or tuple(str)): tuples of (class name, entity name or element name list) + entity_alternatives: (list(tuple(str,str or tuple(str),str,bool): tuples of + (class name, entity name or element name list, alternative name, activity) + entity_groups (list(tuple(str,str,str))): tuples of (class name, group entity name, member entity name) + parameter_values (list(tuple(str,str or tuple(str),str,str|numeric,str]): + tuples of (class name, entity name or element name list, parameter name, value, alternative name) + alternatives (list(str,str)): tuples of (name, description) + scenarios (list(str,str)): tuples of (name, description) + scenario_alternatives (list(str,str,str)): tuples of + (scenario name, alternative name, preceeding alternative name) + parameter_value_lists (list(str,str|numeric)): tuples of (list name, value) Yields: - tuple(str, list) + str: item type + tuple(list,list,list): tuple of (items to add, items to update, errors) """ # NOTE: The order is important, because of references. E.g., we want to import alternatives before parameter_values if alternatives: @@ -254,521 +244,213 @@ def get_data_for_import( def import_entity_classes(db_map, data): - """Imports entity classes. - - Example:: - - data = [ - 'new_class', - ('another_class', 'description', 123456), - ('multidimensional_class', 'description', 654321, ("new_class", "another_class")) - ] - import_entity_classes(db_map, data) + """Imports entity classes into a Spine database using a standard format. Args: - db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into - data (Iterable): list/set/iterable of string entity class names, - and optionally description, integer display icon reference, and lists/tuples with dimension names, + db_map (spinedb_api.DiffDatabaseMapping): database mapping + data (list(tuple(str,tuple,str,int)): tuples of + (name, dimension name tuple, description, display icon integer) Returns: - tuple of int and list: Number of successfully inserted object classes, list of errors + int: number of items imported + list: errors """ return import_data(db_map, entity_classes=data) def import_entities(db_map, data): - """Imports entities. - - Example:: - - data = [ - ('class_name1', 'entity_name1'), - ('class_name2', 'entity_name2'), - ('class_name3', ('entity_name1', 'entity_name2')) - ] - import_entities(db_map, data) + """Imports entities into a Spine database using a standard format. Args: - db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into - data (List[List/Tuple]): list/set/iterable of lists/tuples with entity class name - and entity name or list/tuple of element names + db_map (spinedb_api.DiffDatabaseMapping): database mapping + data: (list(tuple(str,str or tuple(str)): tuples of (class name, entity name or element name list) Returns: - (Int, List) Number of successful inserted entities, list of errors + int: number of items imported + list: errors """ return import_data(db_map, entities=data) def import_entity_alternatives(db_map, data): - """Imports entity alternatives. - - Example:: - - data = [ - ('class_name1', 'entity_name1', 'alternative_name3', True), - ('class_name2', 'entity_name2', 'alternative_name4', False), - ('class_name3', ('entity_name1', 'entity_name2'), 'alternative_name5', False) - ] - import_entity_alternatives(db_map, data) + """Imports entity alternatives into a Spine database using a standard format. Args: - db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into - data (List[List/Tuple]): list/set/iterable of lists/tuples with entity class name, - entity name or list/tuple of element names, alternative name, active boolean value + db_map (spinedb_api.DiffDatabaseMapping): database mapping + data: (list(tuple(str,str or tuple(str),str,bool): tuples of + (class name, entity name or element name list, alternative name, activity) Returns: - (Int, List) Number of successful inserted entities, list of errors + int: number of items imported + list: errors """ return import_data(db_map, entity_alternatives=data) def import_entity_groups(db_map, data): - """Imports list of entity groups by name with associated class name into given database mapping: - Ignores duplicate and existing (group, member) tuples. - - Example:: - - data = [ - ('class_name', 'group_name', 'member_name'), - ('class_name', 'group_name', 'another_member_name') - ] - import_entity_groups(db_map, data) + """Imports entity groups into a Spine database using a standard format. Args: - db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into - data (List[List/Tuple]): list/set/iterable of lists/tuples with entity class name, group name, - and member name + db_map (spinedb_api.DiffDatabaseMapping): database mapping + data (list(tuple(str,str,str))): tuples of (class name, group entity name, member entity name) Returns: - (Int, List) Number of successful inserted entity groups, list of errors + int: number of items imported + list: errors """ return import_data(db_map, entity_groups=data) def import_parameter_definitions(db_map, data, unparse_value=to_database): - """Imports list of parameter definitions: - - Example:: - - data = [ - ('entity_class_1', 'new_parameter'), - ('entity_class_2', 'other_parameter', 'default_value', 'value_list_name', 'description') - ] - import_parameter_definitions(db_map, data) + """Imports parameter definitions into a Spine database using a standard format. Args: - db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into - data (List[List/Tuple]): list/set/iterable of lists/tuples with entity class name, parameter name, - and optionally default value, value list name, and description + db_map (spinedb_api.DiffDatabaseMapping): database mapping + data (list(tuple(str,str,str,str)): + tuples of (class name, parameter name, default value, parameter value list name) Returns: - (Int, List) Number of successful inserted parameter definitions, list of errors + int: number of items imported + list: errors """ return import_data(db_map, parameter_definitions=data, unparse_value=unparse_value) def import_parameter_values(db_map, data, unparse_value=to_database, on_conflict="merge"): - """Imports parameter values: - - Example:: - - data = [ - ['example_class2', 'example_entity', 'parameter', 5.5, 'alternative'], - ['example_class1', ('example_entity', 'other_entity'), 'parameter', 2.718] - ] - import_parameter_values(db_map, data) + """Imports parameter values into a Spine database using a standard format. Args: - db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into - data (List[List/Tuple]): list/set/iterable of lists/tuples with - entity class name, entity name or list of element names, parameter name, (deserialized) parameter value, - optional name of an alternative + db_map (spinedb_api.DiffDatabaseMapping): database mapping + data (list(tuple(str,str or tuple(str),str,str|numeric,str]): + tuples of (class name, entity name or element name list, parameter name, value, alternative name) + on_conflict (str): Conflict resolution strategy for :func:`~spinedb_api.parameter_value.fix_conflict` Returns: - (Int, List) Number of successful inserted parameter values, list of errors + int: number of items imported + list: errors """ return import_data(db_map, parameter_values=data, unparse_value=unparse_value, on_conflict=on_conflict) def import_alternatives(db_map, data): - """ - Imports alternatives. - - Example: - - data = ['new_alternative', ('another_alternative', 'description')] - import_alternatives(db_map, data) + """Imports alternatives into a Spine database using a standard format. Args: - db_map (DiffDatabaseMapping): mapping for database to insert into - data (Iterable): an iterable of alternative names, - or of lists/tuples with alternative names and optional descriptions + db_map (spinedb_api.DiffDatabaseMapping): database mapping + data (list(str,str)): tuples of (name, description) Returns: - tuple of int and list: Number of successfully inserted alternatives, list of errors + int: number of items imported + list: errors """ return import_data(db_map, alternatives=data) def import_scenarios(db_map, data): - """ - Imports scenarios. - - Example: - - second_active = True - third_active = False - data = ['scenario', ('second_scenario', second_active), ('third_scenario', third_active, 'description')] - import_scenarios(db_map, data) + """Imports scenarios into a Spine database using a standard format. Args: - db_map (DiffDatabaseMapping): mapping for database to insert into - data (Iterable): an iterable of scenario names, - or of lists/tuples with scenario names and optional descriptions + db_map (spinedb_api.DiffDatabaseMapping): database mapping + data (list(str,str)): tuples of (name, description) Returns: - tuple of int and list: Number of successfully inserted scenarios, list of errors + int: number of items imported + list: errors """ return import_data(db_map, scenarios=data) def import_scenario_alternatives(db_map, data): - """ - Imports scenario alternatives. - - Example: - - data = [('scenario', 'bottom_alternative'), ('another_scenario', 'top_alternative', 'bottom_alternative')] - import_scenario_alternatives(db_map, data) + """Imports scenario alternatives into a Spine database using a standard format. Args: - db_map (DiffDatabaseMapping): mapping for database to insert into - data (Iterable): an iterable of (scenario name, alternative name, - and optionally, 'before' alternative name). - Alternatives are inserted before the 'before' alternative, - or at the end if not given. + db_map (spinedb_api.DiffDatabaseMapping): database mapping + data (list(str,str,str)): tuples of (scenario name, alternative name, preceeding alternative name) Returns: - tuple of int and list: Number of successfully inserted scenario alternatives, list of errors + int: number of items imported + list: errors """ return import_data(db_map, scenario_alternatives=data) def import_parameter_value_lists(db_map, data, unparse_value=to_database): - """Imports list of parameter value lists: - - Example:: - - data = [ - ['value_list_name', value1], ['value_list_name', value2], - ['another_value_list_name', value3], - ] - import_parameter_value_lists(db_map, data) + """Imports parameter value lists into a Spine database using a standard format. Args: - db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into - data (List[List/Tuple]): list/set/iterable of lists/tuples with - value list name, list of values + db_map (spinedb_api.DiffDatabaseMapping): database mapping + data (list(str,str|numeric)): tuples of (list name, value) Returns: - (Int, List) Number of successful inserted objects, list of errors + int: number of items imported + list: errors """ return import_data(db_map, parameter_value_lists=data, unparse_value=unparse_value) -def import_metadata(db_map, data=None): - """Imports metadata. Ignores duplicates. - - Example:: - - data = ['{"name1": "value1"}', '{"name2": "value2"}'] - import_metadata(db_map, data) +def import_metadata(db_map, data): + """Imports metadata into a Spine database using a standard format. Args: - db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into - data (List[List/Tuple]): list/set/iterable of string metadata entries in JSON format + db_map (spinedb_api.DiffDatabaseMapping): database mapping + data (list(tuple(str,str))): tuples of (entry name, value) Returns: - (Int, List) Number of successful inserted objects, list of errors + int: number of items imported + list: errors """ return import_data(db_map, metadata=data) def import_object_classes(db_map, data): - """Imports object classes. - - Example:: - - data = ['new_object_class', ('another_object_class', 'description', 123456)] - import_object_classes(db_map, data) - - Args: - db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into - data (Iterable): list/set/iterable of string object class names, or of lists/tuples with object class names, - and optionally description and integer display icon reference - - Returns: - tuple of int and list: Number of successfully inserted object classes, list of errors - """ return import_data(db_map, object_classes=data) def import_relationship_classes(db_map, data): - """Imports relationship classes. - - Example:: - - data = [ - ('new_rel_class', ['object_class_1', 'object_class_2']), - ('another_rel_class', ['object_class_3', 'object_class_4'], 'description'), - ] - import_relationship_classes(db_map, data) - - Args: - db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into - data (List[List/Tuple]): list/set/iterable of lists/tuples with relationship class names, - list of object class names, and optionally description - - Returns: - (Int, List) Number of successful inserted objects, list of errors - """ return import_data(db_map, relationship_classes=data) def import_objects(db_map, data): - """Imports list of object by name with associated object class name into given database mapping: - Ignores duplicate names and existing names. - - Example:: - - data = [ - ('object_class_name', 'new_object'), - ('object_class_name', 'other_object', 'description') - ] - import_objects(db_map, data) - - Args: - db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into - data (List[List/Tuple]): list/set/iterable of lists/tuples with object name and object class name - - Returns: - (Int, List) Number of successful inserted objects, list of errors - """ return import_data(db_map, objects=data) def import_object_groups(db_map, data): - """Imports list of object groups by name with associated object class name into given database mapping: - Ignores duplicate and existing (group, member) tuples. - - Example:: - - data = [ - ('object_class_name', 'object_group_name', 'member_name'), - ('object_class_name', 'object_group_name', 'another_member_name') - ] - import_object_groups(db_map, data) - - Args: - db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into - data (List[List/Tuple]): list/set/iterable of lists/tuples with object class name, group name, - and member name - - Returns: - (Int, List) Number of successful inserted objects, list of errors - """ return import_data(db_map, object_groups=data) def import_relationships(db_map, data): - """Imports relationships. - - Example:: - - data = [('relationship_class_name', ('object_name1', 'object_name2'))] - import_relationships(db_map, data) - - Args: - db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into - data (List[List/Tuple]): list/set/iterable of lists/tuples with relationship class name - and list/tuple of object names - - Returns: - (Int, List) Number of successful inserted objects, list of errors - """ return import_data(db_map, relationships=data) def import_object_parameters(db_map, data, unparse_value=to_database): - """Imports list of object class parameters: - - Example:: - - data = [ - ('object_class_1', 'new_parameter'), - ('object_class_2', 'other_parameter', 'default_value', 'value_list_name', 'description') - ] - import_object_parameters(db_map, data) - - Args: - db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into - data (List[List/Tuple]): list/set/iterable of lists/tuples with object class name, parameter name, - and optionally default value, value list name, and description - - Returns: - (Int, List) Number of successful inserted objects, list of errors - """ return import_data(db_map, object_parameters=data, unparse_value=unparse_value) def import_relationship_parameters(db_map, data, unparse_value=to_database): - """Imports list of relationship class parameters: - - Example:: - - data = [ - ('relationship_class_1', 'new_parameter'), - ('relationship_class_2', 'other_parameter', 'default_value', 'value_list_name', 'description') - ] - import_relationship_parameters(db_map, data) - - Args: - db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into - data (List[List/Tuple]): list/set/iterable of lists/tuples with relationship class name, parameter name, - and optionally default value, value list name, and description - - Returns: - (Int, List) Number of successful inserted objects, list of errors - """ return import_data(db_map, relationship_parameters=data, unparse_value=unparse_value) def import_object_parameter_values(db_map, data, unparse_value=to_database, on_conflict="merge"): - """Imports object parameter values: - - Example:: - - data = [('object_class_name', 'object_name', 'parameter_name', 123.4), - ('object_class_name', 'object_name', 'parameter_name2', ), - ('object_class_name', 'object_name', 'parameter_name', , 'alternative')] - import_object_parameter_values(db_map, data) - - Args: - db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into - data (List[List/Tuple]): list/set/iterable of lists/tuples with - object_class_name, object name, parameter name, (deserialized) parameter value, - optional name of an alternative - - Returns: - (Int, List) Number of successful inserted objects, list of errors - """ return import_data(db_map, object_parameter_values=data, unparse_value=unparse_value, on_conflict=on_conflict) def import_relationship_parameter_values(db_map, data, unparse_value=to_database, on_conflict="merge"): - """Imports relationship parameter values: - - Example:: - - data = [['example_rel_class', - ['example_object', 'other_object'], 'rel_parameter', 2.718], - ['example_object', 'other_object'], 'rel_parameter', 5.5, 'alternative']] - import_relationship_parameter_values(db_map, data) - - Args: - db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into - data (List[List/Tuple]): list/set/iterable of lists/tuples with - relationship class name, list of object names, parameter name, (deserialized) parameter value, - optional name of an alternative - - Returns: - (Int, List) Number of successful inserted objects, list of errors - """ return import_data(db_map, relationship_parameter_values=data, unparse_value=unparse_value, on_conflict=on_conflict) def import_object_metadata(db_map, data): - """Imports object metadata. Ignores duplicates. - - Example:: - - data = [("classA", "object1", '{"name1": "value1"}'), ("classA", "object1", '{"name2": "value2"}')] - import_object_metadata(db_map, data) - - Args: - db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into - data (List[List/Tuple]): list/set/iterable of tuples with class name, object name, - and string metadata entries in JSON format - - Returns: - (Int, List) Number of successful inserted items, list of errors - """ return import_data(db_map, object_metadata=data) def import_relationship_metadata(db_map, data): - """Imports relationship metadata. Ignores duplicates. - - Example:: - - data = [ - ("classA", ("object1", "object2"), '{"name1": "value1"}'), - ("classA", ("object3", "object4"), '{"name2": "value2"}') - ] - import_relationship_metadata(db_map, data) - - Args: - db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into - data (List[List/Tuple]): list/set/iterable of tuples with class name, tuple of object names, - and string metadata entries in JSON format - - Returns: - (Int, List) Number of successful inserted items, list of errors - """ return import_data(db_map, relationship_metadata=data) def import_object_parameter_value_metadata(db_map, data): - """Imports object parameter value metadata. Ignores duplicates. - - Example:: - - data = [ - ("classA", "object1", "parameterX", '{"name1": "value1"}'), - ("classA", "object1", "parameterY", '{"name2": "value2"}', "alternativeA") - ] - import_object_parameter_value_metadata(db_map, data) - - Args: - db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into - data (List[List/Tuple]): list/set/iterable of tuples with class name, object name, - parameter name, string metadata entries in JSON format, and optionally alternative name - - Returns: - (Int, List) Number of successful inserted items, list of errors - """ return import_data(db_map, object_parameter_value_metadata=data) def import_relationship_parameter_value_metadata(db_map, data): - """Imports relationship parameter value metadata. Ignores duplicates. - - Example:: - - data = [ - ("classA", ("object1", "object2"), "parameterX", '{"name1": "value1"}'), - ("classA", ("object3", "object4"), "parameterY", '{"name2": "value2"}', "alternativeA") - ] - import_object_parameter_value_metadata(db_map, data) - - Args: - db_map (spinedb_api.DiffDatabaseMapping): mapping for database to insert into - data (List[List/Tuple]): list/set/iterable of tuples with class name, tuple of object names, - parameter name, string metadata entries in JSON format, and optionally alternative name - - Returns: - (Int, List) Number of successful inserted items, list of errors - """ return import_data(db_map, relationship_parameter_value_metadata=data) @@ -954,7 +636,7 @@ def _data_iterator(): index = max( ( x["index"] - for x in db_map.cache.get("list_value", {}).values() + for x in db_map.cache.table_cache("list_value").valid_values() if x["parameter_value_list_id"] == current_list["id"] ), default=-1, diff --git a/spinedb_api/import_mapping/__init__.py b/spinedb_api/import_mapping/__init__.py index 46105c99..9966601e 100644 --- a/spinedb_api/import_mapping/__init__.py +++ b/spinedb_api/import_mapping/__init__.py @@ -8,3 +8,7 @@ # Public License for more details. You should have received a copy of the GNU Lesser General Public License along with # this program. If not, see . ###################################################################################################################### +""" +This package contains facilities to map tables into a Spine database. + +""" diff --git a/spinedb_api/mapping.py b/spinedb_api/mapping.py index c2fc5888..0d5d0461 100644 --- a/spinedb_api/mapping.py +++ b/spinedb_api/mapping.py @@ -8,10 +8,8 @@ # Public License for more details. You should have received a copy of the GNU Lesser General Public License along with # this program. If not, see . ###################################################################################################################### -""" -Contains export mappings for database items such as entities, entity classes and parameter values. -""" +# Base class for import and export mappings. from enum import Enum, unique from itertools import takewhile diff --git a/spinedb_api/parameter_value.py b/spinedb_api/parameter_value.py index d6e3741b..c2cd23ed 100644 --- a/spinedb_api/parameter_value.py +++ b/spinedb_api/parameter_value.py @@ -10,21 +10,19 @@ ###################################################################################################################### """ -Support utilities and classes to deal with Spine data (relationship) -parameter values. +Support utilities and classes to deal with Spine parameter values. -The `from_database` function reads the database's value format returning -a float, Datatime, Duration, TimePattern, TimeSeriesFixedResolution +The :func:`from_database` function receives the parameter value and type fields from the database returning +a float, Datetime, Duration, Array, TimePattern, TimeSeriesFixedResolution, TimeSeriesVariableResolution or Map objects. -The above objects can be converted back to the database format by the `to_database` free function +The above objects can be converted back to the database format by the :func:`to_database` free function or by their `to_database` member functions. Individual datetimes are represented as datetime objects from the standard Python library. Individual time steps are represented as relativedelta objects from the dateutil package. Datetime indexes (as returned by TimeSeries.indexes()) are represented as numpy.ndarray arrays holding numpy.datetime64 objects. - """ from collections.abc import Sequence diff --git a/spinedb_api/perfect_split.py b/spinedb_api/perfect_split.py index 82d9b56d..f9413317 100644 --- a/spinedb_api/perfect_split.py +++ b/spinedb_api/perfect_split.py @@ -10,7 +10,7 @@ ###################################################################################################################### """ -Functions for the perfect db split. +This module provides the :func:`perfect_split` function. """ from .db_mapping import DatabaseMapping @@ -19,12 +19,12 @@ def perfect_split(input_urls, intersection_url, diff_urls): - """Splits dbs into disjoint subsets. + """Splits DBs into disjoint subsets. Args: - input_urls (list(str)): List of urls to split - intersection_url (str): A url to store the data common to all input urls - diff_urls (list(str)): List of urls to store the differences of each input with respect to the intersection. + input_urls (list(str)): List of urls of DBs to split. + intersection_url (str): The url of a DB to store the data common to all input DBs (i.e., their intersection). + diff_urls (list(str)): List of urls of DBs to store the differences between each input and the intersection. """ diff_url_lookup = dict(zip(input_urls, diff_urls)) input_data_sets = {} diff --git a/spinedb_api/purge.py b/spinedb_api/purge.py index 99790306..20cfed62 100644 --- a/spinedb_api/purge.py +++ b/spinedb_api/purge.py @@ -10,7 +10,7 @@ ###################################################################################################################### """ -Functions to purge dbs. +Functions to purge DBs. """ @@ -35,11 +35,11 @@ def _ids_for_item_type(db_map, item_type): def purge_url(url, purge_settings, logger=None): - """Removes all given types of items from database. + """Removes all items of selected types from the database at a given URL. Args: url (str): database URL - purge_settings (dict): mapping from item type to boolean + purge_settings (dict): mapping from item type to a boolean indicating whether to remove them or not logger (LoggerInterface, optional): logger Returns: @@ -58,11 +58,11 @@ def purge_url(url, purge_settings, logger=None): def purge(db_map, purge_settings, logger=None): - """Removes items from database. + """Removes all items of selected types from a database. Args: db_map (DatabaseMapping): target database mapping - purge_settings (dict, optional): mapping from item type to purge flag + purge_settings (dict): mapping from item type to a boolean indicating whether to remove them or not logger (LoggerInterface): logger Returns: diff --git a/spinedb_api/query.py b/spinedb_api/query.py index 3df938e6..db23f1e7 100644 --- a/spinedb_api/query.py +++ b/spinedb_api/query.py @@ -9,7 +9,7 @@ # this program. If not, see . ###################################################################################################################### -"""Provides :class:`.Query`.""" +"""The :class:`Query` class.""" from sqlalchemy import select, and_ from sqlalchemy.sql.functions import count @@ -17,7 +17,14 @@ class Query: + """A clone of SQL Alchemy's :class:`~sqlalchemy.orm.query.Query`.""" + def __init__(self, bind, *entities): + """ + Args: + bind(Engine or Connection): An engine or connection to a DB against which the query will be executed. + entities(Iterable): A sequence of SQL expressions. + """ self._bind = bind self._entities = entities self._select = select(entities) diff --git a/spinedb_api/server_client_helpers.py b/spinedb_api/server_client_helpers.py index dcbcef95..ae7715a0 100644 --- a/spinedb_api/server_client_helpers.py +++ b/spinedb_api/server_client_helpers.py @@ -9,13 +9,7 @@ # this program. If not, see . ###################################################################################################################### -""" -Helpers for server and client. - -""" - import json -from .import_functions import ImportErrorLogItem from .exception import SpineDBAPIError # Encode decode server messages @@ -52,7 +46,7 @@ def _recvall(self): class _TailJSONEncoder(json.JSONEncoder): """ A custom JSON encoder that accummulates bytes objects into a tail. - The bytes object are encoded as a string pointing to the address in the tail. + Each bytes object is encoded as a string pointing to the address in the tail. """ def __init__(self): @@ -70,7 +64,7 @@ def default(self, o): return address if isinstance(o, set): return list(o) - if isinstance(o, (SpineDBAPIError, ImportErrorLogItem)): + if isinstance(o, SpineDBAPIError): return str(o) return super().default(o) @@ -81,7 +75,7 @@ def tail(self): def encode(o): """ - Encodes given object (representing a server response) into a message with the following structure: + Encodes given object into a message to be sent via a socket, with the following structure: body | start of tail character | tail @@ -90,10 +84,10 @@ def encode(o): See class:`_TailJSONEncoder`. Args: - o (any): A Python object representing a server response. + o (any): A Python object to encode. Returns: - bytes: A message to the client. + bytes: Encoded message. """ encoder = _TailJSONEncoder() s = encoder.encode(o) @@ -102,7 +96,7 @@ def encode(o): def decode(b): """ - Decodes given message (representing a client request) into a Python object. + Decodes given message received via a socket into a Python object. The message must have the following structure: body | start of tail character | tail @@ -111,10 +105,10 @@ def decode(b): from the tail. Args: - b (bytes): A message from the client. + b (bytes): A message to decode. Returns: - any: A Python object representing a client request. + any: Decoded object. """ body, tail = b.split(_START_OF_TAIL.encode()) o = json.loads(body) diff --git a/spinedb_api/spine_db_client.py b/spinedb_api/spine_db_client.py index b42c4440..cbd2eeaa 100644 --- a/spinedb_api/spine_db_client.py +++ b/spinedb_api/spine_db_client.py @@ -10,8 +10,7 @@ ###################################################################################################################### """ -Contains the SpineDBClient class. - +The :class:`SpineDBClient` class. """ from urllib.parse import urlparse @@ -24,7 +23,8 @@ class SpineDBClient(ReceiveAllMixing): def __init__(self, server_address): - """ + """Represents a client connection to a Spine DB server. + Args: server_address (tuple(str,int)): hostname and port """ @@ -33,34 +33,62 @@ def __init__(self, server_address): @classmethod def from_server_url(cls, url): + """Creates a client from a server's URL. + + Args: + url (str, URL): the url of a Spine DB server. + """ parsed = urlparse(url) if parsed.scheme != "http": raise ValueError(f"unable to create client, invalid server url {url}") return cls((parsed.hostname, parsed.port)) def get_db_url(self): - """ + """Returns the URL of the Spine DB associated with the server. + Returns: - str: The underlying db url from the server + str """ return self._send("get_db_url") def db_checkin(self): + """Blocks until all the servers that need to write to the same DB before this one + have reported all their writes.""" return self._send("db_checkin") def db_checkout(self): + """Reports one write for this server.""" return self._send("db_checkout") def cancel_db_checkout(self): + """Reverts the last write report for this server.""" return self._send("cancel_db_checkout") def import_data(self, data, comment): + """Imports data to the DB using :func:`spinedb_api.import_functions.import_data` and commits the changes. + + Args: + data (dict): to be splatted into keyword arguments to :func:`spinedb_api.import_functions.import_data` + comment (str): a commit message. + """ return self._send("import_data", args=(data, comment)) def export_data(self, **kwargs): + """Exports data from the DB using :func:`spinedb_api.export_functions.export_data`. + + Args: + kwargs: keyword arguments passed to :func:`spinedb_api.import_functions.import_data` + """ return self._send("export_data", kwargs=kwargs) def call_method(self, method_name, *args, **kwargs): + """Calls a method from :class:`spinedb_api.db_mapping.DatabaseMapping`. + + Args: + method_name (str): the name of the method to call + args: positional arguments passed to the method call + kwargs: keyword arguments passed to the method call + """ return self._send("call_method", args=(method_name, *args), kwargs=kwargs) def open_db_map(self, db_url, upgrade, memory): @@ -93,16 +121,6 @@ def _send(self, request, args=None, kwargs=None, receive=True): def get_db_url_from_server(url): - """Returns the underlying db url associated with the given url, if it's a server url. - Otherwise, it assumes it's the url of DB and returns it unaltered. - Used by ``DatabaseMappingBase()``. - - Args: - url (str, URL): a url, either from a Spine DB or from a Spine DB server. - - Returns: - str - """ if isinstance(url, URL): return url parsed = urlparse(url) diff --git a/spinedb_api/spine_db_server.py b/spinedb_api/spine_db_server.py index 5915c8b4..534967f5 100644 --- a/spinedb_api/spine_db_server.py +++ b/spinedb_api/spine_db_server.py @@ -10,7 +10,32 @@ ###################################################################################################################### """ -Contains the SpineDBServer class. +Spine DB server +=============== + +The Spine DB server provides almost the same functionality as :class:`spinedb_api.db_mapping.DatabaseMapping`, +but it does it via a socket. This removes the ``spinedb_api`` requirement (and the Python requirement altogether) +from third-party applications that want to interact with Spine DBs. + +Typically this is done in the following steps: + #. Start a server by specifying the URL of the Spine DB that you want to interact with. + #. Communicate the URL of the server to your third-party application running in another process. + #. Send requests from your application to the server via sockets in order to interact with the DB. + +Available requests +------------------ +TODO + +Encoding/decoding +----------------- +TODO + +This module also provides a mechanism to control the order in which multiples servers +running in parallel should write to the same DB. + +The server is started using :func:`closing_spine_db_server`. +If you want to also control order of writing from multiple servers, +you first need to obtain an 'ordering queue' using :func:`db_server_manager`. """ @@ -543,10 +568,6 @@ def close(self): class DBRequestHandler(ReceiveAllMixing, HandleDBMixin, socketserver.BaseRequestHandler): - """ - The request handler class for our server. - """ - @property def server_address(self): return self.server.server_address @@ -566,15 +587,6 @@ def quick_db_checkout(server_manager_queue, ordering): def start_spine_db_server(server_manager_queue, db_url, upgrade=False, memory=False, ordering=None): - """ - Args: - db_url (str): Spine db url - upgrade (bool): Whether to upgrade db or not - memory (bool): Whether to use an in-memory database together with a persistent connection to it - - Returns: - tuple: server address (e.g. (127.0.0.1, 54321)) - """ handler = _ManagerRequestHandler(server_manager_queue) server_address = handler.start_server(db_url, upgrade, memory, ordering) return server_address @@ -586,17 +598,65 @@ def shutdown_spine_db_server(server_manager_queue, server_address): @contextmanager -def closing_spine_db_server(server_manager_queue, db_url, upgrade=False, memory=False, ordering=None): +def closing_spine_db_server(db_url, upgrade=False, memory=False, ordering=None, server_manager_queue=None): + """Creates a Spine DB server. + + Example:: + + with closing_spine_db_server(db_url) as server_url: + client = SpineDBClient.from_server_url(server_url) + data = client.import_data({"entity_class": [("fish", ()), ("dog", ())]}, "Add two entity classes.") + + + Args: + db_url (str): the URL of a Spine DB. + upgrade (bool): Whether to upgrade the DB to the last revision. + memory (bool): Whether to use an in-memory database together with a persistent connection. + server_manager_queue (Queue,optional): A queue that can be used to control order of writing. + Only needed if you also specify `ordering` below. + ordering (dict,optional): A dictionary specifying an ordering to be followed by multiple concurrent servers + writing to the same DB. It must have the following keys: + - "id": an identifier for the ordering, shared by all the servers in the ordering. + - "current": an identifier for this server within the ordering. + - "precursors": a set of identifiers of other servers that must have written to the DB before this server can write. + - "part_count": the number of times this server needs to write to the DB before their successors can write. + + Yields: + str: server url + """ + if server_manager_queue is None: + mngr = _DBServerManager() + server_manager_queue = mngr.queue + else: + mngr = None server_address = start_spine_db_server(server_manager_queue, db_url, memory=memory, ordering=ordering) host, port = server_address try: yield urlunsplit(("http", f"{host}:{port}", "", "", "")) finally: shutdown_spine_db_server(server_manager_queue, server_address) + if mngr is not None: + mngr.shutdown() @contextmanager def db_server_manager(): + """Creates a DB server manager that can be used to control the order in which different servers + write to the same DB. + + Example:: + + with db_server_manager() as mngr_queue: + with closing_spine_db_server(db_url, server_manager_queue=mngr_queue) as server1_url: + with closing_spine_db_server(db_url, server_manager_queue=mngr_queue) as server1_url: + client1 = SpineDBClient.from_server_url(server_url1) + client2 = SpineDBClient.from_server_url(server_url2) + # TODO: ordering + + Yields: + :class:`~multiprocessing.queues.Queue`: a queue that can be passed to :func:`.closing_spine_db_server` + in order to control write order. + """ mngr = _DBServerManager() try: yield mngr.queue diff --git a/spinedb_api/temp_id.py b/spinedb_api/temp_id.py index 3a57ce6f..501a797f 100644 --- a/spinedb_api/temp_id.py +++ b/spinedb_api/temp_id.py @@ -8,10 +8,6 @@ # Public License for more details. You should have received a copy of the GNU Lesser General Public License along with # this program. If not, see . ###################################################################################################################### -""" -Temp id stuff. - -""" class TempId(int): diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 19cc8a6e..3b6d2a90 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -43,12 +43,27 @@ def query_wrapper(*args, orig_query=db_map.query, **kwargs): IN_MEMORY_DB_URL = "sqlite://" +class TestDatabaseMappingPublic(unittest.TestCase): + _db_map = None + + @classmethod + def setUpClass(cls): + cls._db_map = DatabaseMapping(IN_MEMORY_DB_URL, create=True) + + @classmethod + def tearDownClass(cls): + cls._db_map.close() + + def test_getters(self): + print(dir(self._db_map.get_entity_class)) + + class TestDatabaseMappingConstruction(unittest.TestCase): def test_construction_with_filters(self): db_url = IN_MEMORY_DB_URL + "?spinedbfilter=fltr1&spinedbfilter=fltr2" - with mock.patch("spinedb_api.db_mapping.apply_filter_stack") as mock_apply: + with mock.patch("spinedb_api.db_mapping_base.apply_filter_stack") as mock_apply: with mock.patch( - "spinedb_api.db_mapping.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] + "spinedb_api.db_mapping_base.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] ) as mock_load: db_map = DatabaseMapping(db_url, create=True) db_map.close() @@ -58,9 +73,9 @@ def test_construction_with_filters(self): def test_construction_with_sqlalchemy_url_and_filters(self): db_url = IN_MEMORY_DB_URL + "/?spinedbfilter=fltr1&spinedbfilter=fltr2" sa_url = make_url(db_url) - with mock.patch("spinedb_api.db_mapping.apply_filter_stack") as mock_apply: + with mock.patch("spinedb_api.db_mapping_base.apply_filter_stack") as mock_apply: with mock.patch( - "spinedb_api.db_mapping.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] + "spinedb_api.db_mapping_base.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] ) as mock_load: db_map = DatabaseMapping(sa_url, create=True) db_map.close() @@ -97,9 +112,9 @@ def tearDownClass(cls): def test_construction_with_filters(self): db_url = IN_MEMORY_DB_URL + "?spinedbfilter=fltr1&spinedbfilter=fltr2" - with patch("spinedb_api.db_mapping.apply_filter_stack") as mock_apply: + with patch("spinedb_api.db_mapping_base.apply_filter_stack") as mock_apply: with patch( - "spinedb_api.db_mapping.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] + "spinedb_api.db_mapping_base.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] ) as mock_load: db_map = DatabaseMapping(db_url, create=True) db_map.close() @@ -109,9 +124,9 @@ def test_construction_with_filters(self): def test_construction_with_sqlalchemy_url_and_filters(self): sa_url = URL("sqlite") sa_url.query = {"spinedbfilter": ["fltr1", "fltr2"]} - with patch("spinedb_api.db_mapping.apply_filter_stack") as mock_apply: + with patch("spinedb_api.db_mapping_base.apply_filter_stack") as mock_apply: with patch( - "spinedb_api.db_mapping.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] + "spinedb_api.db_mapping_base.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] ) as mock_load: db_map = DatabaseMapping(sa_url, create=True) db_map.close() @@ -2217,10 +2232,10 @@ def test_rollback_addition(self): import_functions.import_object_classes(self._db_map, ("my_class",)) self._db_map.commit_session("test commit") import_functions.import_object_classes(self._db_map, ("second_class",)) - entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").values()} + entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").valid_values()} self.assertEqual(entity_class_names, {"my_class", "second_class"}) self._db_map.rollback_session() - entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").values()} + entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").valid_values()} self.assertEqual(entity_class_names, {"my_class"}) with self.assertRaises(SpineDBAPIError): # Nothing to commit @@ -2230,10 +2245,10 @@ def test_rollback_removal(self): import_functions.import_object_classes(self._db_map, ("my_class",)) self._db_map.commit_session("test commit") self._db_map.remove_items("entity_class", 1) - entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").values()} + entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").valid_values()} self.assertEqual(entity_class_names, set()) self._db_map.rollback_session() - entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").values()} + entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").valid_values()} self.assertEqual(entity_class_names, {"my_class"}) with self.assertRaises(SpineDBAPIError): # Nothing to commit @@ -2242,11 +2257,11 @@ def test_rollback_removal(self): def test_rollback_update(self): import_functions.import_object_classes(self._db_map, ("my_class",)) self._db_map.commit_session("test commit") - self._db_map.update_items("entity_class", {"id": {"name": "my_class"}, "name": "new_name"}) - entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").values()} + self._db_map.get_item("entity_class", name="my_class").update(name="new_name") + entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").valid_values()} self.assertEqual(entity_class_names, {"new_name"}) self._db_map.rollback_session() - entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").values()} + entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").valid_values()} self.assertEqual(entity_class_names, {"my_class"}) with self.assertRaises(SpineDBAPIError): # Nothing to commit @@ -2256,33 +2271,33 @@ def test_refresh_addition(self): import_functions.import_object_classes(self._db_map, ("my_class",)) self._db_map.commit_session("test commit") import_functions.import_object_classes(self._db_map, ("second_class",)) - entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").values()} + entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").valid_values()} self.assertEqual(entity_class_names, {"my_class", "second_class"}) self._db_map.refresh_session() self._db_map.fetch_all() - entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").values()} + entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").valid_values()} self.assertEqual(entity_class_names, {"my_class", "second_class"}) def test_refresh_removal(self): import_functions.import_object_classes(self._db_map, ("my_class",)) self._db_map.commit_session("test commit") self._db_map.remove_items("entity_class", 1) - entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").values()} + entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").valid_values()} self.assertEqual(entity_class_names, set()) self._db_map.refresh_session() self._db_map.fetch_all() - entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").values()} + entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").valid_values()} self.assertEqual(entity_class_names, set()) def test_refresh_update(self): import_functions.import_object_classes(self._db_map, ("my_class",)) self._db_map.commit_session("test commit") - self._db_map.update_items("entity_class", {"id": {"name": "my_class"}, "name": "new_name"}) - entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").values()} + self._db_map.get_item("entity_class", name="my_class").update(name="new_name") + entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").valid_values()} self.assertEqual(entity_class_names, {"new_name"}) self._db_map.refresh_session() self._db_map.fetch_all() - entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").values()} + entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").valid_values()} self.assertEqual(entity_class_names, {"new_name"}) def test_cascade_remove_unfetched(self): diff --git a/tests/test_db_cache_base.py b/tests/test_db_cache_base.py index 012fe647..524bde5b 100644 --- a/tests/test_db_cache_base.py +++ b/tests/test_db_cache_base.py @@ -15,11 +15,11 @@ class TestCache(DBCacheBase): @property - def _item_types(self): + def item_types(self): return ["cutlery"] @staticmethod - def _item_factory(item_type): + def item_factory(item_type): if item_type == "cutlery": return CacheItemBase raise RuntimeError(f"unknown item_type '{item_type}'") From 6a5c23f7d933992a3d97f8c0a57d232c384baa38 Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 3 Oct 2023 10:41:46 +0200 Subject: [PATCH 102/317] Fix bad method renaming --- spinedb_api/db_mapping_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 9a12f724..c1f197ca 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -228,7 +228,7 @@ def _make_codename(self, codename): return hashing.hexdigest() @staticmethod - def _create_engine(sa_url, upgrade=False, create=False, sqlite_timeout=1800): + def create_engine(sa_url, upgrade=False, create=False, sqlite_timeout=1800): """Creates engine. Args From cb4ffea4f5f33620d6abab938095a39e6bec23d7 Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 3 Oct 2023 14:19:57 +0200 Subject: [PATCH 103/317] Improve documentation --- docs/source/conf.py | 2 +- docs/source/tutorial.rst | 179 ++++++++++++---- spinedb_api/db_cache_base.py | 44 ++-- spinedb_api/db_cache_impl.py | 86 ++++++++ spinedb_api/db_mapping.py | 254 ++++++++++++++++++----- spinedb_api/db_mapping_add_mixin.py | 37 ++-- spinedb_api/db_mapping_base.py | 70 +++---- spinedb_api/db_mapping_commit_mixin.py | 4 +- spinedb_api/db_mapping_remove_mixin.py | 10 - spinedb_api/db_mapping_update_mixin.py | 46 ++-- spinedb_api/export_functions.py | 4 +- spinedb_api/graph_layout_generator.py | 24 ++- spinedb_api/helpers.py | 8 +- spinedb_api/spine_db_server.py | 4 +- spinedb_api/spine_io/exporters/writer.py | 2 +- spinedb_api/temp_id.py | 2 +- tests/test_DatabaseMapping.py | 15 -- 17 files changed, 537 insertions(+), 254 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 643000a6..4e0b6a10 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -85,7 +85,7 @@ pygments_style = 'sphinx' # Settings for Sphinx AutoAPI -autoapi_options = ['members', 'inherited-members'] +autoapi_options = ['members', 'inherited-members', 'show-module-summary'] autoapi_python_class_content = "both" autoapi_add_toctree_entry = True autoapi_root = "autoapi" diff --git a/docs/source/tutorial.rst b/docs/source/tutorial.rst index e7a3fcbf..d966a5e7 100644 --- a/docs/source/tutorial.rst +++ b/docs/source/tutorial.rst @@ -8,70 +8,169 @@ Tutorial ******** -Spine database API provides for the creation and management of -Spine databases, using SQLAlchemy_ as the underlying engine. -This tutorial will provide a full introduction to the usage of this package. +The Spine DB API allows one to create and manipulate +Spine databases in a standard way, using SQLAlchemy_ as the underlying engine. +This tutorial provides a quick introduction to the usage of the package. To begin, make sure Spine database API is installed as described at :ref:`installation`. -Creation --------- +Database Mapping +---------------- -Usage of Spine database API starts with the creation of a Spine database. +The main means of communication with a Spine DB is the :class:`.DatabaseMapping`, +specially designed to retrieve and modify data from the DB. +To create a :class:`.DatabaseMapping`, we just pass the URL of the DB to the class constructor:: -Mapping -------- + from spinedb_api import DatabaseMapping + + url = "mysql://spine_db" # The URL of an existing Spine DB -Next step is the creation of a *Database Mapping*, -a Python object that provides means of interacting with the database. -Spine database API provides two classes of mapping: + with DatabaseMapping(url) as db_map: + # Do something with db_map + pass + +The URL should be formatted following the RFC-1738 standard, as described +`here `_. -- :class:`.DatabaseMapping`, just for *querying* the database (i.e., run ``SELECT`` statements). -- :class:`.DiffDatabaseMapping`, for both querying and *modifying* the database. +.. note:: + + Currently supported database backends are only SQLite and MySQL. More will be added later. -The differences between these two will become more apparent as we go through this tutorial. -However, it is important to note that everything you can do with a :class:`.DatabaseMapping`, -you can also do with a :class:`.DiffDatabaseMapping`. +Creating a DB +------------- -To create a :class:`.DatabaseMapping`, we just pass the database URL to the class constructor:: +If you're following this tutorial, chances are you don't have a Spine DB to play with just yet. +We can remediate this by creating a SQLite DB (which is just a file in your system), as follows:: from spinedb_api import DatabaseMapping - url = "sqlite:///spine.db" + url = "sqlite:///first.sqlite" - db_map = DatabaseMapping(url) + with DatabaseMapping(url, create=True) as db_map: + # Do something with db_map + pass -The URL should be formatted following the RFC-1738 standard, so it basically -works with :func:`sqlalchemy.create_engine` as described -`here `_. +The above will create a file called ``first.sqlite`` in your current working directoy. +Note that we pass the keyword argument ``create=True`` to :class:`.DatabaseMapping` to explicitely say +that we want the DB to be created at the given URL. .. note:: - Currently supported database backends are only SQLite and MySQL. More will be added later. + In the remainder we will skip the above step and work directly with ``db_map``. In other words, + all the examples below assume we are inside the ``with`` block above. + +Adding data +----------- + +To insert data, we use :meth:`~.DatabaseMapping.add_item`. + +Let's begin the party by adding a couple of entity classes:: + + db_map.add_item("entity_class", name="fish", description="It swims.") + db_map.add_item("entity_class", name="cat", description="Eats fish.") + +Now let's add a multi-dimensional entity class between the two above. For this we need to specify the class names +as `dimensions`:: + + db_map.add_item( + "entity_class", + name="fish__cat", + dimension_name_list=("fish", "cat"), + description="A fish getting eaten by a cat?", + ) + + +Let's add entities to our zero-dimensional classes:: + + db_map.add_item("entity", class_name="fish", name="Nemo", description="Lost (for now).") + db_map.add_item( + "entity", class_name="cat", name="Felix", description="The wonderful wonderful cat." + ) + +Let's add a multi-dimensional entity to our multi-dimensional class. For this we need to specify the entity names +as `elements`:: + + db_map.add_item("entity", class_name="fish__cat", element_name_list=("Nemo", "Felix")) + +Let's add a parameter definition for one of our entity classes:: + + db_map.add_item("parameter_definition", entity_class_name="fish", name="color") + +Finally, let's specify a parameter value for one of our entities:: + + db_map.add_item( + "parameter_value", + entity_class_name="fish", + entity_name="Nemo", + parameter_definition_name="color", + value="mainly orange" + ) + +.. note:: + + The data we've added so far is not yet in the DB, but only in a in-memory mapping within our ``db_map`` object. + + +Retrieving data +--------------- + +To retrieve data from the DB, we use :meth:`~.DatabaseMapping.get_item`. +For example, let's find one of the entities we inserted above:: + + felix = db_map.get_item("entity", class_name="cat", name="Felix") + print(felix["description"]) # Prints 'The wonderful wonderful cat.' + + +Above, ``felix`` is a :class:`~.PublicItem` object, representing an item (or row) in a Spine DB. + +Let's find our multi-dimensional entity:: + + nemo_felix = db_map.get_item("entity", class_name="fish__cat", byname=("Nemo", "Felix")) + print(nemo_felix["dimension_name_list"]) # Prints '(fish, cat)' + +To retrieve all the items of a given type, we use :meth:`~.DatabaseMapping.get_items`:: + + print(entity["byname"] for entity in db_map.get_items("entity")) + # Prints [("Nemo",), ("Felix",), ("Nemo", "Felix"),] + +.. note:: + + You should use the above to try and find Nemo! + + +Updating data +------------- + +To update data, we use the :meth:`~.PublicItem.update` method of :class:`~.PublicItem`. + +Let's rename our fish entity to avoid any copyright infringements:: + + db_map.get_item("entity", class_name="fish", name="Nemo").update(name="NotNemo") + +To be safe, let's also change the color:: + + db_map.get_item( + "parameter_value", + entity_class_name="fish", + parameter_definition_name="color", + entity_name="NotNemo" + ).update(value="definitely purple") + + +Note how we need to use then new entity name (``"NotNemo"``) to retrieve the parameter value. This makes sense. + +Removing data +------------- -Querying --------- +You know what, let's just remove the entity entirely. +To do this we use the :meth:`~.PublicItem.remove` method of :class:`~.PublicItem`:: -The database mapping object provides two mechanisms for querying the database. -The first is for running *standard*, general-purpose queries -such as selecting all records from the ``object_class`` table. -The second is for performing *custom* queries that one may need for a particular purpose. + db_map.get_item("entity", class_name="fish", name="NotNemo").remove() -Standard querying -================= -To perform standard querying, we chose among the methods of the :class:`~.DatabaseMappingQueryMixin` class, -the one that bets suits our purpose. E.g.:: - TODO -Custom querying -=============== -TODO -Inserting ---------- -TODO diff --git a/spinedb_api/db_cache_base.py b/spinedb_api/db_cache_base.py index 9143bc44..eb582e68 100644 --- a/spinedb_api/db_cache_base.py +++ b/spinedb_api/db_cache_base.py @@ -30,7 +30,12 @@ class Status(Enum): class DBCacheBase(dict): - """A dictionary that maps table names to ids to items. Used to store and retrieve database contents.""" + """A dictionary representation of a DB, mapping item types (table names), to numeric ids, to items. + + This class is not meant to be used directly. Instead, you need to subclass it for each DB schema you want to use. + + When subclassing, you need to implement :attr:`item_types`, :meth:`item_factory`, and :meth:`query`. + """ def __init__(self): super().__init__() @@ -57,33 +62,33 @@ def fetched_item_types(self): @property def item_types(self): - """Returns a list of supported item type strings. + """Returns a list of item types in the DB (equivalent to the table names). Returns: - list + list(str) """ raise NotImplementedError() @staticmethod def item_factory(item_type): - """Returns a subclass of CacheItemBase to build items of given type. + """Returns a subclass of :class:`.CacheItemBase` to make items of given type. Args: item_type (str) Returns: - CacheItemBase + function """ raise NotImplementedError() def query(self, item_type): - """Returns a Query object to fecth items of given type. + """Returns a :class:`~spinedb_api.query.Query` object to fecth items of given type. Args: item_type (str) Returns: - Query + :class:`~spinedb_api.query.Query` """ raise NotImplementedError() @@ -249,7 +254,7 @@ class _TableCache(dict): def __init__(self, db_cache, item_type, *args, **kwargs): """ Args: - db_cache (DBCache): the DB cache where this table cache belongs. + db_cache (DBCacheBase): the DB cache where this table cache belongs. item_type (str): the item type, equal to a table name """ super().__init__(*args, **kwargs) @@ -271,20 +276,20 @@ def _callback(db_id): temp_id.add_resolve_callback(_callback) return temp_id - def unique_key_value_to_id(self, key, value, strict=False): + def unique_key_value_to_id(self, key, value, strict=False, fetch=True): """Returns the id that has the given value for the given unique key, or None if not found. - Fetches until being sure. Args: key (tuple) value (tuple) strict (bool): if True, raise a KeyError if id is not found + fetch (bool): whether to fetch the DB until found. Returns: int """ id_by_unique_value = self._id_by_unique_key_value.get(key, {}) - if not id_by_unique_value: + if not id_by_unique_value and fetch: id_by_unique_value = self._db_cache.fetch_value( self._item_type, lambda: self._id_by_unique_key_value.get(key, {}) ) @@ -293,8 +298,8 @@ def unique_key_value_to_id(self, key, value, strict=False): return id_by_unique_value[value] return id_by_unique_value.get(value) - def _unique_key_value_to_item(self, key, value): - return self.get(self.unique_key_value_to_id(key, value)) + def _unique_key_value_to_item(self, key, value, fetch=True): + return self.get(self.unique_key_value_to_id(key, value, fetch=fetch)) def valid_values(self): return (x for x in self.values() if x.is_valid()) @@ -310,7 +315,7 @@ def _make_item(self, item): """ return self._db_cache.make_item(self._item_type, **item) - def find_item(self, item, skip_keys=()): + def find_item(self, item, skip_keys=(), fetch=True): """Returns a CacheItemBase that matches the given dictionary-item. Args: @@ -322,7 +327,10 @@ def find_item(self, item, skip_keys=()): id_ = item.get("id") if id_ is not None: # id is given, easy - return self.get(id_) or self._db_cache.fetch_ref(self._item_type, id_) + item = self.get(id_) + if item or not fetch: + return item + return self._db_cache.fetch_ref(self._item_type, id_) # No id. Try to locate the item by the value of one of the unique keys. # Used by import_data (and more...) cache_item = self._make_item(item) @@ -333,7 +341,7 @@ def find_item(self, item, skip_keys=()): if error: return None for key, value in cache_item.unique_values(skip_keys=skip_keys): - current_item = self._unique_key_value_to_item(key, value) + current_item = self._unique_key_value_to_item(key, value, fetch=fetch) if current_item: return current_item @@ -431,6 +439,8 @@ def restore_item(self, id_): class CacheItemBase(dict): """A dictionary that represents a db item.""" + _fields = {} + """A dictionaty mapping fields to a tuple of (type, description)""" _defaults = {} """A dictionary mapping keys to their default values""" _unique_keys = () @@ -460,7 +470,7 @@ class CacheItemBase(dict): def __init__(self, db_cache, item_type, **kwargs): """ Args: - db_cache (DBCache): the DB cache where this item belongs. + db_cache (DBCacheBase): the DB cache where this item belongs. """ super().__init__(**kwargs) self._db_cache = db_cache diff --git a/spinedb_api/db_cache_impl.py b/spinedb_api/db_cache_impl.py index 8250554f..3edf6224 100644 --- a/spinedb_api/db_cache_impl.py +++ b/spinedb_api/db_cache_impl.py @@ -76,6 +76,14 @@ def query(self, item_type): class EntityClassItem(CacheItemBase): + _fields = { + "name": ("str", "The class name."), + "dimension_name_list": ("tuple, optional", "The dimension names for a multi-dimensional class."), + "description": ("str, optional", "The class description."), + "display_icon": ("int, optional", "An integer representing an icon within your application."), + "display_order": ("int, optional", "Not in use at the moment"), + "hidden": ("bool, optional", "Not in use at the moment"), + } _defaults = {"description": None, "display_icon": None, "display_order": 99, "hidden": False} _unique_keys = (("name",),) _references = {"dimension_name_list": ("dimension_id_list", ("entity_class", "name"))} @@ -105,6 +113,12 @@ def commit(self, _commit_id): class EntityItem(CacheItemBase): + _fields = { + "class_name": ("str", "The entity class name."), + "name": ("str, optional", "The entity name - must be given for a zero-dimensional entity."), + "element_name_list": ("tuple, optional", "The element names - must be given for a multi-dimensional entity."), + "description": ("str, optional", "The entity description."), + } _defaults = {"description": None} _unique_keys = (("class_name", "name"), ("class_name", "byname")) _references = { @@ -147,6 +161,11 @@ def polish(self): class EntityGroupItem(CacheItemBase): + _fields = { + "class_name": ("str", "The entity class name."), + "group_name": ("str", "The group entity name."), + "member_name": ("str", "The member entity name."), + } _unique_keys = (("group_name", "member_name"),) _references = { "class_name": ("entity_class_id", ("entity_class", "name")), @@ -169,6 +188,15 @@ def __getitem__(self, key): class EntityAlternativeItem(CacheItemBase): + _fields = { + "entity_class_name": ("str", "The entity class name."), + "entity_byname": ( + "str or tuple", + "The entity name for a zero-dimensional entity, or the element name list for a multi-dimensional one.", + ), + "alternative_name": ("str", "The alternative name."), + "active": ("bool, optional", "Whether the entity is active in the alternative - defaults to True."), + } _defaults = {"active": True} _unique_keys = (("entity_class_name", "entity_byname", "alternative_name"),) _references = { @@ -213,6 +241,14 @@ def __getitem__(self, key): class ParameterDefinitionItem(ParsedValueBase): + _fields = { + "entity_class_name": ("str", "The entity class name."), + "name": ("str", "The parameter name."), + "default_value": ("any, optional", "The default value."), + "default_type": ("str, optional", "The default value type."), + "parameter_value_list_name": ("str, optional", "The parameter value list name if any."), + "description": ("str, optional", "The parameter description."), + } _defaults = {"description": None, "default_value": None, "default_type": None, "parameter_value_list_id": None} _unique_keys = (("entity_class_name", "name"),) _references = { @@ -303,6 +339,17 @@ def merge(self, other): class ParameterValueItem(ParsedValueBase): + _fields = { + "entity_class_name": ("str", "The entity class name."), + "parameter_definition_name": ("str", "The parameter name."), + "entity_byname": ( + "str or tuple", + "The entity name for a zero-dimensional entity, or the element name list for a multi-dimensional one.", + ), + "value": ("any", "The value."), + "type": ("str", "The value type."), + "alternative_name": ("str, optional", "The alternative name - defaults to 'Base'."), + } _unique_keys = (("entity_class_name", "parameter_definition_name", "entity_byname", "alternative_name"),) _references = { "entity_class_name": ("entity_class_id", ("entity_class", "name")), @@ -353,6 +400,9 @@ def __getitem__(self, key): return super().__getitem__(key) def polish(self): + error = super().polish() + if error: + return error list_name = self["parameter_value_list_name"] if list_name is None: return @@ -382,10 +432,17 @@ def callback(new_id): class ParameterValueListItem(CacheItemBase): + _fields = {"name": ("str", "The parameter value list name.")} _unique_keys = (("name",),) class ListValueItem(ParsedValueBase): + _fields = { + "parameter_value_list_name": ("str", "The parameter value list name."), + "value": ("any", "The value."), + "type": ("str", "The value type."), + "index": ("int, optional", "The value index."), + } _unique_keys = (("parameter_value_list_name", "value", "type"), ("parameter_value_list_name", "index")) _references = {"parameter_value_list_name": ("parameter_value_list_id", ("parameter_value_list", "name"))} _inverse_references = { @@ -400,11 +457,20 @@ def _make_parsed_value(self): class AlternativeItem(CacheItemBase): + _fields = { + "name": ("str", "The alternative name."), + "description": ("str, optional", "The alternative description."), + } _defaults = {"description": None} _unique_keys = (("name",),) class ScenarioItem(CacheItemBase): + _fields = { + "name": ("str", "The scenario name."), + "description": ("str, optional", "The scenario description."), + "active": ("bool, optional", "Not in use at the moment."), + } _defaults = {"active": False, "description": None} _unique_keys = (("name",),) @@ -427,6 +493,11 @@ def __getitem__(self, key): class ScenarioAlternativeItem(CacheItemBase): + _fields = { + "scenario_name": ("str", "The scenario name."), + "alternative_name": ("str", "The alternative name."), + "rank": ("int", "The rank - the higher has precedence."), + } _unique_keys = (("scenario_name", "alternative_name"), ("scenario_name", "rank")) _references = { "scenario_name": ("scenario_id", ("scenario", "name")), @@ -454,10 +525,16 @@ def __getitem__(self, key): class MetadataItem(CacheItemBase): + _fields = {"name": ("str", "The metadata entry name."), "value": ("str", "The metadata entry value.")} _unique_keys = (("name", "value"),) class EntityMetadataItem(CacheItemBase): + _fields = { + "entity_name": ("str", "The entity name."), + "metadata_name": ("str", "The metadata entry name."), + "metadata_value": ("str", "The metadata entry value."), + } _unique_keys = (("entity_name", "metadata_name", "metadata_value"),) _references = { "entity_name": ("entity_id", ("entity", "name")), @@ -471,6 +548,15 @@ class EntityMetadataItem(CacheItemBase): class ParameterValueMetadataItem(CacheItemBase): + _fields = { + "parameter_definition_name": ("str", "The parameter name."), + "entity_byname": ( + "str or tuple", + "The entity name for a zero-dimensional entity, or the element name list for a multi-dimensional one.", + ), + "alternative_name": ("str", "The alternative name."), + "metadata_value": ("str", "The metadata entry value."), + } _unique_keys = ( ("parameter_definition_name", "entity_byname", "alternative_name", "metadata_name", "metadata_value"), ) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 1dfe84d2..69520889 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -19,6 +19,7 @@ from .db_mapping_update_mixin import DatabaseMappingUpdateMixin from .db_mapping_remove_mixin import DatabaseMappingRemoveMixin from .db_mapping_commit_mixin import DatabaseMappingCommitMixin +from .db_cache_impl import DBCache class DatabaseMapping( @@ -30,67 +31,184 @@ class DatabaseMapping( ): """Enables communication with a Spine DB. - An in-memory clone (ORM) of the DB is incrementally formed as data is requested/modified. + A mapping of the DB is incrementally created in memory as data is requested/modified. Data is typically retrieved using :meth:`get_item` or :meth:`get_items`. - If the requested data is already in the in-memory clone, it is returned from there; - otherwise it is fetched from the DB, stored in the clone, and then returned. + If the requested data is already in memory, it is returned from there; + otherwise it is fetched from the DB, stored in memory, and then returned. In other words, the data is fetched from the DB exactly once. - Data is added via :meth:`add_item` or :meth:`add_items`; - updated via :meth:`update_item` or :meth:`update_items`; - removed via :meth:`remove_item` or :meth:`remove_items`; - and restored via :meth:`restore_item` or :meth:`restore_items`. - All the above methods modify the in-memory clone (not the DB itself). - These methods also fetch data from the DB into the in-memory clone to perform the necessary integrity checks - (unique constraints, foreign key constraints) as needed. + Data is added via :meth:`add_item`; + updated via :meth:`update_item`; + removed via :meth:`remove_item`; + and restored via :meth:`restore_item`. + All the above methods modify the in-memory mapping (not the DB itself). + These methods also fetch data from the DB into the in-memory mapping to perform the necessary integrity checks + (unique and foreign key constraints). - Modifications to the in-memory clone are committed (written) to the DB via :meth:`commit_session`, + To retrieve an item or to manipulate it, you typically need to specify certain fields. + The :meth:`describe_item_type` method is provided to help you identify these fields. + + Modifications to the in-memory mapping are committed (written) to the DB via :meth:`commit_session`, or rolled back (discarded) via :meth:`rollback_session`. - The in-memory clone is reset via :meth:`refresh_session`. + The DB fetch status is reset via :meth:`refresh_session`. + This causes new items in the DB to be merged into the memory mapping as data is further requested/modified. You can also control the fetching process via :meth:`fetch_more` and/or :meth:`fetch_all`. - These methods are especially useful to be called asynchronously. + For example, a UI application might want to fetch data in the background so the UI is not blocked in the process. + In that case they can call e.g. :meth:`fetch_more` asynchronously as the user scrolls or expands the views. - Data can also be retreived using :meth:`query` in combination with one of the multiple subquery properties - documented below. + The :meth:`query` method is also provided as an alternative way to retrieve data from the DB + while bypassing the in-memory mapping entirely. """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - for item_type in self.ITEM_TYPES: - setattr(self, "get_" + item_type, self._make_getter(item_type)) - - def _make_getter(self, item_type): - def _get_item(self, **kwargs): - return self.get_item(item_type, **kwargs) + def get_item(self, item_type, fetch=True, skip_removed=True, **kwargs): + """Finds and returns and item matching the arguments, or None if none found. - return _get_item + Args: + item_type (str): The type of the item. + fetch (bool, optional): Whether to fetch the DB in case the item is not found in memory. + skip_removed (bool, optional): Whether to ignore removed items. + **kwargs: Fields of one of the item type's unique keys and their values for the requested item. - def get_item(self, tablename, **kwargs): - tablename = self._real_tablename(tablename) - cache_item = self.cache.table_cache(tablename).find_item(kwargs) + Returns: + :class:`PublicItem` or None + """ + item_type = self._real_tablename(item_type) + cache_item = self.cache.table_cache(item_type).find_item(kwargs, fetch=fetch) if not cache_item: return None + if skip_removed and not cache_item.is_valid(): + return None return PublicItem(self, cache_item) - def get_items(self, tablename, fetch=True, valid_only=True): - tablename = self._real_tablename(tablename) - if fetch and tablename not in self.cache.fetched_item_types: - self.fetch_all(tablename) - if valid_only: - return [PublicItem(self, x) for x in self.cache.table_cache(tablename).valid_values()] - return [PublicItem(self, x) for x in self.cache.table_cache(tablename).values()] + def get_items(self, item_type, fetch=True, skip_removed=True): + """Finds and returns and item matching the arguments, or None if none found. + + Args: + item_type (str): The type of items to get. + fetch (bool, optional): Whether to fetch the DB before returning the items. + skip_removed (bool, optional): Whether to ignore removed items. + + Returns: + :class:`PublicItem` or None + """ + item_type = self._real_tablename(item_type) + if fetch and item_type not in self.cache.fetched_item_types: + self.fetch_all(item_type) + table_cache = self.cache.table_cache(item_type) + get_items = table_cache.valid_values if skip_removed else table_cache.values + return [PublicItem(self, x) for x in get_items()] + + def add_item(self, item_type, check=True, **kwargs): + """Adds an item to the in-memory mapping. + + Example:: + + with DatabaseMapping(url) as db_map: + db_map.add_item("entity", class_name="dog", name="Pete") + + + Args: + item_type (str): The type of the item. + check (bool, optional): Whether to carry out integrity checks. + **kwargs: Mandatory fields for the item type and their values. + + Returns: + tuple(:class:`PublicItem` or None, str): The added item and any errors. + """ + item_type = self._real_tablename(item_type) + table_cache = self.cache.table_cache(item_type) + self._convert_legacy(item_type, kwargs) + if not check: + return table_cache.add_item(kwargs, new=True), None + checked_item, error = table_cache.check_item(kwargs) + return table_cache.add_item(checked_item, new=True) if checked_item and not error else None, error + + def update_item(self, item_type, check=True, **kwargs): + """Updates an item in the in-memory mapping. + + Example:: + + with DatabaseMapping(url) as db_map: + my_dog = db_map.get_item("entity", class_name="dog", name="Pete") + db_map.update_item("entity", id=my_dog["id], name="Pluto") + + Args: + item_type (str): The type of the item. + check (bool, optional): Whether to carry out integrity checks. + id (int): The id of the item to update. + **kwargs: Fields to update and their new values. + + Returns: + tuple(:class:`PublicItem` or None, str): The added item and any errors. + """ + item_type = self._real_tablename(item_type) + table_cache = self.cache.table_cache(item_type) + self._convert_legacy(item_type, kwargs) + if not check: + return table_cache.update_item(kwargs), None + checked_item, error = table_cache.check_item(kwargs, for_update=True) + return table_cache.update_item(checked_item._asdict()) if checked_item and not error else None, error + + def remove_item(self, item_type, id_): + """Removes an item from the in-memory mapping. + + Example:: + + with DatabaseMapping(url) as db_map: + my_dog = db_map.get_item("entity", class_name="dog", name="Pluto") + db_map.remove_item("entity", my_dog["id]) + + + Args: + item_type (str): The type of the item. + id (int): The id of the item to remove. + + Returns: + tuple(:class:`PublicItem` or None, str): The removed item if any. + """ + item_type = self._real_tablename(item_type) + table_cache = self.cache.table_cache(item_type) + return table_cache.remove_item(id_) + + def restore_item(self, item_type, id_): + """Restores a previously removed item into the in-memory mapping. + + Example:: + + with DatabaseMapping(url) as db_map: + my_dog = db_map.get_item("entity", skip_removed=False, class_name="dog", name="Pluto") + db_map.restore_item("entity", my_dog["id]) + + Args: + item_type (str): The type of the item. + id (int): The id of the item to restore. + + Returns: + tuple(:class:`PublicItem` or None, str): The restored item if any. + """ + item_type = self._real_tablename(item_type) + table_cache = self.cache.table_cache(item_type) + return table_cache.restore_item(id_) + + def can_fetch_more(self, item_type): + """Whether or not more data can be fetched from the DB for the given item type. + + Args: + item_type (str): The item type (table) to check. - def can_fetch_more(self, tablename): - return tablename not in self.cache.fetched_item_types + Returns: + bool + """ + return item_type not in self.cache.fetched_item_types - def fetch_more(self, tablename, limit): - """Fetches items from the DB into memory, incrementally. + def fetch_more(self, item_type, limit): + """Fetches items from the DB into the in-memory mapping, incrementally. Args: - tablename (str): The table to fetch. + item_type (str): The item type (table) to fetch. limit (int): The maximum number of items to fetch. Successive calls to this function will start from the point where the last one left. In other words, each item is fetched from the DB exactly once. @@ -98,28 +216,50 @@ def fetch_more(self, tablename, limit): Returns: list(PublicItem): The items fetched. """ - tablename = self._real_tablename(tablename) - return self.cache.fetch_more(tablename, limit=limit) + item_type = self._real_tablename(item_type) + return self.cache.fetch_more(item_type, limit=limit) - def fetch_all(self, *tablenames): - """Fetches items from the DB into memory. Unlike :meth:`fetch_more`, this method fetches entire tables. + def fetch_all(self, *item_types): + """Fetches items from the DB into the in-memory mapping. + Unlike :meth:`fetch_more`, this method fetches entire tables. Args: - *tablenames (str): The tables to fetch. If none given, then the entire DB is fecthed. + *item_types (str): The item types (tables) to fetch. If none given, then the entire DB is fetched. """ - tablenames = set(self.ITEM_TYPES) if not tablenames else set(tablenames) & set(self.ITEM_TYPES) - for tablename in tablenames: - tablename = self._real_tablename(tablename) - self.cache.fetch_all(tablename) - - def add_item(self, tablename, **kwargs): - tablename = self._real_tablename(tablename) - table_cache = self.cache.table_cache(tablename) - self._convert_legacy(tablename, kwargs) - checked_item, error = table_cache.check_item(kwargs) - if error: - return None, error - return table_cache.add_item(checked_item, new=True), None + item_types = set(self.ITEM_TYPES) if not item_types else set(item_types) & set(self.ITEM_TYPES) + for item_type in item_types: + item_type = self._real_tablename(item_type) + self.cache.fetch_all(item_type) + + @staticmethod + def describe_item_type(item_type): + """Prints a synopsis of the given item type to the stdout. + + Args: + item_type (str): The type of item to describe. + """ + factory = DBCache.item_factory(item_type) + sections = ("Fields:", "Unique keys:") + width = max(len(s) for s in sections) + 4 + print() + print(item_type) + print("-" * len(item_type)) + section = sections[0] + field_iter = (f"{field} ({type_}) - {description}" for field, (type_, description) in factory._fields.items()) + _print_section(section, width, field_iter) + print() + section = sections[1] + unique_key_iter = ("(" + ", ".join(key) + ")" for key in factory._unique_keys) + _print_section(section, width, unique_key_iter) + print() + + +def _print_section(section, width, iterator): + row = next(iterator) + bullet = "- " + print(f"{section:<{width}}" + bullet + row) + for row in iterator: + print(" " * width + bullet + row) class PublicItem: diff --git a/spinedb_api/db_mapping_add_mixin.py b/spinedb_api/db_mapping_add_mixin.py index 81158b4f..58f61ae5 100644 --- a/spinedb_api/db_mapping_add_mixin.py +++ b/spinedb_api/db_mapping_add_mixin.py @@ -20,40 +20,29 @@ class DatabaseMappingAddMixin: """Provides methods to perform ``INSERT`` operations over a Spine db.""" def add_items(self, tablename, *items, check=True, strict=False): - """Add items to cache. + """Add items to the in-memory mapping. Args: - tablename (str) - items (Iterable): One or more Python :class:`dict` objects representing the items to be inserted. - check (bool): Whether or not to check integrity + tablename (str): The table where items are inserted. + items (Iterable): One or more :class:`dict` objects representing the items to be inserted. + check (bool): Whether or not to run integrity checks. strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` if the insertion of one of the items violates an integrity constraint. Returns: - set: ids or items successfully added - list(str): found violations + tuple(list(dict),list(str)): items successfully added and found violations. """ added, errors = [], [] - if not check: - for item in items: - added.append(self._add_item_unsafe(tablename, item)) - else: - for item in items: - item, error = self.add_item(tablename, **item) - if error: - if strict: - raise SpineIntegrityError(error) - errors.append(error) - continue - added.append(item) + for item in items: + item, error = self.add_item(tablename, check, **item) + if error: + if strict: + raise SpineIntegrityError(error) + errors.append(error) + continue + added.append(item) return added, errors - def _add_item_unsafe(self, tablename, item): - tablename = self._real_tablename(tablename) - table_cache = self.cache.table_cache(tablename) - self._convert_legacy(tablename, item) - return table_cache.add_item(item, new=True) - def _do_add_items(self, connection, tablename, *items_to_add): """Add items to DB without checking integrity.""" if not items_to_add: diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index c1f197ca..42ae72c4 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -80,7 +80,8 @@ def __init__( ): """ Args: - db_url (str or URL): A URL in RFC-1738 format pointing to the database to be mapped, or to a DB server. + db_url (str or :class:`~sqlalchemy.engine.url.URL`): A URL in RFC-1738 format pointing to the database + to be mapped, or to a DB server. username (str, optional): A user name. If not given, it gets replaced by the string ``"anon"``. upgrade (bool, optional): Whether the db at the given URL should be upgraded to the most recent version. @@ -105,7 +106,7 @@ def __init__( self.codename = self._make_codename(codename) self._memory = memory self._memory_dirty = False - self._original_engine = self._create_engine( + self._original_engine = self.create_engine( self.sa_url, upgrade=upgrade, create=create, sqlite_timeout=sqlite_timeout ) # NOTE: The NullPool is needed to receive the close event (or any events), for some reason @@ -195,7 +196,7 @@ def get_filter_configs(self): """Returns filters applicable to this DB mapping. Returns: - list(dict) + list(dict): """ return self._filter_configs @@ -231,13 +232,13 @@ def _make_codename(self, codename): def create_engine(sa_url, upgrade=False, create=False, sqlite_timeout=1800): """Creates engine. - Args + Args: sa_url (URL) upgrade (bool, optional): If True, upgrade the db to the latest version. create (bool, optional): If True, create a new Spine db at the given url if none found. - Returns - Engine + Returns: + :class:`~sqlalchemy.engine.Engine` """ if sa_url.drivername == "sqlite": connect_args = {'timeout': sqlite_timeout} @@ -345,11 +346,11 @@ def _clear_subqueries(self, *tablenames): setattr(self, attr_name, None) def query(self, *args, **kwargs): - """Returns a :class:`~spinedb_api.query.Query` object bound to this :class:`.DatabaseMappingBase`. + """Returns a :class:`~spinedb_api.query.Query` object to execute against this DB. To perform custom ``SELECT`` statements, call this method with one or more of the class documented - :class:`~sqlalchemy.sql.expression.Alias` properties. For example, to select the entity class with - ``id`` equal to 1:: + subquery properties (of :class:`~sqlalchemy.sql.expression.Alias` type). + For example, to select the entity class with ``id`` equal to 1:: from spinedb_api import DatabaseMapping url = 'sqlite:///spine.db' @@ -357,8 +358,10 @@ def query(self, *args, **kwargs): db_map = DatabaseMapping(url) db_map.query(db_map.entity_class_sq).filter_by(id=1).one_or_none() - To perform more complex queries, just use the :class:`~spinedb_api.query.Query` interface. - For example, to select all entity class names and the names of their entities concatenated in a string:: + To perform more complex queries, just use the :class:`~spinedb_api.query.Query` interface + (which is a close clone of SQL Alchemy's :class:`~sqlalchemy.orm.query.Query`). + For example, to select all entity class names and the names of their entities concatenated in a comma-separated + string:: from sqlalchemy import func @@ -381,7 +384,7 @@ def _subquery(self, tablename): tablename (str): the table to be queried. Returns: - sqlalchemy.sql.expression.Alias + :class:`~sqlalchemy.sql.expression.Alias` """ table = self._metadata.tables[tablename] return self.query(table).subquery(tablename + "_sq") @@ -395,7 +398,7 @@ def entity_class_sq(self): SELECT * FROM entity_class Returns: - sqlalchemy.sql.expression.Alias + :class:`~sqlalchemy.sql.expression.Alias` """ if self._entity_class_sq is None: self._entity_class_sq = self._make_entity_class_sq() @@ -410,7 +413,7 @@ def entity_class_dimension_sq(self): SELECT * FROM entity_class_dimension Returns: - sqlalchemy.sql.expression.Alias + :class:`~sqlalchemy.sql.expression.Alias` """ if self._entity_class_dimension_sq is None: self._entity_class_dimension_sq = self._subquery("entity_class_dimension") @@ -433,7 +436,7 @@ def wide_entity_class_sq(self): ec.id == ecd.entity_class_id Returns: - sqlalchemy.sql.expression.Alias + :class:`~sqlalchemy.sql.expression.Alias` """ if self._wide_entity_class_sq is None: entity_class_dimension_sq = ( @@ -498,7 +501,7 @@ def entity_sq(self): SELECT * FROM entity Returns: - sqlalchemy.sql.expression.Alias + :class:`~sqlalchemy.sql.expression.Alias` """ if self._entity_sq is None: self._entity_sq = self._make_entity_sq() @@ -513,7 +516,7 @@ def entity_element_sq(self): SELECT * FROM entity_element Returns: - sqlalchemy.sql.expression.Alias + :class:`~sqlalchemy.sql.expression.Alias` """ if self._entity_element_sq is None: self._entity_element_sq = self._make_entity_element_sq() @@ -536,7 +539,7 @@ def wide_entity_sq(self): e.id == ee.entity_id Returns: - sqlalchemy.sql.expression.Alias + :class:`~sqlalchemy.sql.expression.Alias` """ if self._wide_entity_sq is None: entity_element_sq = ( @@ -583,7 +586,7 @@ def entity_group_sq(self): SELECT * FROM entity_group Returns: - sqlalchemy.sql.expression.Alias + :class:`~sqlalchemy.sql.expression.Alias` """ if self._entity_group_sq is None: self._entity_group_sq = self._subquery("entity_group") @@ -598,7 +601,7 @@ def alternative_sq(self): SELECT * FROM alternative Returns: - sqlalchemy.sql.expression.Alias + :class:`~sqlalchemy.sql.expression.Alias` """ if self._alternative_sq is None: self._alternative_sq = self._make_alternative_sq() @@ -613,7 +616,7 @@ def scenario_sq(self): SELECT * FROM scenario Returns: - sqlalchemy.sql.expression.Alias + :class:`~sqlalchemy.sql.expression.Alias` """ if self._scenario_sq is None: self._scenario_sq = self._make_scenario_sq() @@ -628,7 +631,7 @@ def scenario_alternative_sq(self): SELECT * FROM scenario_alternative Returns: - sqlalchemy.sql.expression.Alias + :class:`~sqlalchemy.sql.expression.Alias` """ if self._scenario_alternative_sq is None: self._scenario_alternative_sq = self._make_scenario_alternative_sq() @@ -643,7 +646,7 @@ def entity_alternative_sq(self): SELECT * FROM entity_alternative Returns: - sqlalchemy.sql.expression.Alias + :class:`~sqlalchemy.sql.expression.Alias` """ if self._entity_alternative_sq is None: self._entity_alternative_sq = self._subquery("entity_alternative") @@ -658,7 +661,7 @@ def parameter_value_list_sq(self): SELECT * FROM parameter_value_list Returns: - sqlalchemy.sql.expression.Alias + :class:`~sqlalchemy.sql.expression.Alias` """ if self._parameter_value_list_sq is None: self._parameter_value_list_sq = self._subquery("parameter_value_list") @@ -673,7 +676,7 @@ def list_value_sq(self): SELECT * FROM list_value Returns: - sqlalchemy.sql.expression.Alias + :class:`~sqlalchemy.sql.expression.Alias` """ if self._list_value_sq is None: self._list_value_sq = self._subquery("list_value") @@ -688,7 +691,7 @@ def parameter_definition_sq(self): SELECT * FROM parameter_definition Returns: - sqlalchemy.sql.expression.Alias + :class:`~sqlalchemy.sql.expression.Alias` """ if self._parameter_definition_sq is None: @@ -704,7 +707,7 @@ def parameter_value_sq(self): SELECT * FROM parameter_value Returns: - sqlalchemy.sql.expression.Alias + :class:`~sqlalchemy.sql.expression.Alias` """ if self._parameter_value_sq is None: self._parameter_value_sq = self._make_parameter_value_sq() @@ -719,7 +722,7 @@ def metadata_sq(self): SELECT * FROM list_value Returns: - sqlalchemy.sql.expression.Alias + :class:`~sqlalchemy.sql.expression.Alias` """ if self._metadata_sq is None: self._metadata_sq = self._subquery("metadata") @@ -734,7 +737,7 @@ def parameter_value_metadata_sq(self): SELECT * FROM parameter_value_metadata Returns: - sqlalchemy.sql.expression.Alias + :class:`~sqlalchemy.sql.expression.Alias` """ if self._parameter_value_metadata_sq is None: self._parameter_value_metadata_sq = self._subquery("parameter_value_metadata") @@ -749,7 +752,7 @@ def entity_metadata_sq(self): SELECT * FROM entity_metadata Returns: - sqlalchemy.sql.expression.Alias + :class:`~sqlalchemy.sql.expression.Alias` """ if self._entity_metadata_sq is None: self._entity_metadata_sq = self._subquery("entity_metadata") @@ -764,7 +767,7 @@ def commit_sq(self): SELECT * FROM commit Returns: - sqlalchemy.sql.expression.Alias + :class:`~sqlalchemy.sql.expression.Alias` """ if self._commit_sq is None: commit_sq = self._subquery("commit") @@ -1520,11 +1523,6 @@ def _make_scenario_alternative_sq(self): return self._subquery("scenario_alternative") def get_import_alternative_name(self): - """Returns the name of the alternative to use as default for all import operations. - - Returns: - str: import alternative name - """ if self._import_alternative_name is None: self._create_import_alternative() return self._import_alternative_name diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py index 0b1f16fd..207a157e 100644 --- a/spinedb_api/db_mapping_commit_mixin.py +++ b/spinedb_api/db_mapping_commit_mixin.py @@ -19,7 +19,7 @@ class DatabaseMappingCommitMixin: """Provides methods to commit or rollback pending changes onto a Spine database.""" def commit_session(self, comment): - """Commits current session to the database. + """Commits the changes from the in-memory mapping to the database. Args: comment (str): commit message @@ -49,10 +49,12 @@ def commit_session(self, comment): return compatibility_transformations(connection) def rollback_session(self): + """Discards all the changes from the in-memory mapping.""" if not self.cache.rollback(): raise SpineDBAPIError("Nothing to rollback.") if self._memory: self._memory_dirty = False def refresh_session(self): + """Resets the fetch status so new items from the DB can be retrieved.""" self.cache.refresh() diff --git a/spinedb_api/db_mapping_remove_mixin.py b/spinedb_api/db_mapping_remove_mixin.py index 4b703e27..15c240bf 100644 --- a/spinedb_api/db_mapping_remove_mixin.py +++ b/spinedb_api/db_mapping_remove_mixin.py @@ -53,16 +53,6 @@ def restore_items(self, tablename, *ids): table_cache = self.cache.table_cache(tablename) return [table_cache.restore_item(id_) for id_ in ids] - def remove_item(self, tablename, id_): - tablename = self._real_tablename(tablename) - table_cache = self.cache.table_cache(tablename) - return table_cache.remove_item(id_) - - def restore_item(self, tablename, id_): - tablename = self._real_tablename(tablename) - table_cache = self.cache.table_cache(tablename) - return table_cache.restore_item(id_) - def purge_items(self, tablename): """Removes all items from given table. diff --git a/spinedb_api/db_mapping_update_mixin.py b/spinedb_api/db_mapping_update_mixin.py index 1b4f03c1..f50d980e 100644 --- a/spinedb_api/db_mapping_update_mixin.py +++ b/spinedb_api/db_mapping_update_mixin.py @@ -71,47 +71,29 @@ def _extra_items_to_update_per_table(tablename, items_to_update): yield ("entity_element", ee_items_to_update) def update_items(self, tablename, *items, check=True, strict=False): - """Updates items in cache. + """Updates items in the in-memory mapping. Args: - tablename (str): Target database table name - *items: One or more Python :class:`dict` objects representing the items to be inserted. - check (bool): Whether or not to check integrity + tablename (str): The table where items are updated + *items: One or more :class:`dict` objects representing the items to be updated. + check (bool): Whether or not to run integrity checks. strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` - if the insertion of one of the items violates an integrity constraint. + if the update of one of the items violates an integrity constraint. Returns: - set: ids or items successfully updated - list(SpineIntegrityError): found violations + tuple(list(dict),list(str)): items successfully updated and found violations. """ updated, errors = [], [] - if not check: - for item in items: - updated.append(self._update_item_unsafe(tablename, item)) - else: - for item in items: - item, error = self.update_item(tablename, **item) - if error: - if strict: - raise SpineIntegrityError(error) - errors.append(error) - if item: - updated.append(item) + for item in items: + item, error = self.update_item(tablename, check=check, **item) + if error: + if strict: + raise SpineIntegrityError(error) + errors.append(error) + if item: + updated.append(item) return updated, errors - def _update_item_unsafe(self, tablename, item): - tablename = self._real_tablename(tablename) - table_cache = self.cache.table_cache(tablename) - self._convert_legacy(tablename, item) - return table_cache.update_item(item) - - def update_item(self, tablename, **kwargs): - tablename = self._real_tablename(tablename) - table_cache = self.cache.table_cache(tablename) - self._convert_legacy(tablename, kwargs) - checked_item, error = table_cache.check_item(kwargs, for_update=True) - return table_cache.update_item(checked_item._asdict()) if checked_item else None, error - def update_alternatives(self, *items, **kwargs): return self.update_items("alternative", *items, **kwargs) diff --git a/spinedb_api/export_functions.py b/spinedb_api/export_functions.py index f53cf525..af67469d 100644 --- a/spinedb_api/export_functions.py +++ b/spinedb_api/export_functions.py @@ -36,8 +36,8 @@ def export_data( ): """ Exports data from a Spine DB into a standard dictionary format. - The result can be splatted into keyword arguments for :func:`spinedb_api.import_functions.import_data` - to transfer data from one DB to another. + The result can be splatted into keyword arguments for :func:`spinedb_api.import_functions.import_data`, + to copy data from one DB to another. Args: db_map (DatabaseMapping): The db to pull data from. diff --git a/spinedb_api/graph_layout_generator.py b/spinedb_api/graph_layout_generator.py index e1e3be49..5ffd3595 100644 --- a/spinedb_api/graph_layout_generator.py +++ b/spinedb_api/graph_layout_generator.py @@ -20,7 +20,9 @@ class GraphLayoutGenerator: - """A class to build an optimised layout for an undirected graph.""" + """A class to build an optimised layout for an undirected graph. + This can help visualizing the Spine data structure of multi-dimensional entities. + """ def __init__( self, @@ -39,21 +41,21 @@ def __init__( """ Args: vertex_count (int): The number of vertices in the graph. Graph vertices will have indices 0, 1, 2, ... - src_inds (tuple,optional): The indices of the source vertices of each edge. - dst_inds (tuple,optional): The indices of the destination vertices of each edge. - spread (int,optional): the ideal edge length. - heavy_positions (dict,optional): a dictionary mapping vertex indices to another dictionary + src_inds (tuple, optional): The indices of the source vertices of each edge. + dst_inds (tuple, optional): The indices of the destination vertices of each edge. + spread (int, optional): the ideal edge length. + heavy_positions (dict, optional): a dictionary mapping vertex indices to another dictionary with keys "x" and "y" specifying the position it should have in the generated layout. - max_iters (int,optional): the maximum numbers of iterations of the layout generation algorithm. - weight_exp (int,optional): The exponential decay rate of attraction between vertices. The higher this + max_iters (int, optional): the maximum numbers of iterations of the layout generation algorithm. + weight_exp (int, optional): The exponential decay rate of attraction between vertices. The higher this number, the lesser the attraction between distant vertices. - is_stopped (function,optional): A function to call without arguments, that returns a boolean indicating + is_stopped (function, optional): A function to call without arguments, that returns a boolean indicating whether the layout generation process needs to be stopped. - preview_available (function,optional): A function to call after every iteration with two lists, x and y, + preview_available (function, optional): A function to call after every iteration with two lists, x and y, representing the current layout. - layout_available (function,optional): A function to call after the last iteration with two lists, x and y, + layout_available (function, optional): A function to call after the last iteration with two lists, x and y, representing the final layout. - layout_progressed (function,optional): A function to call after each iteration with the current iteration + layout_progressed (function, optional): A function to call after each iteration with the current iteration number. """ super().__init__() diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index ecb464a2..7cfc6653 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -197,10 +197,10 @@ def copy_database(dest_url, source_url, overwrite=True, upgrade=False, only_tabl Args: dest_url (str): The destination url. source_url (str): The source url. - overwrite (bool,optional): whether to overwrite the destination. - upgrade (bool,optional): whether to upgrade the source to the latest Spine schema revision. - only_tables (tuple,optional): If given, only these tables are copied. - skip_tables (tuple,optional): If given, these tables are skipped. + overwrite (bool, optional): whether to overwrite the destination. + upgrade (bool, optional): whether to upgrade the source to the latest Spine schema revision. + only_tables (tuple, optional): If given, only these tables are copied. + skip_tables (tuple, optional): If given, these tables are skipped. """ if not _is_head(source_url, upgrade=upgrade): raise SpineDBVersionError(url=source_url) diff --git a/spinedb_api/spine_db_server.py b/spinedb_api/spine_db_server.py index 534967f5..3e67b997 100644 --- a/spinedb_api/spine_db_server.py +++ b/spinedb_api/spine_db_server.py @@ -612,9 +612,9 @@ def closing_spine_db_server(db_url, upgrade=False, memory=False, ordering=None, db_url (str): the URL of a Spine DB. upgrade (bool): Whether to upgrade the DB to the last revision. memory (bool): Whether to use an in-memory database together with a persistent connection. - server_manager_queue (Queue,optional): A queue that can be used to control order of writing. + server_manager_queue (Queue, optional): A queue that can be used to control order of writing. Only needed if you also specify `ordering` below. - ordering (dict,optional): A dictionary specifying an ordering to be followed by multiple concurrent servers + ordering (dict, optional): A dictionary specifying an ordering to be followed by multiple concurrent servers writing to the same DB. It must have the following keys: - "id": an identifier for the ordering, shared by all the servers in the ordering. - "current": an identifier for this server within the ordering. diff --git a/spinedb_api/spine_io/exporters/writer.py b/spinedb_api/spine_io/exporters/writer.py index f199d724..a569f7f3 100644 --- a/spinedb_api/spine_io/exporters/writer.py +++ b/spinedb_api/spine_io/exporters/writer.py @@ -130,7 +130,7 @@ def _new_table(writer, table_name, title_key): Args: writer (Writer): a writer table_name (str): table's name - title_key (dict,optional) + title_key (dict, optional) Yields: bool: whether or not the new table was successfully started diff --git a/spinedb_api/temp_id.py b/spinedb_api/temp_id.py index 501a797f..3ec329d8 100644 --- a/spinedb_api/temp_id.py +++ b/spinedb_api/temp_id.py @@ -28,7 +28,7 @@ def __eq__(self, other): return super().__eq__(other) or (self._db_id is not None and other == self._db_id) def __hash__(self): - return -int(self) + return int(self) @property def db_id(self): diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 3b6d2a90..687a25c5 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -43,21 +43,6 @@ def query_wrapper(*args, orig_query=db_map.query, **kwargs): IN_MEMORY_DB_URL = "sqlite://" -class TestDatabaseMappingPublic(unittest.TestCase): - _db_map = None - - @classmethod - def setUpClass(cls): - cls._db_map = DatabaseMapping(IN_MEMORY_DB_URL, create=True) - - @classmethod - def tearDownClass(cls): - cls._db_map.close() - - def test_getters(self): - print(dir(self._db_map.get_entity_class)) - - class TestDatabaseMappingConstruction(unittest.TestCase): def test_construction_with_filters(self): db_url = IN_MEMORY_DB_URL + "?spinedbfilter=fltr1&spinedbfilter=fltr2" From e387bec7ae17fa984468491968b7d9c9d20a6d39 Mon Sep 17 00:00:00 2001 From: Pekka T Savolainen Date: Tue, 3 Oct 2023 15:24:47 +0300 Subject: [PATCH 104/317] Fix a call --- spinedb_api/db_mapping_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index c1f197ca..9cb6f1da 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -105,7 +105,7 @@ def __init__( self.codename = self._make_codename(codename) self._memory = memory self._memory_dirty = False - self._original_engine = self._create_engine( + self._original_engine = self.create_engine( self.sa_url, upgrade=upgrade, create=create, sqlite_timeout=sqlite_timeout ) # NOTE: The NullPool is needed to receive the close event (or any events), for some reason From c93ab95fd12128c151bdcb4862551a08595f1473 Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 3 Oct 2023 14:32:27 +0200 Subject: [PATCH 105/317] Fix tests and return type for some methods --- spinedb_api/db_mapping.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 69520889..bbf548b7 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -124,7 +124,10 @@ def add_item(self, item_type, check=True, **kwargs): if not check: return table_cache.add_item(kwargs, new=True), None checked_item, error = table_cache.check_item(kwargs) - return table_cache.add_item(checked_item, new=True) if checked_item and not error else None, error + return ( + PublicItem(self, table_cache.add_item(checked_item, new=True)) if checked_item and not error else None, + error, + ) def update_item(self, item_type, check=True, **kwargs): """Updates an item in the in-memory mapping. @@ -150,7 +153,7 @@ def update_item(self, item_type, check=True, **kwargs): if not check: return table_cache.update_item(kwargs), None checked_item, error = table_cache.check_item(kwargs, for_update=True) - return table_cache.update_item(checked_item._asdict()) if checked_item and not error else None, error + return (PublicItem(self, table_cache.update_item(checked_item._asdict())) if checked_item else None, error) def remove_item(self, item_type, id_): """Removes an item from the in-memory mapping. @@ -171,7 +174,7 @@ def remove_item(self, item_type, id_): """ item_type = self._real_tablename(item_type) table_cache = self.cache.table_cache(item_type) - return table_cache.remove_item(id_) + return PublicItem(self, table_cache.remove_item(id_)) def restore_item(self, item_type, id_): """Restores a previously removed item into the in-memory mapping. @@ -191,7 +194,7 @@ def restore_item(self, item_type, id_): """ item_type = self._real_tablename(item_type) table_cache = self.cache.table_cache(item_type) - return table_cache.restore_item(id_) + return PublicItem(self, table_cache.restore_item(id_)) def can_fetch_more(self, item_type): """Whether or not more data can be fetched from the DB for the given item type. From 59bdd4849ba41ba2e01272e0aa08b5224b9473b0 Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 3 Oct 2023 19:35:19 +0200 Subject: [PATCH 106/317] Improve documentation of Spine DB server and client --- docs/source/conf.py | 4 +- spinedb_api/server_client_helpers.py | 2 - spinedb_api/spine_db_client.py | 26 ++--- spinedb_api/spine_db_server.py | 168 +++++++++++++++++---------- 4 files changed, 123 insertions(+), 77 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 4e0b6a10..9d7f39af 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -97,7 +97,9 @@ def _skip_member(app, what, name, obj, skip, options): - if what == "class" and any(x in name for x in ("SpineDBServer", "group_concat")): + if what == "class" and any( + x in name for x in ("SpineDBServer", "group_concat", "DBRequestHandler", "ReceiveAllMixing") + ): skip = True return skip diff --git a/spinedb_api/server_client_helpers.py b/spinedb_api/server_client_helpers.py index ae7715a0..3a39f9de 100644 --- a/spinedb_api/server_client_helpers.py +++ b/spinedb_api/server_client_helpers.py @@ -19,8 +19,6 @@ class ReceiveAllMixing: - """Provides _recvall, to read everything from a socket until the _EOT character is found.""" - _ENCODING = "utf-8" _BUFF_SIZE = 4096 _EOT = '\u0004' # End of transmission diff --git a/spinedb_api/spine_db_client.py b/spinedb_api/spine_db_client.py index cbd2eeaa..570043b4 100644 --- a/spinedb_api/spine_db_client.py +++ b/spinedb_api/spine_db_client.py @@ -10,7 +10,7 @@ ###################################################################################################################### """ -The :class:`SpineDBClient` class. +This module defines the :class:`SpineDBClient` class. """ from urllib.parse import urlparse @@ -23,10 +23,10 @@ class SpineDBClient(ReceiveAllMixing): def __init__(self, server_address): - """Represents a client connection to a Spine DB server. + """Enables sending requests to a Spine DB server. Args: - server_address (tuple(str,int)): hostname and port + server_address (tuple(str,int)): the hostname and port where the server is listening. """ self._server_address = server_address self.request = None @@ -36,7 +36,7 @@ def from_server_url(cls, url): """Creates a client from a server's URL. Args: - url (str, URL): the url of a Spine DB server. + url (str, URL): the URL where the server is listening. """ parsed = urlparse(url) if parsed.scheme != "http": @@ -44,7 +44,7 @@ def from_server_url(cls, url): return cls((parsed.hostname, parsed.port)) def get_db_url(self): - """Returns the URL of the Spine DB associated with the server. + """Returns the URL of the target Spine DB - the one the server is set to communicate with. Returns: str @@ -65,29 +65,29 @@ def cancel_db_checkout(self): return self._send("cancel_db_checkout") def import_data(self, data, comment): - """Imports data to the DB using :func:`spinedb_api.import_functions.import_data` and commits the changes. + """Imports data to the DB using :func:`~spinedb_api.import_functions.import_data` and commits the changes. Args: - data (dict): to be splatted into keyword arguments to :func:`spinedb_api.import_functions.import_data` + data (dict): to be splatted into keyword arguments to :func:`~spinedb_api.import_functions.import_data` comment (str): a commit message. """ return self._send("import_data", args=(data, comment)) def export_data(self, **kwargs): - """Exports data from the DB using :func:`spinedb_api.export_functions.export_data`. + """Exports data from the DB using :func:`~spinedb_api.export_functions.export_data`. Args: - kwargs: keyword arguments passed to :func:`spinedb_api.import_functions.import_data` + **kwargs: keyword arguments passed to :func:`~spinedb_api.import_functions.export_data` """ return self._send("export_data", kwargs=kwargs) def call_method(self, method_name, *args, **kwargs): - """Calls a method from :class:`spinedb_api.db_mapping.DatabaseMapping`. + """Calls a method from :class:`~spinedb_api.db_mapping.DatabaseMapping`. Args: - method_name (str): the name of the method to call - args: positional arguments passed to the method call - kwargs: keyword arguments passed to the method call + method_name (str): the name of the method to call. + *args: positional arguments passed to the method call. + **kwargs: keyword arguments passed to the method call. """ return self._send("call_method", args=(method_name, *args), kwargs=kwargs) diff --git a/spinedb_api/spine_db_server.py b/spinedb_api/spine_db_server.py index 3e67b997..ced29530 100644 --- a/spinedb_api/spine_db_server.py +++ b/spinedb_api/spine_db_server.py @@ -10,33 +10,86 @@ ###################################################################################################################### """ -Spine DB server -=============== - -The Spine DB server provides almost the same functionality as :class:`spinedb_api.db_mapping.DatabaseMapping`, -but it does it via a socket. This removes the ``spinedb_api`` requirement (and the Python requirement altogether) -from third-party applications that want to interact with Spine DBs. - -Typically this is done in the following steps: - #. Start a server by specifying the URL of the Spine DB that you want to interact with. - #. Communicate the URL of the server to your third-party application running in another process. - #. Send requests from your application to the server via sockets in order to interact with the DB. - -Available requests ------------------- -TODO - -Encoding/decoding ------------------ -TODO - -This module also provides a mechanism to control the order in which multiples servers -running in parallel should write to the same DB. +This module provides a mechanism to create a socket server interface to a Spine DB. +The server exposes most of the functionality of :class:`~spinedb_api.db_mapping.DatabaseMapping`, +and can eventually remove the ``spinedb_api`` requirement (and the Python requirement altogether) +from third-party applications that want to interact with Spine DBs. (Of course, they would need to have +access to sockets instead.) + +Typically, you would start the server in a background Python process by specifying the URL of the target Spine DB, +getting back the URL where the server is listening. +You can then use that URL in any number of instances of your application that would connect to the server +- via a socket - and then send requests to retrieve or modify the data in the DB. + +Requests to the server must be encoded using JSON. +Each request must be a JSON array with the following elements: + +#. A JSON string with one of the available request names: + ``"get_db_url"``, ``"import_data"``, ``"export_data"``, ``"query"``, ``"filtered_query"``, + ``"call_method"``, ``"db_checkin"``, ``"db_checkout"``. +#. A JSON array with positional arguments to the request. +#. A JSON object with keyword arguments to the request. +#. A JSON integer indicating the version of the server you want to talk to. + +The positional and keyword arguments to the different requests are documented +in the :class:`~spinedb_api.spine_db_client.SpineDBClient` class +(just look for a member function named after the request). + +The point of the server version is to allow client developers to adapt to changes in the Spine DB server API. +Say we update ``spinedb_api`` and change the signature of one of the requests - in this case, we will +also bump the current server version to the next integer. +If you then upgrade your ``spinedb_api`` installation but not your client, the server will see the version mismatch +and will respond that the client is outdated. +The current server version can be queried by calling :func:`get_current_server_version`. + +The order in which multiple servers should write to the same DB can also be controlled using DB servers. +This is particularly useful in high-concurrency scenarios. The server is started using :func:`closing_spine_db_server`. -If you want to also control order of writing from multiple servers, -you first need to obtain an 'ordering queue' using :func:`db_server_manager`. - +To control the order of writing you need to provide a queue, that you would obtain by calling :func:`db_server_manager`. + + +The below example illustrates most of the functionality of the module. +We create two DB servers targeting the same DB, and set the second to write before the first +(via the ``ordering`` argument to :func:`closing_spine_db_server`). +Then we spawn two threads that connect to those two servers and import an entity class. +We make sure to call :meth:`~spinedb_api.spine_db_client.SpineDBClient.db_checkin` before importing, +and :meth:`~spinedb_api.spine_db_client.SpineDBClient.db_checkout` after so the order of writing is respected. +When the first thread attemps to write to the DB, it hangs because the second one hasn't written yet. +Only after the second writes, the first one also writes and the program finishes:: + + import threading + from spinedb_api.spine_db_server import db_server_manager, closing_spine_db_server + from spinedb_api.spine_db_client import SpineDBClient + from spinedb_api.db_mapping import DatabaseMapping + + + def _import_entity_class(server_url, class_name): + client = SpineDBClient.from_server_url(server_url) + client.db_checkin() + _answer = client.import_data({"entity_classes": [(class_name, ())]}, f"Import {class_name}") + client.db_checkout() + + + db_url = 'sqlite:///somedb.sqlite' + with db_server_manager() as mngr_queue: + first_ordering = {"id": "second_before_first", "current": "first", "precursors": {"second"}, "part_count": 1} + second_ordering = {"id": "second_before_first", "current": "second", "precursors": set(), "part_count": 1} + with closing_spine_db_server(db_url, server_manager_queue=mngr_queue, ordering=first_ordering) as first_server_url: + with closing_spine_db_server( + db_url, server_manager_queue=mngr_queue, ordering=second_ordering + ) as second_server_url: + t1 = threading.Thread(target=_import_entity_class, args=(first_server_url, "monkey")) + t2 = threading.Thread(target=_import_entity_class, args=(second_server_url, "donkey")) + t1.start() + with DatabaseMapping(db_url) as db_map: + assert db_map.get_items("entity_class") == [] # Nothing written yet + t2.start() + t1.join() + t2.join() + + with DatabaseMapping(db_url) as db_map: + assert [x["name"] for x in db_map.get_items("entity_class")] == ["donkey", "monkey"] """ from urllib.parse import urlunsplit @@ -62,7 +115,16 @@ from .filters.tools import apply_filter_stack from .spine_db_client import SpineDBClient -_required_client_version = 6 +_current_server_version = 6 + + +def get_current_server_version(): + """Returns the current client version. + + Returns: + int: current client version + """ + return _current_server_version def _parse_value(v, value_type=None): @@ -527,8 +589,8 @@ def _get_response(self, request): args, kwargs, client_version = extras except ValueError: client_version = 0 - if client_version < _required_client_version: - return dict(error=1, result=_required_client_version) + if client_version < _current_server_version: + return dict(error=1, result=_current_server_version) handler = { "query": self.query, "filtered_query": self.filtered_query, @@ -598,15 +660,24 @@ def shutdown_spine_db_server(server_manager_queue, server_address): @contextmanager -def closing_spine_db_server(db_url, upgrade=False, memory=False, ordering=None, server_manager_queue=None): - """Creates a Spine DB server. +def db_server_manager(): + """Creates a DB server manager that can be used to control the order in which different servers + write to the same DB. - Example:: + Yields: + :class:`~multiprocessing.queues.Queue`: a queue that can be passed to :func:`.closing_spine_db_server` + in order to control write order. + """ + mngr = _DBServerManager() + try: + yield mngr.queue + finally: + mngr.shutdown() - with closing_spine_db_server(db_url) as server_url: - client = SpineDBClient.from_server_url(server_url) - data = client.import_data({"entity_class": [("fish", ()), ("dog", ())]}, "Add two entity classes.") +@contextmanager +def closing_spine_db_server(db_url, upgrade=False, memory=False, ordering=None, server_manager_queue=None): + """Creates a Spine DB server. Args: db_url (str): the URL of a Spine DB. @@ -618,8 +689,8 @@ def closing_spine_db_server(db_url, upgrade=False, memory=False, ordering=None, writing to the same DB. It must have the following keys: - "id": an identifier for the ordering, shared by all the servers in the ordering. - "current": an identifier for this server within the ordering. - - "precursors": a set of identifiers of other servers that must have written to the DB before this server can write. - - "part_count": the number of times this server needs to write to the DB before their successors can write. + - "precursors": a set of identifiers of other servers that must have checked out from the DB before this one can check in. + - "part_count": the number of times this server needs to check out from the DB before their successors can check in. Yields: str: server url @@ -637,28 +708,3 @@ def closing_spine_db_server(db_url, upgrade=False, memory=False, ordering=None, shutdown_spine_db_server(server_manager_queue, server_address) if mngr is not None: mngr.shutdown() - - -@contextmanager -def db_server_manager(): - """Creates a DB server manager that can be used to control the order in which different servers - write to the same DB. - - Example:: - - with db_server_manager() as mngr_queue: - with closing_spine_db_server(db_url, server_manager_queue=mngr_queue) as server1_url: - with closing_spine_db_server(db_url, server_manager_queue=mngr_queue) as server1_url: - client1 = SpineDBClient.from_server_url(server_url1) - client2 = SpineDBClient.from_server_url(server_url2) - # TODO: ordering - - Yields: - :class:`~multiprocessing.queues.Queue`: a queue that can be passed to :func:`.closing_spine_db_server` - in order to control write order. - """ - mngr = _DBServerManager() - try: - yield mngr.queue - finally: - mngr.shutdown() From 79cfc0cc3f0b9954b86860356ee3095ca2017aeb Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 4 Oct 2023 11:24:19 +0200 Subject: [PATCH 107/317] Refactor DatabaseMapping class hierarchy --- docs/source/tutorial.rst | 12 +- spinedb_api/db_cache_base.py | 909 ------- spinedb_api/db_mapping.py | 368 ++- spinedb_api/db_mapping_add_mixin.py | 2 +- spinedb_api/db_mapping_base.py | 2372 ++++++----------- spinedb_api/db_mapping_commit_mixin.py | 6 +- spinedb_api/db_mapping_query_mixin.py | 1478 ++++++++++ spinedb_api/db_mapping_remove_mixin.py | 18 +- spinedb_api/db_mapping_update_mixin.py | 4 +- spinedb_api/export_functions.py | 18 +- spinedb_api/export_mapping/export_mapping.py | 20 +- spinedb_api/export_mapping/generator.py | 4 +- spinedb_api/filters/alternative_filter.py | 14 +- spinedb_api/filters/execution_filter.py | 8 +- spinedb_api/filters/renamer.py | 20 +- spinedb_api/filters/scenario_filter.py | 46 +- spinedb_api/filters/tools.py | 2 +- spinedb_api/filters/value_transformer.py | 10 +- spinedb_api/import_functions.py | 14 +- .../{db_cache_impl.py => mapped_items.py} | 116 +- spinedb_api/purge.py | 14 - spinedb_api/spine_db_server.py | 2 +- spinedb_api/spine_io/exporters/writer.py | 2 +- tests/filters/test_execution_filter.py | 4 +- tests/test_DatabaseMapping.py | 46 +- tests/test_db_cache_base.py | 38 +- 26 files changed, 2762 insertions(+), 2785 deletions(-) delete mode 100644 spinedb_api/db_cache_base.py create mode 100644 spinedb_api/db_mapping_query_mixin.py rename spinedb_api/{db_cache_impl.py => mapped_items.py} (87%) diff --git a/docs/source/tutorial.rst b/docs/source/tutorial.rst index d966a5e7..057342b6 100644 --- a/docs/source/tutorial.rst +++ b/docs/source/tutorial.rst @@ -115,7 +115,7 @@ Finally, let's specify a parameter value for one of our entities:: Retrieving data --------------- -To retrieve data from the DB, we use :meth:`~.DatabaseMapping.get_item`. +To retrieve data from the DB (and the in-memory mapping), we use :meth:`~.DatabaseMapping.get_item`. For example, let's find one of the entities we inserted above:: felix = db_map.get_item("entity", class_name="cat", name="Felix") @@ -127,16 +127,14 @@ Above, ``felix`` is a :class:`~.PublicItem` object, representing an item (or row Let's find our multi-dimensional entity:: nemo_felix = db_map.get_item("entity", class_name="fish__cat", byname=("Nemo", "Felix")) - print(nemo_felix["dimension_name_list"]) # Prints '(fish, cat)' + print(nemo_felix["dimension_name_list"]) # Prints "('fish', 'cat')"" To retrieve all the items of a given type, we use :meth:`~.DatabaseMapping.get_items`:: print(entity["byname"] for entity in db_map.get_items("entity")) # Prints [("Nemo",), ("Felix",), ("Nemo", "Felix"),] -.. note:: - - You should use the above to try and find Nemo! +Now you should use the above to try and find Nemo. Updating data @@ -155,10 +153,10 @@ To be safe, let's also change the color:: entity_class_name="fish", parameter_definition_name="color", entity_name="NotNemo" - ).update(value="definitely purple") + ).update(value="not that orange") -Note how we need to use then new entity name (``"NotNemo"``) to retrieve the parameter value. This makes sense. +Note how we need to use then new entity name ``"NotNemo"`` to retrieve the parameter value. This makes sense. Removing data ------------- diff --git a/spinedb_api/db_cache_base.py b/spinedb_api/db_cache_base.py deleted file mode 100644 index eb582e68..00000000 --- a/spinedb_api/db_cache_base.py +++ /dev/null @@ -1,909 +0,0 @@ -###################################################################################################################### -# Copyright (C) 2017-2022 Spine project consortium -# This file is part of Spine Database API. -# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser -# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your -# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; -# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General -# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with -# this program. If not, see . -###################################################################################################################### - -import threading -from enum import Enum, unique, auto -from .temp_id import TempId - -# TODO: Implement CacheItem.pop() to do lookup? - -_LIMIT = 10000 - - -@unique -class Status(Enum): - """Cache item status.""" - - committed = auto() - to_add = auto() - to_update = auto() - to_remove = auto() - added_and_removed = auto() - - -class DBCacheBase(dict): - """A dictionary representation of a DB, mapping item types (table names), to numeric ids, to items. - - This class is not meant to be used directly. Instead, you need to subclass it for each DB schema you want to use. - - When subclassing, you need to implement :attr:`item_types`, :meth:`item_factory`, and :meth:`query`. - """ - - def __init__(self): - super().__init__() - self._offsets = {} - self._offset_lock = threading.Lock() - self._fetched_item_types = set() - item_types = self.item_types - self._sorted_item_types = [] - while item_types: - item_type = item_types.pop(0) - if self.item_factory(item_type).ref_types() & set(item_types): - item_types.append(item_type) - else: - self._sorted_item_types.append(item_type) - - @property - def fetched_item_types(self): - """Returns a set with the item types that are already fetched. - - Returns: - set - """ - return self._fetched_item_types - - @property - def item_types(self): - """Returns a list of item types in the DB (equivalent to the table names). - - Returns: - list(str) - """ - raise NotImplementedError() - - @staticmethod - def item_factory(item_type): - """Returns a subclass of :class:`.CacheItemBase` to make items of given type. - - Args: - item_type (str) - - Returns: - function - """ - raise NotImplementedError() - - def query(self, item_type): - """Returns a :class:`~spinedb_api.query.Query` object to fecth items of given type. - - Args: - item_type (str) - - Returns: - :class:`~spinedb_api.query.Query` - """ - raise NotImplementedError() - - def make_item(self, item_type, **item): - factory = self.item_factory(item_type) - return factory(self, item_type, **item) - - def dirty_ids(self, item_type): - return { - item["id"] - for item in self.table_cache(item_type).valid_values() - if item.status in (Status.to_add, Status.to_update) - } - - def dirty_items(self): - """Returns a list of tuples of the form (item_type, (to_add, to_update, to_remove)) corresponding to - items that have been modified but not yet committed. - - Returns: - list - """ - dirty_items = [] - for item_type in self._sorted_item_types: - table_cache = self.get(item_type) - if table_cache is None: - continue - to_add = [] - to_update = [] - to_remove = [] - for item in table_cache.values(): - _ = item.is_valid() - if item.status == Status.to_add: - to_add.append(item) - elif item.status == Status.to_update: - to_update.append(item) - elif item.status == Status.to_remove: - to_remove.append(item) - if to_remove: - # Fetch descendants, so that they are validated in next iterations of the loop. - # This ensures cascade removal. - # FIXME: We should also fetch the current item type because of multi-dimensional entities and - # classes which also depend on zero-dimensional ones - for other_item_type in self.item_types: - if ( - other_item_type not in self.fetched_item_types - and item_type in self.item_factory(other_item_type).ref_types() - ): - self.fetch_all(other_item_type) - if to_add or to_update or to_remove: - dirty_items.append((item_type, (to_add, to_update, to_remove))) - return dirty_items - - def rollback(self): - """Discards uncommitted changes. - - Namely, removes all the added items, resets all the updated items, and restores all the removed items. - - Returns: - bool: False if there is no uncommitted items, True if successful. - """ - dirty_items = self.dirty_items() - if not dirty_items: - return False - to_add_by_type = [] - to_update_by_type = [] - to_remove_by_type = [] - for item_type, (to_add, to_update, to_remove) in reversed(dirty_items): - to_add_by_type.append((item_type, to_add)) - to_update_by_type.append((item_type, to_update)) - to_remove_by_type.append((item_type, to_remove)) - for item_type, to_remove in to_remove_by_type: - table_cache = self.table_cache(item_type) - for item in to_remove: - table_cache.restore_item(item["id"]) - for item_type, to_update in to_update_by_type: - table_cache = self.table_cache(item_type) - for item in to_update: - table_cache.update_item(item.backup) - for item_type, to_add in to_add_by_type: - table_cache = self.table_cache(item_type) - for item in to_add: - if table_cache.remove_item(item["id"]) is not None: - item.invalidate_id() - return True - - def refresh(self): - """Clears fetch progress, so the DB is queried again.""" - self._offsets.clear() - self._fetched_item_types.clear() - - def _get_next_chunk(self, item_type, limit): - qry = self.query(item_type) - if not qry: - return [] - if not limit: - self._fetched_item_types.add(item_type) - return [dict(x) for x in qry] - with self._offset_lock: - offset = self._offsets.setdefault(item_type, 0) - chunk = [dict(x) for x in qry.limit(limit).offset(offset)] - self._offsets[item_type] += len(chunk) - return chunk - - def _advance_query(self, item_type, limit): - """Advances the DB query that fetches items of given type - and adds the results to the corresponding table cache. - - Args: - item_type (str) - - Returns: - list: items fetched from the DB - """ - chunk = self._get_next_chunk(item_type, limit) - if not chunk: - self._fetched_item_types.add(item_type) - return [] - table_cache = self.table_cache(item_type) - for item in chunk: - table_cache.add_item(item) - return chunk - - def table_cache(self, item_type): - return self.setdefault(item_type, _TableCache(self, item_type)) - - def get_item(self, item_type, id_): - table_cache = self.get(item_type, {}) - item = table_cache.get(id_) - if item is None: - return {} - return item - - def fetch_more(self, item_type, limit=_LIMIT): - if item_type in self._fetched_item_types: - return [] - return self._advance_query(item_type, limit) - - def fetch_all(self, item_type): - while self.fetch_more(item_type): - pass - - def fetch_value(self, item_type, return_fn): - while self.fetch_more(item_type): - return_value = return_fn() - if return_value: - return return_value - return return_fn() - - def fetch_ref(self, item_type, id_): - while self.fetch_more(item_type): - ref = self.get_item(item_type, id_) - if ref: - return ref - # It is possible that fetching was completed between deciding to call this function - # and starting the while loop above resulting in self.fetch_more() to return False immediately. - # Therefore, we should try one last time if the ref is available. - ref = self.get_item(item_type, id_) - if ref: - return ref - - -class _TableCache(dict): - def __init__(self, db_cache, item_type, *args, **kwargs): - """ - Args: - db_cache (DBCacheBase): the DB cache where this table cache belongs. - item_type (str): the item type, equal to a table name - """ - super().__init__(*args, **kwargs) - self._db_cache = db_cache - self._item_type = item_type - self._id_by_unique_key_value = {} - self._temp_id_by_db_id = {} - - def get(self, id_, default=None): - id_ = self._temp_id_by_db_id.get(id_, id_) - return super().get(id_, default) - - def _new_id(self): - temp_id = TempId(self._item_type) - - def _callback(db_id): - self._temp_id_by_db_id[db_id] = temp_id - - temp_id.add_resolve_callback(_callback) - return temp_id - - def unique_key_value_to_id(self, key, value, strict=False, fetch=True): - """Returns the id that has the given value for the given unique key, or None if not found. - - Args: - key (tuple) - value (tuple) - strict (bool): if True, raise a KeyError if id is not found - fetch (bool): whether to fetch the DB until found. - - Returns: - int - """ - id_by_unique_value = self._id_by_unique_key_value.get(key, {}) - if not id_by_unique_value and fetch: - id_by_unique_value = self._db_cache.fetch_value( - self._item_type, lambda: self._id_by_unique_key_value.get(key, {}) - ) - value = tuple(tuple(x) if isinstance(x, list) else x for x in value) - if strict: - return id_by_unique_value[value] - return id_by_unique_value.get(value) - - def _unique_key_value_to_item(self, key, value, fetch=True): - return self.get(self.unique_key_value_to_id(key, value, fetch=fetch)) - - def valid_values(self): - return (x for x in self.values() if x.is_valid()) - - def _make_item(self, item): - """Returns a cache item. - - Args: - item (dict): the 'db item' to use as base - - Returns: - CacheItem - """ - return self._db_cache.make_item(self._item_type, **item) - - def find_item(self, item, skip_keys=(), fetch=True): - """Returns a CacheItemBase that matches the given dictionary-item. - - Args: - item (dict) - - Returns: - CacheItemBase or None - """ - id_ = item.get("id") - if id_ is not None: - # id is given, easy - item = self.get(id_) - if item or not fetch: - return item - return self._db_cache.fetch_ref(self._item_type, id_) - # No id. Try to locate the item by the value of one of the unique keys. - # Used by import_data (and more...) - cache_item = self._make_item(item) - error = cache_item.resolve_inverse_references(item.keys()) - if error: - return None - error = cache_item.polish() - if error: - return None - for key, value in cache_item.unique_values(skip_keys=skip_keys): - current_item = self._unique_key_value_to_item(key, value, fetch=fetch) - if current_item: - return current_item - - def check_item(self, item, for_update=False, skip_keys=()): - # FIXME: The only use-case for skip_keys at the moment is that of importing scenario alternatives, - # where we only want to match by (scen_name, alt_name) and not by (scen_name, rank) - if for_update: - current_item = self.find_item(item, skip_keys=skip_keys) - if current_item is None: - return None, f"no {self._item_type} matching {item} to update" - full_item, merge_error = current_item.merge(item) - if full_item is None: - return None, merge_error - else: - current_item = None - full_item, merge_error = item, None - candidate_item = self._make_item(full_item) - error = candidate_item.resolve_inverse_references(skip_keys=item.keys()) - if error: - return None, error - error = candidate_item.polish() - if error: - return None, error - first_invalid_key = candidate_item.first_invalid_key() - if first_invalid_key: - return None, f"invalid {first_invalid_key} for {self._item_type}" - try: - for key, value in candidate_item.unique_values(skip_keys=skip_keys): - empty = {k for k, v in zip(key, value) if v == ""} - if empty: - return None, f"invalid empty keys {empty} for {self._item_type}" - unique_item = self._unique_key_value_to_item(key, value) - if unique_item not in (None, current_item) and unique_item.is_valid(): - return None, f"there's already a {self._item_type} with {dict(zip(key, value))}" - except KeyError as e: - return None, f"missing {e} for {self._item_type}" - if "id" not in candidate_item: - candidate_item["id"] = self._new_id() - return candidate_item, merge_error - - def _add_unique(self, item): - for key, value in item.unique_values(): - self._id_by_unique_key_value.setdefault(key, {})[value] = item["id"] - - def _remove_unique(self, item): - for key, value in item.unique_values(): - id_by_value = self._id_by_unique_key_value.get(key, {}) - if id_by_value.get(value) == item["id"]: - del id_by_value[value] - - def add_item(self, item, new=False): - if not isinstance(item, CacheItemBase): - item = self._make_item(item) - item.polish() - if not new: - # Item comes from the DB - id_ = item["id"] - if id_ in self or id_ in self._temp_id_by_db_id: - # The item is already in the cache - return - if any(value in self._id_by_unique_key_value.get(key, {}) for key, value in item.unique_values()): - # An item with the same unique key is already in the cache - return - else: - item.status = Status.to_add - if "id" not in item or not item.is_id_valid: - item["id"] = self._new_id() - self[item["id"]] = item - self._add_unique(item) - return item - - def update_item(self, item): - current_item = self.find_item(item) - self._remove_unique(current_item) - current_item.update(item) - self._add_unique(current_item) - current_item.cascade_update() - return current_item - - def remove_item(self, id_): - current_item = self.find_item({"id": id_}) - if current_item is not None: - self._remove_unique(current_item) - current_item.cascade_remove() - return current_item - - def restore_item(self, id_): - current_item = self.find_item({"id": id_}) - if current_item is not None: - self._add_unique(current_item) - current_item.cascade_restore() - return current_item - - -class CacheItemBase(dict): - """A dictionary that represents a db item.""" - - _fields = {} - """A dictionaty mapping fields to a tuple of (type, description)""" - _defaults = {} - """A dictionary mapping keys to their default values""" - _unique_keys = () - """A tuple where each element is itself a tuple of keys that are unique""" - _references = {} - """A dictionary mapping keys that are not in the original dictionary, - to a recipe for finding the field they reference in another item. - - The recipe is a tuple of the form (original_field, (ref_item_type, ref_field)), - to be interpreted as follows: - 1. take the value from the original_field of this item, which should be an id, - 2. locate the item of type ref_item_type that has that id, - 3. return the value from the ref_field of that item. - """ - _inverse_references = {} - """Another dictionary mapping keys that are not in the original dictionary, - to a recipe for finding the field they reference in another item. - Used only for creating new items, when the user provides names and we want to find the ids. - - The recipe is a tuple of the form (src_unique_key, (ref_item_type, ref_unique_key)), - to be interpreted as follows: - 1. take the values from the src_unique_key of this item, to form a tuple, - 2. locate the item of type ref_item_type where the ref_unique_key is exactly that tuple of values, - 3. return the id of that item. - """ - - def __init__(self, db_cache, item_type, **kwargs): - """ - Args: - db_cache (DBCacheBase): the DB cache where this item belongs. - """ - super().__init__(**kwargs) - self._db_cache = db_cache - self._item_type = item_type - self._referrers = {} - self._weak_referrers = {} - self.restore_callbacks = set() - self.update_callbacks = set() - self.remove_callbacks = set() - self._is_id_valid = True - self._to_remove = False - self._removed = False - self._corrupted = False - self._valid = None - self._status = Status.committed - self._removal_source = None - self._status_when_removed = None - self._backup = None - - @classmethod - def ref_types(cls): - """Returns a set of item types that this class refers. - - Returns: - set(str) - """ - return set(ref_type for _src_key, (ref_type, _ref_key) in cls._references.values()) - - @property - def status(self): - """Returns the status of this item. - - Returns: - Status - """ - return self._status - - @status.setter - def status(self, status): - """Sets the status of this item. - - Args: - status (Status) - """ - self._status = status - - @property - def backup(self): - """Returns the committed version of this item. - - Returns: - dict or None - """ - return self._backup - - @property - def removed(self): - """Returns whether or not this item has been removed. - - Returns: - bool - """ - return self._removed - - @property - def item_type(self): - """Returns this item's type - - Returns: - str - """ - return self._item_type - - @property - def key(self): - """Returns a tuple (item_type, id) for convenience, or None if this item doesn't yet have an id. - TODO: When does the latter happen? - - Returns: - tuple(str,int) or None - """ - id_ = dict.get(self, "id") - if id_ is None: - return None - return (self._item_type, id_) - - @property - def is_id_valid(self): - return self._is_id_valid - - def invalidate_id(self): - """Sets id as invalid.""" - self._is_id_valid = False - - def _extended(self): - """Returns a dict from this item's original fields plus all the references resolved statically. - - Returns: - dict - """ - d = self._asdict() - d.update({key: self[key] for key in self._references}) - return d - - def _asdict(self): - """Returns a dict from this item's original fields. - - Returns: - dict - """ - return dict(self) - - def merge(self, other): - """Merges this item with another and returns the merged item together with any errors. - Used for updating items. - - Args: - other (dict): the item to merge into this. - - Returns: - dict: merged item. - str: error description if any. - """ - if all(self.get(key) == value for key, value in other.items()): - return None, "" - merged = {**self._extended(), **other} - if not isinstance(merged["id"], int): - merged["id"] = self["id"] - return merged, "" - - def first_invalid_key(self): - """Goes through the ``_references`` class attribute and returns the key of the first one - that cannot be resolved. - - Returns: - str or None: unresolved reference's key if any. - """ - for src_key, (ref_type, _ref_key) in self._references.values(): - try: - ref_id = self[src_key] - except KeyError: - return src_key - if isinstance(ref_id, tuple): - for x in ref_id: - if not self._get_ref(ref_type, x): - return src_key - elif not self._get_ref(ref_type, ref_id): - return src_key - - def unique_values(self, skip_keys=()): - """Yields tuples of unique keys and their values. - - Args: - skip_keys: Don't yield these keys - - Yields: - tuple(tuple,tuple): the first element is the unique key, the second is the values. - """ - for key in self._unique_keys: - if key not in skip_keys: - yield key, tuple(self.get(k) for k in key) - - def resolve_inverse_references(self, skip_keys=()): - """Goes through the ``_inverse_references`` class attribute and updates this item - by resolving those references. - Returns any error. - - Args: - skip_keys (tuple): don't resolve references for these keys. - - Returns: - str or None: error description if any. - """ - for src_key, (id_key, (ref_type, ref_key)) in self._inverse_references.items(): - if src_key in skip_keys: - continue - id_value = tuple(dict.pop(self, k, None) or self.get(k) for k in id_key) - if None in id_value: - continue - table_cache = self._db_cache.table_cache(ref_type) - try: - self[src_key] = ( - tuple(table_cache.unique_key_value_to_id(ref_key, v, strict=True) for v in zip(*id_value)) - if all(isinstance(v, (tuple, list)) for v in id_value) - else table_cache.unique_key_value_to_id(ref_key, id_value, strict=True) - ) - except KeyError as err: - # Happens at unique_key_value_to_id(..., strict=True) - return f"can't find {ref_type} with {dict(zip(ref_key, err.args[0]))}" - - def polish(self): - """Polishes this item once all it's references have been resolved. Returns any error. - - The base implementation sets defaults but subclasses can do more work if needed. - - Returns: - str or None: error description if any. - """ - for key, default_value in self._defaults.items(): - self.setdefault(key, default_value) - return "" - - def _get_ref(self, ref_type, ref_id, strong=True): - """Collects a reference from the cache. - Adds this item to the reference's list of referrers if strong is True; - or weak referrers if strong is False. - If the reference is not found, sets some flags. - - Args: - ref_type (str): The references's type - ref_id (int): The references's id - strong (bool): True if the reference corresponds to a foreign key, False otherwise - - Returns: - CacheItemBase or dict - """ - ref = self._db_cache.get_item(ref_type, ref_id) - if not ref: - if not strong: - return {} - ref = self._db_cache.fetch_ref(ref_type, ref_id) - if not ref: - self._corrupted = True - return {} - # Here we have a ref - if strong: - ref.add_referrer(self) - if ref.removed: - self._to_remove = True - else: - ref.add_weak_referrer(self) - if ref.removed: - return {} - return ref - - def _invalidate_ref(self, ref_type, ref_id): - """Invalidates a reference previously collected from the cache. - - Args: - ref_type (str): The references's type - ref_id (int): The references's id - """ - ref = self._db_cache.get_item(ref_type, ref_id) - ref.remove_referrer(self) - - def is_valid(self): - """Checks if this item has all its references. - Removes the item from the cache if not valid by calling ``cascade_remove``. - - Returns: - bool - """ - if self._valid is not None: - return self._valid - if self._removed or self._corrupted: - return False - self._to_remove = False - self._corrupted = False - for key in self._references: - _ = self[key] - if self._to_remove: - self.cascade_remove() - self._valid = not self._removed and not self._corrupted - return self._valid - - def add_referrer(self, referrer): - """Adds a strong referrer to this item. Strong referrers are removed, updated and restored - in cascade with this item. - - Args: - referrer (CacheItemBase) - """ - if referrer.key is None: - return - self._referrers[referrer.key] = self._weak_referrers.pop(referrer.key, referrer) - - def remove_referrer(self, referrer): - """Removes a strong referrer. - - Args: - referrer (CacheItemBase) - """ - if referrer.key is None: - return - self._referrers.pop(referrer.key, None) - - def add_weak_referrer(self, referrer): - """Adds a weak referrer to this item. - Weak referrers' update callbacks are called whenever this item changes. - - Args: - referrer (CacheItemBase) - """ - if referrer.key is None: - return - if referrer.key not in self._referrers: - self._weak_referrers[referrer.key] = referrer - - def _update_weak_referrers(self): - for weak_referrer in self._weak_referrers.values(): - weak_referrer.call_update_callbacks() - - def cascade_restore(self, source=None): - """Restores this item (if removed) and all its referrers in cascade. - Also, updates items' status and calls their restore callbacks. - """ - if not self._removed: - return - if source is not self._removal_source: - return - if self.status in (Status.added_and_removed, Status.to_remove): - self._status = self._status_when_removed - elif self.status == Status.committed: - self._status = Status.to_add - else: - raise RuntimeError("invalid status for item being restored") - self._removed = False - # First restore this, then referrers - obsolete = set() - for callback in list(self.restore_callbacks): - if not callback(self): - obsolete.add(callback) - self.restore_callbacks -= obsolete - for referrer in self._referrers.values(): - referrer.cascade_restore(source=self) - self._update_weak_referrers() - - def cascade_remove(self, source=None): - """Removes this item and all its referrers in cascade. - Also, updates items' status and calls their remove callbacks. - """ - if self._removed: - return - self._status_when_removed = self._status - if self._status == Status.to_add: - self._status = Status.added_and_removed - elif self._status in (Status.committed, Status.to_update): - self._status = Status.to_remove - else: - raise RuntimeError("invalid status for item being removed") - self._removal_source = source - self._removed = True - self._to_remove = False - self._valid = None - # First remove referrers, then this - for referrer in self._referrers.values(): - referrer.cascade_remove(source=self) - self._update_weak_referrers() - obsolete = set() - for callback in list(self.remove_callbacks): - if not callback(self): - obsolete.add(callback) - self.remove_callbacks -= obsolete - - def cascade_update(self): - """Updates this item and all its referrers in cascade. - Also, calls items' update callbacks. - """ - self.call_update_callbacks() - for referrer in self._referrers.values(): - referrer.cascade_update() - self._update_weak_referrers() - - def call_update_callbacks(self): - obsolete = set() - for callback in list(self.update_callbacks): - if not callback(self): - obsolete.add(callback) - self.update_callbacks -= obsolete - - def is_committed(self): - """Returns whether or not this item is committed to the DB. - - Returns: - bool - """ - return self._status == Status.committed - - def commit(self, commit_id): - """Sets this item as committed with the given commit id.""" - self._status = Status.committed - if commit_id: - self["commit_id"] = commit_id - - def __repr__(self): - """Overridden to return a more verbose representation.""" - return f"{self._item_type}{self._extended()}" - - def __getattr__(self, name): - """Overridden to return the dictionary key named after the attribute, or None if it doesn't exist.""" - # FIXME: We should try and get rid of this one - return self.get(name) - - def __getitem__(self, key): - """Overridden to return references.""" - ref = self._references.get(key) - if ref: - src_key, (ref_type, ref_key) = ref - ref_id = self[src_key] - if isinstance(ref_id, tuple): - return tuple(self._get_ref(ref_type, x).get(ref_key) for x in ref_id) - return self._get_ref(ref_type, ref_id).get(ref_key) - return super().__getitem__(key) - - def __setitem__(self, key, value): - """Sets id valid if key is 'id'.""" - if key == "id": - self._is_id_valid = True - super().__setitem__(key, value) - - def get(self, key, default=None): - """Overridden to return references.""" - try: - return self[key] - except KeyError: - return default - - def update(self, other): - """Overridden to update the item status and also to invalidate references that become obsolete.""" - if self._status == Status.committed: - self._status = Status.to_update - self._backup = self._asdict() - elif self._status in (Status.to_remove, Status.added_and_removed): - raise RuntimeError("invalid status of item being updated") - for src_key, (ref_type, _ref_key) in self._references.values(): - ref_id = self[src_key] - if src_key in other and other[src_key] != ref_id: - # Invalidate references - if isinstance(ref_id, tuple): - for x in ref_id: - self._invalidate_ref(ref_type, x) - else: - self._invalidate_ref(ref_type, ref_id) - super().update(other) - if self._asdict() == self._backup: - self._status = Status.committed diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index bbf548b7..01252863 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -13,16 +13,45 @@ This module defines the :class:`.DatabaseMapping` class. """ -import sqlalchemy.exc +import hashlib +import os +import time +import logging +from types import MethodType +from sqlalchemy import create_engine, MetaData, inspect +from sqlalchemy.pool import NullPool +from sqlalchemy.event import listen +from sqlalchemy.exc import DatabaseError +from sqlalchemy.engine.url import make_url, URL +from alembic.migration import MigrationContext +from alembic.environment import EnvironmentContext +from alembic.script import ScriptDirectory +from alembic.config import Config +from alembic.util.exc import CommandError + +from .filters.tools import pop_filter_configs, apply_filter_stack, load_filters +from .spine_db_client import get_db_url_from_server +from .mapped_items import item_factory from .db_mapping_base import DatabaseMappingBase +from .db_mapping_query_mixin import DatabaseMappingQueryMixin from .db_mapping_add_mixin import DatabaseMappingAddMixin from .db_mapping_update_mixin import DatabaseMappingUpdateMixin from .db_mapping_remove_mixin import DatabaseMappingRemoveMixin from .db_mapping_commit_mixin import DatabaseMappingCommitMixin -from .db_cache_impl import DBCache +from .exception import SpineDBAPIError, SpineDBVersionError +from .helpers import ( + _create_first_spine_database, + create_new_spine_database, + compare_schemas, + model_meta, + copy_database_bind, +) + +logging.getLogger("alembic").setLevel(logging.CRITICAL) class DatabaseMapping( + DatabaseMappingQueryMixin, DatabaseMappingAddMixin, DatabaseMappingUpdateMixin, DatabaseMappingRemoveMixin, @@ -31,7 +60,7 @@ class DatabaseMapping( ): """Enables communication with a Spine DB. - A mapping of the DB is incrementally created in memory as data is requested/modified. + The DB is incrementally mapped into memory as data is requested/modified. Data is typically retrieved using :meth:`get_item` or :meth:`get_items`. If the requested data is already in memory, it is returned from there; @@ -53,7 +82,7 @@ class DatabaseMapping( or rolled back (discarded) via :meth:`rollback_session`. The DB fetch status is reset via :meth:`refresh_session`. - This causes new items in the DB to be merged into the memory mapping as data is further requested/modified. + This allows new items in the DB (added by other clients in the meantime) to be retrieved as well. You can also control the fetching process via :meth:`fetch_more` and/or :meth:`fetch_all`. For example, a UI application might want to fetch data in the background so the UI is not blocked in the process. @@ -63,6 +92,292 @@ class DatabaseMapping( while bypassing the in-memory mapping entirely. """ + ITEM_TYPES = ( + "entity_class", + "entity", + "entity_group", + "alternative", + "scenario", + "scenario_alternative", + "entity_alternative", + "parameter_value_list", + "list_value", + "parameter_definition", + "parameter_value", + "metadata", + "entity_metadata", + "parameter_value_metadata", + ) + _sq_name_by_item_type = { + "entity_class": "wide_entity_class_sq", + "entity": "wide_entity_sq", + "entity_alternative": "entity_alternative_sq", + "parameter_value_list": "parameter_value_list_sq", + "list_value": "list_value_sq", + "alternative": "alternative_sq", + "scenario": "scenario_sq", + "scenario_alternative": "scenario_alternative_sq", + "entity_group": "entity_group_sq", + "parameter_definition": "parameter_definition_sq", + "parameter_value": "parameter_value_sq", + "metadata": "metadata_sq", + "entity_metadata": "entity_metadata_sq", + "parameter_value_metadata": "parameter_value_metadata_sq", + "commit": "commit_sq", + } + + def __init__( + self, + db_url, + username=None, + upgrade=False, + codename=None, + create=False, + apply_filters=True, + memory=False, + sqlite_timeout=1800, + ): + """ + Args: + db_url (str or :class:`~sqlalchemy.engine.url.URL`): A URL in RFC-1738 format pointing to the database + to be mapped, or to a DB server. + username (str, optional): A user name. If not given, it gets replaced by the string ``"anon"``. + upgrade (bool, optional): Whether the db at the given URL should be upgraded to the most recent + version. + codename (str, optional): A name to associate with the DB mapping. + create (bool, optional): Whether to create a Spine db at the given URL if it's not one already. + apply_filters (bool, optional): Whether to apply filters in the URL's query part. + memory (bool, optional): Whether or not to use a sqlite memory db as replacement for this DB map. + sqlite_timeout (int, optional): How many seconds to wait before raising connection errors. + """ + super().__init__() + # FIXME: We should also check the server memory property and use it here + db_url = get_db_url_from_server(db_url) + self.db_url = str(db_url) + if isinstance(db_url, str): + filter_configs, db_url = pop_filter_configs(db_url) + elif isinstance(db_url, URL): + filter_configs = db_url.query.pop("spinedbfilter", []) + else: + filter_configs = [] + self._filter_configs = filter_configs if apply_filters else None + self.sa_url = make_url(db_url) + self.username = username if username else "anon" + self.codename = self._make_codename(codename) + self._memory = memory + self._memory_dirty = False + self._original_engine = self.create_engine( + self.sa_url, upgrade=upgrade, create=create, sqlite_timeout=sqlite_timeout + ) + # NOTE: The NullPool is needed to receive the close event (or any events), for some reason + self.engine = create_engine("sqlite://", poolclass=NullPool) if self._memory else self._original_engine + listen(self.engine, 'close', self._receive_engine_close) + if self._memory: + copy_database_bind(self.engine, self._original_engine) + self._metadata = MetaData(self.engine) + self._metadata.reflect() + self._tablenames = [t.name for t in self._metadata.sorted_tables] + self.closed = False + if self._filter_configs is not None: + stack = load_filters(self._filter_configs) + apply_filter_stack(self, stack) + # Table primary ids map: + self._id_fields = { + "entity_class_dimension": "entity_class_id", + "entity_element": "entity_id", + "object_class": "entity_class_id", + "relationship_class": "entity_class_id", + "object": "entity_id", + "relationship": "entity_id", + } + self.composite_pks = { + "entity_element": ("entity_id", "position"), + "entity_class_dimension": ("entity_class_id", "position"), + } + + def __enter__(self): + return self + + def __exit__(self, _exc_type, _exc_val, _exc_tb): + self.close() + + def __del__(self): + self.close() + + @property + def item_types(self): + return list(self._sq_name_by_item_type) + + @staticmethod + def item_factory(item_type): + return item_factory(item_type) + + def make_query(self, item_type): + if self.closed: + return None + sq_name = self._sq_name_by_item_type[item_type] + return self.query(getattr(self, sq_name)) + + def close(self): + """Closes this DB mapping.""" + self.closed = True + + def _make_codename(self, codename): + if codename: + return str(codename) + if not self.sa_url.drivername.startswith("sqlite"): + return self.sa_url.database + if self.sa_url.database is not None: + return os.path.basename(self.sa_url.database) + hashing = hashlib.sha1() + hashing.update(bytes(str(time.time()), "utf-8")) + return hashing.hexdigest() + + @staticmethod + def create_engine(sa_url, upgrade=False, create=False, sqlite_timeout=1800): + """Creates engine. + + Args: + sa_url (URL) + upgrade (bool, optional): If True, upgrade the db to the latest version. + create (bool, optional): If True, create a new Spine db at the given url if none found. + + Returns: + :class:`~sqlalchemy.engine.Engine` + """ + if sa_url.drivername == "sqlite": + connect_args = {'timeout': sqlite_timeout} + else: + connect_args = {} + try: + engine = create_engine(sa_url, connect_args=connect_args) + with engine.connect(): + pass + except Exception as e: + raise SpineDBAPIError( + f"Could not connect to '{sa_url}': {str(e)}. " + f"Please make sure that '{sa_url}' is a valid sqlalchemy URL." + ) from None + config = Config() + config.set_main_option("script_location", "spinedb_api:alembic") + script = ScriptDirectory.from_config(config) + head = script.get_current_head() + with engine.connect() as connection: + migration_context = MigrationContext.configure(connection) + try: + current = migration_context.get_current_revision() + except DatabaseError as error: + raise SpineDBAPIError(str(error)) from None + if current is None: + # No revision information. Check that the schema of the given url corresponds to a 'first' Spine db + # Otherwise we either raise or create a new Spine db at the url. + ref_engine = _create_first_spine_database("sqlite://") + if not compare_schemas(engine, ref_engine): + if not create or inspect(engine).get_table_names(): + raise SpineDBAPIError( + "Unable to determine db revision. " + f"Please check that\n\n\t{sa_url}\n\nis the URL of a valid Spine db." + ) + return create_new_spine_database(sa_url) + if current != head: + if not upgrade: + try: + script.get_revision(current) # Check if current revision is part of alembic rev. history + except CommandError: + # Can't find 'current' revision + raise SpineDBVersionError( + url=sa_url, current=current, expected=head, upgrade_available=False + ) from None + raise SpineDBVersionError(url=sa_url, current=current, expected=head) + + # Upgrade function + def upgrade_to_head(rev, context): + return script._upgrade_revs("head", rev) + + with EnvironmentContext( + config, + script, + fn=upgrade_to_head, + as_sql=False, + starting_rev=None, + destination_rev="head", + tag=None, + ) as environment_context: + environment_context.configure(connection=connection, target_metadata=model_meta) + with environment_context.begin_transaction(): + environment_context.run_migrations() + return engine + + def _receive_engine_close(self, dbapi_con, _connection_record): + if self._memory_dirty: + copy_database_bind(self._original_engine, self.engine) + + def _get_primary_key(self, tablename): + pk = self.composite_pks.get(tablename) + if pk is None: + id_field = self._id_fields.get(tablename, "id") + pk = (id_field,) + return pk + + @staticmethod + def _real_tablename(tablename): + return { + "object_class": "entity_class", + "relationship_class": "entity_class", + "object": "entity", + "relationship": "entity", + }.get(tablename, tablename) + + @staticmethod + def _convert_legacy(tablename, item): + if tablename in ("entity_class", "entity"): + object_class_id_list = tuple(item.pop("object_class_id_list", ())) + if object_class_id_list: + item["dimension_id_list"] = object_class_id_list + object_class_name_list = tuple(item.pop("object_class_name_list", ())) + if object_class_name_list: + item["dimension_name_list"] = object_class_name_list + if tablename == "entity": + object_id_list = tuple(item.pop("object_id_list", ())) + if object_id_list: + item["element_id_list"] = object_id_list + object_name_list = tuple(item.pop("object_name_list", ())) + if object_name_list: + item["element_name_list"] = object_name_list + if tablename in ("parameter_definition", "parameter_value"): + entity_class_id = item.pop("object_class_id", None) or item.pop("relationship_class_id", None) + if entity_class_id: + item["entity_class_id"] = entity_class_id + if tablename == "parameter_value": + entity_id = item.pop("object_id", None) or item.pop("relationship_id", None) + if entity_id: + item["entity_id"] = entity_id + + def get_import_alternative_name(self): + if self._import_alternative_name is None: + self._create_import_alternative() + return self._import_alternative_name + + def _create_import_alternative(self): + """Creates the alternative to be used as default for all import operations.""" + self._import_alternative_name = "Base" + + def override_create_import_alternative(self, method): + self._create_import_alternative = MethodType(method, self) + self._import_alternative_name = None + + def get_filter_configs(self): + """Returns filters applicable to this DB mapping. + + Returns: + list(dict): + """ + return self._filter_configs + + def get_table(self, tablename): + # For tests + return self._metadata.tables[tablename] + def get_item(self, item_type, fetch=True, skip_removed=True, **kwargs): """Finds and returns and item matching the arguments, or None if none found. @@ -76,7 +391,7 @@ def get_item(self, item_type, fetch=True, skip_removed=True, **kwargs): :class:`PublicItem` or None """ item_type = self._real_tablename(item_type) - cache_item = self.cache.table_cache(item_type).find_item(kwargs, fetch=fetch) + cache_item = self.mapped_table(item_type).find_item(kwargs, fetch=fetch) if not cache_item: return None if skip_removed and not cache_item.is_valid(): @@ -95,10 +410,10 @@ def get_items(self, item_type, fetch=True, skip_removed=True): :class:`PublicItem` or None """ item_type = self._real_tablename(item_type) - if fetch and item_type not in self.cache.fetched_item_types: + if fetch and item_type not in self.fetched_item_types: self.fetch_all(item_type) - table_cache = self.cache.table_cache(item_type) - get_items = table_cache.valid_values if skip_removed else table_cache.values + mapped_table = self.mapped_table(item_type) + get_items = mapped_table.valid_values if skip_removed else mapped_table.values return [PublicItem(self, x) for x in get_items()] def add_item(self, item_type, check=True, **kwargs): @@ -119,13 +434,13 @@ def add_item(self, item_type, check=True, **kwargs): tuple(:class:`PublicItem` or None, str): The added item and any errors. """ item_type = self._real_tablename(item_type) - table_cache = self.cache.table_cache(item_type) + mapped_table = self.mapped_table(item_type) self._convert_legacy(item_type, kwargs) if not check: - return table_cache.add_item(kwargs, new=True), None - checked_item, error = table_cache.check_item(kwargs) + return mapped_table.add_item(kwargs, new=True), None + checked_item, error = mapped_table.check_item(kwargs) return ( - PublicItem(self, table_cache.add_item(checked_item, new=True)) if checked_item and not error else None, + PublicItem(self, mapped_table.add_item(checked_item, new=True)) if checked_item and not error else None, error, ) @@ -148,12 +463,12 @@ def update_item(self, item_type, check=True, **kwargs): tuple(:class:`PublicItem` or None, str): The added item and any errors. """ item_type = self._real_tablename(item_type) - table_cache = self.cache.table_cache(item_type) + mapped_table = self.mapped_table(item_type) self._convert_legacy(item_type, kwargs) if not check: - return table_cache.update_item(kwargs), None - checked_item, error = table_cache.check_item(kwargs, for_update=True) - return (PublicItem(self, table_cache.update_item(checked_item._asdict())) if checked_item else None, error) + return mapped_table.update_item(kwargs), None + checked_item, error = mapped_table.check_item(kwargs, for_update=True) + return (PublicItem(self, mapped_table.update_item(checked_item._asdict())) if checked_item else None, error) def remove_item(self, item_type, id_): """Removes an item from the in-memory mapping. @@ -173,8 +488,8 @@ def remove_item(self, item_type, id_): tuple(:class:`PublicItem` or None, str): The removed item if any. """ item_type = self._real_tablename(item_type) - table_cache = self.cache.table_cache(item_type) - return PublicItem(self, table_cache.remove_item(id_)) + mapped_table = self.mapped_table(item_type) + return PublicItem(self, mapped_table.remove_item(id_)) def restore_item(self, item_type, id_): """Restores a previously removed item into the in-memory mapping. @@ -193,8 +508,8 @@ def restore_item(self, item_type, id_): tuple(:class:`PublicItem` or None, str): The restored item if any. """ item_type = self._real_tablename(item_type) - table_cache = self.cache.table_cache(item_type) - return PublicItem(self, table_cache.restore_item(id_)) + mapped_table = self.mapped_table(item_type) + return PublicItem(self, mapped_table.restore_item(id_)) def can_fetch_more(self, item_type): """Whether or not more data can be fetched from the DB for the given item type. @@ -205,9 +520,9 @@ def can_fetch_more(self, item_type): Returns: bool """ - return item_type not in self.cache.fetched_item_types + return item_type not in self.fetched_item_types - def fetch_more(self, item_type, limit): + def fetch_more(self, item_type, limit=None): """Fetches items from the DB into the in-memory mapping, incrementally. Args: @@ -220,7 +535,7 @@ def fetch_more(self, item_type, limit): list(PublicItem): The items fetched. """ item_type = self._real_tablename(item_type) - return self.cache.fetch_more(item_type, limit=limit) + return [PublicItem(self, x) for x in self.do_fetch_more(item_type, limit=limit)] def fetch_all(self, *item_types): """Fetches items from the DB into the in-memory mapping. @@ -232,16 +547,15 @@ def fetch_all(self, *item_types): item_types = set(self.ITEM_TYPES) if not item_types else set(item_types) & set(self.ITEM_TYPES) for item_type in item_types: item_type = self._real_tablename(item_type) - self.cache.fetch_all(item_type) + self.do_fetch_all(item_type) - @staticmethod - def describe_item_type(item_type): + def describe_item_type(self, item_type): """Prints a synopsis of the given item type to the stdout. Args: item_type (str): The type of item to describe. """ - factory = DBCache.item_factory(item_type) + factory = self.item_factory(item_type) sections = ("Fields:", "Unique keys:") width = max(len(s) for s in sections) + 4 print() diff --git a/spinedb_api/db_mapping_add_mixin.py b/spinedb_api/db_mapping_add_mixin.py index 58f61ae5..6a7cc734 100644 --- a/spinedb_api/db_mapping_add_mixin.py +++ b/spinedb_api/db_mapping_add_mixin.py @@ -176,4 +176,4 @@ def add_ext_parameter_value_metadata(self, *items, **kwargs): def get_metadata_to_add_with_item_metadata_items(self, *items): metadata_items = ({"name": item["metadata_name"], "value": item["metadata_value"]} for item in items) - return [x for x in metadata_items if not self.cache.table_cache("metadata").find_item(x)] + return [x for x in metadata_items if not self.mapped_table("metadata").find_item(x)] diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 42ae72c4..917b67af 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -9,1762 +9,910 @@ # this program. If not, see . ###################################################################################################################### -import hashlib -import os -import logging -import time -from types import MethodType -from sqlalchemy import create_engine, MetaData, Table, Integer, inspect, case, func, cast, and_, or_ -from sqlalchemy.sql.expression import Alias, label -from sqlalchemy.engine.url import make_url, URL -from sqlalchemy.orm import aliased -from sqlalchemy.exc import DatabaseError -from sqlalchemy.event import listen -from sqlalchemy.pool import NullPool -from alembic.migration import MigrationContext -from alembic.environment import EnvironmentContext -from alembic.script import ScriptDirectory -from alembic.config import Config -from alembic.util.exc import CommandError -from .exception import SpineDBAPIError, SpineDBVersionError -from .helpers import ( - _create_first_spine_database, - create_new_spine_database, - compare_schemas, - forward_sweep, - group_concat, - model_meta, - copy_database_bind, -) -from .filters.tools import pop_filter_configs, apply_filter_stack, load_filters -from .spine_db_client import get_db_url_from_server -from .db_cache_impl import DBCache -from .query import Query - -logging.getLogger("alembic").setLevel(logging.CRITICAL) +import threading +from enum import Enum, unique, auto +from .temp_id import TempId +# TODO: Implement MappedItem.pop() to do lookup? -class DatabaseMappingBase: - """Base class for all database mappings. - - Provides the :meth:`query` method for performing custom ``SELECT`` queries. - """ +_LIMIT = 10000 - _session_kwargs = {} - ITEM_TYPES = ( - "entity_class", - "parameter_value_list", - "list_value", - "parameter_definition", - "entity", - "entity_group", - "parameter_value", - "alternative", - "scenario", - "scenario_alternative", - "metadata", - "entity_metadata", - "parameter_value_metadata", - ) - - def __init__( - self, - db_url, - username=None, - upgrade=False, - codename=None, - create=False, - apply_filters=True, - memory=False, - sqlite_timeout=1800, - ): - """ - Args: - db_url (str or :class:`~sqlalchemy.engine.url.URL`): A URL in RFC-1738 format pointing to the database - to be mapped, or to a DB server. - username (str, optional): A user name. If not given, it gets replaced by the string ``"anon"``. - upgrade (bool, optional): Whether the db at the given URL should be upgraded to the most recent - version. - codename (str, optional): A name to associate with the DB mapping. - create (bool, optional): Whether to create a Spine db at the given URL if it's not one already. - apply_filters (bool, optional): Whether to apply filters in the URL's query part. - memory (bool, optional): Whether or not to use a sqlite memory db as replacement for this DB map. - sqlite_timeout (int, optional): How many seconds to wait before raising connection errors. - """ - # FIXME: We should also check the server memory property and use it here - db_url = get_db_url_from_server(db_url) - self.db_url = str(db_url) - if isinstance(db_url, str): - filter_configs, db_url = pop_filter_configs(db_url) - elif isinstance(db_url, URL): - filter_configs = db_url.query.pop("spinedbfilter", []) - else: - filter_configs = [] - self._filter_configs = filter_configs if apply_filters else None - self.sa_url = make_url(db_url) - self.username = username if username else "anon" - self.codename = self._make_codename(codename) - self._memory = memory - self._memory_dirty = False - self._original_engine = self.create_engine( - self.sa_url, upgrade=upgrade, create=create, sqlite_timeout=sqlite_timeout - ) - # NOTE: The NullPool is needed to receive the close event (or any events), for some reason - self.engine = create_engine("sqlite://", poolclass=NullPool) if self._memory else self._original_engine - listen(self.engine, 'close', self._receive_engine_close) - if self._memory: - copy_database_bind(self.engine, self._original_engine) - self._metadata = MetaData(self.engine) - self._metadata.reflect() - self._tablenames = [t.name for t in self._metadata.sorted_tables] - self.cache = DBCache(self) - self.closed = False - # Subqueries that select everything from each table - self._commit_sq = None - self._alternative_sq = None - self._scenario_sq = None - self._scenario_alternative_sq = None - self._entity_class_sq = None - self._entity_sq = None - self._entity_class_dimension_sq = None - self._entity_element_sq = None - self._entity_alternative_sq = None - self._object_class_sq = None - self._object_sq = None - self._relationship_class_sq = None - self._relationship_sq = None - self._entity_group_sq = None - self._parameter_definition_sq = None - self._parameter_value_sq = None - self._parameter_value_list_sq = None - self._list_value_sq = None - self._metadata_sq = None - self._parameter_value_metadata_sq = None - self._entity_metadata_sq = None - # Special convenience subqueries that join two or more tables - self._wide_entity_class_sq = None - self._wide_entity_sq = None - self._ext_parameter_value_list_sq = None - self._wide_parameter_value_list_sq = None - self._ord_list_value_sq = None - self._ext_scenario_sq = None - self._wide_scenario_sq = None - self._linked_scenario_alternative_sq = None - self._ext_linked_scenario_alternative_sq = None - self._ext_object_sq = None - self._ext_relationship_class_sq = None - self._wide_relationship_class_sq = None - self._ext_relationship_class_object_parameter_definition_sq = None - self._wide_relationship_class_object_parameter_definition_sq = None - self._ext_relationship_sq = None - self._wide_relationship_sq = None - self._ext_entity_group_sq = None - self._entity_parameter_definition_sq = None - self._object_parameter_definition_sq = None - self._relationship_parameter_definition_sq = None - self._entity_parameter_value_sq = None - self._object_parameter_value_sq = None - self._relationship_parameter_value_sq = None - self._ext_parameter_value_metadata_sq = None - self._ext_entity_metadata_sq = None - self._import_alternative_name = None - self._table_to_sq_attr = {} - # Table primary ids map: - self._id_fields = { - "entity_class_dimension": "entity_class_id", - "entity_element": "entity_id", - "object_class": "entity_class_id", - "relationship_class": "entity_class_id", - "object": "entity_id", - "relationship": "entity_id", - } - self.composite_pks = { - "entity_element": ("entity_id", "position"), - "entity_class_dimension": ("entity_class_id", "position"), - } - if self._filter_configs is not None: - stack = load_filters(self._filter_configs) - apply_filter_stack(self, stack) - def __enter__(self): - return self +@unique +class Status(Enum): + """Cache item status.""" - def __exit__(self, _exc_type, _exc_val, _exc_tb): - self.close() + committed = auto() + to_add = auto() + to_update = auto() + to_remove = auto() + added_and_removed = auto() - def get_filter_configs(self): - """Returns filters applicable to this DB mapping. - Returns: - list(dict): - """ - return self._filter_configs - - def close(self): - """Closes this DB mapping.""" - self.closed = True - - @staticmethod - def _real_tablename(tablename): - return { - "object_class": "entity_class", - "relationship_class": "entity_class", - "object": "entity", - "relationship": "entity", - }.get(tablename, tablename) - - def get_table(self, tablename): - # For tests - return self._metadata.tables[tablename] - - def _make_codename(self, codename): - if codename: - return str(codename) - if not self.sa_url.drivername.startswith("sqlite"): - return self.sa_url.database - if self.sa_url.database is not None: - return os.path.basename(self.sa_url.database) - hashing = hashlib.sha1() - hashing.update(bytes(str(time.time()), "utf-8")) - return hashing.hexdigest() - - @staticmethod - def create_engine(sa_url, upgrade=False, create=False, sqlite_timeout=1800): - """Creates engine. - - Args: - sa_url (URL) - upgrade (bool, optional): If True, upgrade the db to the latest version. - create (bool, optional): If True, create a new Spine db at the given url if none found. - - Returns: - :class:`~sqlalchemy.engine.Engine` - """ - if sa_url.drivername == "sqlite": - connect_args = {'timeout': sqlite_timeout} - else: - connect_args = {} - try: - engine = create_engine(sa_url, connect_args=connect_args) - with engine.connect(): - pass - except Exception as e: - raise SpineDBAPIError( - f"Could not connect to '{sa_url}': {str(e)}. " - f"Please make sure that '{sa_url}' is a valid sqlalchemy URL." - ) from None - config = Config() - config.set_main_option("script_location", "spinedb_api:alembic") - script = ScriptDirectory.from_config(config) - head = script.get_current_head() - with engine.connect() as connection: - migration_context = MigrationContext.configure(connection) - try: - current = migration_context.get_current_revision() - except DatabaseError as error: - raise SpineDBAPIError(str(error)) from None - if current is None: - # No revision information. Check that the schema of the given url corresponds to a 'first' Spine db - # Otherwise we either raise or create a new Spine db at the url. - ref_engine = _create_first_spine_database("sqlite://") - if not compare_schemas(engine, ref_engine): - if not create or inspect(engine).get_table_names(): - raise SpineDBAPIError( - "Unable to determine db revision. " - f"Please check that\n\n\t{sa_url}\n\nis the URL of a valid Spine db." - ) - return create_new_spine_database(sa_url) - if current != head: - if not upgrade: - try: - script.get_revision(current) # Check if current revision is part of alembic rev. history - except CommandError: - # Can't find 'current' revision - raise SpineDBVersionError( - url=sa_url, current=current, expected=head, upgrade_available=False - ) from None - raise SpineDBVersionError(url=sa_url, current=current, expected=head) - - # Upgrade function - def upgrade_to_head(rev, context): - return script._upgrade_revs("head", rev) - - with EnvironmentContext( - config, - script, - fn=upgrade_to_head, - as_sql=False, - starting_rev=None, - destination_rev="head", - tag=None, - ) as environment_context: - environment_context.configure(connection=connection, target_metadata=model_meta) - with environment_context.begin_transaction(): - environment_context.run_migrations() - return engine - - def _receive_engine_close(self, dbapi_con, _connection_record): - if self._memory_dirty: - copy_database_bind(self._original_engine, self.engine) - - def _get_table_to_sq_attr(self): - if not self._table_to_sq_attr: - self._table_to_sq_attr = self._make_table_to_sq_attr() - return self._table_to_sq_attr - - def _make_table_to_sq_attr(self): - """Returns a dict mapping table names to subquery attribute names, involving that table.""" - - def _func(x, tables): - if isinstance(x, Table): - tables.add(x.name) # pylint: disable=cell-var-from-loop - - # This 'loads' our subquery attributes - for attr in dir(self): - getattr(self, attr) - table_to_sq_attr = {} - for attr, val in vars(self).items(): - if not isinstance(val, Alias): - continue - tables = set() - forward_sweep(val, _func, tables) - # Now `tables` contains all tables related to `val` - for table in tables: - table_to_sq_attr.setdefault(table, set()).add(attr) - return table_to_sq_attr - - def _clear_subqueries(self, *tablenames): - """Set to `None` subquery attributes involving the affected tables. - This forces the subqueries to be refreshed when the corresponding property is accessed. - """ - tablenames = list(tablenames) - for tablename in tablenames: - if self.cache.pop(tablename, False): - self.cache.fetch_all(tablename) - attr_names = set(attr for tablename in tablenames for attr in self._get_table_to_sq_attr().get(tablename, [])) - for attr_name in attr_names: - setattr(self, attr_name, None) - - def query(self, *args, **kwargs): - """Returns a :class:`~spinedb_api.query.Query` object to execute against this DB. - - To perform custom ``SELECT`` statements, call this method with one or more of the class documented - subquery properties (of :class:`~sqlalchemy.sql.expression.Alias` type). - For example, to select the entity class with ``id`` equal to 1:: - - from spinedb_api import DatabaseMapping - url = 'sqlite:///spine.db' - ... - db_map = DatabaseMapping(url) - db_map.query(db_map.entity_class_sq).filter_by(id=1).one_or_none() - - To perform more complex queries, just use the :class:`~spinedb_api.query.Query` interface - (which is a close clone of SQL Alchemy's :class:`~sqlalchemy.orm.query.Query`). - For example, to select all entity class names and the names of their entities concatenated in a comma-separated - string:: - - from sqlalchemy import func - - db_map.query( - db_map.entity_class_sq.c.name, func.group_concat(db_map.entity_sq.c.name) - ).filter( - db_map.entity_sq.c.class_id == db_map.entity_class_sq.c.id - ).group_by(db_map.entity_class_sq.c.name).all() - """ - return Query(self.engine, *args) +class DatabaseMappingBase: + """An in-memory mapping of a DB, mapping item types (table names), to numeric ids, to items. - def _subquery(self, tablename): - """A subquery of the form: + This class is not meant to be used directly. Instead, you should subclass it to fit your particular DB schema. - .. code-block:: sql + When subclassing, you need to implement :attr:`item_types`, :meth:`item_factory`, and :meth:`make_query`. + """ - SELECT * FROM tablename - - Args: - tablename (str): the table to be queried. - - Returns: - :class:`~sqlalchemy.sql.expression.Alias` - """ - table = self._metadata.tables[tablename] - return self.query(table).subquery(tablename + "_sq") + def __init__(self): + self._mapped_tables = {} + self._offsets = {} + self._offset_lock = threading.Lock() + self._fetched_item_types = set() + item_types = self.item_types + self._sorted_item_types = [] + while item_types: + item_type = item_types.pop(0) + if self.item_factory(item_type).ref_types() & set(item_types): + item_types.append(item_type) + else: + self._sorted_item_types.append(item_type) @property - def entity_class_sq(self): - """A subquery of the form: - - .. code-block:: sql - - SELECT * FROM entity_class + def fetched_item_types(self): + """Returns a set with the item types that are already fetched. Returns: - :class:`~sqlalchemy.sql.expression.Alias` + set """ - if self._entity_class_sq is None: - self._entity_class_sq = self._make_entity_class_sq() - return self._entity_class_sq + return self._fetched_item_types @property - def entity_class_dimension_sq(self): - """A subquery of the form: - - .. code-block:: sql - - SELECT * FROM entity_class_dimension + def item_types(self): + """Returns a list of item types from the DB schema (equivalent to the table names). Returns: - :class:`~sqlalchemy.sql.expression.Alias` + list(str) """ - if self._entity_class_dimension_sq is None: - self._entity_class_dimension_sq = self._subquery("entity_class_dimension") - return self._entity_class_dimension_sq + raise NotImplementedError() - @property - def wide_entity_class_sq(self): - """A subquery of the form: - - .. code-block:: sql + @staticmethod + def item_factory(item_type): + """Returns a subclass of :class:`.MappedItemBase` to make items of given type. - SELECT - ec.*, - count(ecd.dimension_id) AS dimension_count - group_concat(ecd.dimension_id) AS dimension_id_list - FROM - entity_class AS ec - entity_class_dimension AS ecd - WHERE - ec.id == ecd.entity_class_id + Args: + item_type (str) Returns: - :class:`~sqlalchemy.sql.expression.Alias` + function """ - if self._wide_entity_class_sq is None: - entity_class_dimension_sq = ( - self.query( - self.entity_class_dimension_sq.c.entity_class_id, - self.entity_class_dimension_sq.c.dimension_id, - self.entity_class_dimension_sq.c.position, - self.entity_class_sq.c.name.label("dimension_name"), - ) - .filter(self.entity_class_dimension_sq.c.dimension_id == self.entity_class_sq.c.id) - .subquery("entity_class_dimension_sq") - ) - ecd_sq = ( - self.query( - self.entity_class_sq.c.id, - self.entity_class_sq.c.name, - self.entity_class_sq.c.description, - self.entity_class_sq.c.display_order, - self.entity_class_sq.c.display_icon, - self.entity_class_sq.c.hidden, - entity_class_dimension_sq.c.dimension_id, - entity_class_dimension_sq.c.dimension_name, - entity_class_dimension_sq.c.position, - ) - .outerjoin( - entity_class_dimension_sq, - self.entity_class_sq.c.id == entity_class_dimension_sq.c.entity_class_id, - ) - .order_by(self.entity_class_sq.c.id, entity_class_dimension_sq.c.position) - .subquery("ext_entity_class_sq") - ) - self._wide_entity_class_sq = ( - self.query( - ecd_sq.c.id, - ecd_sq.c.name, - ecd_sq.c.description, - ecd_sq.c.display_order, - ecd_sq.c.display_icon, - ecd_sq.c.hidden, - group_concat(ecd_sq.c.dimension_id, ecd_sq.c.position).label("dimension_id_list"), - group_concat(ecd_sq.c.dimension_name, ecd_sq.c.position).label("dimension_name_list"), - func.count(ecd_sq.c.dimension_id).label("dimension_count"), - ) - .group_by( - ecd_sq.c.id, - ecd_sq.c.name, - ecd_sq.c.description, - ecd_sq.c.display_order, - ecd_sq.c.display_icon, - ecd_sq.c.hidden, - ) - .subquery("wide_entity_class_sq") - ) - return self._wide_entity_class_sq + raise NotImplementedError() - @property - def entity_sq(self): - """A subquery of the form: - - .. code-block:: sql + def make_query(self, item_type): + """Returns a :class:`~spinedb_api.query.Query` object to fecth items of given type. - SELECT * FROM entity + Args: + item_type (str) Returns: - :class:`~sqlalchemy.sql.expression.Alias` + :class:`~spinedb_api.query.Query` """ - if self._entity_sq is None: - self._entity_sq = self._make_entity_sq() - return self._entity_sq + raise NotImplementedError() - @property - def entity_element_sq(self): - """A subquery of the form: + def make_item(self, item_type, **item): + factory = self.item_factory(item_type) + return factory(self, item_type, **item) - .. code-block:: sql + def dirty_ids(self, item_type): + return { + item["id"] + for item in self.mapped_table(item_type).valid_values() + if item.status in (Status.to_add, Status.to_update) + } - SELECT * FROM entity_element + def dirty_items(self): + """Returns a list of tuples of the form (item_type, (to_add, to_update, to_remove)) corresponding to + items that have been modified but not yet committed. Returns: - :class:`~sqlalchemy.sql.expression.Alias` + list """ - if self._entity_element_sq is None: - self._entity_element_sq = self._make_entity_element_sq() - return self._entity_element_sq - - @property - def wide_entity_sq(self): - """A subquery of the form: - - .. code-block:: sql - - SELECT - e.*, - count(ee.element_id) AS element_count - group_concat(ee.element_id) AS element_id_list - FROM - entity AS e - entity_element AS ee - WHERE - e.id == ee.entity_id + dirty_items = [] + for item_type in self._sorted_item_types: + mapped_table = self.get(item_type) + if mapped_table is None: + continue + to_add = [] + to_update = [] + to_remove = [] + for item in mapped_table.values(): + _ = item.is_valid() + if item.status == Status.to_add: + to_add.append(item) + elif item.status == Status.to_update: + to_update.append(item) + elif item.status == Status.to_remove: + to_remove.append(item) + if to_remove: + # Fetch descendants, so that they are validated in next iterations of the loop. + # This ensures cascade removal. + # FIXME: We should also fetch the current item type because of multi-dimensional entities and + # classes which also depend on zero-dimensional ones + for other_item_type in self.item_types: + if ( + other_item_type not in self.fetched_item_types + and item_type in self.item_factory(other_item_type).ref_types() + ): + self.fetch_all(other_item_type) + if to_add or to_update or to_remove: + dirty_items.append((item_type, (to_add, to_update, to_remove))) + return dirty_items + + def rollback(self): + """Discards uncommitted changes. + + Namely, removes all the added items, resets all the updated items, and restores all the removed items. Returns: - :class:`~sqlalchemy.sql.expression.Alias` - """ - if self._wide_entity_sq is None: - entity_element_sq = ( - self.query(self.entity_element_sq, self.entity_sq.c.name.label("element_name")) - .filter(self.entity_element_sq.c.element_id == self.entity_sq.c.id) - .subquery("entity_element_sq") - ) - ext_entity_sq = ( - self.query(self.entity_sq, entity_element_sq) - .outerjoin( - entity_element_sq, - self.entity_sq.c.id == entity_element_sq.c.entity_id, - ) - .order_by(self.entity_sq.c.id, entity_element_sq.c.position) - .subquery("ext_entity_sq") - ) - self._wide_entity_sq = ( - self.query( - ext_entity_sq.c.id, - ext_entity_sq.c.class_id, - ext_entity_sq.c.name, - ext_entity_sq.c.description, - ext_entity_sq.c.commit_id, - group_concat(ext_entity_sq.c.element_id, ext_entity_sq.c.position).label("element_id_list"), - group_concat(ext_entity_sq.c.element_name, ext_entity_sq.c.position).label("element_name_list"), - ) - .group_by( - ext_entity_sq.c.id, - ext_entity_sq.c.class_id, - ext_entity_sq.c.name, - ext_entity_sq.c.description, - ext_entity_sq.c.commit_id, - ) - .subquery("wide_entity_sq") - ) - return self._wide_entity_sq + bool: False if there is no uncommitted items, True if successful. + """ + dirty_items = self.dirty_items() + if not dirty_items: + return False + to_add_by_type = [] + to_update_by_type = [] + to_remove_by_type = [] + for item_type, (to_add, to_update, to_remove) in reversed(dirty_items): + to_add_by_type.append((item_type, to_add)) + to_update_by_type.append((item_type, to_update)) + to_remove_by_type.append((item_type, to_remove)) + for item_type, to_remove in to_remove_by_type: + mapped_table = self.mapped_table(item_type) + for item in to_remove: + mapped_table.restore_item(item["id"]) + for item_type, to_update in to_update_by_type: + mapped_table = self.mapped_table(item_type) + for item in to_update: + mapped_table.update_item(item.backup) + for item_type, to_add in to_add_by_type: + mapped_table = self.mapped_table(item_type) + for item in to_add: + if mapped_table.remove_item(item["id"]) is not None: + item.invalidate_id() + return True + + def refresh(self): + """Clears fetch progress, so the DB is queried again.""" + self._offsets.clear() + self._fetched_item_types.clear() + + def _get_next_chunk(self, item_type, limit): + qry = self.make_query(item_type) + if not qry: + return [] + if not limit: + self._fetched_item_types.add(item_type) + return [dict(x) for x in qry] + with self._offset_lock: + offset = self._offsets.setdefault(item_type, 0) + chunk = [dict(x) for x in qry.limit(limit).offset(offset)] + self._offsets[item_type] += len(chunk) + return chunk + + def _advance_query(self, item_type, limit): + """Advances the DB query that fetches items of given type + and adds the results to the corresponding table cache. - @property - def entity_group_sq(self): - """A subquery of the form: - - .. code-block:: sql - - SELECT * FROM entity_group + Args: + item_type (str) Returns: - :class:`~sqlalchemy.sql.expression.Alias` + list: items fetched from the DB + """ + chunk = self._get_next_chunk(item_type, limit) + if not chunk: + self._fetched_item_types.add(item_type) + return [] + mapped_table = self.mapped_table(item_type) + for item in chunk: + mapped_table.add_item(item) + return chunk + + def mapped_table(self, item_type): + return self._mapped_tables.setdefault(item_type, _MappedTable(self, item_type)) + + def get(self, item_type, default=None): + return self._mapped_tables.get(item_type, default) + + def pop(self, item_type, default): + return self._mapped_tables.pop(item_type, default) + + def clear(self): + self._mapped_tables.clear() + + def get_mapped_item(self, item_type, id_): + mapped_table = self.mapped_table(item_type) + item = mapped_table.get(id_) + if item is None: + return {} + return item + + def do_fetch_more(self, item_type, limit=_LIMIT): + if item_type in self._fetched_item_types: + return [] + return self._advance_query(item_type, limit) + + def do_fetch_all(self, item_type): + while self.do_fetch_more(item_type): + pass + + def fetch_value(self, item_type, return_fn): + while self.do_fetch_more(item_type): + return_value = return_fn() + if return_value: + return return_value + return return_fn() + + def fetch_ref(self, item_type, id_): + while self.do_fetch_more(item_type): + ref = self.get_mapped_item(item_type, id_) + if ref: + return ref + # It is possible that fetching was completed between deciding to call this function + # and starting the while loop above resulting in self.do_fetch_more() to return False immediately. + # Therefore, we should try one last time if the ref is available. + ref = self.get_mapped_item(item_type, id_) + if ref: + return ref + + +class _MappedTable(dict): + def __init__(self, db_cache, item_type, *args, **kwargs): """ - if self._entity_group_sq is None: - self._entity_group_sq = self._subquery("entity_group") - return self._entity_group_sq - - @property - def alternative_sq(self): - """A subquery of the form: - - .. code-block:: sql - - SELECT * FROM alternative - - Returns: - :class:`~sqlalchemy.sql.expression.Alias` + Args: + db_cache (DBCacheBase): the DB cache where this table cache belongs. + item_type (str): the item type, equal to a table name """ - if self._alternative_sq is None: - self._alternative_sq = self._make_alternative_sq() - return self._alternative_sq + super().__init__(*args, **kwargs) + self._db_cache = db_cache + self._item_type = item_type + self._id_by_unique_key_value = {} + self._temp_id_by_db_id = {} - @property - def scenario_sq(self): - """A subquery of the form: + def get(self, id_, default=None): + id_ = self._temp_id_by_db_id.get(id_, id_) + return super().get(id_, default) - .. code-block:: sql + def _new_id(self): + temp_id = TempId(self._item_type) - SELECT * FROM scenario + def _callback(db_id): + self._temp_id_by_db_id[db_id] = temp_id - Returns: - :class:`~sqlalchemy.sql.expression.Alias` - """ - if self._scenario_sq is None: - self._scenario_sq = self._make_scenario_sq() - return self._scenario_sq + temp_id.add_resolve_callback(_callback) + return temp_id - @property - def scenario_alternative_sq(self): - """A subquery of the form: - - .. code-block:: sql + def unique_key_value_to_id(self, key, value, strict=False, fetch=True): + """Returns the id that has the given value for the given unique key, or None if not found. - SELECT * FROM scenario_alternative + Args: + key (tuple) + value (tuple) + strict (bool): if True, raise a KeyError if id is not found + fetch (bool): whether to fetch the DB until found. Returns: - :class:`~sqlalchemy.sql.expression.Alias` + int """ - if self._scenario_alternative_sq is None: - self._scenario_alternative_sq = self._make_scenario_alternative_sq() - return self._scenario_alternative_sq + id_by_unique_value = self._id_by_unique_key_value.get(key, {}) + if not id_by_unique_value and fetch: + id_by_unique_value = self._db_cache.fetch_value( + self._item_type, lambda: self._id_by_unique_key_value.get(key, {}) + ) + value = tuple(tuple(x) if isinstance(x, list) else x for x in value) + if strict: + return id_by_unique_value[value] + return id_by_unique_value.get(value) - @property - def entity_alternative_sq(self): - """A subquery of the form: + def _unique_key_value_to_item(self, key, value, fetch=True): + return self.get(self.unique_key_value_to_id(key, value, fetch=fetch)) - .. code-block:: sql + def valid_values(self): + return (x for x in self.values() if x.is_valid()) - SELECT * FROM entity_alternative + def _make_item(self, item): + """Returns a cache item. + + Args: + item (dict): the 'db item' to use as base Returns: - :class:`~sqlalchemy.sql.expression.Alias` + CacheItem """ - if self._entity_alternative_sq is None: - self._entity_alternative_sq = self._subquery("entity_alternative") - return self._entity_alternative_sq - - @property - def parameter_value_list_sq(self): - """A subquery of the form: + return self._db_cache.make_item(self._item_type, **item) - .. code-block:: sql + def find_item(self, item, skip_keys=(), fetch=True): + """Returns a MappedItemBase that matches the given dictionary-item. - SELECT * FROM parameter_value_list + Args: + item (dict) Returns: - :class:`~sqlalchemy.sql.expression.Alias` - """ - if self._parameter_value_list_sq is None: - self._parameter_value_list_sq = self._subquery("parameter_value_list") - return self._parameter_value_list_sq - - @property - def list_value_sq(self): - """A subquery of the form: - - .. code-block:: sql + MappedItemBase or None + """ + id_ = item.get("id") + if id_ is not None: + # id is given, easy + item = self.get(id_) + if item or not fetch: + return item + return self._db_cache.fetch_ref(self._item_type, id_) + # No id. Try to locate the item by the value of one of the unique keys. + # Used by import_data (and more...) + cache_item = self._make_item(item) + error = cache_item.resolve_inverse_references(item.keys()) + if error: + return None + error = cache_item.polish() + if error: + return None + for key, value in cache_item.unique_values(skip_keys=skip_keys): + current_item = self._unique_key_value_to_item(key, value, fetch=fetch) + if current_item: + return current_item + + def check_item(self, item, for_update=False, skip_keys=()): + # FIXME: The only use-case for skip_keys at the moment is that of importing scenario alternatives, + # where we only want to match by (scen_name, alt_name) and not by (scen_name, rank) + if for_update: + current_item = self.find_item(item, skip_keys=skip_keys) + if current_item is None: + return None, f"no {self._item_type} matching {item} to update" + full_item, merge_error = current_item.merge(item) + if full_item is None: + return None, merge_error + else: + current_item = None + full_item, merge_error = item, None + candidate_item = self._make_item(full_item) + error = candidate_item.resolve_inverse_references(skip_keys=item.keys()) + if error: + return None, error + error = candidate_item.polish() + if error: + return None, error + first_invalid_key = candidate_item.first_invalid_key() + if first_invalid_key: + return None, f"invalid {first_invalid_key} for {self._item_type}" + try: + for key, value in candidate_item.unique_values(skip_keys=skip_keys): + empty = {k for k, v in zip(key, value) if v == ""} + if empty: + return None, f"invalid empty keys {empty} for {self._item_type}" + unique_item = self._unique_key_value_to_item(key, value) + if unique_item not in (None, current_item) and unique_item.is_valid(): + return None, f"there's already a {self._item_type} with {dict(zip(key, value))}" + except KeyError as e: + return None, f"missing {e} for {self._item_type}" + if "id" not in candidate_item: + candidate_item["id"] = self._new_id() + return candidate_item, merge_error + + def _add_unique(self, item): + for key, value in item.unique_values(): + self._id_by_unique_key_value.setdefault(key, {})[value] = item["id"] + + def _remove_unique(self, item): + for key, value in item.unique_values(): + id_by_value = self._id_by_unique_key_value.get(key, {}) + if id_by_value.get(value) == item["id"]: + del id_by_value[value] + + def add_item(self, item, new=False): + if not isinstance(item, MappedItemBase): + item = self._make_item(item) + item.polish() + if not new: + # Item comes from the DB + id_ = item["id"] + if id_ in self or id_ in self._temp_id_by_db_id: + # The item is already in the cache + return + if any(value in self._id_by_unique_key_value.get(key, {}) for key, value in item.unique_values()): + # An item with the same unique key is already in the cache + return + else: + item.status = Status.to_add + if "id" not in item or not item.is_id_valid: + item["id"] = self._new_id() + self[item["id"]] = item + self._add_unique(item) + return item + + def update_item(self, item): + current_item = self.find_item(item) + self._remove_unique(current_item) + current_item.update(item) + self._add_unique(current_item) + current_item.cascade_update() + return current_item + + def remove_item(self, id_): + current_item = self.find_item({"id": id_}) + if current_item is not None: + self._remove_unique(current_item) + current_item.cascade_remove() + return current_item + + def restore_item(self, id_): + current_item = self.find_item({"id": id_}) + if current_item is not None: + self._add_unique(current_item) + current_item.cascade_restore() + return current_item + + +class MappedItemBase(dict): + """A dictionary that represents a db item.""" + + _fields = {} + """A dictionaty mapping fields to a tuple of (type, description)""" + _defaults = {} + """A dictionary mapping keys to their default values""" + _unique_keys = () + """A tuple where each element is itself a tuple of keys that are unique""" + _references = {} + """A dictionary mapping keys that are not in the original dictionary, + to a recipe for finding the field they reference in another item. + + The recipe is a tuple of the form (original_field, (ref_item_type, ref_field)), + to be interpreted as follows: + 1. take the value from the original_field of this item, which should be an id, + 2. locate the item of type ref_item_type that has that id, + 3. return the value from the ref_field of that item. + """ + _inverse_references = {} + """Another dictionary mapping keys that are not in the original dictionary, + to a recipe for finding the field they reference in another item. + Used only for creating new items, when the user provides names and we want to find the ids. + + The recipe is a tuple of the form (src_unique_key, (ref_item_type, ref_unique_key)), + to be interpreted as follows: + 1. take the values from the src_unique_key of this item, to form a tuple, + 2. locate the item of type ref_item_type where the ref_unique_key is exactly that tuple of values, + 3. return the id of that item. + """ - SELECT * FROM list_value + def __init__(self, db_cache, item_type, **kwargs): + """ + Args: + db_cache (DBCacheBase): the DB cache where this item belongs. + """ + super().__init__(**kwargs) + self._db_cache = db_cache + self._item_type = item_type + self._referrers = {} + self._weak_referrers = {} + self.restore_callbacks = set() + self.update_callbacks = set() + self.remove_callbacks = set() + self._is_id_valid = True + self._to_remove = False + self._removed = False + self._corrupted = False + self._valid = None + self._status = Status.committed + self._removal_source = None + self._status_when_removed = None + self._backup = None + + @classmethod + def ref_types(cls): + """Returns a set of item types that this class refers. Returns: - :class:`~sqlalchemy.sql.expression.Alias` + set(str) """ - if self._list_value_sq is None: - self._list_value_sq = self._subquery("list_value") - return self._list_value_sq + return set(ref_type for _src_key, (ref_type, _ref_key) in cls._references.values()) @property - def parameter_definition_sq(self): - """A subquery of the form: - - .. code-block:: sql - - SELECT * FROM parameter_definition + def status(self): + """Returns the status of this item. Returns: - :class:`~sqlalchemy.sql.expression.Alias` + Status """ + return self._status - if self._parameter_definition_sq is None: - self._parameter_definition_sq = self._make_parameter_definition_sq() - return self._parameter_definition_sq - - @property - def parameter_value_sq(self): - """A subquery of the form: - - .. code-block:: sql - - SELECT * FROM parameter_value + @status.setter + def status(self, status): + """Sets the status of this item. - Returns: - :class:`~sqlalchemy.sql.expression.Alias` + Args: + status (Status) """ - if self._parameter_value_sq is None: - self._parameter_value_sq = self._make_parameter_value_sq() - return self._parameter_value_sq + self._status = status @property - def metadata_sq(self): - """A subquery of the form: - - .. code-block:: sql - - SELECT * FROM list_value + def backup(self): + """Returns the committed version of this item. Returns: - :class:`~sqlalchemy.sql.expression.Alias` + dict or None """ - if self._metadata_sq is None: - self._metadata_sq = self._subquery("metadata") - return self._metadata_sq + return self._backup @property - def parameter_value_metadata_sq(self): - """A subquery of the form: - - .. code-block:: sql - - SELECT * FROM parameter_value_metadata + def removed(self): + """Returns whether or not this item has been removed. Returns: - :class:`~sqlalchemy.sql.expression.Alias` + bool """ - if self._parameter_value_metadata_sq is None: - self._parameter_value_metadata_sq = self._subquery("parameter_value_metadata") - return self._parameter_value_metadata_sq + return self._removed @property - def entity_metadata_sq(self): - """A subquery of the form: - - .. code-block:: sql - - SELECT * FROM entity_metadata + def item_type(self): + """Returns this item's type Returns: - :class:`~sqlalchemy.sql.expression.Alias` + str """ - if self._entity_metadata_sq is None: - self._entity_metadata_sq = self._subquery("entity_metadata") - return self._entity_metadata_sq + return self._item_type @property - def commit_sq(self): - """A subquery of the form: - - .. code-block:: sql - - SELECT * FROM commit + def key(self): + """Returns a tuple (item_type, id) for convenience, or None if this item doesn't yet have an id. + TODO: When does the latter happen? Returns: - :class:`~sqlalchemy.sql.expression.Alias` + tuple(str,int) or None """ - if self._commit_sq is None: - commit_sq = self._subquery("commit") - self._commit_sq = self.query(commit_sq).filter(commit_sq.c.comment != "").subquery() - return self._commit_sq - - @property - def object_class_sq(self): - if self._object_class_sq is None: - self._object_class_sq = ( - self.query( - self.wide_entity_class_sq.c.id.label("id"), - self.wide_entity_class_sq.c.name.label("name"), - self.wide_entity_class_sq.c.description.label("description"), - self.wide_entity_class_sq.c.display_order.label("display_order"), - self.wide_entity_class_sq.c.display_icon.label("display_icon"), - self.wide_entity_class_sq.c.hidden.label("hidden"), - ) - .filter(self.wide_entity_class_sq.c.dimension_id_list == None) - .subquery("object_class_sq") - ) - return self._object_class_sq - - @property - def object_sq(self): - if self._object_sq is None: - self._object_sq = ( - self.query( - self.wide_entity_sq.c.id.label("id"), - self.wide_entity_sq.c.class_id.label("class_id"), - self.wide_entity_sq.c.name.label("name"), - self.wide_entity_sq.c.description.label("description"), - self.wide_entity_sq.c.commit_id.label("commit_id"), - ) - .filter(self.wide_entity_sq.c.element_id_list == None) - .subquery("object_sq") - ) - return self._object_sq + id_ = dict.get(self, "id") + if id_ is None: + return None + return (self._item_type, id_) @property - def relationship_class_sq(self): - if self._relationship_class_sq is None: - ent_cls_dim_sq = self._subquery("entity_class_dimension") - self._relationship_class_sq = ( - self.query( - ent_cls_dim_sq.c.entity_class_id.label("id"), - ent_cls_dim_sq.c.position.label("dimension"), # NOTE: nothing to do with the `dimension` concept - ent_cls_dim_sq.c.dimension_id.label("object_class_id"), - self.wide_entity_class_sq.c.name.label("name"), - self.wide_entity_class_sq.c.description.label("description"), - self.wide_entity_class_sq.c.display_icon.label("display_icon"), - self.wide_entity_class_sq.c.hidden.label("hidden"), - ) - .filter(self.wide_entity_class_sq.c.id == ent_cls_dim_sq.c.entity_class_id) - .subquery("relationship_class_sq") - ) - return self._relationship_class_sq - - @property - def relationship_sq(self): - if self._relationship_sq is None: - ent_el_sq = self._subquery("entity_element") - self._relationship_sq = ( - self.query( - ent_el_sq.c.entity_id.label("id"), - ent_el_sq.c.position.label("dimension"), # NOTE: nothing to do with the `dimension` concept - ent_el_sq.c.element_id.label("object_id"), - ent_el_sq.c.entity_class_id.label("class_id"), - self.wide_entity_sq.c.name.label("name"), - self.wide_entity_sq.c.commit_id.label("commit_id"), - ) - .filter(self.wide_entity_sq.c.id == ent_el_sq.c.entity_id) - .subquery("relationship_sq") - ) - return self._relationship_sq - - @property - def ext_parameter_value_list_sq(self): - if self._ext_parameter_value_list_sq is None: - self._ext_parameter_value_list_sq = ( - self.query( - self.parameter_value_list_sq.c.id, - self.parameter_value_list_sq.c.name, - self.parameter_value_list_sq.c.commit_id, - self.list_value_sq.c.id.label("value_id"), - self.list_value_sq.c.index.label("value_index"), - ).outerjoin( - self.list_value_sq, - self.list_value_sq.c.parameter_value_list_id == self.parameter_value_list_sq.c.id, - ) - ).subquery() - return self._ext_parameter_value_list_sq + def is_id_valid(self): + return self._is_id_valid - @property - def wide_parameter_value_list_sq(self): - if self._wide_parameter_value_list_sq is None: - self._wide_parameter_value_list_sq = ( - self.query( - self.ext_parameter_value_list_sq.c.id, - self.ext_parameter_value_list_sq.c.name, - self.ext_parameter_value_list_sq.c.commit_id, - group_concat( - self.ext_parameter_value_list_sq.c.value_id, self.ext_parameter_value_list_sq.c.value_index - ).label("value_id_list"), - group_concat( - self.ext_parameter_value_list_sq.c.value_index, self.ext_parameter_value_list_sq.c.value_index - ).label("value_index_list"), - ).group_by( - self.ext_parameter_value_list_sq.c.id, - self.ext_parameter_value_list_sq.c.name, - self.ext_parameter_value_list_sq.c.commit_id, - ) - ).subquery() - return self._wide_parameter_value_list_sq + def invalidate_id(self): + """Sets id as invalid.""" + self._is_id_valid = False - @property - def ord_list_value_sq(self): - if self._ord_list_value_sq is None: - self._ord_list_value_sq = ( - self.query( - self.list_value_sq.c.id, - self.list_value_sq.c.parameter_value_list_id, - self.list_value_sq.c.index, - self.list_value_sq.c.value, - self.list_value_sq.c.type, - self.list_value_sq.c.commit_id, - ) - .order_by(self.list_value_sq.c.parameter_value_list_id, self.list_value_sq.c.index) - .subquery() - ) - return self._ord_list_value_sq - - @property - def ext_scenario_sq(self): - if self._ext_scenario_sq is None: - self._ext_scenario_sq = ( - self.query( - self.scenario_sq.c.id.label("id"), - self.scenario_sq.c.name.label("name"), - self.scenario_sq.c.description.label("description"), - self.scenario_sq.c.active.label("active"), - self.scenario_alternative_sq.c.alternative_id.label("alternative_id"), - self.scenario_alternative_sq.c.rank.label("rank"), - self.alternative_sq.c.name.label("alternative_name"), - self.scenario_sq.c.commit_id.label("commit_id"), - ) - .outerjoin( - self.scenario_alternative_sq, self.scenario_alternative_sq.c.scenario_id == self.scenario_sq.c.id - ) - .outerjoin( - self.alternative_sq, self.alternative_sq.c.id == self.scenario_alternative_sq.c.alternative_id - ) - .order_by(self.scenario_sq.c.id, self.scenario_alternative_sq.c.rank) - .subquery() - ) - return self._ext_scenario_sq - - @property - def wide_scenario_sq(self): - if self._wide_scenario_sq is None: - self._wide_scenario_sq = ( - self.query( - self.ext_scenario_sq.c.id.label("id"), - self.ext_scenario_sq.c.name.label("name"), - self.ext_scenario_sq.c.description.label("description"), - self.ext_scenario_sq.c.active.label("active"), - self.ext_scenario_sq.c.commit_id.label("commit_id"), - group_concat(self.ext_scenario_sq.c.alternative_id, self.ext_scenario_sq.c.rank).label( - "alternative_id_list" - ), - group_concat(self.ext_scenario_sq.c.alternative_name, self.ext_scenario_sq.c.rank).label( - "alternative_name_list" - ), - ) - .group_by( - self.ext_scenario_sq.c.id, - self.ext_scenario_sq.c.name, - self.ext_scenario_sq.c.description, - self.ext_scenario_sq.c.active, - self.ext_scenario_sq.c.commit_id, - ) - .subquery() - ) - return self._wide_scenario_sq - - @property - def linked_scenario_alternative_sq(self): - if self._linked_scenario_alternative_sq is None: - scenario_next_alternative = aliased(self.scenario_alternative_sq) - self._linked_scenario_alternative_sq = ( - self.query( - self.scenario_alternative_sq.c.id.label("id"), - self.scenario_alternative_sq.c.scenario_id.label("scenario_id"), - self.scenario_alternative_sq.c.alternative_id.label("alternative_id"), - self.scenario_alternative_sq.c.rank.label("rank"), - scenario_next_alternative.c.alternative_id.label("before_alternative_id"), - scenario_next_alternative.c.rank.label("before_rank"), - self.scenario_alternative_sq.c.commit_id.label("commit_id"), - ) - .outerjoin( - scenario_next_alternative, - and_( - scenario_next_alternative.c.scenario_id == self.scenario_alternative_sq.c.scenario_id, - scenario_next_alternative.c.rank == self.scenario_alternative_sq.c.rank + 1, - ), - ) - .order_by(self.scenario_alternative_sq.c.scenario_id, self.scenario_alternative_sq.c.rank) - .subquery() - ) - return self._linked_scenario_alternative_sq - - @property - def ext_linked_scenario_alternative_sq(self): - if self._ext_linked_scenario_alternative_sq is None: - next_alternative = aliased(self.alternative_sq) - self._ext_linked_scenario_alternative_sq = ( - self.query( - self.linked_scenario_alternative_sq.c.id.label("id"), - self.linked_scenario_alternative_sq.c.scenario_id.label("scenario_id"), - self.scenario_sq.c.name.label("scenario_name"), - self.linked_scenario_alternative_sq.c.alternative_id.label("alternative_id"), - self.alternative_sq.c.name.label("alternative_name"), - self.linked_scenario_alternative_sq.c.rank.label("rank"), - self.linked_scenario_alternative_sq.c.before_alternative_id.label("before_alternative_id"), - self.linked_scenario_alternative_sq.c.before_rank.label("before_rank"), - next_alternative.c.name.label("before_alternative_name"), - self.linked_scenario_alternative_sq.c.commit_id.label("commit_id"), - ) - .filter(self.linked_scenario_alternative_sq.c.scenario_id == self.scenario_sq.c.id) - .filter(self.alternative_sq.c.id == self.linked_scenario_alternative_sq.c.alternative_id) - .outerjoin( - next_alternative, - next_alternative.c.id == self.linked_scenario_alternative_sq.c.before_alternative_id, - ) - .subquery() - ) - return self._ext_linked_scenario_alternative_sq - - @property - def ext_object_sq(self): - if self._ext_object_sq is None: - self._ext_object_sq = ( - self.query( - self.object_sq.c.id.label("id"), - self.object_sq.c.class_id.label("class_id"), - self.object_class_sq.c.name.label("class_name"), - self.object_sq.c.name.label("name"), - self.object_sq.c.description.label("description"), - self.entity_group_sq.c.entity_id.label("group_id"), - self.object_sq.c.commit_id.label("commit_id"), - ) - .filter(self.object_sq.c.class_id == self.object_class_sq.c.id) - .outerjoin(self.entity_group_sq, self.entity_group_sq.c.entity_id == self.object_sq.c.id) - .distinct(self.entity_group_sq.c.entity_id) - .subquery() - ) - return self._ext_object_sq - - @property - def ext_relationship_class_sq(self): - if self._ext_relationship_class_sq is None: - self._ext_relationship_class_sq = ( - self.query( - self.relationship_class_sq.c.id.label("id"), - self.relationship_class_sq.c.name.label("name"), - self.relationship_class_sq.c.description.label("description"), - self.relationship_class_sq.c.dimension.label("dimension"), - self.relationship_class_sq.c.display_icon.label("display_icon"), - self.object_class_sq.c.id.label("object_class_id"), - self.object_class_sq.c.name.label("object_class_name"), - ) - .filter(self.relationship_class_sq.c.object_class_id == self.object_class_sq.c.id) - .order_by(self.relationship_class_sq.c.id, self.relationship_class_sq.c.dimension) - .subquery() - ) - return self._ext_relationship_class_sq - - @property - def wide_relationship_class_sq(self): - if self._wide_relationship_class_sq is None: - self._wide_relationship_class_sq = ( - self.query( - self.ext_relationship_class_sq.c.id, - self.ext_relationship_class_sq.c.name, - self.ext_relationship_class_sq.c.description, - self.ext_relationship_class_sq.c.display_icon, - group_concat( - self.ext_relationship_class_sq.c.object_class_id, self.ext_relationship_class_sq.c.dimension - ).label("object_class_id_list"), - group_concat( - self.ext_relationship_class_sq.c.object_class_name, self.ext_relationship_class_sq.c.dimension - ).label("object_class_name_list"), - ) - .group_by( - self.ext_relationship_class_sq.c.id, - self.ext_relationship_class_sq.c.name, - self.ext_relationship_class_sq.c.description, - self.ext_relationship_class_sq.c.display_icon, - ) - .subquery() - ) - return self._wide_relationship_class_sq - - @property - def ext_relationship_sq(self): - if self._ext_relationship_sq is None: - self._ext_relationship_sq = ( - self.query( - self.relationship_sq.c.id.label("id"), - self.relationship_sq.c.name.label("name"), - self.relationship_sq.c.class_id.label("class_id"), - self.relationship_sq.c.dimension.label("dimension"), - self.wide_relationship_class_sq.c.name.label("class_name"), - self.ext_object_sq.c.id.label("object_id"), - self.ext_object_sq.c.name.label("object_name"), - self.ext_object_sq.c.class_id.label("object_class_id"), - self.ext_object_sq.c.class_name.label("object_class_name"), - self.relationship_sq.c.commit_id.label("commit_id"), - ) - .filter(self.relationship_sq.c.class_id == self.wide_relationship_class_sq.c.id) - .outerjoin(self.ext_object_sq, self.relationship_sq.c.object_id == self.ext_object_sq.c.id) - .order_by(self.relationship_sq.c.id, self.relationship_sq.c.dimension) - .subquery() - ) - return self._ext_relationship_sq - - @property - def wide_relationship_sq(self): - if self._wide_relationship_sq is None: - self._wide_relationship_sq = ( - self.query( - self.ext_relationship_sq.c.id, - self.ext_relationship_sq.c.name, - self.ext_relationship_sq.c.class_id, - self.ext_relationship_sq.c.class_name, - self.ext_relationship_sq.c.commit_id, - group_concat(self.ext_relationship_sq.c.object_id, self.ext_relationship_sq.c.dimension).label( - "object_id_list" - ), - group_concat(self.ext_relationship_sq.c.object_name, self.ext_relationship_sq.c.dimension).label( - "object_name_list" - ), - group_concat( - self.ext_relationship_sq.c.object_class_id, self.ext_relationship_sq.c.dimension - ).label("object_class_id_list"), - group_concat( - self.ext_relationship_sq.c.object_class_name, self.ext_relationship_sq.c.dimension - ).label("object_class_name_list"), - ) - .group_by( - self.ext_relationship_sq.c.id, - self.ext_relationship_sq.c.name, - self.ext_relationship_sq.c.class_id, - self.ext_relationship_sq.c.class_name, - self.ext_relationship_sq.c.commit_id, - ) - # dimension count might be higher than object count when objects have been filtered out - .having( - func.count(self.ext_relationship_sq.c.dimension) == func.count(self.ext_relationship_sq.c.object_id) - ) - .subquery() - ) - return self._wide_relationship_sq - - @property - def ext_entity_group_sq(self): - if self._ext_entity_group_sq is None: - group_entity = aliased(self.entity_sq) - member_entity = aliased(self.entity_sq) - self._ext_entity_group_sq = ( - self.query( - self.entity_group_sq.c.id.label("id"), - self.entity_group_sq.c.entity_class_id.label("class_id"), - self.entity_group_sq.c.entity_id.label("group_id"), - self.entity_group_sq.c.member_id.label("member_id"), - self.wide_entity_class_sq.c.name.label("class_name"), - group_entity.c.name.label("group_name"), - member_entity.c.name.label("member_name"), - label("object_class_id", self._object_class_id()), - label("relationship_class_id", self._relationship_class_id()), - ) - .filter(self.entity_group_sq.c.entity_class_id == self.wide_entity_class_sq.c.id) - .join(group_entity, self.entity_group_sq.c.entity_id == group_entity.c.id) - .join(member_entity, self.entity_group_sq.c.member_id == member_entity.c.id) - .subquery() - ) - return self._ext_entity_group_sq - - @property - def entity_parameter_definition_sq(self): - if self._entity_parameter_definition_sq is None: - self._entity_parameter_definition_sq = ( - self.query( - self.parameter_definition_sq.c.id.label("id"), - self.parameter_definition_sq.c.entity_class_id, - self.wide_entity_class_sq.c.name.label("entity_class_name"), - label("object_class_id", self._object_class_id()), - label("relationship_class_id", self._relationship_class_id()), - label("object_class_name", self._object_class_name()), - label("relationship_class_name", self._relationship_class_name()), - label("object_class_id_list", self._object_class_id_list()), - label("object_class_name_list", self._object_class_name_list()), - self.parameter_definition_sq.c.name.label("parameter_name"), - self.parameter_definition_sq.c.parameter_value_list_id.label("value_list_id"), - self.parameter_value_list_sq.c.name.label("value_list_name"), - self.parameter_definition_sq.c.default_value, - self.parameter_definition_sq.c.default_type, - self.parameter_definition_sq.c.list_value_id, - self.parameter_definition_sq.c.description, - self.parameter_definition_sq.c.commit_id, - ) - .join( - self.wide_entity_class_sq, - self.wide_entity_class_sq.c.id == self.parameter_definition_sq.c.entity_class_id, - ) - .outerjoin( - self.parameter_value_list_sq, - self.parameter_value_list_sq.c.id == self.parameter_definition_sq.c.parameter_value_list_id, - ) - .outerjoin( - self.wide_relationship_class_sq, - self.wide_relationship_class_sq.c.id == self.wide_entity_class_sq.c.id, - ) - .subquery() - ) - return self._entity_parameter_definition_sq - - @property - def object_parameter_definition_sq(self): - if self._object_parameter_definition_sq is None: - self._object_parameter_definition_sq = ( - self.query( - self.parameter_definition_sq.c.id.label("id"), - self.parameter_definition_sq.c.entity_class_id, - self.object_class_sq.c.name.label("entity_class_name"), - self.object_class_sq.c.id.label("object_class_id"), - self.object_class_sq.c.name.label("object_class_name"), - self.parameter_definition_sq.c.name.label("parameter_name"), - self.parameter_definition_sq.c.parameter_value_list_id.label("value_list_id"), - self.parameter_value_list_sq.c.name.label("value_list_name"), - self.parameter_definition_sq.c.default_value, - self.parameter_definition_sq.c.default_type, - self.parameter_definition_sq.c.description, - ) - .filter(self.object_class_sq.c.id == self.parameter_definition_sq.c.entity_class_id) - .outerjoin( - self.parameter_value_list_sq, - self.parameter_value_list_sq.c.id == self.parameter_definition_sq.c.parameter_value_list_id, - ) - .subquery() - ) - return self._object_parameter_definition_sq - - @property - def relationship_parameter_definition_sq(self): - if self._relationship_parameter_definition_sq is None: - self._relationship_parameter_definition_sq = ( - self.query( - self.parameter_definition_sq.c.id.label("id"), - self.parameter_definition_sq.c.entity_class_id, - self.wide_relationship_class_sq.c.name.label("entity_class_name"), - self.wide_relationship_class_sq.c.id.label("relationship_class_id"), - self.wide_relationship_class_sq.c.name.label("relationship_class_name"), - self.wide_relationship_class_sq.c.object_class_id_list, - self.wide_relationship_class_sq.c.object_class_name_list, - self.parameter_definition_sq.c.name.label("parameter_name"), - self.parameter_definition_sq.c.parameter_value_list_id.label("value_list_id"), - self.parameter_value_list_sq.c.name.label("value_list_name"), - self.parameter_definition_sq.c.default_value, - self.parameter_definition_sq.c.default_type, - self.parameter_definition_sq.c.description, - ) - .filter(self.parameter_definition_sq.c.entity_class_id == self.wide_relationship_class_sq.c.id) - .outerjoin( - self.parameter_value_list_sq, - self.parameter_value_list_sq.c.id == self.parameter_definition_sq.c.parameter_value_list_id, - ) - .subquery() - ) - return self._relationship_parameter_definition_sq - - @property - def entity_parameter_value_sq(self): - if self._entity_parameter_value_sq is None: - self._entity_parameter_value_sq = ( - self.query( - self.parameter_value_sq.c.id.label("id"), - self.parameter_definition_sq.c.entity_class_id, - self.wide_entity_class_sq.c.name.label("entity_class_name"), - label("object_class_id", self._object_class_id()), - label("relationship_class_id", self._relationship_class_id()), - label("object_class_name", self._object_class_name()), - label("relationship_class_name", self._relationship_class_name()), - label("object_class_id_list", self._object_class_id_list()), - label("object_class_name_list", self._object_class_name_list()), - self.parameter_value_sq.c.entity_id, - self.wide_entity_sq.c.name.label("entity_name"), - label("object_id", self._object_id()), - label("relationship_id", self._relationship_id()), - label("object_name", self._object_name()), - label("object_id_list", self._object_id_list()), - label("object_name_list", self._object_name_list()), - self.parameter_definition_sq.c.id.label("parameter_id"), - self.parameter_definition_sq.c.name.label("parameter_name"), - self.parameter_value_sq.c.alternative_id, - self.alternative_sq.c.name.label("alternative_name"), - self.parameter_value_sq.c.value, - self.parameter_value_sq.c.type, - self.parameter_value_sq.c.list_value_id, - self.parameter_value_sq.c.commit_id, - ) - .join( - self.parameter_definition_sq, - self.parameter_definition_sq.c.id == self.parameter_value_sq.c.parameter_definition_id, - ) - .join(self.wide_entity_sq, self.parameter_value_sq.c.entity_id == self.wide_entity_sq.c.id) - .join( - self.wide_entity_class_sq, - self.parameter_definition_sq.c.entity_class_id == self.wide_entity_class_sq.c.id, - ) - .join(self.alternative_sq, self.parameter_value_sq.c.alternative_id == self.alternative_sq.c.id) - .outerjoin( - self.wide_relationship_class_sq, - self.wide_relationship_class_sq.c.id == self.wide_entity_class_sq.c.id, - ) - .outerjoin(self.wide_relationship_sq, self.wide_relationship_sq.c.id == self.wide_entity_sq.c.id) - # object_id_list might be None when objects have been filtered out - .filter( - or_( - self.wide_relationship_sq.c.id.is_(None), - self.wide_relationship_sq.c.object_id_list.isnot(None), - ) - ) - .subquery() - ) - return self._entity_parameter_value_sq - - @property - def object_parameter_value_sq(self): - if self._object_parameter_value_sq is None: - self._object_parameter_value_sq = ( - self.query( - self.parameter_value_sq.c.id.label("id"), - self.parameter_definition_sq.c.entity_class_id, - self.object_class_sq.c.id.label("object_class_id"), - self.object_class_sq.c.name.label("object_class_name"), - self.parameter_value_sq.c.entity_id, - self.object_sq.c.id.label("object_id"), - self.object_sq.c.name.label("object_name"), - self.parameter_definition_sq.c.id.label("parameter_id"), - self.parameter_definition_sq.c.name.label("parameter_name"), - self.parameter_value_sq.c.alternative_id, - self.alternative_sq.c.name.label("alternative_name"), - self.parameter_value_sq.c.value, - self.parameter_value_sq.c.type, - ) - .filter(self.parameter_definition_sq.c.id == self.parameter_value_sq.c.parameter_definition_id) - .filter(self.parameter_value_sq.c.entity_id == self.object_sq.c.id) - .filter(self.parameter_definition_sq.c.entity_class_id == self.object_class_sq.c.id) - .filter(self.parameter_value_sq.c.alternative_id == self.alternative_sq.c.id) - .subquery() - ) - return self._object_parameter_value_sq - - @property - def relationship_parameter_value_sq(self): - if self._relationship_parameter_value_sq is None: - self._relationship_parameter_value_sq = ( - self.query( - self.parameter_value_sq.c.id.label("id"), - self.parameter_definition_sq.c.entity_class_id, - self.wide_relationship_class_sq.c.id.label("relationship_class_id"), - self.wide_relationship_class_sq.c.name.label("relationship_class_name"), - self.wide_relationship_class_sq.c.object_class_id_list, - self.wide_relationship_class_sq.c.object_class_name_list, - self.parameter_value_sq.c.entity_id, - self.wide_relationship_sq.c.id.label("relationship_id"), - self.wide_relationship_sq.c.object_id_list, - self.wide_relationship_sq.c.object_name_list, - self.parameter_definition_sq.c.id.label("parameter_id"), - self.parameter_definition_sq.c.name.label("parameter_name"), - self.parameter_value_sq.c.alternative_id, - self.alternative_sq.c.name.label("alternative_name"), - self.parameter_value_sq.c.value, - self.parameter_value_sq.c.type, - ) - .filter(self.parameter_definition_sq.c.id == self.parameter_value_sq.c.parameter_definition_id) - .filter(self.parameter_value_sq.c.entity_id == self.wide_relationship_sq.c.id) - .filter(self.parameter_definition_sq.c.entity_class_id == self.wide_relationship_class_sq.c.id) - .filter(self.parameter_value_sq.c.alternative_id == self.alternative_sq.c.id) - .subquery() - ) - return self._relationship_parameter_value_sq - - @property - def ext_parameter_value_metadata_sq(self): - if self._ext_parameter_value_metadata_sq is None: - self._ext_parameter_value_metadata_sq = ( - self.query( - self.parameter_value_metadata_sq.c.id, - self.parameter_value_metadata_sq.c.parameter_value_id, - self.metadata_sq.c.id.label("metadata_id"), - self.entity_sq.c.name.label("entity_name"), - self.parameter_definition_sq.c.name.label("parameter_name"), - self.alternative_sq.c.name.label("alternative_name"), - self.metadata_sq.c.name.label("metadata_name"), - self.metadata_sq.c.value.label("metadata_value"), - self.parameter_value_metadata_sq.c.commit_id, - ) - .filter(self.parameter_value_metadata_sq.c.parameter_value_id == self.parameter_value_sq.c.id) - .filter(self.parameter_value_sq.c.parameter_definition_id == self.parameter_definition_sq.c.id) - .filter(self.parameter_value_sq.c.entity_id == self.entity_sq.c.id) - .filter(self.parameter_value_sq.c.alternative_id == self.alternative_sq.c.id) - .filter(self.parameter_value_metadata_sq.c.metadata_id == self.metadata_sq.c.id) - .subquery() - ) - return self._ext_parameter_value_metadata_sq - - @property - def ext_entity_metadata_sq(self): - if self._ext_entity_metadata_sq is None: - self._ext_entity_metadata_sq = ( - self.query( - self.entity_metadata_sq.c.id, - self.entity_metadata_sq.c.entity_id, - self.metadata_sq.c.id.label("metadata_id"), - self.entity_sq.c.name.label("entity_name"), - self.metadata_sq.c.name.label("metadata_name"), - self.metadata_sq.c.value.label("metadata_value"), - self.entity_metadata_sq.c.commit_id, - ) - .filter(self.entity_metadata_sq.c.entity_id == self.entity_sq.c.id) - .filter(self.entity_metadata_sq.c.metadata_id == self.metadata_sq.c.id) - .subquery() - ) - return self._ext_entity_metadata_sq - - def _make_entity_class_sq(self): - """ - Creates a subquery for entity classes. + def _extended(self): + """Returns a dict from this item's original fields plus all the references resolved statically. Returns: - Alias: an entity class subquery + dict """ - return self._subquery("entity_class") + d = self._asdict() + d.update({key: self[key] for key in self._references}) + return d - def _make_entity_sq(self): - """ - Creates a subquery for entities. + def _asdict(self): + """Returns a dict from this item's original fields. Returns: - Alias: an entity subquery - """ - return self._subquery("entity") - - def _make_entity_element_sq(self): + dict """ - Creates a subquery for entity-elements. + return dict(self) - Returns: - Alias: an entity_element subquery - """ - return self._subquery("entity_element") + def merge(self, other): + """Merges this item with another and returns the merged item together with any errors. + Used for updating items. - def _make_parameter_definition_sq(self): - """ - Creates a subquery for parameter definitions. + Args: + other (dict): the item to merge into this. Returns: - Alias: a parameter definition subquery + dict: merged item. + str: error description if any. """ - par_def_sq = self._subquery("parameter_definition") - list_value_id = case( - [(par_def_sq.c.default_type == "list_value_ref", cast(par_def_sq.c.default_value, Integer()))], else_=None - ) - default_value = case( - [(par_def_sq.c.default_type == "list_value_ref", self.list_value_sq.c.value)], - else_=par_def_sq.c.default_value, - ) - default_type = case( - [(par_def_sq.c.default_type == "list_value_ref", self.list_value_sq.c.type)], - else_=par_def_sq.c.default_type, - ) - return ( - self.query( - par_def_sq.c.id.label("id"), - par_def_sq.c.name.label("name"), - par_def_sq.c.description.label("description"), - par_def_sq.c.entity_class_id, - label("default_value", default_value), - label("default_type", default_type), - label("list_value_id", list_value_id), - par_def_sq.c.commit_id.label("commit_id"), - par_def_sq.c.parameter_value_list_id.label("parameter_value_list_id"), - ) - .outerjoin(self.list_value_sq, self.list_value_sq.c.id == list_value_id) - .subquery("clean_parameter_definition_sq") - ) + if all(self.get(key) == value for key, value in other.items()): + return None, "" + merged = {**self._extended(), **other} + if not isinstance(merged["id"], int): + merged["id"] = self["id"] + return merged, "" - def _make_parameter_value_sq(self): - """ - Creates a subquery for parameter values. + def first_invalid_key(self): + """Goes through the ``_references`` class attribute and returns the key of the first one + that cannot be resolved. Returns: - Alias: a parameter value subquery + str or None: unresolved reference's key if any. """ - par_val_sq = self._subquery("parameter_value") - list_value_id = case([(par_val_sq.c.type == "list_value_ref", cast(par_val_sq.c.value, Integer()))], else_=None) - value = case([(par_val_sq.c.type == "list_value_ref", self.list_value_sq.c.value)], else_=par_val_sq.c.value) - type_ = case([(par_val_sq.c.type == "list_value_ref", self.list_value_sq.c.type)], else_=par_val_sq.c.type) - return ( - self.query( - par_val_sq.c.id.label("id"), - par_val_sq.c.parameter_definition_id, - par_val_sq.c.entity_class_id, - par_val_sq.c.entity_id, - label("value", value), - label("type", type_), - label("list_value_id", list_value_id), - par_val_sq.c.commit_id.label("commit_id"), - par_val_sq.c.alternative_id, - ) - .filter(par_val_sq.c.entity_id == self.entity_sq.c.id) - .outerjoin(self.list_value_sq, self.list_value_sq.c.id == list_value_id) - .subquery("clean_parameter_value_sq") - ) - - def _make_alternative_sq(self): - """ - Creates a subquery for alternatives. + for src_key, (ref_type, _ref_key) in self._references.values(): + try: + ref_id = self[src_key] + except KeyError: + return src_key + if isinstance(ref_id, tuple): + for x in ref_id: + if not self._get_ref(ref_type, x): + return src_key + elif not self._get_ref(ref_type, ref_id): + return src_key + + def unique_values(self, skip_keys=()): + """Yields tuples of unique keys and their values. - Returns: - Alias: an alternative subquery - """ - return self._subquery("alternative") + Args: + skip_keys: Don't yield these keys - def _make_scenario_sq(self): + Yields: + tuple(tuple,tuple): the first element is the unique key, the second is the values. """ - Creates a subquery for scenarios. + for key in self._unique_keys: + if key not in skip_keys: + yield key, tuple(self.get(k) for k in key) - Returns: - Alias: a scenario subquery - """ - return self._subquery("scenario") + def resolve_inverse_references(self, skip_keys=()): + """Goes through the ``_inverse_references`` class attribute and updates this item + by resolving those references. + Returns any error. - def _make_scenario_alternative_sq(self): - """ - Creates a subquery for scenario alternatives. + Args: + skip_keys (tuple): don't resolve references for these keys. Returns: - Alias: a scenario alternative subquery + str or None: error description if any. """ - return self._subquery("scenario_alternative") - - def get_import_alternative_name(self): - if self._import_alternative_name is None: - self._create_import_alternative() - return self._import_alternative_name - - def _create_import_alternative(self): - """Creates the alternative to be used as default for all import operations.""" - self._import_alternative_name = "Base" + for src_key, (id_key, (ref_type, ref_key)) in self._inverse_references.items(): + if src_key in skip_keys: + continue + id_value = tuple(dict.pop(self, k, None) or self.get(k) for k in id_key) + if None in id_value: + continue + mapped_table = self._db_cache.mapped_table(ref_type) + try: + self[src_key] = ( + tuple(mapped_table.unique_key_value_to_id(ref_key, v, strict=True) for v in zip(*id_value)) + if all(isinstance(v, (tuple, list)) for v in id_value) + else mapped_table.unique_key_value_to_id(ref_key, id_value, strict=True) + ) + except KeyError as err: + # Happens at unique_key_value_to_id(..., strict=True) + return f"can't find {ref_type} with {dict(zip(ref_key, err.args[0]))}" - def override_create_import_alternative(self, method): - self._create_import_alternative = MethodType(method, self) - self._import_alternative_name = None + def polish(self): + """Polishes this item once all it's references have been resolved. Returns any error. - def override_entity_class_sq_maker(self, method): - """ - Overrides the function that creates the ``entity_class_sq`` property. + The base implementation sets defaults but subclasses can do more work if needed. - Args: - method (Callable): a function that accepts a :class:`DatabaseMappingBase` as its argument and - returns entity class subquery as an :class:`Alias` object + Returns: + str or None: error description if any. """ - self._make_entity_class_sq = MethodType(method, self) - self._clear_subqueries("entity_class") + for key, default_value in self._defaults.items(): + self.setdefault(key, default_value) + return "" - def override_entity_sq_maker(self, method): - """ - Overrides the function that creates the ``entity_sq`` property. + def _get_ref(self, ref_type, ref_id, strong=True): + """Collects a reference from the cache. + Adds this item to the reference's list of referrers if strong is True; + or weak referrers if strong is False. + If the reference is not found, sets some flags. Args: - method (Callable): a function that accepts a :class:`DatabaseMappingBase` as its argument and - returns entity subquery as an :class:`Alias` object - """ - self._make_entity_sq = MethodType(method, self) - self._clear_subqueries("entity") + ref_type (str): The references's type + ref_id (int): The references's id + strong (bool): True if the reference corresponds to a foreign key, False otherwise - def override_entity_element_sq_maker(self, method): - """ - Overrides the function that creates the ``entity_element_sq`` property. - - Args: - method (Callable): a function that accepts a :class:`DatabaseMappingBase` as its argument and - returns entity_element subquery as an :class:`Alias` object - """ - self._make_entity_element_sq = MethodType(method, self) - self._clear_subqueries("entity_element") + Returns: + MappedItemBase or dict + """ + ref = self._db_cache.get_mapped_item(ref_type, ref_id) + if not ref: + if not strong: + return {} + ref = self._db_cache.fetch_ref(ref_type, ref_id) + if not ref: + self._corrupted = True + return {} + # Here we have a ref + if strong: + ref.add_referrer(self) + if ref.removed: + self._to_remove = True + else: + ref.add_weak_referrer(self) + if ref.removed: + return {} + return ref - def override_parameter_definition_sq_maker(self, method): - """ - Overrides the function that creates the ``parameter_definition_sq`` property. + def _invalidate_ref(self, ref_type, ref_id): + """Invalidates a reference previously collected from the cache. Args: - method (Callable): a function that accepts a :class:`DatabaseMappingBase` as its argument and - returns parameter definition subquery as an :class:`Alias` object - """ - self._make_parameter_definition_sq = MethodType(method, self) - self._clear_subqueries("parameter_definition") - - def override_parameter_value_sq_maker(self, method): + ref_type (str): The references's type + ref_id (int): The references's id """ - Overrides the function that creates the ``parameter_value_sq`` property. + ref = self._db_cache.get_mapped_item(ref_type, ref_id) + ref.remove_referrer(self) - Args: - method (Callable): a function that accepts a :class:`DatabaseMappingBase` as its argument and - returns parameter value subquery as an :class:`Alias` object - """ - self._make_parameter_value_sq = MethodType(method, self) - self._clear_subqueries("parameter_value") + def is_valid(self): + """Checks if this item has all its references. + Removes the item from the cache if not valid by calling ``cascade_remove``. - def override_alternative_sq_maker(self, method): - """ - Overrides the function that creates the ``alternative_sq`` property. + Returns: + bool + """ + if self._valid is not None: + return self._valid + if self._removed or self._corrupted: + return False + self._to_remove = False + self._corrupted = False + for key in self._references: + _ = self[key] + if self._to_remove: + self.cascade_remove() + self._valid = not self._removed and not self._corrupted + return self._valid + + def add_referrer(self, referrer): + """Adds a strong referrer to this item. Strong referrers are removed, updated and restored + in cascade with this item. Args: - method (Callable): a function that accepts a :class:`DatabaseMappingBase` as its argument and - returns alternative subquery as an :class:`Alias` object + referrer (MappedItemBase) """ - self._make_alternative_sq = MethodType(method, self) - self._clear_subqueries("alternative") + if referrer.key is None: + return + self._referrers[referrer.key] = self._weak_referrers.pop(referrer.key, referrer) - def override_scenario_sq_maker(self, method): - """ - Overrides the function that creates the ``scenario_sq`` property. + def remove_referrer(self, referrer): + """Removes a strong referrer. Args: - method (Callable): a function that accepts a :class:`DatabaseMappingBase` as its argument and - returns scenario subquery as an :class:`Alias` object + referrer (MappedItemBase) """ - self._make_scenario_sq = MethodType(method, self) - self._clear_subqueries("scenario") + if referrer.key is None: + return + self._referrers.pop(referrer.key, None) - def override_scenario_alternative_sq_maker(self, method): - """ - Overrides the function that creates the ``scenario_alternative_sq`` property. + def add_weak_referrer(self, referrer): + """Adds a weak referrer to this item. + Weak referrers' update callbacks are called whenever this item changes. Args: - method (Callable): a function that accepts a :class:`DatabaseMappingBase` as its argument and - returns scenario alternative subquery as an :class:`Alias` object - """ - self._make_scenario_alternative_sq = MethodType(method, self) - self._clear_subqueries("scenario_alternative") - - def restore_entity_class_sq_maker(self): - """Restores the original function that creates the ``entity_class_sq`` property.""" - self._make_entity_class_sq = MethodType(DatabaseMappingBase._make_entity_class_sq, self) - self._clear_subqueries("entity_class") - - def restore_entity_sq_maker(self): - """Restores the original function that creates the ``entity_sq`` property.""" - self._make_entity_sq = MethodType(DatabaseMappingBase._make_entity_sq, self) - self._clear_subqueries("entity") - - def restore_entity_element_sq_maker(self): - """Restores the original function that creates the ``entity_element_sq`` property.""" - self._make_entity_element_sq = MethodType(DatabaseMappingBase._make_entity_element_sq, self) - self._clear_subqueries("entity_element") - - def restore_parameter_definition_sq_maker(self): - """Restores the original function that creates the ``parameter_definition_sq`` property.""" - self._make_parameter_definition_sq = MethodType(DatabaseMappingBase._make_parameter_definition_sq, self) - self._clear_subqueries("parameter_definition") - - def restore_parameter_value_sq_maker(self): - """Restores the original function that creates the ``parameter_value_sq`` property.""" - self._make_parameter_value_sq = MethodType(DatabaseMappingBase._make_parameter_value_sq, self) - self._clear_subqueries("parameter_value") - - def restore_alternative_sq_maker(self): - """Restores the original function that creates the ``alternative_sq`` property.""" - self._make_alternative_sq = MethodType(DatabaseMappingBase._make_alternative_sq, self) - self._clear_subqueries("alternative") - - def restore_scenario_sq_maker(self): - """Restores the original function that creates the ``scenario_sq`` property.""" - self._make_scenario_sq = MethodType(DatabaseMappingBase._make_scenario_sq, self) - self._clear_subqueries("scenario") - - def restore_scenario_alternative_sq_maker(self): - """Restores the original function that creates the ``scenario_alternative_sq`` property.""" - self._make_scenario_alternative_sq = MethodType(DatabaseMappingBase._make_scenario_alternative_sq, self) - self._clear_subqueries("scenario_alternative") - - def _get_primary_key(self, tablename): - pk = self.composite_pks.get(tablename) - if pk is None: - id_field = self._id_fields.get(tablename, "id") - pk = (id_field,) - return pk - - def _reset_mapping(self): - """Delete all records from all tables but don't drop the tables. - Useful for writing tests - """ - with self.engine.connect() as connection: - for tablename in self._tablenames: - table = self._metadata.tables[tablename] - connection.execute(table.delete()) - connection.execute("INSERT INTO alternative VALUES (1, 'Base', 'Base alternative', null)") - - def _object_class_id(self): - return case( - [(self.wide_entity_class_sq.c.dimension_id_list == None, self.wide_entity_class_sq.c.id)], else_=None - ) - - def _relationship_class_id(self): - return case( - [(self.wide_entity_class_sq.c.dimension_id_list != None, self.wide_entity_class_sq.c.id)], else_=None - ) - - def _object_id(self): - return case([(self.wide_entity_sq.c.element_id_list == None, self.wide_entity_sq.c.id)], else_=None) - - def _relationship_id(self): - return case([(self.wide_entity_sq.c.element_id_list != None, self.wide_entity_sq.c.id)], else_=None) - - def _object_class_name(self): - return case( - [(self.wide_entity_class_sq.c.dimension_id_list == None, self.wide_entity_class_sq.c.name)], else_=None - ) - - def _relationship_class_name(self): - return case( - [(self.wide_entity_class_sq.c.dimension_id_list != None, self.wide_entity_class_sq.c.name)], else_=None - ) - - def _object_class_id_list(self): - return case( - [ - ( - self.wide_entity_class_sq.c.dimension_id_list != None, - self.wide_relationship_class_sq.c.object_class_id_list, - ) - ], - else_=None, - ) - - def _object_class_name_list(self): - return case( - [ - ( - self.wide_entity_class_sq.c.dimension_id_list != None, - self.wide_relationship_class_sq.c.object_class_name_list, - ) - ], - else_=None, - ) - - def _object_name(self): - return case([(self.wide_entity_sq.c.element_id_list == None, self.wide_entity_sq.c.name)], else_=None) - - def _object_id_list(self): - return case( - [(self.wide_entity_sq.c.element_id_list != None, self.wide_relationship_sq.c.object_id_list)], else_=None - ) - - def _object_name_list(self): - return case( - [(self.wide_entity_sq.c.element_id_list != None, self.wide_relationship_sq.c.object_name_list)], else_=None - ) + referrer (MappedItemBase) + """ + if referrer.key is None: + return + if referrer.key not in self._referrers: + self._weak_referrers[referrer.key] = referrer + + def _update_weak_referrers(self): + for weak_referrer in self._weak_referrers.values(): + weak_referrer.call_update_callbacks() + + def cascade_restore(self, source=None): + """Restores this item (if removed) and all its referrers in cascade. + Also, updates items' status and calls their restore callbacks. + """ + if not self._removed: + return + if source is not self._removal_source: + return + if self.status in (Status.added_and_removed, Status.to_remove): + self._status = self._status_when_removed + elif self.status == Status.committed: + self._status = Status.to_add + else: + raise RuntimeError("invalid status for item being restored") + self._removed = False + # First restore this, then referrers + obsolete = set() + for callback in list(self.restore_callbacks): + if not callback(self): + obsolete.add(callback) + self.restore_callbacks -= obsolete + for referrer in self._referrers.values(): + referrer.cascade_restore(source=self) + self._update_weak_referrers() + + def cascade_remove(self, source=None): + """Removes this item and all its referrers in cascade. + Also, updates items' status and calls their remove callbacks. + """ + if self._removed: + return + self._status_when_removed = self._status + if self._status == Status.to_add: + self._status = Status.added_and_removed + elif self._status in (Status.committed, Status.to_update): + self._status = Status.to_remove + else: + raise RuntimeError("invalid status for item being removed") + self._removal_source = source + self._removed = True + self._to_remove = False + self._valid = None + # First remove referrers, then this + for referrer in self._referrers.values(): + referrer.cascade_remove(source=self) + self._update_weak_referrers() + obsolete = set() + for callback in list(self.remove_callbacks): + if not callback(self): + obsolete.add(callback) + self.remove_callbacks -= obsolete + + def cascade_update(self): + """Updates this item and all its referrers in cascade. + Also, calls items' update callbacks. + """ + self.call_update_callbacks() + for referrer in self._referrers.values(): + referrer.cascade_update() + self._update_weak_referrers() + + def call_update_callbacks(self): + obsolete = set() + for callback in list(self.update_callbacks): + if not callback(self): + obsolete.add(callback) + self.update_callbacks -= obsolete + + def is_committed(self): + """Returns whether or not this item is committed to the DB. - @staticmethod - def _convert_legacy(tablename, item): - if tablename in ("entity_class", "entity"): - object_class_id_list = tuple(item.pop("object_class_id_list", ())) - if object_class_id_list: - item["dimension_id_list"] = object_class_id_list - object_class_name_list = tuple(item.pop("object_class_name_list", ())) - if object_class_name_list: - item["dimension_name_list"] = object_class_name_list - if tablename == "entity": - object_id_list = tuple(item.pop("object_id_list", ())) - if object_id_list: - item["element_id_list"] = object_id_list - object_name_list = tuple(item.pop("object_name_list", ())) - if object_name_list: - item["element_name_list"] = object_name_list - if tablename in ("parameter_definition", "parameter_value"): - entity_class_id = item.pop("object_class_id", None) or item.pop("relationship_class_id", None) - if entity_class_id: - item["entity_class_id"] = entity_class_id - if tablename == "parameter_value": - entity_id = item.pop("object_id", None) or item.pop("relationship_id", None) - if entity_id: - item["entity_id"] = entity_id - - def __del__(self): - self.close() + Returns: + bool + """ + return self._status == Status.committed + + def commit(self, commit_id): + """Sets this item as committed with the given commit id.""" + self._status = Status.committed + if commit_id: + self["commit_id"] = commit_id + + def __repr__(self): + """Overridden to return a more verbose representation.""" + return f"{self._item_type}{self._extended()}" + + def __getattr__(self, name): + """Overridden to return the dictionary key named after the attribute, or None if it doesn't exist.""" + # FIXME: We should try and get rid of this one + return self.get(name) + + def __getitem__(self, key): + """Overridden to return references.""" + ref = self._references.get(key) + if ref: + src_key, (ref_type, ref_key) = ref + ref_id = self[src_key] + if isinstance(ref_id, tuple): + return tuple(self._get_ref(ref_type, x).get(ref_key) for x in ref_id) + return self._get_ref(ref_type, ref_id).get(ref_key) + return super().__getitem__(key) + + def __setitem__(self, key, value): + """Sets id valid if key is 'id'.""" + if key == "id": + self._is_id_valid = True + super().__setitem__(key, value) + + def get(self, key, default=None): + """Overridden to return references.""" + try: + return self[key] + except KeyError: + return default + + def update(self, other): + """Overridden to update the item status and also to invalidate references that become obsolete.""" + if self._status == Status.committed: + self._status = Status.to_update + self._backup = self._asdict() + elif self._status in (Status.to_remove, Status.added_and_removed): + raise RuntimeError("invalid status of item being updated") + for src_key, (ref_type, _ref_key) in self._references.values(): + ref_id = self[src_key] + if src_key in other and other[src_key] != ref_id: + # Invalidate references + if isinstance(ref_id, tuple): + for x in ref_id: + self._invalidate_ref(ref_type, x) + else: + self._invalidate_ref(ref_type, ref_id) + super().update(other) + if self._asdict() == self._backup: + self._status = Status.committed diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py index 207a157e..3aaba6f4 100644 --- a/spinedb_api/db_mapping_commit_mixin.py +++ b/spinedb_api/db_mapping_commit_mixin.py @@ -26,7 +26,7 @@ def commit_session(self, comment): """ if not comment: raise SpineDBAPIError("Commit message cannot be empty.") - dirty_items = self.cache.dirty_items() + dirty_items = self.dirty_items() if not dirty_items: raise SpineDBAPIError("Nothing to commit.") user = self.username @@ -50,11 +50,11 @@ def commit_session(self, comment): def rollback_session(self): """Discards all the changes from the in-memory mapping.""" - if not self.cache.rollback(): + if not self.rollback(): raise SpineDBAPIError("Nothing to rollback.") if self._memory: self._memory_dirty = False def refresh_session(self): """Resets the fetch status so new items from the DB can be retrieved.""" - self.cache.refresh() + self.refresh() diff --git a/spinedb_api/db_mapping_query_mixin.py b/spinedb_api/db_mapping_query_mixin.py new file mode 100644 index 00000000..fb0344e3 --- /dev/null +++ b/spinedb_api/db_mapping_query_mixin.py @@ -0,0 +1,1478 @@ +###################################################################################################################### +# Copyright (C) 2017-2022 Spine project consortium +# This file is part of Spine Database API. +# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser +# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General +# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with +# this program. If not, see . +###################################################################################################################### + +from types import MethodType +from sqlalchemy import Table, Integer, case, func, cast, and_, or_ +from sqlalchemy.sql.expression import Alias, label +from sqlalchemy.orm import aliased +from .helpers import forward_sweep, group_concat +from .query import Query + + +class DatabaseMappingQueryMixin: + """Provides the :meth:`query` method for performing custom ``SELECT`` queries.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Subqueries that select everything from each table + self._commit_sq = None + self._alternative_sq = None + self._scenario_sq = None + self._scenario_alternative_sq = None + self._entity_class_sq = None + self._entity_sq = None + self._entity_class_dimension_sq = None + self._entity_element_sq = None + self._entity_alternative_sq = None + self._object_class_sq = None + self._object_sq = None + self._relationship_class_sq = None + self._relationship_sq = None + self._entity_group_sq = None + self._parameter_definition_sq = None + self._parameter_value_sq = None + self._parameter_value_list_sq = None + self._list_value_sq = None + self._metadata_sq = None + self._parameter_value_metadata_sq = None + self._entity_metadata_sq = None + # Special convenience subqueries that join two or more tables + self._wide_entity_class_sq = None + self._wide_entity_sq = None + self._ext_parameter_value_list_sq = None + self._wide_parameter_value_list_sq = None + self._ord_list_value_sq = None + self._ext_scenario_sq = None + self._wide_scenario_sq = None + self._linked_scenario_alternative_sq = None + self._ext_linked_scenario_alternative_sq = None + self._ext_object_sq = None + self._ext_relationship_class_sq = None + self._wide_relationship_class_sq = None + self._ext_relationship_class_object_parameter_definition_sq = None + self._wide_relationship_class_object_parameter_definition_sq = None + self._ext_relationship_sq = None + self._wide_relationship_sq = None + self._ext_entity_group_sq = None + self._entity_parameter_definition_sq = None + self._object_parameter_definition_sq = None + self._relationship_parameter_definition_sq = None + self._entity_parameter_value_sq = None + self._object_parameter_value_sq = None + self._relationship_parameter_value_sq = None + self._ext_parameter_value_metadata_sq = None + self._ext_entity_metadata_sq = None + self._import_alternative_name = None + self._table_to_sq_attr = {} + + def _get_table_to_sq_attr(self): + if not self._table_to_sq_attr: + self._table_to_sq_attr = self._make_table_to_sq_attr() + return self._table_to_sq_attr + + def _make_table_to_sq_attr(self): + """Returns a dict mapping table names to subquery attribute names, involving that table.""" + + def _func(x, tables): + if isinstance(x, Table): + tables.add(x.name) # pylint: disable=cell-var-from-loop + + # This 'loads' our subquery attributes + for attr in dir(self): + getattr(self, attr) + table_to_sq_attr = {} + for attr, val in vars(self).items(): + if not isinstance(val, Alias): + continue + tables = set() + forward_sweep(val, _func, tables) + # Now `tables` contains all tables related to `val` + for table in tables: + table_to_sq_attr.setdefault(table, set()).add(attr) + return table_to_sq_attr + + def _clear_subqueries(self, *tablenames): + """Set to `None` subquery attributes involving the affected tables. + This forces the subqueries to be refreshed when the corresponding property is accessed. + """ + tablenames = list(tablenames) + for tablename in tablenames: + if self.pop(tablename, False): + self.fetch_all(tablename) + attr_names = set(attr for tablename in tablenames for attr in self._get_table_to_sq_attr().get(tablename, [])) + for attr_name in attr_names: + setattr(self, attr_name, None) + + def query(self, *args, **kwargs): + """Returns a :class:`~spinedb_api.query.Query` object to execute against the mapped DB. + + To perform custom ``SELECT`` statements, call this method with one or more of the class documented + subquery properties (of :class:`~sqlalchemy.sql.expression.Alias` type). + For example, to select the entity class with ``id`` equal to 1:: + + from spinedb_api import DatabaseMapping + url = 'sqlite:///spine.db' + ... + db_map = DatabaseMapping(url) + db_map.query(db_map.entity_class_sq).filter_by(id=1).one_or_none() + + To perform more complex queries, just use the :class:`~spinedb_api.query.Query` interface + (which is a close clone of SQL Alchemy's :class:`~sqlalchemy.orm.query.Query`). + For example, to select all entity class names and the names of their entities concatenated in a comma-separated + string:: + + from sqlalchemy import func + + db_map.query( + db_map.entity_class_sq.c.name, func.group_concat(db_map.entity_sq.c.name) + ).filter( + db_map.entity_sq.c.class_id == db_map.entity_class_sq.c.id + ).group_by(db_map.entity_class_sq.c.name).all() + """ + return Query(self.engine, *args) + + def _subquery(self, tablename): + """A subquery of the form: + + .. code-block:: sql + + SELECT * FROM tablename + + Args: + tablename (str): the table to be queried. + + Returns: + :class:`~sqlalchemy.sql.expression.Alias` + """ + table = self._metadata.tables[tablename] + return self.query(table).subquery(tablename + "_sq") + + @property + def entity_class_sq(self): + """A subquery of the form: + + .. code-block:: sql + + SELECT * FROM entity_class + + Returns: + :class:`~sqlalchemy.sql.expression.Alias` + """ + if self._entity_class_sq is None: + self._entity_class_sq = self._make_entity_class_sq() + return self._entity_class_sq + + @property + def entity_class_dimension_sq(self): + """A subquery of the form: + + .. code-block:: sql + + SELECT * FROM entity_class_dimension + + Returns: + :class:`~sqlalchemy.sql.expression.Alias` + """ + if self._entity_class_dimension_sq is None: + self._entity_class_dimension_sq = self._subquery("entity_class_dimension") + return self._entity_class_dimension_sq + + @property + def wide_entity_class_sq(self): + """A subquery of the form: + + .. code-block:: sql + + SELECT + ec.*, + count(ecd.dimension_id) AS dimension_count + group_concat(ecd.dimension_id) AS dimension_id_list + FROM + entity_class AS ec + entity_class_dimension AS ecd + WHERE + ec.id == ecd.entity_class_id + + Returns: + :class:`~sqlalchemy.sql.expression.Alias` + """ + if self._wide_entity_class_sq is None: + entity_class_dimension_sq = ( + self.query( + self.entity_class_dimension_sq.c.entity_class_id, + self.entity_class_dimension_sq.c.dimension_id, + self.entity_class_dimension_sq.c.position, + self.entity_class_sq.c.name.label("dimension_name"), + ) + .filter(self.entity_class_dimension_sq.c.dimension_id == self.entity_class_sq.c.id) + .subquery("entity_class_dimension_sq") + ) + ecd_sq = ( + self.query( + self.entity_class_sq.c.id, + self.entity_class_sq.c.name, + self.entity_class_sq.c.description, + self.entity_class_sq.c.display_order, + self.entity_class_sq.c.display_icon, + self.entity_class_sq.c.hidden, + entity_class_dimension_sq.c.dimension_id, + entity_class_dimension_sq.c.dimension_name, + entity_class_dimension_sq.c.position, + ) + .outerjoin( + entity_class_dimension_sq, + self.entity_class_sq.c.id == entity_class_dimension_sq.c.entity_class_id, + ) + .order_by(self.entity_class_sq.c.id, entity_class_dimension_sq.c.position) + .subquery("ext_entity_class_sq") + ) + self._wide_entity_class_sq = ( + self.query( + ecd_sq.c.id, + ecd_sq.c.name, + ecd_sq.c.description, + ecd_sq.c.display_order, + ecd_sq.c.display_icon, + ecd_sq.c.hidden, + group_concat(ecd_sq.c.dimension_id, ecd_sq.c.position).label("dimension_id_list"), + group_concat(ecd_sq.c.dimension_name, ecd_sq.c.position).label("dimension_name_list"), + func.count(ecd_sq.c.dimension_id).label("dimension_count"), + ) + .group_by( + ecd_sq.c.id, + ecd_sq.c.name, + ecd_sq.c.description, + ecd_sq.c.display_order, + ecd_sq.c.display_icon, + ecd_sq.c.hidden, + ) + .subquery("wide_entity_class_sq") + ) + return self._wide_entity_class_sq + + @property + def entity_sq(self): + """A subquery of the form: + + .. code-block:: sql + + SELECT * FROM entity + + Returns: + :class:`~sqlalchemy.sql.expression.Alias` + """ + if self._entity_sq is None: + self._entity_sq = self._make_entity_sq() + return self._entity_sq + + @property + def entity_element_sq(self): + """A subquery of the form: + + .. code-block:: sql + + SELECT * FROM entity_element + + Returns: + :class:`~sqlalchemy.sql.expression.Alias` + """ + if self._entity_element_sq is None: + self._entity_element_sq = self._make_entity_element_sq() + return self._entity_element_sq + + @property + def wide_entity_sq(self): + """A subquery of the form: + + .. code-block:: sql + + SELECT + e.*, + count(ee.element_id) AS element_count + group_concat(ee.element_id) AS element_id_list + FROM + entity AS e + entity_element AS ee + WHERE + e.id == ee.entity_id + + Returns: + :class:`~sqlalchemy.sql.expression.Alias` + """ + if self._wide_entity_sq is None: + entity_element_sq = ( + self.query(self.entity_element_sq, self.entity_sq.c.name.label("element_name")) + .filter(self.entity_element_sq.c.element_id == self.entity_sq.c.id) + .subquery("entity_element_sq") + ) + ext_entity_sq = ( + self.query(self.entity_sq, entity_element_sq) + .outerjoin( + entity_element_sq, + self.entity_sq.c.id == entity_element_sq.c.entity_id, + ) + .order_by(self.entity_sq.c.id, entity_element_sq.c.position) + .subquery("ext_entity_sq") + ) + self._wide_entity_sq = ( + self.query( + ext_entity_sq.c.id, + ext_entity_sq.c.class_id, + ext_entity_sq.c.name, + ext_entity_sq.c.description, + ext_entity_sq.c.commit_id, + group_concat(ext_entity_sq.c.element_id, ext_entity_sq.c.position).label("element_id_list"), + group_concat(ext_entity_sq.c.element_name, ext_entity_sq.c.position).label("element_name_list"), + ) + .group_by( + ext_entity_sq.c.id, + ext_entity_sq.c.class_id, + ext_entity_sq.c.name, + ext_entity_sq.c.description, + ext_entity_sq.c.commit_id, + ) + .subquery("wide_entity_sq") + ) + return self._wide_entity_sq + + @property + def entity_group_sq(self): + """A subquery of the form: + + .. code-block:: sql + + SELECT * FROM entity_group + + Returns: + :class:`~sqlalchemy.sql.expression.Alias` + """ + if self._entity_group_sq is None: + self._entity_group_sq = self._subquery("entity_group") + return self._entity_group_sq + + @property + def alternative_sq(self): + """A subquery of the form: + + .. code-block:: sql + + SELECT * FROM alternative + + Returns: + :class:`~sqlalchemy.sql.expression.Alias` + """ + if self._alternative_sq is None: + self._alternative_sq = self._make_alternative_sq() + return self._alternative_sq + + @property + def scenario_sq(self): + """A subquery of the form: + + .. code-block:: sql + + SELECT * FROM scenario + + Returns: + :class:`~sqlalchemy.sql.expression.Alias` + """ + if self._scenario_sq is None: + self._scenario_sq = self._make_scenario_sq() + return self._scenario_sq + + @property + def scenario_alternative_sq(self): + """A subquery of the form: + + .. code-block:: sql + + SELECT * FROM scenario_alternative + + Returns: + :class:`~sqlalchemy.sql.expression.Alias` + """ + if self._scenario_alternative_sq is None: + self._scenario_alternative_sq = self._make_scenario_alternative_sq() + return self._scenario_alternative_sq + + @property + def entity_alternative_sq(self): + """A subquery of the form: + + .. code-block:: sql + + SELECT * FROM entity_alternative + + Returns: + :class:`~sqlalchemy.sql.expression.Alias` + """ + if self._entity_alternative_sq is None: + self._entity_alternative_sq = self._subquery("entity_alternative") + return self._entity_alternative_sq + + @property + def parameter_value_list_sq(self): + """A subquery of the form: + + .. code-block:: sql + + SELECT * FROM parameter_value_list + + Returns: + :class:`~sqlalchemy.sql.expression.Alias` + """ + if self._parameter_value_list_sq is None: + self._parameter_value_list_sq = self._subquery("parameter_value_list") + return self._parameter_value_list_sq + + @property + def list_value_sq(self): + """A subquery of the form: + + .. code-block:: sql + + SELECT * FROM list_value + + Returns: + :class:`~sqlalchemy.sql.expression.Alias` + """ + if self._list_value_sq is None: + self._list_value_sq = self._subquery("list_value") + return self._list_value_sq + + @property + def parameter_definition_sq(self): + """A subquery of the form: + + .. code-block:: sql + + SELECT * FROM parameter_definition + + Returns: + :class:`~sqlalchemy.sql.expression.Alias` + """ + + if self._parameter_definition_sq is None: + self._parameter_definition_sq = self._make_parameter_definition_sq() + return self._parameter_definition_sq + + @property + def parameter_value_sq(self): + """A subquery of the form: + + .. code-block:: sql + + SELECT * FROM parameter_value + + Returns: + :class:`~sqlalchemy.sql.expression.Alias` + """ + if self._parameter_value_sq is None: + self._parameter_value_sq = self._make_parameter_value_sq() + return self._parameter_value_sq + + @property + def metadata_sq(self): + """A subquery of the form: + + .. code-block:: sql + + SELECT * FROM list_value + + Returns: + :class:`~sqlalchemy.sql.expression.Alias` + """ + if self._metadata_sq is None: + self._metadata_sq = self._subquery("metadata") + return self._metadata_sq + + @property + def parameter_value_metadata_sq(self): + """A subquery of the form: + + .. code-block:: sql + + SELECT * FROM parameter_value_metadata + + Returns: + :class:`~sqlalchemy.sql.expression.Alias` + """ + if self._parameter_value_metadata_sq is None: + self._parameter_value_metadata_sq = self._subquery("parameter_value_metadata") + return self._parameter_value_metadata_sq + + @property + def entity_metadata_sq(self): + """A subquery of the form: + + .. code-block:: sql + + SELECT * FROM entity_metadata + + Returns: + :class:`~sqlalchemy.sql.expression.Alias` + """ + if self._entity_metadata_sq is None: + self._entity_metadata_sq = self._subquery("entity_metadata") + return self._entity_metadata_sq + + @property + def commit_sq(self): + """A subquery of the form: + + .. code-block:: sql + + SELECT * FROM commit + + Returns: + :class:`~sqlalchemy.sql.expression.Alias` + """ + if self._commit_sq is None: + commit_sq = self._subquery("commit") + self._commit_sq = self.query(commit_sq).filter(commit_sq.c.comment != "").subquery() + return self._commit_sq + + @property + def object_class_sq(self): + if self._object_class_sq is None: + self._object_class_sq = ( + self.query( + self.wide_entity_class_sq.c.id.label("id"), + self.wide_entity_class_sq.c.name.label("name"), + self.wide_entity_class_sq.c.description.label("description"), + self.wide_entity_class_sq.c.display_order.label("display_order"), + self.wide_entity_class_sq.c.display_icon.label("display_icon"), + self.wide_entity_class_sq.c.hidden.label("hidden"), + ) + .filter(self.wide_entity_class_sq.c.dimension_id_list == None) + .subquery("object_class_sq") + ) + return self._object_class_sq + + @property + def object_sq(self): + if self._object_sq is None: + self._object_sq = ( + self.query( + self.wide_entity_sq.c.id.label("id"), + self.wide_entity_sq.c.class_id.label("class_id"), + self.wide_entity_sq.c.name.label("name"), + self.wide_entity_sq.c.description.label("description"), + self.wide_entity_sq.c.commit_id.label("commit_id"), + ) + .filter(self.wide_entity_sq.c.element_id_list == None) + .subquery("object_sq") + ) + return self._object_sq + + @property + def relationship_class_sq(self): + if self._relationship_class_sq is None: + ent_cls_dim_sq = self._subquery("entity_class_dimension") + self._relationship_class_sq = ( + self.query( + ent_cls_dim_sq.c.entity_class_id.label("id"), + ent_cls_dim_sq.c.position.label("dimension"), # NOTE: nothing to do with the `dimension` concept + ent_cls_dim_sq.c.dimension_id.label("object_class_id"), + self.wide_entity_class_sq.c.name.label("name"), + self.wide_entity_class_sq.c.description.label("description"), + self.wide_entity_class_sq.c.display_icon.label("display_icon"), + self.wide_entity_class_sq.c.hidden.label("hidden"), + ) + .filter(self.wide_entity_class_sq.c.id == ent_cls_dim_sq.c.entity_class_id) + .subquery("relationship_class_sq") + ) + return self._relationship_class_sq + + @property + def relationship_sq(self): + if self._relationship_sq is None: + ent_el_sq = self._subquery("entity_element") + self._relationship_sq = ( + self.query( + ent_el_sq.c.entity_id.label("id"), + ent_el_sq.c.position.label("dimension"), # NOTE: nothing to do with the `dimension` concept + ent_el_sq.c.element_id.label("object_id"), + ent_el_sq.c.entity_class_id.label("class_id"), + self.wide_entity_sq.c.name.label("name"), + self.wide_entity_sq.c.commit_id.label("commit_id"), + ) + .filter(self.wide_entity_sq.c.id == ent_el_sq.c.entity_id) + .subquery("relationship_sq") + ) + return self._relationship_sq + + @property + def ext_parameter_value_list_sq(self): + if self._ext_parameter_value_list_sq is None: + self._ext_parameter_value_list_sq = ( + self.query( + self.parameter_value_list_sq.c.id, + self.parameter_value_list_sq.c.name, + self.parameter_value_list_sq.c.commit_id, + self.list_value_sq.c.id.label("value_id"), + self.list_value_sq.c.index.label("value_index"), + ).outerjoin( + self.list_value_sq, + self.list_value_sq.c.parameter_value_list_id == self.parameter_value_list_sq.c.id, + ) + ).subquery() + return self._ext_parameter_value_list_sq + + @property + def wide_parameter_value_list_sq(self): + if self._wide_parameter_value_list_sq is None: + self._wide_parameter_value_list_sq = ( + self.query( + self.ext_parameter_value_list_sq.c.id, + self.ext_parameter_value_list_sq.c.name, + self.ext_parameter_value_list_sq.c.commit_id, + group_concat( + self.ext_parameter_value_list_sq.c.value_id, self.ext_parameter_value_list_sq.c.value_index + ).label("value_id_list"), + group_concat( + self.ext_parameter_value_list_sq.c.value_index, self.ext_parameter_value_list_sq.c.value_index + ).label("value_index_list"), + ).group_by( + self.ext_parameter_value_list_sq.c.id, + self.ext_parameter_value_list_sq.c.name, + self.ext_parameter_value_list_sq.c.commit_id, + ) + ).subquery() + return self._wide_parameter_value_list_sq + + @property + def ord_list_value_sq(self): + if self._ord_list_value_sq is None: + self._ord_list_value_sq = ( + self.query( + self.list_value_sq.c.id, + self.list_value_sq.c.parameter_value_list_id, + self.list_value_sq.c.index, + self.list_value_sq.c.value, + self.list_value_sq.c.type, + self.list_value_sq.c.commit_id, + ) + .order_by(self.list_value_sq.c.parameter_value_list_id, self.list_value_sq.c.index) + .subquery() + ) + return self._ord_list_value_sq + + @property + def ext_scenario_sq(self): + if self._ext_scenario_sq is None: + self._ext_scenario_sq = ( + self.query( + self.scenario_sq.c.id.label("id"), + self.scenario_sq.c.name.label("name"), + self.scenario_sq.c.description.label("description"), + self.scenario_sq.c.active.label("active"), + self.scenario_alternative_sq.c.alternative_id.label("alternative_id"), + self.scenario_alternative_sq.c.rank.label("rank"), + self.alternative_sq.c.name.label("alternative_name"), + self.scenario_sq.c.commit_id.label("commit_id"), + ) + .outerjoin( + self.scenario_alternative_sq, self.scenario_alternative_sq.c.scenario_id == self.scenario_sq.c.id + ) + .outerjoin( + self.alternative_sq, self.alternative_sq.c.id == self.scenario_alternative_sq.c.alternative_id + ) + .order_by(self.scenario_sq.c.id, self.scenario_alternative_sq.c.rank) + .subquery() + ) + return self._ext_scenario_sq + + @property + def wide_scenario_sq(self): + if self._wide_scenario_sq is None: + self._wide_scenario_sq = ( + self.query( + self.ext_scenario_sq.c.id.label("id"), + self.ext_scenario_sq.c.name.label("name"), + self.ext_scenario_sq.c.description.label("description"), + self.ext_scenario_sq.c.active.label("active"), + self.ext_scenario_sq.c.commit_id.label("commit_id"), + group_concat(self.ext_scenario_sq.c.alternative_id, self.ext_scenario_sq.c.rank).label( + "alternative_id_list" + ), + group_concat(self.ext_scenario_sq.c.alternative_name, self.ext_scenario_sq.c.rank).label( + "alternative_name_list" + ), + ) + .group_by( + self.ext_scenario_sq.c.id, + self.ext_scenario_sq.c.name, + self.ext_scenario_sq.c.description, + self.ext_scenario_sq.c.active, + self.ext_scenario_sq.c.commit_id, + ) + .subquery() + ) + return self._wide_scenario_sq + + @property + def linked_scenario_alternative_sq(self): + if self._linked_scenario_alternative_sq is None: + scenario_next_alternative = aliased(self.scenario_alternative_sq) + self._linked_scenario_alternative_sq = ( + self.query( + self.scenario_alternative_sq.c.id.label("id"), + self.scenario_alternative_sq.c.scenario_id.label("scenario_id"), + self.scenario_alternative_sq.c.alternative_id.label("alternative_id"), + self.scenario_alternative_sq.c.rank.label("rank"), + scenario_next_alternative.c.alternative_id.label("before_alternative_id"), + scenario_next_alternative.c.rank.label("before_rank"), + self.scenario_alternative_sq.c.commit_id.label("commit_id"), + ) + .outerjoin( + scenario_next_alternative, + and_( + scenario_next_alternative.c.scenario_id == self.scenario_alternative_sq.c.scenario_id, + scenario_next_alternative.c.rank == self.scenario_alternative_sq.c.rank + 1, + ), + ) + .order_by(self.scenario_alternative_sq.c.scenario_id, self.scenario_alternative_sq.c.rank) + .subquery() + ) + return self._linked_scenario_alternative_sq + + @property + def ext_linked_scenario_alternative_sq(self): + if self._ext_linked_scenario_alternative_sq is None: + next_alternative = aliased(self.alternative_sq) + self._ext_linked_scenario_alternative_sq = ( + self.query( + self.linked_scenario_alternative_sq.c.id.label("id"), + self.linked_scenario_alternative_sq.c.scenario_id.label("scenario_id"), + self.scenario_sq.c.name.label("scenario_name"), + self.linked_scenario_alternative_sq.c.alternative_id.label("alternative_id"), + self.alternative_sq.c.name.label("alternative_name"), + self.linked_scenario_alternative_sq.c.rank.label("rank"), + self.linked_scenario_alternative_sq.c.before_alternative_id.label("before_alternative_id"), + self.linked_scenario_alternative_sq.c.before_rank.label("before_rank"), + next_alternative.c.name.label("before_alternative_name"), + self.linked_scenario_alternative_sq.c.commit_id.label("commit_id"), + ) + .filter(self.linked_scenario_alternative_sq.c.scenario_id == self.scenario_sq.c.id) + .filter(self.alternative_sq.c.id == self.linked_scenario_alternative_sq.c.alternative_id) + .outerjoin( + next_alternative, + next_alternative.c.id == self.linked_scenario_alternative_sq.c.before_alternative_id, + ) + .subquery() + ) + return self._ext_linked_scenario_alternative_sq + + @property + def ext_object_sq(self): + if self._ext_object_sq is None: + self._ext_object_sq = ( + self.query( + self.object_sq.c.id.label("id"), + self.object_sq.c.class_id.label("class_id"), + self.object_class_sq.c.name.label("class_name"), + self.object_sq.c.name.label("name"), + self.object_sq.c.description.label("description"), + self.entity_group_sq.c.entity_id.label("group_id"), + self.object_sq.c.commit_id.label("commit_id"), + ) + .filter(self.object_sq.c.class_id == self.object_class_sq.c.id) + .outerjoin(self.entity_group_sq, self.entity_group_sq.c.entity_id == self.object_sq.c.id) + .distinct(self.entity_group_sq.c.entity_id) + .subquery() + ) + return self._ext_object_sq + + @property + def ext_relationship_class_sq(self): + if self._ext_relationship_class_sq is None: + self._ext_relationship_class_sq = ( + self.query( + self.relationship_class_sq.c.id.label("id"), + self.relationship_class_sq.c.name.label("name"), + self.relationship_class_sq.c.description.label("description"), + self.relationship_class_sq.c.dimension.label("dimension"), + self.relationship_class_sq.c.display_icon.label("display_icon"), + self.object_class_sq.c.id.label("object_class_id"), + self.object_class_sq.c.name.label("object_class_name"), + ) + .filter(self.relationship_class_sq.c.object_class_id == self.object_class_sq.c.id) + .order_by(self.relationship_class_sq.c.id, self.relationship_class_sq.c.dimension) + .subquery() + ) + return self._ext_relationship_class_sq + + @property + def wide_relationship_class_sq(self): + if self._wide_relationship_class_sq is None: + self._wide_relationship_class_sq = ( + self.query( + self.ext_relationship_class_sq.c.id, + self.ext_relationship_class_sq.c.name, + self.ext_relationship_class_sq.c.description, + self.ext_relationship_class_sq.c.display_icon, + group_concat( + self.ext_relationship_class_sq.c.object_class_id, self.ext_relationship_class_sq.c.dimension + ).label("object_class_id_list"), + group_concat( + self.ext_relationship_class_sq.c.object_class_name, self.ext_relationship_class_sq.c.dimension + ).label("object_class_name_list"), + ) + .group_by( + self.ext_relationship_class_sq.c.id, + self.ext_relationship_class_sq.c.name, + self.ext_relationship_class_sq.c.description, + self.ext_relationship_class_sq.c.display_icon, + ) + .subquery() + ) + return self._wide_relationship_class_sq + + @property + def ext_relationship_sq(self): + if self._ext_relationship_sq is None: + self._ext_relationship_sq = ( + self.query( + self.relationship_sq.c.id.label("id"), + self.relationship_sq.c.name.label("name"), + self.relationship_sq.c.class_id.label("class_id"), + self.relationship_sq.c.dimension.label("dimension"), + self.wide_relationship_class_sq.c.name.label("class_name"), + self.ext_object_sq.c.id.label("object_id"), + self.ext_object_sq.c.name.label("object_name"), + self.ext_object_sq.c.class_id.label("object_class_id"), + self.ext_object_sq.c.class_name.label("object_class_name"), + self.relationship_sq.c.commit_id.label("commit_id"), + ) + .filter(self.relationship_sq.c.class_id == self.wide_relationship_class_sq.c.id) + .outerjoin(self.ext_object_sq, self.relationship_sq.c.object_id == self.ext_object_sq.c.id) + .order_by(self.relationship_sq.c.id, self.relationship_sq.c.dimension) + .subquery() + ) + return self._ext_relationship_sq + + @property + def wide_relationship_sq(self): + if self._wide_relationship_sq is None: + self._wide_relationship_sq = ( + self.query( + self.ext_relationship_sq.c.id, + self.ext_relationship_sq.c.name, + self.ext_relationship_sq.c.class_id, + self.ext_relationship_sq.c.class_name, + self.ext_relationship_sq.c.commit_id, + group_concat(self.ext_relationship_sq.c.object_id, self.ext_relationship_sq.c.dimension).label( + "object_id_list" + ), + group_concat(self.ext_relationship_sq.c.object_name, self.ext_relationship_sq.c.dimension).label( + "object_name_list" + ), + group_concat( + self.ext_relationship_sq.c.object_class_id, self.ext_relationship_sq.c.dimension + ).label("object_class_id_list"), + group_concat( + self.ext_relationship_sq.c.object_class_name, self.ext_relationship_sq.c.dimension + ).label("object_class_name_list"), + ) + .group_by( + self.ext_relationship_sq.c.id, + self.ext_relationship_sq.c.name, + self.ext_relationship_sq.c.class_id, + self.ext_relationship_sq.c.class_name, + self.ext_relationship_sq.c.commit_id, + ) + # dimension count might be higher than object count when objects have been filtered out + .having( + func.count(self.ext_relationship_sq.c.dimension) == func.count(self.ext_relationship_sq.c.object_id) + ) + .subquery() + ) + return self._wide_relationship_sq + + @property + def ext_entity_group_sq(self): + if self._ext_entity_group_sq is None: + group_entity = aliased(self.entity_sq) + member_entity = aliased(self.entity_sq) + self._ext_entity_group_sq = ( + self.query( + self.entity_group_sq.c.id.label("id"), + self.entity_group_sq.c.entity_class_id.label("class_id"), + self.entity_group_sq.c.entity_id.label("group_id"), + self.entity_group_sq.c.member_id.label("member_id"), + self.wide_entity_class_sq.c.name.label("class_name"), + group_entity.c.name.label("group_name"), + member_entity.c.name.label("member_name"), + label("object_class_id", self._object_class_id()), + label("relationship_class_id", self._relationship_class_id()), + ) + .filter(self.entity_group_sq.c.entity_class_id == self.wide_entity_class_sq.c.id) + .join(group_entity, self.entity_group_sq.c.entity_id == group_entity.c.id) + .join(member_entity, self.entity_group_sq.c.member_id == member_entity.c.id) + .subquery() + ) + return self._ext_entity_group_sq + + @property + def entity_parameter_definition_sq(self): + if self._entity_parameter_definition_sq is None: + self._entity_parameter_definition_sq = ( + self.query( + self.parameter_definition_sq.c.id.label("id"), + self.parameter_definition_sq.c.entity_class_id, + self.wide_entity_class_sq.c.name.label("entity_class_name"), + label("object_class_id", self._object_class_id()), + label("relationship_class_id", self._relationship_class_id()), + label("object_class_name", self._object_class_name()), + label("relationship_class_name", self._relationship_class_name()), + label("object_class_id_list", self._object_class_id_list()), + label("object_class_name_list", self._object_class_name_list()), + self.parameter_definition_sq.c.name.label("parameter_name"), + self.parameter_definition_sq.c.parameter_value_list_id.label("value_list_id"), + self.parameter_value_list_sq.c.name.label("value_list_name"), + self.parameter_definition_sq.c.default_value, + self.parameter_definition_sq.c.default_type, + self.parameter_definition_sq.c.list_value_id, + self.parameter_definition_sq.c.description, + self.parameter_definition_sq.c.commit_id, + ) + .join( + self.wide_entity_class_sq, + self.wide_entity_class_sq.c.id == self.parameter_definition_sq.c.entity_class_id, + ) + .outerjoin( + self.parameter_value_list_sq, + self.parameter_value_list_sq.c.id == self.parameter_definition_sq.c.parameter_value_list_id, + ) + .outerjoin( + self.wide_relationship_class_sq, + self.wide_relationship_class_sq.c.id == self.wide_entity_class_sq.c.id, + ) + .subquery() + ) + return self._entity_parameter_definition_sq + + @property + def object_parameter_definition_sq(self): + if self._object_parameter_definition_sq is None: + self._object_parameter_definition_sq = ( + self.query( + self.parameter_definition_sq.c.id.label("id"), + self.parameter_definition_sq.c.entity_class_id, + self.object_class_sq.c.name.label("entity_class_name"), + self.object_class_sq.c.id.label("object_class_id"), + self.object_class_sq.c.name.label("object_class_name"), + self.parameter_definition_sq.c.name.label("parameter_name"), + self.parameter_definition_sq.c.parameter_value_list_id.label("value_list_id"), + self.parameter_value_list_sq.c.name.label("value_list_name"), + self.parameter_definition_sq.c.default_value, + self.parameter_definition_sq.c.default_type, + self.parameter_definition_sq.c.description, + ) + .filter(self.object_class_sq.c.id == self.parameter_definition_sq.c.entity_class_id) + .outerjoin( + self.parameter_value_list_sq, + self.parameter_value_list_sq.c.id == self.parameter_definition_sq.c.parameter_value_list_id, + ) + .subquery() + ) + return self._object_parameter_definition_sq + + @property + def relationship_parameter_definition_sq(self): + if self._relationship_parameter_definition_sq is None: + self._relationship_parameter_definition_sq = ( + self.query( + self.parameter_definition_sq.c.id.label("id"), + self.parameter_definition_sq.c.entity_class_id, + self.wide_relationship_class_sq.c.name.label("entity_class_name"), + self.wide_relationship_class_sq.c.id.label("relationship_class_id"), + self.wide_relationship_class_sq.c.name.label("relationship_class_name"), + self.wide_relationship_class_sq.c.object_class_id_list, + self.wide_relationship_class_sq.c.object_class_name_list, + self.parameter_definition_sq.c.name.label("parameter_name"), + self.parameter_definition_sq.c.parameter_value_list_id.label("value_list_id"), + self.parameter_value_list_sq.c.name.label("value_list_name"), + self.parameter_definition_sq.c.default_value, + self.parameter_definition_sq.c.default_type, + self.parameter_definition_sq.c.description, + ) + .filter(self.parameter_definition_sq.c.entity_class_id == self.wide_relationship_class_sq.c.id) + .outerjoin( + self.parameter_value_list_sq, + self.parameter_value_list_sq.c.id == self.parameter_definition_sq.c.parameter_value_list_id, + ) + .subquery() + ) + return self._relationship_parameter_definition_sq + + @property + def entity_parameter_value_sq(self): + if self._entity_parameter_value_sq is None: + self._entity_parameter_value_sq = ( + self.query( + self.parameter_value_sq.c.id.label("id"), + self.parameter_definition_sq.c.entity_class_id, + self.wide_entity_class_sq.c.name.label("entity_class_name"), + label("object_class_id", self._object_class_id()), + label("relationship_class_id", self._relationship_class_id()), + label("object_class_name", self._object_class_name()), + label("relationship_class_name", self._relationship_class_name()), + label("object_class_id_list", self._object_class_id_list()), + label("object_class_name_list", self._object_class_name_list()), + self.parameter_value_sq.c.entity_id, + self.wide_entity_sq.c.name.label("entity_name"), + label("object_id", self._object_id()), + label("relationship_id", self._relationship_id()), + label("object_name", self._object_name()), + label("object_id_list", self._object_id_list()), + label("object_name_list", self._object_name_list()), + self.parameter_definition_sq.c.id.label("parameter_id"), + self.parameter_definition_sq.c.name.label("parameter_name"), + self.parameter_value_sq.c.alternative_id, + self.alternative_sq.c.name.label("alternative_name"), + self.parameter_value_sq.c.value, + self.parameter_value_sq.c.type, + self.parameter_value_sq.c.list_value_id, + self.parameter_value_sq.c.commit_id, + ) + .join( + self.parameter_definition_sq, + self.parameter_definition_sq.c.id == self.parameter_value_sq.c.parameter_definition_id, + ) + .join(self.wide_entity_sq, self.parameter_value_sq.c.entity_id == self.wide_entity_sq.c.id) + .join( + self.wide_entity_class_sq, + self.parameter_definition_sq.c.entity_class_id == self.wide_entity_class_sq.c.id, + ) + .join(self.alternative_sq, self.parameter_value_sq.c.alternative_id == self.alternative_sq.c.id) + .outerjoin( + self.wide_relationship_class_sq, + self.wide_relationship_class_sq.c.id == self.wide_entity_class_sq.c.id, + ) + .outerjoin(self.wide_relationship_sq, self.wide_relationship_sq.c.id == self.wide_entity_sq.c.id) + # object_id_list might be None when objects have been filtered out + .filter( + or_( + self.wide_relationship_sq.c.id.is_(None), + self.wide_relationship_sq.c.object_id_list.isnot(None), + ) + ) + .subquery() + ) + return self._entity_parameter_value_sq + + @property + def object_parameter_value_sq(self): + if self._object_parameter_value_sq is None: + self._object_parameter_value_sq = ( + self.query( + self.parameter_value_sq.c.id.label("id"), + self.parameter_definition_sq.c.entity_class_id, + self.object_class_sq.c.id.label("object_class_id"), + self.object_class_sq.c.name.label("object_class_name"), + self.parameter_value_sq.c.entity_id, + self.object_sq.c.id.label("object_id"), + self.object_sq.c.name.label("object_name"), + self.parameter_definition_sq.c.id.label("parameter_id"), + self.parameter_definition_sq.c.name.label("parameter_name"), + self.parameter_value_sq.c.alternative_id, + self.alternative_sq.c.name.label("alternative_name"), + self.parameter_value_sq.c.value, + self.parameter_value_sq.c.type, + ) + .filter(self.parameter_definition_sq.c.id == self.parameter_value_sq.c.parameter_definition_id) + .filter(self.parameter_value_sq.c.entity_id == self.object_sq.c.id) + .filter(self.parameter_definition_sq.c.entity_class_id == self.object_class_sq.c.id) + .filter(self.parameter_value_sq.c.alternative_id == self.alternative_sq.c.id) + .subquery() + ) + return self._object_parameter_value_sq + + @property + def relationship_parameter_value_sq(self): + if self._relationship_parameter_value_sq is None: + self._relationship_parameter_value_sq = ( + self.query( + self.parameter_value_sq.c.id.label("id"), + self.parameter_definition_sq.c.entity_class_id, + self.wide_relationship_class_sq.c.id.label("relationship_class_id"), + self.wide_relationship_class_sq.c.name.label("relationship_class_name"), + self.wide_relationship_class_sq.c.object_class_id_list, + self.wide_relationship_class_sq.c.object_class_name_list, + self.parameter_value_sq.c.entity_id, + self.wide_relationship_sq.c.id.label("relationship_id"), + self.wide_relationship_sq.c.object_id_list, + self.wide_relationship_sq.c.object_name_list, + self.parameter_definition_sq.c.id.label("parameter_id"), + self.parameter_definition_sq.c.name.label("parameter_name"), + self.parameter_value_sq.c.alternative_id, + self.alternative_sq.c.name.label("alternative_name"), + self.parameter_value_sq.c.value, + self.parameter_value_sq.c.type, + ) + .filter(self.parameter_definition_sq.c.id == self.parameter_value_sq.c.parameter_definition_id) + .filter(self.parameter_value_sq.c.entity_id == self.wide_relationship_sq.c.id) + .filter(self.parameter_definition_sq.c.entity_class_id == self.wide_relationship_class_sq.c.id) + .filter(self.parameter_value_sq.c.alternative_id == self.alternative_sq.c.id) + .subquery() + ) + return self._relationship_parameter_value_sq + + @property + def ext_parameter_value_metadata_sq(self): + if self._ext_parameter_value_metadata_sq is None: + self._ext_parameter_value_metadata_sq = ( + self.query( + self.parameter_value_metadata_sq.c.id, + self.parameter_value_metadata_sq.c.parameter_value_id, + self.metadata_sq.c.id.label("metadata_id"), + self.entity_sq.c.name.label("entity_name"), + self.parameter_definition_sq.c.name.label("parameter_name"), + self.alternative_sq.c.name.label("alternative_name"), + self.metadata_sq.c.name.label("metadata_name"), + self.metadata_sq.c.value.label("metadata_value"), + self.parameter_value_metadata_sq.c.commit_id, + ) + .filter(self.parameter_value_metadata_sq.c.parameter_value_id == self.parameter_value_sq.c.id) + .filter(self.parameter_value_sq.c.parameter_definition_id == self.parameter_definition_sq.c.id) + .filter(self.parameter_value_sq.c.entity_id == self.entity_sq.c.id) + .filter(self.parameter_value_sq.c.alternative_id == self.alternative_sq.c.id) + .filter(self.parameter_value_metadata_sq.c.metadata_id == self.metadata_sq.c.id) + .subquery() + ) + return self._ext_parameter_value_metadata_sq + + @property + def ext_entity_metadata_sq(self): + if self._ext_entity_metadata_sq is None: + self._ext_entity_metadata_sq = ( + self.query( + self.entity_metadata_sq.c.id, + self.entity_metadata_sq.c.entity_id, + self.metadata_sq.c.id.label("metadata_id"), + self.entity_sq.c.name.label("entity_name"), + self.metadata_sq.c.name.label("metadata_name"), + self.metadata_sq.c.value.label("metadata_value"), + self.entity_metadata_sq.c.commit_id, + ) + .filter(self.entity_metadata_sq.c.entity_id == self.entity_sq.c.id) + .filter(self.entity_metadata_sq.c.metadata_id == self.metadata_sq.c.id) + .subquery() + ) + return self._ext_entity_metadata_sq + + def _make_entity_class_sq(self): + """ + Creates a subquery for entity classes. + + Returns: + Alias: an entity class subquery + """ + return self._subquery("entity_class") + + def _make_entity_sq(self): + """ + Creates a subquery for entities. + + Returns: + Alias: an entity subquery + """ + return self._subquery("entity") + + def _make_entity_element_sq(self): + """ + Creates a subquery for entity-elements. + + Returns: + Alias: an entity_element subquery + """ + return self._subquery("entity_element") + + def _make_parameter_definition_sq(self): + """ + Creates a subquery for parameter definitions. + + Returns: + Alias: a parameter definition subquery + """ + par_def_sq = self._subquery("parameter_definition") + list_value_id = case( + [(par_def_sq.c.default_type == "list_value_ref", cast(par_def_sq.c.default_value, Integer()))], else_=None + ) + default_value = case( + [(par_def_sq.c.default_type == "list_value_ref", self.list_value_sq.c.value)], + else_=par_def_sq.c.default_value, + ) + default_type = case( + [(par_def_sq.c.default_type == "list_value_ref", self.list_value_sq.c.type)], + else_=par_def_sq.c.default_type, + ) + return ( + self.query( + par_def_sq.c.id.label("id"), + par_def_sq.c.name.label("name"), + par_def_sq.c.description.label("description"), + par_def_sq.c.entity_class_id, + label("default_value", default_value), + label("default_type", default_type), + label("list_value_id", list_value_id), + par_def_sq.c.commit_id.label("commit_id"), + par_def_sq.c.parameter_value_list_id.label("parameter_value_list_id"), + ) + .outerjoin(self.list_value_sq, self.list_value_sq.c.id == list_value_id) + .subquery("clean_parameter_definition_sq") + ) + + def _make_parameter_value_sq(self): + """ + Creates a subquery for parameter values. + + Returns: + Alias: a parameter value subquery + """ + par_val_sq = self._subquery("parameter_value") + list_value_id = case([(par_val_sq.c.type == "list_value_ref", cast(par_val_sq.c.value, Integer()))], else_=None) + value = case([(par_val_sq.c.type == "list_value_ref", self.list_value_sq.c.value)], else_=par_val_sq.c.value) + type_ = case([(par_val_sq.c.type == "list_value_ref", self.list_value_sq.c.type)], else_=par_val_sq.c.type) + return ( + self.query( + par_val_sq.c.id.label("id"), + par_val_sq.c.parameter_definition_id, + par_val_sq.c.entity_class_id, + par_val_sq.c.entity_id, + label("value", value), + label("type", type_), + label("list_value_id", list_value_id), + par_val_sq.c.commit_id.label("commit_id"), + par_val_sq.c.alternative_id, + ) + .filter(par_val_sq.c.entity_id == self.entity_sq.c.id) + .outerjoin(self.list_value_sq, self.list_value_sq.c.id == list_value_id) + .subquery("clean_parameter_value_sq") + ) + + def _make_alternative_sq(self): + """ + Creates a subquery for alternatives. + + Returns: + Alias: an alternative subquery + """ + return self._subquery("alternative") + + def _make_scenario_sq(self): + """ + Creates a subquery for scenarios. + + Returns: + Alias: a scenario subquery + """ + return self._subquery("scenario") + + def _make_scenario_alternative_sq(self): + """ + Creates a subquery for scenario alternatives. + + Returns: + Alias: a scenario alternative subquery + """ + return self._subquery("scenario_alternative") + + def override_entity_class_sq_maker(self, method): + """ + Overrides the function that creates the ``entity_class_sq`` property. + + Args: + method (Callable): a function that accepts a :class:`DatabaseMapping` as its argument and + returns entity class subquery as an :class:`Alias` object + """ + self._make_entity_class_sq = MethodType(method, self) + self._clear_subqueries("entity_class") + + def override_entity_sq_maker(self, method): + """ + Overrides the function that creates the ``entity_sq`` property. + + Args: + method (Callable): a function that accepts a :class:`DatabaseMapping` as its argument and + returns entity subquery as an :class:`Alias` object + """ + self._make_entity_sq = MethodType(method, self) + self._clear_subqueries("entity") + + def override_entity_element_sq_maker(self, method): + """ + Overrides the function that creates the ``entity_element_sq`` property. + + Args: + method (Callable): a function that accepts a :class:`DatabaseMapping` as its argument and + returns entity_element subquery as an :class:`Alias` object + """ + self._make_entity_element_sq = MethodType(method, self) + self._clear_subqueries("entity_element") + + def override_parameter_definition_sq_maker(self, method): + """ + Overrides the function that creates the ``parameter_definition_sq`` property. + + Args: + method (Callable): a function that accepts a :class:`DatabaseMapping` as its argument and + returns parameter definition subquery as an :class:`Alias` object + """ + self._make_parameter_definition_sq = MethodType(method, self) + self._clear_subqueries("parameter_definition") + + def override_parameter_value_sq_maker(self, method): + """ + Overrides the function that creates the ``parameter_value_sq`` property. + + Args: + method (Callable): a function that accepts a :class:`DatabaseMapping` as its argument and + returns parameter value subquery as an :class:`Alias` object + """ + self._make_parameter_value_sq = MethodType(method, self) + self._clear_subqueries("parameter_value") + + def override_alternative_sq_maker(self, method): + """ + Overrides the function that creates the ``alternative_sq`` property. + + Args: + method (Callable): a function that accepts a :class:`DatabaseMapping` as its argument and + returns alternative subquery as an :class:`Alias` object + """ + self._make_alternative_sq = MethodType(method, self) + self._clear_subqueries("alternative") + + def override_scenario_sq_maker(self, method): + """ + Overrides the function that creates the ``scenario_sq`` property. + + Args: + method (Callable): a function that accepts a :class:`DatabaseMapping` as its argument and + returns scenario subquery as an :class:`Alias` object + """ + self._make_scenario_sq = MethodType(method, self) + self._clear_subqueries("scenario") + + def override_scenario_alternative_sq_maker(self, method): + """ + Overrides the function that creates the ``scenario_alternative_sq`` property. + + Args: + method (Callable): a function that accepts a :class:`DatabaseMapping` as its argument and + returns scenario alternative subquery as an :class:`Alias` object + """ + self._make_scenario_alternative_sq = MethodType(method, self) + self._clear_subqueries("scenario_alternative") + + def restore_entity_class_sq_maker(self): + """Restores the original function that creates the ``entity_class_sq`` property.""" + self._make_entity_class_sq = MethodType(DatabaseMapping._make_entity_class_sq, self) + self._clear_subqueries("entity_class") + + def restore_entity_sq_maker(self): + """Restores the original function that creates the ``entity_sq`` property.""" + self._make_entity_sq = MethodType(DatabaseMapping._make_entity_sq, self) + self._clear_subqueries("entity") + + def restore_entity_element_sq_maker(self): + """Restores the original function that creates the ``entity_element_sq`` property.""" + self._make_entity_element_sq = MethodType(DatabaseMapping._make_entity_element_sq, self) + self._clear_subqueries("entity_element") + + def restore_parameter_definition_sq_maker(self): + """Restores the original function that creates the ``parameter_definition_sq`` property.""" + self._make_parameter_definition_sq = MethodType(DatabaseMapping._make_parameter_definition_sq, self) + self._clear_subqueries("parameter_definition") + + def restore_parameter_value_sq_maker(self): + """Restores the original function that creates the ``parameter_value_sq`` property.""" + self._make_parameter_value_sq = MethodType(DatabaseMapping._make_parameter_value_sq, self) + self._clear_subqueries("parameter_value") + + def restore_alternative_sq_maker(self): + """Restores the original function that creates the ``alternative_sq`` property.""" + self._make_alternative_sq = MethodType(DatabaseMapping._make_alternative_sq, self) + self._clear_subqueries("alternative") + + def restore_scenario_sq_maker(self): + """Restores the original function that creates the ``scenario_sq`` property.""" + self._make_scenario_sq = MethodType(DatabaseMapping._make_scenario_sq, self) + self._clear_subqueries("scenario") + + def restore_scenario_alternative_sq_maker(self): + """Restores the original function that creates the ``scenario_alternative_sq`` property.""" + self._make_scenario_alternative_sq = MethodType(DatabaseMapping._make_scenario_alternative_sq, self) + self._clear_subqueries("scenario_alternative") + + def _object_class_id(self): + return case( + [(self.wide_entity_class_sq.c.dimension_id_list == None, self.wide_entity_class_sq.c.id)], else_=None + ) + + def _relationship_class_id(self): + return case( + [(self.wide_entity_class_sq.c.dimension_id_list != None, self.wide_entity_class_sq.c.id)], else_=None + ) + + def _object_id(self): + return case([(self.wide_entity_sq.c.element_id_list == None, self.wide_entity_sq.c.id)], else_=None) + + def _relationship_id(self): + return case([(self.wide_entity_sq.c.element_id_list != None, self.wide_entity_sq.c.id)], else_=None) + + def _object_class_name(self): + return case( + [(self.wide_entity_class_sq.c.dimension_id_list == None, self.wide_entity_class_sq.c.name)], else_=None + ) + + def _relationship_class_name(self): + return case( + [(self.wide_entity_class_sq.c.dimension_id_list != None, self.wide_entity_class_sq.c.name)], else_=None + ) + + def _object_class_id_list(self): + return case( + [ + ( + self.wide_entity_class_sq.c.dimension_id_list != None, + self.wide_relationship_class_sq.c.object_class_id_list, + ) + ], + else_=None, + ) + + def _object_class_name_list(self): + return case( + [ + ( + self.wide_entity_class_sq.c.dimension_id_list != None, + self.wide_relationship_class_sq.c.object_class_name_list, + ) + ], + else_=None, + ) + + def _object_name(self): + return case([(self.wide_entity_sq.c.element_id_list == None, self.wide_entity_sq.c.name)], else_=None) + + def _object_id_list(self): + return case( + [(self.wide_entity_sq.c.element_id_list != None, self.wide_relationship_sq.c.object_id_list)], else_=None + ) + + def _object_name_list(self): + return case( + [(self.wide_entity_sq.c.element_id_list != None, self.wide_relationship_sq.c.object_name_list)], else_=None + ) diff --git a/spinedb_api/db_mapping_remove_mixin.py b/spinedb_api/db_mapping_remove_mixin.py index 15c240bf..f49c96e1 100644 --- a/spinedb_api/db_mapping_remove_mixin.py +++ b/spinedb_api/db_mapping_remove_mixin.py @@ -36,22 +36,22 @@ def remove_items(self, tablename, *ids): if not ids: return [] tablename = self._real_tablename(tablename) - table_cache = self.cache.table_cache(tablename) + mapped_table = self.mapped_table(tablename) if Asterisk in ids: - self.cache.fetch_all(tablename) - ids = table_cache + self.fetch_all(tablename) + ids = mapped_table ids = set(ids) if tablename == "alternative": # Do not remove the Base alternative ids.discard(1) - return [table_cache.remove_item(id_) for id_ in ids] + return [mapped_table.remove_item(id_) for id_ in ids] def restore_items(self, tablename, *ids): if not ids: return [] tablename = self._real_tablename(tablename) - table_cache = self.cache.table_cache(tablename) - return [table_cache.restore_item(id_) for id_ in ids] + mapped_table = self.mapped_table(tablename) + return [mapped_table.restore_item(id_) for id_ in ids] def purge_items(self, tablename): """Removes all items from given table. @@ -97,9 +97,9 @@ def _do_remove_items(self, connection, tablename, *ids): def remove_unused_metadata(self): used_metadata_ids = set() - for x in self.cache.table_cache("entity_metadata").valid_values(): + for x in self.mapped_table("entity_metadata").valid_values(): used_metadata_ids.add(x["metadata_id"]) - for x in self.cache.table_cache("parameter_value_metadata").valid_values(): + for x in self.mapped_table("parameter_value_metadata").valid_values(): used_metadata_ids.add(x["metadata_id"]) - unused_metadata_ids = {x["id"] for x in self.cache.table_cache("metadata").valid_values()} - used_metadata_ids + unused_metadata_ids = {x["id"] for x in self.mapped_table("metadata").valid_values()} - used_metadata_ids self.remove_items("metadata", *unused_metadata_ids) diff --git a/spinedb_api/db_mapping_update_mixin.py b/spinedb_api/db_mapping_update_mixin.py index f50d980e..d64b4ddb 100644 --- a/spinedb_api/db_mapping_update_mixin.py +++ b/spinedb_api/db_mapping_update_mixin.py @@ -172,7 +172,7 @@ def get_data_to_set_scenario_alternatives(self, *scenarios, strict=True): scen_alt_ids_to_remove = {} errors = [] for scen in scenarios: - current_scen = self.cache.table_cache("scenario").find_item(scen) + current_scen = self.mapped_table("scenario").find_item(scen) if current_scen is None: error = f"no scenario matching {scen} to set alternatives for" if strict: @@ -187,7 +187,7 @@ def get_data_to_set_scenario_alternatives(self, *scenarios, strict=True): scen_alts_to_add.append(item_to_add) for alternative_id in current_scen["alternative_id_list"]: scen_alt = {"scenario_id": current_scen["id"], "alternative_id": alternative_id} - current_scen_alt = self.cache.table_cache("scenario_alternative").find_item(scen_alt) + current_scen_alt = self.mapped_table("scenario_alternative").find_item(scen_alt) scen_alt_ids_to_remove[current_scen_alt["id"]] = current_scen_alt # Remove items that are both to add and to remove for id_, to_rm in list(scen_alt_ids_to_remove.items()): diff --git a/spinedb_api/export_functions.py b/spinedb_api/export_functions.py index af67469d..08a265b5 100644 --- a/spinedb_api/export_functions.py +++ b/spinedb_api/export_functions.py @@ -77,26 +77,26 @@ def export_data( def _get_items(db_map, tablename, ids): if not ids: return () - _process_item = _make_item_processor(db_map.cache, tablename) - for item in _get_items_from_cache(db_map.cache, tablename, ids): + _process_item = _make_item_processor(db_map, tablename) + for item in _get_items_from_db_map(db_map, tablename, ids): yield from _process_item(item) -def _get_items_from_cache(cache, tablename, ids): +def _get_items_from_db_map(db_map, tablename, ids): if ids is Asterisk: - cache.fetch_all(tablename) - yield from cache.table_cache(tablename).valid_values() + db_map.fetch_all(tablename) + yield from db_map.mapped_table(tablename).valid_values() return for id_ in ids: - item = cache.get_item(tablename, id_) or cache.fetch_ref(tablename, id_) + item = db_map.get_item(tablename, id=id_) if item.is_valid(): yield item -def _make_item_processor(cache, tablename): +def _make_item_processor(db_map, tablename): if tablename == "parameter_value_list": - cache.fetch_all("list_value") - return _ParameterValueListProcessor(cache.table_cache("list_value").valid_values()) + db_map.fetch_all("list_value") + return _ParameterValueListProcessor(db_map.mapped_table("list_value").valid_values()) return lambda item: (item,) diff --git a/spinedb_api/export_mapping/export_mapping.py b/spinedb_api/export_mapping/export_mapping.py index 4b4783f2..597ae1d5 100644 --- a/spinedb_api/export_mapping/export_mapping.py +++ b/spinedb_api/export_mapping/export_mapping.py @@ -179,7 +179,7 @@ def add_query_columns(self, db_map, query): The base class implementation just returns the same query without adding any new columns. Args: - db_map (DatabaseMappingBase) + db_map (DatabaseMapping) query (Alias or dict) Returns: @@ -193,7 +193,7 @@ def filter_query(self, db_map, query): The base class implementation just returns the same query without applying any new filters. Args: - db_map (DatabaseMappingBase) + db_map (DatabaseMapping) query (Alias or dict) Returns: @@ -221,7 +221,7 @@ def _build_query(self, db_map, title_state): """Builds and returns the query to run for this mapping hierarchy. Args: - db_map (DatabaseMappingBase) + db_map (DatabaseMapping) title_state (dict) Returns: @@ -251,7 +251,7 @@ def _build_title_query(self, db_map): """Builds and returns the query to get titles for this mapping hierarchy. Args: - db_map (DatabaseMappingBase): database mapping + db_map (DatabaseMapping): database mapping Returns: Alias: title query @@ -275,7 +275,7 @@ def _build_header_query(self, db_map, title_state, buddies): """Builds the header query for this mapping hierarchy. Args: - db_map (DatabaseMappingBase): database mapping + db_map (DatabaseMapping): database mapping title_state (dict): title state buddies (list of tuple): pairs of buddy mappings @@ -413,7 +413,7 @@ def rows(self, db_map, title_state): """Yields rows issued by this mapping and its children combined. Args: - db_map (DatabaseMappingBase) + db_map (DatabaseMapping) title_state (dict) Returns: @@ -498,7 +498,7 @@ def _non_unique_titles(self, db_map, limit=None): """Yields all titles, not necessarily unique, and associated state dictionaries. Args: - db_map (DatabaseMappingBase): a database map + db_map (DatabaseMapping): a database map limit (int, optional): yield only this many items Yields: @@ -512,7 +512,7 @@ def titles(self, db_map, limit=None): """Yields unique titles and associated state dictionaries. Args: - db_map (DatabaseMappingBase): a database map + db_map (DatabaseMapping): a database map limit (int, optional): yield only this many items Yields: @@ -540,7 +540,7 @@ def make_header_recursive(self, query, buddies): Args: build_header_query (callable): a function that any mapping in the hierarchy can call to get the query - db_map (DatabaseMappingBase): database map + db_map (DatabaseMapping): database map title_state (dict): title state buddies (list of tuple): buddy mappings @@ -567,7 +567,7 @@ def make_header(self, db_map, title_state, buddies): """Returns the header for this mapping. Args: - db_map (DatabaseMappingBase): database map + db_map (DatabaseMapping): database map title_state (dict): title state buddies (list of tuple): buddy mappings diff --git a/spinedb_api/export_mapping/generator.py b/spinedb_api/export_mapping/generator.py index fa29255c..026454be 100644 --- a/spinedb_api/export_mapping/generator.py +++ b/spinedb_api/export_mapping/generator.py @@ -25,7 +25,7 @@ def rows(root_mapping, db_map, fixed_state=None, empty_data_header=True, group_f Args: root_mapping (Mapping): root export mapping - db_map (DatabaseMappingBase): a database map + db_map (DatabaseMapping): a database map fixed_state (dict, optional): mapping state that fixes items empty_data_header (bool): True to yield at least header rows even if there is no data, False to yield nothing group_fn (str): group function name @@ -98,7 +98,7 @@ def titles(root_mapping, db_map, limit=None): Args: root_mapping (Mapping): root export mapping - db_map (DatabaseMappingBase): a database map + db_map (DatabaseMapping): a database map Yield: tuple: title and title's fixed key diff --git a/spinedb_api/filters/alternative_filter.py b/spinedb_api/filters/alternative_filter.py index 8299f7cf..f406f793 100644 --- a/spinedb_api/filters/alternative_filter.py +++ b/spinedb_api/filters/alternative_filter.py @@ -26,7 +26,7 @@ def apply_alternative_filter_to_parameter_value_sq(db_map, alternatives): Replaces parameter value subquery properties in ``db_map`` such that they return only values of given alternatives. Args: - db_map (DatabaseMappingBase): a database map to alter + db_map (DatabaseMapping): a database map to alter alternatives (Iterable of str or int, optional): alternative names or ids; """ state = _AlternativeFilterState(db_map, alternatives) @@ -52,7 +52,7 @@ def alternative_filter_from_dict(db_map, config): Applies alternative filter to given database map. Args: - db_map (DatabaseMappingBase): target database map + db_map (DatabaseMapping): target database map config (dict): alternative filter configuration """ apply_alternative_filter_to_parameter_value_sq(db_map, config["alternatives"]) @@ -117,7 +117,7 @@ class _AlternativeFilterState: def __init__(self, db_map, alternatives): """ Args: - db_map (DatabaseMappingBase): database the state applies to + db_map (DatabaseMapping): database the state applies to alternatives (Iterable of str or int): alternative names or ids; """ self.original_parameter_value_sq = db_map.parameter_value_sq @@ -129,7 +129,7 @@ def _alternative_ids(db_map, alternatives): Finds ids for given alternatives. Args: - db_map (DatabaseMappingBase): a database map + db_map (DatabaseMapping): a database map alternatives (Iterable): alternative names or ids Returns: @@ -160,12 +160,12 @@ def _alternative_ids(db_map, alternatives): def _make_alternative_filtered_parameter_value_sq(db_map, state): """ - Returns an alternative filtering subquery similar to :func:`DatabaseMappingBase.parameter_value_sq`. + Returns an alternative filtering subquery similar to :func:`DatabaseMapping.parameter_value_sq`. - This function can be used as replacement for parameter value subquery maker in :class:`DatabaseMappingBase`. + This function can be used as replacement for parameter value subquery maker in :class:`DatabaseMapping`. Args: - db_map (DatabaseMappingBase): a database map + db_map (DatabaseMapping): a database map state (_AlternativeFilterState): a state bound to ``db_map`` Returns: diff --git a/spinedb_api/filters/execution_filter.py b/spinedb_api/filters/execution_filter.py index 10a6f7e7..bc647b41 100644 --- a/spinedb_api/filters/execution_filter.py +++ b/spinedb_api/filters/execution_filter.py @@ -28,7 +28,7 @@ def apply_execution_filter(db_map, execution): Replaces the import alternative in ``db_map`` with a dedicated alternative for an execution. Args: - db_map (DatabaseMappingBase): a database map to alter + db_map (DatabaseMapping): a database map to alter execution (dict): execution descriptor """ state = _ExecutionFilterState(db_map, execution) @@ -54,7 +54,7 @@ def execution_filter_from_dict(db_map, config): Applies execution filter to given database map. Args: - db_map (DatabaseMappingBase): target database map + db_map (DatabaseMapping): target database map config (dict): execution filter configuration """ apply_execution_filter(db_map, config["execution"]) @@ -116,7 +116,7 @@ class _ExecutionFilterState: def __init__(self, db_map, execution): """ Args: - db_map (DatabaseMappingBase): database the state applies to + db_map (DatabaseMapping): database the state applies to execution (dict): execution descriptor """ self.original_create_import_alternative = db_map._create_import_alternative @@ -151,7 +151,7 @@ def _create_import_alternative(db_map, state): Creates an alternative to use as default for all import operations on the given db_map. Args: - db_map (DatabaseMappingBase): database the state applies to + db_map (DatabaseMapping): database the state applies to state (_ExecutionFilterState): a state bound to ``db_map`` """ execution_item = state.execution_item diff --git a/spinedb_api/filters/renamer.py b/spinedb_api/filters/renamer.py index f729e8df..fc7cc05f 100644 --- a/spinedb_api/filters/renamer.py +++ b/spinedb_api/filters/renamer.py @@ -28,7 +28,7 @@ def apply_renaming_to_entity_class_sq(db_map, name_map): Applies renaming to entity class subquery. Args: - db_map (DatabaseMappingBase): a database map + db_map (DatabaseMapping): a database map name_map (dict): a map from old name to new name """ state = _EntityClassRenamerState(db_map, name_map) @@ -54,7 +54,7 @@ def entity_class_renamer_from_dict(db_map, config): Applies entity class renamer manipulator to given database map. Args: - db_map (DatabaseMappingBase): target database map + db_map (DatabaseMapping): target database map config (dict): renamer configuration """ apply_renaming_to_entity_class_sq(db_map, config["name_map"]) @@ -98,7 +98,7 @@ def apply_renaming_to_parameter_definition_sq(db_map, name_map): Applies renaming to parameter definition subquery. Args: - db_map (DatabaseMappingBase): a database map + db_map (DatabaseMapping): a database map name_map (dict): a map from old name to new name """ state = _ParameterRenamerState(db_map, name_map) @@ -124,7 +124,7 @@ def parameter_renamer_from_dict(db_map, config): Applies parameter renamer manipulator to given database map. Args: - db_map (DatabaseMappingBase): target database map + db_map (DatabaseMapping): target database map config (dict): renamer configuration """ apply_renaming_to_parameter_definition_sq(db_map, config["name_map"]) @@ -168,7 +168,7 @@ class _EntityClassRenamerState: def __init__(self, db_map, name_map): """ Args: - db_map (DatabaseMappingBase): a database map + db_map (DatabaseMapping): a database map name_map (dict): a mapping from original name to a new name. """ name_map = {old: new for old, new in name_map.items() if old != new} @@ -179,7 +179,7 @@ def __init__(self, db_map, name_map): def _ids(db_map, name_map): """ Args: - db_map (DatabaseMappingBase): a database map + db_map (DatabaseMapping): a database map name_map (dict): a mapping from original name to a new name Returns: @@ -197,7 +197,7 @@ def _make_renaming_entity_class_sq(db_map, state): Returns an entity class subquery which renames classes. Args: - db_map (DatabaseMappingBase): a database map + db_map (DatabaseMapping): a database map state (_EntityClassRenamerState): Returns: @@ -223,7 +223,7 @@ class _ParameterRenamerState: def __init__(self, db_map, name_map): """ Args: - db_map (DatabaseMappingBase): a database map + db_map (DatabaseMapping): a database map name_map (dict): mapping from entity class name to mapping from parameter name to new name """ self.id_to_name = self._ids(db_map, name_map) @@ -233,7 +233,7 @@ def __init__(self, db_map, name_map): def _ids(db_map, name_map): """ Args: - db_map (DatabaseMappingBase): a database map + db_map (DatabaseMapping): a database map name_map (dict): a mapping from original name to a new name Returns: @@ -256,7 +256,7 @@ def _make_renaming_parameter_definition_sq(db_map, state): Returns an entity class subquery which renames parameters. Args: - db_map (DatabaseMappingBase): a database map + db_map (DatabaseMapping): a database map state (_ParameterRenamerState): Returns: diff --git a/spinedb_api/filters/scenario_filter.py b/spinedb_api/filters/scenario_filter.py index 858bcb40..bacb6c51 100644 --- a/spinedb_api/filters/scenario_filter.py +++ b/spinedb_api/filters/scenario_filter.py @@ -27,7 +27,7 @@ def apply_scenario_filter_to_subqueries(db_map, scenario): Replaces affected subqueries in ``db_map`` such that they return only values of given scenario. Args: - db_map (DatabaseMappingBase): a database map to alter + db_map (DatabaseMapping): a database map to alter scenario (str or int): scenario name or id """ state = _ScenarioFilterState(db_map, scenario) @@ -63,7 +63,7 @@ def scenario_filter_from_dict(db_map, config): Applies scenario filter to given database map. Args: - db_map (DatabaseMappingBase): target database map + db_map (DatabaseMapping): target database map config (dict): scenario filter configuration """ apply_scenario_filter_to_subqueries(db_map, config["scenario"]) @@ -128,7 +128,7 @@ class _ScenarioFilterState: def __init__(self, db_map, scenario): """ Args: - db_map (DatabaseMappingBase): database the state applies to + db_map (DatabaseMapping): database the state applies to scenario (str or int): scenario name or ids """ self.original_entity_sq = db_map.entity_sq @@ -146,7 +146,7 @@ def _scenario_id(db_map, scenario): Finds id for given scenario. Args: - db_map (DatabaseMappingBase): a database map + db_map (DatabaseMapping): a database map scenario (str or int): scenario name or id Returns: @@ -171,7 +171,7 @@ def _scenario_alternative_ids(self, db_map): Finds scenario alternative and alternative ids of current scenario. Args: - db_map (DatabaseMappingBase): a database map + db_map (DatabaseMapping): a database map Returns: tuple: scenario alternative ids and alternative ids @@ -216,12 +216,12 @@ def _ext_entity_sq(db_map, state): def _make_scenario_filtered_entity_element_sq(db_map, state): - """Returns a scenario filtering subquery similar to :func:`DatabaseMappingBase.entity_element_sq`. + """Returns a scenario filtering subquery similar to :func:`DatabaseMapping.entity_element_sq`. - This function can be used as replacement for entity_element subquery maker in :class:`DatabaseMappingBase`. + This function can be used as replacement for entity_element subquery maker in :class:`DatabaseMapping`. Args: - db_map (DatabaseMappingBase): a database map + db_map (DatabaseMapping): a database map state (_ScenarioFilterState): a state bound to ``db_map`` Returns: @@ -247,12 +247,12 @@ def _make_scenario_filtered_entity_element_sq(db_map, state): def _make_scenario_filtered_entity_sq(db_map, state): - """Returns a scenario filtering subquery similar to :func:`DatabaseMappingBase.entity_sq`. + """Returns a scenario filtering subquery similar to :func:`DatabaseMapping.entity_sq`. - This function can be used as replacement for entity subquery maker in :class:`DatabaseMappingBase`. + This function can be used as replacement for entity subquery maker in :class:`DatabaseMapping`. Args: - db_map (DatabaseMappingBase): a database map + db_map (DatabaseMapping): a database map state (_ScenarioFilterState): a state bound to ``db_map`` Returns: @@ -298,12 +298,12 @@ def _make_scenario_filtered_entity_sq(db_map, state): def _make_scenario_filtered_parameter_value_sq(db_map, state): """ - Returns a scenario filtering subquery similar to :func:`DatabaseMappingBase.parameter_value_sq`. + Returns a scenario filtering subquery similar to :func:`DatabaseMapping.parameter_value_sq`. - This function can be used as replacement for parameter value subquery maker in :class:`DatabaseMappingBase`. + This function can be used as replacement for parameter value subquery maker in :class:`DatabaseMapping`. Args: - db_map (DatabaseMappingBase): a database map + db_map (DatabaseMapping): a database map state (_ScenarioFilterState): a state bound to ``db_map`` Returns: @@ -330,12 +330,12 @@ def _make_scenario_filtered_parameter_value_sq(db_map, state): def _make_scenario_filtered_alternative_sq(db_map, state): """ - Returns an alternative filtering subquery similar to :func:`DatabaseMappingBase.alternative_sq`. + Returns an alternative filtering subquery similar to :func:`DatabaseMapping.alternative_sq`. - This function can be used as replacement for alternative subquery maker in :class:`DatabaseMappingBase`. + This function can be used as replacement for alternative subquery maker in :class:`DatabaseMapping`. Args: - db_map (DatabaseMappingBase): a database map + db_map (DatabaseMapping): a database map state (_ScenarioFilterState): a state bound to ``db_map`` Returns: @@ -347,12 +347,12 @@ def _make_scenario_filtered_alternative_sq(db_map, state): def _make_scenario_filtered_scenario_sq(db_map, state): """ - Returns a scenario filtering subquery similar to :func:`DatabaseMappingBase.scenario_sq`. + Returns a scenario filtering subquery similar to :func:`DatabaseMapping.scenario_sq`. - This function can be used as replacement for scenario subquery maker in :class:`DatabaseMappingBase`. + This function can be used as replacement for scenario subquery maker in :class:`DatabaseMapping`. Args: - db_map (DatabaseMappingBase): a database map + db_map (DatabaseMapping): a database map state (_ScenarioFilterState): a state bound to ``db_map`` Returns: @@ -364,12 +364,12 @@ def _make_scenario_filtered_scenario_sq(db_map, state): def _make_scenario_filtered_scenario_alternative_sq(db_map, state): """ - Returns a scenario alternative filtering subquery similar to :func:`DatabaseMappingBase.scenario_alternative_sq`. + Returns a scenario alternative filtering subquery similar to :func:`DatabaseMapping.scenario_alternative_sq`. - This function can be used as replacement for scenario alternative subquery maker in :class:`DatabaseMappingBase`. + This function can be used as replacement for scenario alternative subquery maker in :class:`DatabaseMapping`. Args: - db_map (DatabaseMappingBase): a database map + db_map (DatabaseMapping): a database map state (_ScenarioFilterState): a state bound to ``db_map`` Returns: diff --git a/spinedb_api/filters/tools.py b/spinedb_api/filters/tools.py index d8af50fe..94d60545 100644 --- a/spinedb_api/filters/tools.py +++ b/spinedb_api/filters/tools.py @@ -70,7 +70,7 @@ def apply_filter_stack(db_map, stack): Applies stack of filters and manipulator to given database map. Args: - db_map (DatabaseMappingBase): a database map + db_map (DatabaseMapping): a database map stack (list): a stack of database filters and manipulators """ appliers = { diff --git a/spinedb_api/filters/value_transformer.py b/spinedb_api/filters/value_transformer.py index 0d35d256..956de19d 100644 --- a/spinedb_api/filters/value_transformer.py +++ b/spinedb_api/filters/value_transformer.py @@ -31,7 +31,7 @@ def apply_value_transform_to_parameter_value_sq(db_map, instructions): Applies renaming to parameter definition subquery. Args: - db_map (DatabaseMappingBase): a database map + db_map (DatabaseMapping): a database map instructions (dict): mapping from entity class name to mapping from parameter name to list of instructions """ @@ -59,7 +59,7 @@ def value_transformer_from_dict(db_map, config): Applies value transformer manipulator to given database map. Args: - db_map (DatabaseMappingBase): target database map + db_map (DatabaseMapping): target database map config (dict): transformer configuration """ apply_value_transform_to_parameter_value_sq(db_map, config["instructions"]) @@ -120,7 +120,7 @@ class _ValueTransformerState: def __init__(self, db_map, instructions): """ Args: - db_map (DatabaseMappingBase): a database map + db_map (DatabaseMapping): a database map instructions (dict): mapping from entity class name to parameter name to list of instructions """ self.original_parameter_value_sq = db_map.parameter_value_sq @@ -131,7 +131,7 @@ def _transform(db_map, instructions): """Transforms applicable parameter values for caching. Args: - db_map (DatabaseMappingBase): a database map + db_map (DatabaseMapping): a database map instructions (dict): mapping from entity class name to parameter name to list of instructions Returns: @@ -164,7 +164,7 @@ def _make_parameter_value_transforming_sq(db_map, state): Returns subquery which applies transformations to parameter values. Args: - db_map (DatabaseMappingBase): a database map + db_map (DatabaseMapping): a database map state (_ValueTransformerState): state Returns: diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 30280a21..e592dc1e 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -455,19 +455,19 @@ def import_relationship_parameter_value_metadata(db_map, data): def _get_items_for_import(db_map, item_type, data, check_skip_keys=()): - table_cache = db_map.cache.table_cache(item_type) + mapped_table = db_map.mapped_table(item_type) errors = [] to_add = [] to_update = [] seen = {} for item in data: - checked_item, add_error = table_cache.check_item(item, skip_keys=check_skip_keys) + checked_item, add_error = mapped_table.check_item(item, skip_keys=check_skip_keys) if not add_error: if not _check_unique(item_type, checked_item, seen, errors): continue to_add.append(checked_item) continue - checked_item, update_error = table_cache.check_item(item, for_update=True, skip_keys=check_skip_keys) + checked_item, update_error = mapped_table.check_item(item, for_update=True, skip_keys=check_skip_keys) if not update_error: if checked_item: if not _check_unique(item_type, checked_item, seen, errors): @@ -567,7 +567,7 @@ def _data_iterator(): "value": None, "type": None, } - pv = db_map.cache.table_cache("parameter_value").find_item(item) + pv = db_map.mapped_table("parameter_value").find_item(item) if pv is not None: value, type_ = fix_conflict((value, type_), (pv["value"], pv["type"]), on_conflict) item.update({"value": value, "type": type_}) @@ -593,7 +593,7 @@ def _get_scenarios_for_import(db_map, data): def _get_scenario_alternatives_for_import(db_map, data): alt_name_list_by_scen_name, errors = {}, [] for scen_name, alt_name, *optionals in data: - scen = db_map.cache.table_cache("scenario").find_item({"name": scen_name}) + scen = db_map.mapped_table("scenario").find_item({"name": scen_name}) if scen is None: errors.append(f"no scenario with name {scen_name} to set alternatives for") continue @@ -632,11 +632,11 @@ def _data_iterator(): value, type_ = unparse_value(value) index = index_by_list_name.get(list_name) if index is None: - current_list = db_map.cache.table_cache("parameter_value_list").find_item({"name": list_name}) + current_list = db_map.mapped_table("parameter_value_list").find_item({"name": list_name}) index = max( ( x["index"] - for x in db_map.cache.table_cache("list_value").valid_values() + for x in db_map.mapped_table("list_value").valid_values() if x["parameter_value_list_id"] == current_list["id"] ), default=-1, diff --git a/spinedb_api/db_cache_impl.py b/spinedb_api/mapped_items.py similarity index 87% rename from spinedb_api/db_cache_impl.py rename to spinedb_api/mapped_items.py index 3edf6224..9cfe8982 100644 --- a/spinedb_api/db_cache_impl.py +++ b/spinedb_api/mapped_items.py @@ -14,68 +14,30 @@ import uuid from operator import itemgetter from .parameter_value import to_database, from_database, ParameterValueFormatError -from .db_cache_base import DBCacheBase, CacheItemBase +from .db_mapping_base import MappedItemBase from .temp_id import TempId -class DBCache(DBCacheBase): - _sq_name_by_item_type = { - "entity_class": "wide_entity_class_sq", - "entity": "wide_entity_sq", - "entity_alternative": "entity_alternative_sq", - "parameter_value_list": "parameter_value_list_sq", - "list_value": "list_value_sq", - "alternative": "alternative_sq", - "scenario": "scenario_sq", - "scenario_alternative": "scenario_alternative_sq", - "entity_group": "entity_group_sq", - "parameter_definition": "parameter_definition_sq", - "parameter_value": "parameter_value_sq", - "metadata": "metadata_sq", - "entity_metadata": "entity_metadata_sq", - "parameter_value_metadata": "parameter_value_metadata_sq", - "commit": "commit_sq", - } - - def __init__(self, db_map): - """ - Args: - db_map (DatabaseMapping) - """ - super().__init__() - self._db_map = db_map - - @property - def item_types(self): - return list(self._sq_name_by_item_type) - - @staticmethod - def item_factory(item_type): - return { - "entity_class": EntityClassItem, - "entity": EntityItem, - "entity_alternative": EntityAlternativeItem, - "entity_group": EntityGroupItem, - "parameter_definition": ParameterDefinitionItem, - "parameter_value": ParameterValueItem, - "parameter_value_list": ParameterValueListItem, - "list_value": ListValueItem, - "alternative": AlternativeItem, - "scenario": ScenarioItem, - "scenario_alternative": ScenarioAlternativeItem, - "metadata": MetadataItem, - "entity_metadata": EntityMetadataItem, - "parameter_value_metadata": ParameterValueMetadataItem, - }.get(item_type, CacheItemBase) - - def query(self, item_type): - if self._db_map.closed: - return None - sq_name = self._sq_name_by_item_type[item_type] - return self._db_map.query(getattr(self._db_map, sq_name)) - - -class EntityClassItem(CacheItemBase): +def item_factory(item_type): + return { + "entity_class": EntityClassItem, + "entity": EntityItem, + "entity_alternative": EntityAlternativeItem, + "entity_group": EntityGroupItem, + "parameter_definition": ParameterDefinitionItem, + "parameter_value": ParameterValueItem, + "parameter_value_list": ParameterValueListItem, + "list_value": ListValueItem, + "alternative": AlternativeItem, + "scenario": ScenarioItem, + "scenario_alternative": ScenarioAlternativeItem, + "metadata": MetadataItem, + "entity_metadata": EntityMetadataItem, + "parameter_value_metadata": ParameterValueMetadataItem, + }.get(item_type, MappedItemBase) + + +class EntityClassItem(MappedItemBase): _fields = { "name": ("str", "The class name."), "dimension_name_list": ("tuple, optional", "The dimension names for a multi-dimensional class."), @@ -112,7 +74,7 @@ def commit(self, _commit_id): super().commit(None) -class EntityItem(CacheItemBase): +class EntityItem(MappedItemBase): _fields = { "class_name": ("str", "The entity class name."), "name": ("str, optional", "The entity name - must be given for a zero-dimensional entity."), @@ -154,13 +116,13 @@ def polish(self): return base_name = self["class_name"] + "_" + "__".join(self["element_name_list"]) name = base_name - table_cache = self._db_cache.table_cache(self._item_type) - while table_cache.unique_key_value_to_id(("class_name", "name"), (self["class_name"], name)) is not None: + mapped_table = self._db_cache.mapped_table(self._item_type) + while mapped_table.unique_key_value_to_id(("class_name", "name"), (self["class_name"], name)) is not None: name = base_name + "_" + uuid.uuid4().hex self["name"] = name -class EntityGroupItem(CacheItemBase): +class EntityGroupItem(MappedItemBase): _fields = { "class_name": ("str", "The entity class name."), "group_name": ("str", "The group entity name."), @@ -187,7 +149,7 @@ def __getitem__(self, key): return super().__getitem__(key) -class EntityAlternativeItem(CacheItemBase): +class EntityAlternativeItem(MappedItemBase): _fields = { "entity_class_name": ("str", "The entity class name."), "entity_byname": ( @@ -216,7 +178,7 @@ class EntityAlternativeItem(CacheItemBase): } -class ParsedValueBase(CacheItemBase): +class ParsedValueBase(MappedItemBase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._parsed_value = None @@ -306,7 +268,7 @@ def polish(self): parsed_value = from_database(default_value, default_type) if parsed_value is None: return - list_value_id = self._db_cache.table_cache("list_value").unique_key_value_to_id( + list_value_id = self._db_cache.mapped_table("list_value").unique_key_value_to_id( ("parameter_value_list_name", "value", "type"), (list_name, default_value, default_type) ) if list_value_id is None: @@ -327,7 +289,7 @@ def merge(self, other): and other_parameter_value_list_id != self["parameter_value_list_id"] and any( x["parameter_definition_id"] == self["id"] - for x in self._db_cache.table_cache("parameter_value").valid_values() + for x in self._db_cache.mapped_table("parameter_value").valid_values() ) ): del other["parameter_value_list_id"] @@ -413,7 +375,7 @@ def polish(self): parsed_value = from_database(value, type_) if parsed_value is None: return - list_value_id = self._db_cache.table_cache("list_value").unique_key_value_to_id( + list_value_id = self._db_cache.mapped_table("list_value").unique_key_value_to_id( ("parameter_value_list_name", "value", "type"), (list_name, value, type_) ) if list_value_id is None: @@ -431,7 +393,7 @@ def callback(new_id): list_value_id.add_resolve_callback(callback) -class ParameterValueListItem(CacheItemBase): +class ParameterValueListItem(MappedItemBase): _fields = {"name": ("str", "The parameter value list name.")} _unique_keys = (("name",),) @@ -456,7 +418,7 @@ def _make_parsed_value(self): return error -class AlternativeItem(CacheItemBase): +class AlternativeItem(MappedItemBase): _fields = { "name": ("str", "The alternative name."), "description": ("str, optional", "The alternative description."), @@ -465,7 +427,7 @@ class AlternativeItem(CacheItemBase): _unique_keys = (("name",),) -class ScenarioItem(CacheItemBase): +class ScenarioItem(MappedItemBase): _fields = { "name": ("str", "The scenario name."), "description": ("str, optional", "The scenario description."), @@ -480,11 +442,11 @@ def __getitem__(self, key): if key == "alternative_name_list": return [x["alternative_name"] for x in self.sorted_scenario_alternatives] if key == "sorted_scenario_alternatives": - self._db_cache.fetch_all("scenario_alternative") + self._db_cache.do_fetch_all("scenario_alternative") return sorted( ( x - for x in self._db_cache.table_cache("scenario_alternative").valid_values() + for x in self._db_cache.mapped_table("scenario_alternative").valid_values() if x["scenario_id"] == self["id"] ), key=itemgetter("rank"), @@ -492,7 +454,7 @@ def __getitem__(self, key): return super().__getitem__(key) -class ScenarioAlternativeItem(CacheItemBase): +class ScenarioAlternativeItem(MappedItemBase): _fields = { "scenario_name": ("str", "The scenario name."), "alternative_name": ("str", "The alternative name."), @@ -524,12 +486,12 @@ def __getitem__(self, key): return super().__getitem__(key) -class MetadataItem(CacheItemBase): +class MetadataItem(MappedItemBase): _fields = {"name": ("str", "The metadata entry name."), "value": ("str", "The metadata entry value.")} _unique_keys = (("name", "value"),) -class EntityMetadataItem(CacheItemBase): +class EntityMetadataItem(MappedItemBase): _fields = { "entity_name": ("str", "The entity name."), "metadata_name": ("str", "The metadata entry name."), @@ -547,7 +509,7 @@ class EntityMetadataItem(CacheItemBase): } -class ParameterValueMetadataItem(CacheItemBase): +class ParameterValueMetadataItem(MappedItemBase): _fields = { "parameter_definition_name": ("str", "The parameter name."), "entity_byname": ( diff --git a/spinedb_api/purge.py b/spinedb_api/purge.py index 20cfed62..efe64814 100644 --- a/spinedb_api/purge.py +++ b/spinedb_api/purge.py @@ -20,20 +20,6 @@ from .helpers import remove_credentials_from_url -def _ids_for_item_type(db_map, item_type): - """Queries ids for given database item type. - - Args: - db_map (DatabaseMapping): database map - item_type (str): database item type - - Returns: - set of int: item ids - """ - sq_attr = db_map.cache_sqs[item_type] - return {row.id for row in db_map.query(getattr(db_map, sq_attr))} - - def purge_url(url, purge_settings, logger=None): """Removes all items of selected types from the database at a given URL. diff --git a/spinedb_api/spine_db_server.py b/spinedb_api/spine_db_server.py index ced29530..65990aa9 100644 --- a/spinedb_api/spine_db_server.py +++ b/spinedb_api/spine_db_server.py @@ -19,7 +19,7 @@ Typically, you would start the server in a background Python process by specifying the URL of the target Spine DB, getting back the URL where the server is listening. You can then use that URL in any number of instances of your application that would connect to the server -- via a socket - and then send requests to retrieve or modify the data in the DB. +via a socket and then send requests to retrieve or modify the data in the DB. Requests to the server must be encoded using JSON. Each request must be a JSON array with the following elements: diff --git a/spinedb_api/spine_io/exporters/writer.py b/spinedb_api/spine_io/exporters/writer.py index a569f7f3..26698854 100644 --- a/spinedb_api/spine_io/exporters/writer.py +++ b/spinedb_api/spine_io/exporters/writer.py @@ -27,7 +27,7 @@ def write(db_map, writer, *mappings, empty_data_header=True, max_tables=None, ma Writes given mapping. Args: - db_map (DatabaseMappingBase): database map + db_map (DatabaseMapping): database map writer (Writer): target writer mappings (Mapping): root mappings empty_data_header (bool or Iterable of bool): True to write at least header rows even if there is no data, diff --git a/tests/filters/test_execution_filter.py b/tests/filters/test_execution_filter.py index bc44c7a7..6a092ee6 100644 --- a/tests/filters/test_execution_filter.py +++ b/tests/filters/test_execution_filter.py @@ -24,9 +24,9 @@ def test_import_alternative_after_applying_execution_filter(self): apply_execution_filter(db_map, execution) alternative_name = db_map.get_import_alternative_name() self.assertEqual(alternative_name, "low_on_steam_wasting_my_time__Importing importer@2023-09-06T01:23:45") - alternatives = {item["name"] for item in db_map.cache["alternative"].values()} + alternatives = {item["name"] for item in db_map.mapped_table("alternative").valid_values()} self.assertIn(alternative_name, alternatives) - scenarios = {item["name"] for item in db_map.cache["scenario"].values()} + scenarios = {item["name"] for item in db_map.mapped_table("scenario").valid_values()} self.assertEqual(scenarios, {"low_on_steam", "wasting_my_time"}) diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 687a25c5..7e347832 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -46,9 +46,9 @@ def query_wrapper(*args, orig_query=db_map.query, **kwargs): class TestDatabaseMappingConstruction(unittest.TestCase): def test_construction_with_filters(self): db_url = IN_MEMORY_DB_URL + "?spinedbfilter=fltr1&spinedbfilter=fltr2" - with mock.patch("spinedb_api.db_mapping_base.apply_filter_stack") as mock_apply: + with mock.patch("spinedb_api.db_mapping.apply_filter_stack") as mock_apply: with mock.patch( - "spinedb_api.db_mapping_base.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] + "spinedb_api.db_mapping.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] ) as mock_load: db_map = DatabaseMapping(db_url, create=True) db_map.close() @@ -58,9 +58,9 @@ def test_construction_with_filters(self): def test_construction_with_sqlalchemy_url_and_filters(self): db_url = IN_MEMORY_DB_URL + "/?spinedbfilter=fltr1&spinedbfilter=fltr2" sa_url = make_url(db_url) - with mock.patch("spinedb_api.db_mapping_base.apply_filter_stack") as mock_apply: + with mock.patch("spinedb_api.db_mapping.apply_filter_stack") as mock_apply: with mock.patch( - "spinedb_api.db_mapping_base.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] + "spinedb_api.db_mapping.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] ) as mock_load: db_map = DatabaseMapping(sa_url, create=True) db_map.close() @@ -84,7 +84,7 @@ def test_shorthand_filter_query_works(self): db_map.close() -class TestDatabaseMappingBase(unittest.TestCase): +class TestDatabaseMapping(unittest.TestCase): _db_map = None @classmethod @@ -97,9 +97,9 @@ def tearDownClass(cls): def test_construction_with_filters(self): db_url = IN_MEMORY_DB_URL + "?spinedbfilter=fltr1&spinedbfilter=fltr2" - with patch("spinedb_api.db_mapping_base.apply_filter_stack") as mock_apply: + with patch("spinedb_api.db_mapping.apply_filter_stack") as mock_apply: with patch( - "spinedb_api.db_mapping_base.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] + "spinedb_api.db_mapping.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] ) as mock_load: db_map = DatabaseMapping(db_url, create=True) db_map.close() @@ -109,9 +109,9 @@ def test_construction_with_filters(self): def test_construction_with_sqlalchemy_url_and_filters(self): sa_url = URL("sqlite") sa_url.query = {"spinedbfilter": ["fltr1", "fltr2"]} - with patch("spinedb_api.db_mapping_base.apply_filter_stack") as mock_apply: + with patch("spinedb_api.db_mapping.apply_filter_stack") as mock_apply: with patch( - "spinedb_api.db_mapping_base.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] + "spinedb_api.db_mapping.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] ) as mock_load: db_map = DatabaseMapping(sa_url, create=True) db_map.close() @@ -343,7 +343,7 @@ def test_get_import_alternative_returns_base_alternative_by_default(self): self.assertEqual(alternative_name, "Base") -class TestDatabaseMappingBaseQueries(unittest.TestCase): +class TestDatabaseMappingQueries(unittest.TestCase): def setUp(self): self._db_map = DatabaseMapping(IN_MEMORY_DB_URL, create=True) @@ -2217,10 +2217,10 @@ def test_rollback_addition(self): import_functions.import_object_classes(self._db_map, ("my_class",)) self._db_map.commit_session("test commit") import_functions.import_object_classes(self._db_map, ("second_class",)) - entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").valid_values()} + entity_class_names = {x["name"] for x in self._db_map.mapped_table("entity_class").valid_values()} self.assertEqual(entity_class_names, {"my_class", "second_class"}) self._db_map.rollback_session() - entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").valid_values()} + entity_class_names = {x["name"] for x in self._db_map.mapped_table("entity_class").valid_values()} self.assertEqual(entity_class_names, {"my_class"}) with self.assertRaises(SpineDBAPIError): # Nothing to commit @@ -2230,10 +2230,10 @@ def test_rollback_removal(self): import_functions.import_object_classes(self._db_map, ("my_class",)) self._db_map.commit_session("test commit") self._db_map.remove_items("entity_class", 1) - entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").valid_values()} + entity_class_names = {x["name"] for x in self._db_map.mapped_table("entity_class").valid_values()} self.assertEqual(entity_class_names, set()) self._db_map.rollback_session() - entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").valid_values()} + entity_class_names = {x["name"] for x in self._db_map.mapped_table("entity_class").valid_values()} self.assertEqual(entity_class_names, {"my_class"}) with self.assertRaises(SpineDBAPIError): # Nothing to commit @@ -2243,10 +2243,10 @@ def test_rollback_update(self): import_functions.import_object_classes(self._db_map, ("my_class",)) self._db_map.commit_session("test commit") self._db_map.get_item("entity_class", name="my_class").update(name="new_name") - entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").valid_values()} + entity_class_names = {x["name"] for x in self._db_map.mapped_table("entity_class").valid_values()} self.assertEqual(entity_class_names, {"new_name"}) self._db_map.rollback_session() - entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").valid_values()} + entity_class_names = {x["name"] for x in self._db_map.mapped_table("entity_class").valid_values()} self.assertEqual(entity_class_names, {"my_class"}) with self.assertRaises(SpineDBAPIError): # Nothing to commit @@ -2256,33 +2256,33 @@ def test_refresh_addition(self): import_functions.import_object_classes(self._db_map, ("my_class",)) self._db_map.commit_session("test commit") import_functions.import_object_classes(self._db_map, ("second_class",)) - entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").valid_values()} + entity_class_names = {x["name"] for x in self._db_map.mapped_table("entity_class").valid_values()} self.assertEqual(entity_class_names, {"my_class", "second_class"}) self._db_map.refresh_session() self._db_map.fetch_all() - entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").valid_values()} + entity_class_names = {x["name"] for x in self._db_map.mapped_table("entity_class").valid_values()} self.assertEqual(entity_class_names, {"my_class", "second_class"}) def test_refresh_removal(self): import_functions.import_object_classes(self._db_map, ("my_class",)) self._db_map.commit_session("test commit") self._db_map.remove_items("entity_class", 1) - entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").valid_values()} + entity_class_names = {x["name"] for x in self._db_map.mapped_table("entity_class").valid_values()} self.assertEqual(entity_class_names, set()) self._db_map.refresh_session() self._db_map.fetch_all() - entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").valid_values()} + entity_class_names = {x["name"] for x in self._db_map.mapped_table("entity_class").valid_values()} self.assertEqual(entity_class_names, set()) def test_refresh_update(self): import_functions.import_object_classes(self._db_map, ("my_class",)) self._db_map.commit_session("test commit") self._db_map.get_item("entity_class", name="my_class").update(name="new_name") - entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").valid_values()} + entity_class_names = {x["name"] for x in self._db_map.mapped_table("entity_class").valid_values()} self.assertEqual(entity_class_names, {"new_name"}) self._db_map.refresh_session() self._db_map.fetch_all() - entity_class_names = {x["name"] for x in self._db_map.cache.table_cache("entity_class").valid_values()} + entity_class_names = {x["name"] for x in self._db_map.mapped_table("entity_class").valid_values()} self.assertEqual(entity_class_names, {"new_name"}) def test_cascade_remove_unfetched(self): @@ -2290,7 +2290,7 @@ def test_cascade_remove_unfetched(self): import_functions.import_objects(self._db_map, (("my_class", "my_object"),)) self._db_map.commit_session("test commit") self._db_map.refresh_session() - self._db_map.cache.clear() + self._db_map.clear() self._db_map.remove_items("entity_class", 1) self._db_map.commit_session("test commit") ents = self._db_map.query(self._db_map.entity_sq).all() diff --git a/tests/test_db_cache_base.py b/tests/test_db_cache_base.py index 524bde5b..76397d2a 100644 --- a/tests/test_db_cache_base.py +++ b/tests/test_db_cache_base.py @@ -10,10 +10,10 @@ ###################################################################################################################### import unittest -from spinedb_api.db_cache_base import CacheItemBase, DBCacheBase +from spinedb_api.db_mapping_base import MappedItemBase, DatabaseMappingBase -class TestCache(DBCacheBase): +class TestDBMapping(DatabaseMappingBase): @property def item_types(self): return ["cutlery"] @@ -21,51 +21,51 @@ def item_types(self): @staticmethod def item_factory(item_type): if item_type == "cutlery": - return CacheItemBase + return MappedItemBase raise RuntimeError(f"unknown item_type '{item_type}'") class TestDBCacheBase(unittest.TestCase): def test_rolling_back_new_item_invalidates_its_id(self): - cache = TestCache() - table_cache = cache.table_cache("cutlery") - item = table_cache.add_item({}, new=True) + db_map = TestDBMapping() + mapped_table = db_map.mapped_table("cutlery") + item = mapped_table.add_item({}, new=True) self.assertTrue(item.is_id_valid) self.assertIn("id", item) id_ = item["id"] - cache.rollback() + db_map.rollback() self.assertFalse(item.is_id_valid) self.assertEqual(item["id"], id_) class TestTableCache(unittest.TestCase): def test_readding_item_with_invalid_id_creates_new_id(self): - cache = TestCache() - table_cache = cache.table_cache("cutlery") - item = table_cache.add_item({}, new=True) + db_map = TestDBMapping() + mapped_table = db_map.mapped_table("cutlery") + item = mapped_table.add_item({}, new=True) id_ = item["id"] - cache.rollback() + db_map.rollback() self.assertFalse(item.is_id_valid) - table_cache.add_item(item, new=True) + mapped_table.add_item(item, new=True) self.assertTrue(item.is_id_valid) self.assertNotEqual(item["id"], id_) -class TestCacheItemBase(unittest.TestCase): +class TestMappedItemBase(unittest.TestCase): def test_id_is_valid_initially(self): - cache = TestCache() - item = CacheItemBase(cache, "cutlery") + db_map = TestDBMapping() + item = MappedItemBase(db_map, "cutlery") self.assertTrue(item.is_id_valid) def test_id_can_be_invalidated(self): - cache = TestCache() - item = CacheItemBase(cache, "cutlery") + db_map = TestDBMapping() + item = MappedItemBase(db_map, "cutlery") item.invalidate_id() self.assertFalse(item.is_id_valid) def test_setting_new_id_validates_it(self): - cache = TestCache() - item = CacheItemBase(cache, "cutlery") + db_map = TestDBMapping() + item = MappedItemBase(db_map, "cutlery") item.invalidate_id() self.assertFalse(item.is_id_valid) item["id"] = 23 From 6d2fad2def3f4ec8b754aa72d7c870f4de9048c5 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 4 Oct 2023 11:27:40 +0200 Subject: [PATCH 108/317] Fix export functions --- spinedb_api/export_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spinedb_api/export_functions.py b/spinedb_api/export_functions.py index 08a265b5..de09933a 100644 --- a/spinedb_api/export_functions.py +++ b/spinedb_api/export_functions.py @@ -88,7 +88,7 @@ def _get_items_from_db_map(db_map, tablename, ids): yield from db_map.mapped_table(tablename).valid_values() return for id_ in ids: - item = db_map.get_item(tablename, id=id_) + item = db_map.get_mapped_item(tablename, id_) if item.is_valid(): yield item From 0fbadc56832be5629731115d1f6e4fbd3a2a6c6e Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 4 Oct 2023 12:53:04 +0200 Subject: [PATCH 109/317] Improve parameter_value docs plus minor improvements --- spinedb_api/db_mapping_base.py | 8 +- spinedb_api/db_mapping_commit_mixin.py | 6 +- spinedb_api/parameter_value.py | 175 +++++++++--------- ..._cache_base.py => test_db_mapping_base.py} | 8 +- 4 files changed, 101 insertions(+), 96 deletions(-) rename tests/{test_db_cache_base.py => test_db_mapping_base.py} (95%) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 917b67af..cc7bdf5e 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -103,7 +103,7 @@ def dirty_ids(self, item_type): if item.status in (Status.to_add, Status.to_update) } - def dirty_items(self): + def _dirty_items(self): """Returns a list of tuples of the form (item_type, (to_add, to_update, to_remove)) corresponding to items that have been modified but not yet committed. @@ -141,7 +141,7 @@ def dirty_items(self): dirty_items.append((item_type, (to_add, to_update, to_remove))) return dirty_items - def rollback(self): + def _rollback(self): """Discards uncommitted changes. Namely, removes all the added items, resets all the updated items, and restores all the removed items. @@ -149,7 +149,7 @@ def rollback(self): Returns: bool: False if there is no uncommitted items, True if successful. """ - dirty_items = self.dirty_items() + dirty_items = self._dirty_items() if not dirty_items: return False to_add_by_type = [] @@ -174,7 +174,7 @@ def rollback(self): item.invalidate_id() return True - def refresh(self): + def _refresh(self): """Clears fetch progress, so the DB is queried again.""" self._offsets.clear() self._fetched_item_types.clear() diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py index 3aaba6f4..20cc8897 100644 --- a/spinedb_api/db_mapping_commit_mixin.py +++ b/spinedb_api/db_mapping_commit_mixin.py @@ -26,7 +26,7 @@ def commit_session(self, comment): """ if not comment: raise SpineDBAPIError("Commit message cannot be empty.") - dirty_items = self.dirty_items() + dirty_items = self._dirty_items() if not dirty_items: raise SpineDBAPIError("Nothing to commit.") user = self.username @@ -50,11 +50,11 @@ def commit_session(self, comment): def rollback_session(self): """Discards all the changes from the in-memory mapping.""" - if not self.rollback(): + if not self._rollback(): raise SpineDBAPIError("Nothing to rollback.") if self._memory: self._memory_dirty = False def refresh_session(self): """Resets the fetch status so new items from the DB can be retrieved.""" - self.refresh() + self._refresh() diff --git a/spinedb_api/parameter_value.py b/spinedb_api/parameter_value.py index c2cd23ed..4814bf9d 100644 --- a/spinedb_api/parameter_value.py +++ b/spinedb_api/parameter_value.py @@ -12,12 +12,12 @@ """ Support utilities and classes to deal with Spine parameter values. -The :func:`from_database` function receives the parameter value and type fields from the database returning -a float, Datetime, Duration, Array, TimePattern, TimeSeriesFixedResolution, -TimeSeriesVariableResolution or Map objects. +The :func:`from_database` function receives a DB representation of a parameter value (value and type) and returns +a float, str, bool, :class:`DateTime`, :class:`Duration`, :class:`Array`, :class:`TimePattern`, +:class:`TimeSeriesFixedResolution`, :class:`TimeSeriesVariableResolution` or :class:`Map` object. The above objects can be converted back to the database format by the :func:`to_database` free function -or by their `to_database` member functions. +or by their :meth:`~ParameterValue.to_database` member function. Individual datetimes are represented as datetime objects from the standard Python library. Individual time steps are represented as relativedelta objects from the dateutil package. @@ -53,10 +53,10 @@ def duration_to_relativedelta(duration): Converts a duration to a relativedelta object. Args: - duration (str): a duration specification + duration (str): a duration string. Returns: - a relativedelta object corresponding to the given duration + :class:`~dateutil.relativedelta.relativedelta`: a relativedelta object corresponding to the given duration. """ try: count, abbreviation, full_unit = re.split("\\s|([a-z]|[A-Z])", duration, maxsplit=1) @@ -84,10 +84,10 @@ def relativedelta_to_duration(delta): Converts a relativedelta to duration. Args: - delta (relativedelta): the relativedelta to convert + delta (:class:`~dateutil.relativedelta.relativedelta`): the relativedelta to convert. Returns: - a duration string + str: a duration string """ if delta.seconds > 0: seconds = delta.seconds @@ -119,15 +119,15 @@ def relativedelta_to_duration(delta): def load_db_value(db_value, value_type=None): """ - Loads a database parameter value into a Python object using JSON. - Adds the "type" property to dicts representing complex types. + Parses a database representation of a parameter value (value and type) into a Python object, using JSON. + If the result is a dict, adds the "type" property to it. Args: - db_value (bytes, optional): a value in the database - value_type (str, optional): the type in case of complex ones + db_value (bytes, optional): the database value. + value_type (str, optional): the value type. Returns: - Any: the parsed parameter value + any: the parsed parameter value """ if db_value is None: return None @@ -142,15 +142,14 @@ def load_db_value(db_value, value_type=None): def dump_db_value(parsed_value): """ - Dumps a Python object into a database parameter value using JSON. - Extracts the "type" property from dicts representing complex types. + Unparses a Python object into a database representation of a parameter value (value and type), using JSON. + If the given object is a dict, extracts the "type" property from it. Args: - parsed_value (Any): the Python object + parsed_value (any): a Python object, typically obtained by calling :func:`load_db_value`. Returns: - str: the database parameter value - str: the type + tuple(str,str): database representation (value and type). """ value_type = parsed_value.pop("type") if isinstance(parsed_value, dict) else None db_value = json.dumps(parsed_value).encode("UTF8") @@ -161,14 +160,14 @@ def dump_db_value(parsed_value): def from_database(database_value, value_type=None): """ - Converts a parameter value from its database representation into an encoded Python object. + Converts a database representation of a parameter value (value and type) into an encoded parameter value. Args: - database_value (bytes, optional): a value in the database - value_type (str, optional): the type in case of complex ones + database_value (bytes, optional): the database value + value_type (str, optional): the value type Returns: - Any: the encoded parameter value + :class:`ParameterValue`, float, str, bool or None: the encoded parameter value. """ parsed = load_db_value(database_value, value_type) if isinstance(parsed, dict): @@ -182,16 +181,14 @@ def from_database(database_value, value_type=None): def from_database_to_single_value(database_value, value_type): """ - Converts a value from its database representation into a single value. - - Indexed values get converted to their type string. + Same as :func:`from_database`, but in the case of indexed types it returns just the type as a string. Args: - database_value (bytes): a value in the database - value_type (str, optional): value's type + database_value (bytes): the database value + value_type (str, optional): the value type Returns: - Any: single-value representation + :class:`ParameterValue`, float, str, bool or None: the encoded parameter value or its type. """ if value_type is None or value_type not in ("map", "time_series", "time_pattern", "array"): return from_database(database_value, value_type) @@ -200,11 +197,11 @@ def from_database_to_single_value(database_value, value_type): def from_database_to_dimension_count(database_value, value_type): """ - Counts dimensions of value's database representation + Counts the dimensions in a database representation of a parameter value (value and type). Args: - database_value (bytes): a value in the database - value_type (str, optional): value's type + database_value (bytes): the database value + value_type (str, optional): the value type Returns: int: number of dimensions @@ -220,14 +217,13 @@ def from_database_to_dimension_count(database_value, value_type): def to_database(parsed_value): """ - Converts an encoded Python object into its database representation. + Converts an encoded parameter value into its database representation (value and type). Args: - value: the value to convert. It can be the result of either ``load_db_value`` or ``from_database```. + value(any): a Python object, typically obtained by calling :func:`load_db_value` or :func:`from_database`. Returns: - bytes: value's database representation as bytes - str: the value type + tuple(bytes,str): database representation (value and type). """ if hasattr(parsed_value, "to_database"): return parsed_value.to_database() @@ -235,30 +231,30 @@ def to_database(parsed_value): return db_value, None -def from_dict(value_dict): +def from_dict(value): """ - Converts a complex (relationship) parameter value from its dictionary representation to a Python object. + Converts a dictionary representation of a parameter value into an encoded parameter value. Args: - value_dict (dict): value's dictionary; a parsed JSON object with the "type" key + value (dict): the value dictionary including the "type" key. Returns: - the encoded (relationship) parameter value + :class:`ParameterValue`, float, str, bool or None: the encoded parameter value. """ - value_type = value_dict["type"] + value_type = value["type"] try: if value_type == "date_time": - return _datetime_from_database(value_dict["data"]) + return _datetime_from_database(value["data"]) if value_type == "duration": - return _duration_from_database(value_dict["data"]) + return _duration_from_database(value["data"]) if value_type == "map": - return _map_from_database(value_dict) + return _map_from_database(value) if value_type == "time_pattern": - return _time_pattern_from_database(value_dict) + return _time_pattern_from_database(value) if value_type == "time_series": - return _time_series_from_database(value_dict) + return _time_series_from_database(value) if value_type == "array": - return _array_from_database(value_dict) + return _array_from_database(value) raise ParameterValueFormatError(f'Unknown parameter value type "{value_type}"') except KeyError as error: raise ParameterValueFormatError(f'"{error.args[0]}" is missing in the parameter value description') @@ -268,15 +264,15 @@ def fix_conflict(new, old, on_conflict="merge"): """Resolves conflicts between parameter values: Args: - new (any): new parameter value to write - old (any): existing parameter value in the db + new (:class:`ParameterValue`, float, str, bool or None): new parameter value to be written. + old (:class:`ParameterValue`, float, str, bool or None): an existing parameter value in the db. on_conflict (str): conflict resolution strategy: - - 'merge': Merge indexes if possible, otherwise replace - - 'replace': Replace old with new - - 'keep': keep old + - 'merge': Merge indexes if possible, otherwise replace. + - 'replace': Replace old with new. + - 'keep': Keep old. Returns: - any: a parameter value with conflicts resolved + :class:`ParameterValue`, float, str, bool or None: a new parameter value with conflicts resolved. """ funcs = {"keep": lambda new, old: old, "replace": lambda new, old: new, "merge": merge} func = funcs.get(on_conflict) @@ -288,13 +284,14 @@ def fix_conflict(new, old, on_conflict="merge"): def merge(value, other): - """Merges other into value, returns the result. + """Merges the DB representation of two parameter values. + Args: - value (tuple): recipient value and type - other (tuple): other value and type + value (tuple(bytes,str)): recipient value and type. + other (tuple(bytes,str)): other value and type. Returns: - tuple: value and type of merged value + tuple(bytes,str): the DB representation of the merged value. """ parsed_value = from_database(*value) if not hasattr(parsed_value, "merge"): @@ -631,7 +628,31 @@ def to_database(self): return json.dumps(self._list_value_id).encode("UTF8"), self.type_() -class DateTime: +class ParameterValue: + """Base class for all encoded parameter values.""" + + def to_dict(self): + """Returns the dictionary representation of this object. + + Returns: + dict: a dictionary including the "type" key. + """ + raise NotImplementedError() + + @staticmethod + def type_(): + raise NotImplementedError() + + def to_database(self): + """Returns the database representation of this object as JSON bytes and type. + + Returns: + tuple(bytes,str): the DB value and type. + """ + raise NotImplementedError() + + +class DateTime(ParameterValue): """A single datetime value.""" VALUE_TYPE = "single value" @@ -639,7 +660,7 @@ class DateTime: def __init__(self, value=None): """ Args: - value (DataTime or str or datetime.datetime): a timestamp + value (:class:`DateTime` or str or datetime.datetime): a timestamp """ if value is None: value = datetime(year=2000, month=1, day=1) @@ -676,7 +697,6 @@ def value_to_database_data(self): return self._value.isoformat() def to_dict(self): - """Returns the database representation of this object.""" return {"data": self.value_to_database_data()} @staticmethod @@ -684,7 +704,6 @@ def type_(): return "date_time" def to_database(self): - """Returns the database representation of this object as JSON.""" return json.dumps(self.to_dict()).encode("UTF8"), self.type_() @property @@ -693,9 +712,9 @@ def value(self): return self._value -class Duration: +class Duration(ParameterValue): """ - This class represents a duration in time. + A duration in time. Durations are always handled as relativedeltas. """ @@ -705,7 +724,7 @@ class Duration: def __init__(self, value=None): """ Args: - value (str or relativedelta): the time step + value (str or :class:`~dateutil.dateutil.relativedelta`): the time step """ if value is None: value = relativedelta(hours=1) @@ -734,7 +753,6 @@ def value_to_database_data(self): return relativedelta_to_duration(self._value) def to_dict(self): - """Returns the database representation of the duration.""" return {"data": self.value_to_database_data()} @staticmethod @@ -742,7 +760,6 @@ def type_(): return "duration" def to_database(self): - """Returns the database representation of the duration as JSON.""" return json.dumps(self.to_dict()).encode("UTF8"), self.type_() @property @@ -784,9 +801,9 @@ def __bool__(self): return np.size(self) != 0 -class IndexedValue: +class IndexedValue(ParameterValue): """ - An abstract base class for indexed values. + Base class for all indexed values. Attributes: index_name (str): index name @@ -831,7 +848,6 @@ def indexes(self, indexes): self._indexes = _Indexes(indexes) def to_database(self): - """Return the database representation of the value.""" return json.dumps(self.to_dict()).encode("UTF8"), self.type_() @property @@ -862,11 +878,6 @@ def set_value(self, index, value): self.values[pos] = value def to_dict(self): - """Converts the value to a Python dictionary. - - Returns: - dict(): mapping from indexes to values - """ raise NotImplementedError() def merge(self, other): @@ -923,7 +934,6 @@ def type_(): return "array" def to_dict(self): - """See base class.""" value_type_id = { float: "float", str: "str", # String could also mean time_period but we don't have any way to distinguish that, yet. @@ -949,7 +959,7 @@ def value_type(self): class IndexedNumberArray(IndexedValue): """ - An abstract base class for indexed floats. + Abstract base class for all values mapping indexes to floats. The indexes and numbers are stored in numpy.ndarrays. """ @@ -975,12 +985,11 @@ def type_(): raise NotImplementedError() def to_dict(self): - """Return the database representation of the value.""" raise NotImplementedError() class TimeSeries(IndexedNumberArray): - """An abstract base class for time series.""" + """Abstract base class for time-series values.""" VALUE_TYPE = "time series" DEFAULT_INDEX_NAME = "t" @@ -1093,7 +1102,7 @@ def __setitem__(self, position, index): class TimePattern(IndexedNumberArray): - """Represents a time pattern (relationship) parameter value.""" + """A time-pattern parameter value.""" VALUE_TYPE = "time pattern" DEFAULT_INDEX_NAME = "p" @@ -1132,7 +1141,6 @@ def type_(): return "time_pattern" def to_dict(self): - """Returns the database representation of this time pattern.""" value_dict = {"data": dict(zip(self._indexes, self._values))} if self.index_name != "p": value_dict["index_name"] = self.index_name @@ -1141,7 +1149,7 @@ def to_dict(self): class TimeSeriesFixedResolution(TimeSeries): """ - A time series with fixed durations between the time stamps. + A time-series value with fixed durations between the time stamps. When getting the indexes the durations are applied cyclically. @@ -1265,7 +1273,6 @@ def resolution(self, resolution): self._indexes = None def to_dict(self): - """Returns the value in its database representation.""" if len(self._resolution) > 1: resolution_as_json = [relativedelta_to_duration(step) for step in self._resolution] else: @@ -1285,7 +1292,7 @@ def to_dict(self): class TimeSeriesVariableResolution(TimeSeries): - """A class representing time series data with variable time steps.""" + """A time-series value with variable time steps.""" def __init__(self, indexes, values, ignore_year, repeat, index_name=""): """ @@ -1327,7 +1334,6 @@ def __eq__(self, other): ) def to_dict(self): - """Returns the value in its database representation""" value_dict = dict() value_dict["data"] = {str(index): float(value) for index, value in zip(self._indexes, self._values)} # Add "index" entry only if its contents are not set to their default values. @@ -1388,7 +1394,6 @@ def type_(): return "map" def to_dict(self): - """Returns map's database representation.""" value_dict = { "index_type": _map_index_type_to_database(self._index_type), "data": self.value_to_database_data(), @@ -1482,7 +1487,7 @@ def convert_map_to_table(map_, make_square=True, row_this_far=None, empty=None): map_ (Map): map to convert make_square (bool): if True, append None to shorter rows, otherwise leave the row as is row_this_far (list, optional): current row; used for recursion - empty (Any, optional): object to fill empty cells with + empty (any, optional): object to fill empty cells with Returns: list of list: map's rows diff --git a/tests/test_db_cache_base.py b/tests/test_db_mapping_base.py similarity index 95% rename from tests/test_db_cache_base.py rename to tests/test_db_mapping_base.py index 76397d2a..fbeac2be 100644 --- a/tests/test_db_cache_base.py +++ b/tests/test_db_mapping_base.py @@ -25,7 +25,7 @@ def item_factory(item_type): raise RuntimeError(f"unknown item_type '{item_type}'") -class TestDBCacheBase(unittest.TestCase): +class TestDBMappingBase(unittest.TestCase): def test_rolling_back_new_item_invalidates_its_id(self): db_map = TestDBMapping() mapped_table = db_map.mapped_table("cutlery") @@ -33,18 +33,18 @@ def test_rolling_back_new_item_invalidates_its_id(self): self.assertTrue(item.is_id_valid) self.assertIn("id", item) id_ = item["id"] - db_map.rollback() + db_map._rollback() self.assertFalse(item.is_id_valid) self.assertEqual(item["id"], id_) -class TestTableCache(unittest.TestCase): +class TestMappedTable(unittest.TestCase): def test_readding_item_with_invalid_id_creates_new_id(self): db_map = TestDBMapping() mapped_table = db_map.mapped_table("cutlery") item = mapped_table.add_item({}, new=True) id_ = item["id"] - db_map.rollback() + db_map._rollback() self.assertFalse(item.is_id_valid) mapped_table.add_item(item, new=True) self.assertTrue(item.is_id_valid) From 6a4b1a8400cbb17d546505c0884cc27fd044167c Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 4 Oct 2023 12:57:47 +0200 Subject: [PATCH 110/317] Attempt to fix docs building online --- docs/source/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 9d7f39af..4e71355a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -93,7 +93,7 @@ autoapi_ignore = [ '*/spinedb_api/alembic/*', ] # ignored modules -autoapi_keep_files = True +# autoapi_keep_files = True def _skip_member(app, what, name, obj, skip, options): From 3ad90162124f80b22f886461397c4542895153a5 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 4 Oct 2023 13:01:38 +0200 Subject: [PATCH 111/317] Second attempt to fix docs build --- docs/requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 0fb2adf7..02f01599 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,5 +1,5 @@ # Requirements for compiling the documentation -sphinx >= 2.2.0 +sphinx >= 7.1.2 sphinx_rtd_theme >= 0.4.3 recommonmark >= 0.6.0 -sphinx-autoapi >= 1.1.0 +sphinx-autoapi >= 2.0.0 From ffbc82e01be4de3f78e12f42a9c30ef794dda8ec Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 4 Oct 2023 13:18:06 +0200 Subject: [PATCH 112/317] Update readthedocs.yml --- .readthedocs.yml | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/.readthedocs.yml b/.readthedocs.yml index 7ab0106a..9e9405d3 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -5,6 +5,12 @@ # Required version: 2 +# Set the version of Python and other tools you might need +build: + os: ubuntu-22.04 + tools: + python: "3.9" + # Build documentation in the docs/ directory with Sphinx sphinx: configuration: docs/source/conf.py @@ -16,8 +22,6 @@ formats: # Optionally set the version of Python and requirements required to build your docs python: - version: 3.8 install: - - method: pip - path: . + - requirements: requirements.txt - requirements: docs/requirements.txt From 6af817e5b41facd9e0e9faa22c23214bce61fe50 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 4 Oct 2023 13:20:53 +0200 Subject: [PATCH 113/317] One more try to fix docs online build --- docs/source/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 4e71355a..bb527f25 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -93,7 +93,7 @@ autoapi_ignore = [ '*/spinedb_api/alembic/*', ] # ignored modules -# autoapi_keep_files = True +autoapi_keep_files = False def _skip_member(app, what, name, obj, skip, options): From 4fdb442d062bd950b337cc67861d87f51f4f115e Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 4 Oct 2023 15:02:58 +0200 Subject: [PATCH 114/317] Try to fix the docs build again --- docs/source/conf.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index bb527f25..626164a3 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -46,7 +46,7 @@ 'sphinx.ext.todo', 'sphinx.ext.coverage', 'sphinx.ext.ifconfig', - 'sphinx.ext.viewcode', + # 'sphinx.ext.viewcode', 'sphinx.ext.githubpages', 'sphinx.ext.napoleon', 'sphinx.ext.intersphinx', @@ -93,7 +93,6 @@ autoapi_ignore = [ '*/spinedb_api/alembic/*', ] # ignored modules -autoapi_keep_files = False def _skip_member(app, what, name, obj, skip, options): From 6213c88acbd9d6d81b55344983aeeaa55320a602 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 4 Oct 2023 16:24:11 +0200 Subject: [PATCH 115/317] Improve documentation --- spinedb_api/db_mapping.py | 25 ++- spinedb_api/parameter_value.py | 273 +++++++++++++++++++-------------- 2 files changed, 179 insertions(+), 119 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 01252863..741c603e 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -51,11 +51,11 @@ class DatabaseMapping( - DatabaseMappingQueryMixin, DatabaseMappingAddMixin, DatabaseMappingUpdateMixin, DatabaseMappingRemoveMixin, DatabaseMappingCommitMixin, + DatabaseMappingQueryMixin, DatabaseMappingBase, ): """Enables communication with a Spine DB. @@ -75,8 +75,10 @@ class DatabaseMapping( These methods also fetch data from the DB into the in-memory mapping to perform the necessary integrity checks (unique and foreign key constraints). + The :attr:`item_types` property contains the supported item types (equivalent to the table names in the DB). + To retrieve an item or to manipulate it, you typically need to specify certain fields. - The :meth:`describe_item_type` method is provided to help you identify these fields. + The :meth:`describe_item_type` method is provided to help you with this. Modifications to the in-memory mapping are committed (written) to the DB via :meth:`commit_session`, or rolled back (discarded) via :meth:`rollback_session`. @@ -90,6 +92,11 @@ class DatabaseMapping( The :meth:`query` method is also provided as an alternative way to retrieve data from the DB while bypassing the in-memory mapping entirely. + + The class is intended to be used as a context manager. For example:: + + with DatabaseMapping(db_url) as db_map: + print(db_map.item_types) """ ITEM_TYPES = ( @@ -367,7 +374,7 @@ def override_create_import_alternative(self, method): self._import_alternative_name = None def get_filter_configs(self): - """Returns filters applicable to this DB mapping. + """Returns the filters used to build this DB mapping. Returns: list(dict): @@ -381,11 +388,16 @@ def get_table(self, tablename): def get_item(self, item_type, fetch=True, skip_removed=True, **kwargs): """Finds and returns and item matching the arguments, or None if none found. + Example:: + with DatabaseMapping(db_url) as db_map: + bar = db_map.get_item("entity", class_name="foo", name="bar") + print(bar["description"]) # Prints the description field + Args: item_type (str): The type of the item. fetch (bool, optional): Whether to fetch the DB in case the item is not found in memory. skip_removed (bool, optional): Whether to ignore removed items. - **kwargs: Fields of one of the item type's unique keys and their values for the requested item. + **kwargs: Fields and values for one of the unique keys of the item type. Returns: :class:`PublicItem` or None @@ -401,6 +413,11 @@ def get_item(self, item_type, fetch=True, skip_removed=True, **kwargs): def get_items(self, item_type, fetch=True, skip_removed=True): """Finds and returns and item matching the arguments, or None if none found. + + Example:: + with DatabaseMapping(db_url) as db_map: + all_entities = db_map.get_items("entity") + Args: item_type (str): The type of items to get. fetch (bool, optional): Whether to fetch the DB before returning the items. diff --git a/spinedb_api/parameter_value.py b/spinedb_api/parameter_value.py index 4814bf9d..7ef96b85 100644 --- a/spinedb_api/parameter_value.py +++ b/spinedb_api/parameter_value.py @@ -615,19 +615,6 @@ def _array_from_database(value_dict): return Array(data, value_type, index_name) -class ListValueRef: - def __init__(self, list_value_id): - self._list_value_id = list_value_id - - @staticmethod - def type_(): - return "list_value_ref" - - def to_database(self): - """Returns the database representation of this object as JSON.""" - return json.dumps(self._list_value_id).encode("UTF8"), self.type_() - - class ParameterValue: """Base class for all encoded parameter values.""" @@ -641,19 +628,36 @@ def to_dict(self): @staticmethod def type_(): + """Returns the value type for this object. + + Returns: + str: the value type. + """ raise NotImplementedError() def to_database(self): - """Returns the database representation of this object as JSON bytes and type. + """Returns the database representation of this object (value and type). Returns: tuple(bytes,str): the DB value and type. """ - raise NotImplementedError() + return json.dumps(self.to_dict()).encode("UTF8"), self.type_() + + +class ListValueRef: + def __init__(self, list_value_id): + self._list_value_id = list_value_id + + @staticmethod + def type_(): + return "list_value_ref" + + def to_database(self): + return json.dumps(self._list_value_id).encode("UTF8"), self.type_() class DateTime(ParameterValue): - """A single datetime value.""" + """A moment in time.""" VALUE_TYPE = "single value" @@ -676,7 +680,6 @@ def __init__(self, value=None): self._value = value def __eq__(self, other): - """Returns True if other is equal to this object.""" if not isinstance(other, DateTime): return NotImplemented return self._value == other._value @@ -703,12 +706,13 @@ def to_dict(self): def type_(): return "date_time" - def to_database(self): - return json.dumps(self.to_dict()).encode("UTF8"), self.type_() - @property def value(self): - """Returns the value as a datetime object.""" + """The value. + + Returns: + :class:`~datetime.datetime`: + """ return self._value @@ -716,7 +720,7 @@ class Duration(ParameterValue): """ A duration in time. - Durations are always handled as relativedeltas. + Durations are always handled as :class:`~dateutil.dateutil.relativedelta`s. """ VALUE_TYPE = "single value" @@ -724,7 +728,7 @@ class Duration(ParameterValue): def __init__(self, value=None): """ Args: - value (str or :class:`~dateutil.dateutil.relativedelta`): the time step + value (str or :class:`~dateutil.dateutil.relativedelta`): the duration """ if value is None: value = relativedelta(hours=1) @@ -737,7 +741,6 @@ def __init__(self, value=None): self._value = value def __eq__(self, other): - """Returns True if other is equal to this object.""" if not isinstance(other, Duration): return NotImplemented return self._value == other._value @@ -749,7 +752,7 @@ def __str__(self): return str(relativedelta_to_duration(self._value)) def value_to_database_data(self): - """Returns the 'data' attribute part of Duration's database representation.""" + """Returns the 'data' property of this object's database representation.""" return relativedelta_to_duration(self._value) def to_dict(self): @@ -759,9 +762,6 @@ def to_dict(self): def type_(): return "duration" - def to_database(self): - return json.dumps(self.to_dict()).encode("UTF8"), self.type_() - @property def value(self): """Returns the duration as a :class:`relativedelta`.""" @@ -803,7 +803,7 @@ def __bool__(self): class IndexedValue(ParameterValue): """ - Base class for all indexed values. + Base class for all values that have indexes. Attributes: index_name (str): index name @@ -814,7 +814,7 @@ class IndexedValue(ParameterValue): def __init__(self, index_name): """ Args: - index_name (str): index name + index_name (str): index name. """ self._indexes = None self._values = None @@ -825,39 +825,46 @@ def __bool__(self): return bool(self.indexes) def __len__(self): - """Returns the number of values.""" return len(self.indexes) @staticmethod def type_(): - """Returns a type identifier string. - - Returns: - str: type identifier - """ raise NotImplementedError() @property def indexes(self): - """Returns the indexes.""" + """The indexes. + + Returns: + :class:`~numpy.ndarray` + """ return self._indexes @indexes.setter def indexes(self, indexes): - """Sets the indexes.""" - self._indexes = _Indexes(indexes) + """Sets the indexes. - def to_database(self): - return json.dumps(self.to_dict()).encode("UTF8"), self.type_() + Args: + indexes (:class:`~numpy.ndarray`) + """ + self._indexes = _Indexes(indexes) @property def values(self): - """Returns the data values.""" + """The values. + + Returns: + :class:`~numpy.ndarray` + """ return self._values @values.setter def values(self, values): - """Sets the values.""" + """Sets the values. + + Args: + values (:class:`~numpy.ndarray`) + """ self._values = values def get_nearest(self, index): @@ -865,14 +872,26 @@ def get_nearest(self, index): return self.values[pos] def get_value(self, index): - """Returns the value at the given index.""" + """Returns the value at a given index. + + Args: + index (any): The index. + + Returns: + any: The value. + """ pos = self.indexes.position_lookup.get(index) if pos is None: return None return self.values[pos] def set_value(self, index, value): - """Sets the value at the given index.""" + """Sets the value at a given index. + + Args: + index (any): The index. + value (any): The value. + """ pos = self.indexes.position_lookup.get(index) if pos is not None: self.values[pos] = value @@ -901,10 +920,10 @@ class Array(IndexedValue): def __init__(self, values, value_type=None, index_name=""): """ Args: - values (Sequence): array's values - value_type (Type, optional): array element type; will be deduced from the array if not given - and defaults to float if ``values`` is empty - index_name (str): index name + values (Sequence): the values in the array. + value_type (Type, optional): array element type; will be deduced from ``values`` if not given + and defaults to float if ``values`` is empty. + index_name (str): index name. """ super().__init__(index_name if index_name else Array.DEFAULT_INDEX_NAME) if value_type is None: @@ -953,7 +972,11 @@ def to_dict(self): @property def value_type(self): - """Returns the type of array's elements.""" + """Returns the type of the values. + + Returns: + str: + """ return self._value_type @@ -961,21 +984,20 @@ class IndexedNumberArray(IndexedValue): """ Abstract base class for all values mapping indexes to floats. - The indexes and numbers are stored in numpy.ndarrays. + The indexes and numbers are stored in :class:`~numpy.ndarray`s. """ def __init__(self, index_name, values): """ Args: - index_name (str): index name - values (Sequence): array's values; index handling should be implemented by subclasses + index_name (str): index name. + values (Sequence): the values in the array; index handling should be implemented by subclasses. """ super().__init__(index_name) self.values = values @IndexedValue.values.setter def values(self, values): - """Sets the values.""" if not isinstance(values, np.ndarray) or not values.dtype == np.dtype(float): values = np.array(values, dtype=float) self._values = values @@ -989,7 +1011,7 @@ def to_dict(self): class TimeSeries(IndexedNumberArray): - """Abstract base class for time-series values.""" + """Abstract base class for time-series.""" VALUE_TYPE = "time series" DEFAULT_INDEX_NAME = "t" @@ -997,10 +1019,10 @@ class TimeSeries(IndexedNumberArray): def __init__(self, values, ignore_year, repeat, index_name=""): """ Args: - values (Sequence): an array of values - ignore_year (bool): True if the year should be ignored in the time stamps - repeat (bool): True if the series should be repeated from the beginning - index_name (str): index name + values (Sequence): the values in the time-series. + ignore_year (bool): True if the year should be ignored. + repeat (bool): True if the series is repeating. + index_name (str): index name. """ if len(values) < 1: raise ParameterValueFormatError("Time series too short. Must have one or more values") @@ -1009,25 +1031,42 @@ def __init__(self, values, ignore_year, repeat, index_name=""): self._repeat = repeat def __len__(self): - """Returns the number of values.""" return len(self._values) @property def ignore_year(self): - """Returns True if the year should be ignored.""" + """Whether the year should be ignored. + + Returns: + bool: + """ return self._ignore_year @ignore_year.setter def ignore_year(self, ignore_year): + """Sets the ignore_year property. + + Args: + bool: new value. + """ self._ignore_year = bool(ignore_year) @property def repeat(self): - """Returns True if the series should be repeated.""" + """Whether the series is repeating. + + Returns: + bool: + """ return self._repeat @repeat.setter def repeat(self, repeat): + """Sets the repeat property. + + Args: + bool: new value. + """ self._repeat = bool(repeat) @staticmethod @@ -1035,7 +1074,6 @@ def type_(): return "time_series" def to_dict(self): - """Return the database representation of the value.""" raise NotImplementedError() @@ -1111,7 +1149,7 @@ def __init__(self, indexes, values, index_name=""): """ Args: indexes (list): a list of time pattern strings - values (Sequence): an array of values corresponding to the time patterns + values (Sequence): the value for each time pattern. index_name (str): index name """ if len(indexes) != len(values): @@ -1122,7 +1160,6 @@ def __init__(self, indexes, values, index_name=""): self.indexes = indexes def __eq__(self, other): - """Returns True if other is equal to this object.""" if not isinstance(other, TimePattern): return NotImplemented return ( @@ -1133,7 +1170,6 @@ def __eq__(self, other): @IndexedNumberArray.indexes.setter def indexes(self, indexes): - """Sets the indexes.""" self._indexes = _TimePatternIndexes(indexes, dtype=np.object_) @staticmethod @@ -1149,7 +1185,7 @@ def to_dict(self): class TimeSeriesFixedResolution(TimeSeries): """ - A time-series value with fixed durations between the time stamps. + A time-series with fixed durations between the time stamps. When getting the indexes the durations are applied cyclically. @@ -1162,12 +1198,12 @@ class TimeSeriesFixedResolution(TimeSeries): def __init__(self, start, resolution, values, ignore_year, repeat, index_name=""): """ Args: - start (str or datetime or datetime64): the first time stamp - resolution (str, relativedelta, list): duration(s) between the time stamps - values (Sequence): data values at each time stamp - ignore_year (bool): whether or not the time-series should apply to any year - repeat (bool): whether or not the time series should repeat cyclically - index_name (str): index name + start (str or :class:`~datetime.datetime` or :class:`numpy.datetime64`): the first time stamp + resolution (str, :class:`dateutil.relativedelta.relativedelta`, list): duration(s) between the time stamps. + values (Sequence): the values in the time-series. + ignore_year (bool): True if the year should be ignored. + repeat (bool): True if the series is repeating. + index_name (str): index name. """ super().__init__(values, ignore_year, repeat, index_name) self._start = None @@ -1176,7 +1212,6 @@ def __init__(self, start, resolution, values, ignore_year, repeat, index_name="" self.resolution = resolution def __eq__(self, other): - """Returns True if other is equal to this object.""" if not isinstance(other, TimeSeriesFixedResolution): return NotImplemented return ( @@ -1211,29 +1246,31 @@ def _get_memoized_indexes(self): @property def indexes(self): - """Returns the time stamps as a numpy.ndarray of numpy.datetime64 objects.""" if self._indexes is None: self.indexes = self._get_memoized_indexes() return IndexedValue.indexes.fget(self) @indexes.setter def indexes(self, indexes): - """Sets the indexes.""" # Needed because we redefine the setter self._indexes = _Indexes(indexes) @property def start(self): - """Returns the start index.""" + """Returns the start index. + + Returns: + :class:`~numpy.datetime64`: + """ return self._start @start.setter def start(self, start): """ - Sets the start datetime. + Sets the start index. Args: - start (datetime or datetime64 or str): the start of the series + start (:class:`~datetime.datetime` or :class:`~numpy.datetime64` or str): the start of the series """ if isinstance(start, str): try: @@ -1248,7 +1285,11 @@ def start(self, start): @property def resolution(self): - """Returns the resolution as list of durations.""" + """Returns the resolution as list of durations. + + Returns: + list(:class:`Duration`): + """ return self._resolution @resolution.setter @@ -1257,7 +1298,7 @@ def resolution(self, resolution): Sets the resolution. Args: - resolution (str, relativedelta, list): resolution or a list thereof + resolution (str, :class:`~.dateutil.relativedelta.relativedelta`, list): resolution or a list thereof """ if isinstance(resolution, str): resolution = [duration_to_relativedelta(resolution)] @@ -1292,16 +1333,16 @@ def to_dict(self): class TimeSeriesVariableResolution(TimeSeries): - """A time-series value with variable time steps.""" + """A time-series with variable time steps.""" def __init__(self, indexes, values, ignore_year, repeat, index_name=""): """ Args: - indexes (Sequence): time stamps as numpy.datetime64 objects - values (Sequence): the values corresponding to the time stamps - ignore_year (bool): True if the stamp year should be ignored - repeat (bool): True if the series should be repeated from the beginning - index_name (str): index name + indexes (Sequence(:class:`~numpy.datetime64`)): the time stamps. + values (Sequence): the value for each time stamp. + ignore_year (bool): True if the year should be ignored. + repeat (bool): True if the series is repeating. + index_name (str): index name. """ super().__init__(values, ignore_year, repeat, index_name) if len(indexes) != len(values): @@ -1322,7 +1363,6 @@ def __init__(self, indexes, values, ignore_year, repeat, index_name=""): self.indexes = indexes def __eq__(self, other): - """Returns True if other is equal to this object.""" if not isinstance(other, TimeSeriesVariableResolution): return NotImplemented return ( @@ -1355,10 +1395,10 @@ class Map(IndexedValue): def __init__(self, indexes, values, index_type=None, index_name=""): """ Args: - indexes (Sequence): map's indexes - values (Sequence): map's values - index_type (type or NoneType): index type or None to deduce from indexes - index_name (str): index name + indexes (Sequence): the indexes in the map. + values (Sequence): the value for each index. + index_type (type or NoneType): index type or None to deduce from ``indexes``. + index_name (str): index name. """ if not indexes and index_type is None: raise ParameterValueFormatError("Cannot deduce index type from empty indexes list.") @@ -1377,7 +1417,11 @@ def __eq__(self, other): return other._indexes == self._indexes and other._values == self._values and self.index_name == other.index_name def is_nested(self): - """Returns True if any of the values is also a map.""" + """Whether any of the values is also a map. + + Returns: + bool: + """ return any(isinstance(value, Map) for value in self._values) def value_to_database_data(self): @@ -1404,10 +1448,10 @@ def to_dict(self): def map_dimensions(map_): - """Counts Map's dimensions. + """Counts the dimensions in a map. Args: - map_ (Map): a Map + map_ (:class:`Map`): the map to process. Returns: int: number of dimensions @@ -1423,17 +1467,18 @@ def map_dimensions(map_): def convert_leaf_maps_to_specialized_containers(map_): """ - Converts suitable leaf maps to corresponding specialized containers. + Converts leafs to specialized containers. - Currently supported conversions: + Current conversion rules: - - index_type: :class:`DateTime`, all values ``float`` -> :class"`TimeSeries` + - If the ``index_type`` is a :class:`DateTime` and all ``values`` are float, + then the leaf is converted to a :class:`TimeSeries`. Args: - map_ (Map): a map to process + map_ (:class:`Map`): a map to process. Returns: - IndexedValue: a map with leaves converted or specialized container if map was convertible in itself + :class:`IndexedValue`: a new map with leaves converted. """ converted_container = _try_convert_to_container(map_) if converted_container is not None: @@ -1452,13 +1497,13 @@ def convert_containers_to_maps(value): """ Converts indexed values into maps. - if ``value`` is :class:`Map` converts leaf values into Maps recursively. + If ``value`` is a :class:`Map` then converts leaf values into maps recursively. Args: - value (IndexedValue): a value to convert + value (:class:`IndexedValue`): an indexed value to convert. Returns: - Map: converted Map + :class:`Map`: converted Map """ if isinstance(value, Map): if not value: @@ -1484,10 +1529,10 @@ def convert_map_to_table(map_, make_square=True, row_this_far=None, empty=None): Converts :class:`Map` into list of rows recursively. Args: - map_ (Map): map to convert - make_square (bool): if True, append None to shorter rows, otherwise leave the row as is - row_this_far (list, optional): current row; used for recursion - empty (any, optional): object to fill empty cells with + map_ (:class:`Map`): map to convert. + make_square (bool): if True, then pad rows with None so they all have the same length. + row_this_far (list, optional): current row; used for recursion. + empty (any, optional): object to fill empty cells with. Returns: list of list: map's rows @@ -1514,13 +1559,13 @@ def convert_map_to_table(map_, make_square=True, row_this_far=None, empty=None): def convert_map_to_dict(map_): """ - Converts :class:`Map` to nested dictionaries. + Converts a :class:`Map` to a nested dictionary. Args: - map_ (Map): map to convert + map_ (:class:`Map`): map to convert Returns: - dict: Map as a dict + dict: """ d = dict() for index, x in zip(map_.indexes, map_.values): @@ -1567,7 +1612,7 @@ def join_value_and_type(db_value, db_type): db_type (str, optional): value type Returns: - str: parameter value as JSON with an additional `type` field. + str: parameter value as JSON with an additional ``type`` field. """ try: parsed = load_db_value(db_value, db_type) @@ -1578,14 +1623,12 @@ def join_value_and_type(db_value, db_type): def split_value_and_type(value_and_type): """Splits the given string into value and type. - The string must be the result of calling ``join_value_and_type`` or have the same form. Args: - value_and_type (str) + value_and_type (str): a string joining value and type, as obtained by calling :func:`join_value_and_type`. Returns: - bytes - str or NoneType + tuple(bytes,str): database value and type. """ try: parsed = json.loads(value_and_type) From 7406ac9602a2b477b1eeb988fd7f27898fb5ff61 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 4 Oct 2023 16:34:12 +0200 Subject: [PATCH 116/317] Minor touch to the docs --- spinedb_api/db_mapping.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 741c603e..aca67266 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -75,17 +75,17 @@ class DatabaseMapping( These methods also fetch data from the DB into the in-memory mapping to perform the necessary integrity checks (unique and foreign key constraints). - The :attr:`item_types` property contains the supported item types (equivalent to the table names in the DB). - - To retrieve an item or to manipulate it, you typically need to specify certain fields. - The :meth:`describe_item_type` method is provided to help you with this. - Modifications to the in-memory mapping are committed (written) to the DB via :meth:`commit_session`, or rolled back (discarded) via :meth:`rollback_session`. The DB fetch status is reset via :meth:`refresh_session`. This allows new items in the DB (added by other clients in the meantime) to be retrieved as well. + The :attr:`item_types` property contains the supported item types (equivalent to the table names in the DB). + + To retrieve an item or to manipulate it, you typically need to specify certain fields. + The :meth:`describe_item_type` method is provided to help you with this. + You can also control the fetching process via :meth:`fetch_more` and/or :meth:`fetch_all`. For example, a UI application might want to fetch data in the background so the UI is not blocked in the process. In that case they can call e.g. :meth:`fetch_more` asynchronously as the user scrolls or expands the views. From af34079821e87ff79bfc0e1df69ece8ad83cabfd Mon Sep 17 00:00:00 2001 From: Henrik Koski <98282892+PiispaH@users.noreply.github.com> Date: Thu, 5 Oct 2023 11:17:26 +0300 Subject: [PATCH 117/317] Fix importer fewer columns than expected Traceback (#277) Now when an importer spec is modified to have a column reference that is out of range for the source data, an error is shown. If such importer spec is executed an error will also be thrown and logged. Re spine-tools/Spine-Toolbox#2333 --- spinedb_api/import_mapping/generator.py | 3 +- spinedb_api/import_mapping/import_mapping.py | 33 +++++++++++++++++--- spinedb_api/spine_io/importers/reader.py | 4 +-- 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/spinedb_api/import_mapping/generator.py b/spinedb_api/import_mapping/generator.py index 1af95d9d..b73037a4 100644 --- a/spinedb_api/import_mapping/generator.py +++ b/spinedb_api/import_mapping/generator.py @@ -83,6 +83,7 @@ def get_mapped_data( rows = list(data_source) if not rows: return mapped_data, errors + column_count = len(max(rows, key=lambda x: len(x) if x else 0)) if column_convert_fns is None: column_convert_fns = {} if row_convert_fns is None: @@ -92,7 +93,7 @@ def get_mapped_data( for mapping in mappings: read_state = {} mapping = deepcopy(mapping) - mapping.polish(table_name, data_header) + mapping.polish(table_name, data_header, column_count) mapping_errors = check_validity(mapping) if mapping_errors: errors += mapping_errors diff --git a/spinedb_api/import_mapping/import_mapping.py b/spinedb_api/import_mapping/import_mapping.py index 5259a7f1..c07d4f3c 100644 --- a/spinedb_api/import_mapping/import_mapping.py +++ b/spinedb_api/import_mapping/import_mapping.py @@ -135,22 +135,44 @@ def read_start_row(self, row): raise ValueError(f"row must be >= 0 ({row})") self._read_start_row = row - def polish(self, table_name, source_header, for_preview=False): + def check_for_invalid_column_refs(self, header, table_name): + """Checks that the mappings column refs are not out of range for the source table + + Args: + header (list): The header of the table as a list + table_name (str): The name of the source table + + Returns: + str: Error message if a column ref exceeds the column count of the source table, + empty string otherwise + """ + if self.child is not None: + error = self.child.check_for_invalid_column_refs(header, table_name) + if error: + return error + if isinstance(self.position, int) and self.position >= len(header) > 0: + msg = f"Column ref {self.position + 1} is out of range for the source table \"{table_name}\"" + return msg + return "" + + def polish(self, table_name, source_header, column_count=0, for_preview=False): """Polishes the mapping before an import operation. 'Expands' transient ``position`` and ``value`` attributes into their final value. Args: table_name (str) source_header (list(str)) + column_count (int, optional) + for_preview (bool, optional) """ - self._polish_for_import(table_name, source_header) + self._polish_for_import(table_name, source_header, column_count) if for_preview: self._polish_for_preview(source_header) - def _polish_for_import(self, table_name, source_header): + def _polish_for_import(self, table_name, source_header, column_count): # FIXME: Polish skip columns if self.child is not None: - self.child._polish_for_import(table_name, source_header) + self.child._polish_for_import(table_name, source_header, column_count) if isinstance(self.position, str): # Column mapping with string position, we need to find the index in the header try: @@ -185,6 +207,9 @@ def _polish_for_import(self, table_name, source_header): except IndexError: msg = f"'{self.value}' is not a valid index in header '{source_header}'" raise InvalidMappingComponent(msg) + if isinstance(self.position, int) and self.position >= column_count > 0: + msg = f"Column ref {self.position + 1} is out of range for the source table \"{table_name}\"" + raise InvalidMappingComponent(msg) def _polish_for_preview(self, source_header): if self.position == Position.header and self.value is not None: diff --git a/spinedb_api/spine_io/importers/reader.py b/spinedb_api/spine_io/importers/reader.py index 92eda12f..051618b9 100644 --- a/spinedb_api/spine_io/importers/reader.py +++ b/spinedb_api/spine_io/importers/reader.py @@ -16,7 +16,7 @@ from itertools import islice -from spinedb_api.exception import ConnectorError +from spinedb_api.exception import ConnectorError, InvalidMappingComponent from spinedb_api.import_mapping.generator import get_mapped_data, identity from spinedb_api.import_mapping.import_mapping_compat import parse_named_mapping_spec from spinedb_api import DateTime, Duration, ParameterValueFormatError @@ -151,7 +151,7 @@ def get_mapped_data( row_convert_fns, unparse_value, ) - except (ConnectorError, ParameterValueFormatError) as error: + except (ConnectorError, ParameterValueFormatError, InvalidMappingComponent) as error: errors.append(str(error)) continue for key, value in data.items(): From 12e6bb85311cbc6c5bc3fdc3c2d7ba8d70fa5859 Mon Sep 17 00:00:00 2001 From: Pekka T Savolainen Date: Thu, 5 Oct 2023 12:06:51 +0300 Subject: [PATCH 118/317] Make docs/requirements.txt compatible with toolbox requirements - Set language = "en" to prevent a warning - Add builder: html just in case Re spine-tools/Spine-Toolbox#2342 --- .readthedocs.yml | 1 + docs/requirements.txt | 12 ++++++++---- docs/source/conf.py | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/.readthedocs.yml b/.readthedocs.yml index 9e9405d3..cecbda6d 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -13,6 +13,7 @@ build: # Build documentation in the docs/ directory with Sphinx sphinx: + builder: html configuration: docs/source/conf.py # Optionally build your docs in additional formats such as PDF and ePub diff --git a/docs/requirements.txt b/docs/requirements.txt index 02f01599..e42a014d 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,5 +1,9 @@ # Requirements for compiling the documentation -sphinx >= 7.1.2 -sphinx_rtd_theme >= 0.4.3 -recommonmark >= 0.6.0 -sphinx-autoapi >= 2.0.0 +markupsafe < 2.1 # Jinja2<3.0 tries to import soft_unicode, which has been removed in markupsafe 2.1 +jinja2 < 3.0 # Dagster 0.12.8 requires Jinja2<3.0 +docutils < 0.17 +sphinx < 5.2 +sphinx_rtd_theme +recommonmark +astroid < 3.0 # sphinx-autoapi installs the latest astroid. We are not compatible with astroid v3.0 +sphinx-autoapi < 2.1 # 2.1 removed support for sphinx < 5.2.0 diff --git a/docs/source/conf.py b/docs/source/conf.py index 626164a3..5ccb2d6d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -74,7 +74,7 @@ # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. -language = None +language = "en" # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. From 2acdeb0d87b5110747d114c3efe9e336df5fa6a5 Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 5 Oct 2023 15:56:56 +0200 Subject: [PATCH 119/317] Move things around a little more, keep trying to improve docs --- docs/source/conf.py | 50 ++- spinedb_api/db_mapping.py | 420 ++++++++++++++++-------- spinedb_api/db_mapping_add_mixin.py | 179 ---------- spinedb_api/db_mapping_base.py | 12 +- spinedb_api/db_mapping_commit_mixin.py | 180 +++++++--- spinedb_api/db_mapping_query_mixin.py | 29 -- spinedb_api/db_mapping_remove_mixin.py | 105 ------ spinedb_api/db_mapping_update_mixin.py | 202 ------------ spinedb_api/filters/execution_filter.py | 9 +- spinedb_api/mapped_items.py | 4 +- spinedb_api/purge.py | 2 +- spinedb_api/spine_db_server.py | 8 +- tests/custom_db_mapping.py | 117 +++++++ tests/test_DatabaseMapping.py | 72 ++-- tests/test_db_mapping_base.py | 4 +- tests/test_purge.py | 2 +- 16 files changed, 646 insertions(+), 749 deletions(-) delete mode 100644 spinedb_api/db_mapping_add_mixin.py delete mode 100644 spinedb_api/db_mapping_remove_mixin.py delete mode 100644 spinedb_api/db_mapping_update_mixin.py create mode 100644 tests/custom_db_mapping.py diff --git a/docs/source/conf.py b/docs/source/conf.py index 626164a3..cd506f0f 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -14,6 +14,7 @@ # import os import sys +from spinedb_api import DatabaseMapping root_path = os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir) sys.path.insert(0, os.path.abspath(root_path)) @@ -46,7 +47,7 @@ 'sphinx.ext.todo', 'sphinx.ext.coverage', 'sphinx.ext.ifconfig', - # 'sphinx.ext.viewcode', + 'sphinx.ext.viewcode', 'sphinx.ext.githubpages', 'sphinx.ext.napoleon', 'sphinx.ext.intersphinx', @@ -85,7 +86,7 @@ pygments_style = 'sphinx' # Settings for Sphinx AutoAPI -autoapi_options = ['members', 'inherited-members', 'show-module-summary'] +autoapi_options = ['members', 'show-module-summary', 'show-inheritance'] autoapi_python_class_content = "both" autoapi_add_toctree_entry = True autoapi_root = "autoapi" @@ -103,8 +104,53 @@ def _skip_member(app, what, name, obj, skip, options): return skip +def _process_docstring(app, what, name, obj, options, lines): + try: + i = lines.index("") + except ValueError: + pass + else: + new_lines = [] + for item_type in DatabaseMapping.item_types(): + factory = DatabaseMapping.item_factory(item_type) + if not factory._fields: + continue + new_lines.extend([item_type, len(item_type) * "-", ""]) + new_lines.extend( + [ + ".. list-table:: Fields and values", + " :header-rows: 1", + "", + " * - field", + " - type", + " - value", + ] + ) + for f_name, (f_type, f_value) in factory._fields.items(): + new_lines.extend([f" * - {f_name}", f" - {f_type}", f" - {f_value}"]) + new_lines.append("") + new_lines.extend( + [ + ".. list-table:: Unique keys", + " :header-rows: 0", + "", + ] + ) + for f_names in factory._unique_keys: + f_names = ", ".join(f_names) + new_lines.extend([f" * - {f_names}"]) + lines[i : i + 1] = new_lines + return + if what == "method": + spine_item_types = ", ".join([f"`{x}`" for x in DatabaseMapping.item_types()]) + for k, line in enumerate(lines): + if "" in line: + lines[k] = line.replace("", spine_item_types) + + def setup(sphinx): sphinx.connect("autoapi-skip-member", _skip_member) + sphinx.connect("autodoc-process-docstring", _process_docstring) # -- Options for HTML output ------------------------------------------------- diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index aca67266..adc44313 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -11,17 +11,23 @@ """ This module defines the :class:`.DatabaseMapping` class. + + +DB mapping schema +================= + """ import hashlib import os import time import logging +from datetime import datetime, timezone from types import MethodType from sqlalchemy import create_engine, MetaData, inspect from sqlalchemy.pool import NullPool from sqlalchemy.event import listen -from sqlalchemy.exc import DatabaseError +from sqlalchemy.exc import DatabaseError, DBAPIError from sqlalchemy.engine.url import make_url, URL from alembic.migration import MigrationContext from alembic.environment import EnvironmentContext @@ -33,31 +39,24 @@ from .spine_db_client import get_db_url_from_server from .mapped_items import item_factory from .db_mapping_base import DatabaseMappingBase -from .db_mapping_query_mixin import DatabaseMappingQueryMixin -from .db_mapping_add_mixin import DatabaseMappingAddMixin -from .db_mapping_update_mixin import DatabaseMappingUpdateMixin -from .db_mapping_remove_mixin import DatabaseMappingRemoveMixin from .db_mapping_commit_mixin import DatabaseMappingCommitMixin -from .exception import SpineDBAPIError, SpineDBVersionError +from .db_mapping_query_mixin import DatabaseMappingQueryMixin +from .exception import SpineDBAPIError, SpineDBVersionError, SpineIntegrityError +from .query import Query +from .compatibility import compatibility_transformations from .helpers import ( _create_first_spine_database, create_new_spine_database, compare_schemas, model_meta, copy_database_bind, + Asterisk, ) logging.getLogger("alembic").setLevel(logging.CRITICAL) -class DatabaseMapping( - DatabaseMappingAddMixin, - DatabaseMappingUpdateMixin, - DatabaseMappingRemoveMixin, - DatabaseMappingCommitMixin, - DatabaseMappingQueryMixin, - DatabaseMappingBase, -): +class DatabaseMapping(DatabaseMappingQueryMixin, DatabaseMappingCommitMixin, DatabaseMappingBase): """Enables communication with a Spine DB. The DB is incrementally mapped into memory as data is requested/modified. @@ -81,11 +80,6 @@ class DatabaseMapping( The DB fetch status is reset via :meth:`refresh_session`. This allows new items in the DB (added by other clients in the meantime) to be retrieved as well. - The :attr:`item_types` property contains the supported item types (equivalent to the table names in the DB). - - To retrieve an item or to manipulate it, you typically need to specify certain fields. - The :meth:`describe_item_type` method is provided to help you with this. - You can also control the fetching process via :meth:`fetch_more` and/or :meth:`fetch_all`. For example, a UI application might want to fetch data in the background so the UI is not blocked in the process. In that case they can call e.g. :meth:`fetch_more` asynchronously as the user scrolls or expands the views. @@ -93,28 +87,8 @@ class DatabaseMapping( The :meth:`query` method is also provided as an alternative way to retrieve data from the DB while bypassing the in-memory mapping entirely. - The class is intended to be used as a context manager. For example:: - - with DatabaseMapping(db_url) as db_map: - print(db_map.item_types) """ - ITEM_TYPES = ( - "entity_class", - "entity", - "entity_group", - "alternative", - "scenario", - "scenario_alternative", - "entity_alternative", - "parameter_value_list", - "list_value", - "parameter_definition", - "parameter_value", - "metadata", - "entity_metadata", - "parameter_value_metadata", - ) _sq_name_by_item_type = { "entity_class": "wide_entity_class_sq", "entity": "wide_entity_sq", @@ -148,14 +122,14 @@ def __init__( Args: db_url (str or :class:`~sqlalchemy.engine.url.URL`): A URL in RFC-1738 format pointing to the database to be mapped, or to a DB server. - username (str, optional): A user name. If not given, it gets replaced by the string ``"anon"``. - upgrade (bool, optional): Whether the db at the given URL should be upgraded to the most recent + username (str, optional): A user name. If not given, it gets replaced by the string `anon`. + upgrade (bool, optional): Whether the DB at the given `url` should be upgraded to the most recent version. - codename (str, optional): A name to associate with the DB mapping. - create (bool, optional): Whether to create a Spine db at the given URL if it's not one already. - apply_filters (bool, optional): Whether to apply filters in the URL's query part. - memory (bool, optional): Whether or not to use a sqlite memory db as replacement for this DB map. - sqlite_timeout (int, optional): How many seconds to wait before raising connection errors. + codename (str, optional): A name to identify this object in your application. + create (bool, optional): Whether to create a new Spine DB at the given `url` if it's not already one. + apply_filters (bool, optional): Whether to apply filters in the `url`'s query segment. + memory (bool, optional): Whether to use a SQLite memory DB as replacement for the original one. + sqlite_timeout (int, optional): The number of seconds to wait before raising SQLite connection errors. """ super().__init__() # FIXME: We should also check the server memory property and use it here @@ -188,19 +162,6 @@ def __init__( if self._filter_configs is not None: stack = load_filters(self._filter_configs) apply_filter_stack(self, stack) - # Table primary ids map: - self._id_fields = { - "entity_class_dimension": "entity_class_id", - "entity_element": "entity_id", - "object_class": "entity_class_id", - "relationship_class": "entity_class_id", - "object": "entity_id", - "relationship": "entity_id", - } - self.composite_pks = { - "entity_element": ("entity_id", "position"), - "entity_class_dimension": ("entity_class_id", "position"), - } def __enter__(self): return self @@ -211,9 +172,9 @@ def __exit__(self, _exc_type, _exc_val, _exc_tb): def __del__(self): self.close() - @property - def item_types(self): - return list(self._sq_name_by_item_type) + @staticmethod + def item_types(): + return list(DatabaseMapping._sq_name_by_item_type) @staticmethod def item_factory(item_type): @@ -319,13 +280,6 @@ def _receive_engine_close(self, dbapi_con, _connection_record): if self._memory_dirty: copy_database_bind(self._original_engine, self.engine) - def _get_primary_key(self, tablename): - pk = self.composite_pks.get(tablename) - if pk is None: - id_field = self._id_fields.get(tablename, "id") - pk = (id_field,) - return pk - @staticmethod def _real_tablename(tablename): return { @@ -386,18 +340,13 @@ def get_table(self, tablename): return self._metadata.tables[tablename] def get_item(self, item_type, fetch=True, skip_removed=True, **kwargs): - """Finds and returns and item matching the arguments, or None if none found. - - Example:: - with DatabaseMapping(db_url) as db_map: - bar = db_map.get_item("entity", class_name="foo", name="bar") - print(bar["description"]) # Prints the description field + """Finds and returns an item matching the arguments, or None if none found. Args: - item_type (str): The type of the item. + item_type (str): One of . fetch (bool, optional): Whether to fetch the DB in case the item is not found in memory. skip_removed (bool, optional): Whether to ignore removed items. - **kwargs: Fields and values for one of the unique keys of the item type. + **kwargs: Fields and values for one the unique keys as specified for the item type in `DB mapping schema`_. Returns: :class:`PublicItem` or None @@ -411,20 +360,15 @@ def get_item(self, item_type, fetch=True, skip_removed=True, **kwargs): return PublicItem(self, cache_item) def get_items(self, item_type, fetch=True, skip_removed=True): - """Finds and returns and item matching the arguments, or None if none found. - - - Example:: - with DatabaseMapping(db_url) as db_map: - all_entities = db_map.get_items("entity") + """Finds and returns all the items of one type. Args: - item_type (str): The type of items to get. + item_type (str): One of . fetch (bool, optional): Whether to fetch the DB before returning the items. skip_removed (bool, optional): Whether to ignore removed items. Returns: - :class:`PublicItem` or None + list(:class:`PublicItem`): The items. """ item_type = self._real_tablename(item_type) if fetch and item_type not in self.fetched_item_types: @@ -436,16 +380,10 @@ def get_items(self, item_type, fetch=True, skip_removed=True): def add_item(self, item_type, check=True, **kwargs): """Adds an item to the in-memory mapping. - Example:: - - with DatabaseMapping(url) as db_map: - db_map.add_item("entity", class_name="dog", name="Pete") - - Args: - item_type (str): The type of the item. + item_type (str): One of . check (bool, optional): Whether to carry out integrity checks. - **kwargs: Mandatory fields for the item type and their values. + **kwargs: Fields and values as specified for the item type in `DB mapping schema`_. Returns: tuple(:class:`PublicItem` or None, str): The added item and any errors. @@ -461,23 +399,42 @@ def add_item(self, item_type, check=True, **kwargs): error, ) - def update_item(self, item_type, check=True, **kwargs): - """Updates an item in the in-memory mapping. + def add_items(self, item_type, *items, check=True, strict=False): + """Add many items to the in-memory mapping. - Example:: + Args: + item_type (str): One of . + *items (Iterable(dict)): One or more :class:`dict` objects mapping fields to values, + as specified for the item type in `DB mapping schema`_. + check (bool): Whether or not to run integrity checks. + strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` + if the insertion of one of the items violates an integrity constraint. - with DatabaseMapping(url) as db_map: - my_dog = db_map.get_item("entity", class_name="dog", name="Pete") - db_map.update_item("entity", id=my_dog["id], name="Pluto") + Returns: + tuple(list(:class:`PublicItem`),list(str)): items successfully added and found violations. + """ + added, errors = [], [] + for item in items: + item, error = self.add_item(item_type, check, **item) + if error: + if strict: + raise SpineIntegrityError(error) + errors.append(error) + continue + added.append(item) + return added, errors + + def update_item(self, item_type, check=True, **kwargs): + """Updates an item in the in-memory mapping. Args: - item_type (str): The type of the item. + item_type (str): One of . check (bool, optional): Whether to carry out integrity checks. id (int): The id of the item to update. - **kwargs: Fields to update and their new values. + **kwargs: Fields to update and their new values as specified for the item type in `DB mapping schema`_. Returns: - tuple(:class:`PublicItem` or None, str): The added item and any errors. + tuple(:class:`PublicItem` or None, str): The updated item and any errors. """ item_type = self._real_tablename(item_type) mapped_table = self.mapped_table(item_type) @@ -487,6 +444,31 @@ def update_item(self, item_type, check=True, **kwargs): checked_item, error = mapped_table.check_item(kwargs, for_update=True) return (PublicItem(self, mapped_table.update_item(checked_item._asdict())) if checked_item else None, error) + def update_items(self, item_type, *items, check=True, strict=False): + """Updates many items in the in-memory mapping. + + Args: + item_type (str): One of . + *items (Iterable(dict)): One or more :class:`dict` objects mapping fields to values, + as specified for the item type in `DB mapping schema`_ and including the `id`. + check (bool): Whether or not to run integrity checks. + strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` + if the update of one of the items violates an integrity constraint. + + Returns: + tuple(list(:class:`PublicItem`),list(str)): items successfully updated and found violations. + """ + updated, errors = [], [] + for item in items: + item, error = self.update_item(item_type, check=check, **item) + if error: + if strict: + raise SpineIntegrityError(error) + errors.append(error) + if item: + updated.append(item) + return updated, errors + def remove_item(self, item_type, id_): """Removes an item from the in-memory mapping. @@ -498,7 +480,7 @@ def remove_item(self, item_type, id_): Args: - item_type (str): The type of the item. + item_type (str): One of . id (int): The id of the item to remove. Returns: @@ -508,6 +490,29 @@ def remove_item(self, item_type, id_): mapped_table = self.mapped_table(item_type) return PublicItem(self, mapped_table.remove_item(id_)) + def remove_items(self, item_type, *ids): + """Removes many items from the in-memory mapping. + + Args: + item_type (str): One of . + *ids (Iterable(int)): Ids of items to be removed. + + Returns: + list(:class:`PublicItem`): the removed items. + """ + if not ids: + return [] + item_type = self._real_tablename(item_type) + mapped_table = self.mapped_table(item_type) + if Asterisk in ids: + self.fetch_all(item_type) + ids = mapped_table + ids = set(ids) + if item_type == "alternative": + # Do not remove the Base alternative + ids.discard(1) + return [self.remove_item(item_type, id_) for id_ in ids] + def restore_item(self, item_type, id_): """Restores a previously removed item into the in-memory mapping. @@ -518,7 +523,7 @@ def restore_item(self, item_type, id_): db_map.restore_item("entity", my_dog["id]) Args: - item_type (str): The type of the item. + item_type (str): One of . id (int): The id of the item to restore. Returns: @@ -528,11 +533,36 @@ def restore_item(self, item_type, id_): mapped_table = self.mapped_table(item_type) return PublicItem(self, mapped_table.restore_item(id_)) + def restore_items(self, item_type, *ids): + """Restores many previously removed items into the in-memory mapping. + + Args: + item_type (str): One of . + *ids (Iterable(int)): Ids of items to be removed. + + Returns: + list(:class:`PublicItem`): the restored items. + """ + if not ids: + return [] + return [self.restore_item(item_type, id_) for id_ in ids] + + def purge_items(self, item_type): + """Removes all items of one type. + + Args: + item_type (str): One of . + + Returns: + bool: True if operation was successful, False otherwise + """ + return self.remove_items(item_type, Asterisk) + def can_fetch_more(self, item_type): """Whether or not more data can be fetched from the DB for the given item type. Args: - item_type (str): The item type (table) to check. + item_type (str): One of . Returns: bool @@ -543,7 +573,7 @@ def fetch_more(self, item_type, limit=None): """Fetches items from the DB into the in-memory mapping, incrementally. Args: - item_type (str): The item type (table) to fetch. + item_type (str): One of . limit (int): The maximum number of items to fetch. Successive calls to this function will start from the point where the last one left. In other words, each item is fetched from the DB exactly once. @@ -559,41 +589,165 @@ def fetch_all(self, *item_types): Unlike :meth:`fetch_more`, this method fetches entire tables. Args: - *item_types (str): The item types (tables) to fetch. If none given, then the entire DB is fetched. + *item_types (Iterable(str)): One or more of . + If none given, then the entire DB is fetched. """ - item_types = set(self.ITEM_TYPES) if not item_types else set(item_types) & set(self.ITEM_TYPES) + item_types = set(self.item_types()) if not item_types else set(item_types) & set(self.item_types()) for item_type in item_types: item_type = self._real_tablename(item_type) self.do_fetch_all(item_type) - def describe_item_type(self, item_type): - """Prints a synopsis of the given item type to the stdout. + def query(self, *args, **kwargs): + """Returns a :class:`~spinedb_api.query.Query` object to execute against the mapped DB. + + To perform custom ``SELECT`` statements, call this method with one or more of the documented + subquery properties of :class:`~spinedb_api.DatabaseMappingQueryMixin` returning + :class:`~sqlalchemy.sql.expression.Alias` objetcs. + For example, to select the entity class with ``id`` equal to 1:: + + from spinedb_api import DatabaseMapping + url = 'sqlite:///spine.db' + ... + db_map = DatabaseMapping(url) + db_map.query(db_map.entity_class_sq).filter_by(id=1).one_or_none() + + To perform more complex queries, just use the :class:`~spinedb_api.query.Query` interface + (which is a close clone of SQL Alchemy's :class:`~sqlalchemy.orm.query.Query`). + For example, to select all entity class names and the names of their entities concatenated in a comma-separated + string:: + + from sqlalchemy import func + + db_map.query( + db_map.entity_class_sq.c.name, func.group_concat(db_map.entity_sq.c.name) + ).filter( + db_map.entity_sq.c.class_id == db_map.entity_class_sq.c.id + ).group_by(db_map.entity_class_sq.c.name).all() + """ + return Query(self.engine, *args) + + def commit_session(self, comment): + """Commits the changes from the in-memory mapping to the database. + + Args: + comment (str): commit message + """ + if not comment: + raise SpineDBAPIError("Commit message cannot be empty.") + dirty_items = self._dirty_items() + if not dirty_items: + raise SpineDBAPIError("Nothing to commit.") + user = self.username + date = datetime.now(timezone.utc) + ins = self._metadata.tables["commit"].insert() + with self.engine.begin() as connection: + try: + commit_id = connection.execute(ins, dict(user=user, date=date, comment=comment)).inserted_primary_key[0] + except DBAPIError as e: + raise SpineDBAPIError(f"Fail to commit: {e.orig.args}") from e + for tablename, (to_add, to_update, to_remove) in dirty_items: + for item in to_add + to_update + to_remove: + item.commit(commit_id) + # Remove before add, to help with keeping integrity constraints + self._do_remove_items(connection, tablename, *{x["id"] for x in to_remove}) + self._do_update_items(connection, tablename, *to_update) + self._do_add_items(connection, tablename, *to_add) + if self._memory: + self._memory_dirty = True + return compatibility_transformations(connection) + + def rollback_session(self): + """Discards all the changes from the in-memory mapping.""" + if not self._rollback(): + raise SpineDBAPIError("Nothing to rollback.") + if self._memory: + self._memory_dirty = False + + def refresh_session(self): + """Resets the fetch status so new items from the DB can be retrieved.""" + self._refresh() + + def add_ext_entity_metadata(self, *items, **kwargs): + metadata_items = self.get_metadata_to_add_with_item_metadata_items(*items) + self.add_items("metadata", *metadata_items, **kwargs) + return self.add_items("entity_metadata", *items, **kwargs) + + def add_ext_parameter_value_metadata(self, *items, **kwargs): + metadata_items = self.get_metadata_to_add_with_item_metadata_items(*items) + self.add_items("metadata", *metadata_items, **kwargs) + return self.add_items("parameter_value_metadata", *items, **kwargs) + + def get_metadata_to_add_with_item_metadata_items(self, *items): + metadata_items = ({"name": item["metadata_name"], "value": item["metadata_value"]} for item in items) + return [x for x in metadata_items if not self.mapped_table("metadata").find_item(x)] + + def _update_ext_item_metadata(self, tablename, *items, **kwargs): + metadata_items = self.get_metadata_to_add_with_item_metadata_items(*items) + added, errors = self.add_items("metadata", *metadata_items, **kwargs) + updated, more_errors = self.update_items(tablename, *items, **kwargs) + return added + updated, errors + more_errors + + def update_ext_entity_metadata(self, *items, **kwargs): + return self._update_ext_item_metadata("entity_metadata", *items, **kwargs) + + def update_ext_parameter_value_metadata(self, *items, **kwargs): + return self._update_ext_item_metadata("parameter_value_metadata", *items, **kwargs) + + def get_data_to_set_scenario_alternatives(self, *scenarios, strict=True): + """Returns data to add and remove, in order to set wide scenario alternatives. Args: - item_type (str): The type of item to describe. + *scenarios: One or more wide scenario :class:`dict` objects to set. + Each item must include the following keys: + + - "id": integer scenario id + - "alternative_id_list": list of alternative ids for that scenario + + Returns + list: scenario_alternative :class:`dict` objects to add. + set: integer scenario_alternative ids to remove """ - factory = self.item_factory(item_type) - sections = ("Fields:", "Unique keys:") - width = max(len(s) for s in sections) + 4 - print() - print(item_type) - print("-" * len(item_type)) - section = sections[0] - field_iter = (f"{field} ({type_}) - {description}" for field, (type_, description) in factory._fields.items()) - _print_section(section, width, field_iter) - print() - section = sections[1] - unique_key_iter = ("(" + ", ".join(key) + ")" for key in factory._unique_keys) - _print_section(section, width, unique_key_iter) - print() - - -def _print_section(section, width, iterator): - row = next(iterator) - bullet = "- " - print(f"{section:<{width}}" + bullet + row) - for row in iterator: - print(" " * width + bullet + row) + + def _is_equal(to_add, to_rm): + return all(to_rm[k] == v for k, v in to_add.items()) + + scen_alts_to_add = [] + scen_alt_ids_to_remove = {} + errors = [] + for scen in scenarios: + current_scen = self.mapped_table("scenario").find_item(scen) + if current_scen is None: + error = f"no scenario matching {scen} to set alternatives for" + if strict: + raise SpineIntegrityError(error) + errors.append(error) + continue + for k, alternative_id in enumerate(scen.get("alternative_id_list", ())): + item_to_add = {"scenario_id": current_scen["id"], "alternative_id": alternative_id, "rank": k + 1} + scen_alts_to_add.append(item_to_add) + for k, alternative_name in enumerate(scen.get("alternative_name_list", ())): + item_to_add = {"scenario_id": current_scen["id"], "alternative_name": alternative_name, "rank": k + 1} + scen_alts_to_add.append(item_to_add) + for alternative_id in current_scen["alternative_id_list"]: + scen_alt = {"scenario_id": current_scen["id"], "alternative_id": alternative_id} + current_scen_alt = self.mapped_table("scenario_alternative").find_item(scen_alt) + scen_alt_ids_to_remove[current_scen_alt["id"]] = current_scen_alt + # Remove items that are both to add and to remove + for id_, to_rm in list(scen_alt_ids_to_remove.items()): + i = next((i for i, to_add in enumerate(scen_alts_to_add) if _is_equal(to_add, to_rm)), None) + if i is not None: + del scen_alts_to_add[i] + del scen_alt_ids_to_remove[id_] + return scen_alts_to_add, set(scen_alt_ids_to_remove), errors + + def remove_unused_metadata(self): + used_metadata_ids = set() + for x in self.mapped_table("entity_metadata").valid_values(): + used_metadata_ids.add(x["metadata_id"]) + for x in self.mapped_table("parameter_value_metadata").valid_values(): + used_metadata_ids.add(x["metadata_id"]) + unused_metadata_ids = {x["id"] for x in self.mapped_table("metadata").valid_values()} - used_metadata_ids + self.remove_items("metadata", *unused_metadata_ids) class PublicItem: diff --git a/spinedb_api/db_mapping_add_mixin.py b/spinedb_api/db_mapping_add_mixin.py deleted file mode 100644 index 6a7cc734..00000000 --- a/spinedb_api/db_mapping_add_mixin.py +++ /dev/null @@ -1,179 +0,0 @@ -###################################################################################################################### -# Copyright (C) 2017-2022 Spine project consortium -# This file is part of Spine Database API. -# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser -# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your -# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; -# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General -# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with -# this program. If not, see . -###################################################################################################################### - -# TODO: improve docstrings - -from sqlalchemy.exc import DBAPIError -from .exception import SpineIntegrityError, SpineDBAPIError -from .temp_id import TempId, resolve - - -class DatabaseMappingAddMixin: - """Provides methods to perform ``INSERT`` operations over a Spine db.""" - - def add_items(self, tablename, *items, check=True, strict=False): - """Add items to the in-memory mapping. - - Args: - tablename (str): The table where items are inserted. - items (Iterable): One or more :class:`dict` objects representing the items to be inserted. - check (bool): Whether or not to run integrity checks. - strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` - if the insertion of one of the items violates an integrity constraint. - - Returns: - tuple(list(dict),list(str)): items successfully added and found violations. - """ - added, errors = [], [] - for item in items: - item, error = self.add_item(tablename, check, **item) - if error: - if strict: - raise SpineIntegrityError(error) - errors.append(error) - continue - added.append(item) - return added, errors - - def _do_add_items(self, connection, tablename, *items_to_add): - """Add items to DB without checking integrity.""" - if not items_to_add: - return - try: - table = self._metadata.tables[self._real_tablename(tablename)] - id_items, temp_id_items = [], [] - for item in items_to_add: - if isinstance(item["id"], TempId): - temp_id_items.append(item) - else: - id_items.append(item) - if id_items: - connection.execute(table.insert(), [resolve(x._asdict()) for x in id_items]) - if temp_id_items: - current_ids = {x["id"] for x in connection.execute(table.select())} - next_id = max(current_ids, default=0) + 1 - available_ids = set(range(1, next_id)) - current_ids - required_id_count = len(temp_id_items) - len(available_ids) - new_ids = set(range(next_id, next_id + required_id_count)) - ids = sorted(available_ids | new_ids) - for id_, item in zip(ids, temp_id_items): - temp_id = item["id"] - temp_id.resolve(id_) - connection.execute(table.insert(), [resolve(x._asdict()) for x in temp_id_items]) - for tablename_, items_to_add_ in self._extra_items_to_add_per_table(tablename, items_to_add): - if not items_to_add_: - continue - table = self._metadata.tables[self._real_tablename(tablename_)] - connection.execute(table.insert(), [resolve(x) for x in items_to_add_]) - except DBAPIError as e: - msg = f"DBAPIError while inserting {tablename} items: {e.orig.args}" - raise SpineDBAPIError(msg) from e - - @staticmethod - def _extra_items_to_add_per_table(tablename, items_to_add): - """ - Yields tuples of string tablename, list of items to insert. Needed because some insert queries - actually need to insert records to more than one table. - - Args: - tablename (str): target database table name - items_to_add (list): items to add - - Yields: - tuple: database table name, items to add - """ - if tablename == "entity_class": - ecd_items_to_add = [ - {"entity_class_id": item["id"], "position": position, "dimension_id": dimension_id} - for item in items_to_add - for position, dimension_id in enumerate(item["dimension_id_list"]) - ] - yield ("entity_class_dimension", ecd_items_to_add) - elif tablename == "entity": - ee_items_to_add = [ - { - "entity_id": item["id"], - "entity_class_id": item["class_id"], - "position": position, - "element_id": element_id, - "dimension_id": dimension_id, - } - for item in items_to_add - for position, (element_id, dimension_id) in enumerate( - zip(item["element_id_list"], item["dimension_id_list"]) - ) - ] - yield ("entity_element", ee_items_to_add) - - def add_object_classes(self, *items, **kwargs): - return self.add_items("object_class", *items, **kwargs) - - def add_objects(self, *items, **kwargs): - return self.add_items("object", *items, **kwargs) - - def add_entity_classes(self, *items, **kwargs): - return self.add_items("entity_class", *items, **kwargs) - - def add_entities(self, *items, **kwargs): - return self.add_items("entity", *items, **kwargs) - - def add_wide_relationship_classes(self, *items, **kwargs): - return self.add_items("relationship_class", *items, **kwargs) - - def add_wide_relationships(self, *items, **kwargs): - return self.add_items("relationship", *items, **kwargs) - - def add_parameter_definitions(self, *items, **kwargs): - return self.add_items("parameter_definition", *items, **kwargs) - - def add_parameter_values(self, *items, **kwargs): - return self.add_items("parameter_value", *items, **kwargs) - - def add_parameter_value_lists(self, *items, **kwargs): - return self.add_items("parameter_value_list", *items, **kwargs) - - def add_list_values(self, *items, **kwargs): - return self.add_items("list_value", *items, **kwargs) - - def add_alternatives(self, *items, **kwargs): - return self.add_items("alternative", *items, **kwargs) - - def add_scenarios(self, *items, **kwargs): - return self.add_items("scenario", *items, **kwargs) - - def add_scenario_alternatives(self, *items, **kwargs): - return self.add_items("scenario_alternative", *items, **kwargs) - - def add_entity_groups(self, *items, **kwargs): - return self.add_items("entity_group", *items, **kwargs) - - def add_metadata(self, *items, **kwargs): - return self.add_items("metadata", *items, **kwargs) - - def add_entity_metadata(self, *items, **kwargs): - return self.add_items("entity_metadata", *items, **kwargs) - - def add_parameter_value_metadata(self, *items, **kwargs): - return self.add_items("parameter_value_metadata", *items, **kwargs) - - def add_ext_entity_metadata(self, *items, **kwargs): - metadata_items = self.get_metadata_to_add_with_item_metadata_items(*items) - self.add_items("metadata", *metadata_items, **kwargs) - return self.add_items("entity_metadata", *items, **kwargs) - - def add_ext_parameter_value_metadata(self, *items, **kwargs): - metadata_items = self.get_metadata_to_add_with_item_metadata_items(*items) - self.add_items("metadata", *metadata_items, **kwargs) - return self.add_items("parameter_value_metadata", *items, **kwargs) - - def get_metadata_to_add_with_item_metadata_items(self, *items): - metadata_items = ({"name": item["metadata_name"], "value": item["metadata_value"]} for item in items) - return [x for x in metadata_items if not self.mapped_table("metadata").find_item(x)] diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index cc7bdf5e..9a2d12b0 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -34,7 +34,7 @@ class DatabaseMappingBase: This class is not meant to be used directly. Instead, you should subclass it to fit your particular DB schema. - When subclassing, you need to implement :attr:`item_types`, :meth:`item_factory`, and :meth:`make_query`. + When subclassing, you need to implement :meth:`item_types`, :meth:`item_factory`, and :meth:`make_query`. """ def __init__(self): @@ -42,7 +42,7 @@ def __init__(self): self._offsets = {} self._offset_lock = threading.Lock() self._fetched_item_types = set() - item_types = self.item_types + item_types = self.item_types() self._sorted_item_types = [] while item_types: item_type = item_types.pop(0) @@ -60,9 +60,9 @@ def fetched_item_types(self): """ return self._fetched_item_types - @property - def item_types(self): - """Returns a list of item types from the DB schema (equivalent to the table names). + @staticmethod + def item_types(): + """Returns a list of item types from the DB mapping schema (equivalent to the table names). Returns: list(str) @@ -131,7 +131,7 @@ def _dirty_items(self): # This ensures cascade removal. # FIXME: We should also fetch the current item type because of multi-dimensional entities and # classes which also depend on zero-dimensional ones - for other_item_type in self.item_types: + for other_item_type in self.item_types(): if ( other_item_type not in self.fetched_item_types and item_type in self.item_factory(other_item_type).ref_types() diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py index 20cc8897..165ed007 100644 --- a/spinedb_api/db_mapping_commit_mixin.py +++ b/spinedb_api/db_mapping_commit_mixin.py @@ -9,52 +9,152 @@ # this program. If not, see . ###################################################################################################################### -from datetime import datetime, timezone -import sqlalchemy.exc +from sqlalchemy import and_, or_ +from sqlalchemy.sql.expression import bindparam +from sqlalchemy.exc import DBAPIError from .exception import SpineDBAPIError -from .compatibility import compatibility_transformations +from .temp_id import TempId, resolve +from .helpers import group_consecutive class DatabaseMappingCommitMixin: - """Provides methods to commit or rollback pending changes onto a Spine database.""" + _id_fields = { + "entity_class_dimension": "entity_class_id", + "entity_element": "entity_id", + "object_class": "entity_class_id", + "relationship_class": "entity_class_id", + "object": "entity_id", + "relationship": "entity_id", + } + composite_pks = { + "entity_element": ("entity_id", "position"), + "entity_class_dimension": ("entity_class_id", "position"), + } - def commit_session(self, comment): - """Commits the changes from the in-memory mapping to the database. + def _do_add_items(self, connection, tablename, *items_to_add): + """Add items to DB without checking integrity.""" + if not items_to_add: + return + try: + table = self._metadata.tables[self._real_tablename(tablename)] + id_items, temp_id_items = [], [] + for item in items_to_add: + if isinstance(item["id"], TempId): + temp_id_items.append(item) + else: + id_items.append(item) + if id_items: + connection.execute(table.insert(), [resolve(x._asdict()) for x in id_items]) + if temp_id_items: + current_ids = {x["id"] for x in connection.execute(table.select())} + next_id = max(current_ids, default=0) + 1 + available_ids = set(range(1, next_id)) - current_ids + required_id_count = len(temp_id_items) - len(available_ids) + new_ids = set(range(next_id, next_id + required_id_count)) + ids = sorted(available_ids | new_ids) + for id_, item in zip(ids, temp_id_items): + temp_id = item["id"] + temp_id.resolve(id_) + connection.execute(table.insert(), [resolve(x._asdict()) for x in temp_id_items]) + for tablename_, items_to_add_ in self._extra_items_to_add_per_table(tablename, items_to_add): + if not items_to_add_: + continue + table = self._metadata.tables[self._real_tablename(tablename_)] + connection.execute(table.insert(), [resolve(x) for x in items_to_add_]) + except DBAPIError as e: + msg = f"DBAPIError while inserting {tablename} items: {e.orig.args}" + raise SpineDBAPIError(msg) from e + + @staticmethod + def _dimensions_for_classes(classes): + return [ + {"entity_class_id": x["id"], "position": position, "dimension_id": dimension_id} + for x in classes + for position, dimension_id in enumerate(x["dimension_id_list"]) + ] + + @staticmethod + def _elements_for_entities(entities): + return [ + { + "entity_id": x["id"], + "entity_class_id": x["class_id"], + "position": position, + "element_id": element_id, + "dimension_id": dimension_id, + } + for x in entities + for position, (element_id, dimension_id) in enumerate(zip(x["element_id_list"], x["dimension_id_list"])) + ] + + def _extra_items_to_add_per_table(self, tablename, items_to_add): + if tablename == "entity_class": + yield ("entity_class_dimension", self._dimensions_for_classes(items_to_add)) + elif tablename == "entity": + yield ("entity_element", self._elements_for_entities(items_to_add)) + + def _extra_items_to_update_per_table(self, tablename, items_to_update): + if tablename == "entity": + yield ("entity_element", self._elements_for_entities(items_to_update)) + + def _get_primary_key(self, tablename): + pk = self.composite_pks.get(tablename) + if pk is None: + id_field = self._id_fields.get(tablename, "id") + pk = (id_field,) + return pk + + def _make_update_stmt(self, tablename, keys): + table = self._metadata.tables[self._real_tablename(tablename)] + upd = table.update() + for k in self._get_primary_key(tablename): + upd = upd.where(getattr(table.c, k) == bindparam(k)) + return upd.values({key: bindparam(key) for key in table.columns.keys() & keys}) + + def _do_update_items(self, connection, tablename, *items_to_update): + """Update items in DB without checking integrity.""" + if not items_to_update: + return + try: + upd = self._make_update_stmt(tablename, items_to_update[0].keys()) + connection.execute(upd, [resolve(item._asdict()) for item in items_to_update]) + for tablename_, items_to_update_ in self._extra_items_to_update_per_table(tablename, items_to_update): + if not items_to_update_: + continue + upd = self._make_update_stmt(tablename_, items_to_update_[0].keys()) + connection.execute(upd, [resolve(x) for x in items_to_update_]) + except DBAPIError as e: + msg = f"DBAPIError while updating '{tablename}' items: {e.orig.args}" + raise SpineDBAPIError(msg) from e + + def _do_remove_items(self, connection, tablename, *ids): + """Removes items from the db. Args: - comment (str): commit message + *ids: ids to remove """ - if not comment: - raise SpineDBAPIError("Commit message cannot be empty.") - dirty_items = self._dirty_items() - if not dirty_items: - raise SpineDBAPIError("Nothing to commit.") - user = self.username - date = datetime.now(timezone.utc) - ins = self._metadata.tables["commit"].insert() - with self.engine.begin() as connection: + tablename = self._real_tablename(tablename) + ids = {resolve(id_) for id_ in ids} + if tablename == "alternative": + # Do not remove the Base alternative + ids.discard(1) + if not ids: + return + tablenames = [tablename] + if tablename == "entity_class": + # Also remove the items corresponding to the id in entity_class_dimension + tablenames.append("entity_class_dimension") + elif tablename == "entity": + # Also remove the items corresponding to the id in entity_element + tablenames.append("entity_element") + for tablename_ in tablenames: + table = self._metadata.tables[tablename_] + id_field = self._id_fields.get(tablename_, "id") + id_column = getattr(table.c, id_field) + cond = or_(*(and_(id_column >= first, id_column <= last) for first, last in group_consecutive(ids))) + delete = table.delete().where(cond) try: - commit_id = connection.execute(ins, dict(user=user, date=date, comment=comment)).inserted_primary_key[0] - except sqlalchemy.exc.DBAPIError as e: - raise SpineDBAPIError(f"Fail to commit: {e.orig.args}") from e - for tablename, (to_add, to_update, to_remove) in dirty_items: - for item in to_add + to_update + to_remove: - item.commit(commit_id) - # Remove before add, to help with keeping integrity constraints - self._do_remove_items(connection, tablename, *{x["id"] for x in to_remove}) - self._do_update_items(connection, tablename, *to_update) - self._do_add_items(connection, tablename, *to_add) - if self._memory: - self._memory_dirty = True - return compatibility_transformations(connection) - - def rollback_session(self): - """Discards all the changes from the in-memory mapping.""" - if not self._rollback(): - raise SpineDBAPIError("Nothing to rollback.") - if self._memory: - self._memory_dirty = False - - def refresh_session(self): - """Resets the fetch status so new items from the DB can be retrieved.""" - self._refresh() + connection.execute(delete) + except DBAPIError as e: + msg = f"DBAPIError while removing {tablename_} items: {e.orig.args}" + raise SpineDBAPIError(msg) from e diff --git a/spinedb_api/db_mapping_query_mixin.py b/spinedb_api/db_mapping_query_mixin.py index fb0344e3..54b4c0ca 100644 --- a/spinedb_api/db_mapping_query_mixin.py +++ b/spinedb_api/db_mapping_query_mixin.py @@ -14,7 +14,6 @@ from sqlalchemy.sql.expression import Alias, label from sqlalchemy.orm import aliased from .helpers import forward_sweep, group_concat -from .query import Query class DatabaseMappingQueryMixin: @@ -111,34 +110,6 @@ def _clear_subqueries(self, *tablenames): for attr_name in attr_names: setattr(self, attr_name, None) - def query(self, *args, **kwargs): - """Returns a :class:`~spinedb_api.query.Query` object to execute against the mapped DB. - - To perform custom ``SELECT`` statements, call this method with one or more of the class documented - subquery properties (of :class:`~sqlalchemy.sql.expression.Alias` type). - For example, to select the entity class with ``id`` equal to 1:: - - from spinedb_api import DatabaseMapping - url = 'sqlite:///spine.db' - ... - db_map = DatabaseMapping(url) - db_map.query(db_map.entity_class_sq).filter_by(id=1).one_or_none() - - To perform more complex queries, just use the :class:`~spinedb_api.query.Query` interface - (which is a close clone of SQL Alchemy's :class:`~sqlalchemy.orm.query.Query`). - For example, to select all entity class names and the names of their entities concatenated in a comma-separated - string:: - - from sqlalchemy import func - - db_map.query( - db_map.entity_class_sq.c.name, func.group_concat(db_map.entity_sq.c.name) - ).filter( - db_map.entity_sq.c.class_id == db_map.entity_class_sq.c.id - ).group_by(db_map.entity_class_sq.c.name).all() - """ - return Query(self.engine, *args) - def _subquery(self, tablename): """A subquery of the form: diff --git a/spinedb_api/db_mapping_remove_mixin.py b/spinedb_api/db_mapping_remove_mixin.py deleted file mode 100644 index f49c96e1..00000000 --- a/spinedb_api/db_mapping_remove_mixin.py +++ /dev/null @@ -1,105 +0,0 @@ -###################################################################################################################### -# Copyright (C) 2017-2022 Spine project consortium -# This file is part of Spine Database API. -# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser -# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your -# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; -# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General -# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with -# this program. If not, see . -###################################################################################################################### - - -from sqlalchemy import and_, or_ -from sqlalchemy.exc import DBAPIError -from .exception import SpineDBAPIError -from .helpers import Asterisk, group_consecutive -from .temp_id import resolve - -# TODO: improve docstrings - - -class DatabaseMappingRemoveMixin: - """Provides methods to perform ``REMOVE`` operations over a Spine db.""" - - def remove_items(self, tablename, *ids): - """Removes items from the DB. - - Args: - tablename (str): Target database table name - *ids (int): Ids of items to be removed. - - Returns: - set: ids or items successfully updated - list(SpineIntegrityError): found violations - """ - if not ids: - return [] - tablename = self._real_tablename(tablename) - mapped_table = self.mapped_table(tablename) - if Asterisk in ids: - self.fetch_all(tablename) - ids = mapped_table - ids = set(ids) - if tablename == "alternative": - # Do not remove the Base alternative - ids.discard(1) - return [mapped_table.remove_item(id_) for id_ in ids] - - def restore_items(self, tablename, *ids): - if not ids: - return [] - tablename = self._real_tablename(tablename) - mapped_table = self.mapped_table(tablename) - return [mapped_table.restore_item(id_) for id_ in ids] - - def purge_items(self, tablename): - """Removes all items from given table. - - Args: - tablename (str): name of table - - Returns: - bool: True if operation was successful, False otherwise - """ - return self.remove_items(tablename, Asterisk) - - def _do_remove_items(self, connection, tablename, *ids): - """Removes items from the db. - - Args: - *ids: ids to remove - """ - tablenames = [self._real_tablename(tablename)] - ids = {resolve(id_) for id_ in ids} - if tablenames[0] == "alternative": - # Do not remove the Base alternative - ids.discard(1) - if not ids: - return - if tablenames[0] == "entity_class": - # Also remove the items corresponding to the id in entity_class_dimension - tablenames.append("entity_class_dimension") - elif tablenames[0] == "entity": - # Also remove the items corresponding to the id in entity_element - tablenames.append("entity_element") - for tablename in tablenames: - table = self._metadata.tables[tablename] - id_field = self._id_fields.get(tablename, "id") - id_column = getattr(table.c, id_field) - cond = or_(*(and_(id_column >= first, id_column <= last) for first, last in group_consecutive(ids))) - delete = table.delete().where(cond) - try: - connection.execute(delete) - except DBAPIError as e: - msg = f"DBAPIError while removing {tablename} items: {e.orig.args}" - raise SpineDBAPIError(msg) from e - - def remove_unused_metadata(self): - used_metadata_ids = set() - for x in self.mapped_table("entity_metadata").valid_values(): - used_metadata_ids.add(x["metadata_id"]) - for x in self.mapped_table("parameter_value_metadata").valid_values(): - used_metadata_ids.add(x["metadata_id"]) - unused_metadata_ids = {x["id"] for x in self.mapped_table("metadata").valid_values()} - used_metadata_ids - self.remove_items("metadata", *unused_metadata_ids) diff --git a/spinedb_api/db_mapping_update_mixin.py b/spinedb_api/db_mapping_update_mixin.py deleted file mode 100644 index d64b4ddb..00000000 --- a/spinedb_api/db_mapping_update_mixin.py +++ /dev/null @@ -1,202 +0,0 @@ -###################################################################################################################### -# Copyright (C) 2017-2022 Spine project consortium -# This file is part of Spine Database API. -# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser -# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your -# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; -# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General -# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with -# this program. If not, see . -###################################################################################################################### - -from sqlalchemy.exc import DBAPIError -from sqlalchemy.sql.expression import bindparam -from .exception import SpineIntegrityError, SpineDBAPIError -from .temp_id import resolve - - -class DatabaseMappingUpdateMixin: - """Provides methods to perform ``UPDATE`` operations over a Spine db.""" - - def _make_update_stmt(self, tablename, keys): - table = self._metadata.tables[self._real_tablename(tablename)] - upd = table.update() - for k in self._get_primary_key(tablename): - upd = upd.where(getattr(table.c, k) == bindparam(k)) - return upd.values({key: bindparam(key) for key in table.columns.keys() & keys}) - - def _do_update_items(self, connection, tablename, *items_to_update): - """Update items in DB without checking integrity.""" - if not items_to_update: - return - try: - upd = self._make_update_stmt(tablename, items_to_update[0].keys()) - connection.execute(upd, [resolve(item._asdict()) for item in items_to_update]) - for tablename_, items_to_update_ in self._extra_items_to_update_per_table(tablename, items_to_update): - if not items_to_update_: - continue - upd = self._make_update_stmt(tablename_, items_to_update_[0].keys()) - connection.execute(upd, [resolve(x) for x in items_to_update_]) - except DBAPIError as e: - msg = f"DBAPIError while updating '{tablename}' items: {e.orig.args}" - raise SpineDBAPIError(msg) from e - - @staticmethod - def _extra_items_to_update_per_table(tablename, items_to_update): - """ - Yields tuples of string tablename, list of items to update. Needed because some update queries - actually need to update records in more than one table. - - Args: - tablename (str): target database table name - items_to_update (list): items to update - - Yields: - tuple: database table name, items to update - """ - if tablename == "entity": - ee_items_to_update = [ - { - "entity_id": item["id"], - "entity_class_id": item["class_id"], - "position": position, - "element_id": element_id, - "dimension_id": dimension_id, - } - for item in items_to_update - for position, (element_id, dimension_id) in enumerate( - zip(item["element_id_list"], item["dimension_id_list"]) - ) - ] - yield ("entity_element", ee_items_to_update) - - def update_items(self, tablename, *items, check=True, strict=False): - """Updates items in the in-memory mapping. - - Args: - tablename (str): The table where items are updated - *items: One or more :class:`dict` objects representing the items to be updated. - check (bool): Whether or not to run integrity checks. - strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` - if the update of one of the items violates an integrity constraint. - - Returns: - tuple(list(dict),list(str)): items successfully updated and found violations. - """ - updated, errors = [], [] - for item in items: - item, error = self.update_item(tablename, check=check, **item) - if error: - if strict: - raise SpineIntegrityError(error) - errors.append(error) - if item: - updated.append(item) - return updated, errors - - def update_alternatives(self, *items, **kwargs): - return self.update_items("alternative", *items, **kwargs) - - def update_scenarios(self, *items, **kwargs): - return self.update_items("scenario", *items, **kwargs) - - def update_scenario_alternatives(self, *items, **kwargs): - return self.update_items("scenario_alternative", *items, **kwargs) - - def update_entity_classes(self, *items, **kwargs): - return self.update_items("entity_class", *items, **kwargs) - - def update_entities(self, *items, **kwargs): - return self.update_items("entity", *items, **kwargs) - - def update_object_classes(self, *items, **kwargs): - return self.update_items("object_class", *items, **kwargs) - - def update_objects(self, *items, **kwargs): - return self.update_items("object", *items, **kwargs) - - def update_wide_relationship_classes(self, *items, **kwargs): - return self.update_items("relationship_class", *items, **kwargs) - - def update_wide_relationships(self, *items, **kwargs): - return self.update_items("relationship", *items, **kwargs) - - def update_parameter_definitions(self, *items, **kwargs): - return self.update_items("parameter_definition", *items, **kwargs) - - def update_parameter_values(self, *items, **kwargs): - return self.update_items("parameter_value", *items, **kwargs) - - def update_parameter_value_lists(self, *items, **kwargs): - return self.update_items("parameter_value_list", *items, **kwargs) - - def update_list_values(self, *items, **kwargs): - return self.update_items("list_value", *items, **kwargs) - - def update_metadata(self, *items, **kwargs): - return self.update_items("metadata", *items, **kwargs) - - def update_entity_metadata(self, *items, **kwargs): - return self.update_items("entity_metadata", *items, **kwargs) - - def update_parameter_value_metadata(self, *items, **kwargs): - return self.update_items("parameter_value_metadata", *items, **kwargs) - - def _update_ext_item_metadata(self, tablename, *items, **kwargs): - metadata_items = self.get_metadata_to_add_with_item_metadata_items(*items) - added, errors = self.add_items("metadata", *metadata_items, **kwargs) - updated, more_errors = self.update_items(tablename, *items, **kwargs) - return added + updated, errors + more_errors - - def update_ext_entity_metadata(self, *items, **kwargs): - return self._update_ext_item_metadata("entity_metadata", *items, **kwargs) - - def update_ext_parameter_value_metadata(self, *items, **kwargs): - return self._update_ext_item_metadata("parameter_value_metadata", *items, **kwargs) - - def get_data_to_set_scenario_alternatives(self, *scenarios, strict=True): - """Returns data to add and remove, in order to set wide scenario alternatives. - - Args: - *scenarios: One or more wide scenario :class:`dict` objects to set. - Each item must include the following keys: - - - "id": integer scenario id - - "alternative_id_list": list of alternative ids for that scenario - - Returns - list: scenario_alternative :class:`dict` objects to add. - set: integer scenario_alternative ids to remove - """ - scen_alts_to_add = [] - scen_alt_ids_to_remove = {} - errors = [] - for scen in scenarios: - current_scen = self.mapped_table("scenario").find_item(scen) - if current_scen is None: - error = f"no scenario matching {scen} to set alternatives for" - if strict: - raise SpineIntegrityError(error) - errors.append(error) - continue - for k, alternative_id in enumerate(scen.get("alternative_id_list", ())): - item_to_add = {"scenario_id": current_scen["id"], "alternative_id": alternative_id, "rank": k + 1} - scen_alts_to_add.append(item_to_add) - for k, alternative_name in enumerate(scen.get("alternative_name_list", ())): - item_to_add = {"scenario_id": current_scen["id"], "alternative_name": alternative_name, "rank": k + 1} - scen_alts_to_add.append(item_to_add) - for alternative_id in current_scen["alternative_id_list"]: - scen_alt = {"scenario_id": current_scen["id"], "alternative_id": alternative_id} - current_scen_alt = self.mapped_table("scenario_alternative").find_item(scen_alt) - scen_alt_ids_to_remove[current_scen_alt["id"]] = current_scen_alt - # Remove items that are both to add and to remove - for id_, to_rm in list(scen_alt_ids_to_remove.items()): - i = next((i for i, to_add in enumerate(scen_alts_to_add) if _is_equal(to_add, to_rm)), None) - if i is not None: - del scen_alts_to_add[i] - del scen_alt_ids_to_remove[id_] - return scen_alts_to_add, set(scen_alt_ids_to_remove), errors - - -def _is_equal(to_add, to_rm): - return all(to_rm[k] == v for k, v in to_add.items()) diff --git a/spinedb_api/filters/execution_filter.py b/spinedb_api/filters/execution_filter.py index bc647b41..51a9f2db 100644 --- a/spinedb_api/filters/execution_filter.py +++ b/spinedb_api/filters/execution_filter.py @@ -159,11 +159,12 @@ def _create_import_alternative(db_map, state): timestamp = state.timestamp sep = "__" if scenarios else "" db_map._import_alternative_name = f"{'_'.join(scenarios)}{sep}{execution_item}@{timestamp}" - db_map.add_alternatives({"name": db_map._import_alternative_name}) - db_map.add_scenarios(*({"name": scen_name} for scen_name in scenarios)) + db_map.add_item("alternative", name=db_map._import_alternative_name) + for scen_name in scenarios: + db_map.add_item("scenario", name=scen_name) for scen_name in scenarios: scen = db_map.get_item("scenario", name=scen_name) rank = len(scen["sorted_scenario_alternatives"]) + 1 # ranks are 1-based - db_map.add_scenario_alternatives( - {"scenario_name": scen_name, "alternative_name": db_map._import_alternative_name, "rank": rank} + db_map.add_item( + "scenario_alternative", scenario_name=scen_name, alternative_name=db_map._import_alternative_name, rank=rank ) diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 9cfe8982..9e30b30a 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -43,8 +43,8 @@ class EntityClassItem(MappedItemBase): "dimension_name_list": ("tuple, optional", "The dimension names for a multi-dimensional class."), "description": ("str, optional", "The class description."), "display_icon": ("int, optional", "An integer representing an icon within your application."), - "display_order": ("int, optional", "Not in use at the moment"), - "hidden": ("bool, optional", "Not in use at the moment"), + "display_order": ("int, optional", "Not in use at the moment."), + "hidden": ("bool, optional", "Not in use at the moment."), } _defaults = {"description": None, "display_icon": None, "display_order": 99, "hidden": False} _unique_keys = (("name",),) diff --git a/spinedb_api/purge.py b/spinedb_api/purge.py index efe64814..904e4294 100644 --- a/spinedb_api/purge.py +++ b/spinedb_api/purge.py @@ -56,7 +56,7 @@ def purge(db_map, purge_settings, logger=None): """ if purge_settings is None: # Bring all the pain - purge_settings = {item_type: True for item_type in DatabaseMapping.ITEM_TYPES} + purge_settings = {item_type: True for item_type in DatabaseMapping.item_types()} removable_db_map_data = {item_type for item_type, checked in purge_settings.items() if checked} if removable_db_map_data: try: diff --git a/spinedb_api/spine_db_server.py b/spinedb_api/spine_db_server.py index 65990aa9..140a31a9 100644 --- a/spinedb_api/spine_db_server.py +++ b/spinedb_api/spine_db_server.py @@ -25,8 +25,8 @@ Each request must be a JSON array with the following elements: #. A JSON string with one of the available request names: - ``"get_db_url"``, ``"import_data"``, ``"export_data"``, ``"query"``, ``"filtered_query"``, - ``"call_method"``, ``"db_checkin"``, ``"db_checkout"``. + ``get_db_url``, ``import_data``, ``export_data``, ``query``, ``filtered_query``, ``apply_filters``, + ``clear_filters``, ``call_method``, ``db_checkin``, ``db_checkout``. #. A JSON array with positional arguments to the request. #. A JSON object with keyword arguments to the request. #. A JSON integer indicating the version of the server you want to talk to. @@ -38,8 +38,8 @@ The point of the server version is to allow client developers to adapt to changes in the Spine DB server API. Say we update ``spinedb_api`` and change the signature of one of the requests - in this case, we will also bump the current server version to the next integer. -If you then upgrade your ``spinedb_api`` installation but not your client, the server will see the version mismatch -and will respond that the client is outdated. +If you then upgrade your ``spinedb_api`` installation but not your client, the server will be able to respond +with an error message saying that you need to update your client. The current server version can be queried by calling :func:`get_current_server_version`. The order in which multiple servers should write to the same DB can also be controlled using DB servers. diff --git a/tests/custom_db_mapping.py b/tests/custom_db_mapping.py new file mode 100644 index 00000000..ab578e3f --- /dev/null +++ b/tests/custom_db_mapping.py @@ -0,0 +1,117 @@ +###################################################################################################################### +# Copyright (C) 2017-2022 Spine project consortium +# This file is part of Spine Database API. +# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser +# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General +# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with +# this program. If not, see . +###################################################################################################################### + +""" +Unit tests for DatabaseMapping class. + +""" +from spinedb_api import DatabaseMapping, SpineIntegrityError + + +class CustomDatabaseMapping(DatabaseMapping): + def add_object_classes(self, *items, **kwargs): + return self.add_items("object_class", *items, **kwargs) + + def add_objects(self, *items, **kwargs): + return self.add_items("object", *items, **kwargs) + + def add_entity_classes(self, *items, **kwargs): + return self.add_items("entity_class", *items, **kwargs) + + def add_entities(self, *items, **kwargs): + return self.add_items("entity", *items, **kwargs) + + def add_wide_relationship_classes(self, *items, **kwargs): + return self.add_items("relationship_class", *items, **kwargs) + + def add_wide_relationships(self, *items, **kwargs): + return self.add_items("relationship", *items, **kwargs) + + def add_parameter_definitions(self, *items, **kwargs): + return self.add_items("parameter_definition", *items, **kwargs) + + def add_parameter_values(self, *items, **kwargs): + return self.add_items("parameter_value", *items, **kwargs) + + def add_parameter_value_lists(self, *items, **kwargs): + return self.add_items("parameter_value_list", *items, **kwargs) + + def add_list_values(self, *items, **kwargs): + return self.add_items("list_value", *items, **kwargs) + + def add_alternatives(self, *items, **kwargs): + return self.add_items("alternative", *items, **kwargs) + + def add_scenarios(self, *items, **kwargs): + return self.add_items("scenario", *items, **kwargs) + + def add_scenario_alternatives(self, *items, **kwargs): + return self.add_items("scenario_alternative", *items, **kwargs) + + def add_entity_groups(self, *items, **kwargs): + return self.add_items("entity_group", *items, **kwargs) + + def add_metadata(self, *items, **kwargs): + return self.add_items("metadata", *items, **kwargs) + + def add_entity_metadata(self, *items, **kwargs): + return self.add_items("entity_metadata", *items, **kwargs) + + def add_parameter_value_metadata(self, *items, **kwargs): + return self.add_items("parameter_value_metadata", *items, **kwargs) + + def update_alternatives(self, *items, **kwargs): + return self.update_items("alternative", *items, **kwargs) + + def update_scenarios(self, *items, **kwargs): + return self.update_items("scenario", *items, **kwargs) + + def update_scenario_alternatives(self, *items, **kwargs): + return self.update_items("scenario_alternative", *items, **kwargs) + + def update_entity_classes(self, *items, **kwargs): + return self.update_items("entity_class", *items, **kwargs) + + def update_entities(self, *items, **kwargs): + return self.update_items("entity", *items, **kwargs) + + def update_object_classes(self, *items, **kwargs): + return self.update_items("object_class", *items, **kwargs) + + def update_objects(self, *items, **kwargs): + return self.update_items("object", *items, **kwargs) + + def update_wide_relationship_classes(self, *items, **kwargs): + return self.update_items("relationship_class", *items, **kwargs) + + def update_wide_relationships(self, *items, **kwargs): + return self.update_items("relationship", *items, **kwargs) + + def update_parameter_definitions(self, *items, **kwargs): + return self.update_items("parameter_definition", *items, **kwargs) + + def update_parameter_values(self, *items, **kwargs): + return self.update_items("parameter_value", *items, **kwargs) + + def update_parameter_value_lists(self, *items, **kwargs): + return self.update_items("parameter_value_list", *items, **kwargs) + + def update_list_values(self, *items, **kwargs): + return self.update_items("list_value", *items, **kwargs) + + def update_metadata(self, *items, **kwargs): + return self.update_items("metadata", *items, **kwargs) + + def update_entity_metadata(self, *items, **kwargs): + return self.update_items("entity_metadata", *items, **kwargs) + + def update_parameter_value_metadata(self, *items, **kwargs): + return self.update_items("parameter_value_metadata", *items, **kwargs) diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 7e347832..22023b39 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -20,14 +20,8 @@ from unittest.mock import patch from sqlalchemy.engine.url import make_url, URL from sqlalchemy.util import KeyedTuple -from spinedb_api import ( - DatabaseMapping, - import_functions, - from_database, - to_database, - SpineDBAPIError, - SpineIntegrityError, -) +from spinedb_api import import_functions, from_database, to_database, SpineDBAPIError, SpineIntegrityError +from .custom_db_mapping import CustomDatabaseMapping def create_query_wrapper(db_map): @@ -50,7 +44,7 @@ def test_construction_with_filters(self): with mock.patch( "spinedb_api.db_mapping.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] ) as mock_load: - db_map = DatabaseMapping(db_url, create=True) + db_map = CustomDatabaseMapping(db_url, create=True) db_map.close() mock_load.assert_called_once_with(["fltr1", "fltr2"]) mock_apply.assert_called_once_with(db_map, [{"fltr1": "config1", "fltr2": "config2"}]) @@ -62,7 +56,7 @@ def test_construction_with_sqlalchemy_url_and_filters(self): with mock.patch( "spinedb_api.db_mapping.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] ) as mock_load: - db_map = DatabaseMapping(sa_url, create=True) + db_map = CustomDatabaseMapping(sa_url, create=True) db_map.close() mock_load.assert_called_once_with(["fltr1", "fltr2"]) mock_apply.assert_called_once_with(db_map, [{"fltr1": "config1", "fltr2": "config2"}]) @@ -71,15 +65,15 @@ def test_shorthand_filter_query_works(self): with TemporaryDirectory() as temp_dir: url = URL("sqlite") url.database = os.path.join(temp_dir, "test_shorthand_filter_query_works.json") - out_db_map = DatabaseMapping(url, create=True) + out_db_map = CustomDatabaseMapping(url, create=True) out_db_map.add_scenarios({"name": "scen1"}) out_db_map.add_scenario_alternatives({"scenario_name": "scen1", "alternative_name": "Base", "rank": 1}) out_db_map.commit_session("Add scen.") out_db_map.close() try: - db_map = DatabaseMapping(url) + db_map = CustomDatabaseMapping(url) except: - self.fail("DatabaseMapping.__init__() should not raise.") + self.fail("CustomDatabaseMapping.__init__() should not raise.") else: db_map.close() @@ -89,7 +83,7 @@ class TestDatabaseMapping(unittest.TestCase): @classmethod def setUpClass(cls): - cls._db_map = DatabaseMapping(IN_MEMORY_DB_URL, create=True) + cls._db_map = CustomDatabaseMapping(IN_MEMORY_DB_URL, create=True) @classmethod def tearDownClass(cls): @@ -101,7 +95,7 @@ def test_construction_with_filters(self): with patch( "spinedb_api.db_mapping.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] ) as mock_load: - db_map = DatabaseMapping(db_url, create=True) + db_map = CustomDatabaseMapping(db_url, create=True) db_map.close() mock_load.assert_called_once_with(["fltr1", "fltr2"]) mock_apply.assert_called_once_with(db_map, [{"fltr1": "config1", "fltr2": "config2"}]) @@ -113,7 +107,7 @@ def test_construction_with_sqlalchemy_url_and_filters(self): with patch( "spinedb_api.db_mapping.load_filters", return_value=[{"fltr1": "config1", "fltr2": "config2"}] ) as mock_load: - db_map = DatabaseMapping(sa_url, create=True) + db_map = CustomDatabaseMapping(sa_url, create=True) db_map.close() mock_load.assert_called_once_with(["fltr1", "fltr2"]) mock_apply.assert_called_once_with(db_map, [{"fltr1": "config1", "fltr2": "config2"}]) @@ -345,7 +339,7 @@ def test_get_import_alternative_returns_base_alternative_by_default(self): class TestDatabaseMappingQueries(unittest.TestCase): def setUp(self): - self._db_map = DatabaseMapping(IN_MEMORY_DB_URL, create=True) + self._db_map = CustomDatabaseMapping(IN_MEMORY_DB_URL, create=True) def tearDown(self): self._db_map.close() @@ -622,7 +616,7 @@ def test_filter_query_accepts_multiple_criteria(self): class TestDatabaseMappingAdd(unittest.TestCase): def setUp(self): - self._db_map = DatabaseMapping(IN_MEMORY_DB_URL, create=True) + self._db_map = CustomDatabaseMapping(IN_MEMORY_DB_URL, create=True) def tearDown(self): self._db_map.close() @@ -750,10 +744,10 @@ def test_add_relationship_classes_with_same_name(self): def test_add_relationship_class_with_same_name_as_existing_one(self): """Test that adding a relationship class with an already taken name raises an integrity error.""" query_wrapper = create_query_wrapper(self._db_map) - with mock.patch.object(DatabaseMapping, "query") as mock_query, mock.patch.object( - DatabaseMapping, "object_class_sq" + with mock.patch.object(CustomDatabaseMapping, "query") as mock_query, mock.patch.object( + CustomDatabaseMapping, "object_class_sq" ) as mock_object_class_sq, mock.patch.object( - DatabaseMapping, "wide_relationship_class_sq" + CustomDatabaseMapping, "wide_relationship_class_sq" ) as mock_wide_rel_cls_sq: mock_query.side_effect = query_wrapper mock_object_class_sq.return_value = [ @@ -771,9 +765,9 @@ def test_add_relationship_class_with_same_name_as_existing_one(self): def test_add_relationship_class_with_invalid_object_class(self): """Test that adding a relationship class with a non existing object class raises an integrity error.""" query_wrapper = create_query_wrapper(self._db_map) - with mock.patch.object(DatabaseMapping, "query") as mock_query, mock.patch.object( - DatabaseMapping, "object_class_sq" - ) as mock_object_class_sq, mock.patch.object(DatabaseMapping, "wide_relationship_class_sq"): + with mock.patch.object(CustomDatabaseMapping, "query") as mock_query, mock.patch.object( + CustomDatabaseMapping, "object_class_sq" + ) as mock_object_class_sq, mock.patch.object(CustomDatabaseMapping, "wide_relationship_class_sq"): mock_query.side_effect = query_wrapper mock_object_class_sq.return_value = [KeyedTuple([1, "fish"], labels=["id", "name"])] with self.assertRaises(SpineIntegrityError): @@ -824,12 +818,12 @@ def test_add_relationship_identical_to_existing_one(self): raises an integrity error. """ query_wrapper = create_query_wrapper(self._db_map) - with mock.patch.object(DatabaseMapping, "query") as mock_query, mock.patch.object( - DatabaseMapping, "object_sq" + with mock.patch.object(CustomDatabaseMapping, "query") as mock_query, mock.patch.object( + CustomDatabaseMapping, "object_sq" ) as mock_object_sq, mock.patch.object( - DatabaseMapping, "wide_relationship_class_sq" + CustomDatabaseMapping, "wide_relationship_class_sq" ) as mock_wide_rel_cls_sq, mock.patch.object( - DatabaseMapping, "wide_relationship_sq" + CustomDatabaseMapping, "wide_relationship_sq" ) as mock_wide_rel_sq: mock_query.side_effect = query_wrapper mock_object_sq.return_value = [ @@ -850,12 +844,12 @@ def test_add_relationship_identical_to_existing_one(self): def test_add_relationship_with_invalid_class(self): """Test that adding a relationship with an invalid class raises an integrity error.""" query_wrapper = create_query_wrapper(self._db_map) - with mock.patch.object(DatabaseMapping, "query") as mock_query, mock.patch.object( - DatabaseMapping, "object_sq" + with mock.patch.object(CustomDatabaseMapping, "query") as mock_query, mock.patch.object( + CustomDatabaseMapping, "object_sq" ) as mock_object_sq, mock.patch.object( - DatabaseMapping, "wide_relationship_class_sq" + CustomDatabaseMapping, "wide_relationship_class_sq" ) as mock_wide_rel_cls_sq, mock.patch.object( - DatabaseMapping, "wide_relationship_sq" + CustomDatabaseMapping, "wide_relationship_sq" ): mock_query.side_effect = query_wrapper mock_object_sq.return_value = [ @@ -873,12 +867,12 @@ def test_add_relationship_with_invalid_class(self): def test_add_relationship_with_invalid_object(self): """Test that adding a relationship with an invalid object raises an integrity error.""" query_wrapper = create_query_wrapper(self._db_map) - with mock.patch.object(DatabaseMapping, "query") as mock_query, mock.patch.object( - DatabaseMapping, "object_sq" + with mock.patch.object(CustomDatabaseMapping, "query") as mock_query, mock.patch.object( + CustomDatabaseMapping, "object_sq" ) as mock_object_sq, mock.patch.object( - DatabaseMapping, "wide_relationship_class_sq" + CustomDatabaseMapping, "wide_relationship_class_sq" ) as mock_wide_rel_cls_sq, mock.patch.object( - DatabaseMapping, "wide_relationship_sq" + CustomDatabaseMapping, "wide_relationship_sq" ): mock_query.side_effect = query_wrapper mock_object_sq.return_value = [ @@ -1394,7 +1388,7 @@ def test_add_ext_parameter_value_metadata_reuses_existing_metadata(self): class TestDatabaseMappingUpdate(unittest.TestCase): def setUp(self): - self._db_map = DatabaseMapping(IN_MEMORY_DB_URL, create=True) + self._db_map = CustomDatabaseMapping(IN_MEMORY_DB_URL, create=True) def tearDown(self): self._db_map.close() @@ -1791,7 +1785,7 @@ def test_update_metadata(self): class TestDatabaseMappingRemoveMixin(unittest.TestCase): def setUp(self): - self._db_map = DatabaseMapping(IN_MEMORY_DB_URL, create=True) + self._db_map = CustomDatabaseMapping(IN_MEMORY_DB_URL, create=True) def tearDown(self): self._db_map.close() @@ -2194,7 +2188,7 @@ def test_remove_parameter_value2(self): class TestDatabaseMappingCommitMixin(unittest.TestCase): def setUp(self): - self._db_map = DatabaseMapping(IN_MEMORY_DB_URL, create=True) + self._db_map = CustomDatabaseMapping(IN_MEMORY_DB_URL, create=True) def tearDown(self): self._db_map.close() diff --git a/tests/test_db_mapping_base.py b/tests/test_db_mapping_base.py index fbeac2be..608499e8 100644 --- a/tests/test_db_mapping_base.py +++ b/tests/test_db_mapping_base.py @@ -14,8 +14,8 @@ class TestDBMapping(DatabaseMappingBase): - @property - def item_types(self): + @staticmethod + def item_types(): return ["cutlery"] @staticmethod diff --git a/tests/test_purge.py b/tests/test_purge.py index 1aacdd8e..906e00d1 100644 --- a/tests/test_purge.py +++ b/tests/test_purge.py @@ -27,7 +27,7 @@ def tearDown(self): def test_purge_entity_classes(self): with DatabaseMapping(self._url, create=True) as db_map: - db_map.add_entity_classes({"name": "Soup"}) + db_map.add_item("entity_class", name="Soup") db_map.commit_session("Add test data") purge_url(self._url, {"alternative": False, "entity_class": True}) with DatabaseMapping(self._url) as db_map: From 41e953fdb45bf1752ba631cb98ab6216461a5771 Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 5 Oct 2023 16:00:48 +0200 Subject: [PATCH 120/317] Minor fixes --- docs/source/conf.py | 4 ++-- spinedb_api/db_mapping.py | 2 +- spinedb_api/db_mapping_base.py | 4 ++-- spinedb_api/mapped_items.py | 28 ++++++++++++++-------------- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 8b719ebd..6d0fbe16 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -113,7 +113,7 @@ def _process_docstring(app, what, name, obj, options, lines): new_lines = [] for item_type in DatabaseMapping.item_types(): factory = DatabaseMapping.item_factory(item_type) - if not factory._fields: + if not factory.fields: continue new_lines.extend([item_type, len(item_type) * "-", ""]) new_lines.extend( @@ -126,7 +126,7 @@ def _process_docstring(app, what, name, obj, options, lines): " - value", ] ) - for f_name, (f_type, f_value) in factory._fields.items(): + for f_name, (f_type, f_value) in factory.fields.items(): new_lines.extend([f" * - {f_name}", f" - {f_type}", f" - {f_value}"]) new_lines.append("") new_lines.extend( diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index adc44313..2e839b7e 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -174,7 +174,7 @@ def __del__(self): @staticmethod def item_types(): - return list(DatabaseMapping._sq_name_by_item_type) + return [x for x in DatabaseMapping._sq_name_by_item_type if item_factory(x).fields] @staticmethod def item_factory(item_type): diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 9a2d12b0..e858bcc3 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -448,8 +448,8 @@ def restore_item(self, id_): class MappedItemBase(dict): """A dictionary that represents a db item.""" - _fields = {} - """A dictionaty mapping fields to a tuple of (type, description)""" + fields = {} + """A dictionaty mapping fields to a tuple of (type, value description)""" _defaults = {} """A dictionary mapping keys to their default values""" _unique_keys = () diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 9e30b30a..7d8e0ae8 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -38,7 +38,7 @@ def item_factory(item_type): class EntityClassItem(MappedItemBase): - _fields = { + fields = { "name": ("str", "The class name."), "dimension_name_list": ("tuple, optional", "The dimension names for a multi-dimensional class."), "description": ("str, optional", "The class description."), @@ -75,7 +75,7 @@ def commit(self, _commit_id): class EntityItem(MappedItemBase): - _fields = { + fields = { "class_name": ("str", "The entity class name."), "name": ("str, optional", "The entity name - must be given for a zero-dimensional entity."), "element_name_list": ("tuple, optional", "The element names - must be given for a multi-dimensional entity."), @@ -123,7 +123,7 @@ def polish(self): class EntityGroupItem(MappedItemBase): - _fields = { + fields = { "class_name": ("str", "The entity class name."), "group_name": ("str", "The group entity name."), "member_name": ("str", "The member entity name."), @@ -150,7 +150,7 @@ def __getitem__(self, key): class EntityAlternativeItem(MappedItemBase): - _fields = { + fields = { "entity_class_name": ("str", "The entity class name."), "entity_byname": ( "str or tuple", @@ -203,7 +203,7 @@ def __getitem__(self, key): class ParameterDefinitionItem(ParsedValueBase): - _fields = { + fields = { "entity_class_name": ("str", "The entity class name."), "name": ("str", "The parameter name."), "default_value": ("any, optional", "The default value."), @@ -301,7 +301,7 @@ def merge(self, other): class ParameterValueItem(ParsedValueBase): - _fields = { + fields = { "entity_class_name": ("str", "The entity class name."), "parameter_definition_name": ("str", "The parameter name."), "entity_byname": ( @@ -394,12 +394,12 @@ def callback(new_id): class ParameterValueListItem(MappedItemBase): - _fields = {"name": ("str", "The parameter value list name.")} + fields = {"name": ("str", "The parameter value list name.")} _unique_keys = (("name",),) class ListValueItem(ParsedValueBase): - _fields = { + fields = { "parameter_value_list_name": ("str", "The parameter value list name."), "value": ("any", "The value."), "type": ("str", "The value type."), @@ -419,7 +419,7 @@ def _make_parsed_value(self): class AlternativeItem(MappedItemBase): - _fields = { + fields = { "name": ("str", "The alternative name."), "description": ("str, optional", "The alternative description."), } @@ -428,7 +428,7 @@ class AlternativeItem(MappedItemBase): class ScenarioItem(MappedItemBase): - _fields = { + fields = { "name": ("str", "The scenario name."), "description": ("str, optional", "The scenario description."), "active": ("bool, optional", "Not in use at the moment."), @@ -455,7 +455,7 @@ def __getitem__(self, key): class ScenarioAlternativeItem(MappedItemBase): - _fields = { + fields = { "scenario_name": ("str", "The scenario name."), "alternative_name": ("str", "The alternative name."), "rank": ("int", "The rank - the higher has precedence."), @@ -487,12 +487,12 @@ def __getitem__(self, key): class MetadataItem(MappedItemBase): - _fields = {"name": ("str", "The metadata entry name."), "value": ("str", "The metadata entry value.")} + fields = {"name": ("str", "The metadata entry name."), "value": ("str", "The metadata entry value.")} _unique_keys = (("name", "value"),) class EntityMetadataItem(MappedItemBase): - _fields = { + fields = { "entity_name": ("str", "The entity name."), "metadata_name": ("str", "The metadata entry name."), "metadata_value": ("str", "The metadata entry value."), @@ -510,7 +510,7 @@ class EntityMetadataItem(MappedItemBase): class ParameterValueMetadataItem(MappedItemBase): - _fields = { + fields = { "parameter_definition_name": ("str", "The parameter name."), "entity_byname": ( "str or tuple", From f3b99849a44ac568205cde1659adcd58fbc556e1 Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 5 Oct 2023 17:07:54 +0200 Subject: [PATCH 121/317] Keep improving docs --- docs/source/conf.py | 14 ++-- docs/source/tutorial.rst | 2 +- spinedb_api/db_mapping.py | 130 ++++++++++++++++++++++----------- spinedb_api/db_mapping_base.py | 14 ++-- tests/test_db_mapping_base.py | 2 +- 5 files changed, 104 insertions(+), 58 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 6d0fbe16..051a192e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -105,6 +105,7 @@ def _skip_member(app, what, name, obj, skip, options): def _process_docstring(app, what, name, obj, options, lines): + # Expand try: i = lines.index("") except ValueError: @@ -112,7 +113,7 @@ def _process_docstring(app, what, name, obj, options, lines): else: new_lines = [] for item_type in DatabaseMapping.item_types(): - factory = DatabaseMapping.item_factory(item_type) + factory = DatabaseMapping._item_factory(item_type) if not factory.fields: continue new_lines.extend([item_type, len(item_type) * "-", ""]) @@ -140,12 +141,11 @@ def _process_docstring(app, what, name, obj, options, lines): f_names = ", ".join(f_names) new_lines.extend([f" * - {f_names}"]) lines[i : i + 1] = new_lines - return - if what == "method": - spine_item_types = ", ".join([f"`{x}`" for x in DatabaseMapping.item_types()]) - for k, line in enumerate(lines): - if "" in line: - lines[k] = line.replace("", spine_item_types) + # Expand + spine_item_types = ", ".join([f"`{x}`" for x in DatabaseMapping.item_types()]) + for k, line in enumerate(lines): + if "" in line: + lines[k] = line.replace("", spine_item_types) def setup(sphinx): diff --git a/docs/source/tutorial.rst b/docs/source/tutorial.rst index 057342b6..96981f1b 100644 --- a/docs/source/tutorial.rst +++ b/docs/source/tutorial.rst @@ -18,7 +18,7 @@ To begin, make sure Spine database API is installed as described at :ref:`instal Database Mapping ---------------- -The main means of communication with a Spine DB is the :class:`.DatabaseMapping`, +The main mean of communication with a Spine DB is the :class:`.DatabaseMapping`, specially designed to retrieve and modify data from the DB. To create a :class:`.DatabaseMapping`, we just pass the URL of the DB to the class constructor:: diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 2e839b7e..dc8f93f4 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -10,11 +10,22 @@ ###################################################################################################################### """ -This module defines the :class:`.DatabaseMapping` class. +This module defines the :class:`.DatabaseMapping` class, the main mean to communicate with a Spine DB. +If you're planning to use this class, it is probably a good idea to first familiarize yourself a little bit with the +DB mapping schema. DB mapping schema ================= + +The DB mapping schema is a close cousin of the Spine DB schema, with some extra flexibility such as +(or should I say, mainly) the ability to define references by name rather than by numerical id. +The schema defines the following item types: . As you can see, these follow the names +of some of the tables in the Spine DB schema. + +The following subsections provide all the details you need to know about the different item types, namely, +their fields, values, and unique keys. + """ @@ -59,7 +70,7 @@ class DatabaseMapping(DatabaseMappingQueryMixin, DatabaseMappingCommitMixin, DatabaseMappingBase): """Enables communication with a Spine DB. - The DB is incrementally mapped into memory as data is requested/modified. + The DB is incrementally mapped into memory as data is requested/modified, following the `DB mapping schema`_. Data is typically retrieved using :meth:`get_item` or :meth:`get_items`. If the requested data is already in memory, it is returned from there; @@ -87,6 +98,12 @@ class DatabaseMapping(DatabaseMappingQueryMixin, DatabaseMappingCommitMixin, Dat The :meth:`query` method is also provided as an alternative way to retrieve data from the DB while bypassing the in-memory mapping entirely. + You can use this class as a context manager, e.g.:: + + with DatabaseMapping(db_url) as db_map: + # Do stuff with db_map + ... + """ _sq_name_by_item_type = { @@ -177,17 +194,35 @@ def item_types(): return [x for x in DatabaseMapping._sq_name_by_item_type if item_factory(x).fields] @staticmethod - def item_factory(item_type): + def _item_factory(item_type): return item_factory(item_type) - def make_query(self, item_type): + def _make_query(self, item_type): if self.closed: return None sq_name = self._sq_name_by_item_type[item_type] return self.query(getattr(self, sq_name)) def close(self): - """Closes this DB mapping.""" + """Closes this DB mapping. This is only needed if you're keeping a long-lived session. + For instance:: + + class MyDBMappingWrapper: + def __init__(self, url): + self._db_map = DatabaseMapping(url) + + # More methods that do stuff with self._db_map + + def __del__(self): + self._db_map.close() + + Otherwise, the usage as context manager is recommended:: + + with DatabaseMapping(url) as db_map: + # Do stuff with db_map + ... + # db_map.close() is automatically called when leaving this block + """ self.closed = True def _make_codename(self, codename): @@ -203,16 +238,6 @@ def _make_codename(self, codename): @staticmethod def create_engine(sa_url, upgrade=False, create=False, sqlite_timeout=1800): - """Creates engine. - - Args: - sa_url (URL) - upgrade (bool, optional): If True, upgrade the db to the latest version. - create (bool, optional): If True, create a new Spine db at the given url if none found. - - Returns: - :class:`~sqlalchemy.engine.Engine` - """ if sa_url.drivername == "sqlite": connect_args = {'timeout': sqlite_timeout} else: @@ -327,14 +352,6 @@ def override_create_import_alternative(self, method): self._create_import_alternative = MethodType(method, self) self._import_alternative_name = None - def get_filter_configs(self): - """Returns the filters used to build this DB mapping. - - Returns: - list(dict): - """ - return self._filter_configs - def get_table(self, tablename): # For tests return self._metadata.tables[tablename] @@ -342,11 +359,16 @@ def get_table(self, tablename): def get_item(self, item_type, fetch=True, skip_removed=True, **kwargs): """Finds and returns an item matching the arguments, or None if none found. + Example:: + + with DatabaseMapping(db_url) as db_map: + prince = db_map.get_item("entity", class_name="musician", name="Prince") + Args: item_type (str): One of . fetch (bool, optional): Whether to fetch the DB in case the item is not found in memory. skip_removed (bool, optional): Whether to ignore removed items. - **kwargs: Fields and values for one the unique keys as specified for the item type in `DB mapping schema`_. + **kwargs: Fields and values for one the unique keys of the item type as specified in `DB mapping schema`_. Returns: :class:`PublicItem` or None @@ -380,10 +402,16 @@ def get_items(self, item_type, fetch=True, skip_removed=True): def add_item(self, item_type, check=True, **kwargs): """Adds an item to the in-memory mapping. + Example:: + + with DatabaseMapping(db_url) as db_map: + db_map.add_item("entity_class", name="musician") + db_map.add_item("entity", class_name="musician", name="Prince") + Args: item_type (str): One of . check (bool, optional): Whether to carry out integrity checks. - **kwargs: Fields and values as specified for the item type in `DB mapping schema`_. + **kwargs: Fields and values of the item type as specified in `DB mapping schema`_. Returns: tuple(:class:`PublicItem` or None, str): The added item and any errors. @@ -404,8 +432,8 @@ def add_items(self, item_type, *items, check=True, strict=False): Args: item_type (str): One of . - *items (Iterable(dict)): One or more :class:`dict` objects mapping fields to values, - as specified for the item type in `DB mapping schema`_. + *items (Iterable(dict)): One or more :class:`dict` objects mapping fields to values of the item type, + as specified in `DB mapping schema`_. check (bool): Whether or not to run integrity checks. strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` if the insertion of one of the items violates an integrity constraint. @@ -427,11 +455,19 @@ def add_items(self, item_type, *items, check=True, strict=False): def update_item(self, item_type, check=True, **kwargs): """Updates an item in the in-memory mapping. + Example:: + + with DatabaseMapping(db_url) as db_map: + prince = db_map.get_item("entity", class_name="musician", name="Prince") + db_map.update_item( + "entity", id=prince["id"], name="the Artist", description="Formerly known as Prince." + ) + Args: item_type (str): One of . check (bool, optional): Whether to carry out integrity checks. id (int): The id of the item to update. - **kwargs: Fields to update and their new values as specified for the item type in `DB mapping schema`_. + **kwargs: Fields to update and their new values as specified in `DB mapping schema`_. Returns: tuple(:class:`PublicItem` or None, str): The updated item and any errors. @@ -449,8 +485,8 @@ def update_items(self, item_type, *items, check=True, strict=False): Args: item_type (str): One of . - *items (Iterable(dict)): One or more :class:`dict` objects mapping fields to values, - as specified for the item type in `DB mapping schema`_ and including the `id`. + *items (Iterable(dict)): One or more :class:`dict` objects mapping fields to values of the item type, + as specified in `DB mapping schema`_ and including the `id`. check (bool): Whether or not to run integrity checks. strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` if the update of one of the items violates an integrity constraint. @@ -474,10 +510,9 @@ def remove_item(self, item_type, id_): Example:: - with DatabaseMapping(url) as db_map: - my_dog = db_map.get_item("entity", class_name="dog", name="Pluto") - db_map.remove_item("entity", my_dog["id]) - + with DatabaseMapping(db_url) as db_map: + prince = db_map.get_item("entity", class_name="musician", name="Prince") + db_map.remove_item("entity", prince["id"]) Args: item_type (str): One of . @@ -518,9 +553,9 @@ def restore_item(self, item_type, id_): Example:: - with DatabaseMapping(url) as db_map: - my_dog = db_map.get_item("entity", skip_removed=False, class_name="dog", name="Pluto") - db_map.restore_item("entity", my_dog["id]) + with DatabaseMapping(db_url) as db_map: + prince = db_map.get_item("entity", skip_remove=False, class_name="musician", name="Prince") + db_map.restore_item("entity", prince["id"]) Args: item_type (str): One of . @@ -554,9 +589,9 @@ def purge_items(self, item_type): item_type (str): One of . Returns: - bool: True if operation was successful, False otherwise + bool: True if any data was removed, False otherwise. """ - return self.remove_items(item_type, Asterisk) + return bool(self.remove_items(item_type, Asterisk)) def can_fetch_more(self, item_type): """Whether or not more data can be fetched from the DB for the given item type. @@ -565,7 +600,7 @@ def can_fetch_more(self, item_type): item_type (str): One of . Returns: - bool + bool: True if more data can be fetched. """ return item_type not in self.fetched_item_types @@ -579,7 +614,7 @@ def fetch_more(self, item_type, limit=None): In other words, each item is fetched from the DB exactly once. Returns: - list(PublicItem): The items fetched. + list(:class:`PublicItem`): The items fetched. """ item_type = self._real_tablename(item_type) return [PublicItem(self, x) for x in self.do_fetch_more(item_type, limit=limit)] @@ -601,7 +636,7 @@ def query(self, *args, **kwargs): """Returns a :class:`~spinedb_api.query.Query` object to execute against the mapped DB. To perform custom ``SELECT`` statements, call this method with one or more of the documented - subquery properties of :class:`~spinedb_api.DatabaseMappingQueryMixin` returning + subquery properties of :class:`~spinedb_api.db_mapping_query_mixin.DatabaseMappingQueryMixin` returning :class:`~sqlalchemy.sql.expression.Alias` objetcs. For example, to select the entity class with ``id`` equal to 1:: @@ -623,6 +658,9 @@ def query(self, *args, **kwargs): ).filter( db_map.entity_sq.c.class_id == db_map.entity_class_sq.c.id ).group_by(db_map.entity_class_sq.c.name).all() + + Returns: + :class:`~spinedb_api.query.Query`: The resulting query. """ return Query(self.engine, *args) @@ -749,6 +787,14 @@ def remove_unused_metadata(self): unused_metadata_ids = {x["id"] for x in self.mapped_table("metadata").valid_values()} - used_metadata_ids self.remove_items("metadata", *unused_metadata_ids) + def get_filter_configs(self): + """Returns the filters used to build this DB mapping. + + Returns: + list(dict): + """ + return self._filter_configs + class PublicItem: def __init__(self, db_map, cache_item): diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index e858bcc3..a2997f80 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -34,7 +34,7 @@ class DatabaseMappingBase: This class is not meant to be used directly. Instead, you should subclass it to fit your particular DB schema. - When subclassing, you need to implement :meth:`item_types`, :meth:`item_factory`, and :meth:`make_query`. + When subclassing, you need to implement :meth:`item_types`, :meth:`_item_factory`, and :meth:`_make_query`. """ def __init__(self): @@ -46,7 +46,7 @@ def __init__(self): self._sorted_item_types = [] while item_types: item_type = item_types.pop(0) - if self.item_factory(item_type).ref_types() & set(item_types): + if self._item_factory(item_type).ref_types() & set(item_types): item_types.append(item_type) else: self._sorted_item_types.append(item_type) @@ -70,7 +70,7 @@ def item_types(): raise NotImplementedError() @staticmethod - def item_factory(item_type): + def _item_factory(item_type): """Returns a subclass of :class:`.MappedItemBase` to make items of given type. Args: @@ -81,7 +81,7 @@ def item_factory(item_type): """ raise NotImplementedError() - def make_query(self, item_type): + def _make_query(self, item_type): """Returns a :class:`~spinedb_api.query.Query` object to fecth items of given type. Args: @@ -93,7 +93,7 @@ def make_query(self, item_type): raise NotImplementedError() def make_item(self, item_type, **item): - factory = self.item_factory(item_type) + factory = self._item_factory(item_type) return factory(self, item_type, **item) def dirty_ids(self, item_type): @@ -134,7 +134,7 @@ def _dirty_items(self): for other_item_type in self.item_types(): if ( other_item_type not in self.fetched_item_types - and item_type in self.item_factory(other_item_type).ref_types() + and item_type in self._item_factory(other_item_type).ref_types() ): self.fetch_all(other_item_type) if to_add or to_update or to_remove: @@ -180,7 +180,7 @@ def _refresh(self): self._fetched_item_types.clear() def _get_next_chunk(self, item_type, limit): - qry = self.make_query(item_type) + qry = self._make_query(item_type) if not qry: return [] if not limit: diff --git a/tests/test_db_mapping_base.py b/tests/test_db_mapping_base.py index 608499e8..6c3ccd40 100644 --- a/tests/test_db_mapping_base.py +++ b/tests/test_db_mapping_base.py @@ -19,7 +19,7 @@ def item_types(): return ["cutlery"] @staticmethod - def item_factory(item_type): + def _item_factory(item_type): if item_type == "cutlery": return MappedItemBase raise RuntimeError(f"unknown item_type '{item_type}'") From 9e4f5eb360eda58b922b62204a8b5099d282497f Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 5 Oct 2023 17:20:19 +0200 Subject: [PATCH 122/317] Remove get_data_to_set_scenario_alternatives --- spinedb_api/db_mapping.py | 47 --------------------------------------- 1 file changed, 47 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index dc8f93f4..4fbac1fc 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -731,53 +731,6 @@ def update_ext_entity_metadata(self, *items, **kwargs): def update_ext_parameter_value_metadata(self, *items, **kwargs): return self._update_ext_item_metadata("parameter_value_metadata", *items, **kwargs) - def get_data_to_set_scenario_alternatives(self, *scenarios, strict=True): - """Returns data to add and remove, in order to set wide scenario alternatives. - - Args: - *scenarios: One or more wide scenario :class:`dict` objects to set. - Each item must include the following keys: - - - "id": integer scenario id - - "alternative_id_list": list of alternative ids for that scenario - - Returns - list: scenario_alternative :class:`dict` objects to add. - set: integer scenario_alternative ids to remove - """ - - def _is_equal(to_add, to_rm): - return all(to_rm[k] == v for k, v in to_add.items()) - - scen_alts_to_add = [] - scen_alt_ids_to_remove = {} - errors = [] - for scen in scenarios: - current_scen = self.mapped_table("scenario").find_item(scen) - if current_scen is None: - error = f"no scenario matching {scen} to set alternatives for" - if strict: - raise SpineIntegrityError(error) - errors.append(error) - continue - for k, alternative_id in enumerate(scen.get("alternative_id_list", ())): - item_to_add = {"scenario_id": current_scen["id"], "alternative_id": alternative_id, "rank": k + 1} - scen_alts_to_add.append(item_to_add) - for k, alternative_name in enumerate(scen.get("alternative_name_list", ())): - item_to_add = {"scenario_id": current_scen["id"], "alternative_name": alternative_name, "rank": k + 1} - scen_alts_to_add.append(item_to_add) - for alternative_id in current_scen["alternative_id_list"]: - scen_alt = {"scenario_id": current_scen["id"], "alternative_id": alternative_id} - current_scen_alt = self.mapped_table("scenario_alternative").find_item(scen_alt) - scen_alt_ids_to_remove[current_scen_alt["id"]] = current_scen_alt - # Remove items that are both to add and to remove - for id_, to_rm in list(scen_alt_ids_to_remove.items()): - i = next((i for i, to_add in enumerate(scen_alts_to_add) if _is_equal(to_add, to_rm)), None) - if i is not None: - del scen_alts_to_add[i] - del scen_alt_ids_to_remove[id_] - return scen_alts_to_add, set(scen_alt_ids_to_remove), errors - def remove_unused_metadata(self): used_metadata_ids = set() for x in self.mapped_table("entity_metadata").valid_values(): From 47cb454645fa4b9adcf88dd55cfbe4fa48be6028 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Fri, 6 Oct 2023 14:53:48 +0300 Subject: [PATCH 123/317] Fix MAP_TYPE Re #279 --- spinedb_api/export_mapping/export_mapping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spinedb_api/export_mapping/export_mapping.py b/spinedb_api/export_mapping/export_mapping.py index 597ae1d5..df8e33c9 100644 --- a/spinedb_api/export_mapping/export_mapping.py +++ b/spinedb_api/export_mapping/export_mapping.py @@ -953,7 +953,7 @@ class ParameterDefaultValueTypeMapping(ParameterDefaultValueMapping): an :class:`AlternativeMapping` as parents. """ - MAP_TYPE = "ParameterValueType" + MAP_TYPE = "ParameterDefaultValueType" def _data(self, db_row): type_ = db_row.default_type From 48427c1455e80d1fcf73975b04b52d3c1ad56224 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Mon, 9 Oct 2023 15:03:53 +0300 Subject: [PATCH 124/317] Fix tutorial The code in tutorial was broken in a few ways. It still doesn't run due to a probable bug but this is a step in the right direction. Re #282 --- docs/source/parameter_value_format.rst | 7 ++ docs/source/tutorial.rst | 77 ++++++++++++++------- tests/test_DatabaseMapping.py | 96 +++++++++++++++++++++++++- 3 files changed, 154 insertions(+), 26 deletions(-) diff --git a/docs/source/parameter_value_format.rst b/docs/source/parameter_value_format.rst index 9e67faa8..b2bc5e21 100644 --- a/docs/source/parameter_value_format.rst +++ b/docs/source/parameter_value_format.rst @@ -2,6 +2,13 @@ Parameter value format ********************** +.. note:: + + Client code should almost never convert parameter values to JSON and back manually. + For most cases, JSON should be considered an implementation detail. + Clients should rather use :func:`.to_database` and :func:`.from_database` which shield + from abrupt changes in the database representation. + Parameter values are specified using JSON in the ``value`` field of the ``parameter_value`` table. This document describes the JSON specification for parameter values of special type (namely, date-time, duration, time-pattern, time-series, array, and map.) diff --git a/docs/source/tutorial.rst b/docs/source/tutorial.rst index 96981f1b..6fd2f309 100644 --- a/docs/source/tutorial.rst +++ b/docs/source/tutorial.rst @@ -12,7 +12,7 @@ The Spine DB API allows one to create and manipulate Spine databases in a standard way, using SQLAlchemy_ as the underlying engine. This tutorial provides a quick introduction to the usage of the package. -To begin, make sure Spine database API is installed as described at :ref:`installation`. +To begin, make sure Spine database API is installed as described in :ref:`installation`. Database Mapping @@ -35,7 +35,7 @@ The URL should be formatted following the RFC-1738 standard, as described .. note:: - Currently supported database backends are only SQLite and MySQL. More will be added later. + Currently supported database backends are SQLite and MySQL. Creating a DB ------------- @@ -52,13 +52,14 @@ We can remediate this by creating a SQLite DB (which is just a file in your syst pass The above will create a file called ``first.sqlite`` in your current working directoy. -Note that we pass the keyword argument ``create=True`` to :class:`.DatabaseMapping` to explicitely say +Note that we pass the keyword argument ``create=True`` to :class:`.DatabaseMapping` to explicitly say that we want the DB to be created at the given URL. .. note:: In the remainder we will skip the above step and work directly with ``db_map``. In other words, - all the examples below assume we are inside the ``with`` block above. + all the examples below assume we are inside the ``with`` block above + except when we need to modify the ``import`` line. Adding data ----------- @@ -71,7 +72,7 @@ Let's begin the party by adding a couple of entity classes:: db_map.add_item("entity_class", name="cat", description="Eats fish.") Now let's add a multi-dimensional entity class between the two above. For this we need to specify the class names -as `dimensions`:: +as `dimension_name_list`:: db_map.add_item( "entity_class", @@ -80,7 +81,6 @@ as `dimensions`:: description="A fish getting eaten by a cat?", ) - Let's add entities to our zero-dimensional classes:: db_map.add_item("entity", class_name="fish", name="Nemo", description="Lost (for now).") @@ -89,7 +89,7 @@ Let's add entities to our zero-dimensional classes:: ) Let's add a multi-dimensional entity to our multi-dimensional class. For this we need to specify the entity names -as `elements`:: +as `element_name_list`:: db_map.add_item("entity", class_name="fish__cat", element_name_list=("Nemo", "Felix")) @@ -97,20 +97,33 @@ Let's add a parameter definition for one of our entity classes:: db_map.add_item("parameter_definition", entity_class_name="fish", name="color") -Finally, let's specify a parameter value for one of our entities:: +Finally, let's specify a parameter value for one of our entities. +For this we need :func:`.to_database` function which converts the value into its database representation. +Let's modify the import statement at the beginning of our script:: + + from spinedb_api import DatabaseMapping, to_database + +Now we're ready to go:: + color, value_type = to_database("mainly orange") db_map.add_item( "parameter_value", entity_class_name="fish", - entity_name="Nemo", + entity_byname=("Nemo",), parameter_definition_name="color", - value="mainly orange" + alternative_name="Base", + value=color, + type=value_type ) -.. note:: +Note that in the above, we must refer the entity by its *byname* which is a tuple of its dimensions. +We also set the value to belong to an *alternative* called ``"Base"`` +which is readily available in new databases. - The data we've added so far is not yet in the DB, but only in a in-memory mapping within our ``db_map`` object. +.. note:: + The data we've added so far is not yet in the DB, but only in an in-memory mapping within our ``db_map`` object. + You need to call :meth:`~.DatabaseMapping.commit_session` to actually store the data. Retrieving data --------------- @@ -121,17 +134,32 @@ For example, let's find one of the entities we inserted above:: felix = db_map.get_item("entity", class_name="cat", name="Felix") print(felix["description"]) # Prints 'The wonderful wonderful cat.' - Above, ``felix`` is a :class:`~.PublicItem` object, representing an item (or row) in a Spine DB. Let's find our multi-dimensional entity:: - nemo_felix = db_map.get_item("entity", class_name="fish__cat", byname=("Nemo", "Felix")) - print(nemo_felix["dimension_name_list"]) # Prints "('fish', 'cat')"" + nemo_felix = db_map.get_item("entity", class_name="fish__cat", element_name_list=("Nemo", "Felix")) + print(nemo_felix["dimension_name_list"]) # Prints "('fish', 'cat')" + +Parameter values need to be converted to Python values using :func:`.from_database` before we can use them. +First we need to import the function:: + + from spinedb_api import DatabaseMapping, to_database, from_database + +Then we can retrieve the ``"color"`` of ``"Nemo"`` (in the ``"Base"`` alternative):: + + color_value = db_map.get_item( + "parameter_value", + class_name="fish", + entity_byname=("Nemo",), + alternative="Base" + ) + color = from_database(color_value["value"], color_value["type"]) + print(color) # Prints 'mainly orange' To retrieve all the items of a given type, we use :meth:`~.DatabaseMapping.get_items`:: - print(entity["byname"] for entity in db_map.get_items("entity")) + print(list(entity["byname"] for entity in db_map.get_items("entity"))) # Prints [("Nemo",), ("Felix",), ("Nemo", "Felix"),] Now you should use the above to try and find Nemo. @@ -148,13 +176,14 @@ Let's rename our fish entity to avoid any copyright infringements:: To be safe, let's also change the color:: + new_color, value_type = to_database("not that orange") db_map.get_item( "parameter_value", entity_class_name="fish", + entity_byname=("NotNemo",), parameter_definition_name="color", - entity_name="NotNemo" - ).update(value="not that orange") - + alternative_name="Base", + ).update(value=new_color, type=value_type) Note how we need to use then new entity name ``"NotNemo"`` to retrieve the parameter value. This makes sense. @@ -166,9 +195,7 @@ To do this we use the :meth:`~.PublicItem.remove` method of :class:`~.PublicItem db_map.get_item("entity", class_name="fish", name="NotNemo").remove() - - - - - - +Note that the above call removes items in *cascade*, +meaning that items that depend on ``"NotNemo"`` will get removed as well. +We have one such item in the database, namely the ``"color"`` parameter value +which also gets dropped when the above method is called. diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 22023b39..0cf2b495 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -20,7 +20,14 @@ from unittest.mock import patch from sqlalchemy.engine.url import make_url, URL from sqlalchemy.util import KeyedTuple -from spinedb_api import import_functions, from_database, to_database, SpineDBAPIError, SpineIntegrityError +from spinedb_api import ( + DatabaseMapping, + import_functions, + from_database, + to_database, + SpineDBAPIError, + SpineIntegrityError, +) from .custom_db_mapping import CustomDatabaseMapping @@ -79,6 +86,93 @@ def test_shorthand_filter_query_works(self): class TestDatabaseMapping(unittest.TestCase): + def test_commit_parameter_value(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "database.sqlite") + with DatabaseMapping(url, create=True) as db_map: + _, error = db_map.add_item("entity_class", name="fish", description="It swims.") + self.assertIsNone(error) + _, error = db_map.add_item( + "entity", class_name="fish", name="Nemo", description="Peacefully swimming away." + ) + self.assertIsNone(error) + _, error = db_map.add_item("parameter_definition", entity_class_name="fish", name="color") + self.assertIsNone(error) + value, type_ = to_database("mainly orange") + _, error = db_map.add_item( + "parameter_value", + entity_class_name="fish", + entity_byname=("Nemo",), + parameter_definition_name="color", + alternative_name="Base", + value=value, + type=type_, + ) + self.assertIsNone(error) + db_map.commit_session("Added data") + with DatabaseMapping(url) as db_map: + color = db_map.get_item( + "parameter_value", + entity_class_name="fish", + entity_byname=("Nemo",), + parameter_definition_name="color", + alternative_name="Base", + ) + value = from_database(color["value"], color["type"]) + self.assertEqual(value, "mainly orange") + + def test_commit_multidimensional_parameter_value(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "database.sqlite") + with DatabaseMapping(url, create=True) as db_map: + _, error = db_map.add_item("entity_class", name="fish", description="It swims.") + self.assertIsNone(error) + _, error = db_map.add_item("entity_class", name="cat", description="Eats fish.") + self.assertIsNone(error) + _, error = db_map.add_item( + "entity_class", + name="fish__cat", + dimension_name_list=("fish", "cat"), + description="A fish getting eaten by a cat?", + ) + self.assertIsNone(error) + _, error = db_map.add_item("entity", class_name="fish", name="Nemo", description="Lost (soon).") + self.assertIsNone(error) + _, error = db_map.add_item( + "entity", class_name="cat", name="Felix", description="The wonderful wonderful cat." + ) + self.assertIsNone(error) + _, error = db_map.add_item("entity", class_name="fish__cat", element_name_list=("Nemo", "Felix")) + self.assertIsNone(error) + _, error = db_map.add_item("parameter_definition", entity_class_name="fish__cat", name="rate") + self.assertIsNone(error) + value, type_ = to_database(0.23) + _, error = db_map.add_item( + "parameter_value", + entity_class_name="fish__cat", + entity_byname=("Nemo", "Felix"), + parameter_definition_name="rate", + alternative_name="Base", + value=value, + type=type_, + ) + self.assertIsNone(error) + db_map.commit_session("Added data") + with DatabaseMapping(url) as db_map: + color = db_map.get_item( + "parameter_value", + entity_class_name="fish__cat", + entity_byname=("Nemo", "Felix"), + parameter_definition_name="rate", + alternative_name="Base", + ) + value = from_database(color["value"], color["type"]) + self.assertEqual(value, 0.23) + + +class TestDatabaseMappingLegacy(unittest.TestCase): + """'Backward compatibility' tests, i.e. pre-entity tests converted to work with the entity structure.""" + _db_map = None @classmethod From 703dd9336b921b98eebd95ed22ad4c7847e55c39 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Tue, 10 Oct 2023 13:05:28 +0300 Subject: [PATCH 125/317] Fix get_item() after dependee item update We need to update _id_by_uniqe_key_value dictionaries in _MappedTables in cascade when an item gets updated so get_item() works e.g. when asked to retrieve a parameter value after the value's entity has been renamed. Re #284 --- spinedb_api/db_mapping_base.py | 40 +++++++++++++++++++-------- tests/test_DatabaseMapping.py | 50 ++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 12 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index a2997f80..164974de 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -391,14 +391,16 @@ def check_item(self, item, for_update=False, skip_keys=()): candidate_item["id"] = self._new_id() return candidate_item, merge_error - def _add_unique(self, item): + def add_unique(self, item): + id_ = item["id"] for key, value in item.unique_values(): - self._id_by_unique_key_value.setdefault(key, {})[value] = item["id"] + self._id_by_unique_key_value.setdefault(key, {})[value] = id_ - def _remove_unique(self, item): + def remove_unique(self, item): + id_ = item["id"] for key, value in item.unique_values(): id_by_value = self._id_by_unique_key_value.get(key, {}) - if id_by_value.get(value) == item["id"]: + if id_by_value.get(value) == id_: del id_by_value[value] def add_item(self, item, new=False): @@ -419,28 +421,28 @@ def add_item(self, item, new=False): if "id" not in item or not item.is_id_valid: item["id"] = self._new_id() self[item["id"]] = item - self._add_unique(item) + self.add_unique(item) return item def update_item(self, item): current_item = self.find_item(item) - self._remove_unique(current_item) + current_item.cascade_remove_unique() current_item.update(item) - self._add_unique(current_item) + current_item.cascade_add_unique() current_item.cascade_update() return current_item def remove_item(self, id_): current_item = self.find_item({"id": id_}) if current_item is not None: - self._remove_unique(current_item) + self.remove_unique(current_item) current_item.cascade_remove() return current_item def restore_item(self, id_): current_item = self.find_item({"id": id_}) if current_item is not None: - self._add_unique(current_item) + self.add_unique(current_item) current_item.cascade_restore() return current_item @@ -449,7 +451,7 @@ class MappedItemBase(dict): """A dictionary that represents a db item.""" fields = {} - """A dictionaty mapping fields to a tuple of (type, value description)""" + """A dictionary mapping fields to a tuple of (type, value description)""" _defaults = {} """A dictionary mapping keys to their default values""" _unique_keys = () @@ -719,8 +721,8 @@ def _invalidate_ref(self, ref_type, ref_id): """Invalidates a reference previously collected from the cache. Args: - ref_type (str): The references's type - ref_id (int): The references's id + ref_type (str): The reference's type + ref_id (int): The reference's id """ ref = self._db_cache.get_mapped_item(ref_type, ref_id) ref.remove_referrer(self) @@ -850,6 +852,20 @@ def call_update_callbacks(self): obsolete.add(callback) self.update_callbacks -= obsolete + def cascade_add_unique(self): + """Removes item and all its referrers unique keys and ids in cascade.""" + mapped_table = self._db_cache.mapped_table(self._item_type) + mapped_table.add_unique(self) + for referrer in self._referrers.values(): + referrer.cascade_add_unique() + + def cascade_remove_unique(self): + """Removes item and all its referrers unique keys and ids in cascade.""" + mapped_table = self._db_cache.mapped_table(self._item_type) + mapped_table.remove_unique(self) + for referrer in self._referrers.values(): + referrer.cascade_remove_unique() + def is_committed(self): """Returns whether or not this item is committed to the DB. diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 0cf2b495..a0809e57 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -169,6 +169,56 @@ def test_commit_multidimensional_parameter_value(self): value = from_database(color["value"], color["type"]) self.assertEqual(value, 0.23) + def test_updating_entity_name_updates_the_name_in_parameter_value_too(self): + with DatabaseMapping(IN_MEMORY_DB_URL, create=True) as db_map: + _, error = db_map.add_item("entity_class", name="fish", description="It swims.") + self.assertIsNone(error) + _, error = db_map.add_item( + "entity", class_name="fish", name="Nemo", description="Peacefully swimming away." + ) + self.assertIsNone(error) + _, error = db_map.add_item("parameter_definition", entity_class_name="fish", name="color") + self.assertIsNone(error) + value, type_ = to_database("mainly orange") + _, error = db_map.add_item( + "parameter_value", + entity_class_name="fish", + entity_byname=("Nemo",), + parameter_definition_name="color", + alternative_name="Base", + value=value, + type=type_, + ) + self.assertIsNone(error) + color = db_map.get_item( + "parameter_value", + entity_class_name="fish", + entity_byname=("Nemo",), + parameter_definition_name="color", + alternative_name="Base", + ) + self.assertIsNotNone(color) + fish = db_map.get_item("entity", class_name="fish", name="Nemo") + self.assertIsNotNone(fish) + fish.update(name="NotNemo") + self.assertEqual(fish["name"], "NotNemo") + not_color_anymore = db_map.get_item( + "parameter_value", + entity_class_name="fish", + entity_byname=("Nemo",), + parameter_definition_name="color", + alternative_name="Base", + ) + self.assertIsNone(not_color_anymore) + color = db_map.get_item( + "parameter_value", + entity_class_name="fish", + entity_byname=("NotNemo",), + parameter_definition_name="color", + alternative_name="Base", + ) + self.assertIsNotNone(color) + class TestDatabaseMappingLegacy(unittest.TestCase): """'Backward compatibility' tests, i.e. pre-entity tests converted to work with the entity structure.""" From 2aaaa993b5e3a061d1ab44df7b2d61550da0f7de Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Tue, 10 Oct 2023 14:10:45 +0300 Subject: [PATCH 126/317] Fix example code in tutorial. --- docs/source/tutorial.rst | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/tutorial.rst b/docs/source/tutorial.rst index 6fd2f309..619cf58e 100644 --- a/docs/source/tutorial.rst +++ b/docs/source/tutorial.rst @@ -150,9 +150,10 @@ Then we can retrieve the ``"color"`` of ``"Nemo"`` (in the ``"Base"`` alternativ color_value = db_map.get_item( "parameter_value", - class_name="fish", + entity_class_name="fish", entity_byname=("Nemo",), - alternative="Base" + parameter_definition_name="color", + alternative_name="Base" ) color = from_database(color_value["value"], color_value["type"]) print(color) # Prints 'mainly orange' From 86deb4b54af525185a59046ca518c86e2be0a67d Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 10 Oct 2023 16:06:15 +0200 Subject: [PATCH 127/317] Create PublicItems only once plus some renaming cache -> mapping --- spinedb_api/db_mapping.py | 74 +++------------------- spinedb_api/db_mapping_base.py | 109 ++++++++++++++++++++++++--------- spinedb_api/mapped_items.py | 14 ++--- 3 files changed, 97 insertions(+), 100 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 4fbac1fc..486dd840 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -374,12 +374,12 @@ def get_item(self, item_type, fetch=True, skip_removed=True, **kwargs): :class:`PublicItem` or None """ item_type = self._real_tablename(item_type) - cache_item = self.mapped_table(item_type).find_item(kwargs, fetch=fetch) - if not cache_item: + item = self.mapped_table(item_type).find_item(kwargs, fetch=fetch) + if not item: return None - if skip_removed and not cache_item.is_valid(): + if skip_removed and not item.is_valid(): return None - return PublicItem(self, cache_item) + return item.public_item def get_items(self, item_type, fetch=True, skip_removed=True): """Finds and returns all the items of one type. @@ -397,7 +397,7 @@ def get_items(self, item_type, fetch=True, skip_removed=True): self.fetch_all(item_type) mapped_table = self.mapped_table(item_type) get_items = mapped_table.valid_values if skip_removed else mapped_table.values - return [PublicItem(self, x) for x in get_items()] + return [x.public_item for x in get_items()] def add_item(self, item_type, check=True, **kwargs): """Adds an item to the in-memory mapping. @@ -423,7 +423,7 @@ def add_item(self, item_type, check=True, **kwargs): return mapped_table.add_item(kwargs, new=True), None checked_item, error = mapped_table.check_item(kwargs) return ( - PublicItem(self, mapped_table.add_item(checked_item, new=True)) if checked_item and not error else None, + mapped_table.add_item(checked_item, new=True).public_item if checked_item and not error else None, error, ) @@ -478,7 +478,7 @@ def update_item(self, item_type, check=True, **kwargs): if not check: return mapped_table.update_item(kwargs), None checked_item, error = mapped_table.check_item(kwargs, for_update=True) - return (PublicItem(self, mapped_table.update_item(checked_item._asdict())) if checked_item else None, error) + return (mapped_table.update_item(checked_item._asdict()).public_item if checked_item else None, error) def update_items(self, item_type, *items, check=True, strict=False): """Updates many items in the in-memory mapping. @@ -523,7 +523,7 @@ def remove_item(self, item_type, id_): """ item_type = self._real_tablename(item_type) mapped_table = self.mapped_table(item_type) - return PublicItem(self, mapped_table.remove_item(id_)) + return mapped_table.remove_item(id_).public_item def remove_items(self, item_type, *ids): """Removes many items from the in-memory mapping. @@ -566,7 +566,7 @@ def restore_item(self, item_type, id_): """ item_type = self._real_tablename(item_type) mapped_table = self.mapped_table(item_type) - return PublicItem(self, mapped_table.restore_item(id_)) + return mapped_table.restore_item(id_).public_item def restore_items(self, item_type, *ids): """Restores many previously removed items into the in-memory mapping. @@ -617,7 +617,7 @@ def fetch_more(self, item_type, limit=None): list(:class:`PublicItem`): The items fetched. """ item_type = self._real_tablename(item_type) - return [PublicItem(self, x) for x in self.do_fetch_more(item_type, limit=limit)] + return [x.public_item for x in self.do_fetch_more(item_type, limit=limit)] def fetch_all(self, *item_types): """Fetches items from the DB into the in-memory mapping. @@ -747,57 +747,3 @@ def get_filter_configs(self): list(dict): """ return self._filter_configs - - -class PublicItem: - def __init__(self, db_map, cache_item): - self._db_map = db_map - self._cache_item = cache_item - - @property - def item_type(self): - return self._cache_item.item_type - - def __getitem__(self, key): - return self._cache_item[key] - - def __eq__(self, other): - if isinstance(other, dict): - return self._cache_item == other - return super().__eq__(other) - - def __repr__(self): - return repr(self._cache_item) - - def __str__(self): - return str(self._cache_item) - - def get(self, key, default=None): - return self._cache_item.get(key, default) - - def is_valid(self): - return self._cache_item.is_valid() - - def is_committed(self): - return self._cache_item.is_committed() - - def _asdict(self): - return self._cache_item._asdict() - - def update(self, **kwargs): - self._db_map.update_item(self.item_type, id=self["id"], **kwargs) - - def remove(self): - return self._db_map.remove_item(self.item_type, self["id"]) - - def restore(self): - return self._db_map.restore_item(self.item_type, self["id"]) - - def add_update_callback(self, callback): - self._cache_item.update_callbacks.add(callback) - - def add_remove_callback(self, callback): - self._cache_item.remove_callbacks.add(callback) - - def add_restore_callback(self, callback): - self._cache_item.restore_callbacks.add(callback) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 164974de..e4334df1 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -20,7 +20,7 @@ @unique class Status(Enum): - """Cache item status.""" + """Mapped item status.""" committed = auto() to_add = auto() @@ -194,7 +194,7 @@ def _get_next_chunk(self, item_type, limit): def _advance_query(self, item_type, limit): """Advances the DB query that fetches items of given type - and adds the results to the corresponding table cache. + and adds the results to the corresponding mapped table. Args: item_type (str) @@ -207,9 +207,7 @@ def _advance_query(self, item_type, limit): self._fetched_item_types.add(item_type) return [] mapped_table = self.mapped_table(item_type) - for item in chunk: - mapped_table.add_item(item) - return chunk + return [mapped_table.add_item(item) for item in chunk] def mapped_table(self, item_type): return self._mapped_tables.setdefault(item_type, _MappedTable(self, item_type)) @@ -260,14 +258,14 @@ def fetch_ref(self, item_type, id_): class _MappedTable(dict): - def __init__(self, db_cache, item_type, *args, **kwargs): + def __init__(self, db_map, item_type, *args, **kwargs): """ Args: - db_cache (DBCacheBase): the DB cache where this table cache belongs. + db_map (DatabaseMappingBase): the DB mapping where this mapped table belongs. item_type (str): the item type, equal to a table name """ super().__init__(*args, **kwargs) - self._db_cache = db_cache + self._db_map = db_map self._item_type = item_type self._id_by_unique_key_value = {} self._temp_id_by_db_id = {} @@ -314,15 +312,15 @@ def valid_values(self): return (x for x in self.values() if x.is_valid()) def _make_item(self, item): - """Returns a cache item. + """Returns a mapped item. Args: item (dict): the 'db item' to use as base Returns: - CacheItem + MappedItem """ - return self._db_cache.make_item(self._item_type, **item) + return self._db_map.make_item(self._item_type, **item) def find_item(self, item, skip_keys=(), fetch=True): """Returns a MappedItemBase that matches the given dictionary-item. @@ -339,17 +337,17 @@ def find_item(self, item, skip_keys=(), fetch=True): item = self.get(id_) if item or not fetch: return item - return self._db_cache.fetch_ref(self._item_type, id_) + return self._db_map.fetch_ref(self._item_type, id_) # No id. Try to locate the item by the value of one of the unique keys. # Used by import_data (and more...) - cache_item = self._make_item(item) - error = cache_item.resolve_inverse_references(item.keys()) + mapped_item = self._make_item(item) + error = mapped_item.resolve_inverse_references(item.keys()) if error: return None - error = cache_item.polish() + error = mapped_item.polish() if error: return None - for key, value in cache_item.unique_values(skip_keys=skip_keys): + for key, value in mapped_item.unique_values(skip_keys=skip_keys): current_item = self._unique_key_value_to_item(key, value, fetch=fetch) if current_item: return current_item @@ -411,10 +409,10 @@ def add_item(self, item, new=False): # Item comes from the DB id_ = item["id"] if id_ in self or id_ in self._temp_id_by_db_id: - # The item is already in the cache + # The item is already in the mapping return if any(value in self._id_by_unique_key_value.get(key, {}) for key, value in item.unique_values()): - # An item with the same unique key is already in the cache + # An item with the same unique key is already in the mapping return else: item.status = Status.to_add @@ -478,13 +476,13 @@ class MappedItemBase(dict): 3. return the id of that item. """ - def __init__(self, db_cache, item_type, **kwargs): + def __init__(self, db_map, item_type, **kwargs): """ Args: - db_cache (DBCacheBase): the DB cache where this item belongs. + db_map (DatabaseMappingBase): the DB where this item belongs. """ super().__init__(**kwargs) - self._db_cache = db_cache + self._db_map = db_map self._item_type = item_type self._referrers = {} self._weak_referrers = {} @@ -500,6 +498,7 @@ def __init__(self, db_cache, item_type, **kwargs): self._removal_source = None self._status_when_removed = None self._backup = None + self.public_item = PublicItem(self._db_map, self) @classmethod def ref_types(cls): @@ -661,7 +660,7 @@ def resolve_inverse_references(self, skip_keys=()): id_value = tuple(dict.pop(self, k, None) or self.get(k) for k in id_key) if None in id_value: continue - mapped_table = self._db_cache.mapped_table(ref_type) + mapped_table = self._db_map.mapped_table(ref_type) try: self[src_key] = ( tuple(mapped_table.unique_key_value_to_id(ref_key, v, strict=True) for v in zip(*id_value)) @@ -685,7 +684,7 @@ def polish(self): return "" def _get_ref(self, ref_type, ref_id, strong=True): - """Collects a reference from the cache. + """Collects a reference from the in-memory mapping. Adds this item to the reference's list of referrers if strong is True; or weak referrers if strong is False. If the reference is not found, sets some flags. @@ -698,11 +697,11 @@ def _get_ref(self, ref_type, ref_id, strong=True): Returns: MappedItemBase or dict """ - ref = self._db_cache.get_mapped_item(ref_type, ref_id) + ref = self._db_map.get_mapped_item(ref_type, ref_id) if not ref: if not strong: return {} - ref = self._db_cache.fetch_ref(ref_type, ref_id) + ref = self._db_map.fetch_ref(ref_type, ref_id) if not ref: self._corrupted = True return {} @@ -718,18 +717,18 @@ def _get_ref(self, ref_type, ref_id, strong=True): return ref def _invalidate_ref(self, ref_type, ref_id): - """Invalidates a reference previously collected from the cache. + """Invalidates a reference previously collected from the in-memory mapping. Args: ref_type (str): The reference's type ref_id (int): The reference's id """ - ref = self._db_cache.get_mapped_item(ref_type, ref_id) + ref = self._db_map.get_mapped_item(ref_type, ref_id) ref.remove_referrer(self) def is_valid(self): """Checks if this item has all its references. - Removes the item from the cache if not valid by calling ``cascade_remove``. + Removes the item from the in-memory mapping if not valid by calling ``cascade_remove``. Returns: bool @@ -932,3 +931,57 @@ def update(self, other): super().update(other) if self._asdict() == self._backup: self._status = Status.committed + + +class PublicItem: + def __init__(self, db_map, mapped_item): + self._db_map = db_map + self._mapped_item = mapped_item + + @property + def item_type(self): + return self._mapped_item.item_type + + def __getitem__(self, key): + return self._mapped_item[key] + + def __eq__(self, other): + if isinstance(other, dict): + return self._mapped_item == other + return super().__eq__(other) + + def __repr__(self): + return repr(self._mapped_item) + + def __str__(self): + return str(self._mapped_item) + + def get(self, key, default=None): + return self._mapped_item.get(key, default) + + def is_valid(self): + return self._mapped_item.is_valid() + + def is_committed(self): + return self._mapped_item.is_committed() + + def _asdict(self): + return self._mapped_item._asdict() + + def update(self, **kwargs): + self._db_map.update_item(self.item_type, id=self["id"], **kwargs) + + def remove(self): + return self._db_map.remove_item(self.item_type, self["id"]) + + def restore(self): + return self._db_map.restore_item(self.item_type, self["id"]) + + def add_update_callback(self, callback): + self._mapped_item.update_callbacks.add(callback) + + def add_remove_callback(self, callback): + self._mapped_item.remove_callbacks.add(callback) + + def add_restore_callback(self, callback): + self._mapped_item.restore_callbacks.add(callback) diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 7d8e0ae8..4f2cf423 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -9,8 +9,6 @@ # this program. If not, see . ###################################################################################################################### -# The Spine implementation for DBCacheBase - import uuid from operator import itemgetter from .parameter_value import to_database, from_database, ParameterValueFormatError @@ -116,7 +114,7 @@ def polish(self): return base_name = self["class_name"] + "_" + "__".join(self["element_name_list"]) name = base_name - mapped_table = self._db_cache.mapped_table(self._item_type) + mapped_table = self._db_map.mapped_table(self._item_type) while mapped_table.unique_key_value_to_id(("class_name", "name"), (self["class_name"], name)) is not None: name = base_name + "_" + uuid.uuid4().hex self["name"] = name @@ -268,7 +266,7 @@ def polish(self): parsed_value = from_database(default_value, default_type) if parsed_value is None: return - list_value_id = self._db_cache.mapped_table("list_value").unique_key_value_to_id( + list_value_id = self._db_map.mapped_table("list_value").unique_key_value_to_id( ("parameter_value_list_name", "value", "type"), (list_name, default_value, default_type) ) if list_value_id is None: @@ -289,7 +287,7 @@ def merge(self, other): and other_parameter_value_list_id != self["parameter_value_list_id"] and any( x["parameter_definition_id"] == self["id"] - for x in self._db_cache.mapped_table("parameter_value").valid_values() + for x in self._db_map.mapped_table("parameter_value").valid_values() ) ): del other["parameter_value_list_id"] @@ -375,7 +373,7 @@ def polish(self): parsed_value = from_database(value, type_) if parsed_value is None: return - list_value_id = self._db_cache.mapped_table("list_value").unique_key_value_to_id( + list_value_id = self._db_map.mapped_table("list_value").unique_key_value_to_id( ("parameter_value_list_name", "value", "type"), (list_name, value, type_) ) if list_value_id is None: @@ -442,11 +440,11 @@ def __getitem__(self, key): if key == "alternative_name_list": return [x["alternative_name"] for x in self.sorted_scenario_alternatives] if key == "sorted_scenario_alternatives": - self._db_cache.do_fetch_all("scenario_alternative") + self._db_map.do_fetch_all("scenario_alternative") return sorted( ( x - for x in self._db_cache.mapped_table("scenario_alternative").valid_values() + for x in self._db_map.mapped_table("scenario_alternative").valid_values() if x["scenario_id"] == self["id"] ), key=itemgetter("rank"), From 8ce8f0144139a743da0cb87cd93b3a4c62529f8d Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 10 Oct 2023 20:05:50 +0200 Subject: [PATCH 128/317] Fix typos --- spinedb_api/db_mapping_base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index e4334df1..b8b4a04d 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -297,7 +297,7 @@ def unique_key_value_to_id(self, key, value, strict=False, fetch=True): """ id_by_unique_value = self._id_by_unique_key_value.get(key, {}) if not id_by_unique_value and fetch: - id_by_unique_value = self._db_cache.fetch_value( + id_by_unique_value = self._db_map.fetch_value( self._item_type, lambda: self._id_by_unique_key_value.get(key, {}) ) value = tuple(tuple(x) if isinstance(x, list) else x for x in value) @@ -853,14 +853,14 @@ def call_update_callbacks(self): def cascade_add_unique(self): """Removes item and all its referrers unique keys and ids in cascade.""" - mapped_table = self._db_cache.mapped_table(self._item_type) + mapped_table = self._db_map.mapped_table(self._item_type) mapped_table.add_unique(self) for referrer in self._referrers.values(): referrer.cascade_add_unique() def cascade_remove_unique(self): """Removes item and all its referrers unique keys and ids in cascade.""" - mapped_table = self._db_cache.mapped_table(self._item_type) + mapped_table = self._db_map.mapped_table(self._item_type) mapped_table.remove_unique(self) for referrer in self._referrers.values(): referrer.cascade_remove_unique() From e4d1f3b273289e550df52b8375379c2a8653bdc8 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 11 Oct 2023 09:00:27 +0200 Subject: [PATCH 129/317] Try to generate DBMapping convenience methods and docs lazily --- docs/source/conf.py | 51 +++++++++++++++++++++++++++++++++- spinedb_api/db_mapping.py | 46 +++++++++++++++++++++++------- spinedb_api/db_mapping_base.py | 14 ++++++++++ spinedb_api/mapped_items.py | 27 ++++++++++++------ 4 files changed, 118 insertions(+), 20 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 051a192e..241a6285 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -142,10 +142,59 @@ def _process_docstring(app, what, name, obj, options, lines): new_lines.extend([f" * - {f_names}"]) lines[i : i + 1] = new_lines # Expand - spine_item_types = ", ".join([f"`{x}`" for x in DatabaseMapping.item_types()]) + spine_item_types = ", ".join([f"``{x}``" for x in DatabaseMapping.item_types()]) for k, line in enumerate(lines): if "" in line: lines[k] = line.replace("", spine_item_types) + # Expand + if lines[0] == "": + item_type = name.split("get_")[1].split("_item")[0] + a = "an" if any(item_type.lower().startswith(x) for x in "aeiou") else "a" + factory = DatabaseMapping._item_factory(item_type) + new_lines = [f"Finds and returns {a} `{item_type}` matching the arguments, or None if none found.", ""] + new_lines.extend( + [ + f":param fetch: Whether to fetch the DB in case the `{item_type}` is not found in memory.", + ":type fetch: bool, optional", + "", + ":param skip_removed: Whether to ignore removed items.", + ":type skip_removed: bool, optional", + "", + ] + ) + uq_f_names = {f_name: None for f_names in factory._unique_keys for f_name in f_names} + for f_name in uq_f_names: + f_type, f_value = factory.fields[f_name] + new_lines.extend([f":param {f_name}: {f_value}", f":type {f_name}: {f_type}", ""]) + new_lines.extend([f":returns: The `{item_type}` if found.", ":rtype: :class:`PublicItem` or None", ""]) + lines[0:1] = new_lines + # Expand , + if lines[0] in ("", ""): + update = lines[0] == "" + head = "update_" if update else "add_" + item_type = name.split(head)[1].split("_item")[0] + a = "an" if any(item_type.lower().startswith(x) for x in "aeiou") else "a" + factory = DatabaseMapping._item_factory(item_type) + synopsis = ( + f"Updates {a} `{item_type}` in the in-memory mapping." + if update + else f"Adds {a} `{item_type}` to the in-memory mapping." + ) + new_lines = [synopsis, ""] + new_lines.extend([":param check: Whether to carry out integrity checks.", ":type check: bool, optional", ""]) + if update: + new_lines.extend([f":param id: The id of the `{item_type}` to update.", ":type id: int", ""]) + uq_f_names = {f_name: None for f_names in factory._unique_keys for f_name in f_names} + for f_name, (f_type, f_value) in factory.fields.items(): + new_lines.extend([f":param {f_name}: {f_value}", f":type {f_name}: {f_type}", ""]) + new_lines.extend( + [ + f":returns: The {'updated' if update else 'added'} `{item_type}` and any errors.", + ":rtype: tuple(:class:`PublicItem` or None, str)", + "", + ] + ) + lines[0:1] = new_lines def setup(sphinx): diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 486dd840..d7bbc34e 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -12,14 +12,14 @@ """ This module defines the :class:`.DatabaseMapping` class, the main mean to communicate with a Spine DB. If you're planning to use this class, it is probably a good idea to first familiarize yourself a little bit with the -DB mapping schema. +DB mapping schema below. DB mapping schema ================= -The DB mapping schema is a close cousin of the Spine DB schema, with some extra flexibility such as -(or should I say, mainly) the ability to define references by name rather than by numerical id. +The DB mapping schema is a close cousin of the Spine DB schema with some extra flexibility, +like the ability to specify references by name rather than by numerical id. The schema defines the following item types: . As you can see, these follow the names of some of the tables in the Spine DB schema. @@ -33,6 +33,8 @@ import os import time import logging +import astroid +from functools import partialmethod from datetime import datetime, timezone from types import MethodType from sqlalchemy import create_engine, MetaData, inspect @@ -77,6 +79,8 @@ class DatabaseMapping(DatabaseMappingQueryMixin, DatabaseMappingCommitMixin, Dat otherwise it is fetched from the DB, stored in memory, and then returned. In other words, the data is fetched from the DB exactly once. + For convenience, we also provide specialized 'get' methods for each item type, e.g., :meth:`get_entity_item`. + Data is added via :meth:`add_item`; updated via :meth:`update_item`; removed via :meth:`remove_item`; @@ -85,6 +89,10 @@ class DatabaseMapping(DatabaseMappingQueryMixin, DatabaseMappingCommitMixin, Dat These methods also fetch data from the DB into the in-memory mapping to perform the necessary integrity checks (unique and foreign key constraints). + For convenience, we also provide specialized 'add', 'update', 'remove', and 'restore' methods + for each item type, e.g., + :meth:`add_entity_item`, :meth:`update_entity_item`, :meth:`remove_entity_item`, :meth:`restore_entity_item`. + Modifications to the in-memory mapping are committed (written) to the DB via :meth:`commit_session`, or rolled back (discarded) via :meth:`rollback_session`. @@ -107,15 +115,15 @@ class DatabaseMapping(DatabaseMappingQueryMixin, DatabaseMappingCommitMixin, Dat """ _sq_name_by_item_type = { + "alternative": "alternative_sq", + "scenario": "scenario_sq", + "scenario_alternative": "scenario_alternative_sq", "entity_class": "wide_entity_class_sq", "entity": "wide_entity_sq", + "entity_group": "entity_group_sq", "entity_alternative": "entity_alternative_sq", "parameter_value_list": "parameter_value_list_sq", "list_value": "list_value_sq", - "alternative": "alternative_sq", - "scenario": "scenario_sq", - "scenario_alternative": "scenario_alternative_sq", - "entity_group": "entity_group_sq", "parameter_definition": "parameter_definition_sq", "parameter_value": "parameter_value_sq", "metadata": "metadata_sq", @@ -368,7 +376,7 @@ def get_item(self, item_type, fetch=True, skip_removed=True, **kwargs): item_type (str): One of . fetch (bool, optional): Whether to fetch the DB in case the item is not found in memory. skip_removed (bool, optional): Whether to ignore removed items. - **kwargs: Fields and values for one the unique keys of the item type as specified in `DB mapping schema`_. + **kwargs: Fields and values for one the unique keys as specified for the item type in `DB mapping schema`_. Returns: :class:`PublicItem` or None @@ -411,7 +419,7 @@ def add_item(self, item_type, check=True, **kwargs): Args: item_type (str): One of . check (bool, optional): Whether to carry out integrity checks. - **kwargs: Fields and values of the item type as specified in `DB mapping schema`_. + **kwargs: Fields and values as specified for the item type in `DB mapping schema`_. Returns: tuple(:class:`PublicItem` or None, str): The added item and any errors. @@ -467,7 +475,7 @@ def update_item(self, item_type, check=True, **kwargs): item_type (str): One of . check (bool, optional): Whether to carry out integrity checks. id (int): The id of the item to update. - **kwargs: Fields to update and their new values as specified in `DB mapping schema`_. + **kwargs: Fields to update and their new values as specified for the item type in `DB mapping schema`_. Returns: tuple(:class:`PublicItem` or None, str): The updated item and any errors. @@ -747,3 +755,21 @@ def get_filter_configs(self): list(dict): """ return self._filter_configs + + +for x in DatabaseMapping.item_types(): + setattr(DatabaseMapping, "add_" + x, partialmethod(DatabaseMapping.add_item, x)) + + +def format_to_fstring_transform2(node): + if node.name == "DatabaseMapping": + f = astroid.FunctionDef("get_entity_class", lineno=node.lineno, col_offset=node.col_offset, parent=node) + f.postinit(doc_node=astroid.nodes.Const("Do stuff")) + print(f) + # node.postinit(body=node.body + [x]) + # print(node.body) + return node + + +astroid.MANAGER.register_transform(astroid.ClassDef, format_to_fstring_transform2) +# astroid.MANAGER.register_transform(astroid.FunctionDef, format_to_fstring_transform2) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index b8b4a04d..01136b8f 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -11,7 +11,9 @@ import threading from enum import Enum, unique, auto +from difflib import SequenceMatcher from .temp_id import TempId +from .exception import SpineDBAPIError # TODO: Implement MappedItem.pop() to do lookup? @@ -209,13 +211,21 @@ def _advance_query(self, item_type, limit): mapped_table = self.mapped_table(item_type) return [mapped_table.add_item(item) for item in chunk] + def _check_item_type(self, item_type): + if item_type not in self.item_types(): + candidate = max(self.item_types(), key=lambda x: SequenceMatcher(None, item_type, x).ratio()) + raise SpineDBAPIError(f"Invalid item type '{item_type}' - maybe you meant '{candidate}'?") + def mapped_table(self, item_type): + self._check_item_type(item_type) return self._mapped_tables.setdefault(item_type, _MappedTable(self, item_type)) def get(self, item_type, default=None): + self._check_item_type(item_type) return self._mapped_tables.get(item_type, default) def pop(self, item_type, default): + self._check_item_type(item_type) return self._mapped_tables.pop(item_type, default) def clear(self): @@ -340,6 +350,10 @@ def find_item(self, item, skip_keys=(), fetch=True): return self._db_map.fetch_ref(self._item_type, id_) # No id. Try to locate the item by the value of one of the unique keys. # Used by import_data (and more...) + # FIXME: Do we really need to make the MappedItem here? + # Can't we just obtain the unique_values directly from item? + # I guess it's needed in case the user specifies stuff like 'class_id', as tests do, + # but that should be a corner case... mapped_item = self._make_item(item) error = mapped_item.resolve_inverse_references(item.keys()) if error: diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 4f2cf423..14c9411e 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -75,8 +75,13 @@ def commit(self, _commit_id): class EntityItem(MappedItemBase): fields = { "class_name": ("str", "The entity class name."), - "name": ("str, optional", "The entity name - must be given for a zero-dimensional entity."), - "element_name_list": ("tuple, optional", "The element names - must be given for a multi-dimensional entity."), + "name": ("str", "The entity name."), + "element_name_list": ("tuple", "The element names if the entity is multi-dimensional."), + "byname": ( + "tuple", + "A tuple with the entity name as single element if the entity is zero-dimensional, " + "or the element name list if it is multi-dimensional.", + ), "description": ("str, optional", "The entity description."), } _defaults = {"description": None} @@ -126,7 +131,7 @@ class EntityGroupItem(MappedItemBase): "group_name": ("str", "The group entity name."), "member_name": ("str", "The member entity name."), } - _unique_keys = (("group_name", "member_name"),) + _unique_keys = (("class_name", "group_name", "member_name"),) _references = { "class_name": ("entity_class_id", ("entity_class", "name")), "group_name": ("entity_id", ("entity", "name")), @@ -151,8 +156,9 @@ class EntityAlternativeItem(MappedItemBase): fields = { "entity_class_name": ("str", "The entity class name."), "entity_byname": ( - "str or tuple", - "The entity name for a zero-dimensional entity, or the element name list for a multi-dimensional one.", + "tuple", + "A tuple with the entity name as single element if the entity is zero-dimensional, " + "or the element name list if it is multi-dimensional.", ), "alternative_name": ("str", "The alternative name."), "active": ("bool, optional", "Whether the entity is active in the alternative - defaults to True."), @@ -303,8 +309,9 @@ class ParameterValueItem(ParsedValueBase): "entity_class_name": ("str", "The entity class name."), "parameter_definition_name": ("str", "The parameter name."), "entity_byname": ( - "str or tuple", - "The entity name for a zero-dimensional entity, or the element name list for a multi-dimensional one.", + "tuple", + "A tuple with the entity name as single element if the entity is zero-dimensional, " + "or the element name list if the entity is multi-dimensional.", ), "value": ("any", "The value."), "type": ("str", "The value type."), @@ -511,10 +518,12 @@ class ParameterValueMetadataItem(MappedItemBase): fields = { "parameter_definition_name": ("str", "The parameter name."), "entity_byname": ( - "str or tuple", - "The entity name for a zero-dimensional entity, or the element name list for a multi-dimensional one.", + "tuple", + "A tuple with the entity name as single element if the entity is zero-dimensional, " + "or the element name list if it is multi-dimensional.", ), "alternative_name": ("str", "The alternative name."), + "metadata_name": ("str", "The metadata entry name."), "metadata_value": ("str", "The metadata entry value."), } _unique_keys = ( From f5086d4b25d3d309278764309b560024cf686154 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Wed, 11 Oct 2023 12:23:50 +0300 Subject: [PATCH 130/317] Fix bug with alternative export mapping Export AlternativeMapping expected to be part of parameter value export when it wasn't the root mapping. However, AlternativeMapping should be usable as stand alone mapping even when not root. Re spine-tools/Spine-Toolbox#2339 --- spinedb_api/export_mapping/export_mapping.py | 9 +++-- spinedb_api/import_functions.py | 2 +- tests/export_mapping/test_export_mapping.py | 42 ++++++++++++++++++++ 3 files changed, 49 insertions(+), 4 deletions(-) diff --git a/spinedb_api/export_mapping/export_mapping.py b/spinedb_api/export_mapping/export_mapping.py index df8e33c9..e5ccef3e 100644 --- a/spinedb_api/export_mapping/export_mapping.py +++ b/spinedb_api/export_mapping/export_mapping.py @@ -1279,9 +1279,12 @@ def add_query_columns(self, db_map, query): ) def filter_query(self, db_map, query): - if self.parent is None: - return query - return query.filter(db_map.alternative_sq.c.id == db_map.parameter_value_sq.c.alternative_id) + parent = self.parent + while parent is not None: + if isinstance(parent, ParameterDefinitionMapping): + return query.filter(db_map.alternative_sq.c.id == db_map.parameter_value_sq.c.alternative_id) + parent = parent.parent + return query @staticmethod def name_field(): diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index e592dc1e..59a8ddf8 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -351,7 +351,7 @@ def import_scenarios(db_map, data): Args: db_map (spinedb_api.DiffDatabaseMapping): database mapping - data (list(str,str)): tuples of (name, description) + data (list(str, bool, str)): tuples of (name, , description) Returns: int: number of items imported diff --git a/tests/export_mapping/test_export_mapping.py b/tests/export_mapping/test_export_mapping.py index ccafe03f..b8ea8063 100644 --- a/tests/export_mapping/test_export_mapping.py +++ b/tests/export_mapping/test_export_mapping.py @@ -38,6 +38,7 @@ entity_export, ) from spinedb_api.export_mapping.export_mapping import ( + AlternativeDescriptionMapping, AlternativeMapping, drop_non_positioned_tail, FixedValueMapping, @@ -60,6 +61,7 @@ ElementMapping, ScenarioActiveFlagMapping, ScenarioAlternativeMapping, + ScenarioDescriptionMapping, ScenarioMapping, ) from spinedb_api.mapping import unflatten @@ -1568,6 +1570,46 @@ def test_export_object_parameters_while_exporting_relationships_with_multiple_pa self.assertEqual(list(rows(root_mapping, db_map)), expected) db_map.close() + def test_alternative_mapping_with_header_and_description(self): + root_mapping = AlternativeMapping(0, header="alternative") + root_mapping.child = AlternativeDescriptionMapping(1, header="description") + with DatabaseMapping("sqlite://", create=True) as db_map: + expected = [["alternative", "description"], ["Base", "Base alternative"]] + self.assertEqual(list(rows(root_mapping, db_map)), expected) + + def test_fixed_value_and_alternative_mappings_with_header_and_description(self): + root_mapping = FixedValueMapping(Position.table_name, value="Alternative") + alternative_mapping = root_mapping.child = AlternativeMapping(0, header="alternative") + alternative_mapping.child = AlternativeDescriptionMapping(1, header="description") + with DatabaseMapping("sqlite://", create=True) as db_map: + expected = [["alternative", "description"], ["Base", "Base alternative"]] + self.assertEqual(list(rows(root_mapping, db_map)), expected) + + def test_fixed_value_and_scenario_mappings_with_header_and_description(self): + root_mapping = FixedValueMapping(Position.table_name, value="Scenario") + scenario_mapping = root_mapping.child = ScenarioMapping(0, header="scenario") + scenario_mapping.child = ScenarioDescriptionMapping(1, header="description") + with DatabaseMapping("sqlite://", create=True) as db_map: + import_scenarios(db_map, (("scenario1", False, "Scenario with Base alternative"),)) + db_map.commit_session("Add test data.") + expected = [["scenario", "description"], ["scenario1", "Scenario with Base alternative"]] + self.assertEqual(list(rows(root_mapping, db_map)), expected) + + def test_rows_from_scenario_mappings_after_rows_from_alternative_mappings(self): + root_mapping1 = FixedValueMapping(Position.table_name, value="Alternative") + alternative_mapping = root_mapping1.child = AlternativeMapping(0, header="alternative") + alternative_mapping.child = AlternativeDescriptionMapping(1, header="description") + root_mapping2 = FixedValueMapping(Position.table_name, value="Scenario") + scenario_mapping = root_mapping2.child = ScenarioMapping(0, header="scenario") + scenario_mapping.child = ScenarioDescriptionMapping(1, header="description") + with DatabaseMapping("sqlite://", create=True) as db_map: + import_scenarios(db_map, (("scenario1", False, "Scenario with Base alternative"),)) + db_map.commit_session("Add test data.") + expected1 = [["alternative", "description"], ["Base", "Base alternative"]] + self.assertEqual(list(rows(root_mapping1, db_map)), expected1) + expected2 = [["scenario", "description"], ["scenario1", "Scenario with Base alternative"]] + self.assertEqual(list(rows(root_mapping2, db_map)), expected2) + if __name__ == "__main__": unittest.main() From 20ddfa0ec40d35d52afb25b9774cac52b47cb185 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 11 Oct 2023 12:25:51 +0200 Subject: [PATCH 131/317] Add convenience methods to act on specific item types --- docs/source/conf.py | 49 -------------- spinedb_api/db_mapping.py | 133 ++++++++++++++++++++++++++++++++------ 2 files changed, 115 insertions(+), 67 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 241a6285..2c0b7836 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -146,55 +146,6 @@ def _process_docstring(app, what, name, obj, options, lines): for k, line in enumerate(lines): if "" in line: lines[k] = line.replace("", spine_item_types) - # Expand - if lines[0] == "": - item_type = name.split("get_")[1].split("_item")[0] - a = "an" if any(item_type.lower().startswith(x) for x in "aeiou") else "a" - factory = DatabaseMapping._item_factory(item_type) - new_lines = [f"Finds and returns {a} `{item_type}` matching the arguments, or None if none found.", ""] - new_lines.extend( - [ - f":param fetch: Whether to fetch the DB in case the `{item_type}` is not found in memory.", - ":type fetch: bool, optional", - "", - ":param skip_removed: Whether to ignore removed items.", - ":type skip_removed: bool, optional", - "", - ] - ) - uq_f_names = {f_name: None for f_names in factory._unique_keys for f_name in f_names} - for f_name in uq_f_names: - f_type, f_value = factory.fields[f_name] - new_lines.extend([f":param {f_name}: {f_value}", f":type {f_name}: {f_type}", ""]) - new_lines.extend([f":returns: The `{item_type}` if found.", ":rtype: :class:`PublicItem` or None", ""]) - lines[0:1] = new_lines - # Expand , - if lines[0] in ("", ""): - update = lines[0] == "" - head = "update_" if update else "add_" - item_type = name.split(head)[1].split("_item")[0] - a = "an" if any(item_type.lower().startswith(x) for x in "aeiou") else "a" - factory = DatabaseMapping._item_factory(item_type) - synopsis = ( - f"Updates {a} `{item_type}` in the in-memory mapping." - if update - else f"Adds {a} `{item_type}` to the in-memory mapping." - ) - new_lines = [synopsis, ""] - new_lines.extend([":param check: Whether to carry out integrity checks.", ":type check: bool, optional", ""]) - if update: - new_lines.extend([f":param id: The id of the `{item_type}` to update.", ":type id: int", ""]) - uq_f_names = {f_name: None for f_names in factory._unique_keys for f_name in f_names} - for f_name, (f_type, f_value) in factory.fields.items(): - new_lines.extend([f":param {f_name}: {f_value}", f":type {f_name}: {f_type}", ""]) - new_lines.extend( - [ - f":returns: The {'updated' if update else 'added'} `{item_type}` and any errors.", - ":rtype: tuple(:class:`PublicItem` or None, str)", - "", - ] - ) - lines[0:1] = new_lines def setup(sphinx): diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index d7bbc34e..afd74987 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -33,7 +33,6 @@ import os import time import logging -import astroid from functools import partialmethod from datetime import datetime, timezone from types import MethodType @@ -513,7 +512,7 @@ def update_items(self, item_type, *items, check=True, strict=False): updated.append(item) return updated, errors - def remove_item(self, item_type, id_): + def remove_item(self, item_type, id): """Removes an item from the in-memory mapping. Example:: @@ -531,7 +530,7 @@ def remove_item(self, item_type, id_): """ item_type = self._real_tablename(item_type) mapped_table = self.mapped_table(item_type) - return mapped_table.remove_item(id_).public_item + return mapped_table.remove_item(id).public_item def remove_items(self, item_type, *ids): """Removes many items from the in-memory mapping. @@ -556,7 +555,7 @@ def remove_items(self, item_type, *ids): ids.discard(1) return [self.remove_item(item_type, id_) for id_ in ids] - def restore_item(self, item_type, id_): + def restore_item(self, item_type, id): """Restores a previously removed item into the in-memory mapping. Example:: @@ -574,7 +573,7 @@ def restore_item(self, item_type, id_): """ item_type = self._real_tablename(item_type) mapped_table = self.mapped_table(item_type) - return mapped_table.restore_item(id_).public_item + return mapped_table.restore_item(id).public_item def restore_items(self, item_type, *ids): """Restores many previously removed items into the in-memory mapping. @@ -757,19 +756,117 @@ def get_filter_configs(self): return self._filter_configs -for x in DatabaseMapping.item_types(): - setattr(DatabaseMapping, "add_" + x, partialmethod(DatabaseMapping.add_item, x)) - - -def format_to_fstring_transform2(node): - if node.name == "DatabaseMapping": - f = astroid.FunctionDef("get_entity_class", lineno=node.lineno, col_offset=node.col_offset, parent=node) - f.postinit(doc_node=astroid.nodes.Const("Do stuff")) - print(f) - # node.postinit(body=node.body + [x]) - # print(node.body) +# Define convenience methods +for it in DatabaseMapping.item_types(): + setattr(DatabaseMapping, "get_" + it + "_item", partialmethod(DatabaseMapping.get_item, it)) + setattr(DatabaseMapping, "add_" + it + "_item", partialmethod(DatabaseMapping.add_item, it)) + setattr(DatabaseMapping, "update_" + it + "_item", partialmethod(DatabaseMapping.update_item, it)) + setattr(DatabaseMapping, "remove_" + it + "_item", partialmethod(DatabaseMapping.remove_item, it)) + setattr(DatabaseMapping, "restore_" + it + "_item", partialmethod(DatabaseMapping.restore_item, it)) + +# Astroid transform so DatabaseMapping looks like it has the convenience methods defined above +def _add_convenience_methods(node): + import astroid + + if node.name != "DatabaseMapping": + return node + for item_type in DatabaseMapping.item_types(): + factory = DatabaseMapping._item_factory(item_type) + uq_fields = {f_name: factory.fields[f_name] for f_names in factory._unique_keys for f_name in f_names} + a = "an" if any(item_type.lower().startswith(x) for x in "aeiou") else "a" + padding = 20 * " " + get_kwargs = f"\n{padding}".join( + [f"{f_name} ({f_type}): {f_value}" for f_name, (f_type, f_value) in uq_fields.items()] + ) + add_kwargs = f"\n{padding}".join( + [f"{f_name} ({f_type}): {f_value}" for f_name, (f_type, f_value) in factory.fields.items()] + ) + update_kwargs = f"id (int): The id of the item to update.\n{padding}" + add_kwargs + child = astroid.extract_node( + f''' + def get_{item_type}_item(self, fetch=True, skip_removed=True, **kwargs): + """Finds and returns {a} `{item_type}` item matching the arguments, or None if none found. + + Args: + fetch (bool, optional): Whether to fetch the DB in case the item is not found in memory. + skip_removed (bool, optional): Whether to ignore removed items. + {get_kwargs} + + Returns: + :class:`PublicItem` or None + """ + ''' + ) + child.parent = node + node.body.append(child) + child = astroid.extract_node( + f''' + def add_{item_type}_item(self, check=True, **kwargs): + """Adds {a} `{item_type}` item to the in-memory mapping. + + Args: + check (bool, optional): Whether to carry out integrity checks. + {add_kwargs} + + Returns: + tuple(:class:`PublicItem` or None, str): The added item and any errors. + """ + ''' + ) + child.parent = node + node.body.append(child) + child = astroid.extract_node( + f''' + def update_{item_type}_item(self, check=True, **kwargs): + """Updates {a} `{item_type}` item in the in-memory mapping. + + Args: + check (bool, optional): Whether to carry out integrity checks. + {update_kwargs} + + Returns: + tuple(:class:`PublicItem` or None, str): The updated item and any errors. + """ + ''' + ) + child.parent = node + node.body.append(child) + child = astroid.extract_node( + f''' + def remove_{item_type}_item(self, id): + """Removes {a} `{item_type}` item from the in-memory mapping. + + Args: + id (int): the id of the item to remove. + + Returns: + tuple(:class:`PublicItem` or None, str): The removed item if any. + """ + ''' + ) + child.parent = node + node.body.append(child) + child = astroid.extract_node( + f''' + def restore_{item_type}_item(self, id): + """Restores a previously removed `{item_type}` item into the in-memory mapping. + + Args: + id (int): the id of the item to restore. + + Returns: + tuple(:class:`PublicItem` or None, str): The restored item if any. + """ + ''' + ) + child.parent = node + node.body.append(child) return node -astroid.MANAGER.register_transform(astroid.ClassDef, format_to_fstring_transform2) -# astroid.MANAGER.register_transform(astroid.FunctionDef, format_to_fstring_transform2) +try: + import astroid + + astroid.MANAGER.register_transform(astroid.ClassDef, _add_convenience_methods) +except ModuleNotFoundError: + pass From 72c3008537eceabe10a41cce7bfa7017ecf13360 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 11 Oct 2023 15:18:23 +0200 Subject: [PATCH 132/317] Implement _extended for PublicItem --- spinedb_api/db_mapping_base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 01136b8f..b405871a 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -982,6 +982,9 @@ def is_committed(self): def _asdict(self): return self._mapped_item._asdict() + def _extended(self): + return self._mapped_item._extended() + def update(self, **kwargs): self._db_map.update_item(self.item_type, id=self["id"], **kwargs) From 86a2a0f562fb0e376ba052297a56079099c0324e Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 11 Oct 2023 16:22:23 +0200 Subject: [PATCH 133/317] Make byname recursive to accommodate meta-multi-D entities --- spinedb_api/mapped_items.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 14c9411e..b6a3e634 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -106,9 +106,18 @@ def __init__(self, *args, **kwargs): kwargs["element_id_list"] = tuple(element_id_list) super().__init__(*args, **kwargs) + def _byname_iter(self, id_, strong=False): + entity = self._get_ref("entity", id_, strong=strong) + element_id_list = entity["element_id_list"] + if not element_id_list: + yield entity["name"] + else: + for el_id in element_id_list: + yield from self._byname_iter(el_id, strong=True) + def __getitem__(self, key): if key == "byname": - return self["element_name_list"] or (self["name"],) + return tuple(self._byname_iter(self["id"])) return super().__getitem__(key) def polish(self): From adf1ea51bd925f8065d1ee6a851a8fa30ceef850 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 11 Oct 2023 20:15:50 +0200 Subject: [PATCH 134/317] Fix tests --- spinedb_api/db_mapping_base.py | 2 +- spinedb_api/mapped_items.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index b405871a..5911c89a 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -214,7 +214,7 @@ def _advance_query(self, item_type, limit): def _check_item_type(self, item_type): if item_type not in self.item_types(): candidate = max(self.item_types(), key=lambda x: SequenceMatcher(None, item_type, x).ratio()) - raise SpineDBAPIError(f"Invalid item type '{item_type}' - maybe you meant '{candidate}'?") + # FIXME: raise SpineDBAPIError(f"Invalid item type '{item_type}' - maybe you meant '{candidate}'?") def mapped_table(self, item_type): self._check_item_type(item_type) diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index b6a3e634..1dea2daf 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -117,7 +117,8 @@ def _byname_iter(self, id_, strong=False): def __getitem__(self, key): if key == "byname": - return tuple(self._byname_iter(self["id"])) + return self["element_name_list"] or (self["name"],) + # FIXME: Try to use this instead return tuple(self._byname_iter(self["id"])) return super().__getitem__(key) def polish(self): From cba5a2d56826f04f5e0e2e5b2cb01e88e4b0273a Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Thu, 12 Oct 2023 09:48:19 +0300 Subject: [PATCH 135/317] Fix fetch_more() Traceback after calling refresh_session() We need to filter items that are None in _advance_query(). Re #287 --- spinedb_api/db_mapping_base.py | 6 +++--- tests/test_DatabaseMapping.py | 15 +++++++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 5911c89a..db44d1b9 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -209,7 +209,7 @@ def _advance_query(self, item_type, limit): self._fetched_item_types.add(item_type) return [] mapped_table = self.mapped_table(item_type) - return [mapped_table.add_item(item) for item in chunk] + return list(filter(lambda i: i is not None, (mapped_table.add_item(item) for item in chunk))) def _check_item_type(self, item_type): if item_type not in self.item_types(): @@ -704,8 +704,8 @@ def _get_ref(self, ref_type, ref_id, strong=True): If the reference is not found, sets some flags. Args: - ref_type (str): The references's type - ref_id (int): The references's id + ref_type (str): The reference's type + ref_id (int): The reference's id strong (bool): True if the reference corresponds to a foreign key, False otherwise Returns: diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index a0809e57..4238395b 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -219,6 +219,21 @@ def test_updating_entity_name_updates_the_name_in_parameter_value_too(self): ) self.assertIsNotNone(color) + def test_fetch_more(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + alternatives = db_map.fetch_more("alternative") + expected = [{"id": 1, "name": "Base", "description": "Base alternative", "commit_id": 1}] + self.assertEqual([a._asdict() for a in alternatives], expected) + + def test_fetch_more_after_commit_and_refresh(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + db_map.add_item("entity_class", name="Widget") + db_map.add_item("entity", class_name="Widget", name="gadget") + db_map.commit_session("Add test data.") + db_map.refresh_session() + entities = db_map.fetch_more("entity") + self.assertEqual(entities, []) + class TestDatabaseMappingLegacy(unittest.TestCase): """'Backward compatibility' tests, i.e. pre-entity tests converted to work with the entity structure.""" From bd6cf40bca78769337ffc7c3e87fe9ce7a832e46 Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 12 Oct 2023 08:48:23 +0200 Subject: [PATCH 136/317] FIx tests better --- spinedb_api/db_mapping_base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 5911c89a..3c327612 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -214,18 +214,16 @@ def _advance_query(self, item_type, limit): def _check_item_type(self, item_type): if item_type not in self.item_types(): candidate = max(self.item_types(), key=lambda x: SequenceMatcher(None, item_type, x).ratio()) - # FIXME: raise SpineDBAPIError(f"Invalid item type '{item_type}' - maybe you meant '{candidate}'?") + raise SpineDBAPIError(f"Invalid item type '{item_type}' - maybe you meant '{candidate}'?") def mapped_table(self, item_type): self._check_item_type(item_type) return self._mapped_tables.setdefault(item_type, _MappedTable(self, item_type)) def get(self, item_type, default=None): - self._check_item_type(item_type) return self._mapped_tables.get(item_type, default) def pop(self, item_type, default): - self._check_item_type(item_type) return self._mapped_tables.pop(item_type, default) def clear(self): From 2581754157a9580aba90b4fa18a358483011ea2c Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Thu, 12 Oct 2023 10:45:00 +0300 Subject: [PATCH 137/317] Make unsupported database dialects official spinedb_api does not support all database dialects that SqlAlchemy supports. However, you should still be able to use spine_io to import and export to these dialects. To make the other dialects work, we need to add psycopg2 and cx_Oracle to dependencies. Re spine-tools/Spine-Toolbox#2329 --- pyproject.toml | 16 +++++++++------- spinedb_api/helpers.py | 16 ++++++++++------ 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 86a510da..47410815 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,9 +7,9 @@ description = "An API to talk to Spine databases." keywords = ["energy system modelling", "workflow", "optimisation", "database"] readme = {file = "README.md", content-type = "text/markdown"} classifiers = [ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)", - "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)", + "Operating System :: OS Independent", ] requires-python = ">=3.8.1, <3.12" dependencies = [ @@ -27,6 +27,8 @@ dependencies = [ "ijson >=3.1.4", "chardet >=4.0.0", "pymysql >=1.0.2", + "psycopg2", + "cx_Oracle", ] [project.urls] @@ -49,10 +51,10 @@ include-package-data = true [tool.setuptools.packages.find] exclude = [ - "bin*", - "docs*", - "fig*", - "tests*", + "bin*", + "docs*", + "fig*", + "tests*", ] [tool.coverage.run] diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index 7cfc6653..2dfbb7fe 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -56,16 +56,20 @@ from alembic.environment import EnvironmentContext from .exception import SpineDBAPIError, SpineDBVersionError -# Supported dialects and recommended dbapi. Restricted to mysql and sqlite for now: -# - sqlite works -# - mysql is trying to work SUPPORTED_DIALECTS = { "mysql": "pymysql", "sqlite": "sqlite3", - # "mssql": "pyodbc", - # "postgresql": "psycopg2", - # "oracle": "cx_oracle", } +"""Currently supported dialects and recommended dbapi.""" + + +UNSUPPORTED_DIALECTS = { + "mssql": "pyodbc", + "postgresql": "psycopg2", + "oracle": "cx_oracle", +} +"""Dialects and recommended dbapi that are not supported by DatabaseMapping but are supported by SqlAlchemy.""" + naming_convention = { "pk": "pk_%(table_name)s", From 6d00079f5974c992fd7cee0fc2e9c81e08f5215d Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 12 Oct 2023 10:20:30 +0200 Subject: [PATCH 138/317] Fix byname_iter --- spinedb_api/mapped_items.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 1dea2daf..7de71c29 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -106,19 +106,18 @@ def __init__(self, *args, **kwargs): kwargs["element_id_list"] = tuple(element_id_list) super().__init__(*args, **kwargs) - def _byname_iter(self, id_, strong=False): - entity = self._get_ref("entity", id_, strong=strong) + def _byname_iter(self, entity): element_id_list = entity["element_id_list"] if not element_id_list: yield entity["name"] else: for el_id in element_id_list: - yield from self._byname_iter(el_id, strong=True) + element = self._get_ref("entity", el_id) + yield from self._byname_iter(element) def __getitem__(self, key): if key == "byname": - return self["element_name_list"] or (self["name"],) - # FIXME: Try to use this instead return tuple(self._byname_iter(self["id"])) + return tuple(self._byname_iter(self)) return super().__getitem__(key) def polish(self): From 1f8f4c10e640f09fd3bad01764cf84ce646892e7 Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 13 Oct 2023 11:19:46 +0200 Subject: [PATCH 139/317] Improve MappedItem.merge to better find out there's nothing to update --- spinedb_api/db_mapping_base.py | 20 +++++++++++++++++++- spinedb_api/temp_id.py | 6 ------ tests/test_import_functions.py | 1 - 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index cd162926..51304e4d 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -10,6 +10,7 @@ ###################################################################################################################### import threading +import json from enum import Enum, unique, auto from difflib import SequenceMatcher from .temp_id import TempId @@ -616,7 +617,24 @@ def merge(self, other): dict: merged item. str: error description if any. """ - if all(self.get(key) == value for key, value in other.items()): + + def _convert(x): + if isinstance(x, list): + return tuple(x) + if isinstance(x, bytes): + try: + return json.loads(x) + except json.decoder.JSONDecodeError: + pass + return x + + def _equals(left, right): + if isinstance(left, TempId): + return left == right or (left.db_id is not None and left.db_id == right) + return _convert(left) == _convert(right) + + if all(_equals(self.get(key), value) for key, value in other.items()): + # Nothing to update, that's fine return None, "" merged = {**self._extended(), **other} if not isinstance(merged["id"], int): diff --git a/spinedb_api/temp_id.py b/spinedb_api/temp_id.py index 3ec329d8..b3280a20 100644 --- a/spinedb_api/temp_id.py +++ b/spinedb_api/temp_id.py @@ -24,12 +24,6 @@ def __init__(self, item_type): self._resolve_callbacks = [] self._db_id = None - def __eq__(self, other): - return super().__eq__(other) or (self._db_id is not None and other == self._db_id) - - def __hash__(self): - return int(self) - @property def db_id(self): return self._db_id diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py index 9309b565..659d21ff 100644 --- a/tests/test_import_functions.py +++ b/tests/test_import_functions.py @@ -406,7 +406,6 @@ def test_import_existing_relationship(self): self.assertIn("relationship_class_object1__object2", [r.name for r in db_map.query(db_map.relationship_sq)]) _, errors = import_relationships(db_map, [["relationship_class", ["object1", "object2"]]]) self.assertFalse(errors) - db_map.commit_session("test") self.assertIn("relationship_class_object1__object2", [r.name for r in db_map.query(db_map.relationship_sq)]) db_map.close() From 5e62042ca6fdce9f04318e864aa0f27f18db3490 Mon Sep 17 00:00:00 2001 From: Manuel Date: Sat, 14 Oct 2023 16:38:43 +0200 Subject: [PATCH 140/317] Introduce DBMapping.reset(), fix typos, make PublicItem serializable --- spinedb_api/db_mapping_base.py | 12 +++++++----- spinedb_api/db_mapping_query_mixin.py | 21 +++++++++------------ spinedb_api/server_client_helpers.py | 3 +++ tests/test_DatabaseMapping.py | 3 +-- 4 files changed, 20 insertions(+), 19 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 51304e4d..af62edec 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -224,11 +224,13 @@ def mapped_table(self, item_type): def get(self, item_type, default=None): return self._mapped_tables.get(item_type, default) - def pop(self, item_type, default): - return self._mapped_tables.pop(item_type, default) - - def clear(self): - self._mapped_tables.clear() + def reset(self, *item_types): + """Resets the mapping for given item types as if never was fetched from the DB.""" + item_types = set(self.item_types()) if not item_types else set(item_types) & set(self.item_types()) + for item_type in item_types: + self._mapped_tables.pop(item_type, None) + self._offsets.pop(item_type, None) + self._fetched_item_types.discard(item_type) def get_mapped_item(self, item_type, id_): mapped_table = self.mapped_table(item_type) diff --git a/spinedb_api/db_mapping_query_mixin.py b/spinedb_api/db_mapping_query_mixin.py index 54b4c0ca..93ab4e68 100644 --- a/spinedb_api/db_mapping_query_mixin.py +++ b/spinedb_api/db_mapping_query_mixin.py @@ -102,10 +102,7 @@ def _clear_subqueries(self, *tablenames): """Set to `None` subquery attributes involving the affected tables. This forces the subqueries to be refreshed when the corresponding property is accessed. """ - tablenames = list(tablenames) - for tablename in tablenames: - if self.pop(tablename, False): - self.fetch_all(tablename) + self.reset(*tablenames) attr_names = set(attr for tablename in tablenames for attr in self._get_table_to_sq_attr().get(tablename, [])) for attr_name in attr_names: setattr(self, attr_name, None) @@ -1349,42 +1346,42 @@ def override_scenario_alternative_sq_maker(self, method): def restore_entity_class_sq_maker(self): """Restores the original function that creates the ``entity_class_sq`` property.""" - self._make_entity_class_sq = MethodType(DatabaseMapping._make_entity_class_sq, self) + self._make_entity_class_sq = MethodType(DatabaseMappingQueryMixin._make_entity_class_sq, self) self._clear_subqueries("entity_class") def restore_entity_sq_maker(self): """Restores the original function that creates the ``entity_sq`` property.""" - self._make_entity_sq = MethodType(DatabaseMapping._make_entity_sq, self) + self._make_entity_sq = MethodType(DatabaseMappingQueryMixin._make_entity_sq, self) self._clear_subqueries("entity") def restore_entity_element_sq_maker(self): """Restores the original function that creates the ``entity_element_sq`` property.""" - self._make_entity_element_sq = MethodType(DatabaseMapping._make_entity_element_sq, self) + self._make_entity_element_sq = MethodType(DatabaseMappingQueryMixin._make_entity_element_sq, self) self._clear_subqueries("entity_element") def restore_parameter_definition_sq_maker(self): """Restores the original function that creates the ``parameter_definition_sq`` property.""" - self._make_parameter_definition_sq = MethodType(DatabaseMapping._make_parameter_definition_sq, self) + self._make_parameter_definition_sq = MethodType(DatabaseMappingQueryMixin._make_parameter_definition_sq, self) self._clear_subqueries("parameter_definition") def restore_parameter_value_sq_maker(self): """Restores the original function that creates the ``parameter_value_sq`` property.""" - self._make_parameter_value_sq = MethodType(DatabaseMapping._make_parameter_value_sq, self) + self._make_parameter_value_sq = MethodType(DatabaseMappingQueryMixin._make_parameter_value_sq, self) self._clear_subqueries("parameter_value") def restore_alternative_sq_maker(self): """Restores the original function that creates the ``alternative_sq`` property.""" - self._make_alternative_sq = MethodType(DatabaseMapping._make_alternative_sq, self) + self._make_alternative_sq = MethodType(DatabaseMappingQueryMixin._make_alternative_sq, self) self._clear_subqueries("alternative") def restore_scenario_sq_maker(self): """Restores the original function that creates the ``scenario_sq`` property.""" - self._make_scenario_sq = MethodType(DatabaseMapping._make_scenario_sq, self) + self._make_scenario_sq = MethodType(DatabaseMappingQueryMixin._make_scenario_sq, self) self._clear_subqueries("scenario") def restore_scenario_alternative_sq_maker(self): """Restores the original function that creates the ``scenario_alternative_sq`` property.""" - self._make_scenario_alternative_sq = MethodType(DatabaseMapping._make_scenario_alternative_sq, self) + self._make_scenario_alternative_sq = MethodType(DatabaseMappingQueryMixin._make_scenario_alternative_sq, self) self._clear_subqueries("scenario_alternative") def _object_class_id(self): diff --git a/spinedb_api/server_client_helpers.py b/spinedb_api/server_client_helpers.py index 3a39f9de..b43cf5e7 100644 --- a/spinedb_api/server_client_helpers.py +++ b/spinedb_api/server_client_helpers.py @@ -11,6 +11,7 @@ import json from .exception import SpineDBAPIError +from .db_mapping_base import PublicItem # Encode decode server messages _START_OF_TAIL = '\u001f' # Unit separator @@ -64,6 +65,8 @@ def default(self, o): return list(o) if isinstance(o, SpineDBAPIError): return str(o) + if isinstance(o, PublicItem): + return o._asdict() return super().default(o) @property diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 4238395b..d2148661 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -2442,8 +2442,7 @@ def test_cascade_remove_unfetched(self): import_functions.import_object_classes(self._db_map, ("my_class",)) import_functions.import_objects(self._db_map, (("my_class", "my_object"),)) self._db_map.commit_session("test commit") - self._db_map.refresh_session() - self._db_map.clear() + self._db_map.reset() self._db_map.remove_items("entity_class", 1) self._db_map.commit_session("test commit") ents = self._db_map.query(self._db_map.entity_sq).all() From 4a4b8330b62a631effd593963aed9a02ed7cb6d8 Mon Sep 17 00:00:00 2001 From: Manuel Date: Sun, 15 Oct 2023 12:44:52 +0200 Subject: [PATCH 141/317] Fix DBMapping.reset to also clear tables in cascade --- spinedb_api/db_mapping_base.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index af62edec..260c6b89 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -225,8 +225,20 @@ def get(self, item_type, default=None): return self._mapped_tables.get(item_type, default) def reset(self, *item_types): - """Resets the mapping for given item types as if never was fetched from the DB.""" + """Resets the mapping for given item types as if nothing was fetched from the DB or modified in the mapping. + Any modifications in the mapping that aren't committed to the DB are lost after this. + """ item_types = set(self.item_types()) if not item_types else set(item_types) & set(self.item_types()) + # Include descendants, otherwise references are broken + while True: + changed = False + for item_type in item_types - set(self.item_types()): + if self._item_factory(item_type).ref_types() & item_types: + item_types.add(item_type) + changed = True + if not changed: + break + # Now clear things for item_type in item_types: self._mapped_tables.pop(item_type, None) self._offsets.pop(item_type, None) From a912992d42d7ad7ff9150d91a9472bbb81964f57 Mon Sep 17 00:00:00 2001 From: Manuel Date: Sun, 15 Oct 2023 12:45:59 +0200 Subject: [PATCH 142/317] Fix order in export_entities so clients stay the same SpineOpt particularly is sensible to the order of entities in a class (for determining the sense of connection flows) so we need to keep the right sorting. --- spinedb_api/export_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spinedb_api/export_functions.py b/spinedb_api/export_functions.py index de09933a..efe53c97 100644 --- a/spinedb_api/export_functions.py +++ b/spinedb_api/export_functions.py @@ -131,7 +131,7 @@ def export_entity_classes(db_map, ids=Asterisk): def export_entities(db_map, ids=Asterisk): return sorted( ((x.class_name, x.element_name_list or x.name, x.description) for x in _get_items(db_map, "entity", ids)), - key=lambda x: (0 if isinstance(x[1], str) else len(x[1]), x[0]), + key=lambda x: (0 if isinstance(x[1], str) else len(x[1]), x[0], (x[1],) if isinstance(x[1], str) else x[1]), ) From 0344004a2b58406351d77cd1488d9a43ce2d5c80 Mon Sep 17 00:00:00 2001 From: Manuel Date: Sun, 15 Oct 2023 12:58:28 +0200 Subject: [PATCH 143/317] Fix mistake --- spinedb_api/db_mapping_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 260c6b89..8eef7ce0 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -232,7 +232,7 @@ def reset(self, *item_types): # Include descendants, otherwise references are broken while True: changed = False - for item_type in item_types - set(self.item_types()): + for item_type in set(self.item_types()) - item_types: if self._item_factory(item_type).ref_types() & item_types: item_types.add(item_type) changed = True From a031b056598cdc115d5fc5fd7b6db3098ade3534 Mon Sep 17 00:00:00 2001 From: Manuel Date: Sun, 15 Oct 2023 14:33:16 +0200 Subject: [PATCH 144/317] Introduce MappedItem._something_to_update To allow subclasses with parsed values to identify situations better. --- spinedb_api/db_mapping_base.py | 26 +++++--------- spinedb_api/mapped_items.py | 62 +++++++++++++++++++++++++--------- spinedb_api/parameter_value.py | 2 +- spinedb_api/temp_id.py | 6 ++++ 4 files changed, 61 insertions(+), 35 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 8eef7ce0..f415a88a 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -10,7 +10,6 @@ ###################################################################################################################### import threading -import json from enum import Enum, unique, auto from difflib import SequenceMatcher from .temp_id import TempId @@ -631,23 +630,7 @@ def merge(self, other): dict: merged item. str: error description if any. """ - - def _convert(x): - if isinstance(x, list): - return tuple(x) - if isinstance(x, bytes): - try: - return json.loads(x) - except json.decoder.JSONDecodeError: - pass - return x - - def _equals(left, right): - if isinstance(left, TempId): - return left == right or (left.db_id is not None and left.db_id == right) - return _convert(left) == _convert(right) - - if all(_equals(self.get(key), value) for key, value in other.items()): + if not self._something_to_update(other): # Nothing to update, that's fine return None, "" merged = {**self._extended(), **other} @@ -655,6 +638,13 @@ def _equals(left, right): merged["id"] = self["id"] return merged, "" + def _something_to_update(self, other): + def _convert(x): + if isinstance(x, list): + return tuple(x) + + return all(_convert(self.get(key)) != _convert(value) for key, value in other.items()) + def first_invalid_key(self): """Goes through the ``_references`` class attribute and returns the key of the first one that cannot be resolved. diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 7de71c29..f0aefffb 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -202,9 +202,20 @@ def parsed_value(self): self._parsed_value = self._make_parsed_value() return self._parsed_value - def _make_parsed_value(self): + @property + def _value_key(self): + raise NotImplementedError() + + @property + def _type_key(self): raise NotImplementedError() + def _make_parsed_value(self): + try: + return from_database(self[self._value_key], self[self._type_key]) + except ParameterValueFormatError as error: + return error + def update(self, other): self._parsed_value = None super().update(other) @@ -214,6 +225,19 @@ def __getitem__(self, key): return self.parsed_value return super().__getitem__(key) + def _something_to_update(self, other): + other = other.copy() + if self._value_key in other and self._type_key in other: + try: + other_parsed_value = from_database(other[self._value_key], other[self._type_key]) + if self.parsed_value != other_parsed_value: + return True + except ParameterValueFormatError: + pass + _ = other.pop(self._value_key, None) + _ = other.pop(self._type_key, None) + return super()._something_to_update(other) + class ParameterDefinitionItem(ParsedValueBase): fields = { @@ -242,11 +266,13 @@ def list_value_id(self): return int(dict.__getitem__(self, "default_value")) return None - def _make_parsed_value(self): - try: - return from_database(self["default_value"], self["default_type"]) - except ParameterValueFormatError as error: - return error + @property + def _value_key(self): + return "default_value" + + @property + def _type_key(self): + return "default_type" def __getitem__(self, key): if key == "parameter_name": @@ -356,11 +382,13 @@ def list_value_id(self): return int(dict.__getitem__(self, "value")) return None - def _make_parsed_value(self): - try: - return from_database(self["value"], self["type"]) - except ParameterValueFormatError as error: - return error + @property + def _value_key(self): + return "value" + + @property + def _type_key(self): + return "type" def __getitem__(self, key): if key == "parameter_id": @@ -425,11 +453,13 @@ class ListValueItem(ParsedValueBase): "parameter_value_list_id": (("parameter_value_list_name",), ("parameter_value_list", ("name",))), } - def _make_parsed_value(self): - try: - return from_database(self["value"], self["type"]) - except ParameterValueFormatError as error: - return error + @property + def _value_key(self): + return "value" + + @property + def _type_key(self): + return "type" class AlternativeItem(MappedItemBase): diff --git a/spinedb_api/parameter_value.py b/spinedb_api/parameter_value.py index 7ef96b85..5388db8b 100644 --- a/spinedb_api/parameter_value.py +++ b/spinedb_api/parameter_value.py @@ -946,7 +946,7 @@ def __init__(self, values, value_type=None, index_name=""): def __eq__(self, other): if not isinstance(other, Array): return NotImplemented - return np.array_equal(self._values, other._values) and self.index_name == other.index_name + return np.array_equal(self._values, other._values, equal_nan=True) and self.index_name == other.index_name @staticmethod def type_(): diff --git a/spinedb_api/temp_id.py b/spinedb_api/temp_id.py index b3280a20..79066941 100644 --- a/spinedb_api/temp_id.py +++ b/spinedb_api/temp_id.py @@ -28,6 +28,12 @@ def __init__(self, item_type): def db_id(self): return self._db_id + def __eq__(self, other): + return super().__eq__(other) or (self._db_id is not None and other == self._db_id) + + def __hash__(self): + return int(self) + def __repr__(self): return f"TempId({self._item_type}, {super().__repr__()})" From 4cdbc776f81ecbda02dfe7c9e70bdfd892fe9796 Mon Sep 17 00:00:00 2001 From: Manuel Date: Sun, 15 Oct 2023 16:51:37 +0200 Subject: [PATCH 145/317] Fix tests --- spinedb_api/db_mapping_base.py | 5 ++--- spinedb_api/mapped_items.py | 4 ++-- spinedb_api/parameter_value.py | 5 ++++- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index f415a88a..78a54032 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -640,10 +640,9 @@ def merge(self, other): def _something_to_update(self, other): def _convert(x): - if isinstance(x, list): - return tuple(x) + return tuple(x) if isinstance(x, list) else x - return all(_convert(self.get(key)) != _convert(value) for key, value in other.items()) + return not all(_convert(self.get(key)) == _convert(value) for key, value in other.items()) def first_invalid_key(self): """Goes through the ``_references`` class attribute and returns the key of the first one diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index f0aefffb..3473dda6 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -232,10 +232,10 @@ def _something_to_update(self, other): other_parsed_value = from_database(other[self._value_key], other[self._type_key]) if self.parsed_value != other_parsed_value: return True + _ = other.pop(self._value_key, None) + _ = other.pop(self._type_key, None) except ParameterValueFormatError: pass - _ = other.pop(self._value_key, None) - _ = other.pop(self._type_key, None) return super()._something_to_update(other) diff --git a/spinedb_api/parameter_value.py b/spinedb_api/parameter_value.py index 5388db8b..26f5ed1b 100644 --- a/spinedb_api/parameter_value.py +++ b/spinedb_api/parameter_value.py @@ -946,7 +946,10 @@ def __init__(self, values, value_type=None, index_name=""): def __eq__(self, other): if not isinstance(other, Array): return NotImplemented - return np.array_equal(self._values, other._values, equal_nan=True) and self.index_name == other.index_name + try: + return np.array_equal(self._values, other._values, equal_nan=True) and self.index_name == other.index_name + except TypeError: + return np.array_equal(self._values, other._values) and self.index_name == other.index_name @staticmethod def type_(): From be976a85088321483fcb55becf3f0bbe7da04f6f Mon Sep 17 00:00:00 2001 From: Manuel Date: Mon, 16 Oct 2023 15:22:37 +0200 Subject: [PATCH 146/317] API to import multi-D entities where the elements are also multi-D ents Re #291 --- spinedb_api/db_mapping_base.py | 62 ++++++++++++--------- spinedb_api/import_functions.py | 95 ++++++++++++++++++--------------- spinedb_api/mapped_items.py | 37 +++++++++++-- tests/test_import_functions.py | 31 +++++++++++ 4 files changed, 153 insertions(+), 72 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 78a54032..8f41f00d 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -362,10 +362,16 @@ def find_item(self, item, skip_keys=(), fetch=True): return self._db_map.fetch_ref(self._item_type, id_) # No id. Try to locate the item by the value of one of the unique keys. # Used by import_data (and more...) - # FIXME: Do we really need to make the MappedItem here? - # Can't we just obtain the unique_values directly from item? - # I guess it's needed in case the user specifies stuff like 'class_id', as tests do, - # but that should be a corner case... + for key in self._db_map._item_factory(self._item_type)._unique_keys: + if key in skip_keys: + continue + value = tuple(item.get(k) for k in key) + if None in value: + continue + current_item = self._unique_key_value_to_item(key, value, fetch=fetch) + if current_item: + return current_item + # Last hope: maybe item is missing some key stuff, so try with a resolved and polished MappedItem instead... mapped_item = self._make_item(item) error = mapped_item.resolve_inverse_references(item.keys()) if error: @@ -475,24 +481,24 @@ class MappedItemBase(dict): """A dictionary that represents a db item.""" fields = {} - """A dictionary mapping fields to a tuple of (type, value description)""" + """A dictionary mapping keys to a tuple of (type, value description)""" _defaults = {} """A dictionary mapping keys to their default values""" _unique_keys = () - """A tuple where each element is itself a tuple of keys that are unique""" + """A tuple where each element is itself a tuple of keys corresponding to a unique constraint""" _references = {} """A dictionary mapping keys that are not in the original dictionary, - to a recipe for finding the field they reference in another item. + to a recipe for finding the key they reference in another item. - The recipe is a tuple of the form (original_field, (ref_item_type, ref_field)), + The recipe is a tuple of the form (src_key, (ref_item_type, ref_key)), to be interpreted as follows: - 1. take the value from the original_field of this item, which should be an id, + 1. take the value from the src_key of this item, which should be an id, 2. locate the item of type ref_item_type that has that id, - 3. return the value from the ref_field of that item. + 3. return the value from the ref_key of that item. """ _inverse_references = {} """Another dictionary mapping keys that are not in the original dictionary, - to a recipe for finding the field they reference in another item. + to a recipe for finding the key they reference in another item. Used only for creating new items, when the user provides names and we want to find the ids. The recipe is a tuple of the form (src_unique_key, (ref_item_type, ref_unique_key)), @@ -687,22 +693,28 @@ def resolve_inverse_references(self, skip_keys=()): Returns: str or None: error description if any. """ - for src_key, (id_key, (ref_type, ref_key)) in self._inverse_references.items(): + for src_key in self._inverse_references: if src_key in skip_keys: continue - id_value = tuple(dict.pop(self, k, None) or self.get(k) for k in id_key) - if None in id_value: - continue - mapped_table = self._db_map.mapped_table(ref_type) - try: - self[src_key] = ( - tuple(mapped_table.unique_key_value_to_id(ref_key, v, strict=True) for v in zip(*id_value)) - if all(isinstance(v, (tuple, list)) for v in id_value) - else mapped_table.unique_key_value_to_id(ref_key, id_value, strict=True) - ) - except KeyError as err: - # Happens at unique_key_value_to_id(..., strict=True) - return f"can't find {ref_type} with {dict(zip(ref_key, err.args[0]))}" + error = self._do_resolve_inverse_reference(src_key) + if error: + return error + + def _do_resolve_inverse_reference(self, src_key): + id_key, (ref_type, ref_key) = self._inverse_references[src_key] + id_value = tuple(dict.pop(self, k, None) or self.get(k) for k in id_key) + if None in id_value: + return + mapped_table = self._db_map.mapped_table(ref_type) + try: + self[src_key] = ( + tuple(mapped_table.unique_key_value_to_id(ref_key, v, strict=True) for v in zip(*id_value)) + if all(isinstance(v, (tuple, list)) for v in id_value) + else mapped_table.unique_key_value_to_id(ref_key, id_value, strict=True) + ) + except KeyError as err: + # Happens at unique_key_value_to_id(..., strict=True) + return f"can't find {ref_type} with {dict(zip(ref_key, err.args[0]))}" def polish(self): """Polishes this item once all it's references have been resolved. Returns any error. diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 59a8ddf8..0d2363fc 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -176,11 +176,11 @@ def get_data_for_import( yield ("alternative", _get_alternatives_for_import(db_map, alternatives)) yield ("scenario_alternative", _get_scenario_alternatives_for_import(db_map, scenario_alternatives)) if entity_classes: - yield ("entity_class", _get_entity_classes_for_import(db_map, entity_classes, zero_dim=True)) - yield ("entity_class", _get_entity_classes_for_import(db_map, entity_classes, zero_dim=False)) + for bucket in _get_entity_classes_for_import(db_map, entity_classes): + yield ("entity_class", bucket) if entities: - yield ("entity", _get_entities_for_import(db_map, entities, zero_dim=True)) - yield ("entity", _get_entities_for_import(db_map, entities, zero_dim=False)) + for bucket in _get_entities_for_import(db_map, entities): + yield ("entity", bucket) if entity_alternatives: yield ("entity_alternative", _get_entity_alternatives_for_import(db_map, entity_alternatives)) if entity_groups: @@ -207,31 +207,28 @@ def get_data_for_import( yield ("parameter_value_metadata", _get_parameter_value_metadata_for_import(db_map, parameter_value_metadata)) # Legacy if object_classes: - yield ("object_class", _get_object_classes_for_import(db_map, object_classes)) + yield from get_data_for_import(db_map, entity_classes=object_classes) if relationship_classes: - yield ("relationship_class", _get_entity_classes_for_import(db_map, relationship_classes, zero_dim=False)) + yield from get_data_for_import(db_map, entity_classes=relationship_classes) if object_parameters: - yield ("parameter_definition", _get_parameter_definitions_for_import(db_map, object_parameters, unparse_value)) + yield from get_data_for_import(db_map, unparse_value=unparse_value, parameter_definitions=object_parameters) if relationship_parameters: - yield ( - "parameter_definition", - _get_parameter_definitions_for_import(db_map, relationship_parameters, unparse_value), + yield from get_data_for_import( + db_map, unparse_value=unparse_value, parameter_definitions=relationship_parameters ) if objects: - yield ("object", _get_entities_for_import(db_map, objects, zero_dim=True)) + yield from get_data_for_import(db_map, entities=objects) if relationships: - yield ("relationship", _get_entities_for_import(db_map, relationships, zero_dim=False)) + yield from get_data_for_import(db_map, entities=relationships) if object_groups: - yield ("entity_group", _get_entity_groups_for_import(db_map, object_groups)) + yield from get_data_for_import(db_map, entity_groups=object_groups) if object_parameter_values: - yield ( - "parameter_value", - _get_parameter_values_for_import(db_map, object_parameter_values, unparse_value, on_conflict), + yield from get_data_for_import( + db_map, unparse_value=unparse_value, on_conflict=on_conflict, parameter_values=object_parameter_values ) if relationship_parameter_values: - yield ( - "parameter_value", - _get_parameter_values_for_import(db_map, relationship_parameter_values, unparse_value, on_conflict), + yield from get_data_for_import( + db_map, unparse_value=unparse_value, on_conflict=on_conflict, parameter_values=relationship_parameter_values ) if object_metadata: yield from get_data_for_import(db_map, entity_metadata=object_metadata) @@ -494,32 +491,44 @@ def _add_to_seen(checked_item, seen): seen.setdefault(key, set()).add(value) -def _get_entity_classes_for_import(db_map, data, zero_dim): - def _data_iterator(): - for x in data: - if isinstance(x, str): - x = x, () - name, *optionals = x - dim_name_list = optionals.pop(0) if optionals else () - if (dim_name_list and zero_dim) or (not dim_name_list and not zero_dim): - continue - yield name, dim_name_list, *optionals - +def _get_entity_classes_for_import(db_map, data): + dim_name_list_by_name = {} + items = [] key = ("name", "dimension_name_list", "description", "display_icon") - return _get_items_for_import(db_map, "entity_class", (dict(zip(key, x)) for x in _data_iterator())) - + for x in data: + if isinstance(x, str): + x = x, () + name, *optionals = x + dim_name_list = optionals.pop(0) if optionals else () + item = dict(zip(key, (name, dim_name_list, *optionals))) + items.append(item) + dim_name_list_by_name[name] = dim_name_list + + def _ref_count(name): + dim_name_list = dim_name_list_by_name.get(name, ()) + return len(dim_name_list) + sum((_ref_count(dim_name) for dim_name in dim_name_list), start=0) + + items_by_ref_count = {} + for item in items: + items_by_ref_count.setdefault(_ref_count(item["name"]), []).append(item) + return ( + _get_items_for_import(db_map, "entity_class", items_by_ref_count[ref_count]) + for ref_count in sorted(items_by_ref_count) + ) -def _get_entities_for_import(db_map, data, zero_dim): - def _data_iterator(): - for class_name, name_or_element_name_list, *optionals in data: - is_zero_dim = isinstance(name_or_element_name_list, str) - if (is_zero_dim and not zero_dim) or (not is_zero_dim and zero_dim): - continue - byname_key = "name" if is_zero_dim else "element_name_list" - key = ("class_name", byname_key, "description") - yield dict(zip(key, (class_name, name_or_element_name_list, *optionals))) - return _get_items_for_import(db_map, "entity", _data_iterator()) +def _get_entities_for_import(db_map, data): + items_by_el_count = {} + key = ("class_name", "byname", "description") + for class_name, name_or_element_name_list, *optionals in data: + is_zero_dim = isinstance(name_or_element_name_list, str) + byname = (name_or_element_name_list,) if is_zero_dim else tuple(name_or_element_name_list) + item = dict(zip(key, (class_name, byname, *optionals))) + el_count = 0 if is_zero_dim else len(name_or_element_name_list) + items_by_el_count.setdefault(el_count, []).append(item) + return ( + _get_items_for_import(db_map, "entity", items_by_el_count[el_count]) for el_count in sorted(items_by_el_count) + ) def _get_entity_alternatives_for_import(db_map, data): @@ -699,4 +708,4 @@ def _data_iterator(): name, *optionals = x yield name, (), *optionals - return _get_entity_classes_for_import(db_map, _data_iterator(), zero_dim=True) + return _get_entity_classes_for_import(db_map, _data_iterator()) diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 3473dda6..6898d712 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -80,7 +80,7 @@ class EntityItem(MappedItemBase): "byname": ( "tuple", "A tuple with the entity name as single element if the entity is zero-dimensional, " - "or the element name list if it is multi-dimensional.", + "or the element names if it is multi-dimensional.", ), "description": ("str, optional", "The entity description."), } @@ -120,6 +120,35 @@ def __getitem__(self, key): return tuple(self._byname_iter(self)) return super().__getitem__(key) + def resolve_inverse_references(self, skip_keys=()): + error = super().resolve_inverse_references(skip_keys=skip_keys) + if error: + return error + byname = dict.pop(self, "byname", None) + if byname is None: + return + if not self["dimension_id_list"]: + self["name"] = byname[0] + return + byname_remainder = list(byname) + self["element_name_list"] = self._element_name_list_recursive(self["class_name"], byname_remainder) + return self._do_resolve_inverse_reference("element_id_list") + + def _element_name_list_recursive(self, class_name, byname_remainder): + dimension_name_list = self._db_map.get_item("entity_class", name=class_name).get("dimension_name_list") + if not dimension_name_list: + name = byname_remainder.pop(0) + return (name,) + return tuple( + ( + self._db_map.get_item( + "entity", class_name=dim_name, byname=self._element_name_list_recursive(dim_name, byname_remainder) + ) + or {} + ).get("name") + for dim_name in dimension_name_list + ) + def polish(self): error = super().polish() if error: @@ -167,7 +196,7 @@ class EntityAlternativeItem(MappedItemBase): "entity_byname": ( "tuple", "A tuple with the entity name as single element if the entity is zero-dimensional, " - "or the element name list if it is multi-dimensional.", + "or the element names if it is multi-dimensional.", ), "alternative_name": ("str", "The alternative name."), "active": ("bool, optional", "Whether the entity is active in the alternative - defaults to True."), @@ -346,7 +375,7 @@ class ParameterValueItem(ParsedValueBase): "entity_byname": ( "tuple", "A tuple with the entity name as single element if the entity is zero-dimensional, " - "or the element name list if the entity is multi-dimensional.", + "or the element names if the entity is multi-dimensional.", ), "value": ("any", "The value."), "type": ("str", "The value type."), @@ -559,7 +588,7 @@ class ParameterValueMetadataItem(MappedItemBase): "entity_byname": ( "tuple", "A tuple with the entity name as single element if the entity is zero-dimensional, " - "or the element name list if it is multi-dimensional.", + "or the element names if it is multi-dimensional.", ), "alternative_name": ("str", "The alternative name."), "metadata_name": ("str", "The metadata entry name."), diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py index 659d21ff..93f1892e 100644 --- a/tests/test_import_functions.py +++ b/tests/test_import_functions.py @@ -419,6 +419,37 @@ def test_import_relationship_with_one_None_object(self): self.assertFalse([r.name for r in db_map.query(db_map.relationship_sq)]) db_map.close() + def test_import_relationship_of_relationships(self): + db_map = create_diff_db_map() + self.populate(db_map) + import_data( + db_map, + entity_classes=[ + ["relationship_class1", ["object_class1", "object_class2"]], + ["relationship_class2", ["object_class2", "object_class1"]], + ["meta_relationship_class", ["relationship_class1", "relationship_class2"]], + ], + entities=[ + ["relationship_class1", ["object1", "object2"]], + ["relationship_class2", ["object2", "object1"]], + ], + ) + _, errors = import_data( + db_map, entities=[["meta_relationship_class", ["object1", "object2", "object2", "object1"]]] + ) + self.assertFalse(errors) + db_map.commit_session("test") + entities = { + tuple(r.element_name_list.split(",")) if r.element_name_list else r.name: r.name + for r in db_map.query(db_map.wide_entity_sq) + } + self.assertTrue("object1" in entities) + self.assertTrue("object2" in entities) + self.assertTrue(("object1", "object2") in entities) + self.assertTrue(("object2", "object1") in entities) + self.assertTrue((entities["object1", "object2"], entities["object2", "object1"]) in entities) + self.assertEqual(len(entities), 5) + class TestImportParameterDefinition(unittest.TestCase): def setUp(self): From 09baab5f12d55b0f1ae7d6f8d77c1023d534bb68 Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 17 Oct 2023 12:19:10 +0200 Subject: [PATCH 147/317] Allow python-level filtering in get_items --- spinedb_api/db_mapping.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index afd74987..8b9990f5 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -388,7 +388,7 @@ def get_item(self, item_type, fetch=True, skip_removed=True, **kwargs): return None return item.public_item - def get_items(self, item_type, fetch=True, skip_removed=True): + def get_items(self, item_type, fetch=True, skip_removed=True, **kwargs): """Finds and returns all the items of one type. Args: @@ -404,7 +404,7 @@ def get_items(self, item_type, fetch=True, skip_removed=True): self.fetch_all(item_type) mapped_table = self.mapped_table(item_type) get_items = mapped_table.valid_values if skip_removed else mapped_table.values - return [x.public_item for x in get_items()] + return [x.public_item for x in get_items() if all(x.get(k) == v for k, v in kwargs.items())] def add_item(self, item_type, check=True, **kwargs): """Adds an item to the in-memory mapping. From 02e489ec3fd1961d47757f52f51b817e3ae1aa3e Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 17 Oct 2023 12:32:56 +0200 Subject: [PATCH 148/317] Remove confusing key parameter_name for ParameterDefinition just "name" is enough. --- spinedb_api/mapped_items.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 6898d712..7d42fee3 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -304,8 +304,6 @@ def _type_key(self): return "default_type" def __getitem__(self, key): - if key == "parameter_name": - return super().__getitem__("name") if key == "value_list_id": return super().__getitem__("parameter_value_list_id") if key == "parameter_value_list_id": From ef24e1115b94f90238e5e9057c09125eb28944fe Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 17 Oct 2023 12:52:45 +0200 Subject: [PATCH 149/317] FIx attribute access --- spinedb_api/export_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spinedb_api/export_functions.py b/spinedb_api/export_functions.py index efe53c97..a4e07fff 100644 --- a/spinedb_api/export_functions.py +++ b/spinedb_api/export_functions.py @@ -150,7 +150,7 @@ def export_parameter_definitions(db_map, ids=Asterisk, parse_value=from_database return sorted( ( x.entity_class_name, - x.parameter_name, + x.name, parse_value(x.default_value, x.default_type), x.parameter_value_list_name, x.description, From d1db6c99f90405ceadeaa63d0fbf505c949375d3 Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 17 Oct 2023 14:55:01 +0200 Subject: [PATCH 150/317] Don't skip commit in item_types() Re spine-tools/Spine-Toolbox#2354 --- docs/source/conf.py | 4 +++- spinedb_api/db_mapping.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 2c0b7836..d3ad8ebe 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -142,7 +142,9 @@ def _process_docstring(app, what, name, obj, options, lines): new_lines.extend([f" * - {f_names}"]) lines[i : i + 1] = new_lines # Expand - spine_item_types = ", ".join([f"``{x}``" for x in DatabaseMapping.item_types()]) + spine_item_types = ", ".join( + [f"``{x}``" for x in DatabaseMapping.item_types() if DatabaseMapping._item_factory(x).fields] + ) for k, line in enumerate(lines): if "" in line: lines[k] = line.replace("", spine_item_types) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 8b9990f5..027311b6 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -198,7 +198,7 @@ def __del__(self): @staticmethod def item_types(): - return [x for x in DatabaseMapping._sq_name_by_item_type if item_factory(x).fields] + return list(DatabaseMapping._sq_name_by_item_type) @staticmethod def _item_factory(item_type): From 91c51e0dbf1f3c691f0528729f222f673a942a78 Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 17 Oct 2023 15:56:22 +0200 Subject: [PATCH 151/317] Introduce all_item_types to distinguish read-only stuff like commit Re spine-tools/Spine-Toolbox#2354 --- docs/source/conf.py | 6 +----- spinedb_api/db_mapping.py | 4 ++++ spinedb_api/db_mapping_base.py | 15 ++++++++++++--- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index d3ad8ebe..9e11999d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -114,8 +114,6 @@ def _process_docstring(app, what, name, obj, options, lines): new_lines = [] for item_type in DatabaseMapping.item_types(): factory = DatabaseMapping._item_factory(item_type) - if not factory.fields: - continue new_lines.extend([item_type, len(item_type) * "-", ""]) new_lines.extend( [ @@ -142,9 +140,7 @@ def _process_docstring(app, what, name, obj, options, lines): new_lines.extend([f" * - {f_names}"]) lines[i : i + 1] = new_lines # Expand - spine_item_types = ", ".join( - [f"``{x}``" for x in DatabaseMapping.item_types() if DatabaseMapping._item_factory(x).fields] - ) + spine_item_types = ", ".join([f"``{x}``" for x in DatabaseMapping.item_types()]) for k, line in enumerate(lines): if "" in line: lines[k] = line.replace("", spine_item_types) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 027311b6..ea90660b 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -198,6 +198,10 @@ def __del__(self): @staticmethod def item_types(): + return [x for x in DatabaseMapping._sq_name_by_item_type if item_factory(x).fields] + + @staticmethod + def all_item_types(): return list(DatabaseMapping._sq_name_by_item_type) @staticmethod diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 8f41f00d..c739f689 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -64,7 +64,16 @@ def fetched_item_types(self): @staticmethod def item_types(): - """Returns a list of item types from the DB mapping schema (equivalent to the table names). + """Returns a list of public item types from the DB mapping schema (equivalent to the table names). + + Returns: + list(str) + """ + raise NotImplementedError() + + @staticmethod + def all_item_types(): + """Returns a list of all item types from the DB mapping schema (equivalent to the table names). Returns: list(str) @@ -212,8 +221,8 @@ def _advance_query(self, item_type, limit): return list(filter(lambda i: i is not None, (mapped_table.add_item(item) for item in chunk))) def _check_item_type(self, item_type): - if item_type not in self.item_types(): - candidate = max(self.item_types(), key=lambda x: SequenceMatcher(None, item_type, x).ratio()) + if item_type not in self.all_item_types(): + candidate = max(self.all_item_types(), key=lambda x: SequenceMatcher(None, item_type, x).ratio()) raise SpineDBAPIError(f"Invalid item type '{item_type}' - maybe you meant '{candidate}'?") def mapped_table(self, item_type): From c7ae3f81a240d047434337cf0416a68f766a69a4 Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 17 Oct 2023 15:59:46 +0200 Subject: [PATCH 152/317] Fix tests --- tests/test_db_mapping_base.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_db_mapping_base.py b/tests/test_db_mapping_base.py index 6c3ccd40..c677a659 100644 --- a/tests/test_db_mapping_base.py +++ b/tests/test_db_mapping_base.py @@ -18,6 +18,10 @@ class TestDBMapping(DatabaseMappingBase): def item_types(): return ["cutlery"] + @staticmethod + def all_item_types(): + return ["cutlery"] + @staticmethod def _item_factory(item_type): if item_type == "cutlery": From fee5e83c2660eea2aaf6d0c9869054919c378149 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 18 Oct 2023 14:30:04 +0200 Subject: [PATCH 153/317] Support filtering in SQL via get_items This is so we don't bring more rows than needed into memory and also to speed up things in some cases. --- spinedb_api/db_mapping.py | 33 +++----- spinedb_api/db_mapping_base.py | 146 ++++++++++++++++----------------- tests/test_DatabaseMapping.py | 2 +- 3 files changed, 80 insertions(+), 101 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index ea90660b..7d215126 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -99,8 +99,9 @@ class DatabaseMapping(DatabaseMappingQueryMixin, DatabaseMappingCommitMixin, Dat This allows new items in the DB (added by other clients in the meantime) to be retrieved as well. You can also control the fetching process via :meth:`fetch_more` and/or :meth:`fetch_all`. - For example, a UI application might want to fetch data in the background so the UI is not blocked in the process. - In that case they can call e.g. :meth:`fetch_more` asynchronously as the user scrolls or expands the views. + For example, you can call :meth:`fetch_more` in a dedicated thread while you do some work on the main thread. + This will nicely place items in the in-memory mapping so you can access them later, without + the overhead of fetching them from the DB. The :meth:`query` method is also provided as an alternative way to retrieve data from the DB while bypassing the in-memory mapping entirely. @@ -208,11 +209,9 @@ def all_item_types(): def _item_factory(item_type): return item_factory(item_type) - def _make_query(self, item_type): - if self.closed: - return None + def _make_sq(self, item_type): sq_name = self._sq_name_by_item_type[item_type] - return self.query(getattr(self, sq_name)) + return getattr(self, sq_name) def close(self): """Closes this DB mapping. This is only needed if you're keeping a long-lived session. @@ -404,7 +403,7 @@ def get_items(self, item_type, fetch=True, skip_removed=True, **kwargs): list(:class:`PublicItem`): The items. """ item_type = self._real_tablename(item_type) - if fetch and item_type not in self.fetched_item_types: + if fetch: self.fetch_all(item_type) mapped_table = self.mapped_table(item_type) get_items = mapped_table.valid_values if skip_removed else mapped_table.values @@ -604,31 +603,19 @@ def purge_items(self, item_type): """ return bool(self.remove_items(item_type, Asterisk)) - def can_fetch_more(self, item_type): - """Whether or not more data can be fetched from the DB for the given item type. - - Args: - item_type (str): One of . - - Returns: - bool: True if more data can be fetched. - """ - return item_type not in self.fetched_item_types - - def fetch_more(self, item_type, limit=None): + def fetch_more(self, item_type, offset=0, limit=None): """Fetches items from the DB into the in-memory mapping, incrementally. Args: item_type (str): One of . - limit (int): The maximum number of items to fetch. Successive calls to this function - will start from the point where the last one left. - In other words, each item is fetched from the DB exactly once. + offset (int): The initial row. + limit (int): The maximum number of rows to fetch. Returns: list(:class:`PublicItem`): The items fetched. """ item_type = self._real_tablename(item_type) - return [x.public_item for x in self.do_fetch_more(item_type, limit=limit)] + return [x.public_item for x in self.do_fetch_more(item_type, offset=offset, limit=limit)] def fetch_all(self, *item_types): """Fetches items from the DB into the in-memory mapping. diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index c739f689..21b9bf99 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -9,7 +9,6 @@ # this program. If not, see . ###################################################################################################################### -import threading from enum import Enum, unique, auto from difflib import SequenceMatcher from .temp_id import TempId @@ -17,8 +16,6 @@ # TODO: Implement MappedItem.pop() to do lookup? -_LIMIT = 10000 - @unique class Status(Enum): @@ -41,9 +38,7 @@ class DatabaseMappingBase: def __init__(self): self._mapped_tables = {} - self._offsets = {} - self._offset_lock = threading.Lock() - self._fetched_item_types = set() + self._completed_queries = {} item_types = self.item_types() self._sorted_item_types = [] while item_types: @@ -53,15 +48,6 @@ def __init__(self): else: self._sorted_item_types.append(item_type) - @property - def fetched_item_types(self): - """Returns a set with the item types that are already fetched. - - Returns: - set - """ - return self._fetched_item_types - @staticmethod def item_types(): """Returns a list of public item types from the DB mapping schema (equivalent to the table names). @@ -92,14 +78,40 @@ def _item_factory(item_type): """ raise NotImplementedError() - def _make_query(self, item_type): - """Returns a :class:`~spinedb_api.query.Query` object to fecth items of given type. + def _make_query(self, item_type, **kwargs): + """Returns a :class:`~spinedb_api.query.Query` object to fetch items of given type. + + Args: + item_type (str) + **kwargs: query filters + + Returns: + :class:`~spinedb_api.query.Query` or None if the mapping is closed. + """ + if self.closed: + return None + sq = self._make_sq(item_type) + qry = self.query(sq) + for key, value in kwargs.items(): + if hasattr(sq.c, key): + qry = qry.filter(getattr(sq.c, key) == value) + elif key in self._item_factory(item_type)._references: + src_key, (ref_type, ref_key) = self._item_factory(item_type)._references[key] + ref_sq = self._make_sq(ref_type) + qry = qry.filter(getattr(sq.c, src_key) == ref_sq.c.id, getattr(ref_sq.c, ref_key) == value) + else: + raise SpineDBAPIError(f"invalid filter {key}={value} for {item_type}") + return qry + + def _make_sq(self, item_type): + """Returns a :class:`~sqlalchemy.sql.expression.Alias` object representing a subquery + to collect items of given type. Args: item_type (str) Returns: - :class:`~spinedb_api.query.Query` + :class:`~sqlalchemy.sql.expression.Alias` """ raise NotImplementedError() @@ -143,10 +155,7 @@ def _dirty_items(self): # FIXME: We should also fetch the current item type because of multi-dimensional entities and # classes which also depend on zero-dimensional ones for other_item_type in self.item_types(): - if ( - other_item_type not in self.fetched_item_types - and item_type in self._item_factory(other_item_type).ref_types() - ): + if item_type in self._item_factory(other_item_type).ref_types(): self.fetch_all(other_item_type) if to_add or to_update or to_remove: dirty_items.append((item_type, (to_add, to_update, to_remove))) @@ -187,23 +196,28 @@ def _rollback(self): def _refresh(self): """Clears fetch progress, so the DB is queried again.""" - self._offsets.clear() - self._fetched_item_types.clear() - - def _get_next_chunk(self, item_type, limit): - qry = self._make_query(item_type) + self._completed_queries.clear() + + def _get_next_chunk(self, item_type, offset, limit, **kwargs): + completed_queries = self._completed_queries.setdefault(item_type, set()) + qry_key = tuple(sorted(kwargs.items())) + if qry_key in completed_queries: + items = [x for x in self.mapped_table(item_type).values() if all(x.get(k) == v for k, v in kwargs.items())] + if limit is None: + return items[offset:] + return items[offset : offset + limit] + qry = self._make_query(item_type, **kwargs) if not qry: return [] if not limit: - self._fetched_item_types.add(item_type) + completed_queries.add(qry_key) return [dict(x) for x in qry] - with self._offset_lock: - offset = self._offsets.setdefault(item_type, 0) - chunk = [dict(x) for x in qry.limit(limit).offset(offset)] - self._offsets[item_type] += len(chunk) + chunk = [dict(x) for x in qry.limit(limit).offset(offset)] + if len(chunk) < limit: + completed_queries.add(qry_key) return chunk - def _advance_query(self, item_type, limit): + def _advance_query(self, item_type, offset, limit, **kwargs): """Advances the DB query that fetches items of given type and adds the results to the corresponding mapped table. @@ -213,12 +227,11 @@ def _advance_query(self, item_type, limit): Returns: list: items fetched from the DB """ - chunk = self._get_next_chunk(item_type, limit) + chunk = self._get_next_chunk(item_type, offset, limit, **kwargs) if not chunk: - self._fetched_item_types.add(item_type) return [] mapped_table = self.mapped_table(item_type) - return list(filter(lambda i: i is not None, (mapped_table.add_item(item) for item in chunk))) + return [mapped_table.add_item(item) for item in chunk] def _check_item_type(self, item_type): if item_type not in self.all_item_types(): @@ -249,8 +262,7 @@ def reset(self, *item_types): # Now clear things for item_type in item_types: self._mapped_tables.pop(item_type, None) - self._offsets.pop(item_type, None) - self._fetched_item_types.discard(item_type) + self._completed_queries.pop(item_type, None) def get_mapped_item(self, item_type, id_): mapped_table = self.mapped_table(item_type) @@ -259,30 +271,14 @@ def get_mapped_item(self, item_type, id_): return {} return item - def do_fetch_more(self, item_type, limit=_LIMIT): - if item_type in self._fetched_item_types: - return [] - return self._advance_query(item_type, limit) - - def do_fetch_all(self, item_type): - while self.do_fetch_more(item_type): - pass + def do_fetch_more(self, item_type, offset=0, limit=None, **kwargs): + return self._advance_query(item_type, offset, limit, **kwargs) - def fetch_value(self, item_type, return_fn): - while self.do_fetch_more(item_type): - return_value = return_fn() - if return_value: - return return_value - return return_fn() + def do_fetch_all(self, item_type, **kwargs): + self.do_fetch_more(item_type, **kwargs) def fetch_ref(self, item_type, id_): - while self.do_fetch_more(item_type): - ref = self.get_mapped_item(item_type, id_) - if ref: - return ref - # It is possible that fetching was completed between deciding to call this function - # and starting the while loop above resulting in self.do_fetch_more() to return False immediately. - # Therefore, we should try one last time if the ref is available. + self.do_fetch_all(item_type) ref = self.get_mapped_item(item_type, id_) if ref: return ref @@ -328,9 +324,8 @@ def unique_key_value_to_id(self, key, value, strict=False, fetch=True): """ id_by_unique_value = self._id_by_unique_key_value.get(key, {}) if not id_by_unique_value and fetch: - id_by_unique_value = self._db_map.fetch_value( - self._item_type, lambda: self._id_by_unique_key_value.get(key, {}) - ) + self._db_map.do_fetch_all(self._item_type) + id_by_unique_value = self._id_by_unique_key_value.get(key, {}) value = tuple(tuple(x) if isinstance(x, list) else x for x in value) if strict: return id_by_unique_value[value] @@ -365,11 +360,12 @@ def find_item(self, item, skip_keys=(), fetch=True): id_ = item.get("id") if id_ is not None: # id is given, easy - item = self.get(id_) - if item or not fetch: - return item - return self._db_map.fetch_ref(self._item_type, id_) - # No id. Try to locate the item by the value of one of the unique keys. + current_item = self.get(id_) + if not current_item and fetch: + current_item = self._db_map.fetch_ref(self._item_type, id_) + if current_item: + return current_item + # No id or not found by id. Try to locate the item by the value of one of the unique keys. # Used by import_data (and more...) for key in self._db_map._item_factory(self._item_type)._unique_keys: if key in skip_keys: @@ -443,19 +439,15 @@ def remove_unique(self, item): del id_by_value[value] def add_item(self, item, new=False): + if not new: + # Item comes from the DB; don̈́'t add it twice + existing = self.find_item(item, fetch=False) + if existing: + return existing if not isinstance(item, MappedItemBase): item = self._make_item(item) item.polish() - if not new: - # Item comes from the DB - id_ = item["id"] - if id_ in self or id_ in self._temp_id_by_db_id: - # The item is already in the mapping - return - if any(value in self._id_by_unique_key_value.get(key, {}) for key, value in item.unique_values()): - # An item with the same unique key is already in the mapping - return - else: + if new: item.status = Status.to_add if "id" not in item or not item.is_id_valid: item["id"] = self._new_id() diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index d2148661..256b45be 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -232,7 +232,7 @@ def test_fetch_more_after_commit_and_refresh(self): db_map.commit_session("Add test data.") db_map.refresh_session() entities = db_map.fetch_more("entity") - self.assertEqual(entities, []) + self.assertEqual([(x["class_name"], x["name"]) for x in entities], [("Widget", "gadget")]) class TestDatabaseMappingLegacy(unittest.TestCase): From 062697e558f8586e70bb8cab072b05f04a1895af Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 18 Oct 2023 15:26:45 +0200 Subject: [PATCH 154/317] Fix importing object classes --- spinedb_api/import_functions.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 0d2363fc..0560f924 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -207,7 +207,7 @@ def get_data_for_import( yield ("parameter_value_metadata", _get_parameter_value_metadata_for_import(db_map, parameter_value_metadata)) # Legacy if object_classes: - yield from get_data_for_import(db_map, entity_classes=object_classes) + yield from get_data_for_import(db_map, entity_classes=_object_classes_from_entity_classes(object_classes)) if relationship_classes: yield from get_data_for_import(db_map, entity_classes=relationship_classes) if object_parameters: @@ -699,13 +699,10 @@ def _data_iterator(): # Legacy -def _get_object_classes_for_import(db_map, data): - def _data_iterator(): - for x in data: - if isinstance(x, str): - yield x, () - else: - name, *optionals = x - yield name, (), *optionals - - return _get_entity_classes_for_import(db_map, _data_iterator()) +def _object_classes_from_entity_classes(data): + for x in data: + if isinstance(x, str): + yield x, () + else: + name, *optionals = x + yield name, (), *optionals From 05544b7b6a9c0ce577f9912f3fbd93617f6bb19e Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 19 Oct 2023 14:25:52 +0200 Subject: [PATCH 155/317] Lazy purge - instead of fetching the entire table just to purge it... ...we mark the table as purged and then purge stuff as it is fetched by the system. Re #262 --- spinedb_api/compatibility.py | 4 +- spinedb_api/db_mapping.py | 21 ++++---- spinedb_api/db_mapping_base.py | 58 ++++++++++++++++------ spinedb_api/db_mapping_commit_mixin.py | 12 +++-- tests/test_purge.py | 69 +++++++++++++++++++++++++- 5 files changed, 129 insertions(+), 35 deletions(-) diff --git a/spinedb_api/compatibility.py b/spinedb_api/compatibility.py index 2e748ce6..38bc144c 100644 --- a/spinedb_api/compatibility.py +++ b/spinedb_api/compatibility.py @@ -108,8 +108,8 @@ def compatibility_transformations(connection): connection (Connection) Returns: - list: list of tuples (tablename, (items_added, items_updated, ids_removed)) - list: list of strings indicating the changes + tuple(list, list): list of tuples (tablename, (items_added, items_updated, ids_removed)), and + list of strings indicating the changes """ ea_items_added, ea_items_updated, pval_ids_removed = convert_tool_feature_method_to_entity_alternative(connection) transformations = [] diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 7d215126..4cbc8937 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -515,7 +515,7 @@ def update_items(self, item_type, *items, check=True, strict=False): updated.append(item) return updated, errors - def remove_item(self, item_type, id): + def remove_item(self, item_type, id_): """Removes an item from the in-memory mapping. Example:: @@ -526,14 +526,14 @@ def remove_item(self, item_type, id): Args: item_type (str): One of . - id (int): The id of the item to remove. + id_ (int): The id of the item to remove. Returns: tuple(:class:`PublicItem` or None, str): The removed item if any. """ item_type = self._real_tablename(item_type) mapped_table = self.mapped_table(item_type) - return mapped_table.remove_item(id).public_item + return mapped_table.remove_item(id_).public_item def remove_items(self, item_type, *ids): """Removes many items from the in-memory mapping. @@ -548,17 +548,13 @@ def remove_items(self, item_type, *ids): if not ids: return [] item_type = self._real_tablename(item_type) - mapped_table = self.mapped_table(item_type) - if Asterisk in ids: - self.fetch_all(item_type) - ids = mapped_table ids = set(ids) if item_type == "alternative": # Do not remove the Base alternative ids.discard(1) return [self.remove_item(item_type, id_) for id_ in ids] - def restore_item(self, item_type, id): + def restore_item(self, item_type, id_): """Restores a previously removed item into the in-memory mapping. Example:: @@ -569,14 +565,14 @@ def restore_item(self, item_type, id): Args: item_type (str): One of . - id (int): The id of the item to restore. + id_ (int): The id of the item to restore. Returns: tuple(:class:`PublicItem` or None, str): The restored item if any. """ item_type = self._real_tablename(item_type) mapped_table = self.mapped_table(item_type) - return mapped_table.restore_item(id).public_item + return mapped_table.restore_item(id_).public_item def restore_items(self, item_type, *ids): """Restores many previously removed items into the in-memory mapping. @@ -667,6 +663,9 @@ def commit_session(self, comment): Args: comment (str): commit message + + Returns: + tuple(list, list): compatibility transformations """ if not comment: raise SpineDBAPIError("Commit message cannot be empty.") @@ -757,8 +756,6 @@ def get_filter_configs(self): # Astroid transform so DatabaseMapping looks like it has the convenience methods defined above def _add_convenience_methods(node): - import astroid - if node.name != "DatabaseMapping": return node for item_type in DatabaseMapping.item_types(): diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 21b9bf99..c26deb3b 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -13,6 +13,7 @@ from difflib import SequenceMatcher from .temp_id import TempId from .exception import SpineDBAPIError +from .helpers import Asterisk # TODO: Implement MappedItem.pop() to do lookup? @@ -134,21 +135,25 @@ def _dirty_items(self): list """ dirty_items = [] + purged_item_types = {x for x in self.item_types() if self.mapped_table(x).purged} + self._add_descendants(purged_item_types) for item_type in self._sorted_item_types: - mapped_table = self.get(item_type) - if mapped_table is None: - continue + mapped_table = self.mapped_table(item_type) to_add = [] to_update = [] to_remove = [] - for item in mapped_table.values(): - _ = item.is_valid() + for item in mapped_table.valid_values(): if item.status == Status.to_add: to_add.append(item) elif item.status == Status.to_update: to_update.append(item) - elif item.status == Status.to_remove: - to_remove.append(item) + if item_type in purged_item_types: + to_remove.append(mapped_table.wildcard_item) + else: + for item in mapped_table.values(): + _ = item.is_valid() + if item.status == Status.to_remove: + to_remove.append(item) if to_remove: # Fetch descendants, so that they are validated in next iterations of the loop. # This ensures cascade removal. @@ -250,7 +255,12 @@ def reset(self, *item_types): Any modifications in the mapping that aren't committed to the DB are lost after this. """ item_types = set(self.item_types()) if not item_types else set(item_types) & set(self.item_types()) - # Include descendants, otherwise references are broken + self._add_descendants(item_types) + for item_type in item_types: + self._mapped_tables.pop(item_type, None) + self._completed_queries.pop(item_type, None) + + def _add_descendants(self, item_types): while True: changed = False for item_type in set(self.item_types()) - item_types: @@ -259,10 +269,6 @@ def reset(self, *item_types): changed = True if not changed: break - # Now clear things - for item_type in item_types: - self._mapped_tables.pop(item_type, None) - self._completed_queries.pop(item_type, None) def get_mapped_item(self, item_type, id_): mapped_table = self.mapped_table(item_type) @@ -296,6 +302,15 @@ def __init__(self, db_map, item_type, *args, **kwargs): self._item_type = item_type self._id_by_unique_key_value = {} self._temp_id_by_db_id = {} + self.wildcard_item = MappedItemBase(self._db_map, self._item_type, id=Asterisk) + + @property + def purged(self): + return self.wildcard_item.status == Status.to_remove + + @purged.setter + def purged(self, purged): + self.wildcard_item.status = Status.to_remove if purged else Status.committed def get(self, id_, default=None): id_ = self._temp_id_by_db_id.get(id_, id_) @@ -447,12 +462,17 @@ def add_item(self, item, new=False): if not isinstance(item, MappedItemBase): item = self._make_item(item) item.polish() - if new: - item.status = Status.to_add if "id" not in item or not item.is_id_valid: item["id"] = self._new_id() self[item["id"]] = item self.add_unique(item) + if new: + item.status = Status.to_add + elif self.purged: + # Sorry, item, you're coming from the DB and I have been purged: so... + item.cascade_remove(source=self.wildcard_item) + # More seriously, this is like a lazy purge: insteaf of fetching all at purge time, + # we purge stuff as it comes. return item def update_item(self, item): @@ -464,6 +484,11 @@ def update_item(self, item): return current_item def remove_item(self, id_): + if id_ is Asterisk: + self.purged = True + for item in self.valid_values(): + item.cascade_remove(source=self.wildcard_item) + return self.wildcard_item current_item = self.find_item({"id": id_}) if current_item is not None: self.remove_unique(current_item) @@ -471,6 +496,11 @@ def remove_item(self, id_): return current_item def restore_item(self, id_): + if id_ is Asterisk: + self.purged = False + for item in self.values(): + item.cascade_restore(source=self.wildcard_item) + return self.wildcard_item current_item = self.find_item({"id": id_}) if current_item is not None: self.add_unique(current_item) diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py index 165ed007..d37e4eb7 100644 --- a/spinedb_api/db_mapping_commit_mixin.py +++ b/spinedb_api/db_mapping_commit_mixin.py @@ -14,7 +14,7 @@ from sqlalchemy.exc import DBAPIError from .exception import SpineDBAPIError from .temp_id import TempId, resolve -from .helpers import group_consecutive +from .helpers import group_consecutive, Asterisk class DatabaseMappingCommitMixin: @@ -149,10 +149,12 @@ def _do_remove_items(self, connection, tablename, *ids): tablenames.append("entity_element") for tablename_ in tablenames: table = self._metadata.tables[tablename_] - id_field = self._id_fields.get(tablename_, "id") - id_column = getattr(table.c, id_field) - cond = or_(*(and_(id_column >= first, id_column <= last) for first, last in group_consecutive(ids))) - delete = table.delete().where(cond) + delete = table.delete() + if Asterisk not in ids: + id_field = self._id_fields.get(tablename_, "id") + id_column = getattr(table.c, id_field) + cond = or_(*(and_(id_column >= first, id_column <= last) for first, last in group_consecutive(ids))) + delete = delete.where(cond) try: connection.execute(delete) except DBAPIError as e: diff --git a/tests/test_purge.py b/tests/test_purge.py index 906e00d1..15f84fab 100644 --- a/tests/test_purge.py +++ b/tests/test_purge.py @@ -14,6 +14,7 @@ from spinedb_api import DatabaseMapping from spinedb_api.purge import purge_url +from spinedb_api.helpers import Asterisk class TestPurgeUrl(unittest.TestCase): @@ -31,11 +32,75 @@ def test_purge_entity_classes(self): db_map.commit_session("Add test data") purge_url(self._url, {"alternative": False, "entity_class": True}) with DatabaseMapping(self._url) as db_map: - entities = db_map.query(db_map.entity_class_sq).all() - self.assertEqual(entities, []) + classes = db_map.query(db_map.entity_class_sq).all() + self.assertEqual(classes, []) alternatives = db_map.query(db_map.alternative_sq).all() self.assertEqual(len(alternatives), 1) + def test_purge_then_add(self): + with DatabaseMapping(self._url, create=True) as db_map: + db_map.remove_item("entity_class", Asterisk) + db_map.add_item("entity_class", name="Soup") + db_map.commit_session("Yummy") + self.assertEqual([x["name"] for x in db_map.get_items("entity_class")], ["Soup"]) + with DatabaseMapping(self._url, create=True) as db_map: + self.assertEqual([x["name"] for x in db_map.get_items("entity_class")], ["Soup"]) + + def test_add_then_purge_then_unpurge(self): + with DatabaseMapping(self._url, create=True) as db_map: + db_map.add_item("entity_class", name="Soup") + db_map.remove_item("entity_class", Asterisk) + self.assertFalse(db_map.get_items("entity_class")) + db_map.restore_item("entity_class", Asterisk) + self.assertEqual([x["name"] for x in db_map.get_items("entity_class")], ["Soup"]) + + def test_add_then_purge_then_add(self): + with DatabaseMapping(self._url, create=True) as db_map: + db_map.add_item("entity_class", name="Soup") + db_map.remove_item("entity_class", Asterisk) + self.assertFalse(db_map.get_items("entity_class")) + db_map.add_item("entity_class", name="Poison") + self.assertEqual([x["name"] for x in db_map.get_items("entity_class")], ["Poison"]) + + def test_add_then_purge_then_add_then_purge_again(self): + with DatabaseMapping(self._url, create=True) as db_map: + db_map.add_item("entity_class", name="Soup") + db_map.remove_item("entity_class", Asterisk) + self.assertFalse(db_map.get_items("entity_class")) + db_map.add_item("entity_class", name="Poison") + db_map.remove_item("entity_class", Asterisk) + self.assertFalse(db_map.get_items("entity_class")) + + def test_dont_keep_purging_after_commit(self): + """Tests that if I purge and then commit, then add more stuff the commit again, the stuff I added + after the first commit is not purged. In other words, the commit resets the purge need.""" + with DatabaseMapping(self._url, create=True) as db_map: + db_map.add_item("entity_class", name="Soup") + db_map.remove_item("entity_class", Asterisk) + db_map.commit_session("Yummy but nope") + self.assertFalse(db_map.get_items("entity_class")) + db_map.add_item("entity_class", name="Poison") + db_map.commit_session("Deadly") + self.assertEqual([x["name"] for x in db_map.get_items("entity_class")], ["Poison"]) + with DatabaseMapping(self._url, create=True) as db_map: + self.assertEqual([x["name"] for x in db_map.get_items("entity_class")], ["Poison"]) + + def test_purge_externally(self): + with DatabaseMapping(self._url, create=True) as db_map: + db_map.add_item("entity_class", name="Soup") + db_map.commit_session("Add test data") + with DatabaseMapping(self._url, create=True) as db_map: + db_map.fetch_all() + self.assertEqual([x["name"] for x in db_map.get_items("entity_class")], ["Soup"]) + purge_url(self._url, {"entity_class": True}) + self.assertEqual([x["name"] for x in db_map.get_items("entity_class")], ["Soup"]) + # Mapped items survive an external purge! + # It is up to the client to resolve the situation. + # For example, toolbox does it via SpineDBManager.notify_session_committed + # which calls DatabaseMapping.reset + with DatabaseMapping(self._url, create=True) as db_map: + self.assertFalse(db_map.get_items("entity_class")) + if __name__ == '__main__': unittest.main() From f7b7684e542ebc3cd6719b2890b45f2b0e050d1e Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Thu, 19 Oct 2023 16:12:17 +0300 Subject: [PATCH 156/317] Add method to check if database has commits from other sources. Re spine-tools/Spine-Toolbox#2353 --- spinedb_api/db_mapping.py | 13 ++++++++++++- tests/test_DatabaseMapping.py | 24 ++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 7d215126..d68383bf 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -187,6 +187,7 @@ def __init__( if self._filter_configs is not None: stack = load_filters(self._filter_configs) apply_filter_stack(self, stack) + self._commit_count = self.query(self.commit_sq).count() def __enter__(self): return self @@ -349,6 +350,14 @@ def _convert_legacy(tablename, item): if entity_id: item["entity_id"] = entity_id + def has_external_commits(self): + """Test whether the database has had commits from other sources than this mapping. + + Returns: + bool: True if database has external commits, False otherwise + """ + return self._commit_count != self.query(self.commit_sq).count() + def get_import_alternative_name(self): if self._import_alternative_name is None: self._create_import_alternative() @@ -690,7 +699,9 @@ def commit_session(self, comment): self._do_add_items(connection, tablename, *to_add) if self._memory: self._memory_dirty = True - return compatibility_transformations(connection) + transformation_info = compatibility_transformations(connection) + self._commit_count = self.query(self.commit_sq).count() + return transformation_info def rollback_session(self): """Discards all the changes from the in-memory mapping.""" diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 256b45be..5f51fba2 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -234,6 +234,30 @@ def test_fetch_more_after_commit_and_refresh(self): entities = db_map.fetch_more("entity") self.assertEqual([(x["class_name"], x["name"]) for x in entities], [("Widget", "gadget")]) + def test_has_external_commits_returns_false_initially(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + self.assertFalse(db_map.has_external_commits()) + + def test_has_external_commits_returns_true_when_another_db_mapping_has_made_commits(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + with DatabaseMapping(url) as other_db_map: + other_db_map.add_item("entity_class", name="cc") + other_db_map.commit_session("Added a class") + self.assertTrue(db_map.has_external_commits()) + + def test_has_external_commits_returns_false_after_commit_session(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + with DatabaseMapping(url) as other_db_map: + other_db_map.add_item("entity_class", name="cc") + other_db_map.commit_session("Added a class") + db_map.add_item("entity_class", name="omega") + db_map.commit_session("Added a class") + self.assertFalse(db_map.has_external_commits()) + class TestDatabaseMappingLegacy(unittest.TestCase): """'Backward compatibility' tests, i.e. pre-entity tests converted to work with the entity structure.""" From b213b5a8a01c012de779f07f8c20806ac3e174af Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Fri, 20 Oct 2023 09:10:11 +0300 Subject: [PATCH 157/317] Add CommitItem This makes it possible to get_items("commit"). Handy for things like Database editor's Commit viewer. Re #2366 --- spinedb_api/mapped_items.py | 13 +++++++++++++ tests/test_DatabaseMapping.py | 6 ++++++ 2 files changed, 19 insertions(+) diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 7d42fee3..a3f5386b 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -18,6 +18,7 @@ def item_factory(item_type): return { + "commit": CommitItem, "entity_class": EntityClassItem, "entity": EntityItem, "entity_alternative": EntityAlternativeItem, @@ -35,6 +36,18 @@ def item_factory(item_type): }.get(item_type, MappedItemBase) +class CommitItem(MappedItemBase): + fields = { + "comment": ("str", "A comment describing the commit."), + "date": {"datetime", "Date and time of the commit."}, + "user": {"str", "Username of the committer."}, + } + _unique_keys = (("date",),) + + def commit(self, commit_id): + raise RuntimeError("Commits are created automatically when session is committed.") + + class EntityClassItem(MappedItemBase): fields = { "name": ("str", "The class name."), diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 5f51fba2..22022f0c 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -258,6 +258,12 @@ def test_has_external_commits_returns_false_after_commit_session(self): db_map.commit_session("Added a class") self.assertFalse(db_map.has_external_commits()) + def test_get_items_gives_commits(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + items = db_map.get_items("commit") + self.assertEqual(len(items), 1) + self.assertEqual(items[0].item_type, "commit") + class TestDatabaseMappingLegacy(unittest.TestCase): """'Backward compatibility' tests, i.e. pre-entity tests converted to work with the entity structure.""" From 58cb7b47b17dd381d4257f2659ff6e03cb7f2b9f Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 20 Oct 2023 12:18:13 +0200 Subject: [PATCH 158/317] Fix adding/updating values from list Re #295 --- spinedb_api/db_mapping.py | 4 +- spinedb_api/db_mapping_base.py | 42 ++++--- spinedb_api/db_mapping_commit_mixin.py | 6 +- spinedb_api/mapped_items.py | 149 +++++++++++-------------- tests/test_DatabaseMapping.py | 20 +++- tests/test_import_functions.py | 19 ++++ tests/test_purge.py | 4 +- 7 files changed, 136 insertions(+), 108 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 68ad938e..94fba674 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -395,9 +395,9 @@ def get_item(self, item_type, fetch=True, skip_removed=True, **kwargs): item_type = self._real_tablename(item_type) item = self.mapped_table(item_type).find_item(kwargs, fetch=fetch) if not item: - return None + return {} if skip_removed and not item.is_valid(): - return None + return {} return item.public_item def get_items(self, item_type, fetch=True, skip_removed=True, **kwargs): diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index c26deb3b..a0502914 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -11,7 +11,7 @@ from enum import Enum, unique, auto from difflib import SequenceMatcher -from .temp_id import TempId +from .temp_id import TempId, resolve from .exception import SpineDBAPIError from .helpers import Asterisk @@ -374,14 +374,17 @@ def find_item(self, item, skip_keys=(), fetch=True): """ id_ = item.get("id") if id_ is not None: - # id is given, easy - current_item = self.get(id_) - if not current_item and fetch: - current_item = self._db_map.fetch_ref(self._item_type, id_) - if current_item: - return current_item - # No id or not found by id. Try to locate the item by the value of one of the unique keys. - # Used by import_data (and more...) + return self._find_item_by_id(id_, fetch=fetch) + return self._find_item_by_unique_key(item, skip_keys=skip_keys, fetch=fetch) + + def _find_item_by_id(self, id_, fetch=True): + current_item = self.get(id_) + if current_item is None and fetch: + current_item = self._db_map.fetch_ref(self._item_type, id_) + if current_item: + return current_item + + def _find_item_by_unique_key(self, item, skip_keys=(), fetch=True): for key in self._db_map._item_factory(self._item_type)._unique_keys: if key in skip_keys: continue @@ -391,7 +394,7 @@ def find_item(self, item, skip_keys=(), fetch=True): current_item = self._unique_key_value_to_item(key, value, fetch=fetch) if current_item: return current_item - # Last hope: maybe item is missing some key stuff, so try with a resolved and polished MappedItem instead... + # Maybe item is missing some key stuff, so try with a resolved and polished MappedItem too... mapped_item = self._make_item(item) error = mapped_item.resolve_inverse_references(item.keys()) if error: @@ -456,9 +459,9 @@ def remove_unique(self, item): def add_item(self, item, new=False): if not new: # Item comes from the DB; don̈́'t add it twice - existing = self.find_item(item, fetch=False) - if existing: - return existing + current = self._find_item_by_id(item["id"], fetch=False) or self._find_item_by_unique_key(item, fetch=False) + if current: + return current if not isinstance(item, MappedItemBase): item = self._make_item(item) item.polish() @@ -486,8 +489,9 @@ def update_item(self, item): def remove_item(self, id_): if id_ is Asterisk: self.purged = True - for item in self.valid_values(): - item.cascade_remove(source=self.wildcard_item) + for current_item in self.valid_values(): + self.remove_unique(current_item) + current_item.cascade_remove(source=self.wildcard_item) return self.wildcard_item current_item = self.find_item({"id": id_}) if current_item is not None: @@ -498,8 +502,9 @@ def remove_item(self, id_): def restore_item(self, id_): if id_ is Asterisk: self.purged = False - for item in self.values(): - item.cascade_restore(source=self.wildcard_item) + for current_item in self.values(): + self.add_unique(current_item) + current_item.cascade_restore(source=self.wildcard_item) return self.wildcard_item current_item = self.find_item({"id": id_}) if current_item is not None: @@ -656,6 +661,9 @@ def _asdict(self): """ return dict(self) + def resolve(self): + return {k: resolve(v) for k, v in self._asdict().items()} + def merge(self, other): """Merges this item with another and returns the merged item together with any errors. Used for updating items. diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py index d37e4eb7..d89542a8 100644 --- a/spinedb_api/db_mapping_commit_mixin.py +++ b/spinedb_api/db_mapping_commit_mixin.py @@ -44,7 +44,7 @@ def _do_add_items(self, connection, tablename, *items_to_add): else: id_items.append(item) if id_items: - connection.execute(table.insert(), [resolve(x._asdict()) for x in id_items]) + connection.execute(table.insert(), [x.resolve() for x in id_items]) if temp_id_items: current_ids = {x["id"] for x in connection.execute(table.select())} next_id = max(current_ids, default=0) + 1 @@ -55,7 +55,7 @@ def _do_add_items(self, connection, tablename, *items_to_add): for id_, item in zip(ids, temp_id_items): temp_id = item["id"] temp_id.resolve(id_) - connection.execute(table.insert(), [resolve(x._asdict()) for x in temp_id_items]) + connection.execute(table.insert(), [x.resolve() for x in temp_id_items]) for tablename_, items_to_add_ in self._extra_items_to_add_per_table(tablename, items_to_add): if not items_to_add_: continue @@ -117,7 +117,7 @@ def _do_update_items(self, connection, tablename, *items_to_update): return try: upd = self._make_update_stmt(tablename, items_to_update[0].keys()) - connection.execute(upd, [resolve(item._asdict()) for item in items_to_update]) + connection.execute(upd, [x.resolve() for x in items_to_update]) for tablename_, items_to_update_ in self._extra_items_to_update_per_table(tablename, items_to_update): if not items_to_update_: continue diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index a3f5386b..754bd826 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -13,7 +13,7 @@ from operator import itemgetter from .parameter_value import to_database, from_database, ParameterValueFormatError from .db_mapping_base import MappedItemBase -from .temp_id import TempId +from .temp_id import TempId, resolve def item_factory(item_type): @@ -153,11 +153,8 @@ def _element_name_list_recursive(self, class_name, byname_remainder): name = byname_remainder.pop(0) return (name,) return tuple( - ( - self._db_map.get_item( - "entity", class_name=dim_name, byname=self._element_name_list_recursive(dim_name, byname_remainder) - ) - or {} + self._db_map.get_item( + "entity", class_name=dim_name, byname=self._element_name_list_recursive(dim_name, byname_remainder) ).get("name") for dim_name in dimension_name_list ) @@ -281,7 +278,59 @@ def _something_to_update(self, other): return super()._something_to_update(other) -class ParameterDefinitionItem(ParsedValueBase): +class ParameterItemBase(ParsedValueBase): + @property + def _value_key(self): + raise NotImplementedError() + + @property + def _type_key(self): + raise NotImplementedError() + + def _value_not_in_list_error(self, parsed_value, list_name): + raise NotImplementedError() + + @classmethod + def ref_types(cls): + return super().ref_types() | {"list_value"} + + @property + def list_value_id(self): + return self["list_value_id"] + + def resolve(self): + d = super().resolve() + list_value_id = d.get("list_value_id") + if list_value_id is not None: + d[self._value_key] = to_database(list_value_id)[0] + return d + + def polish(self): + error = super().polish() + if error: + return error + list_name = self["parameter_value_list_name"] + if list_name is None: + self["list_value_id"] = None + return + type_ = super().__getitem__(self._type_key) + if type_ == "list_value_ref": + return + # value = self[self._value_key] + value = super().__getitem__(self._value_key) + parsed_value = from_database(value, type_) + if parsed_value is None: + return + list_value_id = self._db_map.get_item( + "list_value", parameter_value_list_name=list_name, value=value, type=type_ + ).get("id") + if list_value_id is None: + return self._value_not_in_list_error(parsed_value, list_name) + self["list_value_id"] = list_value_id + self[self._type_key] = "list_value_ref" + + +class ParameterDefinitionItem(ParameterItemBase): fields = { "entity_class_name": ("str", "The entity class name."), "name": ("str", "The parameter name."), @@ -302,12 +351,6 @@ class ParameterDefinitionItem(ParsedValueBase): "parameter_value_list_id": (("parameter_value_list_name",), ("parameter_value_list", ("name",))), } - @property - def list_value_id(self): - if dict.get(self, "default_type") == "list_value_ref": - return int(dict.__getitem__(self, "default_value")) - return None - @property def _value_key(self): return "default_value" @@ -324,43 +367,13 @@ def __getitem__(self, key): if key == "parameter_value_list_name": return self._get_ref("parameter_value_list", self["parameter_value_list_id"], strong=False).get("name") if key in ("default_value", "default_type"): - list_value_id = self.list_value_id + list_value_id = self["list_value_id"] if list_value_id is not None: list_value_key = {"default_value": "value", "default_type": "type"}[key] return self._get_ref("list_value", list_value_id, strong=False).get(list_value_key) return dict.get(self, key) - if key == "list_value_id": - return self.list_value_id return super().__getitem__(key) - def polish(self): - error = super().polish() - if error: - return error - default_type = self["default_type"] - default_value = self["default_value"] - list_name = self["parameter_value_list_name"] - if list_name is None: - return - if default_type == "list_value_ref": - return - parsed_value = from_database(default_value, default_type) - if parsed_value is None: - return - list_value_id = self._db_map.mapped_table("list_value").unique_key_value_to_id( - ("parameter_value_list_name", "value", "type"), (list_name, default_value, default_type) - ) - if list_value_id is None: - return f"default value {parsed_value} of {self['name']} is not in {list_name}" - self["default_value"] = to_database(list_value_id)[0] - self["default_type"] = "list_value_ref" - if isinstance(list_value_id, TempId): - - def callback(new_id): - self["default_value"] = to_database(new_id)[0] - - list_value_id.add_resolve_callback(callback) - def merge(self, other): other_parameter_value_list_id = other.get("parameter_value_list_id") if ( @@ -378,8 +391,11 @@ def merge(self, other): merged, super_error = super().merge(other) return merged, " and ".join([x for x in (super_error, error) if x]) + def _value_not_in_list_error(self, parsed_value, list_name): + return f"default value {parsed_value} of {self['name']} is not in {list_name}" + -class ParameterValueItem(ParsedValueBase): +class ParameterValueItem(ParameterItemBase): fields = { "entity_class_name": ("str", "The entity class name."), "parameter_definition_name": ("str", "The parameter name."), @@ -416,12 +432,6 @@ class ParameterValueItem(ParsedValueBase): "alternative_id": (("alternative_name",), ("alternative", ("name",))), } - @property - def list_value_id(self): - if dict.__getitem__(self, "type") == "list_value_ref": - return int(dict.__getitem__(self, "value")) - return None - @property def _value_key(self): return "value" @@ -436,43 +446,16 @@ def __getitem__(self, key): if key == "parameter_name": return super().__getitem__("parameter_definition_name") if key in ("value", "type"): - list_value_id = self.list_value_id + list_value_id = self["list_value_id"] if list_value_id: return self._get_ref("list_value", list_value_id, strong=False).get(key) - if key == "list_value_id": - return self.list_value_id return super().__getitem__(key) - def polish(self): - error = super().polish() - if error: - return error - list_name = self["parameter_value_list_name"] - if list_name is None: - return - type_ = self["type"] - if type_ == "list_value_ref": - return - value = self["value"] - parsed_value = from_database(value, type_) - if parsed_value is None: - return - list_value_id = self._db_map.mapped_table("list_value").unique_key_value_to_id( - ("parameter_value_list_name", "value", "type"), (list_name, value, type_) + def _value_not_in_list_error(self, parsed_value, list_name): + return ( + f"value {parsed_value} of {self['parameter_definition_name']} for {self['entity_byname']} " + f"is not in {list_name}" ) - if list_value_id is None: - return ( - f"value {parsed_value} of {self['parameter_definition_name']} for {self['entity_byname']} " - f"is not in {list_name}" - ) - self["value"] = to_database(list_value_id)[0] - self["type"] = "list_value_ref" - if isinstance(list_value_id, TempId): - - def callback(new_id): - self["value"] = to_database(new_id)[0] - - list_value_id.add_resolve_callback(callback) class ParameterValueListItem(MappedItemBase): diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 22022f0c..9d980ba9 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -209,7 +209,7 @@ def test_updating_entity_name_updates_the_name_in_parameter_value_too(self): parameter_definition_name="color", alternative_name="Base", ) - self.assertIsNone(not_color_anymore) + self.assertEqual(not_color_anymore, {}) color = db_map.get_item( "parameter_value", entity_class_name="fish", @@ -1734,6 +1734,24 @@ def test_update_parameter_value_by_id_only(self): pval = pvals[0] self.assertEqual(pval.value, b"something else") + def test_update_parameter_value_to_an_uncommitted_list_value(self): + import_functions.import_object_classes(self._db_map, ("object_class1",)) + import_functions.import_parameter_value_lists(self._db_map, (("values_1", 5.0),)) + import_functions.import_object_parameters(self._db_map, (("object_class1", "parameter1", None, "values_1"),)) + import_functions.import_objects(self._db_map, (("object_class1", "object1"),)) + import_functions.import_object_parameter_values( + self._db_map, (("object_class1", "object1", "parameter1", 5.0),) + ) + self._db_map.commit_session("Update data.") + import_functions.import_parameter_value_lists(self._db_map, (("values_1", 7.0),)) + value, type_ = to_database(7.0) + items, errors = self._db_map.update_parameter_values({"id": 1, "value": value, "type": type_}) + self.assertEqual(errors, []) + self.assertEqual(len(items), 1) + self._db_map.commit_session("Update data.") + pvals = self._db_map.query(self._db_map.parameter_value_sq).all() + self.assertEqual(from_database(pvals[0].value, pvals[0].type), 7.0) + def test_update_parameter_definition_by_id_only(self): import_functions.import_object_classes(self._db_map, ("object_class1",)) import_functions.import_object_parameters(self._db_map, (("object_class1", "parameter1"),)) diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py index 93f1892e..1b071478 100644 --- a/tests/test_import_functions.py +++ b/tests/test_import_functions.py @@ -761,6 +761,25 @@ def test_import_object_parameter_value_fails_with_nonexistent_alternative(self): self.assertEqual(count, 0) db_map.close() + def test_import_parameter_values_from_committed_value_list(self): + db_map = create_diff_db_map() + import_data(db_map, parameter_value_lists=(("values_1", 5.0),)) + db_map.commit_session("test") + count, errors = import_data( + db_map, + object_classes=("object_class",), + object_parameters=(("object_class", "parameter", None, "values_1"),), + objects=(("object_class", "my_object"),), + object_parameter_values=(("object_class", "my_object", "parameter", 5.0),), + ) + self.assertEqual(count, 4) + self.assertEqual(errors, []) + db_map.commit_session("test") + values = db_map.query(db_map.object_parameter_value_sq).all() + value = values[0] + self.assertEqual(from_database(value.value), 5.0) + db_map.close() + def test_valid_object_parameter_value_from_value_list(self): db_map = create_diff_db_map() import_parameter_value_lists(db_map, (("values_1", 5.0),)) diff --git a/tests/test_purge.py b/tests/test_purge.py index 15f84fab..f6093e1c 100644 --- a/tests/test_purge.py +++ b/tests/test_purge.py @@ -72,8 +72,8 @@ def test_add_then_purge_then_add_then_purge_again(self): self.assertFalse(db_map.get_items("entity_class")) def test_dont_keep_purging_after_commit(self): - """Tests that if I purge and then commit, then add more stuff the commit again, the stuff I added - after the first commit is not purged. In other words, the commit resets the purge need.""" + """Tests that if I purge and then commit, then add more stuff then commit again, the stuff I added + after the first commit is not purged afterwards. In other words, the commit resets the purge need.""" with DatabaseMapping(self._url, create=True) as db_map: db_map.add_item("entity_class", name="Soup") db_map.remove_item("entity_class", Asterisk) From f723d9a021b8396d0a67442e7df1627fbbf5ae1a Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Fri, 20 Oct 2023 14:02:55 +0300 Subject: [PATCH 159/317] Check _Index dimensions before comparison to avoid Traceback Re #296 --- spinedb_api/parameter_value.py | 2 +- tests/test_parameter_value.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/spinedb_api/parameter_value.py b/spinedb_api/parameter_value.py index 26f5ed1b..9d699de1 100644 --- a/spinedb_api/parameter_value.py +++ b/spinedb_api/parameter_value.py @@ -795,7 +795,7 @@ def __setitem__(self, position, index): super().__setitem__(position, index) def __eq__(self, other): - return np.all(super().__eq__(other)) + return len(self) == len(other) and np.all(super().__eq__(other)) def __bool__(self): return np.size(self) != 0 diff --git a/tests/test_parameter_value.py b/tests/test_parameter_value.py index 659cc2d8..c4a0339f 100644 --- a/tests/test_parameter_value.py +++ b/tests/test_parameter_value.py @@ -864,6 +864,10 @@ def test_Map_equality(self): map_value = Map(["A"], [nested_map]) self.assertEqual(map_value, Map(["A"], [Map(["a"], [-2.3])])) + def test_Map_inequality(self): + map_value = Map(["1", "2", "3", "4", "5"], [-2.3, 2.3, 2.3, 2.3, 2.3]) + self.assertNotEqual(map_value, Map(["a", "b"], [2.3, -2.3])) + def test_TimePattern_equality(self): pattern = TimePattern(["D1-2", "D3-7"], np.array([-2.3, -5.0])) self.assertEqual(pattern, pattern) From 73c24bc2ff4a091f70fbfdbe60f9094960986f67 Mon Sep 17 00:00:00 2001 From: Manuel Date: Sat, 21 Oct 2023 14:39:12 +0200 Subject: [PATCH 160/317] Sanitize obsolete purge settings --- spinedb_api/db_mapping.py | 20 ++++++++++---------- spinedb_api/db_mapping_base.py | 4 +--- spinedb_api/db_mapping_commit_mixin.py | 8 ++++---- spinedb_api/mapped_items.py | 3 +-- spinedb_api/purge.py | 11 +++++++---- spinedb_api/server_client_helpers.py | 2 +- 6 files changed, 24 insertions(+), 24 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 94fba674..cc9f2774 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -317,7 +317,7 @@ def _receive_engine_close(self, dbapi_con, _connection_record): copy_database_bind(self._original_engine, self.engine) @staticmethod - def _real_tablename(tablename): + def real_item_type(tablename): return { "object_class": "entity_class", "relationship_class": "entity_class", @@ -392,7 +392,7 @@ def get_item(self, item_type, fetch=True, skip_removed=True, **kwargs): Returns: :class:`PublicItem` or None """ - item_type = self._real_tablename(item_type) + item_type = self.real_item_type(item_type) item = self.mapped_table(item_type).find_item(kwargs, fetch=fetch) if not item: return {} @@ -411,7 +411,7 @@ def get_items(self, item_type, fetch=True, skip_removed=True, **kwargs): Returns: list(:class:`PublicItem`): The items. """ - item_type = self._real_tablename(item_type) + item_type = self.real_item_type(item_type) if fetch: self.fetch_all(item_type) mapped_table = self.mapped_table(item_type) @@ -435,7 +435,7 @@ def add_item(self, item_type, check=True, **kwargs): Returns: tuple(:class:`PublicItem` or None, str): The added item and any errors. """ - item_type = self._real_tablename(item_type) + item_type = self.real_item_type(item_type) mapped_table = self.mapped_table(item_type) self._convert_legacy(item_type, kwargs) if not check: @@ -491,7 +491,7 @@ def update_item(self, item_type, check=True, **kwargs): Returns: tuple(:class:`PublicItem` or None, str): The updated item and any errors. """ - item_type = self._real_tablename(item_type) + item_type = self.real_item_type(item_type) mapped_table = self.mapped_table(item_type) self._convert_legacy(item_type, kwargs) if not check: @@ -540,7 +540,7 @@ def remove_item(self, item_type, id_): Returns: tuple(:class:`PublicItem` or None, str): The removed item if any. """ - item_type = self._real_tablename(item_type) + item_type = self.real_item_type(item_type) mapped_table = self.mapped_table(item_type) return mapped_table.remove_item(id_).public_item @@ -556,7 +556,7 @@ def remove_items(self, item_type, *ids): """ if not ids: return [] - item_type = self._real_tablename(item_type) + item_type = self.real_item_type(item_type) ids = set(ids) if item_type == "alternative": # Do not remove the Base alternative @@ -579,7 +579,7 @@ def restore_item(self, item_type, id_): Returns: tuple(:class:`PublicItem` or None, str): The restored item if any. """ - item_type = self._real_tablename(item_type) + item_type = self.real_item_type(item_type) mapped_table = self.mapped_table(item_type) return mapped_table.restore_item(id_).public_item @@ -619,7 +619,7 @@ def fetch_more(self, item_type, offset=0, limit=None): Returns: list(:class:`PublicItem`): The items fetched. """ - item_type = self._real_tablename(item_type) + item_type = self.real_item_type(item_type) return [x.public_item for x in self.do_fetch_more(item_type, offset=offset, limit=limit)] def fetch_all(self, *item_types): @@ -632,7 +632,7 @@ def fetch_all(self, *item_types): """ item_types = set(self.item_types()) if not item_types else set(item_types) & set(self.item_types()) for item_type in item_types: - item_type = self._real_tablename(item_type) + item_type = self.real_item_type(item_type) self.do_fetch_all(item_type) def query(self, *args, **kwargs): diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index a0502914..647fb4c1 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -472,10 +472,8 @@ def add_item(self, item, new=False): if new: item.status = Status.to_add elif self.purged: - # Sorry, item, you're coming from the DB and I have been purged: so... + # Lazy purge: instead of fetching all at purge time, we purge stuff as it comes. item.cascade_remove(source=self.wildcard_item) - # More seriously, this is like a lazy purge: insteaf of fetching all at purge time, - # we purge stuff as it comes. return item def update_item(self, item): diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py index d89542a8..ce105140 100644 --- a/spinedb_api/db_mapping_commit_mixin.py +++ b/spinedb_api/db_mapping_commit_mixin.py @@ -36,7 +36,7 @@ def _do_add_items(self, connection, tablename, *items_to_add): if not items_to_add: return try: - table = self._metadata.tables[self._real_tablename(tablename)] + table = self._metadata.tables[self.real_item_type(tablename)] id_items, temp_id_items = [], [] for item in items_to_add: if isinstance(item["id"], TempId): @@ -59,7 +59,7 @@ def _do_add_items(self, connection, tablename, *items_to_add): for tablename_, items_to_add_ in self._extra_items_to_add_per_table(tablename, items_to_add): if not items_to_add_: continue - table = self._metadata.tables[self._real_tablename(tablename_)] + table = self._metadata.tables[self.real_item_type(tablename_)] connection.execute(table.insert(), [resolve(x) for x in items_to_add_]) except DBAPIError as e: msg = f"DBAPIError while inserting {tablename} items: {e.orig.args}" @@ -105,7 +105,7 @@ def _get_primary_key(self, tablename): return pk def _make_update_stmt(self, tablename, keys): - table = self._metadata.tables[self._real_tablename(tablename)] + table = self._metadata.tables[self.real_item_type(tablename)] upd = table.update() for k in self._get_primary_key(tablename): upd = upd.where(getattr(table.c, k) == bindparam(k)) @@ -133,7 +133,7 @@ def _do_remove_items(self, connection, tablename, *ids): Args: *ids: ids to remove """ - tablename = self._real_tablename(tablename) + tablename = self.real_item_type(tablename) ids = {resolve(id_) for id_ in ids} if tablename == "alternative": # Do not remove the Base alternative diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 754bd826..3fddf4d8 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -13,7 +13,6 @@ from operator import itemgetter from .parameter_value import to_database, from_database, ParameterValueFormatError from .db_mapping_base import MappedItemBase -from .temp_id import TempId, resolve def item_factory(item_type): @@ -168,7 +167,7 @@ def polish(self): base_name = self["class_name"] + "_" + "__".join(self["element_name_list"]) name = base_name mapped_table = self._db_map.mapped_table(self._item_type) - while mapped_table.unique_key_value_to_id(("class_name", "name"), (self["class_name"], name)) is not None: + while mapped_table.find_item({"class_name": self["class_name"], "name": name}): name = base_name + "_" + uuid.uuid4().hex self["name"] = name diff --git a/spinedb_api/purge.py b/spinedb_api/purge.py index 904e4294..87f35d53 100644 --- a/spinedb_api/purge.py +++ b/spinedb_api/purge.py @@ -57,18 +57,21 @@ def purge(db_map, purge_settings, logger=None): if purge_settings is None: # Bring all the pain purge_settings = {item_type: True for item_type in DatabaseMapping.item_types()} - removable_db_map_data = {item_type for item_type, checked in purge_settings.items() if checked} + removable_db_map_data = { + DatabaseMapping.real_item_type(item_type) for item_type, checked in purge_settings.items() if checked + } if removable_db_map_data: try: if logger: logger.msg.emit("Purging database...") - for item_type in removable_db_map_data: + for item_type in removable_db_map_data & set(DatabaseMapping.item_types()): db_map.purge_items(item_type) db_map.commit_session("Purge database") if logger: logger.msg.emit("Database purged") - except SpineDBAPIError: + except SpineDBAPIError as err: if logger: - logger.msg_error.emit("Failed to purge database.") + sanitized_url = clear_filter_configs(remove_credentials_from_url(db_map.db_url)) + logger.msg_error.emit(f"Failed to purge {sanitized_url}: {err}") return False return True diff --git a/spinedb_api/server_client_helpers.py b/spinedb_api/server_client_helpers.py index b43cf5e7..f9b3b0b1 100644 --- a/spinedb_api/server_client_helpers.py +++ b/spinedb_api/server_client_helpers.py @@ -66,7 +66,7 @@ def default(self, o): if isinstance(o, SpineDBAPIError): return str(o) if isinstance(o, PublicItem): - return o._asdict() + return o._extended() return super().default(o) @property From 24b40ecf0837a1bc8b423685687cb314afbad904 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 25 Oct 2023 12:46:56 +0200 Subject: [PATCH 161/317] Fix fetching entities that refer to unfetched entities Re #298 --- spinedb_api/db_mapping.py | 4 +- spinedb_api/db_mapping_base.py | 127 +++++++++++++++++++-------------- tests/test_DatabaseMapping.py | 23 ++++++ tests/test_db_mapping_base.py | 6 +- 4 files changed, 101 insertions(+), 59 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index cc9f2774..44a2da06 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -439,10 +439,10 @@ def add_item(self, item_type, check=True, **kwargs): mapped_table = self.mapped_table(item_type) self._convert_legacy(item_type, kwargs) if not check: - return mapped_table.add_item(kwargs, new=True), None + return mapped_table.add_item(kwargs), None checked_item, error = mapped_table.check_item(kwargs) return ( - mapped_table.add_item(checked_item, new=True).public_item if checked_item and not error else None, + mapped_table.add_item(checked_item).public_item if checked_item and not error else None, error, ) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 647fb4c1..169b2fd6 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -203,41 +203,6 @@ def _refresh(self): """Clears fetch progress, so the DB is queried again.""" self._completed_queries.clear() - def _get_next_chunk(self, item_type, offset, limit, **kwargs): - completed_queries = self._completed_queries.setdefault(item_type, set()) - qry_key = tuple(sorted(kwargs.items())) - if qry_key in completed_queries: - items = [x for x in self.mapped_table(item_type).values() if all(x.get(k) == v for k, v in kwargs.items())] - if limit is None: - return items[offset:] - return items[offset : offset + limit] - qry = self._make_query(item_type, **kwargs) - if not qry: - return [] - if not limit: - completed_queries.add(qry_key) - return [dict(x) for x in qry] - chunk = [dict(x) for x in qry.limit(limit).offset(offset)] - if len(chunk) < limit: - completed_queries.add(qry_key) - return chunk - - def _advance_query(self, item_type, offset, limit, **kwargs): - """Advances the DB query that fetches items of given type - and adds the results to the corresponding mapped table. - - Args: - item_type (str) - - Returns: - list: items fetched from the DB - """ - chunk = self._get_next_chunk(item_type, offset, limit, **kwargs) - if not chunk: - return [] - mapped_table = self.mapped_table(item_type) - return [mapped_table.add_item(item) for item in chunk] - def _check_item_type(self, item_type): if item_type not in self.all_item_types(): candidate = max(self.all_item_types(), key=lambda x: SequenceMatcher(None, item_type, x).ratio()) @@ -245,10 +210,9 @@ def _check_item_type(self, item_type): def mapped_table(self, item_type): self._check_item_type(item_type) - return self._mapped_tables.setdefault(item_type, _MappedTable(self, item_type)) - - def get(self, item_type, default=None): - return self._mapped_tables.get(item_type, default) + if item_type not in self._mapped_tables: + self._mapped_tables[item_type] = _MappedTable(self, item_type) + return self._mapped_tables[item_type] def reset(self, *item_types): """Resets the mapping for given item types as if nothing was fetched from the DB or modified in the mapping. @@ -277,11 +241,54 @@ def get_mapped_item(self, item_type, id_): return {} return item + def _get_next_chunk(self, item_type, offset, limit, **kwargs): + completed_queries = self._completed_queries.setdefault(item_type, set()) + qry_key = tuple(sorted(kwargs.items())) + if qry_key in completed_queries: + items = [x for x in self.mapped_table(item_type).values() if all(x.get(k) == v for k, v in kwargs.items())] + if limit is None: + return items[offset:] + return items[offset : offset + limit] + qry = self._make_query(item_type, **kwargs) + if not qry: + return [] + if not limit: + completed_queries.add(qry_key) + return [dict(x) for x in qry] + chunk = [dict(x) for x in qry.limit(limit).offset(offset)] + if len(chunk) < limit: + completed_queries.add(qry_key) + return chunk + def do_fetch_more(self, item_type, offset=0, limit=None, **kwargs): - return self._advance_query(item_type, offset, limit, **kwargs) + """Fetches items from the DB and adds them to the mapping. - def do_fetch_all(self, item_type, **kwargs): - self.do_fetch_more(item_type, **kwargs) + Args: + item_type (str) + + Returns: + list(MappedItem): items fetched from the DB. + """ + chunk = self._get_next_chunk(item_type, offset, limit, **kwargs) + if not chunk: + return [] + mapped_table = self.mapped_table(item_type) + items = [] + new_items = [] + # Add items first + for x in chunk: + item, new = mapped_table.add_item_from_db(x) + if new: + new_items.append(item) + items.append(item) + # Once all items are added, add the unique key values + # This is because entity (class) items can refer other entity (class) items + for item in new_items: + mapped_table.add_unique(item) + return items + + def do_fetch_all(self, item_type): + self.do_fetch_more(item_type, offset=0, limit=None) def fetch_ref(self, item_type, id_): self.do_fetch_all(item_type) @@ -381,8 +388,7 @@ def _find_item_by_id(self, id_, fetch=True): current_item = self.get(id_) if current_item is None and fetch: current_item = self._db_map.fetch_ref(self._item_type, id_) - if current_item: - return current_item + return current_item def _find_item_by_unique_key(self, item, skip_keys=(), fetch=True): for key in self._db_map._item_factory(self._item_type)._unique_keys: @@ -456,24 +462,37 @@ def remove_unique(self, item): if id_by_value.get(value) == id_: del id_by_value[value] - def add_item(self, item, new=False): - if not new: - # Item comes from the DB; don̈́'t add it twice - current = self._find_item_by_id(item["id"], fetch=False) or self._find_item_by_unique_key(item, fetch=False) - if current: - return current + def _make_and_add_item(self, item): if not isinstance(item, MappedItemBase): item = self._make_item(item) item.polish() if "id" not in item or not item.is_id_valid: item["id"] = self._new_id() self[item["id"]] = item - self.add_unique(item) - if new: - item.status = Status.to_add - elif self.purged: + return item + + def add_item_from_db(self, item): + """Adds an item fetched from the DB. + + Args: + item (dict): item from the DB. + + Returns: + tuple(MappedItem,bool): The mapped item and whether it hadn't been added before. + """ + current = self._find_item_by_id(item["id"], fetch=False) or self._find_item_by_unique_key(item, fetch=False) + if current: + return current, False + item = self._make_and_add_item(item) + if self.purged: # Lazy purge: instead of fetching all at purge time, we purge stuff as it comes. item.cascade_remove(source=self.wildcard_item) + return item, True + + def add_item(self, item): + item = self._make_and_add_item(item) + self.add_unique(item) + item.status = Status.to_add return item def update_item(self, item): diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 9d980ba9..84c88c55 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -264,6 +264,29 @@ def test_get_items_gives_commits(self): self.assertEqual(len(items), 1) self.assertEqual(items[0].item_type, "commit") + def test_fetch_entities_that_refer_to_unfetched_entities(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + db_map.add_entity_class_item(name="dog") + db_map.add_entity_class_item(name="cat") + db_map.add_entity_class_item(name="dog__cat", dimension_name_list=("dog", "cat")) + db_map.add_entity_item(name="Pulgoso", class_name="dog") + db_map.add_entity_item(name="Sylvester", class_name="cat") + db_map.add_entity_item(name="Tom", class_name="cat") + db_map.commit_session("Arf!") + with DatabaseMapping(url) as db_map: + # Remove the entity in the middle and add a multi-D one referring to the third entity. + # The multi-D one will go in the middle. + db_map.get_entity_item(name="Sylvester", class_name="cat").remove() + db_map.add_entity_item(element_name_list=("Pulgoso", "Tom"), class_name="dog__cat") + db_map.commit_session("Meow!") + with DatabaseMapping(url) as db_map: + # The ("Pulgoso", "Tom") entity will be fetched before "Tom". + # What happens? + entities = db_map.get_items("entity") + self.assertEqual(len(entities), 3) + class TestDatabaseMappingLegacy(unittest.TestCase): """'Backward compatibility' tests, i.e. pre-entity tests converted to work with the entity structure.""" diff --git a/tests/test_db_mapping_base.py b/tests/test_db_mapping_base.py index c677a659..0b130d85 100644 --- a/tests/test_db_mapping_base.py +++ b/tests/test_db_mapping_base.py @@ -33,7 +33,7 @@ class TestDBMappingBase(unittest.TestCase): def test_rolling_back_new_item_invalidates_its_id(self): db_map = TestDBMapping() mapped_table = db_map.mapped_table("cutlery") - item = mapped_table.add_item({}, new=True) + item = mapped_table.add_item({}) self.assertTrue(item.is_id_valid) self.assertIn("id", item) id_ = item["id"] @@ -46,11 +46,11 @@ class TestMappedTable(unittest.TestCase): def test_readding_item_with_invalid_id_creates_new_id(self): db_map = TestDBMapping() mapped_table = db_map.mapped_table("cutlery") - item = mapped_table.add_item({}, new=True) + item = mapped_table.add_item({}) id_ = item["id"] db_map._rollback() self.assertFalse(item.is_id_valid) - mapped_table.add_item(item, new=True) + mapped_table.add_item(item) self.assertTrue(item.is_id_valid) self.assertNotEqual(item["id"], id_) From 9650d018d41c73d100e629d1b9ff6a8c750c4627 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 25 Oct 2023 15:12:47 +0200 Subject: [PATCH 162/317] Optimize fetching a little bit --- spinedb_api/db_mapping.py | 6 +-- spinedb_api/db_mapping_base.py | 78 ++++++++++++++++++---------------- 2 files changed, 44 insertions(+), 40 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 44a2da06..951ed7d9 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -413,7 +413,7 @@ def get_items(self, item_type, fetch=True, skip_removed=True, **kwargs): """ item_type = self.real_item_type(item_type) if fetch: - self.fetch_all(item_type) + self.do_fetch_all(item_type, **kwargs) mapped_table = self.mapped_table(item_type) get_items = mapped_table.valid_values if skip_removed else mapped_table.values return [x.public_item for x in get_items() if all(x.get(k) == v for k, v in kwargs.items())] @@ -608,7 +608,7 @@ def purge_items(self, item_type): """ return bool(self.remove_items(item_type, Asterisk)) - def fetch_more(self, item_type, offset=0, limit=None): + def fetch_more(self, item_type, offset=0, limit=None, ticket=None): """Fetches items from the DB into the in-memory mapping, incrementally. Args: @@ -620,7 +620,7 @@ def fetch_more(self, item_type, offset=0, limit=None): list(:class:`PublicItem`): The items fetched. """ item_type = self.real_item_type(item_type) - return [x.public_item for x in self.do_fetch_more(item_type, offset=offset, limit=limit)] + return [x.public_item for x in self.do_fetch_more(item_type, offset=offset, limit=limit, ticket=ticket)] def fetch_all(self, *item_types): """Fetches items from the DB into the in-memory mapping. diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 169b2fd6..c4b99988 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -39,7 +39,7 @@ class DatabaseMappingBase: def __init__(self): self._mapped_tables = {} - self._completed_queries = {} + self._completed_tickets = {} item_types = self.item_types() self._sorted_item_types = [] while item_types: @@ -201,7 +201,7 @@ def _rollback(self): def _refresh(self): """Clears fetch progress, so the DB is queried again.""" - self._completed_queries.clear() + self._completed_tickets.clear() def _check_item_type(self, item_type): if item_type not in self.all_item_types(): @@ -209,8 +209,8 @@ def _check_item_type(self, item_type): raise SpineDBAPIError(f"Invalid item type '{item_type}' - maybe you meant '{candidate}'?") def mapped_table(self, item_type): - self._check_item_type(item_type) if item_type not in self._mapped_tables: + self._check_item_type(item_type) self._mapped_tables[item_type] = _MappedTable(self, item_type) return self._mapped_tables[item_type] @@ -222,7 +222,7 @@ def reset(self, *item_types): self._add_descendants(item_types) for item_type in item_types: self._mapped_tables.pop(item_type, None) - self._completed_queries.pop(item_type, None) + self._completed_tickets.pop(item_type, None) def _add_descendants(self, item_types): while True: @@ -242,25 +242,20 @@ def get_mapped_item(self, item_type, id_): return item def _get_next_chunk(self, item_type, offset, limit, **kwargs): - completed_queries = self._completed_queries.setdefault(item_type, set()) - qry_key = tuple(sorted(kwargs.items())) - if qry_key in completed_queries: - items = [x for x in self.mapped_table(item_type).values() if all(x.get(k) == v for k, v in kwargs.items())] - if limit is None: - return items[offset:] - return items[offset : offset + limit] + """Gets chunk of items from the DB. + + Returns: + tuple(list(dict),bool): list of dictionary items and whether this is the last chunk. + """ qry = self._make_query(item_type, **kwargs) if not qry: - return [] + return [], True if not limit: - completed_queries.add(qry_key) - return [dict(x) for x in qry] + return [dict(x) for x in qry], True chunk = [dict(x) for x in qry.limit(limit).offset(offset)] - if len(chunk) < limit: - completed_queries.add(qry_key) - return chunk + return chunk, len(chunk) < limit - def do_fetch_more(self, item_type, offset=0, limit=None, **kwargs): + def do_fetch_more(self, item_type, offset=0, limit=None, ticket=None, **kwargs): """Fetches items from the DB and adds them to the mapping. Args: @@ -269,7 +264,12 @@ def do_fetch_more(self, item_type, offset=0, limit=None, **kwargs): Returns: list(MappedItem): items fetched from the DB. """ - chunk = self._get_next_chunk(item_type, offset, limit, **kwargs) + completed_tickets = self._completed_tickets.setdefault(item_type, set()) + if ticket in completed_tickets: + return [] + chunk, completed = self._get_next_chunk(item_type, offset, limit, **kwargs) + if ticket is not None and completed: + completed_tickets.add(ticket) if not chunk: return [] mapped_table = self.mapped_table(item_type) @@ -287,8 +287,8 @@ def do_fetch_more(self, item_type, offset=0, limit=None, **kwargs): mapped_table.add_unique(item) return items - def do_fetch_all(self, item_type): - self.do_fetch_more(item_type, offset=0, limit=None) + def do_fetch_all(self, item_type, **kwargs): + self.do_fetch_more(item_type, offset=0, limit=None, **kwargs) def fetch_ref(self, item_type, id_): self.do_fetch_all(item_type) @@ -390,7 +390,7 @@ def _find_item_by_id(self, id_, fetch=True): current_item = self._db_map.fetch_ref(self._item_type, id_) return current_item - def _find_item_by_unique_key(self, item, skip_keys=(), fetch=True): + def _find_item_by_unique_key(self, item, skip_keys=(), fetch=True, complete=True): for key in self._db_map._item_factory(self._item_type)._unique_keys: if key in skip_keys: continue @@ -400,18 +400,19 @@ def _find_item_by_unique_key(self, item, skip_keys=(), fetch=True): current_item = self._unique_key_value_to_item(key, value, fetch=fetch) if current_item: return current_item - # Maybe item is missing some key stuff, so try with a resolved and polished MappedItem too... - mapped_item = self._make_item(item) - error = mapped_item.resolve_inverse_references(item.keys()) - if error: - return None - error = mapped_item.polish() - if error: - return None - for key, value in mapped_item.unique_values(skip_keys=skip_keys): - current_item = self._unique_key_value_to_item(key, value, fetch=fetch) - if current_item: - return current_item + if complete: + # Maybe item is missing some key stuff, so try with a resolved and polished MappedItem too... + mapped_item = self._make_item(item) + error = mapped_item.resolve_inverse_references(item.keys()) + if error: + return None + error = mapped_item.polish() + if error: + return None + for key, value in mapped_item.unique_values(skip_keys=skip_keys): + current_item = self._unique_key_value_to_item(key, value, fetch=fetch) + if current_item: + return current_item def check_item(self, item, for_update=False, skip_keys=()): # FIXME: The only use-case for skip_keys at the moment is that of importing scenario alternatives, @@ -480,7 +481,9 @@ def add_item_from_db(self, item): Returns: tuple(MappedItem,bool): The mapped item and whether it hadn't been added before. """ - current = self._find_item_by_id(item["id"], fetch=False) or self._find_item_by_unique_key(item, fetch=False) + current = self._find_item_by_id(item["id"], fetch=False) or self._find_item_by_unique_key( + item, fetch=False, complete=False + ) if current: return current, False item = self._make_and_add_item(item) @@ -854,9 +857,10 @@ def add_referrer(self, referrer): Args: referrer (MappedItemBase) """ - if referrer.key is None: + key = referrer.key + if key is None: return - self._referrers[referrer.key] = self._weak_referrers.pop(referrer.key, referrer) + self._referrers[key] = self._weak_referrers.pop(key, referrer) def remove_referrer(self, referrer): """Removes a strong referrer. From 1900a69d0a791414567a86ccf0cdfe1fbe84bc94 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 25 Oct 2023 15:14:49 +0200 Subject: [PATCH 163/317] Progress tutorial --- docs/source/tutorial.rst | 75 +++++++++++++++++++++------------------- 1 file changed, 40 insertions(+), 35 deletions(-) diff --git a/docs/source/tutorial.rst b/docs/source/tutorial.rst index 619cf58e..b1c6da53 100644 --- a/docs/source/tutorial.rst +++ b/docs/source/tutorial.rst @@ -22,6 +22,7 @@ The main mean of communication with a Spine DB is the :class:`.DatabaseMapping`, specially designed to retrieve and modify data from the DB. To create a :class:`.DatabaseMapping`, we just pass the URL of the DB to the class constructor:: + import spinedb_api as api from spinedb_api import DatabaseMapping url = "mysql://spine_db" # The URL of an existing Spine DB @@ -43,6 +44,7 @@ Creating a DB If you're following this tutorial, chances are you don't have a Spine DB to play with just yet. We can remediate this by creating a SQLite DB (which is just a file in your system), as follows:: + import spinedb_api as api from spinedb_api import DatabaseMapping url = "sqlite:///first.sqlite" @@ -58,8 +60,7 @@ that we want the DB to be created at the given URL. .. note:: In the remainder we will skip the above step and work directly with ``db_map``. In other words, - all the examples below assume we are inside the ``with`` block above - except when we need to modify the ``import`` line. + all the examples below assume we are inside the ``with`` block above. Adding data ----------- @@ -85,7 +86,8 @@ Let's add entities to our zero-dimensional classes:: db_map.add_item("entity", class_name="fish", name="Nemo", description="Lost (for now).") db_map.add_item( - "entity", class_name="cat", name="Felix", description="The wonderful wonderful cat." + "entity", + class_name="cat", name="Felix", description="The wonderful wonderful cat." ) Let's add a multi-dimensional entity to our multi-dimensional class. For this we need to specify the entity names @@ -98,70 +100,64 @@ Let's add a parameter definition for one of our entity classes:: db_map.add_item("parameter_definition", entity_class_name="fish", name="color") Finally, let's specify a parameter value for one of our entities. -For this we need :func:`.to_database` function which converts the value into its database representation. -Let's modify the import statement at the beginning of our script:: +We use :func:`.to_database` to convert our value +into a tuple of value and type to specify for our parameter value item:: - from spinedb_api import DatabaseMapping, to_database - -Now we're ready to go:: - - color, value_type = to_database("mainly orange") + value, type_ = api.to_database("mainly orange") db_map.add_item( "parameter_value", entity_class_name="fish", entity_byname=("Nemo",), parameter_definition_name="color", alternative_name="Base", - value=color, - type=value_type + value=value, + type=type_ ) -Note that in the above, we must refer the entity by its *byname* which is a tuple of its dimensions. -We also set the value to belong to an *alternative* called ``"Base"`` +Note that in the above, we refer to the entity by its *byname* which is a tuple of its elements. +We also set the value to belong to an *alternative* called ``Base`` which is readily available in new databases. .. note:: The data we've added so far is not yet in the DB, but only in an in-memory mapping within our ``db_map`` object. - You need to call :meth:`~.DatabaseMapping.commit_session` to actually store the data. + Don't worry, we will save it to the DB soon (see `Committing data`_ if you're impatient). Retrieving data --------------- -To retrieve data from the DB (and the in-memory mapping), we use :meth:`~.DatabaseMapping.get_item`. +To retrieve data, we use :meth:`~.DatabaseMapping.get_item`. This implicitly fetches data from the DB +into the in-memory mapping, if not already there. For example, let's find one of the entities we inserted above:: - felix = db_map.get_item("entity", class_name="cat", name="Felix") - print(felix["description"]) # Prints 'The wonderful wonderful cat.' + felix_item = db_map.get_entity_item(class_name="cat", name="Felix") + assert felix_item["description"] == "The wonderful wonderful cat." -Above, ``felix`` is a :class:`~.PublicItem` object, representing an item (or row) in a Spine DB. +Above, ``felix_item`` is a :class:`~.PublicItem` object, representing an item (or row) in a Spine DB. Let's find our multi-dimensional entity:: - nemo_felix = db_map.get_item("entity", class_name="fish__cat", element_name_list=("Nemo", "Felix")) - print(nemo_felix["dimension_name_list"]) # Prints "('fish', 'cat')" - -Parameter values need to be converted to Python values using :func:`.from_database` before we can use them. -First we need to import the function:: + nemo_felix_item = db_map.get_item("entity", class_name="fish__cat", element_name_list=("Nemo", "Felix")) + assert nemo_felix_item["dimension_name_list"] == ('fish', 'cat') - from spinedb_api import DatabaseMapping, to_database, from_database +Now let's retrieve our parameter value. +We use :func:`.from_database` to convert the value and type from the parameter value item into our original value:: -Then we can retrieve the ``"color"`` of ``"Nemo"`` (in the ``"Base"`` alternative):: - - color_value = db_map.get_item( + nemo_color_item = db_map.get_item( "parameter_value", entity_class_name="fish", entity_byname=("Nemo",), parameter_definition_name="color", alternative_name="Base" ) - color = from_database(color_value["value"], color_value["type"]) - print(color) # Prints 'mainly orange' + nemo_color = api.from_database(nemo_color_item["value"], nemo_color_item["type"]) + assert nemo_color == "mainly orange" To retrieve all the items of a given type, we use :meth:`~.DatabaseMapping.get_items`:: - print(list(entity["byname"] for entity in db_map.get_items("entity"))) - # Prints [("Nemo",), ("Felix",), ("Nemo", "Felix"),] + assert [entity["byname"] for entity in db_map.get_items("entity")] == [ + ("Nemo",), ("Felix",), ("Nemo", "Felix") + ] Now you should use the above to try and find Nemo. @@ -177,16 +173,16 @@ Let's rename our fish entity to avoid any copyright infringements:: To be safe, let's also change the color:: - new_color, value_type = to_database("not that orange") + new_value, new_type = api.to_database("not that orange") db_map.get_item( "parameter_value", entity_class_name="fish", entity_byname=("NotNemo",), parameter_definition_name="color", alternative_name="Base", - ).update(value=new_color, type=value_type) + ).update(value=new_value, type=new_type) -Note how we need to use then new entity name ``"NotNemo"`` to retrieve the parameter value. This makes sense. +Note how we need to use then new entity name ``NotNemo`` to retrieve the parameter value. This makes sense. Removing data ------------- @@ -200,3 +196,12 @@ Note that the above call removes items in *cascade*, meaning that items that depend on ``"NotNemo"`` will get removed as well. We have one such item in the database, namely the ``"color"`` parameter value which also gets dropped when the above method is called. + + +Committing data +--------------- + +Enough messing around. To save the contents of the in-memory mapping into the DB, +we use :meth:`~.DatabaseMapping.commit_session`:: + + db_map.commit_session("Find Nemo, then lose him again") From 39d1cfa5e148f901b74e4a413caeff6b752de3ef Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 26 Oct 2023 08:38:54 +0200 Subject: [PATCH 164/317] Don't raise if cannot filter query, that's ok --- spinedb_api/db_mapping_base.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index c4b99988..9e10c048 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -100,8 +100,6 @@ def _make_query(self, item_type, **kwargs): src_key, (ref_type, ref_key) = self._item_factory(item_type)._references[key] ref_sq = self._make_sq(ref_type) qry = qry.filter(getattr(sq.c, src_key) == ref_sq.c.id, getattr(ref_sq.c, ref_key) == value) - else: - raise SpineDBAPIError(f"invalid filter {key}={value} for {item_type}") return qry def _make_sq(self, item_type): From b1def6d3457f23a263dca9032761b3488554a151 Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 27 Oct 2023 11:34:45 +0200 Subject: [PATCH 165/317] First steps towards superclass_subclass --- spinedb_api/db_mapping.py | 1 + spinedb_api/db_mapping_base.py | 47 ++++++-------- spinedb_api/db_mapping_query_mixin.py | 16 +++++ spinedb_api/export_functions.py | 4 +- spinedb_api/helpers.py | 16 +++++ spinedb_api/import_functions.py | 23 ++++--- spinedb_api/mapped_items.py | 68 +++++++++++---------- tests/export_mapping/test_export_mapping.py | 24 ++++---- tests/export_mapping/test_settings.py | 6 +- tests/test_DatabaseMapping.py | 8 +-- tests/test_export_functions.py | 3 +- tests/test_import_functions.py | 20 +++--- 12 files changed, 135 insertions(+), 101 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 951ed7d9..a2aa527e 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -119,6 +119,7 @@ class DatabaseMapping(DatabaseMappingQueryMixin, DatabaseMappingCommitMixin, Dat "scenario": "scenario_sq", "scenario_alternative": "scenario_alternative_sq", "entity_class": "wide_entity_class_sq", + "superclass_subclass": "superclass_subclass_sq", "entity": "wide_entity_sq", "entity_group": "entity_group_sq", "entity_alternative": "entity_alternative_sq", diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 9e10c048..7805c3c3 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -232,12 +232,9 @@ def _add_descendants(self, item_types): if not changed: break - def get_mapped_item(self, item_type, id_): + def get_mapped_item(self, item_type, id_, fetch=True): mapped_table = self.mapped_table(item_type) - item = mapped_table.get(id_) - if item is None: - return {} - return item + return mapped_table.find_item_by_id(id_, fetch=fetch) or {} def _get_next_chunk(self, item_type, offset, limit, **kwargs): """Gets chunk of items from the DB. @@ -280,7 +277,7 @@ def do_fetch_more(self, item_type, offset=0, limit=None, ticket=None, **kwargs): new_items.append(item) items.append(item) # Once all items are added, add the unique key values - # This is because entity (class) items can refer other entity (class) items + # Otherwise items that refer to other items that come later in the query will be seen as corrupted for item in new_items: mapped_table.add_unique(item) return items @@ -288,12 +285,6 @@ def do_fetch_more(self, item_type, offset=0, limit=None, ticket=None, **kwargs): def do_fetch_all(self, item_type, **kwargs): self.do_fetch_more(item_type, offset=0, limit=None, **kwargs) - def fetch_ref(self, item_type, id_): - self.do_fetch_all(item_type) - ref = self.get_mapped_item(item_type, id_) - if ref: - return ref - class _MappedTable(dict): def __init__(self, db_map, item_type, *args, **kwargs): @@ -379,16 +370,17 @@ def find_item(self, item, skip_keys=(), fetch=True): """ id_ = item.get("id") if id_ is not None: - return self._find_item_by_id(id_, fetch=fetch) - return self._find_item_by_unique_key(item, skip_keys=skip_keys, fetch=fetch) + return self.find_item_by_id(id_, fetch=fetch) + return self.find_item_by_unique_key(item, skip_keys=skip_keys, fetch=fetch) - def _find_item_by_id(self, id_, fetch=True): + def find_item_by_id(self, id_, fetch=True): current_item = self.get(id_) if current_item is None and fetch: - current_item = self._db_map.fetch_ref(self._item_type, id_) + self._db_map.do_fetch_all(self._item_type) + current_item = self.get(id_) return current_item - def _find_item_by_unique_key(self, item, skip_keys=(), fetch=True, complete=True): + def find_item_by_unique_key(self, item, skip_keys=(), fetch=True, complete=True): for key in self._db_map._item_factory(self._item_type)._unique_keys: if key in skip_keys: continue @@ -479,7 +471,7 @@ def add_item_from_db(self, item): Returns: tuple(MappedItem,bool): The mapped item and whether it hadn't been added before. """ - current = self._find_item_by_id(item["id"], fetch=False) or self._find_item_by_unique_key( + current = self.find_item_by_id(item["id"], fetch=False) or self.find_item_by_unique_key( item, fetch=False, complete=False ) if current: @@ -708,7 +700,7 @@ def _convert(x): return not all(_convert(self.get(key)) == _convert(value) for key, value in other.items()) def first_invalid_key(self): - """Goes through the ``_references`` class attribute and returns the key of the first one + """Goes through the ``_references`` class attribute and returns the key of the first reference that cannot be resolved. Returns: @@ -799,11 +791,11 @@ def _get_ref(self, ref_type, ref_id, strong=True): Returns: MappedItemBase or dict """ - ref = self._db_map.get_mapped_item(ref_type, ref_id) + ref = self._db_map.get_mapped_item(ref_type, ref_id, fetch=False) if not ref: if not strong: return {} - ref = self._db_map.fetch_ref(ref_type, ref_id) + ref = self._db_map.get_mapped_item(ref_type, ref_id, fetch=True) if not ref: self._corrupted = True return {} @@ -866,9 +858,9 @@ def remove_referrer(self, referrer): Args: referrer (MappedItemBase) """ - if referrer.key is None: - return - self._referrers.pop(referrer.key, None) + key = referrer.key + if key is not None: + self._referrers.pop(key, None) def add_weak_referrer(self, referrer): """Adds a weak referrer to this item. @@ -877,10 +869,11 @@ def add_weak_referrer(self, referrer): Args: referrer (MappedItemBase) """ - if referrer.key is None: + key = referrer.key + if key is None: return - if referrer.key not in self._referrers: - self._weak_referrers[referrer.key] = referrer + if key not in self._referrers: + self._weak_referrers[key] = referrer def _update_weak_referrers(self): for weak_referrer in self._weak_referrers.values(): diff --git a/spinedb_api/db_mapping_query_mixin.py b/spinedb_api/db_mapping_query_mixin.py index 93ab4e68..f349fe03 100644 --- a/spinedb_api/db_mapping_query_mixin.py +++ b/spinedb_api/db_mapping_query_mixin.py @@ -43,6 +43,7 @@ def __init__(self, *args, **kwargs): self._metadata_sq = None self._parameter_value_metadata_sq = None self._entity_metadata_sq = None + self._superclass_subclass_sq = None # Special convenience subqueries that join two or more tables self._wide_entity_class_sq = None self._wide_entity_sq = None @@ -123,6 +124,21 @@ def _subquery(self, tablename): table = self._metadata.tables[tablename] return self.query(table).subquery(tablename + "_sq") + @property + def superclass_subclass_sq(self): + """A subquery of the form: + + .. code-block:: sql + + SELECT * FROM superclass_subclass + + Returns: + :class:`~sqlalchemy.sql.expression.Alias` + """ + if self._superclass_subclass_sq is None: + self._superclass_subclass_sq = self._subquery("superclass_subclass") + return self._superclass_subclass_sq + @property def entity_class_sq(self): """A subquery of the form: diff --git a/spinedb_api/export_functions.py b/spinedb_api/export_functions.py index a4e07fff..14e2094b 100644 --- a/spinedb_api/export_functions.py +++ b/spinedb_api/export_functions.py @@ -130,8 +130,8 @@ def export_entity_classes(db_map, ids=Asterisk): def export_entities(db_map, ids=Asterisk): return sorted( - ((x.class_name, x.element_name_list or x.name, x.description) for x in _get_items(db_map, "entity", ids)), - key=lambda x: (0 if isinstance(x[1], str) else len(x[1]), x[0], (x[1],) if isinstance(x[1], str) else x[1]), + ((x.class_name, x.name, x.element_name_list, x.description) for x in _get_items(db_map, "entity", ids)), + key=lambda x: (len(x[2]), x[0], x[2], x[1]), ) diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index 2dfbb7fe..ee93c2bf 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -352,6 +352,22 @@ def create_spine_metadata(): Column("display_icon", BigInteger, server_default=null()), Column("hidden", Integer, server_default="0"), ) + Table( + "superclass_subclass", + meta, + Column("id", Integer, primary_key=True), + Column( + "superclass_id", + Integer, + ForeignKey("entity_class.id", onupdate="CASCADE", ondelete="CASCADE"), + ), + Column( + "subclass_id", + Integer, + ForeignKey("entity_class.id", onupdate="CASCADE", ondelete="CASCADE"), + unique=True, + ), + ) Table( "entity_class_dimension", meta, diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 0560f924..4e7501ce 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -207,7 +207,7 @@ def get_data_for_import( yield ("parameter_value_metadata", _get_parameter_value_metadata_for_import(db_map, parameter_value_metadata)) # Legacy if object_classes: - yield from get_data_for_import(db_map, entity_classes=_object_classes_from_entity_classes(object_classes)) + yield from get_data_for_import(db_map, entity_classes=_object_classes_to_entity_classes(object_classes)) if relationship_classes: yield from get_data_for_import(db_map, entity_classes=relationship_classes) if object_parameters: @@ -519,12 +519,19 @@ def _ref_count(name): def _get_entities_for_import(db_map, data): items_by_el_count = {} - key = ("class_name", "byname", "description") - for class_name, name_or_element_name_list, *optionals in data: - is_zero_dim = isinstance(name_or_element_name_list, str) - byname = (name_or_element_name_list,) if is_zero_dim else tuple(name_or_element_name_list) - item = dict(zip(key, (class_name, byname, *optionals))) - el_count = 0 if is_zero_dim else len(name_or_element_name_list) + key = ("class_name", "name", "element_name_list", "description") + for class_name, name_or_el_name_list, *optionals in data: + if isinstance(name_or_el_name_list, (list, tuple)): + name = None + el_name_list = name_or_el_name_list + else: + name = name_or_el_name_list + if optionals and isinstance(optionals[0], (list, tuple)): + el_name_list = tuple(optionals.pop(0)) + else: + el_name_list = () + item = dict(zip(key, (class_name, name, el_name_list, *optionals))) + el_count = len(el_name_list) items_by_el_count.setdefault(el_count, []).append(item) return ( _get_items_for_import(db_map, "entity", items_by_el_count[el_count]) for el_count in sorted(items_by_el_count) @@ -699,7 +706,7 @@ def _data_iterator(): # Legacy -def _object_classes_from_entity_classes(data): +def _object_classes_to_entity_classes(data): for x in data: if isinstance(x, str): yield x, () diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 3fddf4d8..426321a0 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -9,7 +9,6 @@ # this program. If not, see . ###################################################################################################################### -import uuid from operator import itemgetter from .parameter_value import to_database, from_database, ParameterValueFormatError from .db_mapping_base import MappedItemBase @@ -19,6 +18,7 @@ def item_factory(item_type): return { "commit": CommitItem, "entity_class": EntityClassItem, + "superclass_subclass": SuperclassSubclassItem, "entity": EntityItem, "entity_alternative": EntityAlternativeItem, "entity_group": EntityGroupItem, @@ -70,6 +70,17 @@ def __init__(self, *args, **kwargs): kwargs["dimension_id_list"] = tuple(dimension_id_list) super().__init__(*args, **kwargs) + def __getitem__(self, key): + if key == "superclass_name": + # FIXME: create a weak reference too + return ( + self._db_map.mapped_table("superclass_subclass").find_item_by_unique_key( + {"subclass_name": self["name"]} + ) + or {} + ).get("superclass_name") + return super().__getitem__(key) + def merge(self, other): dimension_id_list = other.pop("dimension_id_list", None) error = ( @@ -118,57 +129,35 @@ def __init__(self, *args, **kwargs): kwargs["element_id_list"] = tuple(element_id_list) super().__init__(*args, **kwargs) - def _byname_iter(self, entity): + def _element_name_list_iter(self, entity): element_id_list = entity["element_id_list"] if not element_id_list: yield entity["name"] else: for el_id in element_id_list: element = self._get_ref("entity", el_id) - yield from self._byname_iter(element) + yield from self._element_name_list_iter(element) def __getitem__(self, key): + if key == "root_element_name_list": + return tuple(self._element_name_list_iter(self)) if key == "byname": - return tuple(self._byname_iter(self)) + return self["element_name_list"] or (self["name"],) return super().__getitem__(key) - def resolve_inverse_references(self, skip_keys=()): - error = super().resolve_inverse_references(skip_keys=skip_keys) - if error: - return error - byname = dict.pop(self, "byname", None) - if byname is None: - return - if not self["dimension_id_list"]: - self["name"] = byname[0] - return - byname_remainder = list(byname) - self["element_name_list"] = self._element_name_list_recursive(self["class_name"], byname_remainder) - return self._do_resolve_inverse_reference("element_id_list") - - def _element_name_list_recursive(self, class_name, byname_remainder): - dimension_name_list = self._db_map.get_item("entity_class", name=class_name).get("dimension_name_list") - if not dimension_name_list: - name = byname_remainder.pop(0) - return (name,) - return tuple( - self._db_map.get_item( - "entity", class_name=dim_name, byname=self._element_name_list_recursive(dim_name, byname_remainder) - ).get("name") - for dim_name in dimension_name_list - ) - def polish(self): error = super().polish() if error: return error - if "name" in self: + if self.get("name") is not None: return - base_name = self["class_name"] + "_" + "__".join(self["element_name_list"]) + base_name = "__".join(self["element_name_list"]) name = base_name mapped_table = self._db_map.mapped_table(self._item_type) + index = 1 while mapped_table.find_item({"class_name": self["class_name"], "name": name}): - name = base_name + "_" + uuid.uuid4().hex + name = base_name + f"_{index}" + index += 1 self["name"] = name @@ -607,3 +596,16 @@ class ParameterValueMetadataItem(MappedItemBase): ), "metadata_id": (("metadata_name", "metadata_value"), ("metadata", ("name", "value"))), } + + +class SuperclassSubclassItem(MappedItemBase): + fields = {"superclass_name": ("str", "The superclass name."), "subclass_name": ("str", "The subclass name.")} + _unique_keys = (("subclass_name",),) + _references = { + "superclass_name": ("superclass_id", ("entity_class", "name")), + "subclass_name": ("subclass_id", ("entity_class", "name")), + } + _inverse_references = { + "superclass_id": (("superclass_name",), ("entity_class", ("name",))), + "subclass_id": (("subclass_name",), ("entity_class", ("name",))), + } diff --git a/tests/export_mapping/test_export_mapping.py b/tests/export_mapping/test_export_mapping.py index b8ea8063..1ec8f669 100644 --- a/tests/export_mapping/test_export_mapping.py +++ b/tests/export_mapping/test_export_mapping.py @@ -648,9 +648,7 @@ def test_object_relationship_name_as_table_name(self): tables = dict() for title, title_key in titles(mappings, db_map): tables[title] = list(rows(mappings, db_map, title_key)) - self.assertEqual( - tables, {"rc_o1__O,o1": [["rc", "oc1", "oc2", "O"]], "rc_o2__O,o2": [["rc", "oc1", "oc2", "O"]]} - ) + self.assertEqual(tables, {"o1__O,o1": [["rc", "oc1", "oc2", "O"]], "o2__O,o2": [["rc", "oc1", "oc2", "O"]]}) db_map.close() def test_parameter_definitions_with_value_lists(self): @@ -823,9 +821,9 @@ def test_export_relationships(self): element1_mapping = relationship_mapping.child = ElementMapping(4) element1_mapping.child = ElementMapping(5) expected = [ - ['rc1', 'oc1', '', 'rc1_o11', 'o11', ''], - ['rc2', 'oc2', 'oc1', 'rc2_o21__o11', 'o21', 'o11'], - ['rc2', 'oc2', 'oc1', 'rc2_o21__o12', 'o21', 'o12'], + ['rc1', 'oc1', '', 'o11', 'o11', ''], + ['rc2', 'oc2', 'oc1', 'o21__o11', 'o21', 'o11'], + ['rc2', 'oc2', 'oc1', 'o21__o12', 'o21', 'o12'], ] self.assertEqual(list(rows(relationship_class_mapping, db_map)), expected) db_map.close() @@ -1072,7 +1070,7 @@ def test_header_position_with_relationships(self): ElementMapping(3), ] ) - expected = [["", "", "oc1", "oc2"], ["rc", "rc_o11__o21", "o11", "o21"]] + expected = [["", "", "oc1", "oc2"], ["rc", "o11__o21", "o11", "o21"]] self.assertEqual(list(rows(root, db_map)), expected) db_map.close() @@ -1481,8 +1479,8 @@ def test_highlight_relationship_objects(self): ] ) expected = [ - ["rc", "oc1", "oc2", "rc_o11__o21", "o11", "o21"], - ["rc", "oc1", "oc2", "rc_o12__o22", "o12", "o22"], + ["rc", "oc1", "oc2", "o11__o21", "o11", "o21"], + ["rc", "oc1", "oc2", "o12__o22", "o12", "o22"], ] self.assertEqual(list(rows(root_mapping, db_map)), expected) db_map.close() @@ -1507,7 +1505,7 @@ def test_export_object_parameters_while_exporting_relationships(self): ParameterValueMapping(6), ] ) - expected = [["rc", "oc", "rc_o", "o", "p", "Base", 23.0]] + expected = [["rc", "oc", "o", "o", "p", "Base", 23.0]] self.assertEqual(list(rows(root_mapping, db_map)), expected) db_map.close() @@ -1563,9 +1561,9 @@ def test_export_object_parameters_while_exporting_relationships_with_multiple_pa ] ) expected = [ - ["rc12", "oc1", "oc2", "rc12_o11__o21", "o11", "o21", "p21", "Base", 5.5], - ["rc12", "oc1", "oc2", "rc12_o12__o21", "o12", "o21", "p21", "Base", 5.5], - ["rc23", "oc2", "oc3", "rc23_o21__o31", "o21", "o31", "p31", "Base", 7.7], + ["rc12", "oc1", "oc2", "o11__o21", "o11", "o21", "p21", "Base", 5.5], + ["rc12", "oc1", "oc2", "o12__o21", "o12", "o21", "p21", "Base", 5.5], + ["rc23", "oc2", "oc3", "o21__o31", "o21", "o31", "p31", "Base", 7.7], ] self.assertEqual(list(rows(root_mapping, db_map)), expected) db_map.close() diff --git a/tests/export_mapping/test_settings.py b/tests/export_mapping/test_settings.py index 4c3d0df5..fb84d20a 100644 --- a/tests/export_mapping/test_settings.py +++ b/tests/export_mapping/test_settings.py @@ -164,9 +164,9 @@ def test_export_with_two_dimensions(self): ) set_entity_dimensions(root_mapping, 2) expected = [ - ["rc", "p11", "rc_o11__o21", "oc1", "oc2", "o11", "o21", "Base", "single_value", 2.3], - ["rc", "p11", "rc_o12__o21", "oc1", "oc2", "o12", "o21", "Base", "single_value", -2.3], - ["rc", "p12", "rc_o12__o21", "oc1", "oc2", "o12", "o21", "Base", "single_value", -5.0], + ["rc", "p11", "o11__o21", "oc1", "oc2", "o11", "o21", "Base", "single_value", 2.3], + ["rc", "p11", "o12__o21", "oc1", "oc2", "o12", "o21", "Base", "single_value", -2.3], + ["rc", "p12", "o12__o21", "oc1", "oc2", "o12", "o21", "Base", "single_value", -5.0], ] self.assertEqual(list(rows(root_mapping, self._db_map)), expected) diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 84c88c55..7ddb75e9 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -646,7 +646,7 @@ def test_entity_sq(self): entity_rows = self._db_map.query(self._db_map.entity_sq).all() self.assertEqual(len(entity_rows), len(objects) + len(relationships)) object_names = [o[1] for o in objects] - relationship_names = [r[0] + "_" + "__".join(r[1]) for r in relationships] + relationship_names = ["__".join(r[1]) for r in relationships] for row, expected_name in zip(entity_rows, object_names + relationship_names): self.assertEqual(row.name, expected_name) @@ -690,7 +690,7 @@ def test_wide_relationship_sq(self): relationship_rows = self._db_map.query(self._db_map.wide_relationship_sq).all() self.assertEqual(len(relationship_rows), 2) for row, relationship in zip(relationship_rows, relationships): - self.assertEqual(row.name, relationship[0] + "_" + "__".join(relationship[1])) + self.assertEqual(row.name, "__".join(relationship[1])) self.assertEqual(row.class_name, relationship[0]) self.assertEqual(row.object_class_name_list, ",".join(object_classes[relationship[0]])) self.assertEqual(row.object_name_list, ",".join(relationship[1])) @@ -886,7 +886,7 @@ def test_add_object_with_invalid_name(self): """Test that adding object classes with empty name raises error""" self._db_map.add_object_classes({"name": "fish"}) with self.assertRaises(SpineIntegrityError): - self._db_map.add_objects({"name": "", "class_id": 1}, strict=True) + self._db_map.add_objects({"name": "", "class_name": "fish"}, strict=True) def test_add_objects_with_same_name(self): """Test that adding two objects with the same name only adds one of them.""" @@ -1423,7 +1423,7 @@ def test_add_entity_metadata_for_relationship(self): dict(entity_metadata[0]), { "entity_id": 2, - "entity_name": "my_relationship_class_my_object", + "entity_name": "my_object", "metadata_name": "title", "metadata_value": "My metadata.", "metadata_id": 1, diff --git a/tests/test_export_functions.py b/tests/test_export_functions.py index 489d549a..ac40310d 100644 --- a/tests/test_export_functions.py +++ b/tests/test_export_functions.py @@ -104,7 +104,8 @@ def test_export_data(self): ) self.assertIn("entities", exported) self.assertEqual( - exported["entities"], [("object_class", "object", None), ("relationship_class", ("object",), None)] + exported["entities"], + [("object_class", "object", (), None), ("relationship_class", "object", ("object",), None)], ) self.assertIn("parameter_values", exported) self.assertEqual( diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py index 1b071478..45313333 100644 --- a/tests/test_import_functions.py +++ b/tests/test_import_functions.py @@ -354,7 +354,7 @@ def test_import_relationships(self): _, errors = import_relationships(db_map, (("relationship_class", ("object",)),)) self.assertFalse(errors) db_map.commit_session("test") - self.assertIn("relationship_class_object", [r.name for r in db_map.query(db_map.relationship_sq)]) + self.assertIn("object", [r.name for r in db_map.query(db_map.relationship_sq)]) db_map.close() def test_import_valid_relationship(self): @@ -364,7 +364,7 @@ def test_import_valid_relationship(self): _, errors = import_relationships(db_map, [["relationship_class", ["object1", "object2"]]]) self.assertFalse(errors) db_map.commit_session("test") - self.assertIn("relationship_class_object1__object2", [r.name for r in db_map.query(db_map.relationship_sq)]) + self.assertIn("object1__object2", [r.name for r in db_map.query(db_map.relationship_sq)]) db_map.close() def test_import_valid_relationship_with_object_name_in_multiple_classes(self): @@ -375,7 +375,7 @@ def test_import_valid_relationship_with_object_name_in_multiple_classes(self): _, errors = import_relationships(db_map, [["relationship_class", ["duplicate", "object2"]]]) self.assertFalse(errors) db_map.commit_session("test") - self.assertIn("relationship_class_duplicate__object2", [r.name for r in db_map.query(db_map.relationship_sq)]) + self.assertIn("duplicate__object2", [r.name for r in db_map.query(db_map.relationship_sq)]) db_map.close() def test_import_relationship_with_invalid_class_name(self): @@ -403,10 +403,10 @@ def test_import_existing_relationship(self): import_relationship_classes(db_map, [["relationship_class", ["object_class1", "object_class2"]]]) import_relationships(db_map, [["relationship_class", ["object1", "object2"]]]) db_map.commit_session("test") - self.assertIn("relationship_class_object1__object2", [r.name for r in db_map.query(db_map.relationship_sq)]) + self.assertIn("object1__object2", [r.name for r in db_map.query(db_map.relationship_sq)]) _, errors = import_relationships(db_map, [["relationship_class", ["object1", "object2"]]]) self.assertFalse(errors) - self.assertIn("relationship_class_object1__object2", [r.name for r in db_map.query(db_map.relationship_sq)]) + self.assertIn("object1__object2", [r.name for r in db_map.query(db_map.relationship_sq)]) db_map.close() def test_import_relationship_with_one_None_object(self): @@ -430,12 +430,12 @@ def test_import_relationship_of_relationships(self): ["meta_relationship_class", ["relationship_class1", "relationship_class2"]], ], entities=[ - ["relationship_class1", ["object1", "object2"]], - ["relationship_class2", ["object2", "object1"]], + ["relationship_class1", "object1__object2", ["object1", "object2"]], + ["relationship_class2", "object2__object1", ["object2", "object1"]], ], ) _, errors = import_data( - db_map, entities=[["meta_relationship_class", ["object1", "object2", "object2", "object1"]]] + db_map, entities=[["meta_relationship_class", ["object1__object2", "object2__object1"]]] ) self.assertFalse(errors) db_map.commit_session("test") @@ -1346,7 +1346,7 @@ def test_import_relationship_parameter_value_metadata(self): dict(metadata[0]), { "alternative_name": "Base", - "entity_name": "relationship_class_object", + "entity_name": "object", "id": 1, "metadata_id": 1, "metadata_name": "co-author", @@ -1360,7 +1360,7 @@ def test_import_relationship_parameter_value_metadata(self): dict(metadata[1]), { "alternative_name": "Base", - "entity_name": "relationship_class_object", + "entity_name": "object", "id": 2, "metadata_id": 2, "metadata_name": "age", From 590fbc26bb1d4f9b43cefb286f69ad44ade3b595 Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 27 Oct 2023 12:16:47 +0200 Subject: [PATCH 166/317] Add reference to superclass_name from entity --- spinedb_api/mapped_items.py | 1 + 1 file changed, 1 insertion(+) diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 426321a0..19dc431c 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -113,6 +113,7 @@ class EntityItem(MappedItemBase): "class_name": ("class_id", ("entity_class", "name")), "dimension_id_list": ("class_id", ("entity_class", "dimension_id_list")), "dimension_name_list": ("class_id", ("entity_class", "dimension_name_list")), + "superclass_name": ("class_id", ("entity_class", "superclass_name")), "element_name_list": ("element_id_list", ("entity", "name")), } _inverse_references = { From dc8ca6c8b00b6ac3759dd23dd9ad71d5e3830d7f Mon Sep 17 00:00:00 2001 From: Manuel Date: Mon, 30 Oct 2023 11:24:31 +0100 Subject: [PATCH 167/317] Register instances under superclass too --- spinedb_api/db_mapping.py | 4 +- spinedb_api/db_mapping_base.py | 206 ++++++++++++----------- spinedb_api/import_functions.py | 4 +- spinedb_api/mapped_items.py | 283 +++++++++++++++++++------------- tests/test_DatabaseMapping.py | 12 ++ tests/test_db_mapping_base.py | 2 +- 6 files changed, 294 insertions(+), 217 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index a2aa527e..7f638d76 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -208,7 +208,7 @@ def all_item_types(): return list(DatabaseMapping._sq_name_by_item_type) @staticmethod - def _item_factory(item_type): + def item_factory(item_type): return item_factory(item_type) def _make_sq(self, item_type): @@ -771,7 +771,7 @@ def _add_convenience_methods(node): if node.name != "DatabaseMapping": return node for item_type in DatabaseMapping.item_types(): - factory = DatabaseMapping._item_factory(item_type) + factory = DatabaseMapping.item_factory(item_type) uq_fields = {f_name: factory.fields[f_name] for f_names in factory._unique_keys for f_name in f_names} a = "an" if any(item_type.lower().startswith(x) for x in "aeiou") else "a" padding = 20 * " " diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 7805c3c3..beb25e81 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -34,7 +34,7 @@ class DatabaseMappingBase: This class is not meant to be used directly. Instead, you should subclass it to fit your particular DB schema. - When subclassing, you need to implement :meth:`item_types`, :meth:`_item_factory`, and :meth:`_make_query`. + When subclassing, you need to implement :meth:`item_types`, :meth:`item_factory`, and :meth:`_make_query`. """ def __init__(self): @@ -44,7 +44,7 @@ def __init__(self): self._sorted_item_types = [] while item_types: item_type = item_types.pop(0) - if self._item_factory(item_type).ref_types() & set(item_types): + if self.item_factory(item_type).ref_types() & set(item_types): item_types.append(item_type) else: self._sorted_item_types.append(item_type) @@ -68,7 +68,7 @@ def all_item_types(): raise NotImplementedError() @staticmethod - def _item_factory(item_type): + def item_factory(item_type): """Returns a subclass of :class:`.MappedItemBase` to make items of given type. Args: @@ -96,10 +96,11 @@ def _make_query(self, item_type, **kwargs): for key, value in kwargs.items(): if hasattr(sq.c, key): qry = qry.filter(getattr(sq.c, key) == value) - elif key in self._item_factory(item_type)._references: - src_key, (ref_type, ref_key) = self._item_factory(item_type)._references[key] + elif key in self._external_fields: + src_key, key = self.item_factory(item_type)._external_fields[key] + ref_type, ref_key = self.item_factory(item_type)._references[src_key] ref_sq = self._make_sq(ref_type) - qry = qry.filter(getattr(sq.c, src_key) == ref_sq.c.id, getattr(ref_sq.c, ref_key) == value) + qry = qry.filter(getattr(sq.c, src_key) == getattr(ref_sq.c, ref_key), getattr(ref_sq.c, key) == value) return qry def _make_sq(self, item_type): @@ -115,7 +116,7 @@ def _make_sq(self, item_type): raise NotImplementedError() def make_item(self, item_type, **item): - factory = self._item_factory(item_type) + factory = self.item_factory(item_type) return factory(self, item_type, **item) def dirty_ids(self, item_type): @@ -158,7 +159,7 @@ def _dirty_items(self): # FIXME: We should also fetch the current item type because of multi-dimensional entities and # classes which also depend on zero-dimensional ones for other_item_type in self.item_types(): - if item_type in self._item_factory(other_item_type).ref_types(): + if item_type in self.item_factory(other_item_type).ref_types(): self.fetch_all(other_item_type) if to_add or to_update or to_remove: dirty_items.append((item_type, (to_add, to_update, to_remove))) @@ -226,7 +227,7 @@ def _add_descendants(self, item_types): while True: changed = False for item_type in set(self.item_types()) - item_types: - if self._item_factory(item_type).ref_types() & item_types: + if self.item_factory(item_type).ref_types() & item_types: item_types.add(item_type) changed = True if not changed: @@ -321,13 +322,12 @@ def _callback(db_id): temp_id.add_resolve_callback(_callback) return temp_id - def unique_key_value_to_id(self, key, value, strict=False, fetch=True): + def _unique_key_value_to_id(self, key, value, fetch=True): """Returns the id that has the given value for the given unique key, or None if not found. Args: key (tuple) value (tuple) - strict (bool): if True, raise a KeyError if id is not found fetch (bool): whether to fetch the DB until found. Returns: @@ -338,12 +338,10 @@ def unique_key_value_to_id(self, key, value, strict=False, fetch=True): self._db_map.do_fetch_all(self._item_type) id_by_unique_value = self._id_by_unique_key_value.get(key, {}) value = tuple(tuple(x) if isinstance(x, list) else x for x in value) - if strict: - return id_by_unique_value[value] return id_by_unique_value.get(value) def _unique_key_value_to_item(self, key, value, fetch=True): - return self.get(self.unique_key_value_to_id(key, value, fetch=fetch)) + return self.get(self._unique_key_value_to_id(key, value, fetch=fetch)) def valid_values(self): return (x for x in self.values() if x.is_valid()) @@ -374,42 +372,38 @@ def find_item(self, item, skip_keys=(), fetch=True): return self.find_item_by_unique_key(item, skip_keys=skip_keys, fetch=fetch) def find_item_by_id(self, id_, fetch=True): - current_item = self.get(id_) - if current_item is None and fetch: + current_item = self.get(id_, {}) + if not current_item and fetch: self._db_map.do_fetch_all(self._item_type) - current_item = self.get(id_) + current_item = self.get(id_, {}) return current_item def find_item_by_unique_key(self, item, skip_keys=(), fetch=True, complete=True): - for key in self._db_map._item_factory(self._item_type)._unique_keys: - if key in skip_keys: - continue - value = tuple(item.get(k) for k in key) - if None in value: - continue + for key, value in self._db_map.item_factory(self._item_type).unique_values_for_item(item, skip_keys=skip_keys): current_item = self._unique_key_value_to_item(key, value, fetch=fetch) if current_item: return current_item if complete: # Maybe item is missing some key stuff, so try with a resolved and polished MappedItem too... mapped_item = self._make_item(item) - error = mapped_item.resolve_inverse_references(item.keys()) + error = mapped_item.resolve_internal_fields(item.keys()) if error: - return None + return {} error = mapped_item.polish() if error: - return None + return {} for key, value in mapped_item.unique_values(skip_keys=skip_keys): current_item = self._unique_key_value_to_item(key, value, fetch=fetch) if current_item: return current_item + return {} def check_item(self, item, for_update=False, skip_keys=()): # FIXME: The only use-case for skip_keys at the moment is that of importing scenario alternatives, # where we only want to match by (scen_name, alt_name) and not by (scen_name, rank) if for_update: current_item = self.find_item(item, skip_keys=skip_keys) - if current_item is None: + if not current_item: return None, f"no {self._item_type} matching {item} to update" full_item, merge_error = current_item.merge(item) if full_item is None: @@ -418,7 +412,7 @@ def check_item(self, item, for_update=False, skip_keys=()): current_item = None full_item, merge_error = item, None candidate_item = self._make_item(full_item) - error = candidate_item.resolve_inverse_references(skip_keys=item.keys()) + error = candidate_item.resolve_internal_fields(skip_keys=item.keys()) if error: return None, error error = candidate_item.polish() @@ -504,7 +498,7 @@ def remove_item(self, id_): current_item.cascade_remove(source=self.wildcard_item) return self.wildcard_item current_item = self.find_item({"id": id_}) - if current_item is not None: + if current_item: self.remove_unique(current_item) current_item.cascade_remove() return current_item @@ -517,7 +511,7 @@ def restore_item(self, id_): current_item.cascade_restore(source=self.wildcard_item) return self.wildcard_item current_item = self.find_item({"id": id_}) - if current_item is not None: + if current_item: self.add_unique(current_item) current_item.cascade_restore() return current_item @@ -533,25 +527,21 @@ class MappedItemBase(dict): _unique_keys = () """A tuple where each element is itself a tuple of keys corresponding to a unique constraint""" _references = {} - """A dictionary mapping keys that are not in the original dictionary, - to a recipe for finding the key they reference in another item. - - The recipe is a tuple of the form (src_key, (ref_item_type, ref_key)), - to be interpreted as follows: - 1. take the value from the src_key of this item, which should be an id, - 2. locate the item of type ref_item_type that has that id, - 3. return the value from the ref_key of that item. + """A dictionary mapping source keys, to a tuple of reference item type and reference key. + Used to access external fields. """ - _inverse_references = {} - """Another dictionary mapping keys that are not in the original dictionary, - to a recipe for finding the key they reference in another item. - Used only for creating new items, when the user provides names and we want to find the ids. - - The recipe is a tuple of the form (src_unique_key, (ref_item_type, ref_unique_key)), - to be interpreted as follows: - 1. take the values from the src_unique_key of this item, to form a tuple, - 2. locate the item of type ref_item_type where the ref_unique_key is exactly that tuple of values, - 3. return the id of that item. + _external_fields = {} + """A dictionary mapping keys that are not in the original dictionary, to a tuple of source key and reference key. + Keys in _external_fields are accessed via the reference key of the reference pointed at by the source key. + """ + _alt_references = {} + """A dictionary mapping source keys, to a tuple of reference item type and reference key. + Used only to resolve internal fields at item creation. + """ + _internal_fields = {} + """A dictionary mapping keys that are not in the original dictionary, to a tuple of source key and reference key. + Keys in _internal_fields are resolved to the reference key of the alternative reference pointed at by the + source key. """ def __init__(self, db_map, item_type, **kwargs): @@ -585,7 +575,7 @@ def ref_types(cls): Returns: set(str) """ - return set(ref_type for _src_key, (ref_type, _ref_key) in cls._references.values()) + return set(ref_type for ref_type, _ref_key in cls._references.values()) @property def status(self): @@ -660,7 +650,7 @@ def _extended(self): dict """ d = self._asdict() - d.update({key: self[key] for key in self._references}) + d.update({key: self[key] for key in self._external_fields}) return d def _asdict(self): @@ -706,17 +696,35 @@ def first_invalid_key(self): Returns: str or None: unresolved reference's key if any. """ - for src_key, (ref_type, _ref_key) in self._references.values(): + return next(self._invalid_keys(), None) + + def _invalid_keys(self): + """Goes through the ``_references`` class attribute and returns the keys of the ones + that cannot be resolved. + + Yields: + str: unresolved keys if any. + """ + for src_key, (ref_type, ref_key) in self._references.items(): try: - ref_id = self[src_key] + src_val = self[src_key] except KeyError: - return src_key - if isinstance(ref_id, tuple): - for x in ref_id: - if not self._get_ref(ref_type, x): - return src_key - elif not self._get_ref(ref_type, ref_id): - return src_key + yield src_key + else: + if isinstance(src_val, tuple): + for x in src_val: + if not self._get_ref(ref_type, {ref_key: x}): + yield src_key + elif not self._get_ref(ref_type, {ref_key: src_val}): + yield src_key + + @classmethod + def unique_values_for_item(cls, item, skip_keys=()): + for key in cls._unique_keys: + if key not in skip_keys: + value = tuple(item.get(k) for k in key) + if None not in value: + yield key, value def unique_values(self, skip_keys=()): """Yields tuples of unique keys and their values. @@ -727,12 +735,10 @@ def unique_values(self, skip_keys=()): Yields: tuple(tuple,tuple): the first element is the unique key, the second is the values. """ - for key in self._unique_keys: - if key not in skip_keys: - yield key, tuple(self.get(k) for k in key) + yield from self.unique_values_for_item(self, skip_keys=skip_keys) - def resolve_inverse_references(self, skip_keys=()): - """Goes through the ``_inverse_references`` class attribute and updates this item + def resolve_internal_fields(self, skip_keys=()): + """Goes through the ``_internal_fields`` class attribute and updates this item by resolving those references. Returns any error. @@ -742,27 +748,26 @@ def resolve_inverse_references(self, skip_keys=()): Returns: str or None: error description if any. """ - for src_key in self._inverse_references: - if src_key in skip_keys: + for key, (src_key, target_key) in self._internal_fields.items(): + if key in skip_keys: continue - error = self._do_resolve_inverse_reference(src_key) + error = self._do_resolve_internal_field(key, src_key, target_key) if error: return error - def _do_resolve_inverse_reference(self, src_key): - id_key, (ref_type, ref_key) = self._inverse_references[src_key] - id_value = tuple(dict.pop(self, k, None) or self.get(k) for k in id_key) - if None in id_value: + def _do_resolve_internal_field(self, key, src_key, target_key): + ref_type, ref_key = self._alt_references[src_key] + src_val = tuple(dict.pop(self, k, None) or self.get(k) for k in src_key) + if None in src_val: return mapped_table = self._db_map.mapped_table(ref_type) try: - self[src_key] = ( - tuple(mapped_table.unique_key_value_to_id(ref_key, v, strict=True) for v in zip(*id_value)) - if all(isinstance(v, (tuple, list)) for v in id_value) - else mapped_table.unique_key_value_to_id(ref_key, id_value, strict=True) + self[key] = ( + tuple(mapped_table.find_item(dict(zip(ref_key, v)))[target_key] for v in zip(*src_val)) + if all(isinstance(v, (tuple, list)) for v in src_val) + else mapped_table.find_item(dict(zip(ref_key, src_val)))[target_key] ) except KeyError as err: - # Happens at unique_key_value_to_id(..., strict=True) return f"can't find {ref_type} with {dict(zip(ref_key, err.args[0]))}" def polish(self): @@ -777,7 +782,7 @@ def polish(self): self.setdefault(key, default_value) return "" - def _get_ref(self, ref_type, ref_id, strong=True): + def _get_ref(self, ref_type, key_val, strong=True): """Collects a reference from the in-memory mapping. Adds this item to the reference's list of referrers if strong is True; or weak referrers if strong is False. @@ -785,17 +790,18 @@ def _get_ref(self, ref_type, ref_id, strong=True): Args: ref_type (str): The reference's type - ref_id (int): The reference's id + key_val (dict): The reference's key and value to match strong (bool): True if the reference corresponds to a foreign key, False otherwise Returns: MappedItemBase or dict """ - ref = self._db_map.get_mapped_item(ref_type, ref_id, fetch=False) + mapped_table = self._db_map.mapped_table(ref_type) + ref = mapped_table.find_item(key_val, fetch=False) if not ref: if not strong: return {} - ref = self._db_map.get_mapped_item(ref_type, ref_id, fetch=True) + ref = mapped_table.find_item(key_val, fetch=True) if not ref: self._corrupted = True return {} @@ -810,14 +816,15 @@ def _get_ref(self, ref_type, ref_id, strong=True): return {} return ref - def _invalidate_ref(self, ref_type, ref_id): + def _invalidate_ref(self, ref_type, key_val): """Invalidates a reference previously collected from the in-memory mapping. Args: ref_type (str): The reference's type - ref_id (int): The reference's id + key_val (dict): The reference's key and value to match """ - ref = self._db_map.get_mapped_item(ref_type, ref_id) + mapped_table = self._db_map.mapped_table(ref_type) + ref = mapped_table.find_item(key_val) ref.remove_referrer(self) def is_valid(self): @@ -833,8 +840,8 @@ def is_valid(self): return False self._to_remove = False self._corrupted = False - for key in self._references: - _ = self[key] + for _ in self._invalid_keys(): # This sets self._to_remove and self._corrupted + pass if self._to_remove: self.cascade_remove() self._valid = not self._removed and not self._corrupted @@ -986,13 +993,14 @@ def __getattr__(self, name): def __getitem__(self, key): """Overridden to return references.""" - ref = self._references.get(key) - if ref: - src_key, (ref_type, ref_key) = ref - ref_id = self[src_key] - if isinstance(ref_id, tuple): - return tuple(self._get_ref(ref_type, x).get(ref_key) for x in ref_id) - return self._get_ref(ref_type, ref_id).get(ref_key) + ext_val = self._external_fields.get(key) + if ext_val: + src_key, key = ext_val + ref_type, ref_key = self._references[src_key] + src_val = self[src_key] + if isinstance(src_val, tuple): + return tuple(self._get_ref(ref_type, {ref_key: x}).get(key) for x in src_val) + return self._get_ref(ref_type, {ref_key: src_val}).get(key) return super().__getitem__(key) def __setitem__(self, key, value): @@ -1015,15 +1023,15 @@ def update(self, other): self._backup = self._asdict() elif self._status in (Status.to_remove, Status.added_and_removed): raise RuntimeError("invalid status of item being updated") - for src_key, (ref_type, _ref_key) in self._references.values(): - ref_id = self[src_key] - if src_key in other and other[src_key] != ref_id: + for src_key, (ref_type, ref_key) in self._references.items(): + src_val = self[src_key] + if src_key in other and other[src_key] != src_val: # Invalidate references - if isinstance(ref_id, tuple): - for x in ref_id: - self._invalidate_ref(ref_type, x) + if isinstance(src_val, tuple): + for x in src_val: + self._invalidate_ref(ref_type, {ref_key: x}) else: - self._invalidate_ref(ref_type, ref_id) + self._invalidate_ref(ref_type, {ref_key: src_val}) super().update(other) if self._asdict() == self._backup: self._status = Status.committed diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 4e7501ce..9144a8e3 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -584,7 +584,7 @@ def _data_iterator(): "type": None, } pv = db_map.mapped_table("parameter_value").find_item(item) - if pv is not None: + if pv: value, type_ = fix_conflict((value, type_), (pv["value"], pv["type"]), on_conflict) item.update({"value": value, "type": type_}) yield item @@ -610,7 +610,7 @@ def _get_scenario_alternatives_for_import(db_map, data): alt_name_list_by_scen_name, errors = {}, [] for scen_name, alt_name, *optionals in data: scen = db_map.mapped_table("scenario").find_item({"name": scen_name}) - if scen is None: + if not scen: errors.append(f"no scenario with name {scen_name} to set alternatives for") continue alternative_name_list = alt_name_list_by_scen_name.setdefault(scen_name, scen["alternative_name_list"]) diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 19dc431c..3229455b 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -58,8 +58,10 @@ class EntityClassItem(MappedItemBase): } _defaults = {"description": None, "display_icon": None, "display_order": 99, "hidden": False} _unique_keys = (("name",),) - _references = {"dimension_name_list": ("dimension_id_list", ("entity_class", "name"))} - _inverse_references = {"dimension_id_list": (("dimension_name_list",), ("entity_class", ("name",)))} + _references = {"dimension_id_list": ("entity_class", "id")} + _external_fields = {"dimension_name_list": ("dimension_id_list", "name")} + _alt_references = {("dimension_name_list",): ("entity_class", ("name",))} + _internal_fields = {"dimension_id_list": (("dimension_name_list",), "id")} def __init__(self, *args, **kwargs): dimension_id_list = kwargs.get("dimension_id_list") @@ -71,14 +73,12 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def __getitem__(self, key): - if key == "superclass_name": + if key in ("superclass_id", "superclass_name"): + superclass_subclass = self._db_map.mapped_table("superclass_subclass").find_item_by_unique_key( + {"subclass_name": self["name"]} + ) # FIXME: create a weak reference too - return ( - self._db_map.mapped_table("superclass_subclass").find_item_by_unique_key( - {"subclass_name": self["name"]} - ) - or {} - ).get("superclass_name") + return superclass_subclass.get(key) return super().__getitem__(key) def merge(self, other): @@ -108,17 +108,23 @@ class EntityItem(MappedItemBase): "description": ("str, optional", "The entity description."), } _defaults = {"description": None} - _unique_keys = (("class_name", "name"), ("class_name", "byname")) - _references = { - "class_name": ("class_id", ("entity_class", "name")), - "dimension_id_list": ("class_id", ("entity_class", "dimension_id_list")), - "dimension_name_list": ("class_id", ("entity_class", "dimension_name_list")), - "superclass_name": ("class_id", ("entity_class", "superclass_name")), - "element_name_list": ("element_id_list", ("entity", "name")), + _unique_keys = (("class_name", "name"), ("class_name", "byname"), ("superclass_name", "name")) + _references = {"class_id": ("entity_class", "id"), "element_id_list": ("entity", "id")} + _external_fields = { + "class_name": ("class_id", "name"), + "dimension_id_list": ("class_id", "dimension_id_list"), + "dimension_name_list": ("class_id", "dimension_name_list"), + "superclass_id": ("class_id", "superclass_id"), + "superclass_name": ("class_id", "superclass_name"), + "element_name_list": ("element_id_list", "name"), + } + _alt_references = { + ("class_name",): ("entity_class", ("name",)), + ("dimension_name_list", "element_name_list"): ("entity", ("class_name", "name")), } - _inverse_references = { - "class_id": (("class_name",), ("entity_class", ("name",))), - "element_id_list": (("dimension_name_list", "element_name_list"), ("entity", ("class_name", "name"))), + _internal_fields = { + "class_id": (("class_name",), "id"), + "element_id_list": (("dimension_name_list", "element_name_list"), "id"), } def __init__(self, *args, **kwargs): @@ -136,7 +142,7 @@ def _element_name_list_iter(self, entity): yield entity["name"] else: for el_id in element_id_list: - element = self._get_ref("entity", el_id) + element = self._get_ref("entity", {"id", el_id}) yield from self._element_name_list_iter(element) def __getitem__(self, key): @@ -150,13 +156,22 @@ def polish(self): error = super().polish() if error: return error + dim_name_lst, el_name_lst = dict.get(self, "dimension_name_list"), dict.get(self, "element_name_list") + if dim_name_lst and el_name_lst: + for dim_name, el_name in zip(dim_name_lst, el_name_lst): + if not ( + self._db_map.get_item("entity", class_name=dim_name, name=el_name, fetch=False) + or self._db_map.get_item("entity", superclass_name=dim_name, name=el_name, fetch=False) + ): + return f"element '{el_name}' is not an instance of class '{dim_name}'" if self.get("name") is not None: return base_name = "__".join(self["element_name_list"]) name = base_name - mapped_table = self._db_map.mapped_table(self._item_type) index = 1 - while mapped_table.find_item({"class_name": self["class_name"], "name": name}): + while self._db_map.get_item("entity", class_name=self["class_name"], name=name) or self._db_map.get_item( + "entity", superclass_name=self["superclass_name"], name=name + ): name = base_name + f"_{index}" index += 1 self["name"] = name @@ -170,15 +185,25 @@ class EntityGroupItem(MappedItemBase): } _unique_keys = (("class_name", "group_name", "member_name"),) _references = { - "class_name": ("entity_class_id", ("entity_class", "name")), - "group_name": ("entity_id", ("entity", "name")), - "member_name": ("member_id", ("entity", "name")), - "dimension_id_list": ("entity_class_id", ("entity_class", "dimension_id_list")), + "entity_class_id": ("entity_class", "id"), + "entity_id": ("entity", "id"), + "member_id": ("entity", "id"), + } + _external_fields = { + "class_name": ("entity_class_id", "name"), + "dimension_id_list": ("entity_class_id", "dimension_id_list"), + "group_name": ("entity_id", "name"), + "member_name": ("member_id", "name"), + } + _alt_references = { + ("class_name",): ("entity_class", ("name",)), + ("class_name", "group_name"): ("entity", ("class_name", "name")), + ("class_name", "member_name"): ("entity", ("class_name", "name")), } - _inverse_references = { - "entity_class_id": (("class_name",), ("entity_class", ("name",))), - "entity_id": (("class_name", "group_name"), ("entity", ("class_name", "name"))), - "member_id": (("class_name", "member_name"), ("entity", ("class_name", "name"))), + _internal_fields = { + "entity_class_id": (("class_name",), "id"), + "entity_id": (("class_name", "group_name"), "id"), + "member_id": (("class_name", "member_name"), "id"), } def __getitem__(self, key): @@ -203,19 +228,28 @@ class EntityAlternativeItem(MappedItemBase): _defaults = {"active": True} _unique_keys = (("entity_class_name", "entity_byname", "alternative_name"),) _references = { - "entity_class_id": ("entity_id", ("entity", "class_id")), - "entity_class_name": ("entity_class_id", ("entity_class", "name")), - "dimension_id_list": ("entity_class_id", ("entity_class", "dimension_id_list")), - "dimension_name_list": ("entity_class_id", ("entity_class", "dimension_name_list")), - "entity_name": ("entity_id", ("entity", "name")), - "entity_byname": ("entity_id", ("entity", "byname")), - "element_id_list": ("entity_id", ("entity", "element_id_list")), - "element_name_list": ("entity_id", ("entity", "element_name_list")), - "alternative_name": ("alternative_id", ("alternative", "name")), + "entity_id": ("entity", "id"), + "entity_class_id": ("entity_class", "id"), + "alternative_id": ("alternative", "id"), } - _inverse_references = { - "entity_id": (("entity_class_name", "entity_byname"), ("entity", ("class_name", "byname"))), - "alternative_id": (("alternative_name",), ("alternative", ("name",))), + _external_fields = { + "entity_class_id": ("entity_id", "class_id"), + "entity_class_name": ("entity_class_id", "name"), + "dimension_id_list": ("entity_class_id", "dimension_id_list"), + "dimension_name_list": ("entity_class_id", "dimension_name_list"), + "entity_name": ("entity_id", "name"), + "entity_byname": ("entity_id", "byname"), + "element_id_list": ("entity_id", "element_id_list"), + "element_name_list": ("entity_id", "element_name_list"), + "alternative_name": ("alternative_id", "name"), + } + _alt_references = { + ("entity_class_name", "entity_byname"): ("entity", ("class_name", "byname")), + ("alternative_name",): ("alternative", ("name",)), + } + _internal_fields = { + "entity_id": (("entity_class_name", "entity_byname"), "id"), + "alternative_id": (("alternative_name",), "id"), } @@ -305,7 +339,6 @@ def polish(self): type_ = super().__getitem__(self._type_key) if type_ == "list_value_ref": return - # value = self[self._value_key] value = super().__getitem__(self._value_key) parsed_value = from_database(value, type_) if parsed_value is None: @@ -330,14 +363,19 @@ class ParameterDefinitionItem(ParameterItemBase): } _defaults = {"description": None, "default_value": None, "default_type": None, "parameter_value_list_id": None} _unique_keys = (("entity_class_name", "name"),) - _references = { - "entity_class_name": ("entity_class_id", ("entity_class", "name")), - "dimension_id_list": ("entity_class_id", ("entity_class", "dimension_id_list")), - "dimension_name_list": ("entity_class_id", ("entity_class", "dimension_name_list")), + _references = {"entity_class_id": ("entity_class", "id")} + _external_fields = { + "entity_class_name": ("entity_class_id", "name"), + "dimension_id_list": ("entity_class_id", "dimension_id_list"), + "dimension_name_list": ("entity_class_id", "dimension_name_list"), } - _inverse_references = { - "entity_class_id": (("entity_class_name",), ("entity_class", ("name",))), - "parameter_value_list_id": (("parameter_value_list_name",), ("parameter_value_list", ("name",))), + _alt_references = { + ("entity_class_name",): ("entity_class", ("name",)), + ("parameter_value_list_name",): ("parameter_value_list", ("name",)), + } + _internal_fields = { + "entity_class_id": (("entity_class_name",), "id"), + "parameter_value_list_id": (("parameter_value_list_name",), "id"), } @property @@ -354,12 +392,14 @@ def __getitem__(self, key): if key == "parameter_value_list_id": return dict.get(self, key) if key == "parameter_value_list_name": - return self._get_ref("parameter_value_list", self["parameter_value_list_id"], strong=False).get("name") + return self._get_ref("parameter_value_list", {"id": self["parameter_value_list_id"]}, strong=False).get( + "name" + ) if key in ("default_value", "default_type"): list_value_id = self["list_value_id"] if list_value_id is not None: list_value_key = {"default_value": "value", "default_type": "type"}[key] - return self._get_ref("list_value", list_value_id, strong=False).get(list_value_key) + return self._get_ref("list_value", {"id": list_value_id}, strong=False).get(list_value_key) return dict.get(self, key) return super().__getitem__(key) @@ -399,26 +439,35 @@ class ParameterValueItem(ParameterItemBase): } _unique_keys = (("entity_class_name", "parameter_definition_name", "entity_byname", "alternative_name"),) _references = { - "entity_class_name": ("entity_class_id", ("entity_class", "name")), - "dimension_id_list": ("entity_class_id", ("entity_class", "dimension_id_list")), - "dimension_name_list": ("entity_class_id", ("entity_class", "dimension_name_list")), - "parameter_definition_name": ("parameter_definition_id", ("parameter_definition", "name")), - "parameter_value_list_id": ("parameter_definition_id", ("parameter_definition", "parameter_value_list_id")), - "parameter_value_list_name": ("parameter_definition_id", ("parameter_definition", "parameter_value_list_name")), - "entity_name": ("entity_id", ("entity", "name")), - "entity_byname": ("entity_id", ("entity", "byname")), - "element_id_list": ("entity_id", ("entity", "element_id_list")), - "element_name_list": ("entity_id", ("entity", "element_name_list")), - "alternative_name": ("alternative_id", ("alternative", "name")), - } - _inverse_references = { - "entity_class_id": (("entity_class_name",), ("entity_class", ("name",))), - "parameter_definition_id": ( - ("entity_class_name", "parameter_definition_name"), - ("parameter_definition", ("entity_class_name", "name")), - ), - "entity_id": (("entity_class_name", "entity_byname"), ("entity", ("class_name", "byname"))), - "alternative_id": (("alternative_name",), ("alternative", ("name",))), + "entity_class_id": ("entity_class", "id"), + "parameter_definition_id": ("parameter_definition", "id"), + "entity_id": ("entity", "id"), + "alternative_id": ("alternative", "id"), + } + _external_fields = { + "entity_class_name": ("entity_class_id", "name"), + "dimension_id_list": ("entity_class_id", "dimension_id_list"), + "dimension_name_list": ("entity_class_id", "dimension_name_list"), + "parameter_definition_name": ("parameter_definition_id", "name"), + "parameter_value_list_id": ("parameter_definition_id", "parameter_value_list_id"), + "parameter_value_list_name": ("parameter_definition_id", "parameter_value_list_name"), + "entity_name": ("entity_id", "name"), + "entity_byname": ("entity_id", "byname"), + "element_id_list": ("entity_id", "element_id_list"), + "element_name_list": ("entity_id", "element_name_list"), + "alternative_name": ("alternative_id", "name"), + } + _alt_references = { + ("entity_class_name",): ("entity_class", ("name",)), + ("entity_class_name", "parameter_definition_name"): ("parameter_definition", ("entity_class_name", "name")), + ("entity_class_name", "entity_byname"): ("entity", ("class_name", "byname")), + ("alternative_name",): ("alternative", ("name",)), + } + _internal_fields = { + "entity_class_id": (("entity_class_name",), "id"), + "parameter_definition_id": (("entity_class_name", "parameter_definition_name"), "id"), + "entity_id": (("entity_class_name", "entity_byname"), "id"), + "alternative_id": (("alternative_name",), "id"), } @property @@ -437,7 +486,7 @@ def __getitem__(self, key): if key in ("value", "type"): list_value_id = self["list_value_id"] if list_value_id: - return self._get_ref("list_value", list_value_id, strong=False).get(key) + return self._get_ref("list_value", {"id": list_value_id}, strong=False).get(key) return super().__getitem__(key) def _value_not_in_list_error(self, parsed_value, list_name): @@ -459,11 +508,11 @@ class ListValueItem(ParsedValueBase): "type": ("str", "The value type."), "index": ("int, optional", "The value index."), } - _unique_keys = (("parameter_value_list_name", "value", "type"), ("parameter_value_list_name", "index")) - _references = {"parameter_value_list_name": ("parameter_value_list_id", ("parameter_value_list", "name"))} - _inverse_references = { - "parameter_value_list_id": (("parameter_value_list_name",), ("parameter_value_list", ("name",))), - } + _unique_keys = (("parameter_value_list_name", "parsed_value"), ("parameter_value_list_name", "index")) + _references = {"parameter_value_list_id": ("parameter_value_list", "id")} + _external_fields = {"parameter_value_list_name": ("parameter_value_list_id", "name")} + _alt_references = {("parameter_value_list_name",): ("parameter_value_list", ("name",))} + _internal_fields = {"parameter_value_list_id": (("parameter_value_list_name",), "id")} @property def _value_key(self): @@ -517,14 +566,10 @@ class ScenarioAlternativeItem(MappedItemBase): "rank": ("int", "The rank - the higher has precedence."), } _unique_keys = (("scenario_name", "alternative_name"), ("scenario_name", "rank")) - _references = { - "scenario_name": ("scenario_id", ("scenario", "name")), - "alternative_name": ("alternative_id", ("alternative", "name")), - } - _inverse_references = { - "scenario_id": (("scenario_name",), ("scenario", ("name",))), - "alternative_id": (("alternative_name",), ("alternative", ("name",))), - } + _references = {"scenario_id": ("scenario", "id"), "alternative_id": ("alternative", "id")} + _external_fields = {"scenario_name": ("scenario_id", "name"), "alternative_name": ("alternative_id", "name")} + _alt_references = {("scenario_name",): ("scenario", ("name",)), ("alternative_name",): ("alternative", ("name",))} + _internal_fields = {"scenario_id": (("scenario_name",), "id"), "alternative_id": (("alternative_name",), "id")} def __getitem__(self, key): # The 'before' is to be interpreted as, this scenario alternative goes *before* the before_alternative. @@ -532,9 +577,9 @@ def __getitem__(self, key): # the second will have the third, etc, and the last will have None. # Note that alternatives with higher ranks overwrite the values of those with lower ranks. if key == "before_alternative_name": - return self._get_ref("alternative", self["before_alternative_id"], strong=False).get("name") + return self._get_ref("alternative", {"id": self["before_alternative_id"]}, strong=False).get("name") if key == "before_alternative_id": - scenario = self._get_ref("scenario", self["scenario_id"], strong=False) + scenario = self._get_ref("scenario", {"id": self["scenario_id"]}, strong=False) try: return scenario["alternative_id_list"][self["rank"]] except IndexError: @@ -554,14 +599,19 @@ class EntityMetadataItem(MappedItemBase): "metadata_value": ("str", "The metadata entry value."), } _unique_keys = (("entity_name", "metadata_name", "metadata_value"),) - _references = { - "entity_name": ("entity_id", ("entity", "name")), - "metadata_name": ("metadata_id", ("metadata", "name")), - "metadata_value": ("metadata_id", ("metadata", "value")), + _references = {"entity_id": ("entity", "id"), "metadata_id": ("metadata", "id")} + _external_fields = { + "entity_name": ("entity_id", "name"), + "metadata_name": ("metadata_id", "name"), + "metadata_value": ("metadata_id", "value"), } - _inverse_references = { - "entity_id": (("entity_class_name", "entity_byname"), ("entity", ("class_name", "byname"))), - "metadata_id": (("metadata_name", "metadata_value"), ("metadata", ("name", "value"))), + _alt_references = { + ("entity_class_name", "entity_byname"): ("entity", ("class_name", "byname")), + ("metadata_name", "metadata_value"): ("metadata", ("name", "value")), + } + _internal_fields = { + "entity_id": (("entity_class_name", "entity_byname"), "id"), + "metadata_id": (("metadata_name", "metadata_value"), "id"), } @@ -580,33 +630,40 @@ class ParameterValueMetadataItem(MappedItemBase): _unique_keys = ( ("parameter_definition_name", "entity_byname", "alternative_name", "metadata_name", "metadata_value"), ) - _references = { - "parameter_definition_name": ("parameter_value_id", ("parameter_value", "parameter_definition_name")), - "entity_byname": ("parameter_value_id", ("parameter_value", "entity_byname")), - "alternative_name": ("parameter_value_id", ("parameter_value", "alternative_name")), - "metadata_name": ("metadata_id", ("metadata", "name")), - "metadata_value": ("metadata_id", ("metadata", "value")), + _references = {"parameter_value_id": ("parameter_value", "id"), "metadata_id": ("metadata", "id")} + _external_fields = { + "parameter_definition_name": ("parameter_value_id", "parameter_definition_name"), + "entity_byname": ("parameter_value_id", "entity_byname"), + "alternative_name": ("parameter_value_id", "alternative_name"), + "metadata_name": ("metadata_id", "name"), + "metadata_value": ("metadata_id", "value"), + } + _alt_references = { + ("entity_class_name", "parameter_definition_name", "entity_byname", "alternative_name"): ( + "parameter_value", + ("entity_class_name", "parameter_definition_name", "entity_byname", "alternative_name"), + ), + ("metadata_name", "metadata_value"): ("metadata", ("name", "value")), } - _inverse_references = { + _internal_fields = { "parameter_value_id": ( ("entity_class_name", "parameter_definition_name", "entity_byname", "alternative_name"), - ( - "parameter_value", - ("entity_class_name", "parameter_definition_name", "entity_byname", "alternative_name"), - ), + "id", ), - "metadata_id": (("metadata_name", "metadata_value"), ("metadata", ("name", "value"))), + "metadata_id": (("metadata_name", "metadata_value"), "id"), } class SuperclassSubclassItem(MappedItemBase): fields = {"superclass_name": ("str", "The superclass name."), "subclass_name": ("str", "The subclass name.")} _unique_keys = (("subclass_name",),) - _references = { - "superclass_name": ("superclass_id", ("entity_class", "name")), - "subclass_name": ("subclass_id", ("entity_class", "name")), + _references = {"superclass_id": ("entity_class", "id")} + _external_fields = { + "superclass_name": ("superclass_id", "name"), + "subclass_name": ("subclass_id", "name"), } - _inverse_references = { - "superclass_id": (("superclass_name",), ("entity_class", ("name",))), - "subclass_id": (("subclass_name",), ("entity_class", ("name",))), + _alt_references = { + ("superclass_name",): ("entity_class", ("name",)), + ("subclass_name",): ("entity_class", ("name",)), } + _internal_fields = {"superclass_id": (("superclass_name",), "id"), "subclass_id": (("subclass_name",), "id")} diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 7ddb75e9..82d4f4ec 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -1597,6 +1597,18 @@ def test_add_ext_parameter_value_metadata_reuses_existing_metadata(self): }, ) + def test_add_entity_to_a_class_with_abstract_dimensions(self): + import_functions.import_entity_classes( + self._db_map, (("fish", ()), ("dog", ()), ("animal", ()), ("two_animals", ("animal", "animal"))) + ) + import_functions.import_superclass_subclasses(self._db_map, (("animal", "fish"), ("animal", "dog"))) + import_functions.import_entities(self._db_map, (("fish", "Nemo"), ("dog", "Pulgoso"))) + self._db_map.commit_session("Add test data.") + items, errors = self._db_map.add_item( + "entity", {"class_name": "two_animals", "entity_name_list": ("Nemo", "Pulgoso")}, strict=False + ) + self.fail() + class TestDatabaseMappingUpdate(unittest.TestCase): def setUp(self): diff --git a/tests/test_db_mapping_base.py b/tests/test_db_mapping_base.py index 0b130d85..179c1ed8 100644 --- a/tests/test_db_mapping_base.py +++ b/tests/test_db_mapping_base.py @@ -23,7 +23,7 @@ def all_item_types(): return ["cutlery"] @staticmethod - def _item_factory(item_type): + def item_factory(item_type): if item_type == "cutlery": return MappedItemBase raise RuntimeError(f"unknown item_type '{item_type}'") From 500c69fbfca6843625161249d78f0a475e6ead25 Mon Sep 17 00:00:00 2001 From: Manuel Date: Mon, 30 Oct 2023 16:45:44 +0100 Subject: [PATCH 168/317] Add multi-D instances with abstract dimensions --- spinedb_api/db_mapping_base.py | 23 ++++++++++++++--------- spinedb_api/import_functions.py | 22 ++++++++++++++++++++++ spinedb_api/mapped_items.py | 24 +++++++++++++++--------- tests/test_DatabaseMapping.py | 7 +++---- 4 files changed, 54 insertions(+), 22 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index beb25e81..2b0e86c1 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -756,19 +756,24 @@ def resolve_internal_fields(self, skip_keys=()): return error def _do_resolve_internal_field(self, key, src_key, target_key): - ref_type, ref_key = self._alt_references[src_key] src_val = tuple(dict.pop(self, k, None) or self.get(k) for k in src_key) if None in src_val: return + ref_type, ref_key = self._alt_references[src_key] mapped_table = self._db_map.mapped_table(ref_type) - try: - self[key] = ( - tuple(mapped_table.find_item(dict(zip(ref_key, v)))[target_key] for v in zip(*src_val)) - if all(isinstance(v, (tuple, list)) for v in src_val) - else mapped_table.find_item(dict(zip(ref_key, src_val)))[target_key] - ) - except KeyError as err: - return f"can't find {ref_type} with {dict(zip(ref_key, err.args[0]))}" + if all(isinstance(v, (tuple, list)) for v in src_val): + refs = [] + for v in zip(*src_val): + ref = mapped_table.find_item(dict(zip(ref_key, v))) + if not ref: + return f"can't find {ref_type} with {dict(zip(ref_key, v))}" + refs.append(ref) + self[key] = tuple(ref[target_key] for ref in refs) + else: + ref = mapped_table.find_item(dict(zip(ref_key, src_val))) + if not ref: + return f"can't find {ref_type} with {dict(zip(ref_key, src_val))}" + self[key] = ref[target_key] def polish(self): """Polishes this item once all it's references have been resolved. Returns any error. diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 9144a8e3..26dd94b6 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -117,6 +117,7 @@ def get_data_for_import( metadata=(), entity_metadata=(), parameter_value_metadata=(), + superclass_subclasses=(), # legacy object_classes=(), relationship_classes=(), @@ -175,6 +176,8 @@ def get_data_for_import( alternatives = list({item[1]: None for item in scenario_alternatives}) yield ("alternative", _get_alternatives_for_import(db_map, alternatives)) yield ("scenario_alternative", _get_scenario_alternatives_for_import(db_map, scenario_alternatives)) + if superclass_subclasses: + yield ("superclass_subclass", _get_parameter_superclass_subclasses_for_import(db_map, superclass_subclasses)) if entity_classes: for bucket in _get_entity_classes_for_import(db_map, entity_classes): yield ("entity_class", bucket) @@ -240,6 +243,20 @@ def get_data_for_import( yield from get_data_for_import(db_map, parameter_value_metadata=relationship_parameter_value_metadata) +def import_superclass_subclasses(db_map, data): + """Imports superclass_subclasses into a Spine database using a standard format. + + Args: + db_map (spinedb_api.DiffDatabaseMapping): database mapping + data (list(tuple(str,tuple,str,int)): tuples of (superclass name, subclass name) + + Returns: + int: number of items imported + list: errors + """ + return import_data(db_map, superclass_subclasses=data) + + def import_entity_classes(db_map, data): """Imports entity classes into a Spine database using a standard format. @@ -491,6 +508,11 @@ def _add_to_seen(checked_item, seen): seen.setdefault(key, set()).add(value) +def _get_parameter_superclass_subclasses_for_import(db_map, data): + key = ("superclass_name", "subclass_name") + return _get_items_for_import(db_map, "superclass_subclass", (dict(zip(key, x)) for x in data)) + + def _get_entity_classes_for_import(db_map, data): dim_name_list_by_name = {} items = [] diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 3229455b..e00728c1 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -108,7 +108,7 @@ class EntityItem(MappedItemBase): "description": ("str, optional", "The entity description."), } _defaults = {"description": None} - _unique_keys = (("class_name", "name"), ("class_name", "byname"), ("superclass_name", "name")) + _unique_keys = (("class_name", "name"), ("class_name", "byname")) _references = {"class_id": ("entity_class", "id"), "element_id_list": ("entity", "id")} _external_fields = { "class_name": ("class_id", "name"), @@ -136,6 +136,15 @@ def __init__(self, *args, **kwargs): kwargs["element_id_list"] = tuple(element_id_list) super().__init__(*args, **kwargs) + @classmethod + def unique_values_for_item(cls, item, skip_keys=()): + yield from super().unique_values_for_item(item, skip_keys=skip_keys) + key = ("class_name", "name") + if key not in skip_keys: + value = tuple(item.get(k) for k in ("superclass_name", "name")) + if None not in value: + yield key, value + def _element_name_list_iter(self, entity): element_id_list = entity["element_id_list"] if not element_id_list: @@ -159,20 +168,17 @@ def polish(self): dim_name_lst, el_name_lst = dict.get(self, "dimension_name_list"), dict.get(self, "element_name_list") if dim_name_lst and el_name_lst: for dim_name, el_name in zip(dim_name_lst, el_name_lst): - if not ( - self._db_map.get_item("entity", class_name=dim_name, name=el_name, fetch=False) - or self._db_map.get_item("entity", superclass_name=dim_name, name=el_name, fetch=False) - ): + if not self._db_map.get_item("entity", class_name=dim_name, name=el_name, fetch=False): return f"element '{el_name}' is not an instance of class '{dim_name}'" if self.get("name") is not None: return base_name = "__".join(self["element_name_list"]) name = base_name index = 1 - while self._db_map.get_item("entity", class_name=self["class_name"], name=name) or self._db_map.get_item( - "entity", superclass_name=self["superclass_name"], name=name + while any( + self._db_map.get_item("entity", class_name=self[k], name=name) for k in ("class_name", "superclass_name") ): - name = base_name + f"_{index}" + name = f"{base_name}_{index}" index += 1 self["name"] = name @@ -657,7 +663,7 @@ class ParameterValueMetadataItem(MappedItemBase): class SuperclassSubclassItem(MappedItemBase): fields = {"superclass_name": ("str", "The superclass name."), "subclass_name": ("str", "The subclass name.")} _unique_keys = (("subclass_name",),) - _references = {"superclass_id": ("entity_class", "id")} + _references = {"superclass_id": ("entity_class", "id"), "subclass_id": ("entity_class", "id")} _external_fields = { "superclass_name": ("superclass_id", "name"), "subclass_name": ("subclass_id", "name"), diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 82d4f4ec..02ea8bf7 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -1604,10 +1604,9 @@ def test_add_entity_to_a_class_with_abstract_dimensions(self): import_functions.import_superclass_subclasses(self._db_map, (("animal", "fish"), ("animal", "dog"))) import_functions.import_entities(self._db_map, (("fish", "Nemo"), ("dog", "Pulgoso"))) self._db_map.commit_session("Add test data.") - items, errors = self._db_map.add_item( - "entity", {"class_name": "two_animals", "entity_name_list": ("Nemo", "Pulgoso")}, strict=False - ) - self.fail() + item, error = self._db_map.add_item("entity", class_name="two_animals", element_name_list=("Nemo", "Pulgoso")) + print(item) + # self.fail() class TestDatabaseMappingUpdate(unittest.TestCase): From 4ce5eef6fd1637a86d36d4f1f2192cc4a4a038f0 Mon Sep 17 00:00:00 2001 From: Manuel Date: Mon, 30 Oct 2023 16:57:58 +0100 Subject: [PATCH 169/317] Fix typo --- spinedb_api/db_mapping_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 2b0e86c1..5a42262d 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -96,7 +96,7 @@ def _make_query(self, item_type, **kwargs): for key, value in kwargs.items(): if hasattr(sq.c, key): qry = qry.filter(getattr(sq.c, key) == value) - elif key in self._external_fields: + elif key in self.item_factory(item_type)._external_fields: src_key, key = self.item_factory(item_type)._external_fields[key] ref_type, ref_key = self.item_factory(item_type)._references[src_key] ref_sq = self._make_sq(ref_type) From a9f71fcc6e86bbff1b5b329953fe853d933e196b Mon Sep 17 00:00:00 2001 From: Manuel Date: Mon, 30 Oct 2023 17:13:16 +0100 Subject: [PATCH 170/317] Export superclass subclasses --- spinedb_api/export_functions.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/spinedb_api/export_functions.py b/spinedb_api/export_functions.py index 14e2094b..191ce571 100644 --- a/spinedb_api/export_functions.py +++ b/spinedb_api/export_functions.py @@ -23,6 +23,7 @@ def export_data( db_map, entity_class_ids=Asterisk, + superclass_subclass_ids=Asterisk, entity_ids=Asterisk, entity_group_ids=Asterisk, parameter_value_list_ids=Asterisk, @@ -57,6 +58,7 @@ def export_data( """ data = { "entity_classes": export_entity_classes(db_map, entity_class_ids), + "superclass_subclasses": export_superclass_subclasses(db_map, superclass_subclass_ids), "entities": export_entities(db_map, entity_ids), "entity_alternatives": export_entity_alternatives(db_map, entity_alternative_ids), "entity_groups": export_entity_groups(db_map, entity_group_ids), @@ -128,6 +130,10 @@ def export_entity_classes(db_map, ids=Asterisk): ) +def export_superclass_subclasses(db_map, ids=Asterisk): + return sorted(((x.superclass_name, x.subclass_name) for x in _get_items(db_map, "superclass_subclasses", ids))) + + def export_entities(db_map, ids=Asterisk): return sorted( ((x.class_name, x.name, x.element_name_list, x.description) for x in _get_items(db_map, "entity", ids)), From 7aa4961095c120bdb0a442e526cf5b1b5214889f Mon Sep 17 00:00:00 2001 From: Manuel Date: Mon, 30 Oct 2023 17:13:36 +0100 Subject: [PATCH 171/317] Add migration script to create the superclass_subclass table --- ...63bef2_create_superclass_subclass_table.py | 45 +++++++++++++++++++ spinedb_api/helpers.py | 4 +- 2 files changed, 47 insertions(+), 2 deletions(-) create mode 100644 spinedb_api/alembic/versions/5385f063bef2_create_superclass_subclass_table.py diff --git a/spinedb_api/alembic/versions/5385f063bef2_create_superclass_subclass_table.py b/spinedb_api/alembic/versions/5385f063bef2_create_superclass_subclass_table.py new file mode 100644 index 00000000..a8e1cc54 --- /dev/null +++ b/spinedb_api/alembic/versions/5385f063bef2_create_superclass_subclass_table.py @@ -0,0 +1,45 @@ +"""create superclass_subclass table + +Revision ID: 5385f063bef2 +Revises: ce9faa82ed59 +Create Date: 2023-10-30 17:11:23.316879 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '5385f063bef2' +down_revision = 'ce9faa82ed59' +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + 'superclass_subclass', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('superclass_id', sa.Integer(), nullable=True), + sa.Column('subclass_id', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ['subclass_id'], + ['entity_class.id'], + name=op.f('fk_superclass_subclass_subclass_id_entity_class'), + onupdate='CASCADE', + ondelete='CASCADE', + ), + sa.ForeignKeyConstraint( + ['superclass_id'], + ['entity_class.id'], + name=op.f('fk_superclass_subclass_superclass_id_entity_class'), + onupdate='CASCADE', + ondelete='CASCADE', + ), + sa.PrimaryKeyConstraint('id', name=op.f('pk_superclass_subclass')), + sa.UniqueConstraint('subclass_id', name=op.f('uq_superclass_subclass_subclass_id')), + ) + + +def downgrade(): + pass diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index ee93c2bf..ad86e52c 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -47,7 +47,7 @@ ) from sqlalchemy.ext.automap import generate_relationship from sqlalchemy.ext.compiler import compiles -from sqlalchemy.exc import DatabaseError, IntegrityError, OperationalError +from sqlalchemy.exc import DatabaseError, IntegrityError from sqlalchemy.dialects.mysql import TINYINT, DOUBLE from sqlalchemy.sql.expression import FunctionElement, bindparam, cast from alembic.config import Config @@ -601,7 +601,7 @@ def create_new_spine_database(db_url): meta.create_all(engine) engine.execute("INSERT INTO `commit` VALUES (1, 'Create the database', CURRENT_TIMESTAMP, 'spinedb_api')") engine.execute("INSERT INTO alternative VALUES (1, 'Base', 'Base alternative', 1)") - engine.execute("INSERT INTO alembic_version VALUES ('ce9faa82ed59')") + engine.execute("INSERT INTO alembic_version VALUES ('5385f063bef2')") except DatabaseError as e: raise SpineDBAPIError(f"Unable to create Spine database: {e}") from None return engine From 09c7cc18d9e116f477d44fc50a9ffb3c34d7e0ff Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 31 Oct 2023 08:34:24 +0100 Subject: [PATCH 172/317] Complete test and fix typo --- spinedb_api/export_functions.py | 2 +- tests/test_DatabaseMapping.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/spinedb_api/export_functions.py b/spinedb_api/export_functions.py index 191ce571..bd54280b 100644 --- a/spinedb_api/export_functions.py +++ b/spinedb_api/export_functions.py @@ -131,7 +131,7 @@ def export_entity_classes(db_map, ids=Asterisk): def export_superclass_subclasses(db_map, ids=Asterisk): - return sorted(((x.superclass_name, x.subclass_name) for x in _get_items(db_map, "superclass_subclasses", ids))) + return sorted(((x.superclass_name, x.subclass_name) for x in _get_items(db_map, "superclass_subclass", ids))) def export_entities(db_map, ids=Asterisk): diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 02ea8bf7..785efa04 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -1605,8 +1605,12 @@ def test_add_entity_to_a_class_with_abstract_dimensions(self): import_functions.import_entities(self._db_map, (("fish", "Nemo"), ("dog", "Pulgoso"))) self._db_map.commit_session("Add test data.") item, error = self._db_map.add_item("entity", class_name="two_animals", element_name_list=("Nemo", "Pulgoso")) - print(item) - # self.fail() + self.assertTrue(item) + self.assertFalse(error) + self._db_map.commit_session("Add test data.") + entities = self._db_map.query(self._db_map.wide_entity_sq).all() + self.assertEqual(len(entities), 3) + self.assertIn("Nemo,Pulgoso", {x.element_name_list for x in entities}) class TestDatabaseMappingUpdate(unittest.TestCase): From 0ea0ee4cb96ab60b9c51a41ccd62d962c947ecd7 Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 31 Oct 2023 08:40:39 +0100 Subject: [PATCH 173/317] Fix creation of superclass_subclass table --- .../versions/5385f063bef2_create_superclass_subclass_table.py | 4 ++-- spinedb_api/helpers.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/spinedb_api/alembic/versions/5385f063bef2_create_superclass_subclass_table.py b/spinedb_api/alembic/versions/5385f063bef2_create_superclass_subclass_table.py index a8e1cc54..2fdb006b 100644 --- a/spinedb_api/alembic/versions/5385f063bef2_create_superclass_subclass_table.py +++ b/spinedb_api/alembic/versions/5385f063bef2_create_superclass_subclass_table.py @@ -20,8 +20,8 @@ def upgrade(): op.create_table( 'superclass_subclass', sa.Column('id', sa.Integer(), nullable=False), - sa.Column('superclass_id', sa.Integer(), nullable=True), - sa.Column('subclass_id', sa.Integer(), nullable=True), + sa.Column('superclass_id', sa.Integer(), nullable=False), + sa.Column('subclass_id', sa.Integer(), nullable=False), sa.ForeignKeyConstraint( ['subclass_id'], ['entity_class.id'], diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index ad86e52c..6a6bd2b5 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -360,11 +360,13 @@ def create_spine_metadata(): "superclass_id", Integer, ForeignKey("entity_class.id", onupdate="CASCADE", ondelete="CASCADE"), + nullable=False, ), Column( "subclass_id", Integer, ForeignKey("entity_class.id", onupdate="CASCADE", ondelete="CASCADE"), + nullable=False, unique=True, ), ) From 0b16b3c86ca77933d6c5126ddc82faf03a1c3cb0 Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 31 Oct 2023 08:48:48 +0100 Subject: [PATCH 174/317] Fix unique keys of list value item --- spinedb_api/mapped_items.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index e00728c1..3a6aa71e 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -514,7 +514,7 @@ class ListValueItem(ParsedValueBase): "type": ("str", "The value type."), "index": ("int, optional", "The value index."), } - _unique_keys = (("parameter_value_list_name", "parsed_value"), ("parameter_value_list_name", "index")) + _unique_keys = (("parameter_value_list_name", "value_and_type"), ("parameter_value_list_name", "index")) _references = {"parameter_value_list_id": ("parameter_value_list", "id")} _external_fields = {"parameter_value_list_name": ("parameter_value_list_id", "name")} _alt_references = {("parameter_value_list_name",): ("parameter_value_list", ("name",))} @@ -528,6 +528,11 @@ def _value_key(self): def _type_key(self): return "type" + def __getitem__(self, key): + if key == "value_and_type": + return (self["value"], self["type"]) + return super().__getitem__(key) + class AlternativeItem(MappedItemBase): fields = { From f7d5f59f81460dacb9a22e06ad0dda64248bbf9a Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 2 Nov 2023 10:17:52 +0100 Subject: [PATCH 175/317] Introduce MappedItem.check_mutability To prevent removal or update of items. At the moment is used to prevent modifying the superclass of a class with entities which would lead to chaos. --- spinedb_api/db_mapping.py | 41 +++++++++++------ spinedb_api/db_mapping_base.py | 79 +++++++++++++++++++++++---------- spinedb_api/import_functions.py | 8 ++-- spinedb_api/mapped_items.py | 15 +++++++ 4 files changed, 103 insertions(+), 40 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 7f638d76..87c50d87 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -441,11 +441,8 @@ def add_item(self, item_type, check=True, **kwargs): self._convert_legacy(item_type, kwargs) if not check: return mapped_table.add_item(kwargs), None - checked_item, error = mapped_table.check_item(kwargs) - return ( - mapped_table.add_item(checked_item).public_item if checked_item and not error else None, - error, - ) + checked_item, error = mapped_table.checked_item_and_error(kwargs) + return (mapped_table.add_item(checked_item).public_item if checked_item else None, error) def add_items(self, item_type, *items, check=True, strict=False): """Add many items to the in-memory mapping. @@ -497,7 +494,7 @@ def update_item(self, item_type, check=True, **kwargs): self._convert_legacy(item_type, kwargs) if not check: return mapped_table.update_item(kwargs), None - checked_item, error = mapped_table.check_item(kwargs, for_update=True) + checked_item, error = mapped_table.checked_item_and_error(kwargs, for_update=True) return (mapped_table.update_item(checked_item._asdict()).public_item if checked_item else None, error) def update_items(self, item_type, *items, check=True, strict=False): @@ -525,7 +522,7 @@ def update_items(self, item_type, *items, check=True, strict=False): updated.append(item) return updated, errors - def remove_item(self, item_type, id_): + def remove_item(self, item_type, id_, check=True): """Removes an item from the in-memory mapping. Example:: @@ -537,32 +534,48 @@ def remove_item(self, item_type, id_): Args: item_type (str): One of . id_ (int): The id of the item to remove. + check (bool, optional): Whether to carry out integrity checks. Returns: - tuple(:class:`PublicItem` or None, str): The removed item if any. + tuple(:class:`PublicItem` or None, str): The removed item and any errors. """ item_type = self.real_item_type(item_type) mapped_table = self.mapped_table(item_type) - return mapped_table.remove_item(id_).public_item + item, error = mapped_table.item_to_remove_and_error(id_) + if check and error: + return None, error + return mapped_table.remove_item(item).public_item, None - def remove_items(self, item_type, *ids): + def remove_items(self, item_type, *ids, check=True, strict=False): """Removes many items from the in-memory mapping. Args: item_type (str): One of . *ids (Iterable(int)): Ids of items to be removed. + check (bool): Whether or not to run integrity checks. + strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` + if the update of one of the items violates an integrity constraint. Returns: - list(:class:`PublicItem`): the removed items. + tuple(list(:class:`PublicItem`),list(str)): items successfully removed and found violations. """ - if not ids: - return [] item_type = self.real_item_type(item_type) ids = set(ids) if item_type == "alternative": # Do not remove the Base alternative ids.discard(1) - return [self.remove_item(item_type, id_) for id_ in ids] + if not ids: + return [], [] + removed, errors = [], [] + for id_ in ids: + item, error = self.remove_item(item_type, id_, check=check) + if error: + if strict: + raise SpineIntegrityError(error) + errors.append(error) + if item: + removed.append(item) + return removed, errors def restore_item(self, item_type, id_): """Restores a previously removed item into the in-memory mapping. diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 5a42262d..2feee21c 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -194,7 +194,7 @@ def _rollback(self): for item_type, to_add in to_add_by_type: mapped_table = self.mapped_table(item_type) for item in to_add: - if mapped_table.remove_item(item["id"]) is not None: + if mapped_table.remove_item(item) is not None: item.invalidate_id() return True @@ -392,13 +392,13 @@ def find_item_by_unique_key(self, item, skip_keys=(), fetch=True, complete=True) error = mapped_item.polish() if error: return {} - for key, value in mapped_item.unique_values(skip_keys=skip_keys): + for key, value in mapped_item.unique_key_values(skip_keys=skip_keys): current_item = self._unique_key_value_to_item(key, value, fetch=fetch) if current_item: return current_item return {} - def check_item(self, item, for_update=False, skip_keys=()): + def checked_item_and_error(self, item, for_update=False, skip_keys=()): # FIXME: The only use-case for skip_keys at the moment is that of importing scenario alternatives, # where we only want to match by (scen_name, alt_name) and not by (scen_name, rank) if for_update: @@ -408,41 +408,66 @@ def check_item(self, item, for_update=False, skip_keys=()): full_item, merge_error = current_item.merge(item) if full_item is None: return None, merge_error + mutability_error = current_item.check_mutability() + if mutability_error: + return None, mutability_error else: current_item = None full_item, merge_error = item, None candidate_item = self._make_item(full_item) - error = candidate_item.resolve_internal_fields(skip_keys=item.keys()) + error = self._prepare_item(candidate_item, current_item, item, skip_keys) if error: return None, error + return candidate_item, merge_error + + def _prepare_item(self, candidate_item, current_item, original_item, skip_keys): + """Prepares item for insertion or update, returns any errors. + + Args: + candidate_item (MappedItem) + current_item (MappedItem) + original_item (dict) + skip_keys (optional, tuple) + + Returns: + str or None: errors if any. + """ + error = candidate_item.resolve_internal_fields(skip_keys=original_item.keys()) + if error: + return error error = candidate_item.polish() if error: - return None, error + return error first_invalid_key = candidate_item.first_invalid_key() if first_invalid_key: - return None, f"invalid {first_invalid_key} for {self._item_type}" + return f"invalid {first_invalid_key} for {self._item_type}" try: - for key, value in candidate_item.unique_values(skip_keys=skip_keys): + for key, value in candidate_item.unique_key_values(skip_keys=skip_keys): empty = {k for k, v in zip(key, value) if v == ""} if empty: - return None, f"invalid empty keys {empty} for {self._item_type}" + return f"invalid empty keys {empty} for {self._item_type}" unique_item = self._unique_key_value_to_item(key, value) if unique_item not in (None, current_item) and unique_item.is_valid(): - return None, f"there's already a {self._item_type} with {dict(zip(key, value))}" + return f"there's already a {self._item_type} with {dict(zip(key, value))}" except KeyError as e: - return None, f"missing {e} for {self._item_type}" - if "id" not in candidate_item: - candidate_item["id"] = self._new_id() - return candidate_item, merge_error + return f"missing {e} for {self._item_type}" + + def item_to_remove_and_error(self, id_): + if id_ is Asterisk: + return self.wildcard_item, None + current_item = self.find_item({"id": id_}) + if not current_item: + return None, None + return current_item, current_item.check_mutability() def add_unique(self, item): id_ = item["id"] - for key, value in item.unique_values(): + for key, value in item.unique_key_values(): self._id_by_unique_key_value.setdefault(key, {})[value] = id_ def remove_unique(self, item): id_ = item["id"] - for key, value in item.unique_values(): + for key, value in item.unique_key_values(): id_by_value = self._id_by_unique_key_value.get(key, {}) if id_by_value.get(value) == id_: del id_by_value[value] @@ -490,18 +515,18 @@ def update_item(self, item): current_item.cascade_update() return current_item - def remove_item(self, id_): - if id_ is Asterisk: + def remove_item(self, item): + if not item: + return None + if item is self.wildcard_item: self.purged = True for current_item in self.valid_values(): self.remove_unique(current_item) current_item.cascade_remove(source=self.wildcard_item) return self.wildcard_item - current_item = self.find_item({"id": id_}) - if current_item: - self.remove_unique(current_item) - current_item.cascade_remove() - return current_item + self.remove_unique(item) + item.cascade_remove() + return item def restore_item(self, id_): if id_ is Asterisk: @@ -726,7 +751,7 @@ def unique_values_for_item(cls, item, skip_keys=()): if None not in value: yield key, value - def unique_values(self, skip_keys=()): + def unique_key_values(self, skip_keys=()): """Yields tuples of unique keys and their values. Args: @@ -787,6 +812,14 @@ def polish(self): self.setdefault(key, default_value) return "" + def check_mutability(self): + """Checks if this item can be mutated (updated or removed). Returns any errors. + + Returns: + str or None: error description if any. + """ + return "" + def _get_ref(self, ref_type, key_val, strong=True): """Collects a reference from the in-memory mapping. Adds this item to the reference's list of referrers if strong is True; diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 26dd94b6..cd2d1ad6 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -475,13 +475,15 @@ def _get_items_for_import(db_map, item_type, data, check_skip_keys=()): to_update = [] seen = {} for item in data: - checked_item, add_error = mapped_table.check_item(item, skip_keys=check_skip_keys) + checked_item, add_error = mapped_table.checked_item_and_error(item, skip_keys=check_skip_keys) if not add_error: if not _check_unique(item_type, checked_item, seen, errors): continue to_add.append(checked_item) continue - checked_item, update_error = mapped_table.check_item(item, for_update=True, skip_keys=check_skip_keys) + checked_item, update_error = mapped_table.checked_item_and_error( + item, for_update=True, skip_keys=check_skip_keys + ) if not update_error: if checked_item: if not _check_unique(item_type, checked_item, seen, errors): @@ -502,7 +504,7 @@ def _check_unique(item_type, checked_item, seen, errors): def _add_to_seen(checked_item, seen): - for key, value in checked_item.unique_values(): + for key, value in checked_item.unique_key_values(): if value in seen.get(key, set()): return dict(zip(key, value)) seen.setdefault(key, set()).add(value) diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 3a6aa71e..6d029b07 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -678,3 +678,18 @@ class SuperclassSubclassItem(MappedItemBase): ("subclass_name",): ("entity_class", ("name",)), } _internal_fields = {"superclass_id": (("superclass_name",), "id"), "subclass_id": (("subclass_name",), "id")} + + def _subclass_entities(self): + return self._db_map.get_items("entity", class_id=self["subclass_id"]) + + def polish(self): + error = super().polish() + if error: + return error + for ent in self._subclass_entities(): + pass + + def check_mutability(self): + if self._subclass_entities(): + return "can't modify the superclass of a class that already has entities" + return super().check_mutability() From 302d6e11d51d0ac2ceb7185796d9280c9b53d6c6 Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 2 Nov 2023 10:35:42 +0100 Subject: [PATCH 176/317] Call check_mutability also for adding items --- spinedb_api/db_mapping_base.py | 8 ++++---- spinedb_api/mapped_items.py | 9 +-------- 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 2feee21c..7e438a98 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -408,9 +408,6 @@ def checked_item_and_error(self, item, for_update=False, skip_keys=()): full_item, merge_error = current_item.merge(item) if full_item is None: return None, merge_error - mutability_error = current_item.check_mutability() - if mutability_error: - return None, mutability_error else: current_item = None full_item, merge_error = item, None @@ -433,6 +430,9 @@ def _prepare_item(self, candidate_item, current_item, original_item, skip_keys): str or None: errors if any. """ error = candidate_item.resolve_internal_fields(skip_keys=original_item.keys()) + if error: + return error + error = candidate_item.check_mutability() if error: return error error = candidate_item.polish() @@ -813,7 +813,7 @@ def polish(self): return "" def check_mutability(self): - """Checks if this item can be mutated (updated or removed). Returns any errors. + """Called before adding, updating, or removing this item. Returns any errors that prevent that. Returns: str or None: error description if any. diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 6d029b07..bbf4bf59 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -682,14 +682,7 @@ class SuperclassSubclassItem(MappedItemBase): def _subclass_entities(self): return self._db_map.get_items("entity", class_id=self["subclass_id"]) - def polish(self): - error = super().polish() - if error: - return error - for ent in self._subclass_entities(): - pass - def check_mutability(self): if self._subclass_entities(): - return "can't modify the superclass of a class that already has entities" + return "can't set or modify the superclass for a class that already has entities" return super().check_mutability() From 591d679c09b8451075cdf3dc1517f0375d8e6678 Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 2 Nov 2023 10:45:29 +0100 Subject: [PATCH 177/317] Add weak ref to superclass_subclass from entity_class --- spinedb_api/db_mapping_base.py | 5 ++--- spinedb_api/mapped_items.py | 6 +----- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 7e438a98..6edd917a 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -837,11 +837,10 @@ def _get_ref(self, ref_type, key_val, strong=True): mapped_table = self._db_map.mapped_table(ref_type) ref = mapped_table.find_item(key_val, fetch=False) if not ref: - if not strong: - return {} ref = mapped_table.find_item(key_val, fetch=True) if not ref: - self._corrupted = True + if strong: + self._corrupted = True return {} # Here we have a ref if strong: diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index bbf4bf59..959545d4 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -74,11 +74,7 @@ def __init__(self, *args, **kwargs): def __getitem__(self, key): if key in ("superclass_id", "superclass_name"): - superclass_subclass = self._db_map.mapped_table("superclass_subclass").find_item_by_unique_key( - {"subclass_name": self["name"]} - ) - # FIXME: create a weak reference too - return superclass_subclass.get(key) + return self._get_ref("superclass_subclass", {"subclass_id": self["id"]}, strong=False).get(key) return super().__getitem__(key) def merge(self, other): From ade3e60a4333ae6fd5dd47fb01ad7cc6dbab4c55 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Wed, 8 Nov 2023 15:09:20 +0200 Subject: [PATCH 178/317] Fix importing scenario alternatives _get_scenario_alternatives_for_import() did not properly handle cases where scenario alternatives were not ordered by before alternatives. Re spine-tools/Spine-Toolbox#2374 --- spinedb_api/import_functions.py | 38 +++++-- tests/test_import_functions.py | 195 ++++++++++++++++++++------------ 2 files changed, 150 insertions(+), 83 deletions(-) diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index cd2d1ad6..97c7ef0a 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -12,6 +12,7 @@ """ Functions for importing data into a Spine database in a standard format. """ +from collections import defaultdict from .parameter_value import to_database, fix_conflict from .helpers import _parse_metadata @@ -631,24 +632,37 @@ def _get_scenarios_for_import(db_map, data): def _get_scenario_alternatives_for_import(db_map, data): - alt_name_list_by_scen_name, errors = {}, [] + alt_name_list_by_scen_name = {} + errors = [] + successors_by_scen_name = defaultdict(dict) for scen_name, alt_name, *optionals in data: + successors_by_scen_name[scen_name][alt_name] = optionals[0] if optionals else None + for scen_name, successors in successors_by_scen_name.items(): scen = db_map.mapped_table("scenario").find_item({"name": scen_name}) if not scen: errors.append(f"no scenario with name {scen_name} to set alternatives for") continue - alternative_name_list = alt_name_list_by_scen_name.setdefault(scen_name, scen["alternative_name_list"]) - if alt_name in alternative_name_list: - alternative_name_list.remove(alt_name) - before_alt_name = optionals[0] if optionals else None - if before_alt_name is None: - alternative_name_list.append(alt_name) + alt_names = set(successors) + alternative_name_list = alt_name_list_by_scen_name[scen_name] = [ + a for a in scen["alternative_name_list"] if a not in alt_names + ] + for predecessor, successor in list(successors.items()): + if successor is None: + alternative_name_list.append(predecessor) + del successors[predecessor] + predecessors = {successor: predecessor for predecessor, successor in successors.items()} + predecessor_errors = [] + for predecessor in predecessors: + if predecessor not in successors and predecessor not in alternative_name_list: + predecessor_errors.append(f"{predecessor} is not in {scen_name}") + if predecessor_errors: + errors += predecessor_errors continue - if before_alt_name in alternative_name_list: - pos = alternative_name_list.index(before_alt_name) - alternative_name_list.insert(pos, alt_name) - else: - errors.append(f"{before_alt_name} is not in {scen_name}") + while predecessors: + for i, alt_name in enumerate(alternative_name_list): + if (predecessor := predecessors.pop(alt_name, None)) is not None: + alternative_name_list.insert(i, predecessor) + break def _data_iterator(): for scen_name, alternative_name_list in alt_name_list_by_scen_name.items(): diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py index 45313333..242c2287 100644 --- a/tests/test_import_functions.py +++ b/tests/test_import_functions.py @@ -70,7 +70,7 @@ def _assert_same_elements(test, obs_vals, exp_vals): test.assertEqual(obs_vals, exp_vals) -def create_diff_db_map(): +def create_db_map(): db_url = "sqlite://" return DatabaseMapping(db_url, username="UnitTest", create=True) @@ -113,7 +113,7 @@ def test_import_data_integration(self): class TestImportObjectClass(unittest.TestCase): def test_import_object_class(self): - db_map = create_diff_db_map() + db_map = create_db_map() _, errors = import_object_classes(db_map, ["new_class"]) self.assertFalse(errors) db_map.commit_session("test") @@ -123,7 +123,7 @@ def test_import_object_class(self): class TestImportObject(unittest.TestCase): def test_import_valid_objects(self): - db_map = create_diff_db_map() + db_map = create_db_map() import_object_classes(db_map, ["object_class"]) _, errors = import_objects(db_map, [["object_class", "new_object"]]) self.assertFalse(errors) @@ -132,13 +132,13 @@ def test_import_valid_objects(self): db_map.close() def test_import_object_with_invalid_object_class_name(self): - db_map = create_diff_db_map() + db_map = create_db_map() _, errors = import_objects(db_map, [["nonexistent_class", "new_object"]]) self.assertTrue(errors) db_map.close() def test_import_two_objects_with_same_name(self): - db_map = create_diff_db_map() + db_map = create_db_map() import_object_classes(db_map, ["object_class1", "object_class2"]) _, errors = import_objects(db_map, [["object_class1", "object"], ["object_class2", "object"]]) self.assertFalse(errors) @@ -154,7 +154,7 @@ def test_import_two_objects_with_same_name(self): db_map.close() def test_import_existing_object(self): - db_map = create_diff_db_map() + db_map = create_db_map() import_object_classes(db_map, ["object_class"]) import_objects(db_map, [["object_class", "object"]]) db_map.commit_session("test") @@ -167,7 +167,7 @@ def test_import_existing_object(self): class TestImportRelationshipClass(unittest.TestCase): def test_import_valid_relationship_class(self): - db_map = create_diff_db_map() + db_map = create_db_map() import_object_classes(db_map, ["object_class1", "object_class2"]) _, errors = import_relationship_classes(db_map, [["relationship_class", ["object_class1", "object_class2"]]]) self.assertFalse(errors) @@ -180,7 +180,7 @@ def test_import_valid_relationship_class(self): db_map.close() def test_import_relationship_class_with_invalid_object_class_name(self): - db_map = create_diff_db_map() + db_map = create_db_map() import_object_classes(db_map, ["object_class"]) _, errors = import_relationship_classes(db_map, [["relationship_class", ["object_class", "nonexistent"]]]) self.assertTrue(errors) @@ -189,7 +189,7 @@ def test_import_relationship_class_with_invalid_object_class_name(self): db_map.close() def test_import_relationship_class_name_twice(self): - db_map = create_diff_db_map() + db_map = create_db_map() import_object_classes(db_map, ["object_class1", "object_class2"]) _, errors = import_relationship_classes( db_map, [["new_rc", ["object_class1", "object_class2"]], ["new_rc", ["object_class1", "object_class2"]]] @@ -204,7 +204,7 @@ def test_import_relationship_class_name_twice(self): db_map.close() def test_import_existing_relationship_class(self): - db_map = create_diff_db_map() + db_map = create_db_map() import_object_classes(db_map, ["object_class1", "object_class2"]) import_relationship_classes(db_map, [["rc", ["object_class1", "object_class2"]]]) _, errors = import_relationship_classes(db_map, [["rc", ["object_class1", "object_class2"]]]) @@ -212,7 +212,7 @@ def test_import_existing_relationship_class(self): db_map.close() def test_import_relationship_class_with_one_object_class_as_None(self): - db_map = create_diff_db_map() + db_map = create_db_map() import_object_classes(db_map, ["object_class1"]) _, errors = import_relationship_classes(db_map, [["new_rc", ["object_class", None]]]) self.assertTrue(errors) @@ -223,7 +223,7 @@ def test_import_relationship_class_with_one_object_class_as_None(self): class TestImportObjectClassParameter(unittest.TestCase): def test_import_valid_object_class_parameter(self): - db_map = create_diff_db_map() + db_map = create_db_map() import_object_classes(db_map, ["object_class"]) _, errors = import_object_parameters(db_map, [["object_class", "new_parameter"]]) self.assertFalse(errors) @@ -232,13 +232,13 @@ def test_import_valid_object_class_parameter(self): db_map.close() def test_import_parameter_with_invalid_object_class_name(self): - db_map = create_diff_db_map() + db_map = create_db_map() _, errors = import_object_parameters(db_map, [["nonexistent_object_class", "new_parameter"]]) self.assertTrue(errors) db_map.close() def test_import_object_class_parameter_name_twice(self): - db_map = create_diff_db_map() + db_map = create_db_map() import_object_classes(db_map, ["object_class1", "object_class2"]) _, errors = import_object_parameters( db_map, [["object_class1", "new_parameter"], ["object_class2", "new_parameter"]] @@ -254,7 +254,7 @@ def test_import_object_class_parameter_name_twice(self): db_map.close() def test_import_existing_object_class_parameter(self): - db_map = create_diff_db_map() + db_map = create_db_map() import_object_classes(db_map, ["object_class"]) import_object_parameters(db_map, [["object_class", "parameter"]]) db_map.commit_session("test") @@ -281,7 +281,7 @@ def test_import_object_class_parameter_with_null_default_value_and_db_server_unp class TestImportRelationshipClassParameter(unittest.TestCase): def test_import_valid_relationship_class_parameter(self): - db_map = create_diff_db_map() + db_map = create_db_map() import_object_classes(db_map, ["object_class1", "object_class2"]) import_relationship_classes(db_map, [["relationship_class", ["object_class1", "object_class2"]]]) _, errors = import_relationship_parameters(db_map, [["relationship_class", "new_parameter"]]) @@ -299,13 +299,13 @@ def test_import_valid_relationship_class_parameter(self): db_map.close() def test_import_parameter_with_invalid_relationship_class_name(self): - db_map = create_diff_db_map() + db_map = create_db_map() _, errors = import_relationship_parameters(db_map, [["nonexistent_relationship_class", "new_parameter"]]) self.assertTrue(errors) db_map.close() def test_import_relationship_class_parameter_name_twice(self): - db_map = create_diff_db_map() + db_map = create_db_map() import_object_classes(db_map, ["object_class1", "object_class2"]) import_relationship_classes( db_map, @@ -331,7 +331,7 @@ def test_import_relationship_class_parameter_name_twice(self): db_map.close() def test_import_existing_relationship_class_parameter(self): - db_map = create_diff_db_map() + db_map = create_db_map() import_object_classes(db_map, ["object_class1", "object_class2"]) import_relationship_classes(db_map, [["relationship_class", ["object_class1", "object_class2"]]]) import_relationship_parameters(db_map, [["relationship_class", "new_parameter"]]) @@ -358,7 +358,7 @@ def test_import_relationships(self): db_map.close() def test_import_valid_relationship(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate(db_map) import_relationship_classes(db_map, [["relationship_class", ["object_class1", "object_class2"]]]) _, errors = import_relationships(db_map, [["relationship_class", ["object1", "object2"]]]) @@ -368,7 +368,7 @@ def test_import_valid_relationship(self): db_map.close() def test_import_valid_relationship_with_object_name_in_multiple_classes(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate(db_map) import_objects(db_map, [["object_class1", "duplicate"], ["object_class2", "duplicate"]]) import_relationship_classes(db_map, [["relationship_class", ["object_class1", "object_class2"]]]) @@ -379,7 +379,7 @@ def test_import_valid_relationship_with_object_name_in_multiple_classes(self): db_map.close() def test_import_relationship_with_invalid_class_name(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate(db_map) _, errors = import_relationships(db_map, [["nonexistent_relationship_class", ["object1", "object2"]]]) self.assertTrue(errors) @@ -388,7 +388,7 @@ def test_import_relationship_with_invalid_class_name(self): db_map.close() def test_import_relationship_with_invalid_object_name(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate(db_map) import_relationship_classes(db_map, [["relationship_class", ["object_class1", "object_class2"]]]) _, errors = import_relationships(db_map, [["relationship_class", ["nonexistent_object", "object2"]]]) @@ -398,7 +398,7 @@ def test_import_relationship_with_invalid_object_name(self): db_map.close() def test_import_existing_relationship(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate(db_map) import_relationship_classes(db_map, [["relationship_class", ["object_class1", "object_class2"]]]) import_relationships(db_map, [["relationship_class", ["object1", "object2"]]]) @@ -410,7 +410,7 @@ def test_import_existing_relationship(self): db_map.close() def test_import_relationship_with_one_None_object(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate(db_map) import_relationship_classes(db_map, [["relationship_class", ["object_class1", "object_class2"]]]) _, errors = import_relationships(db_map, [["relationship_class", [None, "object2"]]]) @@ -420,7 +420,7 @@ def test_import_relationship_with_one_None_object(self): db_map.close() def test_import_relationship_of_relationships(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate(db_map) import_data( db_map, @@ -561,7 +561,7 @@ def populate_with_relationship(db_map): import_relationships(db_map, [["relationship_class", ["object1", "object2"]]]) def test_import_valid_object_parameter_value(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate(db_map) _, errors = import_object_parameter_values(db_map, [["object_class1", "object1", "parameter", 1]]) self.assertFalse(errors) @@ -572,7 +572,7 @@ def test_import_valid_object_parameter_value(self): db_map.close() def test_import_valid_object_parameter_value_string(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate(db_map) _, errors = import_object_parameter_values(db_map, [["object_class1", "object1", "parameter", "value_string"]]) self.assertFalse(errors) @@ -583,7 +583,7 @@ def test_import_valid_object_parameter_value_string(self): db_map.close() def test_import_valid_object_parameter_value_with_duplicate_object_name(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate(db_map) import_objects(db_map, [["object_class1", "duplicate_object"], ["object_class2", "duplicate_object"]]) _, errors = import_object_parameter_values(db_map, [["object_class1", "duplicate_object", "parameter", 1]]) @@ -595,7 +595,7 @@ def test_import_valid_object_parameter_value_with_duplicate_object_name(self): db_map.close() def test_import_valid_object_parameter_value_with_duplicate_parameter_name(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate(db_map) import_object_parameters(db_map, [["object_class2", "parameter"]]) _, errors = import_object_parameter_values(db_map, [["object_class1", "object1", "parameter", 1]]) @@ -607,7 +607,7 @@ def test_import_valid_object_parameter_value_with_duplicate_parameter_name(self) db_map.close() def test_import_object_parameter_value_with_invalid_object(self): - db_map = create_diff_db_map() + db_map = create_db_map() import_object_classes(db_map, ["object_class"]) import_object_parameters(db_map, [["object_class", "parameter"]]) _, errors = import_object_parameter_values(db_map, [["object_class", "nonexistent_object", "parameter", 1]]) @@ -617,7 +617,7 @@ def test_import_object_parameter_value_with_invalid_object(self): db_map.close() def test_import_object_parameter_value_with_invalid_parameter(self): - db_map = create_diff_db_map() + db_map = create_db_map() import_object_classes(db_map, ["object_class"]) import_objects(db_map, ["object_class", "object"]) _, errors = import_object_parameter_values(db_map, [["object_class", "object", "nonexistent_parameter", 1]]) @@ -627,7 +627,7 @@ def test_import_object_parameter_value_with_invalid_parameter(self): db_map.close() def test_import_existing_object_parameter_value_update_the_value(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate(db_map) import_object_parameter_values(db_map, [["object_class1", "object1", "parameter", "initial_value"]]) _, errors = import_object_parameter_values(db_map, [["object_class1", "object1", "parameter", "new_value"]]) @@ -639,7 +639,7 @@ def test_import_existing_object_parameter_value_update_the_value(self): db_map.close() def test_import_existing_object_parameter_value_on_conflict_keep(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate(db_map) initial_value = {"type": "time_series", "data": [("2000-01-01T01:00", "1"), ("2000-01-01T02:00", "2")]} new_value = {"type": "time_series", "data": [("2000-01-01T02:00", "3"), ("2000-01-01T03:00", "4")]} @@ -656,7 +656,7 @@ def test_import_existing_object_parameter_value_on_conflict_keep(self): db_map.close() def test_import_existing_object_parameter_value_on_conflict_replace(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate(db_map) initial_value = {"type": "time_series", "data": [("2000-01-01T01:00", "1"), ("2000-01-01T02:00", "2")]} new_value = {"type": "time_series", "data": [("2000-01-01T02:00", "3"), ("2000-01-01T03:00", "4")]} @@ -673,7 +673,7 @@ def test_import_existing_object_parameter_value_on_conflict_replace(self): db_map.close() def test_import_existing_object_parameter_value_on_conflict_merge(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate(db_map) initial_value = {"type": "time_series", "data": [("2000-01-01T01:00", "1"), ("2000-01-01T02:00", "2")]} new_value = {"type": "time_series", "data": [("2000-01-01T02:00", "3"), ("2000-01-01T03:00", "4")]} @@ -692,7 +692,7 @@ def test_import_existing_object_parameter_value_on_conflict_merge(self): db_map.close() def test_import_existing_object_parameter_value_on_conflict_merge_map(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate(db_map) initial_value = { "type": "map", @@ -721,7 +721,7 @@ def test_import_existing_object_parameter_value_on_conflict_merge_map(self): db_map.close() def test_import_duplicate_object_parameter_value(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate(db_map) _, errors = import_object_parameter_values( db_map, @@ -735,7 +735,7 @@ def test_import_duplicate_object_parameter_value(self): db_map.close() def test_import_object_parameter_value_with_alternative(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate(db_map) import_alternatives(db_map, ["alternative"]) count, errors = import_object_parameter_values( @@ -752,7 +752,7 @@ def test_import_object_parameter_value_with_alternative(self): db_map.close() def test_import_object_parameter_value_fails_with_nonexistent_alternative(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate(db_map) count, errors = import_object_parameter_values( db_map, [["object_class1", "object1", "parameter", 1, "nonexistent_alternative"]] @@ -762,7 +762,7 @@ def test_import_object_parameter_value_fails_with_nonexistent_alternative(self): db_map.close() def test_import_parameter_values_from_committed_value_list(self): - db_map = create_diff_db_map() + db_map = create_db_map() import_data(db_map, parameter_value_lists=(("values_1", 5.0),)) db_map.commit_session("test") count, errors = import_data( @@ -781,7 +781,7 @@ def test_import_parameter_values_from_committed_value_list(self): db_map.close() def test_valid_object_parameter_value_from_value_list(self): - db_map = create_diff_db_map() + db_map = create_db_map() import_parameter_value_lists(db_map, (("values_1", 5.0),)) import_object_classes(db_map, ("object_class",)) import_object_parameters(db_map, (("object_class", "parameter", None, "values_1"),)) @@ -797,7 +797,7 @@ def test_valid_object_parameter_value_from_value_list(self): db_map.close() def test_non_existent_object_parameter_value_from_value_list_fails_gracefully(self): - db_map = create_diff_db_map() + db_map = create_db_map() import_parameter_value_lists(db_map, (("values_1", 5.0),)) import_object_classes(db_map, ("object_class",)) import_object_parameters(db_map, (("object_class", "parameter", None, "values_1"),)) @@ -808,7 +808,7 @@ def test_non_existent_object_parameter_value_from_value_list_fails_gracefully(se db_map.close() def test_import_valid_relationship_parameter_value(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate_with_relationship(db_map) _, errors = import_relationship_parameter_values( db_map, [["relationship_class", ["object1", "object2"], "parameter", 1]] @@ -821,7 +821,7 @@ def test_import_valid_relationship_parameter_value(self): db_map.close() def test_import_valid_relationship_parameter_value_with_duplicate_parameter_name(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate_with_relationship(db_map) import_relationship_classes(db_map, [["relationship_class2", ["object_class2", "object_class1"]]]) import_relationship_parameters(db_map, [["relationship_class2", "parameter"]]) @@ -836,7 +836,7 @@ def test_import_valid_relationship_parameter_value_with_duplicate_parameter_name db_map.close() def test_import_valid_relationship_parameter_value_with_duplicate_object_name(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate_with_relationship(db_map) import_objects(db_map, [["object_class1", "duplicate_object"], ["object_class2", "duplicate_object"]]) import_relationships(db_map, [["relationship_class", ["duplicate_object", "duplicate_object"]]]) @@ -851,7 +851,7 @@ def test_import_valid_relationship_parameter_value_with_duplicate_object_name(se db_map.close() def test_import_relationship_parameter_value_with_invalid_object(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate_with_relationship(db_map) _, errors = import_relationship_parameter_values( db_map, [["relationship_class", ["nonexistent_object", "object2"], "parameter", 1]] @@ -862,7 +862,7 @@ def test_import_relationship_parameter_value_with_invalid_object(self): db_map.close() def test_import_relationship_parameter_value_with_invalid_relationship_class(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate_with_relationship(db_map) _, errors = import_relationship_parameter_values( db_map, [["nonexistent_class", ["object1", "object2"], "parameter", 1]] @@ -873,7 +873,7 @@ def test_import_relationship_parameter_value_with_invalid_relationship_class(sel db_map.close() def test_import_relationship_parameter_value_with_invalid_parameter(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate_with_relationship(db_map) _, errors = import_relationship_parameter_values( db_map, [["relationship_class", ["object1", "object2"], "nonexistent_parameter", 1]] @@ -884,7 +884,7 @@ def test_import_relationship_parameter_value_with_invalid_parameter(self): db_map.close() def test_import_existing_relationship_parameter_value(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate_with_relationship(db_map) import_relationship_parameter_values( db_map, [["relationship_class", ["object1", "object2"], "parameter", "initial_value"]] @@ -900,7 +900,7 @@ def test_import_existing_relationship_parameter_value(self): db_map.close() def test_import_duplicate_relationship_parameter_value(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate_with_relationship(db_map) _, errors = import_relationship_parameter_values( db_map, @@ -917,7 +917,7 @@ def test_import_duplicate_relationship_parameter_value(self): db_map.close() def test_import_relationship_parameter_value_with_alternative(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate_with_relationship(db_map) import_alternatives(db_map, ["alternative"]) count, errors = import_relationship_parameter_values( @@ -935,7 +935,7 @@ def test_import_relationship_parameter_value_with_alternative(self): db_map.close() def test_import_relationship_parameter_value_fails_with_nonexistent_alternative(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate(db_map) count, errors = import_relationship_parameter_values( db_map, [["relationship_class", ["object1", "object2"], "parameter", 1, "alternative"]] @@ -945,7 +945,7 @@ def test_import_relationship_parameter_value_fails_with_nonexistent_alternative( db_map.close() def test_valid_relationship_parameter_value_from_value_list(self): - db_map = create_diff_db_map() + db_map = create_db_map() import_parameter_value_lists(db_map, (("values_1", 5.0),)) import_object_classes(db_map, ("object_class",)) import_objects(db_map, (("object_class", "my_object"),)) @@ -965,7 +965,7 @@ def test_valid_relationship_parameter_value_from_value_list(self): db_map.close() def test_non_existent_relationship_parameter_value_from_value_list_fails_gracefully(self): - db_map = create_diff_db_map() + db_map = create_db_map() import_parameter_value_lists(db_map, (("values_1", 5.0),)) import_object_classes(db_map, ("object_class",)) import_objects(db_map, (("object_class", "my_object"),)) @@ -1023,7 +1023,7 @@ def test_import_twelfth_value(self): class TestImportAlternative(unittest.TestCase): def test_single_alternative(self): - db_map = create_diff_db_map() + db_map = create_db_map() count, errors = import_alternatives(db_map, ["alternative"]) self.assertEqual(count, 1) self.assertFalse(errors) @@ -1035,7 +1035,7 @@ def test_single_alternative(self): db_map.close() def test_alternative_description(self): - db_map = create_diff_db_map() + db_map = create_db_map() count, errors = import_alternatives(db_map, [["alternative", "description"]]) self.assertEqual(count, 1) self.assertFalse(errors) @@ -1046,7 +1046,7 @@ def test_alternative_description(self): db_map.close() def test_update_alternative_description(self): - db_map = create_diff_db_map() + db_map = create_db_map() count, errors = import_alternatives(db_map, [["Base", "new description"]]) self.assertEqual(count, 1) self.assertFalse(errors) @@ -1059,7 +1059,7 @@ def test_update_alternative_description(self): class TestImportScenario(unittest.TestCase): def test_single_scenario(self): - db_map = create_diff_db_map() + db_map = create_db_map() count, errors = import_scenarios(db_map, ["scenario"]) self.assertEqual(count, 1) self.assertFalse(errors) @@ -1069,7 +1069,7 @@ def test_single_scenario(self): db_map.close() def test_scenario_with_description(self): - db_map = create_diff_db_map() + db_map = create_db_map() count, errors = import_scenarios(db_map, [["scenario", False, "description"]]) self.assertEqual(count, 1) self.assertFalse(errors) @@ -1079,7 +1079,7 @@ def test_scenario_with_description(self): db_map.close() def test_update_scenario_description(self): - db_map = create_diff_db_map() + db_map = create_db_map() import_scenarios(db_map, [["scenario", False, "initial description"]]) count, errors = import_scenarios(db_map, [["scenario", False, "new description"]]) self.assertEqual(count, 1) @@ -1092,7 +1092,7 @@ def test_update_scenario_description(self): class TestImportScenarioAlternative(unittest.TestCase): def setUp(self): - self._db_map = create_diff_db_map() + self._db_map = create_db_map() def tearDown(self): self._db_map.close() @@ -1134,8 +1134,61 @@ def test_fails_with_nonexistent_before_alternative(self): count, errors = import_scenario_alternatives( self._db_map, [["scenario", "alternative", "nonexistent_alternative"]] ) - self.assertTrue(errors) + self.assertEqual(errors, ["nonexistent_alternative is not in scenario"]) self.assertEqual(count, 2) + scenario_alternatives = self.scenario_alternatives() + self.assertEqual(scenario_alternatives, {}) + + def test_importing_existing_scenario_alternative_does_not_alter_scenario_alternatives(self): + count, errors = import_scenario_alternatives( + self._db_map, + [["scenario", "alternative2", "alternative1"], ["scenario", "alternative1"]], + ) + self.assertFalse(errors) + self.assertEqual(count, 5) + scenario_alternatives = self.scenario_alternatives() + self.assertEqual(scenario_alternatives, {"scenario": {"alternative1": 2, "alternative2": 1}}) + count, errors = import_scenario_alternatives( + self._db_map, + [["scenario", "alternative1"]], + ) + self.assertFalse(errors) + self.assertEqual(count, 0) + + def test_import_scenario_alternatives_in_arbitrary_order(self): + count, errors = import_scenarios(self._db_map, [('A (1)', False, '')]) + self.assertEqual(errors, []) + self.assertEqual(count, 1) + count, errors = import_alternatives( + self._db_map, [('Base', 'Base alternative'), ('b', ''), ('c', ''), ('d', '')] + ) + self.assertEqual(errors, []) + self.assertEqual(count, 3) + count, errors = import_scenario_alternatives( + self._db_map, [('A (1)', 'c', 'd'), ('A (1)', 'd', None), ('A (1)', 'Base', 'b'), ('A (1)', 'b', 'c')] + ) + self.assertEqual(errors, []) + self.assertEqual(count, 4) + scenario_alternatives = self.scenario_alternatives() + self.assertEqual(scenario_alternatives, {"A (1)": {"Base": 1, "b": 2, "c": 3, "d": 4}}) + + def test_insert_scenario_alternative_in_the_middle_of_other_alternatives(self): + count, errors = import_scenario_alternatives( + self._db_map, + [["scenario", "alternative2", "alternative1"], ["scenario", "alternative1"]], + ) + self.assertFalse(errors) + self.assertEqual(count, 5) + scenario_alternatives = self.scenario_alternatives() + self.assertEqual(scenario_alternatives, {"scenario": {"alternative1": 2, "alternative2": 1}}) + count, errors = import_scenario_alternatives( + self._db_map, + [["scenario", "alternative3", "alternative1"]], + ) + self.assertFalse(errors) + self.assertEqual(count, 3) + scenario_alternatives = self.scenario_alternatives() + self.assertEqual(scenario_alternatives, {"scenario": {"alternative1": 3, "alternative2": 1, "alternative3": 2}}) def scenario_alternatives(self): self._db_map.commit_session("test") @@ -1157,7 +1210,7 @@ def scenario_alternatives(self): class TestImportMetadata(unittest.TestCase): def test_import_metadata(self): - db_map = create_diff_db_map() + db_map = create_db_map() count, errors = import_metadata(db_map, ['{"name": "John", "age": 17}', '{"name": "Charly", "age": 90}']) self.assertEqual(count, 4) self.assertFalse(errors) @@ -1171,7 +1224,7 @@ def test_import_metadata(self): db_map.close() def test_import_metadata_with_duplicate_entry(self): - db_map = create_diff_db_map() + db_map = create_db_map() count, errors = import_metadata(db_map, ['{"name": "John", "age": 17}', '{"name": "Charly", "age": 17}']) self.assertEqual(count, 3) self.assertFalse(errors) @@ -1184,7 +1237,7 @@ def test_import_metadata_with_duplicate_entry(self): db_map.close() def test_import_metadata_with_nested_dict(self): - db_map = create_diff_db_map() + db_map = create_db_map() count, errors = import_metadata(db_map, ['{"name": "John", "info": {"age": 17, "city": "LA"}}']) db_map.commit_session("test") metadata = [(x.name, x.value) for x in db_map.query(db_map.metadata_sq)] @@ -1196,7 +1249,7 @@ def test_import_metadata_with_nested_dict(self): db_map.close() def test_import_metadata_with_nested_list(self): - db_map = create_diff_db_map() + db_map = create_db_map() count, errors = import_metadata(db_map, ['{"contributors": [{"name": "John"}, {"name": "Charly"}]}']) db_map.commit_session("test") metadata = [(x.name, x.value) for x in db_map.query(db_map.metadata_sq)] @@ -1208,7 +1261,7 @@ def test_import_metadata_with_nested_list(self): db_map.close() def test_import_unformatted_metadata(self): - db_map = create_diff_db_map() + db_map = create_db_map() count, errors = import_metadata(db_map, ['not a JSON object']) db_map.commit_session("test") metadata = [(x.name, x.value) for x in db_map.query(db_map.metadata_sq)] @@ -1233,7 +1286,7 @@ def populate(db_map): import_metadata(db_map, ['{"co-author": "John", "age": 17}', '{"co-author": "Charly", "age": 90}']) def test_import_object_metadata(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate(db_map) count, errors = import_object_metadata( db_map, @@ -1256,7 +1309,7 @@ def test_import_object_metadata(self): db_map.close() def test_import_relationship_metadata(self): - db_map = create_diff_db_map() + db_map = create_db_map() self.populate(db_map) count, errors = import_relationship_metadata( db_map, @@ -1279,7 +1332,7 @@ def test_import_relationship_metadata(self): class TestImportParameterValueMetadata(unittest.TestCase): def setUp(self): - self._db_map = create_diff_db_map() + self._db_map = create_db_map() import_metadata(self._db_map, ['{"co-author": "John", "age": 17}']) def tearDown(self): From 5c38a210b525bc9194b6b122bd3adc832d6d252d Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 9 Nov 2023 08:58:55 +0100 Subject: [PATCH 179/317] Fix docs build and update tutorial --- docs/source/conf.py | 2 +- docs/source/tutorial.rst | 57 +++++++++++++++++++++------------------ spinedb_api/db_mapping.py | 6 ++++- 3 files changed, 37 insertions(+), 28 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 9e11999d..83a3f00a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -113,7 +113,7 @@ def _process_docstring(app, what, name, obj, options, lines): else: new_lines = [] for item_type in DatabaseMapping.item_types(): - factory = DatabaseMapping._item_factory(item_type) + factory = DatabaseMapping.item_factory(item_type) new_lines.extend([item_type, len(item_type) * "-", ""]) new_lines.extend( [ diff --git a/docs/source/tutorial.rst b/docs/source/tutorial.rst index b1c6da53..92e788a8 100644 --- a/docs/source/tutorial.rst +++ b/docs/source/tutorial.rst @@ -65,18 +65,18 @@ that we want the DB to be created at the given URL. Adding data ----------- -To insert data, we use :meth:`~.DatabaseMapping.add_item`. +To insert data, we use e.g. :meth:`~.DatabaseMapping.add_entity_class_item`, :meth:`~.DatabaseMapping.add_entity_item`, +and so on. Let's begin the party by adding a couple of entity classes:: - db_map.add_item("entity_class", name="fish", description="It swims.") - db_map.add_item("entity_class", name="cat", description="Eats fish.") + db_map.add_entity_class_item(name="fish", description="It swims.") + db_map.add_entity_class_item(name="cat", description="Eats fish.") Now let's add a multi-dimensional entity class between the two above. For this we need to specify the class names as `dimension_name_list`:: - db_map.add_item( - "entity_class", + db_map.add_entity_class_item( name="fish__cat", dimension_name_list=("fish", "cat"), description="A fish getting eaten by a cat?", @@ -84,28 +84,28 @@ as `dimension_name_list`:: Let's add entities to our zero-dimensional classes:: - db_map.add_item("entity", class_name="fish", name="Nemo", description="Lost (for now).") - db_map.add_item( - "entity", + db_map.add_entity_item(class_name="fish", name="Nemo", description="Lost (for now).") + db_map.add_entity_item( class_name="cat", name="Felix", description="The wonderful wonderful cat." ) Let's add a multi-dimensional entity to our multi-dimensional class. For this we need to specify the entity names as `element_name_list`:: - db_map.add_item("entity", class_name="fish__cat", element_name_list=("Nemo", "Felix")) + db_map.add_entity_item(class_name="fish__cat", element_name_list=("Nemo", "Felix")) Let's add a parameter definition for one of our entity classes:: - db_map.add_item("parameter_definition", entity_class_name="fish", name="color") + db_map.add_parameter_definition_item(entity_class_name="fish", name="color") Finally, let's specify a parameter value for one of our entities. -We use :func:`.to_database` to convert our value -into a tuple of value and type to specify for our parameter value item:: +First, we use :func:`.to_database` to convert the value we want to give into a tuple of ``value`` and ``type``:: value, type_ = api.to_database("mainly orange") - db_map.add_item( - "parameter_value", + +Now we create our parameter value:: + + db_map.add_parameter_value_item( entity_class_name="fish", entity_byname=("Nemo",), parameter_definition_name="color", @@ -114,7 +114,7 @@ into a tuple of value and type to specify for our parameter value item:: type=type_ ) -Note that in the above, we refer to the entity by its *byname* which is a tuple of its elements. +Note that in the above, we refer to the entity by its *byname*. We also set the value to belong to an *alternative* called ``Base`` which is readily available in new databases. @@ -126,30 +126,32 @@ which is readily available in new databases. Retrieving data --------------- -To retrieve data, we use :meth:`~.DatabaseMapping.get_item`. This implicitly fetches data from the DB +To retrieve data, we use e.g. :meth:`~.DatabaseMapping.get_entity_class_item`, +:meth:`~.DatabaseMapping.get_entity_item`, etc. +This implicitly fetches data from the DB into the in-memory mapping, if not already there. For example, let's find one of the entities we inserted above:: felix_item = db_map.get_entity_item(class_name="cat", name="Felix") assert felix_item["description"] == "The wonderful wonderful cat." -Above, ``felix_item`` is a :class:`~.PublicItem` object, representing an item (or row) in a Spine DB. +Above, ``felix_item`` is a :class:`~.PublicItem` object, representing an item. Let's find our multi-dimensional entity:: - nemo_felix_item = db_map.get_item("entity", class_name="fish__cat", element_name_list=("Nemo", "Felix")) + nemo_felix_item = db_map.get_entity_item("entity", class_name="fish__cat", element_name_list=("Nemo", "Felix")) assert nemo_felix_item["dimension_name_list"] == ('fish', 'cat') -Now let's retrieve our parameter value. -We use :func:`.from_database` to convert the value and type from the parameter value item into our original value:: +Now let's retrieve our parameter value:: - nemo_color_item = db_map.get_item( - "parameter_value", + nemo_color_item = db_map.get_parameter_value_item( entity_class_name="fish", entity_byname=("Nemo",), parameter_definition_name="color", alternative_name="Base" ) + +We use :func:`.from_database` to convert the value and type from the parameter value into our original value:: nemo_color = api.from_database(nemo_color_item["value"], nemo_color_item["type"]) assert nemo_color == "mainly orange" @@ -169,13 +171,12 @@ To update data, we use the :meth:`~.PublicItem.update` method of :class:`~.Publi Let's rename our fish entity to avoid any copyright infringements:: - db_map.get_item("entity", class_name="fish", name="Nemo").update(name="NotNemo") + db_map.get_entity_item(class_name="fish", name="Nemo").update(name="NotNemo") To be safe, let's also change the color:: new_value, new_type = api.to_database("not that orange") - db_map.get_item( - "parameter_value", + db_map.get_parameter_value_item( entity_class_name="fish", entity_byname=("NotNemo",), parameter_definition_name="color", @@ -190,13 +191,17 @@ Removing data You know what, let's just remove the entity entirely. To do this we use the :meth:`~.PublicItem.remove` method of :class:`~.PublicItem`:: - db_map.get_item("entity", class_name="fish", name="NotNemo").remove() + db_map.get_entity_item(class_name="fish", name="NotNemo").remove() Note that the above call removes items in *cascade*, meaning that items that depend on ``"NotNemo"`` will get removed as well. We have one such item in the database, namely the ``"color"`` parameter value which also gets dropped when the above method is called. +Restoring data +-------------- + +TODO Committing data --------------- diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 87c50d87..a6628d72 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -785,7 +785,11 @@ def _add_convenience_methods(node): return node for item_type in DatabaseMapping.item_types(): factory = DatabaseMapping.item_factory(item_type) - uq_fields = {f_name: factory.fields[f_name] for f_names in factory._unique_keys for f_name in f_names} + uq_fields = { + f_name: factory.fields[f_name] + for f_names in factory._unique_keys + for f_name in set(f_names) & set(factory.fields.keys()) + } a = "an" if any(item_type.lower().startswith(x) for x in "aeiou") else "a" padding = 20 * " " get_kwargs = f"\n{padding}".join( From 1b416d90786e4387103a1bd2d619cbbea92c4b70 Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 9 Nov 2023 11:34:00 +0100 Subject: [PATCH 180/317] Introduce get_{item_type}_items convenience method for DatabaseMapping --- spinedb_api/db_mapping.py | 62 +++++++++++++++++++++++++++++++-------- 1 file changed, 49 insertions(+), 13 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index a6628d72..cd514fa3 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -78,7 +78,8 @@ class DatabaseMapping(DatabaseMappingQueryMixin, DatabaseMappingCommitMixin, Dat otherwise it is fetched from the DB, stored in memory, and then returned. In other words, the data is fetched from the DB exactly once. - For convenience, we also provide specialized 'get' methods for each item type, e.g., :meth:`get_entity_item`. + For convenience, we also provide specialized 'get' methods for each item type, e.g., :meth:`get_entity_item` + and :meth:`get_entity_items`. Data is added via :meth:`add_item`; updated via :meth:`update_item`; @@ -774,6 +775,7 @@ def get_filter_configs(self): # Define convenience methods for it in DatabaseMapping.item_types(): setattr(DatabaseMapping, "get_" + it + "_item", partialmethod(DatabaseMapping.get_item, it)) + setattr(DatabaseMapping, "get_" + it + "_items", partialmethod(DatabaseMapping.get_items, it)) setattr(DatabaseMapping, "add_" + it + "_item", partialmethod(DatabaseMapping.add_item, it)) setattr(DatabaseMapping, "update_" + it + "_item", partialmethod(DatabaseMapping.update_item, it)) setattr(DatabaseMapping, "remove_" + it + "_item", partialmethod(DatabaseMapping.remove_item, it)) @@ -783,22 +785,25 @@ def get_filter_configs(self): def _add_convenience_methods(node): if node.name != "DatabaseMapping": return node - for item_type in DatabaseMapping.item_types(): - factory = DatabaseMapping.item_factory(item_type) - uq_fields = { + + def _a(item_type): + return "an" if any(item_type.lower().startswith(x) for x in "aeiou") else "a" + + def _uq_fields(factory): + return { f_name: factory.fields[f_name] for f_names in factory._unique_keys for f_name in set(f_names) & set(factory.fields.keys()) } - a = "an" if any(item_type.lower().startswith(x) for x in "aeiou") else "a" - padding = 20 * " " - get_kwargs = f"\n{padding}".join( - [f"{f_name} ({f_type}): {f_value}" for f_name, (f_type, f_value) in uq_fields.items()] - ) - add_kwargs = f"\n{padding}".join( - [f"{f_name} ({f_type}): {f_value}" for f_name, (f_type, f_value) in factory.fields.items()] - ) - update_kwargs = f"id (int): The id of the item to update.\n{padding}" + add_kwargs + + def _kwargs(fields): + return f"\n{padding}".join([f"{f_name} ({f_type}): {f_value}" for f_name, (f_type, f_value) in fields.items()]) + + padding = 20 * " " + for item_type in DatabaseMapping.item_types(): + factory = DatabaseMapping.item_factory(item_type) + a = _a(item_type) + get_kwargs = _kwargs(_uq_fields(factory)) child = astroid.extract_node( f''' def get_{item_type}_item(self, fetch=True, skip_removed=True, **kwargs): @@ -816,6 +821,31 @@ def get_{item_type}_item(self, fetch=True, skip_removed=True, **kwargs): ) child.parent = node node.body.append(child) + for item_type in DatabaseMapping.item_types(): + factory = DatabaseMapping.item_factory(item_type) + a = _a(item_type) + get_kwargs = _kwargs(_uq_fields(factory)) + child = astroid.extract_node( + f''' + def get_{item_type}_items(self, fetch=True, skip_removed=True, **kwargs): + """Finds and returns all {item_type} items. + + Args: + fetch (bool, optional): Whether to fetch the DB before returning the items. + skip_removed (bool, optional): Whether to ignore removed items. + {get_kwargs} + + Returns: + list(:class:`PublicItem`): The items. + """ + ''' + ) + child.parent = node + node.body.append(child) + for item_type in DatabaseMapping.item_types(): + factory = DatabaseMapping.item_factory(item_type) + a = _a(item_type) + add_kwargs = _kwargs(factory.fields) child = astroid.extract_node( f''' def add_{item_type}_item(self, check=True, **kwargs): @@ -832,6 +862,10 @@ def add_{item_type}_item(self, check=True, **kwargs): ) child.parent = node node.body.append(child) + for item_type in DatabaseMapping.item_types(): + factory = DatabaseMapping.item_factory(item_type) + a = _a(item_type) + update_kwargs = f"id (int): The id of the item to update.\n{padding}" + _kwargs(factory.fields) child = astroid.extract_node( f''' def update_{item_type}_item(self, check=True, **kwargs): @@ -848,6 +882,7 @@ def update_{item_type}_item(self, check=True, **kwargs): ) child.parent = node node.body.append(child) + for item_type in DatabaseMapping.item_types(): child = astroid.extract_node( f''' def remove_{item_type}_item(self, id): @@ -863,6 +898,7 @@ def remove_{item_type}_item(self, id): ) child.parent = node node.body.append(child) + for item_type in DatabaseMapping.item_types(): child = astroid.extract_node( f''' def restore_{item_type}_item(self, id): From 3004d7a1c4751bfaae34fb33b47eb2450e68eeb1 Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 9 Nov 2023 11:40:47 +0100 Subject: [PATCH 181/317] Provide cascade_remove_items for legacy --- spinedb_api/db_mapping.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index cd514fa3..9aceaf7b 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -578,6 +578,11 @@ def remove_items(self, item_type, *ids, check=True, strict=False): removed.append(item) return removed, errors + def cascade_remove_items(self, cache=None, **kwargs): + # Legacy + for item_type, ids in kwargs.items(): + self.remove_items(item_type, *ids) + def restore_item(self, item_type, id_): """Restores a previously removed item into the in-memory mapping. From 1ff7a9cfc6456e2ad316c90518a26ddac6c0c5c3 Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 9 Nov 2023 12:30:35 +0100 Subject: [PATCH 182/317] Get rid of the ticket in fetch_more, let clients manage fetch completion --- spinedb_api/db_mapping.py | 4 ++-- spinedb_api/db_mapping_base.py | 21 ++++++--------------- 2 files changed, 8 insertions(+), 17 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 9aceaf7b..30596880 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -628,7 +628,7 @@ def purge_items(self, item_type): """ return bool(self.remove_items(item_type, Asterisk)) - def fetch_more(self, item_type, offset=0, limit=None, ticket=None): + def fetch_more(self, item_type, offset=0, limit=None): """Fetches items from the DB into the in-memory mapping, incrementally. Args: @@ -640,7 +640,7 @@ def fetch_more(self, item_type, offset=0, limit=None, ticket=None): list(:class:`PublicItem`): The items fetched. """ item_type = self.real_item_type(item_type) - return [x.public_item for x in self.do_fetch_more(item_type, offset=offset, limit=limit, ticket=ticket)] + return [x.public_item for x in self.do_fetch_more(item_type, offset=offset, limit=limit)] def fetch_all(self, *item_types): """Fetches items from the DB into the in-memory mapping. diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 6edd917a..028eea4f 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -39,7 +39,6 @@ class DatabaseMappingBase: def __init__(self): self._mapped_tables = {} - self._completed_tickets = {} item_types = self.item_types() self._sorted_item_types = [] while item_types: @@ -200,7 +199,6 @@ def _rollback(self): def _refresh(self): """Clears fetch progress, so the DB is queried again.""" - self._completed_tickets.clear() def _check_item_type(self, item_type): if item_type not in self.all_item_types(): @@ -221,7 +219,6 @@ def reset(self, *item_types): self._add_descendants(item_types) for item_type in item_types: self._mapped_tables.pop(item_type, None) - self._completed_tickets.pop(item_type, None) def _add_descendants(self, item_types): while True: @@ -241,17 +238,16 @@ def _get_next_chunk(self, item_type, offset, limit, **kwargs): """Gets chunk of items from the DB. Returns: - tuple(list(dict),bool): list of dictionary items and whether this is the last chunk. + list(dict): list of dictionary items. """ qry = self._make_query(item_type, **kwargs) if not qry: - return [], True + return [] if not limit: - return [dict(x) for x in qry], True - chunk = [dict(x) for x in qry.limit(limit).offset(offset)] - return chunk, len(chunk) < limit + return [dict(x) for x in qry] + return [dict(x) for x in qry.limit(limit).offset(offset)] - def do_fetch_more(self, item_type, offset=0, limit=None, ticket=None, **kwargs): + def do_fetch_more(self, item_type, offset=0, limit=None, **kwargs): """Fetches items from the DB and adds them to the mapping. Args: @@ -260,12 +256,7 @@ def do_fetch_more(self, item_type, offset=0, limit=None, ticket=None, **kwargs): Returns: list(MappedItem): items fetched from the DB. """ - completed_tickets = self._completed_tickets.setdefault(item_type, set()) - if ticket in completed_tickets: - return [] - chunk, completed = self._get_next_chunk(item_type, offset, limit, **kwargs) - if ticket is not None and completed: - completed_tickets.add(ticket) + chunk = self._get_next_chunk(item_type, offset, limit, **kwargs) if not chunk: return [] mapped_table = self.mapped_table(item_type) From 3b2b887f8181ef95017a347a09daf62669065e3b Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 9 Nov 2023 12:33:23 +0100 Subject: [PATCH 183/317] Accept filtering kwargs in fetch_more --- spinedb_api/db_mapping.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 30596880..a1c2b8cf 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -409,6 +409,7 @@ def get_items(self, item_type, fetch=True, skip_removed=True, **kwargs): item_type (str): One of . fetch (bool, optional): Whether to fetch the DB before returning the items. skip_removed (bool, optional): Whether to ignore removed items. + **kwargs: Fields and values for one the unique keys as specified for the item type in `DB mapping schema`_. Returns: list(:class:`PublicItem`): The items. @@ -628,19 +629,20 @@ def purge_items(self, item_type): """ return bool(self.remove_items(item_type, Asterisk)) - def fetch_more(self, item_type, offset=0, limit=None): + def fetch_more(self, item_type, offset=0, limit=None, **kwargs): """Fetches items from the DB into the in-memory mapping, incrementally. Args: item_type (str): One of . offset (int): The initial row. limit (int): The maximum number of rows to fetch. + **kwargs: Fields and values for one the unique keys as specified for the item type in `DB mapping schema`_. Returns: list(:class:`PublicItem`): The items fetched. """ item_type = self.real_item_type(item_type) - return [x.public_item for x in self.do_fetch_more(item_type, offset=offset, limit=limit)] + return [x.public_item for x in self.do_fetch_more(item_type, offset=offset, limit=limit, **kwargs)] def fetch_all(self, *item_types): """Fetches items from the DB into the in-memory mapping. From 786f9614bfb383f9f614f23545d78c7ef3056b64 Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 9 Nov 2023 14:17:31 +0100 Subject: [PATCH 184/317] Move db mapping schema to its own file --- docs/source/conf.py | 91 ++++++++++++++++++++++-------------- docs/source/front_matter.rst | 2 +- docs/source/index.rst | 1 + spinedb_api/db_mapping.py | 14 ------ 4 files changed, 57 insertions(+), 51 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 83a3f00a..459f7d59 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -93,6 +93,9 @@ autoapi_dirs = ['../../spinedb_api'] # package to be documented autoapi_ignore = [ '*/spinedb_api/alembic/*', + '*/spinedb_api/export_mapping/*', + '*/spinedb_api/import_mapping/*', + '*/spinedb_api/spine_io/*', ] # ignored modules @@ -104,51 +107,67 @@ def _skip_member(app, what, name, obj, skip, options): return skip +def _spine_item_types(): + return ", ".join([f"``{x}``" for x in DatabaseMapping.item_types()]) + + def _process_docstring(app, what, name, obj, options, lines): - # Expand - try: - i = lines.index("") - except ValueError: - pass - else: - new_lines = [] - for item_type in DatabaseMapping.item_types(): - factory = DatabaseMapping.item_factory(item_type) - new_lines.extend([item_type, len(item_type) * "-", ""]) - new_lines.extend( - [ - ".. list-table:: Fields and values", - " :header-rows: 1", - "", - " * - field", - " - type", - " - value", - ] - ) - for f_name, (f_type, f_value) in factory.fields.items(): - new_lines.extend([f" * - {f_name}", f" - {f_type}", f" - {f_value}"]) - new_lines.append("") - new_lines.extend( - [ - ".. list-table:: Unique keys", - " :header-rows: 0", - "", - ] - ) - for f_names in factory._unique_keys: - f_names = ", ".join(f_names) - new_lines.extend([f" * - {f_names}"]) - lines[i : i + 1] = new_lines # Expand - spine_item_types = ", ".join([f"``{x}``" for x in DatabaseMapping.item_types()]) for k, line in enumerate(lines): if "" in line: - lines[k] = line.replace("", spine_item_types) + lines[k] = line.replace("", _spine_item_types()) + + +def _db_mapping_schema_lines(): + lines = [ + "DB mapping schema", + "=================", + "", + "The DB mapping schema is a close cousin of the Spine DB schema with some extra flexibility, " + "like the ability to specify references by name rather than by numerical id.", + "", + f"The schema defines the following item types: {_spine_item_types()}. " + "As you can see, these follow the names of some of the tables in the Spine DB schema.", + "", + "The following subsections provide all the details you need to know about the different item types, namely, " + "their fields, values, and unique keys.", + "", + ] + for item_type in DatabaseMapping.item_types(): + factory = DatabaseMapping.item_factory(item_type) + lines.extend([item_type, len(item_type) * "-", ""]) + lines.extend( + [ + ".. list-table:: Fields and values", + " :header-rows: 1", + "", + " * - field", + " - type", + " - value", + ] + ) + for f_name, (f_type, f_value) in factory.fields.items(): + lines.extend([f" * - {f_name}", f" - {f_type}", f" - {f_value}"]) + lines.append("") + lines.extend( + [ + ".. list-table:: Unique keys", + " :header-rows: 0", + "", + ] + ) + for f_names in factory._unique_keys: + f_names = ", ".join(f_names) + lines.extend([f" * - {f_names}"]) + return lines def setup(sphinx): sphinx.connect("autoapi-skip-member", _skip_member) sphinx.connect("autodoc-process-docstring", _process_docstring) + with open(os.path.join(os.path.dirname(__file__), "db_mapping_schema.rst"), "w") as f: + for line in _db_mapping_schema_lines(): + f.write(line + "\n") # -- Options for HTML output ------------------------------------------------- diff --git a/docs/source/front_matter.rst b/docs/source/front_matter.rst index eb5fdb76..1c034732 100644 --- a/docs/source/front_matter.rst +++ b/docs/source/front_matter.rst @@ -1,4 +1,4 @@ -.. spinedb_api tutorial +.. spinedb_api front matter Created: 18.6.2018 .. _SQLAlchemy: http://www.sqlalchemy.org/ diff --git a/docs/source/index.rst b/docs/source/index.rst index 9d5f5337..9815a94c 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -19,6 +19,7 @@ Welcome to Spine Database API's documentation! parameter_value_format metadata_description results_metadata_description + db_mapping_schema autoapi/index Indices and tables diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index a1c2b8cf..014916e3 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -13,20 +13,6 @@ This module defines the :class:`.DatabaseMapping` class, the main mean to communicate with a Spine DB. If you're planning to use this class, it is probably a good idea to first familiarize yourself a little bit with the DB mapping schema below. - - -DB mapping schema -================= - -The DB mapping schema is a close cousin of the Spine DB schema with some extra flexibility, -like the ability to specify references by name rather than by numerical id. -The schema defines the following item types: . As you can see, these follow the names -of some of the tables in the Spine DB schema. - -The following subsections provide all the details you need to know about the different item types, namely, -their fields, values, and unique keys. - - """ import hashlib From d0df284a93dadb4fe894bc16926f3c7ce3b100bc Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 9 Nov 2023 14:25:52 +0100 Subject: [PATCH 185/317] Fix refs to the DB mapping schema --- .gitignore | 1 + docs/source/conf.py | 2 ++ spinedb_api/db_mapping.py | 21 ++++++++++++--------- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/.gitignore b/.gitignore index 05d53d7b..89f5b605 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ /.idea/ /docs/build/ /docs/source/autoapi/ +/docs/source/db_mapping_schema.rst # Setuptools distribution folder. /build/ diff --git a/docs/source/conf.py b/docs/source/conf.py index 459f7d59..8df70417 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -120,6 +120,8 @@ def _process_docstring(app, what, name, obj, options, lines): def _db_mapping_schema_lines(): lines = [ + ".. _db_mapping_schema:", + "", "DB mapping schema", "=================", "", diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 014916e3..dace41c3 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -12,7 +12,7 @@ """ This module defines the :class:`.DatabaseMapping` class, the main mean to communicate with a Spine DB. If you're planning to use this class, it is probably a good idea to first familiarize yourself a little bit with the -DB mapping schema below. +:ref:`db_mapping_schema`. """ import hashlib @@ -57,7 +57,7 @@ class DatabaseMapping(DatabaseMappingQueryMixin, DatabaseMappingCommitMixin, DatabaseMappingBase): """Enables communication with a Spine DB. - The DB is incrementally mapped into memory as data is requested/modified, following the `DB mapping schema`_. + The DB is incrementally mapped into memory as data is requested/modified, following the :ref:`db_mapping_schema`. Data is typically retrieved using :meth:`get_item` or :meth:`get_items`. If the requested data is already in memory, it is returned from there; @@ -375,7 +375,8 @@ def get_item(self, item_type, fetch=True, skip_removed=True, **kwargs): item_type (str): One of . fetch (bool, optional): Whether to fetch the DB in case the item is not found in memory. skip_removed (bool, optional): Whether to ignore removed items. - **kwargs: Fields and values for one the unique keys as specified for the item type in `DB mapping schema`_. + **kwargs: Fields and values for one the unique keys as specified for the item type + in :ref:`db_mapping_schema`. Returns: :class:`PublicItem` or None @@ -395,7 +396,8 @@ def get_items(self, item_type, fetch=True, skip_removed=True, **kwargs): item_type (str): One of . fetch (bool, optional): Whether to fetch the DB before returning the items. skip_removed (bool, optional): Whether to ignore removed items. - **kwargs: Fields and values for one the unique keys as specified for the item type in `DB mapping schema`_. + **kwargs: Fields and values for one the unique keys as specified for the item type + in :ref:`db_mapping_schema`. Returns: list(:class:`PublicItem`): The items. @@ -419,7 +421,7 @@ def add_item(self, item_type, check=True, **kwargs): Args: item_type (str): One of . check (bool, optional): Whether to carry out integrity checks. - **kwargs: Fields and values as specified for the item type in `DB mapping schema`_. + **kwargs: Fields and values as specified for the item type in :ref:`db_mapping_schema`. Returns: tuple(:class:`PublicItem` or None, str): The added item and any errors. @@ -438,7 +440,7 @@ def add_items(self, item_type, *items, check=True, strict=False): Args: item_type (str): One of . *items (Iterable(dict)): One or more :class:`dict` objects mapping fields to values of the item type, - as specified in `DB mapping schema`_. + as specified in :ref:`db_mapping_schema`. check (bool): Whether or not to run integrity checks. strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` if the insertion of one of the items violates an integrity constraint. @@ -472,7 +474,7 @@ def update_item(self, item_type, check=True, **kwargs): item_type (str): One of . check (bool, optional): Whether to carry out integrity checks. id (int): The id of the item to update. - **kwargs: Fields to update and their new values as specified for the item type in `DB mapping schema`_. + **kwargs: Fields to update and their new values as specified for the item type in :ref:`db_mapping_schema`. Returns: tuple(:class:`PublicItem` or None, str): The updated item and any errors. @@ -491,7 +493,7 @@ def update_items(self, item_type, *items, check=True, strict=False): Args: item_type (str): One of . *items (Iterable(dict)): One or more :class:`dict` objects mapping fields to values of the item type, - as specified in `DB mapping schema`_ and including the `id`. + as specified in :ref:`db_mapping_schema` and including the `id`. check (bool): Whether or not to run integrity checks. strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` if the update of one of the items violates an integrity constraint. @@ -622,7 +624,8 @@ def fetch_more(self, item_type, offset=0, limit=None, **kwargs): item_type (str): One of . offset (int): The initial row. limit (int): The maximum number of rows to fetch. - **kwargs: Fields and values for one the unique keys as specified for the item type in `DB mapping schema`_. + **kwargs: Fields and values for one the unique keys as specified for the item type + in :ref:`db_mapping_schema`. Returns: list(:class:`PublicItem`): The items fetched. From fc0d2309ad20572c97873537397d79dba92a3936 Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 9 Nov 2023 16:05:13 +0100 Subject: [PATCH 186/317] Introduce :meta private: so that autoapi skip some members --- docs/source/conf.py | 29 ++++++++-------- spinedb_api/compatibility.py | 2 +- spinedb_api/export_functions.py | 1 - spinedb_api/graph_layout_generator.py | 1 - spinedb_api/helpers.py | 14 ++++---- spinedb_api/mapping.py | 2 +- spinedb_api/parameter_value.py | 48 +++++++++++++++++++++++++-- spinedb_api/perfect_split.py | 1 - spinedb_api/purge.py | 2 -- 9 files changed, 67 insertions(+), 33 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 8df70417..c2bdca63 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -96,22 +96,26 @@ '*/spinedb_api/export_mapping/*', '*/spinedb_api/import_mapping/*', '*/spinedb_api/spine_io/*', + '*/spinedb_api/compatibility*', + '*/spinedb_api/exception*', + '*/spinedb_api/export_functions*', + '*/spinedb_api/helpers*', + '*/spinedb_api/mapping*', + '*/spinedb_api/perfect_split*', + '*/spinedb_api/purge*', + '*/spinedb_api/query*', + '*/spinedb_api/spine_db_client*', + '*/spinedb_api/spine_db_server*', ] # ignored modules -def _skip_member(app, what, name, obj, skip, options): - if what == "class" and any( - x in name for x in ("SpineDBServer", "group_concat", "DBRequestHandler", "ReceiveAllMixing") - ): - skip = True - return skip - - def _spine_item_types(): return ", ".join([f"``{x}``" for x in DatabaseMapping.item_types()]) def _process_docstring(app, what, name, obj, options, lines): + if any(":meta private:" in line for line in lines): + lines.clear() # Expand for k, line in enumerate(lines): if "" in line: @@ -151,13 +155,7 @@ def _db_mapping_schema_lines(): for f_name, (f_type, f_value) in factory.fields.items(): lines.extend([f" * - {f_name}", f" - {f_type}", f" - {f_value}"]) lines.append("") - lines.extend( - [ - ".. list-table:: Unique keys", - " :header-rows: 0", - "", - ] - ) + lines.extend([".. list-table:: Unique keys", " :header-rows: 0", ""]) for f_names in factory._unique_keys: f_names = ", ".join(f_names) lines.extend([f" * - {f_names}"]) @@ -165,7 +163,6 @@ def _db_mapping_schema_lines(): def setup(sphinx): - sphinx.connect("autoapi-skip-member", _skip_member) sphinx.connect("autodoc-process-docstring", _process_docstring) with open(os.path.join(os.path.dirname(__file__), "db_mapping_schema.rst"), "w") as f: for line in _db_mapping_schema_lines(): diff --git a/spinedb_api/compatibility.py b/spinedb_api/compatibility.py index 38bc144c..92a6783b 100644 --- a/spinedb_api/compatibility.py +++ b/spinedb_api/compatibility.py @@ -9,7 +9,7 @@ # this program. If not, see . ###################################################################################################################### -# Dirty hacks needed to maintain compatibility in cases where migration alone doesn't do it +"""Dirty hacks needed to maintain compatibility in cases where migration alone doesn't do it.""" import sqlalchemy as sa diff --git a/spinedb_api/export_functions.py b/spinedb_api/export_functions.py index bd54280b..a15c0fcd 100644 --- a/spinedb_api/export_functions.py +++ b/spinedb_api/export_functions.py @@ -11,7 +11,6 @@ """ Functions for exporting data from a Spine database in a standard format. - """ from operator import itemgetter diff --git a/spinedb_api/graph_layout_generator.py b/spinedb_api/graph_layout_generator.py index 5ffd3595..be0a149a 100644 --- a/spinedb_api/graph_layout_generator.py +++ b/spinedb_api/graph_layout_generator.py @@ -12,7 +12,6 @@ """ This module defines the :class:`.GraphLayoutGenerator` class. """ - import math import numpy as np from numpy import atleast_1d as arr diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index 6a6bd2b5..e68a6267 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -11,9 +11,7 @@ """ General helper functions. - """ - import os import json import warnings @@ -98,13 +96,13 @@ def set_sqlite_pragma(dbapi_connection, connection_record): @compiles(TINYINT, "sqlite") def compile_TINYINT_mysql_sqlite(element, compiler, **kw): - # Handles mysql TINYINT datatype as INTEGER in sqlite. + """Handles mysql TINYINT datatype as INTEGER in sqlite.""" return compiler.visit_INTEGER(element, **kw) @compiles(DOUBLE, "sqlite") def compile_DOUBLE_mysql_sqlite(element, compiler, **kw): - # Handles mysql DOUBLE datatype as REAL in sqlite. + """Handles mysql DOUBLE datatype as REAL in sqlite.""" return compiler.visit_REAL(element, **kw) @@ -256,7 +254,7 @@ def copy_database_bind(dest_bind, source_bind, overwrite=True, upgrade=False, on def custom_generate_relationship(base, direction, return_fn, attrname, local_cls, referred_cls, **kw): - # Make all relationships view only to avoid warnings. + """Make all relationships view only to avoid warnings.""" kw["viewonly"] = True kw["cascade"] = "" kw["passive_deletes"] = False @@ -763,8 +761,8 @@ def _create_first_spine_database(db_url): def forward_sweep(root, fn, *args): - # Recursively visit, using `get_children()`, the given sqlalchemy object. - # Apply `fn` on every visited node.""" + """Recursively visit, using `get_children()`, the given sqlalchemy object. + Apply `fn` on every visited node.""" current = root parent = {} children = {current: iter(current.get_children(column_collections=False))} @@ -804,7 +802,7 @@ def __repr__(self): def fix_name_ambiguity(input_list, offset=0, prefix=""): - # Modify repeated entries in name list by appending an increasing integer. + """Modify repeated entries in name list by appending an increasing integer.""" result = [] ocurrences = {} for item in input_list: diff --git a/spinedb_api/mapping.py b/spinedb_api/mapping.py index 0d5d0461..5b197b17 100644 --- a/spinedb_api/mapping.py +++ b/spinedb_api/mapping.py @@ -9,7 +9,7 @@ # this program. If not, see . ###################################################################################################################### -# Base class for import and export mappings. +"""Base class for import and export mappings.""" from enum import Enum, unique from itertools import takewhile diff --git a/spinedb_api/parameter_value.py b/spinedb_api/parameter_value.py index 9d699de1..774e2c55 100644 --- a/spinedb_api/parameter_value.py +++ b/spinedb_api/parameter_value.py @@ -52,6 +52,8 @@ def duration_to_relativedelta(duration): """ Converts a duration to a relativedelta object. + :meta private: + Args: duration (str): a duration string. @@ -83,6 +85,8 @@ def relativedelta_to_duration(delta): """ Converts a relativedelta to duration. + :meta private: + Args: delta (:class:`~dateutil.relativedelta.relativedelta`): the relativedelta to convert. @@ -122,6 +126,8 @@ def load_db_value(db_value, value_type=None): Parses a database representation of a parameter value (value and type) into a Python object, using JSON. If the result is a dict, adds the "type" property to it. + :meta private: + Args: db_value (bytes, optional): the database value. value_type (str, optional): the value type. @@ -145,6 +151,8 @@ def dump_db_value(parsed_value): Unparses a Python object into a database representation of a parameter value (value and type), using JSON. If the given object is a dict, extracts the "type" property from it. + :meta private: + Args: parsed_value (any): a Python object, typically obtained by calling :func:`load_db_value`. @@ -183,6 +191,8 @@ def from_database_to_single_value(database_value, value_type): """ Same as :func:`from_database`, but in the case of indexed types it returns just the type as a string. + :meta private: + Args: database_value (bytes): the database value value_type (str, optional): the value type @@ -199,6 +209,8 @@ def from_database_to_dimension_count(database_value, value_type): """ Counts the dimensions in a database representation of a parameter value (value and type). + :meta private: + Args: database_value (bytes): the database value value_type (str, optional): the value type @@ -235,6 +247,8 @@ def from_dict(value): """ Converts a dictionary representation of a parameter value into an encoded parameter value. + :meta private: + Args: value (dict): the value dictionary including the "type" key. @@ -263,6 +277,8 @@ def from_dict(value): def fix_conflict(new, old, on_conflict="merge"): """Resolves conflicts between parameter values: + :meta private: + Args: new (:class:`ParameterValue`, float, str, bool or None): new parameter value to be written. old (:class:`ParameterValue`, float, str, bool or None): an existing parameter value in the db. @@ -286,6 +302,8 @@ def fix_conflict(new, old, on_conflict="merge"): def merge(value, other): """Merges the DB representation of two parameter values. + :meta private: + Args: value (tuple(bytes,str)): recipient value and type. other (tuple(bytes,str)): other value and type. @@ -616,7 +634,10 @@ def _array_from_database(value_dict): class ParameterValue: - """Base class for all encoded parameter values.""" + """Base class for all encoded parameter values. + + :meta private: + """ def to_dict(self): """Returns the dictionary representation of this object. @@ -776,6 +797,8 @@ class _Indexes(np.ndarray): position = indexes.index(element) which might be too slow compared to dictionary lookup. + + :meta private: """ def __new__(cls, other, dtype=None): @@ -805,6 +828,8 @@ class IndexedValue(ParameterValue): """ Base class for all values that have indexes. + :meta private: + Attributes: index_name (str): index name """ @@ -988,6 +1013,8 @@ class IndexedNumberArray(IndexedValue): Abstract base class for all values mapping indexes to floats. The indexes and numbers are stored in :class:`~numpy.ndarray`s. + + :meta private: """ def __init__(self, index_name, values): @@ -1014,7 +1041,10 @@ def to_dict(self): class TimeSeries(IndexedNumberArray): - """Abstract base class for time-series.""" + """Abstract base class for time-series. + + :meta private: + """ VALUE_TYPE = "time series" DEFAULT_INDEX_NAME = "t" @@ -1453,6 +1483,8 @@ def to_dict(self): def map_dimensions(map_): """Counts the dimensions in a map. + :meta private: + Args: map_ (:class:`Map`): the map to process. @@ -1477,6 +1509,8 @@ def convert_leaf_maps_to_specialized_containers(map_): - If the ``index_type`` is a :class:`DateTime` and all ``values`` are float, then the leaf is converted to a :class:`TimeSeries`. + :meta private: + Args: map_ (:class:`Map`): a map to process. @@ -1502,6 +1536,8 @@ def convert_containers_to_maps(value): If ``value`` is a :class:`Map` then converts leaf values into maps recursively. + :meta private: + Args: value (:class:`IndexedValue`): an indexed value to convert. @@ -1531,6 +1567,8 @@ def convert_map_to_table(map_, make_square=True, row_this_far=None, empty=None): """ Converts :class:`Map` into list of rows recursively. + :meta private: + Args: map_ (:class:`Map`): map to convert. make_square (bool): if True, then pad rows with None so they all have the same length. @@ -1564,6 +1602,8 @@ def convert_map_to_dict(map_): """ Converts a :class:`Map` to a nested dictionary. + :meta private: + Args: map_ (:class:`Map`): map to convert @@ -1610,6 +1650,8 @@ def join_value_and_type(db_value, db_type): In case of complex types (duration, date_time, time_series, time_pattern, array, map), the type is just added as top-level key. + :meta private: + Args: db_value (bytes): database value db_type (str, optional): value type @@ -1627,6 +1669,8 @@ def join_value_and_type(db_value, db_type): def split_value_and_type(value_and_type): """Splits the given string into value and type. + :meta private: + Args: value_and_type (str): a string joining value and type, as obtained by calling :func:`join_value_and_type`. diff --git a/spinedb_api/perfect_split.py b/spinedb_api/perfect_split.py index f9413317..6028e47a 100644 --- a/spinedb_api/perfect_split.py +++ b/spinedb_api/perfect_split.py @@ -11,7 +11,6 @@ """ This module provides the :func:`perfect_split` function. - """ from .db_mapping import DatabaseMapping from .export_functions import export_data diff --git a/spinedb_api/purge.py b/spinedb_api/purge.py index 87f35d53..472fa5c3 100644 --- a/spinedb_api/purge.py +++ b/spinedb_api/purge.py @@ -11,9 +11,7 @@ """ Functions to purge DBs. - """ - from .db_mapping import DatabaseMapping from .exception import SpineDBAPIError, SpineDBVersionError from .filters.tools import clear_filter_configs From 26012e98537c18e73fdd989da9784ff1ac7b0c35 Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 10 Nov 2023 13:15:30 +0100 Subject: [PATCH 187/317] Improve docs for the parameter_value module --- docs/source/parameter_value_format.rst | 3 + spinedb_api/__init__.py | 1 - spinedb_api/db_mapping.py | 62 +-- spinedb_api/db_mapping_base.py | 6 + spinedb_api/parameter_value.py | 532 +++++++++++++---------- spinedb_api/spine_db_server.py | 4 +- tests/spine_io/test_excel_integration.py | 6 +- tests/test_parameter_value.py | 77 ++-- 8 files changed, 379 insertions(+), 312 deletions(-) diff --git a/docs/source/parameter_value_format.rst b/docs/source/parameter_value_format.rst index b2bc5e21..ea6f3513 100644 --- a/docs/source/parameter_value_format.rst +++ b/docs/source/parameter_value_format.rst @@ -1,3 +1,6 @@ +.. _parameter_value_format: + + ********************** Parameter value format ********************** diff --git a/spinedb_api/__init__.py b/spinedb_api/__init__.py index 679c68e9..b6b41a6b 100644 --- a/spinedb_api/__init__.py +++ b/spinedb_api/__init__.py @@ -83,7 +83,6 @@ Array, DateTime, Duration, - IndexedNumberArray, IndexedValue, Map, TimePattern, diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index dace41c3..cf6d28a0 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -202,28 +202,6 @@ def _make_sq(self, item_type): sq_name = self._sq_name_by_item_type[item_type] return getattr(self, sq_name) - def close(self): - """Closes this DB mapping. This is only needed if you're keeping a long-lived session. - For instance:: - - class MyDBMappingWrapper: - def __init__(self, url): - self._db_map = DatabaseMapping(url) - - # More methods that do stuff with self._db_map - - def __del__(self): - self._db_map.close() - - Otherwise, the usage as context manager is recommended:: - - with DatabaseMapping(url) as db_map: - # Do stuff with db_map - ... - # db_map.close() is automatically called when leaving this block - """ - self.closed = True - def _make_codename(self, codename): if codename: return str(codename) @@ -338,14 +316,6 @@ def _convert_legacy(tablename, item): if entity_id: item["entity_id"] = entity_id - def has_external_commits(self): - """Test whether the database has had commits from other sources than this mapping. - - Returns: - bool: True if database has external commits, False otherwise - """ - return self._commit_count != self.query(self.commit_sq).count() - def get_import_alternative_name(self): if self._import_alternative_name is None: self._create_import_alternative() @@ -724,6 +694,36 @@ def refresh_session(self): """Resets the fetch status so new items from the DB can be retrieved.""" self._refresh() + def has_external_commits(self): + """Tests whether the database has had commits from other sources than this mapping. + + Returns: + bool: True if database has external commits, False otherwise + """ + return self._commit_count != self.query(self.commit_sq).count() + + def close(self): + """Closes this DB mapping. This is only needed if you're keeping a long-lived session. + For instance:: + + class MyDBMappingWrapper: + def __init__(self, url): + self._db_map = DatabaseMapping(url) + + # More methods that do stuff with self._db_map + + def __del__(self): + self._db_map.close() + + Otherwise, the usage as context manager is recommended:: + + with DatabaseMapping(url) as db_map: + # Do stuff with db_map + ... + # db_map.close() is automatically called when leaving this block + """ + self.closed = True + def add_ext_entity_metadata(self, *items, **kwargs): metadata_items = self.get_metadata_to_add_with_item_metadata_items(*items) self.add_items("metadata", *metadata_items, **kwargs) @@ -760,7 +760,7 @@ def remove_unused_metadata(self): self.remove_items("metadata", *unused_metadata_ids) def get_filter_configs(self): - """Returns the filters used to build this DB mapping. + """Returns the filters from this mapping's URL. Returns: list(dict): diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 028eea4f..bcc28bb2 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -52,6 +52,8 @@ def __init__(self): def item_types(): """Returns a list of public item types from the DB mapping schema (equivalent to the table names). + :meta private: + Returns: list(str) """ @@ -61,6 +63,8 @@ def item_types(): def all_item_types(): """Returns a list of all item types from the DB mapping schema (equivalent to the table names). + :meta private: + Returns: list(str) """ @@ -70,6 +74,8 @@ def all_item_types(): def item_factory(item_type): """Returns a subclass of :class:`.MappedItemBase` to make items of given type. + :meta private: + Args: item_type (str) diff --git a/spinedb_api/parameter_value.py b/spinedb_api/parameter_value.py index 774e2c55..e2e3828e 100644 --- a/spinedb_api/parameter_value.py +++ b/spinedb_api/parameter_value.py @@ -10,19 +10,70 @@ ###################################################################################################################### """ -Support utilities and classes to deal with Spine parameter values. +Parameter values in a Spine DB can be of different types (see :ref:`parameter_value_format`). +For each of these types, this module provides a Python class to represent values of that type. + +.. list-table:: Parameter value type and Python class + :header-rows: 1 + + * - type + - Python class + * - ``date_time`` + - :class:`DateTime` + * - ``duration`` + - :class:`Duration` + * - ``array`` + - :class:`Array` + * - ``time_pattern`` + - :class:`TimePattern` + * - ``time_series`` + - :class:`TimeSeriesFixedResolution` and :class:`TimeSeriesVariableResolution` + * - ``map`` + - :class:`Map` + +The module also provides the functions :func:`to_database` and :func:`from_database` +to translate between instances of the above classes and their DB representation (namely, the `value` and `type` fields +that would go in the ``parameter_value`` table). + +For example, to write a Python object into a parameter value in the DB:: + + # Create the Python object + parsed_value = TimeSeriesFixedResolution( + datetime("2023-01-01T00:00"), # start + "1D", # resolution + [9, 8, 7, 6, 5, 4, 3, 2, 1, 0], # values + ignore_year=False, + repeat=False, + ) + # Translate it to value and type + value, type_ = to_database(parsed_value) + # Add a parameter_value to the DB with that value and type + with DatabaseMapping(url) as db_map: + db_map.add_parameter_value_item( + entity_class_name="cat", + entity_byname=("Tom",), + parameter_definition_name="number_of_lives", + alternative_name="Base", + value=value, + type=type_, + ) + db_map.commit_session("Tom is living one day at a time") -The :func:`from_database` function receives a DB representation of a parameter value (value and type) and returns -a float, str, bool, :class:`DateTime`, :class:`Duration`, :class:`Array`, :class:`TimePattern`, -:class:`TimeSeriesFixedResolution`, :class:`TimeSeriesVariableResolution` or :class:`Map` object. +Similarly, to read a parameter value from the DB into a Python object:: -The above objects can be converted back to the database format by the :func:`to_database` free function -or by their :meth:`~ParameterValue.to_database` member function. + # Get the parameter_value from the DB + with DatabaseMapping(url) as db_map: + pval_item = db_map.get_parameter_value_item( + entity_class_name="cat", + entity_byname=("Tom",), + parameter_definition_name="number_of_lives", + alternative_name="Base", + ) + # Obtain value and type + value, type_ = pval_item["value"], pval_item["type"] + # Translate value and type to a Python object + parsed_value = from_database(value, type_) -Individual datetimes are represented as datetime objects from the standard Python library. -Individual time steps are represented as relativedelta objects from the dateutil package. -Datetime indexes (as returned by TimeSeries.indexes()) are represented as -numpy.ndarray arrays holding numpy.datetime64 objects. """ from collections.abc import Sequence @@ -48,6 +99,43 @@ _TIME_SERIES_PLAIN_INDEX_UNIT = "m" +def from_database(value, type_=None): + """ + Converts a parameter value from the DB into a Python object. + + Args: + value (bytes or None): the `value` field from the ``parameter_value`` table. + type_ (str, optional): the `type` field from the ``parameter_value`` table. + + Returns: + :class:`ParameterValue`, float, str, bool or None: a Python object representing the parameter value. + """ + parsed = load_db_value(value, type_) + if isinstance(parsed, dict): + return from_dict(parsed) + if isinstance(parsed, bool): + return parsed + if isinstance(parsed, Number): + return float(parsed) + return parsed + + +def to_database(parsed_value): + """ + Converts a Python object representing a parameter value into their DB representation. + + Args: + parsed_value (any): the Python object. + + Returns: + tuple(bytes,str): the `value` and `type` fields that would go in the ``parameter_value`` table. + """ + if hasattr(parsed_value, "to_database"): + return parsed_value.to_database() + db_value = json.dumps(parsed_value).encode("UTF8") + return db_value, None + + def duration_to_relativedelta(duration): """ Converts a duration to a relativedelta object. @@ -121,7 +209,7 @@ def relativedelta_to_duration(delta): return "0h" -def load_db_value(db_value, value_type=None): +def load_db_value(db_value, type_=None): """ Parses a database representation of a parameter value (value and type) into a Python object, using JSON. If the result is a dict, adds the "type" property to it. @@ -130,7 +218,7 @@ def load_db_value(db_value, value_type=None): Args: db_value (bytes, optional): the database value. - value_type (str, optional): the value type. + type_ (str, optional): the value type. Returns: any: the parsed parameter value @@ -142,7 +230,7 @@ def load_db_value(db_value, value_type=None): except JSONDecodeError as err: raise ParameterValueFormatError(f"Could not decode the value: {err}") from err if isinstance(parsed, dict): - return {"type": value_type, **parsed} + return {"type": type_, **parsed} return parsed @@ -166,27 +254,6 @@ def dump_db_value(parsed_value): return db_value, value_type -def from_database(database_value, value_type=None): - """ - Converts a database representation of a parameter value (value and type) into an encoded parameter value. - - Args: - database_value (bytes, optional): the database value - value_type (str, optional): the value type - - Returns: - :class:`ParameterValue`, float, str, bool or None: the encoded parameter value. - """ - parsed = load_db_value(database_value, value_type) - if isinstance(parsed, dict): - return from_dict(parsed) - if isinstance(parsed, bool): - return parsed - if isinstance(parsed, Number): - return float(parsed) - return parsed - - def from_database_to_single_value(database_value, value_type): """ Same as :func:`from_database`, but in the case of indexed types it returns just the type as a string. @@ -227,22 +294,6 @@ def from_database_to_dimension_count(database_value, value_type): return 0 -def to_database(parsed_value): - """ - Converts an encoded parameter value into its database representation (value and type). - - Args: - value(any): a Python object, typically obtained by calling :func:`load_db_value` or :func:`from_database`. - - Returns: - tuple(bytes,str): database representation (value and type). - """ - if hasattr(parsed_value, "to_database"): - return parsed_value.to_database() - db_value = json.dumps(parsed_value).encode("UTF8") - return db_value, None - - def from_dict(value): """ Converts a dictionary representation of a parameter value into an encoded parameter value. @@ -634,13 +685,14 @@ def _array_from_database(value_dict): class ParameterValue: - """Base class for all encoded parameter values. + """Base for all classes representing parameter values.""" - :meta private: - """ + VALUE_TYPE = NotImplemented def to_dict(self): - """Returns the dictionary representation of this object. + """Returns a dictionary representation of this parameter value. + + :meta private: Returns: dict: a dictionary including the "type" key. @@ -649,18 +701,18 @@ def to_dict(self): @staticmethod def type_(): - """Returns the value type for this object. + """Returns the type of the parameter value represented by this object. Returns: - str: the value type. + str """ raise NotImplementedError() def to_database(self): - """Returns the database representation of this object (value and type). + """Returns the DB representation of this object. Equivalent to calling :func:`to_database` with it. Returns: - tuple(bytes,str): the DB value and type. + tuple(bytes,str): the `value` and `type` fields that would go in the ``parameter_value`` table. """ return json.dumps(self.to_dict()).encode("UTF8"), self.type_() @@ -678,14 +730,14 @@ def to_database(self): class DateTime(ParameterValue): - """A moment in time.""" + """A parameter value of type 'date_time'. A point in time.""" VALUE_TYPE = "single value" def __init__(self, value=None): """ Args: - value (:class:`DateTime` or str or datetime.datetime): a timestamp + value (:class:`DateTime` or str or :class:`~datetime.datetime`): the `date_time` value. """ if value is None: value = datetime(year=2000, month=1, day=1) @@ -717,7 +769,10 @@ def __str__(self): return self._value.isoformat() def value_to_database_data(self): - """Returns the database representation of the datetime.""" + """Returns the database representation of the datetime. + + :meta private: + """ return self._value.isoformat() def to_dict(self): @@ -725,6 +780,10 @@ def to_dict(self): @staticmethod def type_(): + """See base class + + :meta private: + """ return "date_time" @property @@ -732,16 +791,14 @@ def value(self): """The value. Returns: - :class:`~datetime.datetime`: + :class:`~datetime.datetime` """ return self._value class Duration(ParameterValue): """ - A duration in time. - - Durations are always handled as :class:`~dateutil.dateutil.relativedelta`s. + A parameter value of type 'duration'. An extension of time. """ VALUE_TYPE = "single value" @@ -749,7 +806,7 @@ class Duration(ParameterValue): def __init__(self, value=None): """ Args: - value (str or :class:`~dateutil.dateutil.relativedelta`): the duration + value (str or :class:`Duration` or :class:`~dateutil.dateutil.relativedelta`): the `duration` value. """ if value is None: value = relativedelta(hours=1) @@ -773,7 +830,10 @@ def __str__(self): return str(relativedelta_to_duration(self._value)) def value_to_database_data(self): - """Returns the 'data' property of this object's database representation.""" + """Returns the 'data' property of this object's database representation. + + :meta private: + """ return relativedelta_to_duration(self._value) def to_dict(self): @@ -781,11 +841,19 @@ def to_dict(self): @staticmethod def type_(): + """See base class + + :meta private: + """ return "duration" @property def value(self): - """Returns the duration as a :class:`relativedelta`.""" + """The value. + + Returns + :class:`~dateutil.dateutil.relativedelta` + """ return self._value @@ -797,8 +865,6 @@ class _Indexes(np.ndarray): position = indexes.index(element) which might be too slow compared to dictionary lookup. - - :meta private: """ def __new__(cls, other, dtype=None): @@ -826,24 +892,23 @@ def __bool__(self): class IndexedValue(ParameterValue): """ - Base class for all values that have indexes. - - :meta private: - - Attributes: - index_name (str): index name + Base for all classes representing indexed parameter values. """ - VALUE_TYPE = NotImplemented + DEFAULT_INDEX_NAME = NotImplemented - def __init__(self, index_name): + def __init__(self, values, value_type=None, index_name=""): """ + :meta private: + Args: - index_name (str): index name. + index_name (str): a label for the index. """ + self._value_type = value_type self._indexes = None self._values = None - self.index_name = index_name + self.values = values + self.index_name = index_name if index_name else self.DEFAULT_INDEX_NAME def __bool__(self): # NOTE: Use self.indexes rather than self._indexes, otherwise TimeSeriesFixedResolution gives wrong result @@ -854,6 +919,10 @@ def __len__(self): @staticmethod def type_(): + """See base class + + :meta private: + """ raise NotImplementedError() @property @@ -892,12 +961,29 @@ def values(self, values): """ self._values = values + @property + def value_type(self): + """The type of the values. + + Returns: + type: + """ + return self._value_type + def get_nearest(self, index): + """Returns the value at the nearest index to the given one. + + Args: + index (any): The index. + + Returns: + any: The value. + """ pos = np.searchsorted(self.indexes, index) return self.values[pos] def get_value(self, index): - """Returns the value at a given index. + """Returns the value at the given index. Args: index (any): The index. @@ -911,7 +997,7 @@ def get_value(self, index): return self.values[pos] def set_value(self, index, value): - """Sets the value at a given index. + """Sets the value at the given index. Args: index (any): The index. @@ -937,7 +1023,7 @@ def merge(self, other): class Array(IndexedValue): - """A one dimensional array with zero based indexing.""" + """A parameter value of type 'array'. A one dimensional array with zero based indexing.""" VALUE_TYPE = "array" DEFAULT_INDEX_NAME = "i" @@ -945,28 +1031,26 @@ class Array(IndexedValue): def __init__(self, values, value_type=None, index_name=""): """ Args: - values (Sequence): the values in the array. - value_type (Type, optional): array element type; will be deduced from ``values`` if not given - and defaults to float if ``values`` is empty. - index_name (str): index name. + values (Sequence): the array values. + value_type (type, optional): the type of the values; if not given, it will be deduced from `values`. + Defaults to float if `values` is empty. + index_name (str): the name you would give to the array index in your application. """ - super().__init__(index_name if index_name else Array.DEFAULT_INDEX_NAME) if value_type is None: value_type = type(values[0]) if values else float - if value_type == int: - try: - values = [float(x) for x in values] - except ValueError: - raise ParameterValueFormatError("Cannot convert array's values to float.") - value_type = float - if any(not isinstance(x, value_type) for x in values): + if value_type == int: + value_type = float + try: + values = [value_type(x) for x in values] + except ValueError: + raise ParameterValueFormatError("Cannot convert array's values to float.") + if not all(isinstance(x, value_type) for x in values): try: values = [value_type(x) for x in values] except ValueError: raise ParameterValueFormatError("Not all array's values are of the same type.") + super().__init__(values, value_type=value_type, index_name=index_name) self.indexes = range(len(values)) - self.values = list(values) - self._value_type = value_type def __eq__(self, other): if not isinstance(other, Array): @@ -998,59 +1082,125 @@ def to_dict(self): value_dict["index_name"] = self.index_name return value_dict - @property - def value_type(self): - """Returns the type of the values. - Returns: - str: +class _TimePatternIndexes(_Indexes): + """An array of *checked* time pattern indexes.""" + + @staticmethod + def _check_index(union_str): """ - return self._value_type + Checks if a time pattern index has the right format. + Args: + union_str (str): The time pattern index to check. Generally assumed to be a union of interval intersections. -class IndexedNumberArray(IndexedValue): - """ - Abstract base class for all values mapping indexes to floats. + Raises: + ParameterValueFormatError: If the given string doesn't comply with time pattern spec. + """ + if not union_str: + # We accept empty strings so we can add empty rows in the parameter value editor UI + return + union_dlm = "," + intersection_dlm = ";" + range_dlm = "-" + regexp = r"(Y|M|D|WD|h|m|s)" + for intersection_str in union_str.split(union_dlm): + for interval_str in intersection_str.split(intersection_dlm): + m = re.match(regexp, interval_str) + if m is None: + raise ParameterValueFormatError( + f"Invalid interval {interval_str}, it should start with either Y, M, D, WD, h, m, or s." + ) + key = m.group(0) + lower_upper_str = interval_str[len(key) :] + lower_upper = lower_upper_str.split(range_dlm) + if len(lower_upper) != 2: + raise ParameterValueFormatError( + f"Invalid interval bounds {lower_upper_str}, it should be two integers separated by dash (-)." + ) + lower_str, upper_str = lower_upper + try: + lower = int(lower_str) + except: + raise ParameterValueFormatError(f"Invalid lower bound {lower_str}, must be an integer.") + try: + upper = int(upper_str) + except: + raise ParameterValueFormatError(f"Invalid upper bound {upper_str}, must be an integer.") + if lower > upper: + raise ParameterValueFormatError(f"Lower bound {lower} can't be higher than upper bound {upper}.") - The indexes and numbers are stored in :class:`~numpy.ndarray`s. + def __array_finalize__(self, obj): + """Checks indexes when building the array.""" + for x in obj: + self._check_index(x) + super().__array_finalize__(obj) - :meta private: + def __eq__(self, other): + return list(self) == list(other) + + def __setitem__(self, position, index): + """Checks indexes when setting and item.""" + self._check_index(index) + super().__setitem__(position, index) + + +class TimePattern(IndexedValue): + """A parameter value of type 'time_pattern'. + A mapping from time patterns strings to numerical values. """ - def __init__(self, index_name, values): + VALUE_TYPE = "time pattern" + DEFAULT_INDEX_NAME = "p" + + def __init__(self, indexes, values, index_name=""): """ Args: - index_name (str): index name. - values (Sequence): the values in the array; index handling should be implemented by subclasses. + indexes (list): the time pattern strings. + values (Sequence): the values associated to different patterns. + index_name (str): index name """ - super().__init__(index_name) - self.values = values + if len(indexes) != len(values): + raise ParameterValueFormatError("Length of values does not match length of indexes") + if not indexes: + raise ParameterValueFormatError("Empty time pattern not allowed") + super().__init__(values, value_type=float, index_name=index_name) + self.indexes = indexes - @IndexedValue.values.setter - def values(self, values): - if not isinstance(values, np.ndarray) or not values.dtype == np.dtype(float): - values = np.array(values, dtype=float) - self._values = values + def __eq__(self, other): + if not isinstance(other, TimePattern): + return NotImplemented + return ( + self._indexes == other._indexes + and np.all(self._values == other._values) + and self.index_name == other.index_name + ) + + @IndexedValue.indexes.setter + def indexes(self, indexes): + self._indexes = _TimePatternIndexes(indexes, dtype=np.object_) @staticmethod def type_(): - raise NotImplementedError() + return "time_pattern" def to_dict(self): - raise NotImplementedError() - + value_dict = {"data": dict(zip(self._indexes, self._values))} + if self.index_name != "p": + value_dict["index_name"] = self.index_name + return value_dict -class TimeSeries(IndexedNumberArray): - """Abstract base class for time-series. - :meta private: - """ +class TimeSeries(IndexedValue): + """Base for all classes representing 'time_series' parameter values.""" VALUE_TYPE = "time series" DEFAULT_INDEX_NAME = "t" def __init__(self, values, ignore_year, repeat, index_name=""): """ + :meta private: + Args: values (Sequence): the values in the time-series. ignore_year (bool): True if the year should be ignored. @@ -1059,7 +1209,7 @@ def __init__(self, values, ignore_year, repeat, index_name=""): """ if len(values) < 1: raise ParameterValueFormatError("Time series too short. Must have one or more values") - super().__init__(index_name if index_name else TimeSeries.DEFAULT_INDEX_NAME, values) + super().__init__(values, value_type=float, index_name=index_name) self._ignore_year = ignore_year self._repeat = repeat @@ -1102,123 +1252,29 @@ def repeat(self, repeat): """ self._repeat = bool(repeat) - @staticmethod - def type_(): - return "time_series" - - def to_dict(self): - raise NotImplementedError() - - -def _check_time_pattern_index(union_str): - """ - Checks if a time pattern index has the right format. - - Args: - union_str (str): The time pattern index to check. Generally assumed to be a union of interval intersections. - - Raises: - ParameterValueFormatError: If the given string doesn't comply with time pattern spec. - """ - if not union_str: - # We accept empty strings so we can add empty rows in the parameter value editor UI - return - union_dlm = "," - intersection_dlm = ";" - range_dlm = "-" - regexp = r"(Y|M|D|WD|h|m|s)" - for intersection_str in union_str.split(union_dlm): - for interval_str in intersection_str.split(intersection_dlm): - m = re.match(regexp, interval_str) - if m is None: - raise ParameterValueFormatError( - f"Invalid interval {interval_str}, it should start with either Y, M, D, WD, h, m, or s." - ) - key = m.group(0) - lower_upper_str = interval_str[len(key) :] - lower_upper = lower_upper_str.split(range_dlm) - if len(lower_upper) != 2: - raise ParameterValueFormatError( - f"Invalid interval bounds {lower_upper_str}, it should be two integers separated by dash (-)." - ) - lower_str, upper_str = lower_upper - try: - lower = int(lower_str) - except: - raise ParameterValueFormatError(f"Invalid lower bound {lower_str}, must be an integer.") - try: - upper = int(upper_str) - except: - raise ParameterValueFormatError(f"Invalid upper bound {upper_str}, must be an integer.") - if lower > upper: - raise ParameterValueFormatError(f"Lower bound {lower} can't be higher than upper bound {upper}.") - - -class _TimePatternIndexes(_Indexes): - """An array of *checked* time pattern indexes.""" - - def __array_finalize__(self, obj): - """Checks indexes when building the array.""" - for x in obj: - _check_time_pattern_index(x) - super().__array_finalize__(obj) - - def __eq__(self, other): - return list(self) == list(other) - - def __setitem__(self, position, index): - """Checks indexes when setting and item.""" - _check_time_pattern_index(index) - super().__setitem__(position, index) - - -class TimePattern(IndexedNumberArray): - """A time-pattern parameter value.""" - - VALUE_TYPE = "time pattern" - DEFAULT_INDEX_NAME = "p" + @IndexedValue.values.setter + def values(self, values): + """Sets the values. - def __init__(self, indexes, values, index_name=""): - """ Args: - indexes (list): a list of time pattern strings - values (Sequence): the value for each time pattern. - index_name (str): index name + values (:class:`~numpy.ndarray`) """ - if len(indexes) != len(values): - raise ParameterValueFormatError("Length of values does not match length of indexes") - if not indexes: - raise ParameterValueFormatError("Empty time pattern not allowed") - super().__init__(index_name if index_name else TimePattern.DEFAULT_INDEX_NAME, values) - self.indexes = indexes - - def __eq__(self, other): - if not isinstance(other, TimePattern): - return NotImplemented - return ( - self._indexes == other._indexes - and np.all(self._values == other._values) - and self.index_name == other.index_name - ) - - @IndexedNumberArray.indexes.setter - def indexes(self, indexes): - self._indexes = _TimePatternIndexes(indexes, dtype=np.object_) + if not isinstance(values, np.ndarray) or not values.dtype == np.dtype(float): + values = np.array(values, dtype=float) + self._values = values @staticmethod def type_(): - return "time_pattern" + return "time_series" def to_dict(self): - value_dict = {"data": dict(zip(self._indexes, self._values))} - if self.index_name != "p": - value_dict["index_name"] = self.index_name - return value_dict + raise NotImplementedError() class TimeSeriesFixedResolution(TimeSeries): """ - A time-series with fixed durations between the time stamps. + A parameter value of type 'time_series'. + A mapping from time stamps to numerical values, with fixed durations between the time stamps. When getting the indexes the durations are applied cyclically. @@ -1231,8 +1287,8 @@ class TimeSeriesFixedResolution(TimeSeries): def __init__(self, start, resolution, values, ignore_year, repeat, index_name=""): """ Args: - start (str or :class:`~datetime.datetime` or :class:`numpy.datetime64`): the first time stamp - resolution (str, :class:`dateutil.relativedelta.relativedelta`, list): duration(s) between the time stamps. + start (str or :class:`~datetime.datetime` or :class:`~numpy.datetime64`): the first time stamp + resolution (str, :class:`~dateutil.relativedelta.relativedelta`, list): duration(s) between the time stamps. values (Sequence): the values in the time-series. ignore_year (bool): True if the year should be ignored. repeat (bool): True if the series is repeating. @@ -1366,7 +1422,9 @@ def to_dict(self): class TimeSeriesVariableResolution(TimeSeries): - """A time-series with variable time steps.""" + """A parameter value of type 'time_series'. + A mapping from time stamps to numerical values with arbitrary time steps. + """ def __init__(self, indexes, values, ignore_year, repeat, index_name=""): """ @@ -1420,7 +1478,9 @@ def to_dict(self): class Map(IndexedValue): - """A nested general purpose indexed value.""" + """A parameter value of type 'map'. A mapping from key to value, where the values can be other instances + of :class:`ParameterValue`. + """ VALUE_TYPE = "map" DEFAULT_INDEX_NAME = "x" @@ -1439,7 +1499,7 @@ def __init__(self, indexes, values, index_type=None, index_name=""): raise ParameterValueFormatError('Type of index does not match "index_type" argument.') if len(indexes) != len(values): raise ParameterValueFormatError("Length of values does not match length of indexes") - super().__init__(index_name if index_name else Map.DEFAULT_INDEX_NAME) + super().__init__(values, index_name=index_name) self.indexes = indexes self._index_type = index_type if index_type is not None else type(indexes[0]) self._values = values diff --git a/spinedb_api/spine_db_server.py b/spinedb_api/spine_db_server.py index 140a31a9..4ac5b7d8 100644 --- a/spinedb_api/spine_db_server.py +++ b/spinedb_api/spine_db_server.py @@ -127,8 +127,8 @@ def get_current_server_version(): return _current_server_version -def _parse_value(v, value_type=None): - return (v, value_type) +def _parse_value(v, type_=None): + return (v, type_) def _unparse_value(value_and_type): diff --git a/tests/spine_io/test_excel_integration.py b/tests/spine_io/test_excel_integration.py index 8b500b07..7ca1a318 100644 --- a/tests/spine_io/test_excel_integration.py +++ b/tests/spine_io/test_excel_integration.py @@ -29,12 +29,12 @@ class TestExcelIntegration(unittest.TestCase): def test_array(self): array = b'{"type": "array", "data": [1, 2, 3]}' - array = from_database(array, value_type="array") + array = from_database(array, type_="array") self._check_parameter_value(array) def test_time_series(self): ts = b'{"type": "time_series", "index": {"start": "1999-12-31 23:00:00", "resolution": "1h"}, "data": [0.1, 0.2]}' - ts = from_database(ts, value_type="time_series") + ts = from_database(ts, type_="time_series") self._check_parameter_value(ts) def test_map(self): @@ -97,7 +97,7 @@ def test_map(self): ], } ).encode("UTF8") - map_ = from_database(map_, value_type="map") + map_ = from_database(map_, type_="map") self._check_parameter_value(map_) def _check_parameter_value(self, val): diff --git a/tests/test_parameter_value.py b/tests/test_parameter_value.py index c4a0339f..8ae63ef5 100644 --- a/tests/test_parameter_value.py +++ b/tests/test_parameter_value.py @@ -32,7 +32,6 @@ Array, DateTime, Duration, - IndexedNumberArray, Map, TimePattern, TimeSeriesFixedResolution, @@ -143,13 +142,13 @@ def test_relativedelta_to_duration_years(self): def test_from_database_plain_number(self): database_value = b"23.0" - value = from_database(database_value, value_type=None) + value = from_database(database_value, type_=None) self.assertTrue(isinstance(value, float)) self.assertEqual(value, 23.0) def test_from_database_boolean(self): database_value = b"true" - value = from_database(database_value, value_type=None) + value = from_database(database_value, type_=None) self.assertTrue(isinstance(value, bool)) self.assertEqual(value, True) @@ -169,7 +168,7 @@ def test_to_database_DateTime(self): def test_from_database_DateTime(self): database_value = b'{"data": "2019-06-01T22:15:00+01:00"}' - value = from_database(database_value, value_type="date_time") + value = from_database(database_value, type_="date_time") self.assertEqual(value.value, dateutil.parser.parse("2019-06-01T22:15:00+01:00")) def test_DateTime_to_database(self): @@ -181,17 +180,17 @@ def test_DateTime_to_database(self): def test_from_database_Duration(self): database_value = b'{"data": "4 seconds"}' - value = from_database(database_value, value_type="duration") + value = from_database(database_value, type_="duration") self.assertEqual(value.value, relativedelta(seconds=4)) def test_from_database_Duration_default_units(self): database_value = b'{"data": 23}' - value = from_database(database_value, value_type="duration") + value = from_database(database_value, type_="duration") self.assertEqual(value.value, relativedelta(minutes=23)) def test_from_database_Duration_legacy_list_format_converted_to_Array(self): database_value = b'{"data": ["1 hour", "1h", 60, "2 hours"]}' - value = from_database(database_value, value_type="duration") + value = from_database(database_value, type_="duration") expected = Array([Duration("1h"), Duration("1h"), Duration("1h"), Duration("2h")]) self.assertEqual(value, expected) @@ -211,7 +210,7 @@ def test_from_database_TimePattern(self): } } """ - value = from_database(database_value, value_type="time_pattern") + value = from_database(database_value, type_="time_pattern") self.assertEqual(len(value), 2) self.assertEqual(value.indexes, ["m1-4,m9-12", "m5-8"]) numpy.testing.assert_equal(value.values, numpy.array([300.0, 221.5])) @@ -226,7 +225,7 @@ def test_from_database_TimePattern_with_index_name(self): } } """ - value = from_database(database_value, value_type="time_pattern") + value = from_database(database_value, type_="time_pattern") self.assertEqual(value.indexes, ["M1-12"]) numpy.testing.assert_equal(value.values, numpy.array([300.0])) self.assertEqual(value.index_name, "index") @@ -267,7 +266,7 @@ def test_from_database_TimeSeriesVariableResolution_as_dictionary(self): "1983-05-25": 6 } }""" - time_series = from_database(releases, value_type="time_series") + time_series = from_database(releases, type_="time_series") self.assertEqual( time_series.indexes, numpy.array( @@ -288,7 +287,7 @@ def test_from_database_TimeSeriesVariableResolution_as_dictionary_with_index_nam }, "index_name": "index" }""" - time_series = from_database(releases, value_type="time_series") + time_series = from_database(releases, type_="time_series") self.assertEqual(time_series.index_name, "index") def test_from_database_TimeSeriesVariableResolution_as_two_column_array(self): @@ -299,7 +298,7 @@ def test_from_database_TimeSeriesVariableResolution_as_two_column_array(self): ["1983-05-25", 6] ] }""" - time_series = from_database(releases, value_type="time_series") + time_series = from_database(releases, type_="time_series") self.assertEqual( time_series.indexes, numpy.array( @@ -320,7 +319,7 @@ def test_from_database_TimeSeriesVariableResolution_as_two_column_array_with_ind ], "index_name": "index" }""" - time_series = from_database(releases, value_type="time_series") + time_series = from_database(releases, type_="time_series") self.assertEqual(time_series.index_name, "index") def test_from_database_TimeSeriesFixedResolution_default_repeat(self): @@ -331,7 +330,7 @@ def test_from_database_TimeSeriesFixedResolution_default_repeat(self): "data": [["2019-07-02T10:00:00", 7.0], ["2019-07-02T10:00:01", 4.0]] }""" - time_series = from_database(database_value, value_type="time_series") + time_series = from_database(database_value, type_="time_series") self.assertTrue(time_series.ignore_year) self.assertFalse(time_series.repeat) @@ -378,7 +377,7 @@ def test_from_database_TimeSeriesFixedResolution(self): }, "data": [7.0, 5.0, 8.1] }""" - time_series = from_database(days_of_our_lives, value_type="time_series") + time_series = from_database(days_of_our_lives, type_="time_series") self.assertEqual(len(time_series), 3) self.assertEqual( time_series.indexes, @@ -401,7 +400,7 @@ def test_from_database_TimeSeriesFixedResolution_no_index(self): "data": [1, 2, 3, 4, 5, 8] } """ - time_series = from_database(database_value, value_type="time_series") + time_series = from_database(database_value, type_="time_series") self.assertEqual(len(time_series), 6) self.assertEqual( time_series.indexes, @@ -430,7 +429,7 @@ def test_from_database_TimeSeriesFixedResolution_index_name(self): "index_name": "index" } """ - time_series = from_database(database_value, value_type="time_series") + time_series = from_database(database_value, type_="time_series") self.assertEqual(time_series.index_name, "index") def test_from_database_TimeSeriesFixedResolution_resolution_list(self): @@ -443,7 +442,7 @@ def test_from_database_TimeSeriesFixedResolution_resolution_list(self): }, "data": [7.0, 5.0, 8.1, -4.1] }""" - time_series = from_database(database_value, value_type="time_series") + time_series = from_database(database_value, type_="time_series") self.assertEqual(len(time_series), 4) self.assertEqual( time_series.indexes, @@ -473,7 +472,7 @@ def test_from_database_TimeSeriesFixedResolution_default_resolution_is_1hour(sel }, "data": [7.0, 5.0, 8.1] }""" - time_series = from_database(database_value, value_type="time_series") + time_series = from_database(database_value, type_="time_series") self.assertEqual(len(time_series), 3) self.assertEqual(len(time_series.resolution), 1) self.assertEqual(time_series.resolution[0], relativedelta(hours=1)) @@ -486,7 +485,7 @@ def test_from_database_TimeSeriesFixedResolution_default_resolution_unit_is_minu }, "data": [7.0, 5.0, 8.1] }""" - time_series = from_database(database_value, value_type="time_series") + time_series = from_database(database_value, type_="time_series") self.assertEqual(len(time_series), 3) self.assertEqual(len(time_series.resolution), 1) self.assertEqual(time_series.resolution[0], relativedelta(minutes=30)) @@ -497,7 +496,7 @@ def test_from_database_TimeSeriesFixedResolution_default_resolution_unit_is_minu }, "data": [7.0, 5.0, 8.1] }""" - time_series = from_database(database_value, value_type="time_series") + time_series = from_database(database_value, type_="time_series") self.assertEqual(len(time_series), 3) self.assertEqual(len(time_series.resolution), 2) self.assertEqual(time_series.resolution[0], relativedelta(minutes=30)) @@ -513,7 +512,7 @@ def test_from_database_TimeSeriesFixedResolution_default_ignore_year(self): }, "data": [7.0, 5.0, 8.1] }""" - time_series = from_database(database_value, value_type="time_series") + time_series = from_database(database_value, type_="time_series") self.assertFalse(time_series.ignore_year) # Should be true if start is omitted database_value = b"""{ @@ -523,7 +522,7 @@ def test_from_database_TimeSeriesFixedResolution_default_ignore_year(self): }, "data": [7.0, 5.0, 8.1] }""" - time_series = from_database(database_value, value_type="time_series") + time_series = from_database(database_value, type_="time_series") self.assertTrue(time_series.ignore_year) def test_TimeSeriesFixedResolution_to_database(self): @@ -602,7 +601,7 @@ def test_TimeSeriesVariableResolution_init_conversion(self): def test_from_database_Map_with_index_name(self): database_value = b'{"index_type":"str", "index_name": "index", "data":[["a", 1.1]]}' - value = from_database(database_value, value_type="map") + value = from_database(database_value, type_="map") self.assertIsInstance(value, Map) self.assertEqual(value.indexes, ["a"]) self.assertEqual(value.values, [1.1]) @@ -610,7 +609,7 @@ def test_from_database_Map_with_index_name(self): def test_from_database_Map_dictionary_format(self): database_value = b'{"index_type":"str", "data":{"a": 1.1, "b": 2.2}}' - value = from_database(database_value, value_type="map") + value = from_database(database_value, type_="map") self.assertIsInstance(value, Map) self.assertEqual(value.indexes, ["a", "b"]) self.assertEqual(value.values, [1.1, 2.2]) @@ -618,7 +617,7 @@ def test_from_database_Map_dictionary_format(self): def test_from_database_Map_two_column_array_format(self): database_value = b'{"index_type":"float", "data":[[1.1, "a"], [2.2, "b"]]}' - value = from_database(database_value, value_type="map") + value = from_database(database_value, type_="map") self.assertIsInstance(value, Map) self.assertEqual(value.indexes, [1.1, 2.2]) self.assertEqual(value.values, ["a", "b"]) @@ -632,7 +631,7 @@ def test_from_database_Map_nested_maps(self): "index_type": "date_time", "data": {"2020-01-01T12:00": {"type":"duration", "data":"3 hours"}}}]] }''' - value = from_database(database_value, value_type="map") + value = from_database(database_value, type_="map") self.assertEqual(value.indexes, [Duration("1 hour")]) nested_map = value.values[0] self.assertIsInstance(nested_map, Map) @@ -648,7 +647,7 @@ def test_from_database_Map_with_TimeSeries_values(self): } ]] }''' - value = from_database(database_value, value_type="map") + value = from_database(database_value, type_="map") self.assertEqual(value.indexes, [Duration("1 hour")]) self.assertEqual( value.values, @@ -661,7 +660,7 @@ def test_from_database_Map_with_Array_values(self): "index_type": "duration", "data":[["1 hour", {"type": "array", "data": [-3.0, -9.3]}]] }''' - value = from_database(database_value, value_type="map") + value = from_database(database_value, type_="map") self.assertEqual(value.indexes, [Duration("1 hour")]) self.assertEqual(value.values, [Array([-3.0, -9.3])]) @@ -671,7 +670,7 @@ def test_from_database_Map_with_TimePattern_values(self): "index_type": "float", "data":[["2.3", {"type": "time_pattern", "data": {"M1-2": -9.3, "M3-12": -3.9}}]] }''' - value = from_database(database_value, value_type="map") + value = from_database(database_value, type_="map") self.assertEqual(value.indexes, [2.3]) self.assertEqual(value.values, [TimePattern(["M1-2", "M3-12"], [-9.3, -3.9])]) @@ -771,7 +770,7 @@ def test_Array_of_floats_from_database(self): "value_type": "float", "data": [1.2, 2.3] }""" - array = from_database(database_value, value_type="array") + array = from_database(database_value, type_="array") self.assertEqual(array.values, [1.2, 2.3]) self.assertEqual(array.indexes, [0, 1]) self.assertEqual(array.index_name, "i") @@ -780,7 +779,7 @@ def test_Array_of_default_value_type_from_database(self): database_value = b"""{ "data": [1.2, 2.3] }""" - array = from_database(database_value, value_type="array") + array = from_database(database_value, type_="array") self.assertEqual(array.values, [1.2, 2.3]) self.assertEqual(array.indexes, [0, 1]) self.assertEqual(array.index_name, "i") @@ -790,7 +789,7 @@ def test_Array_of_strings_from_database(self): "value_type": "str", "data": ["A", "B"] }""" - array = from_database(database_value, value_type="array") + array = from_database(database_value, type_="array") self.assertEqual(array.values, ["A", "B"]) self.assertEqual(array.indexes, [0, 1]) self.assertEqual(array.index_name, "i") @@ -800,7 +799,7 @@ def test_Array_of_DateTimes_from_database(self): "value_type": "date_time", "data": ["2020-03-25T10:34:00"] }""" - array = from_database(database_value, value_type="array") + array = from_database(database_value, type_="array") self.assertEqual(array.values, [DateTime("2020-03-25T10:34:00")]) self.assertEqual(array.indexes, [0]) self.assertEqual(array.index_name, "i") @@ -810,7 +809,7 @@ def test_Array_of_Durations_from_database(self): "value_type": "duration", "data": ["2 years", "7 seconds"] }""" - array = from_database(database_value, value_type="array") + array = from_database(database_value, type_="array") self.assertEqual(array.values, [Duration("2 years"), Duration("7s")]) self.assertEqual(array.indexes, [0, 1]) self.assertEqual(array.index_name, "i") @@ -821,7 +820,7 @@ def test_Array_from_database_with_index_name(self): "index_name": "index", "data": [1.2] }""" - array = from_database(database_value, value_type="array") + array = from_database(database_value, type_="array") self.assertEqual(array.values, [1.2]) self.assertEqual(array.indexes, [0]) self.assertEqual(array.index_name, "index") @@ -892,11 +891,11 @@ def test_TimeSeriesVariableResolution_equality(self): inequal_series = TimeSeriesVariableResolution(["2000-01-01T00:00", "2002-01-01T00:00"], [4.2, 2.4], False, True) self.assertNotEqual(series, inequal_series) - def test_IndexedValue_constructor_converts_values_to_floats(self): - value = IndexedNumberArray("", [4, -9, 11]) + def test_TimeSeries_constructor_converts_values_to_floats(self): + value = TimeSeries([4, -9, 11], False, False) self.assertEqual(value.values.dtype, np.dtype(float)) numpy.testing.assert_equal(value.values, numpy.array([4.0, -9.0, 11.0])) - value = IndexedNumberArray("", numpy.array([16, -251, 99])) + value = TimeSeries(numpy.array([16, -251, 99]), False, False) self.assertEqual(value.values.dtype, np.dtype(float)) numpy.testing.assert_equal(value.values, numpy.array([16.0, -251.0, 99.0])) From fd5497cba7f2fd99e3bfb852f8b5a53275371530 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Tue, 14 Nov 2023 11:11:56 +0200 Subject: [PATCH 188/317] Make remove_credentials_from_url() work with special characters Some special characters in passwords broke remove_credentials_from_url(). --- spinedb_api/helpers.py | 7 ++++--- tests/test_helpers.py | 35 ++++++++++++++++++++--------------- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index e68a6267..a5d7702e 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -839,10 +839,11 @@ def remove_credentials_from_url(url): Returns: str: sanitized URL """ - parsed = urlparse(url) - if parsed.username is None: + if "@" not in url: return url - return urlunparse(parsed._replace(netloc=parsed.netloc.partition("@")[-1])) + head, tail = url.rsplit("@", maxsplit=1) + scheme, credentials = head.split("://", maxsplit=1) + return scheme + "://" + tail def group_consecutive(list_of_numbers): diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 1148862e..c28b7546 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -8,37 +8,42 @@ # Public License for more details. You should have received a copy of the GNU Lesser General Public License along with # this program. If not, see . ###################################################################################################################### - -""" -Unit tests for helpers.py. - -""" +"""Unit tests for helpers.py.""" import unittest -from spinedb_api.helpers import compare_schemas, create_new_spine_database - - -class TestHelpers(unittest.TestCase): - def setUp(self): - pass +from spinedb_api.helpers import compare_schemas, create_new_spine_database, remove_credentials_from_url - def tearDown(self): - pass +class TestCreateNewSpineEngine(unittest.TestCase): def test_same_schema(self): - """Test that importing object class works""" engine1 = create_new_spine_database('sqlite://') engine2 = create_new_spine_database('sqlite://') self.assertTrue(compare_schemas(engine1, engine2)) def test_different_schema(self): - """Test that importing object class works""" engine1 = create_new_spine_database('sqlite://') engine2 = create_new_spine_database('sqlite://') engine2.execute("drop table entity") self.assertFalse(compare_schemas(engine1, engine2)) +class TestRemoveCredentialsFromUrl(unittest.TestCase): + def test_url_without_credentials_is_returned_as_is(self): + url = "mysql://example.com/db" + sanitized = remove_credentials_from_url(url) + self.assertEqual(url, sanitized) + + def test_username_and_password_are_removed(self): + url = "mysql://user:secret@example.com/db" + sanitized = remove_credentials_from_url(url) + self.assertEqual(sanitized, "mysql://example.com/db") + + def test_password_with_special_characters(self): + url = "mysql://user:p@ass://word@example.com/db" + sanitized = remove_credentials_from_url(url) + self.assertEqual(sanitized, "mysql://example.com/db") + + if __name__ == "__main__": unittest.main() From 55adad4ea0b4d74c9e1d6f8e5e79708817c45396 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Thu, 16 Nov 2023 08:13:41 +0200 Subject: [PATCH 189/317] Add unit test for adding scenario alternatives Added a unit test to investigate a bug in Toolbox. Might as well commit it to the repository. Re spine-tools/Spine-Toolbox#2417 --- tests/test_DatabaseMapping.py | 40 ++++++++++++++++++++++++++++++----- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 785efa04..80ecb8ac 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -8,11 +8,7 @@ # Public License for more details. You should have received a copy of the GNU Lesser General Public License along with # this program. If not, see . ###################################################################################################################### - -""" -Unit tests for DatabaseMapping class. - -""" +""" Unit tests for DatabaseMapping class. """ import os.path from tempfile import TemporaryDirectory import unittest @@ -287,6 +283,40 @@ def test_fetch_entities_that_refer_to_unfetched_entities(self): entities = db_map.get_items("entity") self.assertEqual(len(entities), 3) + def test_committing_scenario_alternatives(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + item, error = db_map.add_alternative_item(name="alt1") + self.assertIsNone(error) + self.assertIsNotNone(item) + item, error = db_map.add_alternative_item(name="alt2") + self.assertIsNone(error) + self.assertIsNotNone(item) + item, error = db_map.add_scenario_item(name="my_scenario") + self.assertIsNone(error) + self.assertIsNotNone(item) + item, error = db_map.add_scenario_alternative_item( + scenario_name="my_scenario", alternative_name="alt1", rank=0 + ) + self.assertIsNone(error) + self.assertIsNotNone(item) + item, error = db_map.add_scenario_alternative_item( + scenario_name="my_scenario", alternative_name="alt2", rank=1 + ) + self.assertIsNone(error) + self.assertIsNotNone(item) + db_map.commit_session("Add test data.") + with DatabaseMapping(url) as db_map: + scenario_alternatives = db_map.get_items("scenario_alternative") + self.assertEqual(len(scenario_alternatives), 2) + self.assertEqual(scenario_alternatives[0]["scenario_name"], "my_scenario") + self.assertEqual(scenario_alternatives[0]["alternative_name"], "alt1") + self.assertEqual(scenario_alternatives[0]["rank"], 0) + self.assertEqual(scenario_alternatives[1]["scenario_name"], "my_scenario") + self.assertEqual(scenario_alternatives[1]["alternative_name"], "alt2") + self.assertEqual(scenario_alternatives[1]["rank"], 1) + class TestDatabaseMappingLegacy(unittest.TestCase): """'Backward compatibility' tests, i.e. pre-entity tests converted to work with the entity structure.""" From 0f273997cea0d806767bdde190855e01dc4cf981 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Fri, 17 Nov 2023 10:02:06 +0200 Subject: [PATCH 190/317] Add methods to generate entity and entity class names Having functions to generate entity class names from dimension and entity names from elements makes it easier to standardize the naming conventions and change the default later on. Re spine-tools/Spine-Toolbox#2423 --- spinedb_api/helpers.py | 27 ++++++++++++- spinedb_api/mapped_items.py | 4 +- tests/export_mapping/test_export_mapping.py | 45 ++++++++++----------- tests/test_DatabaseMapping.py | 7 ++-- tests/test_export_functions.py | 2 +- tests/test_helpers.py | 24 ++++++++++- tests/test_import_functions.py | 6 +-- 7 files changed, 82 insertions(+), 33 deletions(-) diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index a5d7702e..98722b35 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -17,7 +17,6 @@ import warnings from operator import itemgetter from itertools import groupby -from urllib.parse import urlparse, urlunparse from sqlalchemy import ( Boolean, BigInteger, @@ -81,6 +80,32 @@ LONGTEXT_LENGTH = 2 ** 32 - 1 +def name_from_elements(elements): + """Creates an entity name by combining element names. + + Args: + elements (Sequence of str): element names + + Returns: + str: entity name + """ + if len(elements) == 1: + return elements[0] + "__" + return "__".join(elements) + + +def name_from_dimensions(dimensions): + """Creates an entity class name by combining dimension names. + + Args: + dimensions (Sequence of str): dimension names + + Returns: + str: entity class name + """ + return name_from_elements(dimensions) + + # NOTE: Deactivated since foreign keys are too difficult to get right in the diff tables. # For example, the diff_object table would need a `class_id` field and a `diff_class_id` field, # plus a CHECK constraint that at least one of the two is NOT NULL. diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 959545d4..0853467a 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -10,6 +10,8 @@ ###################################################################################################################### from operator import itemgetter + +from .helpers import name_from_elements from .parameter_value import to_database, from_database, ParameterValueFormatError from .db_mapping_base import MappedItemBase @@ -168,7 +170,7 @@ def polish(self): return f"element '{el_name}' is not an instance of class '{dim_name}'" if self.get("name") is not None: return - base_name = "__".join(self["element_name_list"]) + base_name = name_from_elements(self["element_name_list"]) name = base_name index = 1 while any( diff --git a/tests/export_mapping/test_export_mapping.py b/tests/export_mapping/test_export_mapping.py index 1ec8f669..826588ac 100644 --- a/tests/export_mapping/test_export_mapping.py +++ b/tests/export_mapping/test_export_mapping.py @@ -821,7 +821,7 @@ def test_export_relationships(self): element1_mapping = relationship_mapping.child = ElementMapping(4) element1_mapping.child = ElementMapping(5) expected = [ - ['rc1', 'oc1', '', 'o11', 'o11', ''], + ['rc1', 'oc1', '', 'o11__', 'o11', ''], ['rc2', 'oc2', 'oc1', 'o21__o11', 'o21', 'o11'], ['rc2', 'oc2', 'oc1', 'o21__o12', 'o21', 'o12'], ] @@ -1486,28 +1486,27 @@ def test_highlight_relationship_objects(self): db_map.close() def test_export_object_parameters_while_exporting_relationships(self): - db_map = DatabaseMapping("sqlite://", create=True) - import_object_classes(db_map, ("oc",)) - import_object_parameters(db_map, (("oc", "p"),)) - import_objects(db_map, (("oc", "o"),)) - import_object_parameter_values(db_map, (("oc", "o", "p", 23.0),)) - import_relationship_classes(db_map, (("rc", ("oc",)),)) - import_relationships(db_map, (("rc", ("o",)),)) - db_map.commit_session("Add test data") - root_mapping = unflatten( - [ - EntityClassMapping(0, highlight_position=0), - DimensionMapping(1), - EntityMapping(2), - ElementMapping(3), - ParameterDefinitionMapping(4), - AlternativeMapping(5), - ParameterValueMapping(6), - ] - ) - expected = [["rc", "oc", "o", "o", "p", "Base", 23.0]] - self.assertEqual(list(rows(root_mapping, db_map)), expected) - db_map.close() + with DatabaseMapping("sqlite://", create=True) as db_map: + import_object_classes(db_map, ("oc",)) + import_object_parameters(db_map, (("oc", "p"),)) + import_objects(db_map, (("oc", "o"),)) + import_object_parameter_values(db_map, (("oc", "o", "p", 23.0),)) + import_relationship_classes(db_map, (("rc", ("oc",)),)) + import_relationships(db_map, (("rc", ("o",)),)) + db_map.commit_session("Add test data") + root_mapping = unflatten( + [ + EntityClassMapping(0, highlight_position=0), + DimensionMapping(1), + EntityMapping(2), + ElementMapping(3), + ParameterDefinitionMapping(4), + AlternativeMapping(5), + ParameterValueMapping(6), + ] + ) + expected = [["rc", "oc", "o__", "o", "p", "Base", 23.0]] + self.assertEqual(list(rows(root_mapping, db_map)), expected) def test_export_default_values_of_object_parameters_while_exporting_relationships(self): db_map = DatabaseMapping("sqlite://", create=True) diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 80ecb8ac..b899a7e1 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -24,6 +24,7 @@ SpineDBAPIError, SpineIntegrityError, ) +from spinedb_api.helpers import name_from_elements from .custom_db_mapping import CustomDatabaseMapping @@ -676,7 +677,7 @@ def test_entity_sq(self): entity_rows = self._db_map.query(self._db_map.entity_sq).all() self.assertEqual(len(entity_rows), len(objects) + len(relationships)) object_names = [o[1] for o in objects] - relationship_names = ["__".join(r[1]) for r in relationships] + relationship_names = [name_from_elements(r[1]) for r in relationships] for row, expected_name in zip(entity_rows, object_names + relationship_names): self.assertEqual(row.name, expected_name) @@ -720,7 +721,7 @@ def test_wide_relationship_sq(self): relationship_rows = self._db_map.query(self._db_map.wide_relationship_sq).all() self.assertEqual(len(relationship_rows), 2) for row, relationship in zip(relationship_rows, relationships): - self.assertEqual(row.name, "__".join(relationship[1])) + self.assertEqual(row.name, name_from_elements(relationship[1])) self.assertEqual(row.class_name, relationship[0]) self.assertEqual(row.object_class_name_list, ",".join(object_classes[relationship[0]])) self.assertEqual(row.object_name_list, ",".join(relationship[1])) @@ -1453,7 +1454,7 @@ def test_add_entity_metadata_for_relationship(self): dict(entity_metadata[0]), { "entity_id": 2, - "entity_name": "my_object", + "entity_name": "my_object__", "metadata_name": "title", "metadata_value": "My metadata.", "metadata_id": 1, diff --git a/tests/test_export_functions.py b/tests/test_export_functions.py index ac40310d..8a55f164 100644 --- a/tests/test_export_functions.py +++ b/tests/test_export_functions.py @@ -105,7 +105,7 @@ def test_export_data(self): self.assertIn("entities", exported) self.assertEqual( exported["entities"], - [("object_class", "object", (), None), ("relationship_class", "object", ("object",), None)], + [("object_class", "object", (), None), ("relationship_class", "object__", ("object",), None)], ) self.assertIn("parameter_values", exported) self.assertEqual( diff --git a/tests/test_helpers.py b/tests/test_helpers.py index c28b7546..5ba88aac 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -12,7 +12,29 @@ import unittest -from spinedb_api.helpers import compare_schemas, create_new_spine_database, remove_credentials_from_url +from spinedb_api.helpers import ( + compare_schemas, + create_new_spine_database, + name_from_dimensions, + name_from_elements, + remove_credentials_from_url, +) + + +class TestNameFromElements(unittest.TestCase): + def test_single_element(self): + self.assertEqual(name_from_elements(("a",)), "a__") + + def test_multiple_elements(self): + self.assertEqual(name_from_elements(("a", "b")), "a__b") + + +class TestNameFromDimensions(unittest.TestCase): + def test_single_dimension(self): + self.assertEqual(name_from_dimensions(("a",)), "a__") + + def test_multiple_dimension(self): + self.assertEqual(name_from_dimensions(("a", "b")), "a__b") class TestCreateNewSpineEngine(unittest.TestCase): diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py index 242c2287..98da973d 100644 --- a/tests/test_import_functions.py +++ b/tests/test_import_functions.py @@ -354,7 +354,7 @@ def test_import_relationships(self): _, errors = import_relationships(db_map, (("relationship_class", ("object",)),)) self.assertFalse(errors) db_map.commit_session("test") - self.assertIn("object", [r.name for r in db_map.query(db_map.relationship_sq)]) + self.assertIn("object__", [r.name for r in db_map.query(db_map.relationship_sq)]) db_map.close() def test_import_valid_relationship(self): @@ -1399,7 +1399,7 @@ def test_import_relationship_parameter_value_metadata(self): dict(metadata[0]), { "alternative_name": "Base", - "entity_name": "object", + "entity_name": "object__", "id": 1, "metadata_id": 1, "metadata_name": "co-author", @@ -1413,7 +1413,7 @@ def test_import_relationship_parameter_value_metadata(self): dict(metadata[1]), { "alternative_name": "Base", - "entity_name": "object", + "entity_name": "object__", "id": 2, "metadata_id": 2, "metadata_name": "age", From 9767aaad81ca1ffb0caeb3b91fca39b785200f55 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Mon, 20 Nov 2023 16:24:58 +0200 Subject: [PATCH 191/317] Make SqlAlchemyConnector work with database schemas It is now possible to pass database schema to SqlAlchemyConnector.connect_to_source() Re spine-tools/Spine-Toolbox#2329 --- spinedb_api/spine_io/importers/csv_reader.py | 5 +++-- .../spine_io/importers/datapackage_reader.py | 3 ++- .../spine_io/importers/excel_reader.py | 9 +++++---- .../spine_io/importers/gdx_connector.py | 3 ++- spinedb_api/spine_io/importers/json_reader.py | 3 ++- spinedb_api/spine_io/importers/reader.py | 7 ++++--- .../importers/sqlalchemy_connector.py | 19 +++++++++++-------- 7 files changed, 29 insertions(+), 20 deletions(-) diff --git a/spinedb_api/spine_io/importers/csv_reader.py b/spinedb_api/spine_io/importers/csv_reader.py index ecf76682..b3381bba 100644 --- a/spinedb_api/spine_io/importers/csv_reader.py +++ b/spinedb_api/spine_io/importers/csv_reader.py @@ -46,11 +46,12 @@ def __init__(self, settings): super().__init__(settings) self._filename = None - def connect_to_source(self, source): + def connect_to_source(self, source, **extras): """saves filepath - Arguments: + Args: source (str): filepath + **extras: ignored """ self._filename = source diff --git a/spinedb_api/spine_io/importers/datapackage_reader.py b/spinedb_api/spine_io/importers/datapackage_reader.py index dadb4349..8fa1bc5b 100644 --- a/spinedb_api/spine_io/importers/datapackage_reader.py +++ b/spinedb_api/spine_io/importers/datapackage_reader.py @@ -58,11 +58,12 @@ def __setstate__(self, state): self.__dict__.update(state) self._resource_name_lock = threading.Lock() - def connect_to_source(self, source): + def connect_to_source(self, source, **extras): """Creates datapackage. Args: source (str): filepath of a datapackage.json file + **extras: ignored """ if source: self._datapackage = Package(source) diff --git a/spinedb_api/spine_io/importers/excel_reader.py b/spinedb_api/spine_io/importers/excel_reader.py index 08515314..22cabe0f 100644 --- a/spinedb_api/spine_io/importers/excel_reader.py +++ b/spinedb_api/spine_io/importers/excel_reader.py @@ -42,11 +42,12 @@ def __init__(self, settings): self._filename = None self._wb = None - def connect_to_source(self, source): - """saves filepath + def connect_to_source(self, source, **extras): + """Connects to Excel file. - Arguments: - source {str} -- filepath + Args: + source (str): path to file + **extras: ignored """ if source: self._filename = source diff --git a/spinedb_api/spine_io/importers/gdx_connector.py b/spinedb_api/spine_io/importers/gdx_connector.py index 1cb0f26f..61f28935 100644 --- a/spinedb_api/spine_io/importers/gdx_connector.py +++ b/spinedb_api/spine_io/importers/gdx_connector.py @@ -54,12 +54,13 @@ def __exit__(self, exc_type, exc_value, traceback): def __del__(self): self.disconnect() - def connect_to_source(self, source): + def connect_to_source(self, source, **extras): """ Connects to given .gdx file. Args: source (str): path to .gdx file. + **extras: ignored """ if self._gams_dir is None: raise IOError(f"Could not find GAMS directory. Make sure you have GAMS installed.") diff --git a/spinedb_api/spine_io/importers/json_reader.py b/spinedb_api/spine_io/importers/json_reader.py index 56ffc210..024b98d7 100644 --- a/spinedb_api/spine_io/importers/json_reader.py +++ b/spinedb_api/spine_io/importers/json_reader.py @@ -37,11 +37,12 @@ def __init__(self, settings): self._filename = None self._root_prefix = None - def connect_to_source(self, source): + def connect_to_source(self, source, **extras): """saves filepath Args: source (str): filepath + **extras: ignored """ self._filename = source self._root_prefix = os.path.splitext(os.path.basename(source))[0] diff --git a/spinedb_api/spine_io/importers/reader.py b/spinedb_api/spine_io/importers/reader.py index 051618b9..3a645e96 100644 --- a/spinedb_api/spine_io/importers/reader.py +++ b/spinedb_api/spine_io/importers/reader.py @@ -54,11 +54,12 @@ def __init__(self, settings): settings (dict, optional): connector specific settings or None """ - def connect_to_source(self, source): + def connect_to_source(self, source, **extras): """Connects to source, ex: connecting to a database where source is a connection string. - Arguments: - source {} -- object with information on source to be connected to, ex: filepath string for a csv connection + Args: + source (str): file path or URL to connect to + **extras: additional source specific connection data """ raise NotImplementedError() diff --git a/spinedb_api/spine_io/importers/sqlalchemy_connector.py b/spinedb_api/spine_io/importers/sqlalchemy_connector.py index cb5a6441..e187356d 100644 --- a/spinedb_api/spine_io/importers/sqlalchemy_connector.py +++ b/spinedb_api/spine_io/importers/sqlalchemy_connector.py @@ -8,11 +8,7 @@ # Public License for more details. You should have received a copy of the GNU Lesser General Public License along with # this program. If not, see . ###################################################################################################################### - -""" -Contains SqlAlchemyConnector class. - -""" +""" Contains SqlAlchemyConnector class. """ from sqlalchemy import create_engine, MetaData @@ -37,23 +33,28 @@ def __init__(self, settings): self._engine = None self._connection = None self._session = None - self._metadata = MetaData() + self._schema = None + self._metadata = None - def connect_to_source(self, source): + def connect_to_source(self, source, **extras): """Saves source. Args: source (str): url + **extras: optional database schema """ self._connection_string = source self._engine = create_engine(source) self._connection = self._engine.connect() self._session = Session(self._engine) + self._schema = extras.get("schema") + self._metadata = MetaData(schema=self._schema) self._metadata.reflect(bind=self._engine) def disconnect(self): """Disconnect from connected source.""" self._metadata = None + self._schema = None self._session.close() self._session = None self._connection.close() @@ -67,7 +68,7 @@ def get_tables(self): Returns: list of str: Table names in list """ - tables = list(self._engine.table_names()) + tables = list(self._engine.table_names(schema=self._schema)) return tables def get_data_iterator(self, table, options, max_rows=-1): @@ -81,6 +82,8 @@ def get_data_iterator(self, table, options, max_rows=-1): Returns: tuple: iterator, header, column count """ + if self._schema is not None: + table = self._schema + "." + table db_table = self._metadata.tables[table] header = [str(name) for name in db_table.columns.keys()] From 633ce6df431e56fd6f270f7165b107cc7a9a5bd3 Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 21 Nov 2023 11:13:58 +0100 Subject: [PATCH 192/317] Fix accessing table attributes using mapping names Fixes #309 Sometimes a DB mapping item_type has attributes that are not in the corresponding DB table, this is ok. --- spinedb_api/db_mapping_base.py | 7 ++++++- tests/test_DatabaseMapping.py | 25 +++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index bcc28bb2..9b589b1f 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -105,7 +105,12 @@ def _make_query(self, item_type, **kwargs): src_key, key = self.item_factory(item_type)._external_fields[key] ref_type, ref_key = self.item_factory(item_type)._references[src_key] ref_sq = self._make_sq(ref_type) - qry = qry.filter(getattr(sq.c, src_key) == getattr(ref_sq.c, ref_key), getattr(ref_sq.c, key) == value) + try: + qry = qry.filter( + getattr(sq.c, src_key) == getattr(ref_sq.c, ref_key), getattr(ref_sq.c, key) == value + ) + except AttributeError: + pass return qry def _make_sq(self, item_type): diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index b899a7e1..9b997872 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -857,6 +857,31 @@ def test_filter_query_accepts_multiple_criteria(self): self.assertEqual(entity.class_id, real_class_id) +class TestDatabaseMappingGet(unittest.TestCase): + def setUp(self): + self._db_map = CustomDatabaseMapping(IN_MEMORY_DB_URL, create=True) + + def tearDown(self): + self._db_map.close() + + def test_get_entity_alternative_items(self): + import_functions.import_data( + self._db_map, + entity_classes=(("fish",),), + entities=(("fish", "Nemo"),), + entity_alternatives=(("fish", "Nemo", "Base", True),), + ) + ea_item = self._db_map.get_entity_alternative_item( + alternative_name="Base", entity_class_name="fish", entity_byname=("Nemo",) + ) + self.assertIsNotNone(ea_item) + ea_items = self._db_map.get_entity_alternative_items( + alternative_name="Base", entity_class_name="fish", entity_byname=("Nemo",) + ) + self.assertEqual(len(ea_items), 1) + self.assertEqual(ea_items[0], ea_item) + + class TestDatabaseMappingAdd(unittest.TestCase): def setUp(self): self._db_map = CustomDatabaseMapping(IN_MEMORY_DB_URL, create=True) From 313ad5475799b01de40b17fc1ad08ee656ab58a4 Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 21 Nov 2023 11:16:26 +0100 Subject: [PATCH 193/317] Undocument the check keyword argument to DatabaseMapping methods Re #307 It's for internal use only. We might even rename it to _check in the future? --- spinedb_api/db_mapping.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index cf6d28a0..f1bbf416 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -390,7 +390,6 @@ def add_item(self, item_type, check=True, **kwargs): Args: item_type (str): One of . - check (bool, optional): Whether to carry out integrity checks. **kwargs: Fields and values as specified for the item type in :ref:`db_mapping_schema`. Returns: @@ -411,7 +410,6 @@ def add_items(self, item_type, *items, check=True, strict=False): item_type (str): One of . *items (Iterable(dict)): One or more :class:`dict` objects mapping fields to values of the item type, as specified in :ref:`db_mapping_schema`. - check (bool): Whether or not to run integrity checks. strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` if the insertion of one of the items violates an integrity constraint. @@ -442,7 +440,6 @@ def update_item(self, item_type, check=True, **kwargs): Args: item_type (str): One of . - check (bool, optional): Whether to carry out integrity checks. id (int): The id of the item to update. **kwargs: Fields to update and their new values as specified for the item type in :ref:`db_mapping_schema`. @@ -464,7 +461,6 @@ def update_items(self, item_type, *items, check=True, strict=False): item_type (str): One of . *items (Iterable(dict)): One or more :class:`dict` objects mapping fields to values of the item type, as specified in :ref:`db_mapping_schema` and including the `id`. - check (bool): Whether or not to run integrity checks. strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` if the update of one of the items violates an integrity constraint. @@ -494,7 +490,6 @@ def remove_item(self, item_type, id_, check=True): Args: item_type (str): One of . id_ (int): The id of the item to remove. - check (bool, optional): Whether to carry out integrity checks. Returns: tuple(:class:`PublicItem` or None, str): The removed item and any errors. @@ -512,7 +507,6 @@ def remove_items(self, item_type, *ids, check=True, strict=False): Args: item_type (str): One of . *ids (Iterable(int)): Ids of items to be removed. - check (bool): Whether or not to run integrity checks. strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` if the update of one of the items violates an integrity constraint. @@ -848,7 +842,6 @@ def add_{item_type}_item(self, check=True, **kwargs): """Adds {a} `{item_type}` item to the in-memory mapping. Args: - check (bool, optional): Whether to carry out integrity checks. {add_kwargs} Returns: @@ -868,7 +861,6 @@ def update_{item_type}_item(self, check=True, **kwargs): """Updates {a} `{item_type}` item in the in-memory mapping. Args: - check (bool, optional): Whether to carry out integrity checks. {update_kwargs} Returns: From 90d48543f32ad1db3b07b46c2fadf85a4795ca72 Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 21 Nov 2023 16:35:26 +0100 Subject: [PATCH 194/317] Introduce check_fields to do type checking in DatabaseMapping Re #307 --- docs/source/conf.py | 7 +- spinedb_api/db_mapping.py | 14 ++- spinedb_api/db_mapping_base.py | 27 +++++ spinedb_api/mapped_items.py | 174 +++++++++++++++++++-------------- tests/test_DatabaseMapping.py | 21 ++-- 5 files changed, 155 insertions(+), 88 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index c2bdca63..b82aec6e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -123,6 +123,9 @@ def _process_docstring(app, what, name, obj, options, lines): def _db_mapping_schema_lines(): + def type_(f_dict): + return f_dict['type'].__name__ + (', optional' if f_dict.get('optional', False) else '') + lines = [ ".. _db_mapping_schema:", "", @@ -152,8 +155,8 @@ def _db_mapping_schema_lines(): " - value", ] ) - for f_name, (f_type, f_value) in factory.fields.items(): - lines.extend([f" * - {f_name}", f" - {f_type}", f" - {f_value}"]) + for f_name, f_dict in factory.fields.items(): + lines.extend([f" * - {f_name}", f" - {type_(f_dict)}", f" - {f_dict['value']}"]) lines.append("") lines.extend([".. list-table:: Unique keys", " :header-rows: 0", ""]) for f_names in factory._unique_keys: diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index f1bbf416..eb2efab6 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -352,7 +352,9 @@ def get_item(self, item_type, fetch=True, skip_removed=True, **kwargs): :class:`PublicItem` or None """ item_type = self.real_item_type(item_type) - item = self.mapped_table(item_type).find_item(kwargs, fetch=fetch) + mapped_table = self.mapped_table(item_type) + mapped_table.check_fields(kwargs) + item = mapped_table.find_item(kwargs, fetch=fetch) if not item: return {} if skip_removed and not item.is_valid(): @@ -373,9 +375,10 @@ def get_items(self, item_type, fetch=True, skip_removed=True, **kwargs): list(:class:`PublicItem`): The items. """ item_type = self.real_item_type(item_type) + mapped_table = self.mapped_table(item_type) + mapped_table.check_fields(kwargs) if fetch: self.do_fetch_all(item_type, **kwargs) - mapped_table = self.mapped_table(item_type) get_items = mapped_table.valid_values if skip_removed else mapped_table.values return [x.public_item for x in get_items() if all(x.get(k) == v for k, v in kwargs.items())] @@ -787,7 +790,12 @@ def _uq_fields(factory): } def _kwargs(fields): - return f"\n{padding}".join([f"{f_name} ({f_type}): {f_value}" for f_name, (f_type, f_value) in fields.items()]) + def type_(f_dict): + return f_dict['type'].__name__ + (', optional' if f_dict.get('optional', False) else '') + + return f"\n{padding}".join( + [f"{f_name} ({type_(f_dict)}): {f_dict['value']}" for f_name, f_dict in fields.items()] + ) padding = 20 * " " for item_type in DatabaseMapping.item_types(): diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 9b589b1f..f522bbd0 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -417,6 +417,7 @@ def checked_item_and_error(self, item, for_update=False, skip_keys=()): error = self._prepare_item(candidate_item, current_item, item, skip_keys) if error: return None, error + self.check_fields(candidate_item._asdict()) return candidate_item, merge_error def _prepare_item(self, candidate_item, current_item, original_item, skip_keys): @@ -503,6 +504,31 @@ def add_item_from_db(self, item): item.cascade_remove(source=self.wildcard_item) return item, True + def check_fields(self, item): + factory = self._db_map.item_factory(self._item_type) + + def _error(key, value): + if key in set(factory._internal_fields) | set(factory._external_fields) | factory._private_fields | { + "id", + "commit_id", + }: + # The user seems to know what they're doing + return + f_dict = factory.fields.get(key) + if f_dict is None: + valid_args = ", ".join(factory.fields) + return f"invalid keyword argument '{key}' for '{self._item_type}' - valid arguments are {valid_args}." + valid_types = (f_dict["type"],) if not f_dict.get("optional", False) else (f_dict["type"], type(None)) + if not isinstance(value, valid_types): + return ( + f"invalid type for '{key}' of '{self._item_type}' - " + f"got {type(value).__name__}, expected {f_dict['type'].__name__}." + ) + + errors = list(filter(lambda x: x is not None, (_error(key, value) for key, value in item.items()))) + if errors: + raise SpineDBAPIError("\n".join(errors)) + def add_item(self, item): item = self._make_and_add_item(item) self.add_unique(item) @@ -570,6 +596,7 @@ class MappedItemBase(dict): Keys in _internal_fields are resolved to the reference key of the alternative reference pointed at by the source key. """ + _private_fields = set() def __init__(self, db_map, item_type, **kwargs): """ diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 0853467a..72a5eee6 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -39,10 +39,11 @@ def item_factory(item_type): class CommitItem(MappedItemBase): fields = { - "comment": ("str", "A comment describing the commit."), - "date": {"datetime", "Date and time of the commit."}, - "user": {"str", "Username of the committer."}, + 'comment': {'type': str, 'value': 'A comment describing the commit.'}, + 'date': {'type': str, 'value': 'Date and time of the commit in ISO 8601 format.'}, + 'user': {'type': str, 'value': 'Username of the committer.'}, } + _unique_keys = (("date",),) def commit(self, commit_id): @@ -51,12 +52,20 @@ def commit(self, commit_id): class EntityClassItem(MappedItemBase): fields = { - "name": ("str", "The class name."), - "dimension_name_list": ("tuple, optional", "The dimension names for a multi-dimensional class."), - "description": ("str, optional", "The class description."), - "display_icon": ("int, optional", "An integer representing an icon within your application."), - "display_order": ("int, optional", "Not in use at the moment."), - "hidden": ("bool, optional", "Not in use at the moment."), + 'name': {'type': str, 'value': 'The class name.'}, + 'dimension_name_list': { + 'type': tuple, + 'value': 'The dimension names for a multi-dimensional class.', + 'optional': True, + }, + 'description': {'type': str, 'value': 'The class description.', 'optional': True}, + 'display_icon': { + 'type': int, + 'value': 'An integer representing an icon within your application.', + 'optional': True, + }, + 'display_order': {'type': int, 'value': 'Not in use at the moment.', 'optional': True}, + 'hidden': {'type': bool, 'value': 'Not in use at the moment.', 'optional': True}, } _defaults = {"description": None, "display_icon": None, "display_order": 99, "hidden": False} _unique_keys = (("name",),) @@ -95,16 +104,17 @@ def commit(self, _commit_id): class EntityItem(MappedItemBase): fields = { - "class_name": ("str", "The entity class name."), - "name": ("str", "The entity name."), - "element_name_list": ("tuple", "The element names if the entity is multi-dimensional."), - "byname": ( - "tuple", - "A tuple with the entity name as single element if the entity is zero-dimensional, " - "or the element names if it is multi-dimensional.", - ), - "description": ("str, optional", "The entity description."), + 'class_name': {'type': str, 'value': 'The entity class name.'}, + 'name': {'type': str, 'value': 'The entity name.'}, + 'element_name_list': {'type': tuple, 'value': 'The element names if the entity is multi-dimensional.'}, + 'byname': { + 'type': tuple, + 'value': 'A tuple with the entity name as single element if the entity is zero-dimensional,' + 'or the element names if it is multi-dimensional.', + }, + 'description': {'type': str, 'value': 'The entity description.', 'optional': True}, } + _defaults = {"description": None} _unique_keys = (("class_name", "name"), ("class_name", "byname")) _references = {"class_id": ("entity_class", "id"), "element_id_list": ("entity", "id")} @@ -174,7 +184,9 @@ def polish(self): name = base_name index = 1 while any( - self._db_map.get_item("entity", class_name=self[k], name=name) for k in ("class_name", "superclass_name") + self._db_map.get_item("entity", class_name=self[k], name=name) + for k in ("class_name", "superclass_name") + if self[k] is not None ): name = f"{base_name}_{index}" index += 1 @@ -183,9 +195,9 @@ def polish(self): class EntityGroupItem(MappedItemBase): fields = { - "class_name": ("str", "The entity class name."), - "group_name": ("str", "The group entity name."), - "member_name": ("str", "The member entity name."), + 'class_name': {'type': str, 'value': 'The entity class name.'}, + 'group_name': {'type': str, 'value': 'The group entity name.'}, + 'member_name': {'type': str, 'value': 'The member entity name.'}, } _unique_keys = (("class_name", "group_name", "member_name"),) _references = { @@ -220,14 +232,18 @@ def __getitem__(self, key): class EntityAlternativeItem(MappedItemBase): fields = { - "entity_class_name": ("str", "The entity class name."), - "entity_byname": ( - "tuple", - "A tuple with the entity name as single element if the entity is zero-dimensional, " - "or the element names if it is multi-dimensional.", - ), - "alternative_name": ("str", "The alternative name."), - "active": ("bool, optional", "Whether the entity is active in the alternative - defaults to True."), + 'entity_class_name': {'type': str, 'value': 'The entity class name.'}, + 'entity_byname': { + 'type': tuple, + 'value': 'A tuple with the entity name as single element if the entity is zero-dimensional, ' + 'or the element names if it is multi-dimensional.', + }, + 'alternative_name': {'type': str, 'value': 'The alternative name.'}, + 'active': { + 'type': bool, + 'value': 'Whether the entity is active in the alternative - defaults to True.', + 'optional': True, + }, } _defaults = {"active": True} _unique_keys = (("entity_class_name", "entity_byname", "alternative_name"),) @@ -258,6 +274,8 @@ class EntityAlternativeItem(MappedItemBase): class ParsedValueBase(MappedItemBase): + _private_fields = {"list_value_id"} + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._parsed_value = None @@ -358,12 +376,16 @@ def polish(self): class ParameterDefinitionItem(ParameterItemBase): fields = { - "entity_class_name": ("str", "The entity class name."), - "name": ("str", "The parameter name."), - "default_value": ("any, optional", "The default value."), - "default_type": ("str, optional", "The default value type."), - "parameter_value_list_name": ("str, optional", "The parameter value list name if any."), - "description": ("str, optional", "The parameter description."), + 'entity_class_name': {'type': str, 'value': 'The entity class name.'}, + 'name': {'type': str, 'value': 'The parameter name.'}, + 'default_value': {'type': bytes, 'value': 'The default value.', 'optional': True}, + 'default_type': {'type': str, 'value': 'The default value type.', 'optional': True}, + 'parameter_value_list_name': { + 'type': str, + 'value': 'The parameter value list name if any.', + 'optional': True, + }, + 'description': {'type': str, 'value': 'The parameter description.', 'optional': True}, } _defaults = {"description": None, "default_value": None, "default_type": None, "parameter_value_list_id": None} _unique_keys = (("entity_class_name", "name"),) @@ -430,16 +452,16 @@ def _value_not_in_list_error(self, parsed_value, list_name): class ParameterValueItem(ParameterItemBase): fields = { - "entity_class_name": ("str", "The entity class name."), - "parameter_definition_name": ("str", "The parameter name."), - "entity_byname": ( - "tuple", - "A tuple with the entity name as single element if the entity is zero-dimensional, " - "or the element names if the entity is multi-dimensional.", - ), - "value": ("any", "The value."), - "type": ("str", "The value type."), - "alternative_name": ("str, optional", "The alternative name - defaults to 'Base'."), + 'entity_class_name': {'type': str, 'value': 'The entity class name.'}, + 'parameter_definition_name': {'type': str, 'value': 'The parameter name.'}, + 'entity_byname': { + 'type': tuple, + 'value': 'A tuple with the entity name as single element if the entity is zero-dimensional, ' + 'or the element names if the entity is multi-dimensional.', + }, + 'value': {'type': bytes, 'value': 'The value.'}, + 'type': {'type': str, 'value': 'The value type.', 'optional': True}, + 'alternative_name': {'type': str, 'value': "The alternative name - defaults to 'Base'.", 'optional': True}, } _unique_keys = (("entity_class_name", "parameter_definition_name", "entity_byname", "alternative_name"),) _references = { @@ -501,16 +523,16 @@ def _value_not_in_list_error(self, parsed_value, list_name): class ParameterValueListItem(MappedItemBase): - fields = {"name": ("str", "The parameter value list name.")} + fields = {'name': {'type': str, 'value': 'The parameter value list name.'}} _unique_keys = (("name",),) class ListValueItem(ParsedValueBase): fields = { - "parameter_value_list_name": ("str", "The parameter value list name."), - "value": ("any", "The value."), - "type": ("str", "The value type."), - "index": ("int, optional", "The value index."), + 'parameter_value_list_name': {'type': str, 'value': 'The parameter value list name.'}, + 'value': {'type': bytes, 'value': 'The value.'}, + 'type': {'type': str, 'value': 'The value type.', 'optional': True}, + 'index': {'type': int, 'value': 'The value index.', 'optional': True}, } _unique_keys = (("parameter_value_list_name", "value_and_type"), ("parameter_value_list_name", "index")) _references = {"parameter_value_list_id": ("parameter_value_list", "id")} @@ -534,8 +556,8 @@ def __getitem__(self, key): class AlternativeItem(MappedItemBase): fields = { - "name": ("str", "The alternative name."), - "description": ("str, optional", "The alternative description."), + 'name': {'type': str, 'value': 'The alternative name.'}, + 'description': {'type': str, 'value': 'The alternative description.', 'optional': True}, } _defaults = {"description": None} _unique_keys = (("name",),) @@ -543,9 +565,9 @@ class AlternativeItem(MappedItemBase): class ScenarioItem(MappedItemBase): fields = { - "name": ("str", "The scenario name."), - "description": ("str, optional", "The scenario description."), - "active": ("bool, optional", "Not in use at the moment."), + 'name': {'type': str, 'value': 'The scenario name.'}, + 'description': {'type': str, 'value': 'The scenario description.', 'optional': True}, + 'active': {'type': bool, 'value': 'Not in use at the moment.', 'optional': True}, } _defaults = {"active": False, "description": None} _unique_keys = (("name",),) @@ -570,9 +592,9 @@ def __getitem__(self, key): class ScenarioAlternativeItem(MappedItemBase): fields = { - "scenario_name": ("str", "The scenario name."), - "alternative_name": ("str", "The alternative name."), - "rank": ("int", "The rank - the higher has precedence."), + 'scenario_name': {'type': str, 'value': 'The scenario name.'}, + 'alternative_name': {'type': str, 'value': 'The alternative name.'}, + 'rank': {'type': int, 'value': 'The rank - higher has precedence.'}, } _unique_keys = (("scenario_name", "alternative_name"), ("scenario_name", "rank")) _references = {"scenario_id": ("scenario", "id"), "alternative_id": ("alternative", "id")} @@ -597,15 +619,18 @@ def __getitem__(self, key): class MetadataItem(MappedItemBase): - fields = {"name": ("str", "The metadata entry name."), "value": ("str", "The metadata entry value.")} + fields = { + 'name': {'type': str, 'value': 'The metadata entry name.'}, + 'value': {'type': str, 'value': 'The metadata entry value.'}, + } _unique_keys = (("name", "value"),) class EntityMetadataItem(MappedItemBase): fields = { - "entity_name": ("str", "The entity name."), - "metadata_name": ("str", "The metadata entry name."), - "metadata_value": ("str", "The metadata entry value."), + 'entity_name': {'type': str, 'value': 'The entity name.'}, + 'metadata_name': {'type': str, 'value': 'The metadata entry name.'}, + 'metadata_value': {'type': str, 'value': 'The metadata entry value.'}, } _unique_keys = (("entity_name", "metadata_name", "metadata_value"),) _references = {"entity_id": ("entity", "id"), "metadata_id": ("metadata", "id")} @@ -626,15 +651,15 @@ class EntityMetadataItem(MappedItemBase): class ParameterValueMetadataItem(MappedItemBase): fields = { - "parameter_definition_name": ("str", "The parameter name."), - "entity_byname": ( - "tuple", - "A tuple with the entity name as single element if the entity is zero-dimensional, " - "or the element names if it is multi-dimensional.", - ), - "alternative_name": ("str", "The alternative name."), - "metadata_name": ("str", "The metadata entry name."), - "metadata_value": ("str", "The metadata entry value."), + 'parameter_definition_name': {'type': str, 'value': 'The parameter name.'}, + 'entity_byname': { + 'type': tuple, + 'value': 'A tuple with the entity name as single element if the entity is zero-dimensional, ' + 'or the element names if it is multi-dimensional.', + }, + 'alternative_name': {'type': str, 'value': 'The alternative name.'}, + 'metadata_name': {'type': str, 'value': 'The metadata entry name.'}, + 'metadata_value': {'type': str, 'value': 'The metadata entry value.'}, } _unique_keys = ( ("parameter_definition_name", "entity_byname", "alternative_name", "metadata_name", "metadata_value"), @@ -664,7 +689,10 @@ class ParameterValueMetadataItem(MappedItemBase): class SuperclassSubclassItem(MappedItemBase): - fields = {"superclass_name": ("str", "The superclass name."), "subclass_name": ("str", "The subclass name.")} + fields = { + 'superclass_name': {'type': str, 'value': 'The superclass name.'}, + 'subclass_name': {'type': str, 'value': 'The subclass name.'}, + } _unique_keys = (("subclass_name",),) _references = {"superclass_id": ("entity_class", "id"), "subclass_id": ("entity_class", "id")} _external_fields = { diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 9b997872..174c4137 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -864,6 +864,14 @@ def setUp(self): def tearDown(self): self._db_map.close() + def test_get_entity_class_items_check_fields(self): + import_functions.import_data(self._db_map, entity_classes=(("fish",),)) + with self.assertRaises(SpineDBAPIError): + self._db_map.get_entity_class_item(class_name="fish") + with self.assertRaises(SpineDBAPIError): + self._db_map.get_entity_class_item(name=("fish",)) + self._db_map.get_entity_class_item(name="fish") + def test_get_entity_alternative_items(self): import_functions.import_data( self._db_map, @@ -1555,7 +1563,7 @@ def test_add_parameter_value_metadata(self): import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) self._db_map.commit_session("Add test data.") items, errors = self._db_map.add_parameter_value_metadata( - {"parameter_value_id": 1, "metadata_id": 1, "alternative_id": 1}, strict=False + {"parameter_value_id": 1, "metadata_id": 1}, strict=False ) self.assertEqual(errors, []) self.assertEqual(len(items), 1) @@ -1591,13 +1599,7 @@ def test_add_ext_parameter_value_metadata(self): import_functions.import_object_parameter_values(self._db_map, (("fish", "leviathan", "paranormality", 3.9),)) self._db_map.commit_session("Add test data.") items, errors = self._db_map.add_ext_parameter_value_metadata( - { - "parameter_value_id": 1, - "metadata_name": "key", - "metadata_value": "parameter metadata", - "alternative_id": 1, - }, - strict=False, + {"parameter_value_id": 1, "metadata_name": "key", "metadata_value": "parameter metadata"}, strict=False ) self.assertEqual(errors, []) self.assertEqual(len(items), 1) @@ -1627,8 +1629,7 @@ def test_add_ext_parameter_value_metadata_reuses_existing_metadata(self): import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) self._db_map.commit_session("Add test data.") items, errors = self._db_map.add_ext_parameter_value_metadata( - {"parameter_value_id": 1, "metadata_name": "title", "metadata_value": "My metadata.", "alternative_id": 1}, - strict=False, + {"parameter_value_id": 1, "metadata_name": "title", "metadata_value": "My metadata."}, strict=False ) self.assertEqual(errors, []) self.assertEqual(len(items), 1) From 32a23a82dd9afeedcd9172caa6044943727afc76 Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 21 Nov 2023 17:21:01 +0100 Subject: [PATCH 195/317] Don't update equivalent entities in import_data Fixes #314 --- spinedb_api/db_mapping_base.py | 11 +++++++++-- tests/test_import_functions.py | 21 +++++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index f522bbd0..e8d8a21c 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -574,7 +574,8 @@ class MappedItemBase(dict): """A dictionary that represents a db item.""" fields = {} - """A dictionary mapping keys to a tuple of (type, value description)""" + """A dictionary mapping keys to a another dict mapping "type" to a Python type, + "value" to a description of the value for the key, and "optional" to a bool.""" _defaults = {} """A dictionary mapping keys to their default values""" _unique_keys = () @@ -597,6 +598,7 @@ class MappedItemBase(dict): source key. """ _private_fields = set() + """A set with fields that should be ignored in validations.""" def __init__(self, db_map, item_type, **kwargs): """ @@ -741,7 +743,12 @@ def _something_to_update(self, other): def _convert(x): return tuple(x) if isinstance(x, list) else x - return not all(_convert(self.get(key)) == _convert(value) for key, value in other.items()) + return not all( + _convert(self.get(key)) == _convert(value) + for key, value in other.items() + if value is not None + or self.fields.get(key, {}).get("optional", False) # Ignore mandatory fields that are None + ) def first_invalid_key(self): """Goes through the ``_references`` class attribute and returns the key of the first reference diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py index 98da973d..b6ebfd2c 100644 --- a/tests/test_import_functions.py +++ b/tests/test_import_functions.py @@ -340,6 +340,27 @@ def test_import_existing_relationship_class_parameter(self): db_map.close() +class TestImportEntity(unittest.TestCase): + def test_import_multi_d_entity_twice(self): + db_map = DatabaseMapping("sqlite://", create=True) + import_data( + db_map, + entity_classes=( + ("object_class1",), + ("object_class2",), + ("relationship_class", ("object_class1", "object_class2")), + ), + entities=( + ("object_class1", "object1"), + ("object_class2", "object2"), + ("relationship_class", ("object1", "object2")), + ), + ) + count, errors = import_data(db_map, entities=(("relationship_class", ("object1", "object2")),)) + self.assertEqual(count, 0) + self.assertEqual(errors, []) + + class TestImportRelationship(unittest.TestCase): @staticmethod def populate(db_map): From 511a32ececc6133803e9a223cfb14862a8c01f00 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 22 Nov 2023 12:09:26 +0100 Subject: [PATCH 196/317] Don't complain about None values when checking keyword args Client code (toolbox) seems to like sending None with keyword args sometimes. To simplify our lives we don't check the type of those args in that case. --- spinedb_api/db_mapping.py | 4 ++-- spinedb_api/db_mapping_base.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index eb2efab6..b575763d 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -353,7 +353,7 @@ def get_item(self, item_type, fetch=True, skip_removed=True, **kwargs): """ item_type = self.real_item_type(item_type) mapped_table = self.mapped_table(item_type) - mapped_table.check_fields(kwargs) + mapped_table.check_fields(kwargs, valid_types=(type(None),)) item = mapped_table.find_item(kwargs, fetch=fetch) if not item: return {} @@ -376,7 +376,7 @@ def get_items(self, item_type, fetch=True, skip_removed=True, **kwargs): """ item_type = self.real_item_type(item_type) mapped_table = self.mapped_table(item_type) - mapped_table.check_fields(kwargs) + mapped_table.check_fields(kwargs, valid_types=(type(None),)) if fetch: self.do_fetch_all(item_type, **kwargs) get_items = mapped_table.valid_values if skip_removed else mapped_table.values diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index e8d8a21c..7a81f65e 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -504,7 +504,7 @@ def add_item_from_db(self, item): item.cascade_remove(source=self.wildcard_item) return item, True - def check_fields(self, item): + def check_fields(self, item, valid_types=()): factory = self._db_map.item_factory(self._item_type) def _error(key, value): @@ -518,7 +518,9 @@ def _error(key, value): if f_dict is None: valid_args = ", ".join(factory.fields) return f"invalid keyword argument '{key}' for '{self._item_type}' - valid arguments are {valid_args}." - valid_types = (f_dict["type"],) if not f_dict.get("optional", False) else (f_dict["type"], type(None)) + valid_types = valid_types + (f_dict["type"],) + if f_dict.get("optional", False): + valid_types = valid_types + (type(None),) if not isinstance(value, valid_types): return ( f"invalid type for '{key}' of '{self._item_type}' - " From 4c85065336592dea7823f1ac2bfb5bf55efec667 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 22 Nov 2023 12:17:38 +0100 Subject: [PATCH 197/317] Very minor code refactoring --- spinedb_api/db_mapping_base.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 7a81f65e..e908156e 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -811,14 +811,15 @@ def resolve_internal_fields(self, skip_keys=()): Returns: str or None: error description if any. """ - for key, (src_key, target_key) in self._internal_fields.items(): + for key in self._internal_fields: if key in skip_keys: continue - error = self._do_resolve_internal_field(key, src_key, target_key) + error = self._do_resolve_internal_field(key) if error: return error - def _do_resolve_internal_field(self, key, src_key, target_key): + def _do_resolve_internal_field(self, key): + src_key, target_key = self._internal_fields[key] src_val = tuple(dict.pop(self, k, None) or self.get(k) for k in src_key) if None in src_val: return From d4839dc6ffee02bbfc1b71197486d6fc7e055919 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 22 Nov 2023 12:38:56 +0100 Subject: [PATCH 198/317] Fix check_fields Somehow keyword args are not visible from a local function? --- spinedb_api/db_mapping_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index e908156e..5462c303 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -507,7 +507,7 @@ def add_item_from_db(self, item): def check_fields(self, item, valid_types=()): factory = self._db_map.item_factory(self._item_type) - def _error(key, value): + def _error(key, value, valid_types): if key in set(factory._internal_fields) | set(factory._external_fields) | factory._private_fields | { "id", "commit_id", @@ -527,7 +527,7 @@ def _error(key, value): f"got {type(value).__name__}, expected {f_dict['type'].__name__}." ) - errors = list(filter(lambda x: x is not None, (_error(key, value) for key, value in item.items()))) + errors = list(filter(lambda x: x is not None, (_error(key, value, valid_types) for key, value in item.items()))) if errors: raise SpineDBAPIError("\n".join(errors)) From 7749109b9ded363055cc07758425d00cf1f965aa Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 23 Nov 2023 08:58:09 +0100 Subject: [PATCH 199/317] Add entities with multi-d dimensions by specifying all zero-d elements ...and adapt superclass to that too. --- spinedb_api/import_functions.py | 20 ++++++-------- spinedb_api/mapped_items.py | 49 +++++++++++++++++++++++++++++---- tests/test_import_functions.py | 46 +++++++++++++++++++++++++++---- 3 files changed, 92 insertions(+), 23 deletions(-) diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 97c7ef0a..60664127 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -177,11 +177,11 @@ def get_data_for_import( alternatives = list({item[1]: None for item in scenario_alternatives}) yield ("alternative", _get_alternatives_for_import(db_map, alternatives)) yield ("scenario_alternative", _get_scenario_alternatives_for_import(db_map, scenario_alternatives)) - if superclass_subclasses: - yield ("superclass_subclass", _get_parameter_superclass_subclasses_for_import(db_map, superclass_subclasses)) if entity_classes: for bucket in _get_entity_classes_for_import(db_map, entity_classes): yield ("entity_class", bucket) + if superclass_subclasses: + yield ("superclass_subclass", _get_parameter_superclass_subclasses_for_import(db_map, superclass_subclasses)) if entities: for bucket in _get_entities_for_import(db_map, entities): yield ("entity", bucket) @@ -544,19 +544,15 @@ def _ref_count(name): def _get_entities_for_import(db_map, data): items_by_el_count = {} - key = ("class_name", "name", "element_name_list", "description") + key = ("class_name", "byname", "description") for class_name, name_or_el_name_list, *optionals in data: if isinstance(name_or_el_name_list, (list, tuple)): - name = None - el_name_list = name_or_el_name_list + el_count = len(name_or_el_name_list) + byname = name_or_el_name_list else: - name = name_or_el_name_list - if optionals and isinstance(optionals[0], (list, tuple)): - el_name_list = tuple(optionals.pop(0)) - else: - el_name_list = () - item = dict(zip(key, (class_name, name, el_name_list, *optionals))) - el_count = len(el_name_list) + el_count = 0 + byname = (name_or_el_name_list,) + item = dict(zip(key, (class_name, byname, *optionals))) items_by_el_count.setdefault(el_count, []).append(item) return ( _get_items_for_import(db_map, "entity", items_by_el_count[el_count]) for el_count in sorted(items_by_el_count) diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 72a5eee6..67e8b70e 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -73,6 +73,7 @@ class EntityClassItem(MappedItemBase): _external_fields = {"dimension_name_list": ("dimension_id_list", "name")} _alt_references = {("dimension_name_list",): ("entity_class", ("name",))} _internal_fields = {"dimension_id_list": (("dimension_name_list",), "id")} + _private_fields = {"superclass_id", "superclass_name"} def __init__(self, *args, **kwargs): dimension_id_list = kwargs.get("dimension_id_list") @@ -153,22 +154,58 @@ def unique_values_for_item(cls, item, skip_keys=()): if None not in value: yield key, value - def _element_name_list_iter(self, entity): + def _byname_iter(self, entity): element_id_list = entity["element_id_list"] if not element_id_list: yield entity["name"] else: for el_id in element_id_list: - element = self._get_ref("entity", {"id", el_id}) - yield from self._element_name_list_iter(element) + element = self._get_ref("entity", {"id": el_id}) + yield from self._byname_iter(element) def __getitem__(self, key): - if key == "root_element_name_list": - return tuple(self._element_name_list_iter(self)) if key == "byname": - return self["element_name_list"] or (self["name"],) + return tuple(self._byname_iter(self)) return super().__getitem__(key) + def resolve_internal_fields(self, skip_keys=()): + error = super().resolve_internal_fields(skip_keys=skip_keys) + if error: + return error + byname = dict.pop(self, "byname", None) + if byname is None: + return + if not self["dimension_id_list"]: + self["name"] = byname[0] + return + byname_remainder = list(byname) + _, self["element_name_list"] = self._element_name_list_recursive(self["class_name"], byname_remainder) + return self._do_resolve_internal_field("element_id_list") + + def _element_name_list_recursive(self, class_name, byname_remainder): + class_names = [class_name] + [ + x["subclass_name"] for x in self._db_map.get_items("superclass_subclass", superclass_name=class_name) + ] + for class_name_ in class_names: + dimension_name_list = self._db_map.get_item("entity_class", name=class_name_).get("dimension_name_list") + if not dimension_name_list: + continue + byname_remainder_backup = list(byname_remainder) + element_name_list = tuple( + self._db_map.get_item( + "entity", + **dict( + zip(("class_name", "byname"), self._element_name_list_recursive(dim_name, byname_remainder)) + ), + ).get("name") + for dim_name in dimension_name_list + ) + if None not in element_name_list: + return class_name_, element_name_list + byname_remainder = byname_remainder_backup + name = byname_remainder.pop(0) + return class_name, (name,) + def polish(self): error = super().polish() if error: diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py index b6ebfd2c..3c50ffa4 100644 --- a/tests/test_import_functions.py +++ b/tests/test_import_functions.py @@ -440,7 +440,7 @@ def test_import_relationship_with_one_None_object(self): self.assertFalse([r.name for r in db_map.query(db_map.relationship_sq)]) db_map.close() - def test_import_relationship_of_relationships(self): + def test_import_multi_d_entity_with_multi_d_elements(self): db_map = create_db_map() self.populate(db_map) import_data( @@ -450,13 +450,44 @@ def test_import_relationship_of_relationships(self): ["relationship_class2", ["object_class2", "object_class1"]], ["meta_relationship_class", ["relationship_class1", "relationship_class2"]], ], - entities=[ - ["relationship_class1", "object1__object2", ["object1", "object2"]], - ["relationship_class2", "object2__object1", ["object2", "object1"]], + entities=[["relationship_class1", ["object1", "object2"]], ["relationship_class2", ["object2", "object1"]]], + ) + _, errors = import_data( + db_map, entities=[["meta_relationship_class", ["object1", "object2", "object2", "object1"]]] + ) + self.assertFalse(errors) + db_map.commit_session("test") + entities = { + tuple(r.element_name_list.split(",")) if r.element_name_list else r.name: r.name + for r in db_map.query(db_map.wide_entity_sq) + } + self.assertTrue("object1" in entities) + self.assertTrue("object2" in entities) + self.assertTrue(("object1", "object2") in entities) + self.assertTrue(("object2", "object1") in entities) + self.assertTrue((entities["object1", "object2"], entities["object2", "object1"]) in entities) + self.assertEqual(len(entities), 5) + + def test_import_multi_d_entity_with_multi_d_elements_from_superclass(self): + db_map = create_db_map() + self.populate(db_map) + import_data( + db_map, + entity_classes=[ + ["relationship_class1", ["object_class1", "object_class2"]], + ["relationship_class2", ["object_class2", "object_class1"]], + ["superclass", []], ], + superclass_subclasses=[["superclass", "relationship_class1"], ["superclass", "relationship_class2"]], + ) + import_data( + db_map, + entity_classes=[["meta_relationship_class", ["superclass", "superclass"]]], + entities=[["relationship_class1", ["object1", "object2"]], ["relationship_class2", ["object2", "object1"]]], ) + print("NOE") _, errors = import_data( - db_map, entities=[["meta_relationship_class", ["object1__object2", "object2__object1"]]] + db_map, entities=[["meta_relationship_class", ["object1", "object2", "object2", "object1"]]] ) self.assertFalse(errors) db_map.commit_session("test") @@ -470,6 +501,11 @@ def test_import_relationship_of_relationships(self): self.assertTrue(("object2", "object1") in entities) self.assertTrue((entities["object1", "object2"], entities["object2", "object1"]) in entities) self.assertEqual(len(entities), 5) + # _, errors = import_data(db_map, entities=[["meta_relationship_class", ["object1", "object2", "object2"]]]) + # _, errors = import_data( + # db_map, entities=[["meta_relationship_class", ["object1", "object2", "object2", "object1", "object2"]]] + # ) + # print(errors) class TestImportParameterDefinition(unittest.TestCase): From 10820b8d70e0237449b222c2da942fe84cdde6f1 Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 23 Nov 2023 11:10:59 +0100 Subject: [PATCH 200/317] Add unique key (superclass_name, byname) for entity --- spinedb_api/mapped_items.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 67e8b70e..50642332 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -147,12 +147,12 @@ def __init__(self, *args, **kwargs): @classmethod def unique_values_for_item(cls, item, skip_keys=()): - yield from super().unique_values_for_item(item, skip_keys=skip_keys) - key = ("class_name", "name") - if key not in skip_keys: - value = tuple(item.get(k) for k in ("superclass_name", "name")) - if None not in value: - yield key, value + """Overriden to also yield unique values for the superclass.""" + for key, value in super().unique_values_for_item(item, skip_keys=skip_keys): + yield key, value + sc_value = tuple(item.get("superclass_name" if k == "class_name" else k) for k in key) + if None not in sc_value: + yield (key, sc_value) def _byname_iter(self, entity): element_id_list = entity["element_id_list"] From 158d7a8641cb7aebdce6d77b58c50dc2957f2396 Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 23 Nov 2023 11:12:08 +0100 Subject: [PATCH 201/317] Error whenever element count is not right --- spinedb_api/mapped_items.py | 23 ++++++---- tests/test_import_functions.py | 77 +++++++++++++++++++++++++++++++--- 2 files changed, 85 insertions(+), 15 deletions(-) diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 50642332..6e80a250 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -73,7 +73,6 @@ class EntityClassItem(MappedItemBase): _external_fields = {"dimension_name_list": ("dimension_id_list", "name")} _alt_references = {("dimension_name_list",): ("entity_class", ("name",))} _internal_fields = {"dimension_id_list": (("dimension_name_list",), "id")} - _private_fields = {"superclass_id", "superclass_name"} def __init__(self, *args, **kwargs): dimension_id_list = kwargs.get("dimension_id_list") @@ -175,17 +174,23 @@ def resolve_internal_fields(self, skip_keys=()): byname = dict.pop(self, "byname", None) if byname is None: return - if not self["dimension_id_list"]: + dim_count = len(self["dimension_id_list"]) + if not dim_count: self["name"] = byname[0] return byname_remainder = list(byname) - _, self["element_name_list"] = self._element_name_list_recursive(self["class_name"], byname_remainder) + element_name_list, _ = self._element_name_list_recursive(self["class_name"], byname_remainder) + if len(element_name_list) < dim_count: + return f"too few elements given for entity ({byname})" + if byname_remainder: + return f"too many elements given for entity ({byname})" + self["element_name_list"] = element_name_list return self._do_resolve_internal_field("element_id_list") def _element_name_list_recursive(self, class_name, byname_remainder): - class_names = [class_name] + [ + class_names = [ x["subclass_name"] for x in self._db_map.get_items("superclass_subclass", superclass_name=class_name) - ] + ] or [class_name] for class_name_ in class_names: dimension_name_list = self._db_map.get_item("entity_class", name=class_name_).get("dimension_name_list") if not dimension_name_list: @@ -195,16 +200,16 @@ def _element_name_list_recursive(self, class_name, byname_remainder): self._db_map.get_item( "entity", **dict( - zip(("class_name", "byname"), self._element_name_list_recursive(dim_name, byname_remainder)) + zip(("byname", "class_name"), self._element_name_list_recursive(dim_name, byname_remainder)) ), ).get("name") for dim_name in dimension_name_list ) if None not in element_name_list: - return class_name_, element_name_list + return element_name_list, class_name_ byname_remainder = byname_remainder_backup - name = byname_remainder.pop(0) - return class_name, (name,) + name = byname_remainder.pop(0) if byname_remainder else None + return (name,), class_name def polish(self): error = super().polish() diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py index 3c50ffa4..c3f02d48 100644 --- a/tests/test_import_functions.py +++ b/tests/test_import_functions.py @@ -440,6 +440,51 @@ def test_import_relationship_with_one_None_object(self): self.assertFalse([r.name for r in db_map.query(db_map.relationship_sq)]) db_map.close() + def test_import_multi_d_entity_with_elements_from_superclass(self): + db_map = create_db_map() + import_data( + db_map, + entity_classes=[ + ["object_class1", []], + ["object_class2", []], + ["superclass", []], + ["relationship_class1", ["superclass", "superclass"]], + ], + superclass_subclasses=[["superclass", "object_class1"], ["superclass", "object_class2"]], + entities=[["object_class1", "object1"], ["object_class2", "object2"]], + ) + _, errors = import_data(db_map, entities=[["relationship_class1", ["object1", "object2"]]]) + self.assertFalse(errors) + db_map.commit_session("test") + entities = { + tuple(r.element_name_list.split(",")) if r.element_name_list else r.name: r.name + for r in db_map.query(db_map.wide_entity_sq) + } + self.assertTrue("object1" in entities) + self.assertTrue("object2" in entities) + self.assertTrue(("object1", "object2") in entities) + self.assertEqual(len(entities), 3) + + def test_import_multi_d_entity_with_elements_from_superclass_fails_with_wrong_dimension_count(self): + db_map = create_db_map() + import_data( + db_map, + entity_classes=[ + ["object_class1", []], + ["object_class2", []], + ["superclass", []], + ["relationship_class1", ["superclass", "superclass"]], + ], + superclass_subclasses=[["superclass", "object_class1"], ["superclass", "object_class2"]], + entities=[["object_class1", "object1"], ["object_class2", "object2"]], + ) + _, errors = import_data(db_map, entities=[["relationship_class1", ["object1"]]]) + self.assertEqual(len(errors), 1) + self.assertIn("too few elements", errors[0]) + _, errors = import_data(db_map, entities=[["relationship_class1", ["object1", "object2", "object1"]]]) + self.assertEqual(len(errors), 1) + self.assertIn("too many elements", errors[0]) + def test_import_multi_d_entity_with_multi_d_elements(self): db_map = create_db_map() self.populate(db_map) @@ -485,7 +530,6 @@ def test_import_multi_d_entity_with_multi_d_elements_from_superclass(self): entity_classes=[["meta_relationship_class", ["superclass", "superclass"]]], entities=[["relationship_class1", ["object1", "object2"]], ["relationship_class2", ["object2", "object1"]]], ) - print("NOE") _, errors = import_data( db_map, entities=[["meta_relationship_class", ["object1", "object2", "object2", "object1"]]] ) @@ -501,11 +545,32 @@ def test_import_multi_d_entity_with_multi_d_elements_from_superclass(self): self.assertTrue(("object2", "object1") in entities) self.assertTrue((entities["object1", "object2"], entities["object2", "object1"]) in entities) self.assertEqual(len(entities), 5) - # _, errors = import_data(db_map, entities=[["meta_relationship_class", ["object1", "object2", "object2"]]]) - # _, errors = import_data( - # db_map, entities=[["meta_relationship_class", ["object1", "object2", "object2", "object1", "object2"]]] - # ) - # print(errors) + + def test_import_multi_d_entity_with_multi_d_elements_from_superclass_fails_with_wrong_dimension_count(self): + db_map = create_db_map() + self.populate(db_map) + import_data( + db_map, + entity_classes=[ + ["relationship_class1", ["object_class1", "object_class2"]], + ["relationship_class2", ["object_class2", "object_class1"]], + ["superclass", []], + ], + superclass_subclasses=[["superclass", "relationship_class1"], ["superclass", "relationship_class2"]], + ) + import_data( + db_map, + entity_classes=[["meta_relationship_class", ["superclass", "superclass"]]], + entities=[["relationship_class1", ["object1", "object2"]], ["relationship_class2", ["object2", "object1"]]], + ) + _, errors = import_data(db_map, entities=[["meta_relationship_class", ["object1", "object2", "object2"]]]) + self.assertEqual(len(errors), 1) + self.assertIn("too few elements", errors[0]) + _, errors = import_data( + db_map, entities=[["meta_relationship_class", ["object1", "object2", "object2", "object1", "object1"]]] + ) + self.assertEqual(len(errors), 1) + self.assertIn("too many elements", errors[0]) class TestImportParameterDefinition(unittest.TestCase): From ab01cc811ef58dfff94bbd06626dca96a159d84a Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 23 Nov 2023 12:16:22 +0100 Subject: [PATCH 202/317] Add element_byname_list key to entity --- spinedb_api/mapped_items.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 6e80a250..0a22b9ad 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -125,6 +125,7 @@ class EntityItem(MappedItemBase): "superclass_id": ("class_id", "superclass_id"), "superclass_name": ("class_id", "superclass_name"), "element_name_list": ("element_id_list", "name"), + "element_byname_list": ("element_id_list", "byname"), } _alt_references = { ("class_name",): ("entity_class", ("name",)), @@ -168,6 +169,7 @@ def __getitem__(self, key): return super().__getitem__(key) def resolve_internal_fields(self, skip_keys=()): + """Overriden to translate byname into element name list.""" error = super().resolve_internal_fields(skip_keys=skip_keys) if error: return error @@ -187,7 +189,11 @@ def resolve_internal_fields(self, skip_keys=()): self["element_name_list"] = element_name_list return self._do_resolve_internal_field("element_id_list") - def _element_name_list_recursive(self, class_name, byname_remainder): + def _element_name_list_recursive(self, class_name, byname): + """Returns the element name list corresponding to given class and byname. + If the class is multi-dimensional then recurses for each dimension. + If the class is a superclass then it tries for each subclass until finding something useful. + """ class_names = [ x["subclass_name"] for x in self._db_map.get_items("superclass_subclass", superclass_name=class_name) ] or [class_name] @@ -195,20 +201,18 @@ def _element_name_list_recursive(self, class_name, byname_remainder): dimension_name_list = self._db_map.get_item("entity_class", name=class_name_).get("dimension_name_list") if not dimension_name_list: continue - byname_remainder_backup = list(byname_remainder) + byname_backup = list(byname) element_name_list = tuple( self._db_map.get_item( "entity", - **dict( - zip(("byname", "class_name"), self._element_name_list_recursive(dim_name, byname_remainder)) - ), + **dict(zip(("byname", "class_name"), self._element_name_list_recursive(dim_name, byname))), ).get("name") for dim_name in dimension_name_list ) if None not in element_name_list: return element_name_list, class_name_ - byname_remainder = byname_remainder_backup - name = byname_remainder.pop(0) if byname_remainder else None + byname = byname_backup + name = byname.pop(0) if byname else None return (name,), class_name def polish(self): From dd0d4dc7504edb0759351acedc5c66ed38294653 Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 23 Nov 2023 14:35:12 +0100 Subject: [PATCH 203/317] Fix type of hidden and add dimension_count to private fields of ent_cls --- spinedb_api/mapped_items.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 0a22b9ad..ba1a3430 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -65,7 +65,7 @@ class EntityClassItem(MappedItemBase): 'optional': True, }, 'display_order': {'type': int, 'value': 'Not in use at the moment.', 'optional': True}, - 'hidden': {'type': bool, 'value': 'Not in use at the moment.', 'optional': True}, + 'hidden': {'type': int, 'value': 'Not in use at the moment.', 'optional': True}, } _defaults = {"description": None, "display_icon": None, "display_order": 99, "hidden": False} _unique_keys = (("name",),) @@ -73,6 +73,7 @@ class EntityClassItem(MappedItemBase): _external_fields = {"dimension_name_list": ("dimension_id_list", "name")} _alt_references = {("dimension_name_list",): ("entity_class", ("name",))} _internal_fields = {"dimension_id_list": (("dimension_name_list",), "id")} + _private_fields = {"dimension_count"} def __init__(self, *args, **kwargs): dimension_id_list = kwargs.get("dimension_id_list") From 83f94e3bf02999193849c89d69317cc5f2cb99dd Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 24 Nov 2023 14:22:43 +0100 Subject: [PATCH 204/317] Introduce add_update and use it in import_data --- spinedb_api/db_mapping.py | 64 +++--- spinedb_api/db_mapping_base.py | 13 +- spinedb_api/import_functions.py | 319 ++++++++++---------------- tests/filters/test_scenario_filter.py | 25 +- tests/test_import_functions.py | 39 ++-- 5 files changed, 195 insertions(+), 265 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index b575763d..8b0280cf 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -382,6 +382,19 @@ def get_items(self, item_type, fetch=True, skip_removed=True, **kwargs): get_items = mapped_table.valid_values if skip_removed else mapped_table.values return [x.public_item for x in get_items() if all(x.get(k) == v for k, v in kwargs.items())] + @staticmethod + def _modify_items(function, *items, strict=False): + modified, errors = [], [] + for item in items: + item, error = function(item) + if error: + if strict: + raise SpineIntegrityError(error) + errors.append(error) + if item: + modified.append(item) + return modified, errors + def add_item(self, item_type, check=True, **kwargs): """Adds an item to the in-memory mapping. @@ -419,16 +432,7 @@ def add_items(self, item_type, *items, check=True, strict=False): Returns: tuple(list(:class:`PublicItem`),list(str)): items successfully added and found violations. """ - added, errors = [], [] - for item in items: - item, error = self.add_item(item_type, check, **item) - if error: - if strict: - raise SpineIntegrityError(error) - errors.append(error) - continue - added.append(item) - return added, errors + return self._modify_items(lambda x: self.add_item(item_type, check=check, **x), *items, strict=strict) def update_item(self, item_type, check=True, **kwargs): """Updates an item in the in-memory mapping. @@ -470,16 +474,25 @@ def update_items(self, item_type, *items, check=True, strict=False): Returns: tuple(list(:class:`PublicItem`),list(str)): items successfully updated and found violations. """ - updated, errors = [], [] - for item in items: - item, error = self.update_item(item_type, check=check, **item) - if error: - if strict: - raise SpineIntegrityError(error) - errors.append(error) - if item: - updated.append(item) - return updated, errors + return self._modify_items(lambda x: self.update_item(item_type, check=check, **x), *items, strict=strict) + + def add_update_item(self, item_type, check=True, **kwargs): + added, add_error = self.add_item(item_type, check=check, **kwargs) + if not add_error: + return (added, None), add_error + updated, update_error = self.update_item(item_type, check=check, **kwargs) + if not update_error: + return (None, updated), update_error + return (None, None), add_error or update_error + + def add_update_items(self, item_type, *items, check=True, strict=False): + added_updated, errors = self._modify_items( + lambda x: self.add_update_item(item_type, check=check, **x), *items, strict=strict + ) + added, updated = zip(*added_updated) if added_updated else ([], []) + added = [x for x in added if x] + updated = [x for x in updated if x] + return added, updated, errors def remove_item(self, item_type, id_, check=True): """Removes an item from the in-memory mapping. @@ -523,16 +536,7 @@ def remove_items(self, item_type, *ids, check=True, strict=False): ids.discard(1) if not ids: return [], [] - removed, errors = [], [] - for id_ in ids: - item, error = self.remove_item(item_type, id_, check=check) - if error: - if strict: - raise SpineIntegrityError(error) - errors.append(error) - if item: - removed.append(item) - return removed, errors + return self._modify_items(lambda x: self.remove_item(item_type, x, check=check), *ids, strict=strict) def cascade_remove_items(self, cache=None, **kwargs): # Legacy diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 5462c303..705cbbb1 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -400,11 +400,9 @@ def find_item_by_unique_key(self, item, skip_keys=(), fetch=True, complete=True) return current_item return {} - def checked_item_and_error(self, item, for_update=False, skip_keys=()): - # FIXME: The only use-case for skip_keys at the moment is that of importing scenario alternatives, - # where we only want to match by (scen_name, alt_name) and not by (scen_name, rank) + def checked_item_and_error(self, item, for_update=False): if for_update: - current_item = self.find_item(item, skip_keys=skip_keys) + current_item = self.find_item(item) if not current_item: return None, f"no {self._item_type} matching {item} to update" full_item, merge_error = current_item.merge(item) @@ -414,20 +412,19 @@ def checked_item_and_error(self, item, for_update=False, skip_keys=()): current_item = None full_item, merge_error = item, None candidate_item = self._make_item(full_item) - error = self._prepare_item(candidate_item, current_item, item, skip_keys) + error = self._prepare_item(candidate_item, current_item, item) if error: return None, error self.check_fields(candidate_item._asdict()) return candidate_item, merge_error - def _prepare_item(self, candidate_item, current_item, original_item, skip_keys): + def _prepare_item(self, candidate_item, current_item, original_item): """Prepares item for insertion or update, returns any errors. Args: candidate_item (MappedItem) current_item (MappedItem) original_item (dict) - skip_keys (optional, tuple) Returns: str or None: errors if any. @@ -445,7 +442,7 @@ def _prepare_item(self, candidate_item, current_item, original_item, skip_keys): if first_invalid_key: return f"invalid {first_invalid_key} for {self._item_type}" try: - for key, value in candidate_item.unique_key_values(skip_keys=skip_keys): + for key, value in candidate_item.unique_key_values(): empty = {k for k, v in zip(key, value) if v == ""} if empty: return f"invalid empty keys {empty} for {self._item_type}" diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 60664127..2fbcbee0 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -91,12 +91,12 @@ def import_data(db_map, unparse_value=to_database, on_conflict="merge", **kwargs """ all_errors = [] num_imports = 0 - for tablename, (to_add, to_update, errors) in get_data_for_import( - db_map, unparse_value=unparse_value, on_conflict=on_conflict, **kwargs - ): - updated, _ = db_map.update_items(tablename, *to_update, check=False) - added, _ = db_map.add_items(tablename, *to_add, check=False) - num_imports += len(added) + len(updated) + for item_type, items in get_data_for_import(db_map, unparse_value=unparse_value, on_conflict=on_conflict, **kwargs): + if isinstance(items, tuple): + items, input_errors = items + all_errors.extend(input_errors) + added, updated, errors = db_map.add_update_items(item_type, *items, strict=False) + num_imports += len(added + updated) all_errors.extend(errors) return num_imports, all_errors @@ -170,18 +170,12 @@ def get_data_for_import( if scenarios: yield ("scenario", _get_scenarios_for_import(db_map, scenarios)) if scenario_alternatives: - if not scenarios: - scenarios = list({item[0]: None for item in scenario_alternatives}) - yield ("scenario", _get_scenarios_for_import(db_map, scenarios)) - if not alternatives: - alternatives = list({item[1]: None for item in scenario_alternatives}) - yield ("alternative", _get_alternatives_for_import(db_map, alternatives)) yield ("scenario_alternative", _get_scenario_alternatives_for_import(db_map, scenario_alternatives)) if entity_classes: for bucket in _get_entity_classes_for_import(db_map, entity_classes): yield ("entity_class", bucket) if superclass_subclasses: - yield ("superclass_subclass", _get_parameter_superclass_subclasses_for_import(db_map, superclass_subclasses)) + yield ("superclass_subclass", _get_superclass_subclasses_for_import(db_map, superclass_subclasses)) if entities: for bucket in _get_entities_for_import(db_map, entities): yield ("entity", bucket) @@ -205,9 +199,13 @@ def get_data_for_import( if metadata: yield ("metadata", _get_metadata_for_import(db_map, metadata)) if entity_metadata: - yield ("metadata", _get_metadata_for_import(db_map, (metadata for _, _, metadata in entity_metadata))) + yield ("metadata", _get_metadata_for_import(db_map, (ent_metadata[2] for ent_metadata in entity_metadata))) yield ("entity_metadata", _get_entity_metadata_for_import(db_map, entity_metadata)) if parameter_value_metadata: + yield ( + "metadata", + _get_metadata_for_import(db_map, (pval_metadata[3] for pval_metadata in parameter_value_metadata)), + ) yield ("parameter_value_metadata", _get_parameter_value_metadata_for_import(db_map, parameter_value_metadata)) # Legacy if object_classes: @@ -469,53 +467,6 @@ def import_relationship_parameter_value_metadata(db_map, data): return import_data(db_map, relationship_parameter_value_metadata=data) -def _get_items_for_import(db_map, item_type, data, check_skip_keys=()): - mapped_table = db_map.mapped_table(item_type) - errors = [] - to_add = [] - to_update = [] - seen = {} - for item in data: - checked_item, add_error = mapped_table.checked_item_and_error(item, skip_keys=check_skip_keys) - if not add_error: - if not _check_unique(item_type, checked_item, seen, errors): - continue - to_add.append(checked_item) - continue - checked_item, update_error = mapped_table.checked_item_and_error( - item, for_update=True, skip_keys=check_skip_keys - ) - if not update_error: - if checked_item: - if not _check_unique(item_type, checked_item, seen, errors): - continue - to_update.append(checked_item) - continue - errors.append(add_error) - return to_add, to_update, errors - - -def _check_unique(item_type, checked_item, seen, errors): - dupe_key = _add_to_seen(checked_item, seen) - if not dupe_key: - return True - if item_type in ("parameter_value",): - errors.append(f"attempting to import more than one {item_type} with {dupe_key} - only first will be considered") - return False - - -def _add_to_seen(checked_item, seen): - for key, value in checked_item.unique_key_values(): - if value in seen.get(key, set()): - return dict(zip(key, value)) - seen.setdefault(key, set()).add(value) - - -def _get_parameter_superclass_subclasses_for_import(db_map, data): - key = ("superclass_name", "subclass_name") - return _get_items_for_import(db_map, "superclass_subclass", (dict(zip(key, x)) for x in data)) - - def _get_entity_classes_for_import(db_map, data): dim_name_list_by_name = {} items = [] @@ -536,10 +487,12 @@ def _ref_count(name): items_by_ref_count = {} for item in items: items_by_ref_count.setdefault(_ref_count(item["name"]), []).append(item) - return ( - _get_items_for_import(db_map, "entity_class", items_by_ref_count[ref_count]) - for ref_count in sorted(items_by_ref_count) - ) + return (items_by_ref_count[ref_count] for ref_count in sorted(items_by_ref_count)) + + +def _get_superclass_subclasses_for_import(db_map, data): + key = ("superclass_name", "subclass_name") + return (dict(zip(key, x)) for x in data) def _get_entities_for_import(db_map, data): @@ -554,180 +507,149 @@ def _get_entities_for_import(db_map, data): byname = (name_or_el_name_list,) item = dict(zip(key, (class_name, byname, *optionals))) items_by_el_count.setdefault(el_count, []).append(item) - return ( - _get_items_for_import(db_map, "entity", items_by_el_count[el_count]) for el_count in sorted(items_by_el_count) - ) + return (items_by_el_count[el_count] for el_count in sorted(items_by_el_count)) def _get_entity_alternatives_for_import(db_map, data): - def _data_iterator(): - for class_name, entity_name_or_element_name_list, alternative, active in data: - is_zero_dim = isinstance(entity_name_or_element_name_list, str) - entity_byname = (entity_name_or_element_name_list,) if is_zero_dim else entity_name_or_element_name_list - key = ("entity_class_name", "entity_byname", "alternative_name", "active") - yield dict(zip(key, (class_name, entity_byname, alternative, active))) - - return _get_items_for_import(db_map, "entity_alternative", _data_iterator()) + for class_name, entity_name_or_element_name_list, alternative, active in data: + is_zero_dim = isinstance(entity_name_or_element_name_list, str) + entity_byname = (entity_name_or_element_name_list,) if is_zero_dim else entity_name_or_element_name_list + key = ("entity_class_name", "entity_byname", "alternative_name", "active") + yield dict(zip(key, (class_name, entity_byname, alternative, active))) def _get_entity_groups_for_import(db_map, data): key = ("class_name", "group_name", "member_name") - return _get_items_for_import(db_map, "entity_group", (dict(zip(key, x)) for x in data)) + return (dict(zip(key, x)) for x in data) def _get_parameter_definitions_for_import(db_map, data, unparse_value): - def _data_iterator(): - for class_name, parameter_name, *optionals in data: - if not optionals: - yield class_name, parameter_name - continue - value = optionals.pop(0) - value, type_ = unparse_value(value) - yield class_name, parameter_name, value, type_, *optionals - key = ("entity_class_name", "name", "default_value", "default_type", "parameter_value_list_name", "description") - return _get_items_for_import(db_map, "parameter_definition", (dict(zip(key, x)) for x in _data_iterator())) + for class_name, parameter_name, *optionals in data: + if not optionals: + yield dict(zip(key, (class_name, parameter_name))) + continue + value = optionals.pop(0) + value, type_ = unparse_value(value) + yield dict(zip(key, (class_name, parameter_name, value, type_, *optionals))) def _get_parameter_values_for_import(db_map, data, unparse_value, on_conflict): - def _data_iterator(): - for class_name, entity_byname, parameter_name, value, *optionals in data: - if isinstance(entity_byname, str): - entity_byname = (entity_byname,) - alternative_name = optionals[0] if optionals else db_map.get_import_alternative_name() - value, type_ = unparse_value(value) - item = { - "entity_class_name": class_name, - "entity_byname": entity_byname, - "parameter_definition_name": parameter_name, - "alternative_name": alternative_name, - "value": None, - "type": None, - } - pv = db_map.mapped_table("parameter_value").find_item(item) - if pv: - value, type_ = fix_conflict((value, type_), (pv["value"], pv["type"]), on_conflict) - item.update({"value": value, "type": type_}) - yield item - - return _get_items_for_import(db_map, "parameter_value", _data_iterator()) + seen = set() + errors = [] + items = [] + key = ("entity_class_name", "entity_byname", "parameter_definition_name", "alternative_name", "value", "type") + for class_name, entity_byname, parameter_name, value, *optionals in data: + if isinstance(entity_byname, str): + entity_byname = (entity_byname,) + else: + entity_byname = tuple(entity_byname) + alternative_name = optionals[0] if optionals else db_map.get_import_alternative_name() + unique_values = (class_name, entity_byname, parameter_name, alternative_name) + if unique_values in seen: + dupe = dict(zip(key, unique_values)) + errors.append( + f"attempting to import more than one parameter_value with {dupe} - only first will be considered" + ) + continue + seen.add(unique_values) + value, type_ = unparse_value(value) + item = dict(zip(key, unique_values + (None, None))) + pv = db_map.mapped_table("parameter_value").find_item(item) + if pv: + value, type_ = fix_conflict((value, type_), (pv["value"], pv["type"]), on_conflict) + item.update({"value": value, "type": type_}) + items.append(item) + return items, errors def _get_alternatives_for_import(db_map, data): key = ("name", "description") - return _get_items_for_import( - db_map, "alternative", ({"name": x} if isinstance(x, str) else dict(zip(key, x)) for x in data) - ) + return ({"name": x} if isinstance(x, str) else dict(zip(key, x)) for x in data) def _get_scenarios_for_import(db_map, data): key = ("name", "active", "description") - return _get_items_for_import( - db_map, "scenario", ({"name": x} if isinstance(x, str) else dict(zip(key, x)) for x in data) - ) + return ({"name": x} if isinstance(x, str) else dict(zip(key, x)) for x in data) def _get_scenario_alternatives_for_import(db_map, data): + # FIXME: maybe when updating, we only want to match by (scen_name, alt_name) and not by (scen_name, rank) alt_name_list_by_scen_name = {} - errors = [] - successors_by_scen_name = defaultdict(dict) - for scen_name, alt_name, *optionals in data: - successors_by_scen_name[scen_name][alt_name] = optionals[0] if optionals else None - for scen_name, successors in successors_by_scen_name.items(): + succ_by_pred_by_scen_name = defaultdict(dict) + for scen_name, predecessor, *optionals in data: + successor = optionals[0] if optionals else None + succ_by_pred_by_scen_name[scen_name][predecessor] = successor + for scen_name, succ_by_pred in succ_by_pred_by_scen_name.items(): scen = db_map.mapped_table("scenario").find_item({"name": scen_name}) - if not scen: - errors.append(f"no scenario with name {scen_name} to set alternatives for") - continue - alt_names = set(successors) - alternative_name_list = alt_name_list_by_scen_name[scen_name] = [ - a for a in scen["alternative_name_list"] if a not in alt_names - ] - for predecessor, successor in list(successors.items()): - if successor is None: - alternative_name_list.append(predecessor) - del successors[predecessor] - predecessors = {successor: predecessor for predecessor, successor in successors.items()} - predecessor_errors = [] - for predecessor in predecessors: - if predecessor not in successors and predecessor not in alternative_name_list: - predecessor_errors.append(f"{predecessor} is not in {scen_name}") - if predecessor_errors: - errors += predecessor_errors - continue - while predecessors: - for i, alt_name in enumerate(alternative_name_list): - if (predecessor := predecessors.pop(alt_name, None)) is not None: - alternative_name_list.insert(i, predecessor) - break - - def _data_iterator(): - for scen_name, alternative_name_list in alt_name_list_by_scen_name.items(): - for k, alt_name in enumerate(alternative_name_list): - yield {"scenario_name": scen_name, "alternative_name": alt_name, "rank": k + 1} - - to_add, to_update, more_errors = _get_items_for_import( - db_map, "scenario_alternative", _data_iterator(), check_skip_keys=(("scenario_name", "rank"),) + alternative_name_list = alt_name_list_by_scen_name[scen_name] = scen.get("alternative_name_list", []) + alternative_name_list.append(None) # So alternatives where successor is None find their place at the tail + while succ_by_pred: + some_added = False + for pred, succ in list(succ_by_pred.items()): + if succ in alternative_name_list: + i = alternative_name_list.index(succ) + if pred in alternative_name_list: + alternative_name_list.remove(pred) + alternative_name_list.insert(i, pred) + del succ_by_pred[pred] + some_added = True + if not some_added: + break + alternative_name_list.pop(-1) # Remove the None + items = ( + {"scenario_name": scen_name, "alternative_name": alt_name, "rank": k + 1} + for scen_name, alternative_name_list in alt_name_list_by_scen_name.items() + for k, alt_name in enumerate(alternative_name_list) + ) + errors = ( + f"can't insert alternative '{pred}' before '{succ}' because the latter is not in scenario '{scen_name}'" + for scen, succ_by_pred in succ_by_pred_by_scen_name.items() + for pred, succ in succ_by_pred.items() ) - return to_add, to_update, errors + more_errors + return items, errors def _get_parameter_value_lists_for_import(db_map, data): - return _get_items_for_import(db_map, "parameter_value_list", ({"name": x} for x in {x[0]: None for x in data})) + return ({"name": x} for x in {x[0]: None for x in data}) def _get_list_values_for_import(db_map, data, unparse_value): - def _data_iterator(): - index_by_list_name = {} - for list_name, value in data: - value, type_ = unparse_value(value) - index = index_by_list_name.get(list_name) - if index is None: - current_list = db_map.mapped_table("parameter_value_list").find_item({"name": list_name}) - index = max( - ( - x["index"] - for x in db_map.mapped_table("list_value").valid_values() - if x["parameter_value_list_id"] == current_list["id"] - ), - default=-1, - ) - index += 1 - index_by_list_name[list_name] = index - yield {"parameter_value_list_name": list_name, "value": value, "type": type_, "index": index} - - return _get_items_for_import(db_map, "list_value", _data_iterator()) + index_by_list_name = {} + for list_name, value in data: + value, type_ = unparse_value(value) + index = index_by_list_name.get(list_name) + if index is None: + current_list = db_map.mapped_table("parameter_value_list").find_item({"name": list_name}) + index = max( + ( + x["index"] + for x in db_map.mapped_table("list_value").valid_values() + if x["parameter_value_list_id"] == current_list["id"] + ), + default=-1, + ) + index += 1 + index_by_list_name[list_name] = index + yield {"parameter_value_list_name": list_name, "value": value, "type": type_, "index": index} def _get_metadata_for_import(db_map, data): - def _data_iterator(): - for metadata in data: - for name, value in _parse_metadata(metadata): - yield {"name": name, "value": value} - - return _get_items_for_import(db_map, "metadata", _data_iterator()) + for metadata in data: + for name, value in _parse_metadata(metadata): + yield {"name": name, "value": value} def _get_entity_metadata_for_import(db_map, data): - def _data_iterator(): - for class_name, entity_byname, metadata in data: - if isinstance(entity_byname, str): - entity_byname = (entity_byname,) - for name, value in _parse_metadata(metadata): - yield (class_name, entity_byname, name, value) - key = ("entity_class_name", "entity_byname", "metadata_name", "metadata_value") - return _get_items_for_import(db_map, "entity_metadata", (dict(zip(key, x)) for x in _data_iterator())) + for class_name, entity_byname, metadata in data: + if isinstance(entity_byname, str): + entity_byname = (entity_byname,) + for name, value in _parse_metadata(metadata): + yield dict(zip(key, (class_name, entity_byname, name, value))) def _get_parameter_value_metadata_for_import(db_map, data): - def _data_iterator(): - for class_name, entity_byname, parameter_name, metadata, *optionals in data: - if isinstance(entity_byname, str): - entity_byname = (entity_byname,) - alternative_name = optionals[0] if optionals else db_map.get_import_alternative_name() - for name, value in _parse_metadata(metadata): - yield (class_name, entity_byname, parameter_name, name, value, alternative_name) - key = ( "entity_class_name", "entity_byname", @@ -736,7 +658,12 @@ def _data_iterator(): "metadata_value", "alternative_name", ) - return _get_items_for_import(db_map, "parameter_value_metadata", (dict(zip(key, x)) for x in _data_iterator())) + for class_name, entity_byname, parameter_name, metadata, *optionals in data: + if isinstance(entity_byname, str): + entity_byname = (entity_byname,) + alternative_name = optionals[0] if optionals else db_map.get_import_alternative_name() + for name, value in _parse_metadata(metadata): + yield dict(zip(key, (class_name, entity_byname, parameter_name, name, value, alternative_name))) # Legacy diff --git a/tests/filters/test_scenario_filter.py b/tests/filters/test_scenario_filter.py index c2f2f51f..cff394f7 100644 --- a/tests/filters/test_scenario_filter.py +++ b/tests/filters/test_scenario_filter.py @@ -268,25 +268,22 @@ def test_scenario_filter_selects_highest_ranked_alternative(self): ) def test_scenario_filter_selects_highest_ranked_alternative_of_active_scenario(self): - import_alternatives(self._out_db_map, ["alternative3"]) - import_alternatives(self._out_db_map, ["alternative1"]) - import_alternatives(self._out_db_map, ["alternative2"]) - import_alternatives(self._out_db_map, ["non_active_alternative"]) + import_alternatives( + self._out_db_map, ["alternative3", "alternative1", "alternative2", "non_active_alternative"] + ) import_object_classes(self._out_db_map, ["object_class"]) import_objects(self._out_db_map, [("object_class", "object")]) import_object_parameters(self._out_db_map, [("object_class", "parameter")]) - import_object_parameter_values(self._out_db_map, [("object_class", "object", "parameter", -1.0)]) - import_object_parameter_values( - self._out_db_map, [("object_class", "object", "parameter", 10.0, "alternative1")] - ) import_object_parameter_values( - self._out_db_map, [("object_class", "object", "parameter", 2000.0, "alternative2")] - ) - import_object_parameter_values( - self._out_db_map, [("object_class", "object", "parameter", 300.0, "alternative3")] + self._out_db_map, + [ + ("object_class", "object", "parameter", -1.0), + ("object_class", "object", "parameter", 10.0, "alternative1"), + ("object_class", "object", "parameter", 2000.0, "alternative2"), + ("object_class", "object", "parameter", 300.0, "alternative3"), + ], ) - import_scenarios(self._out_db_map, [("scenario", True)]) - import_scenarios(self._out_db_map, [("non_active_scenario", False)]) + import_scenarios(self._out_db_map, [("scenario", True), "non_active_scenario"]) import_scenario_alternatives( self._out_db_map, [ diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py index c3f02d48..a7a813b8 100644 --- a/tests/test_import_functions.py +++ b/tests/test_import_functions.py @@ -1220,54 +1220,58 @@ def tearDown(self): self._db_map.close() def test_single_scenario_alternative_import(self): + import_data(self._db_map, scenarios=["scenario"], alternatives=["alternative"]) count, errors = import_scenario_alternatives(self._db_map, [["scenario", "alternative"]]) self.assertFalse(errors) - self.assertEqual(count, 3) - scenario_alternatives = self.scenario_alternatives() - self.assertEqual(scenario_alternatives, {"scenario": {"alternative": 1}}) - - def test_scenario_alternative_import_imports_missing_scenarios_and_alternatives(self): - count, errors = import_scenario_alternatives(self._db_map, [["scenario", "alternative"]]) - self.assertFalse(errors) - self.assertEqual(count, 3) + self.assertEqual(count, 1) scenario_alternatives = self.scenario_alternatives() self.assertEqual(scenario_alternatives, {"scenario": {"alternative": 1}}) def test_scenario_alternative_import_multiple_without_before_alternatives(self): + import_data(self._db_map, scenarios=["scenario"], alternatives=["alternative1", "alternative2"]) count, errors = import_scenario_alternatives( self._db_map, [["scenario", "alternative1"], ["scenario", "alternative2"]] ) self.assertFalse(errors) - self.assertEqual(count, 5) + self.assertEqual(count, 2) scenario_alternatives = self.scenario_alternatives() self.assertEqual(scenario_alternatives, {"scenario": {"alternative1": 1, "alternative2": 2}}) def test_scenario_alternative_import_multiple_with_before_alternatives(self): + import_data(self._db_map, scenarios=["scenario"], alternatives=["alternative1", "alternative2", "alternative3"]) count, errors = import_scenario_alternatives( self._db_map, [["scenario", "alternative1"], ["scenario", "alternative3"], ["scenario", "alternative2", "alternative3"]], ) self.assertFalse(errors) - self.assertEqual(count, 7) + self.assertEqual(count, 3) scenario_alternatives = self.scenario_alternatives() self.assertEqual(scenario_alternatives, {"scenario": {"alternative1": 1, "alternative2": 2, "alternative3": 3}}) def test_fails_with_nonexistent_before_alternative(self): + import_data(self._db_map, scenarios=["scenario"], alternatives=["alternative"]) count, errors = import_scenario_alternatives( self._db_map, [["scenario", "alternative", "nonexistent_alternative"]] ) - self.assertEqual(errors, ["nonexistent_alternative is not in scenario"]) - self.assertEqual(count, 2) + self.assertEqual( + errors, + [ + "can't insert alternative 'alternative' before 'nonexistent_alternative' " + "because the latter is not in scenario 'scenario'" + ], + ) + self.assertEqual(count, 0) scenario_alternatives = self.scenario_alternatives() self.assertEqual(scenario_alternatives, {}) def test_importing_existing_scenario_alternative_does_not_alter_scenario_alternatives(self): + import_data(self._db_map, scenarios=["scenario"], alternatives=["alternative1", "alternative2"]) count, errors = import_scenario_alternatives( self._db_map, [["scenario", "alternative2", "alternative1"], ["scenario", "alternative1"]], ) self.assertFalse(errors) - self.assertEqual(count, 5) + self.assertEqual(count, 2) scenario_alternatives = self.scenario_alternatives() self.assertEqual(scenario_alternatives, {"scenario": {"alternative1": 2, "alternative2": 1}}) count, errors = import_scenario_alternatives( @@ -1295,12 +1299,13 @@ def test_import_scenario_alternatives_in_arbitrary_order(self): self.assertEqual(scenario_alternatives, {"A (1)": {"Base": 1, "b": 2, "c": 3, "d": 4}}) def test_insert_scenario_alternative_in_the_middle_of_other_alternatives(self): + import_data(self._db_map, scenarios=["scenario"], alternatives=["alternative1", "alternative2", "alternative3"]) count, errors = import_scenario_alternatives( self._db_map, [["scenario", "alternative2", "alternative1"], ["scenario", "alternative1"]], ) self.assertFalse(errors) - self.assertEqual(count, 5) + self.assertEqual(count, 2) scenario_alternatives = self.scenario_alternatives() self.assertEqual(scenario_alternatives, {"scenario": {"alternative1": 2, "alternative2": 1}}) count, errors = import_scenario_alternatives( @@ -1308,7 +1313,7 @@ def test_insert_scenario_alternative_in_the_middle_of_other_alternatives(self): [["scenario", "alternative3", "alternative1"]], ) self.assertFalse(errors) - self.assertEqual(count, 3) + self.assertEqual(count, 2) scenario_alternatives = self.scenario_alternatives() self.assertEqual(scenario_alternatives, {"scenario": {"alternative1": 3, "alternative2": 1, "alternative3": 2}}) @@ -1323,9 +1328,9 @@ def scenario_alternatives(self): .filter(self._db_map.scenario_alternative_sq.c.scenario_id == self._db_map.scenario_sq.c.id) .filter(self._db_map.scenario_alternative_sq.c.alternative_id == self._db_map.alternative_sq.c.id) ) - scenario_alternatives = dict() + scenario_alternatives = {} for scenario_alternative in scenario_alternative_qry: - alternative_rank = scenario_alternatives.setdefault(scenario_alternative.scenario_name, dict()) + alternative_rank = scenario_alternatives.setdefault(scenario_alternative.scenario_name, {}) alternative_rank[scenario_alternative.alternative_name] = scenario_alternative.rank return scenario_alternatives From 04c38daa273f5ffd323957e5cf70ebe80f4a8e0b Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 24 Nov 2023 15:07:31 +0100 Subject: [PATCH 205/317] Fix export_data as per latest changes regarding multi d entities --- spinedb_api/export_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spinedb_api/export_functions.py b/spinedb_api/export_functions.py index a15c0fcd..4bb5b822 100644 --- a/spinedb_api/export_functions.py +++ b/spinedb_api/export_functions.py @@ -135,8 +135,8 @@ def export_superclass_subclasses(db_map, ids=Asterisk): def export_entities(db_map, ids=Asterisk): return sorted( - ((x.class_name, x.name, x.element_name_list, x.description) for x in _get_items(db_map, "entity", ids)), - key=lambda x: (len(x[2]), x[0], x[2], x[1]), + ((x.class_name, x.element_name_list or x.name, x.description) for x in _get_items(db_map, "entity", ids)), + key=lambda x: (0 if isinstance(x[1], str) else len(x[1]), x[0], (x[1],) if isinstance(x[1], str) else x[1]), ) From 7c7713e2db2059deb2fa1000ef4db008947c77e7 Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 24 Nov 2023 15:13:55 +0100 Subject: [PATCH 206/317] Fix test too --- tests/test_export_functions.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_export_functions.py b/tests/test_export_functions.py index 8a55f164..489d549a 100644 --- a/tests/test_export_functions.py +++ b/tests/test_export_functions.py @@ -104,8 +104,7 @@ def test_export_data(self): ) self.assertIn("entities", exported) self.assertEqual( - exported["entities"], - [("object_class", "object", (), None), ("relationship_class", "object__", ("object",), None)], + exported["entities"], [("object_class", "object", None), ("relationship_class", ("object",), None)] ) self.assertIn("parameter_values", exported) self.assertEqual( From df9d246442cfc30f8594fb1f4f6294a1065f63d7 Mon Sep 17 00:00:00 2001 From: Manuel Date: Mon, 27 Nov 2023 14:40:18 +0100 Subject: [PATCH 207/317] Improve return type of add_update_item --- spinedb_api/db_mapping.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 8b0280cf..c0050f71 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -479,16 +479,18 @@ def update_items(self, item_type, *items, check=True, strict=False): def add_update_item(self, item_type, check=True, **kwargs): added, add_error = self.add_item(item_type, check=check, **kwargs) if not add_error: - return (added, None), add_error + return added, None, add_error updated, update_error = self.update_item(item_type, check=check, **kwargs) if not update_error: - return (None, updated), update_error - return (None, None), add_error or update_error + return None, updated, update_error + return None, None, add_error or update_error def add_update_items(self, item_type, *items, check=True, strict=False): - added_updated, errors = self._modify_items( - lambda x: self.add_update_item(item_type, check=check, **x), *items, strict=strict - ) + def _function(item): + added, updated, error = self.add_update_item(item_type, check=check, **item) + return (added, updated), error + + added_updated, errors = self._modify_items(_function, *items, strict=strict) added, updated = zip(*added_updated) if added_updated else ([], []) added = [x for x in added if x] updated = [x for x in updated if x] From 29ce79ebd194cb63888e617c64987d3d61fa7fb9 Mon Sep 17 00:00:00 2001 From: Manuel Date: Mon, 27 Nov 2023 14:48:48 +0100 Subject: [PATCH 208/317] Reset cache after clearing subqueries just in case Re #322 --- spinedb_api/db_mapping_query_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spinedb_api/db_mapping_query_mixin.py b/spinedb_api/db_mapping_query_mixin.py index f349fe03..89cf0ae8 100644 --- a/spinedb_api/db_mapping_query_mixin.py +++ b/spinedb_api/db_mapping_query_mixin.py @@ -103,10 +103,10 @@ def _clear_subqueries(self, *tablenames): """Set to `None` subquery attributes involving the affected tables. This forces the subqueries to be refreshed when the corresponding property is accessed. """ - self.reset(*tablenames) attr_names = set(attr for tablename in tablenames for attr in self._get_table_to_sq_attr().get(tablename, [])) for attr_name in attr_names: setattr(self, attr_name, None) + self.reset(*tablenames) def _subquery(self, tablename): """A subquery of the form: From e96f513737e32d4376cba6d58aa101b6e3d51577 Mon Sep 17 00:00:00 2001 From: Manuel Date: Mon, 27 Nov 2023 14:49:27 +0100 Subject: [PATCH 209/317] Minor fixes: don't update removed items and always set list_value_id --- spinedb_api/db_mapping_base.py | 2 ++ spinedb_api/mapped_items.py | 1 + 2 files changed, 3 insertions(+) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 5462c303..65a3f2f3 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -1018,6 +1018,8 @@ def cascade_update(self): """Updates this item and all its referrers in cascade. Also, calls items' update callbacks. """ + if self._removed: + return self.call_update_callbacks() for referrer in self._referrers.values(): referrer.cascade_update() diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index ba1a3430..c1d86e40 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -398,6 +398,7 @@ def resolve(self): return d def polish(self): + self["list_value_id"] = None error = super().polish() if error: return error From 90caffec0efea95bb1049c5bd497b1a3f90eab97 Mon Sep 17 00:00:00 2001 From: Manuel Date: Mon, 27 Nov 2023 15:41:43 +0100 Subject: [PATCH 210/317] Add docs and convenience add_update methods --- spinedb_api/db_mapping.py | 48 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index c0050f71..1a2db137 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -420,7 +420,7 @@ def add_item(self, item_type, check=True, **kwargs): return (mapped_table.add_item(checked_item).public_item if checked_item else None, error) def add_items(self, item_type, *items, check=True, strict=False): - """Add many items to the in-memory mapping. + """Adds many items to the in-memory mapping. Args: item_type (str): One of . @@ -477,6 +477,16 @@ def update_items(self, item_type, *items, check=True, strict=False): return self._modify_items(lambda x: self.update_item(item_type, check=check, **x), *items, strict=strict) def add_update_item(self, item_type, check=True, **kwargs): + """Adds an item to the in-memory mapping if it doesn't exist; otherwise updates the current one. + + Args: + item_type (str): One of . + **kwargs: Fields and values as specified for the item type in :ref:`db_mapping_schema`. + + Returns: + tuple(:class:`PublicItem` or None, :class:`PublicItem` or None, str): The added item if any, + the updated item if any, and any errors. + """ added, add_error = self.add_item(item_type, check=check, **kwargs) if not add_error: return added, None, add_error @@ -486,6 +496,20 @@ def add_update_item(self, item_type, check=True, **kwargs): return None, None, add_error or update_error def add_update_items(self, item_type, *items, check=True, strict=False): + """Adds or updates many items into the in-memory mapping. + + Args: + item_type (str): One of . + *items (Iterable(dict)): One or more :class:`dict` objects mapping fields to values of the item type, + as specified in :ref:`db_mapping_schema`. + strict (bool): Whether or not the method should raise :exc:`~.exception.SpineIntegrityError` + if the insertion of one of the items violates an integrity constraint. + + Returns: + tuple(list(:class:`PublicItem`),list(:class:`PublicItem`),list(str)): items successfully added, + items successfully updated, and found violations. + """ + def _function(item): added, updated, error = self.add_update_item(item_type, check=check, **item) return (added, updated), error @@ -777,6 +801,7 @@ def get_filter_configs(self): setattr(DatabaseMapping, "get_" + it + "_items", partialmethod(DatabaseMapping.get_items, it)) setattr(DatabaseMapping, "add_" + it + "_item", partialmethod(DatabaseMapping.add_item, it)) setattr(DatabaseMapping, "update_" + it + "_item", partialmethod(DatabaseMapping.update_item, it)) + setattr(DatabaseMapping, "add_update_" + it + "_item", partialmethod(DatabaseMapping.add_update_item, it)) setattr(DatabaseMapping, "remove_" + it + "_item", partialmethod(DatabaseMapping.remove_item, it)) setattr(DatabaseMapping, "restore_" + it + "_item", partialmethod(DatabaseMapping.restore_item, it)) @@ -884,6 +909,27 @@ def update_{item_type}_item(self, check=True, **kwargs): ) child.parent = node node.body.append(child) + for item_type in DatabaseMapping.item_types(): + factory = DatabaseMapping.item_factory(item_type) + a = _a(item_type) + add_kwargs = _kwargs(factory.fields) + child = astroid.extract_node( + f''' + def add_update_{item_type}_item(self, check=True, **kwargs): + """Adds {a} `{item_type}` item to the in-memory mapping if it doesn't exist; + otherwise updates the current one. + + Args: + {add_kwargs} + + Returns: + tuple(:class:`PublicItem` or None, :class:`PublicItem` or None, str): The added item if any, + the updated item if any, and any errors. + """ + ''' + ) + child.parent = node + node.body.append(child) for item_type in DatabaseMapping.item_types(): child = astroid.extract_node( f''' From 245d3c170bb2054558360d5124b5f3b4210281e7 Mon Sep 17 00:00:00 2001 From: Manuel Date: Mon, 27 Nov 2023 16:06:12 +0100 Subject: [PATCH 211/317] Add note about add_update to import_data documentation --- spinedb_api/import_functions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 2fbcbee0..48248a79 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -11,6 +11,8 @@ """ Functions for importing data into a Spine database in a standard format. +This functionaly is equivalent to the one provided by :meth:`.DatabaseMapping.add_update_item`, +but the syntax is a little more compact. """ from collections import defaultdict From b1f7c4c3b8215c73ffe6efc7b40b0e76c3c6c957 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 29 Nov 2023 08:30:23 +0100 Subject: [PATCH 212/317] Accept None when updating Fixes #2443 --- spinedb_api/db_mapping_base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index c0286490..1602d613 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -415,7 +415,8 @@ def checked_item_and_error(self, item, for_update=False): error = self._prepare_item(candidate_item, current_item, item) if error: return None, error - self.check_fields(candidate_item._asdict()) + valid_types = (type(None),) if for_update else () + self.check_fields(candidate_item._asdict(), valid_types=valid_types) return candidate_item, merge_error def _prepare_item(self, candidate_item, current_item, original_item): From fcd49bf2fa655f2ec530e70b7c282bd07f13b833 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 29 Nov 2023 11:45:45 +0100 Subject: [PATCH 213/317] Simplify computation of time-series indices --- spinedb_api/parameter_value.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/spinedb_api/parameter_value.py b/spinedb_api/parameter_value.py index e2e3828e..db36b73a 100644 --- a/spinedb_api/parameter_value.py +++ b/spinedb_api/parameter_value.py @@ -1317,20 +1317,11 @@ def _get_memoized_indexes(self): memoized_indexes = self._memoized_indexes.get(key) if memoized_indexes is not None: return memoized_indexes - step_index = 0 - step_cycle_index = 0 - full_cycle_duration = sum(self._resolution, relativedelta()) - stamps = np.empty(len(self), dtype=_NUMPY_DATETIME_DTYPE) - stamps[0] = self._start - for stamp_index in range(1, len(self._values)): - if step_index >= len(self._resolution): - step_index = 0 - step_cycle_index += 1 - current_cycle_duration = sum(self._resolution[: step_index + 1], relativedelta()) - duration_from_start = step_cycle_index * full_cycle_duration + current_cycle_duration - stamps[stamp_index] = self._start + duration_from_start - step_index += 1 - memoized_indexes = self._memoized_indexes[key] = np.array(stamps, dtype=_NUMPY_DATETIME_DTYPE) + cycle_count = -(-len(self) // len(self.resolution)) + resolution = (cycle_count * self.resolution)[: len(self) - 1] + resolution.insert(0, self._start) + resolution_arr = np.array(resolution) + memoized_indexes = self._memoized_indexes[key] = resolution_arr.cumsum().astype(_NUMPY_DATETIME_DTYPE) return memoized_indexes @property From e577113687d16d315df2ddc67733c35fedaf29ce Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 29 Nov 2023 12:30:18 +0100 Subject: [PATCH 214/317] Accept np.dtype as value_type for IndexedValue --- spinedb_api/parameter_value.py | 41 ++++++++++++++-------------------- 1 file changed, 17 insertions(+), 24 deletions(-) diff --git a/spinedb_api/parameter_value.py b/spinedb_api/parameter_value.py index db36b73a..7dbd704d 100644 --- a/spinedb_api/parameter_value.py +++ b/spinedb_api/parameter_value.py @@ -561,7 +561,7 @@ def _time_pattern_from_database(value_dict): TimePattern: restored time pattern """ patterns, values = _break_dictionary(value_dict["data"]) - return TimePattern(patterns, values, value_dict.get("index_name", "p")) + return TimePattern(patterns, values, value_dict.get("index_name", TimePattern.DEFAULT_INDEX_NAME)) def _map_from_database(value_dict): @@ -574,7 +574,7 @@ def _map_from_database(value_dict): Map: restored Map """ index_type = _map_index_type_from_database(value_dict["index_type"]) - index_name = value_dict.get("index_name", "x") + index_name = value_dict.get("index_name", Map.DEFAULT_INDEX_NAME) data = value_dict["data"] if isinstance(data, dict): indexes = _map_indexes_from_database(data.keys(), index_type) @@ -680,7 +680,7 @@ def _array_from_database(value_dict): except (TypeError, ParameterValueFormatError) as error: raise ParameterValueFormatError(f'Failed to read values for Array: {error}') else: - index_name = value_dict.get("index_name", "i") + index_name = value_dict.get("index_name", Array.DEFAULT_INDEX_NAME) return Array(data, value_type, index_name) @@ -959,6 +959,10 @@ def values(self, values): Args: values (:class:`~numpy.ndarray`) """ + if isinstance(self._value_type, np.dtype) and ( + not isinstance(values, np.ndarray) or not values.dtype == self._value_type + ): + values = np.array(values, dtype=self._value_type) self._values = values @property @@ -1078,7 +1082,7 @@ def to_dict(self): else: data = [x.value_to_database_data() for x in self._values] value_dict = {"value_type": value_type_id, "data": data} - if self.index_name != "i": + if self.index_name != self.DEFAULT_INDEX_NAME: value_dict["index_name"] = self.index_name return value_dict @@ -1164,7 +1168,7 @@ def __init__(self, indexes, values, index_name=""): raise ParameterValueFormatError("Length of values does not match length of indexes") if not indexes: raise ParameterValueFormatError("Empty time pattern not allowed") - super().__init__(values, value_type=float, index_name=index_name) + super().__init__(values, value_type=np.dtype(float), index_name=index_name) self.indexes = indexes def __eq__(self, other): @@ -1186,7 +1190,7 @@ def type_(): def to_dict(self): value_dict = {"data": dict(zip(self._indexes, self._values))} - if self.index_name != "p": + if self.index_name != self.DEFAULT_INDEX_NAME: value_dict["index_name"] = self.index_name return value_dict @@ -1209,7 +1213,7 @@ def __init__(self, values, ignore_year, repeat, index_name=""): """ if len(values) < 1: raise ParameterValueFormatError("Time series too short. Must have one or more values") - super().__init__(values, value_type=float, index_name=index_name) + super().__init__(values, value_type=np.dtype(float), index_name=index_name) self._ignore_year = ignore_year self._repeat = repeat @@ -1252,17 +1256,6 @@ def repeat(self, repeat): """ self._repeat = bool(repeat) - @IndexedValue.values.setter - def values(self, values): - """Sets the values. - - Args: - values (:class:`~numpy.ndarray`) - """ - if not isinstance(values, np.ndarray) or not values.dtype == np.dtype(float): - values = np.array(values, dtype=float) - self._values = values - @staticmethod def type_(): return "time_series" @@ -1385,9 +1378,9 @@ def resolution(self, resolution): elif not isinstance(resolution, Sequence): resolution = [resolution] else: - for i in range(len(resolution)): - if isinstance(resolution[i], str): - resolution[i] = duration_to_relativedelta(resolution[i]) + for i, r in enumerate(resolution): + if isinstance(r, str): + resolution[i] = duration_to_relativedelta(r) if not resolution: raise ParameterValueFormatError("Resolution cannot be zero.") self._resolution = resolution @@ -1407,7 +1400,7 @@ def to_dict(self): }, "data": self._values.tolist(), } - if self.index_name != "t": + if self.index_name != self.DEFAULT_INDEX_NAME: value_dict["index_name"] = self.index_name return value_dict @@ -1463,7 +1456,7 @@ def to_dict(self): value_dict.setdefault("index", dict())["ignore_year"] = self._ignore_year if self._repeat: value_dict.setdefault("index", dict())["repeat"] = self._repeat - if self.index_name != "t": + if self.index_name != self.DEFAULT_INDEX_NAME: value_dict["index_name"] = self.index_name return value_dict @@ -1526,7 +1519,7 @@ def to_dict(self): "index_type": _map_index_type_to_database(self._index_type), "data": self.value_to_database_data(), } - if self.index_name != "x": + if self.index_name != self.DEFAULT_INDEX_NAME: value_dict["index_name"] = self.index_name return value_dict From 8393b9b16d0beb1139cfbd3d5b91ee71009b1604 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 29 Nov 2023 15:38:35 +0100 Subject: [PATCH 215/317] Fix order of imports to combine legacy and non-legacy --- spinedb_api/import_functions.py | 57 ++++++++++++++++----------------- 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 48248a79..8da06a1d 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -176,15 +176,25 @@ def get_data_for_import( if entity_classes: for bucket in _get_entity_classes_for_import(db_map, entity_classes): yield ("entity_class", bucket) + if object_classes: # Legacy + yield from get_data_for_import(db_map, entity_classes=_object_classes_to_entity_classes(object_classes)) + if relationship_classes: # Legacy + yield from get_data_for_import(db_map, entity_classes=relationship_classes) if superclass_subclasses: yield ("superclass_subclass", _get_superclass_subclasses_for_import(db_map, superclass_subclasses)) if entities: for bucket in _get_entities_for_import(db_map, entities): yield ("entity", bucket) + if objects: # Legacy + yield from get_data_for_import(db_map, entities=objects) + if relationships: # Legacy + yield from get_data_for_import(db_map, entities=relationships) if entity_alternatives: yield ("entity_alternative", _get_entity_alternatives_for_import(db_map, entity_alternatives)) if entity_groups: yield ("entity_group", _get_entity_groups_for_import(db_map, entity_groups)) + if object_groups: # Legacy + yield from get_data_for_import(db_map, entity_groups=object_groups) if parameter_value_lists: yield ("parameter_value_list", _get_parameter_value_lists_for_import(db_map, parameter_value_lists)) yield ("list_value", _get_list_values_for_import(db_map, parameter_value_lists, unparse_value)) @@ -193,11 +203,25 @@ def get_data_for_import( "parameter_definition", _get_parameter_definitions_for_import(db_map, parameter_definitions, unparse_value), ) + if object_parameters: # Legacy + yield from get_data_for_import(db_map, unparse_value=unparse_value, parameter_definitions=object_parameters) + if relationship_parameters: # Legacy + yield from get_data_for_import( + db_map, unparse_value=unparse_value, parameter_definitions=relationship_parameters + ) if parameter_values: yield ( "parameter_value", _get_parameter_values_for_import(db_map, parameter_values, unparse_value, on_conflict), ) + if object_parameter_values: # Legacy + yield from get_data_for_import( + db_map, unparse_value=unparse_value, on_conflict=on_conflict, parameter_values=object_parameter_values + ) + if relationship_parameter_values: # Legacy + yield from get_data_for_import( + db_map, unparse_value=unparse_value, on_conflict=on_conflict, parameter_values=relationship_parameter_values + ) if metadata: yield ("metadata", _get_metadata_for_import(db_map, metadata)) if entity_metadata: @@ -209,38 +233,13 @@ def get_data_for_import( _get_metadata_for_import(db_map, (pval_metadata[3] for pval_metadata in parameter_value_metadata)), ) yield ("parameter_value_metadata", _get_parameter_value_metadata_for_import(db_map, parameter_value_metadata)) - # Legacy - if object_classes: - yield from get_data_for_import(db_map, entity_classes=_object_classes_to_entity_classes(object_classes)) - if relationship_classes: - yield from get_data_for_import(db_map, entity_classes=relationship_classes) - if object_parameters: - yield from get_data_for_import(db_map, unparse_value=unparse_value, parameter_definitions=object_parameters) - if relationship_parameters: - yield from get_data_for_import( - db_map, unparse_value=unparse_value, parameter_definitions=relationship_parameters - ) - if objects: - yield from get_data_for_import(db_map, entities=objects) - if relationships: - yield from get_data_for_import(db_map, entities=relationships) - if object_groups: - yield from get_data_for_import(db_map, entity_groups=object_groups) - if object_parameter_values: - yield from get_data_for_import( - db_map, unparse_value=unparse_value, on_conflict=on_conflict, parameter_values=object_parameter_values - ) - if relationship_parameter_values: - yield from get_data_for_import( - db_map, unparse_value=unparse_value, on_conflict=on_conflict, parameter_values=relationship_parameter_values - ) - if object_metadata: + if object_metadata: # Legacy yield from get_data_for_import(db_map, entity_metadata=object_metadata) - if relationship_metadata: + if relationship_metadata: # Legacy yield from get_data_for_import(db_map, entity_metadata=relationship_metadata) - if object_parameter_value_metadata: + if object_parameter_value_metadata: # Legacy yield from get_data_for_import(db_map, parameter_value_metadata=object_parameter_value_metadata) - if relationship_parameter_value_metadata: + if relationship_parameter_value_metadata: # Legacy yield from get_data_for_import(db_map, parameter_value_metadata=relationship_parameter_value_metadata) From 3f785a01732653fb5bb1352e5cce65a96cdb7cab Mon Sep 17 00:00:00 2001 From: Henrik Koski <98282892+PiispaH@users.noreply.github.com> Date: Wed, 13 Dec 2023 11:11:35 +0200 Subject: [PATCH 216/317] Issue 2454 traceback from url bar (#326) * Fix traceback from Spine DB Editor URL bar Re spine-tools/Spine-Toolbox#2454 --- spinedb_api/db_mapping.py | 10 ++++++++-- tests/test_DatabaseMapping.py | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 1a2db137..d05a4b3c 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -25,7 +25,7 @@ from sqlalchemy import create_engine, MetaData, inspect from sqlalchemy.pool import NullPool from sqlalchemy.event import listen -from sqlalchemy.exc import DatabaseError, DBAPIError +from sqlalchemy.exc import DatabaseError, DBAPIError, ArgumentError from sqlalchemy.engine.url import make_url, URL from alembic.migration import MigrationContext from alembic.environment import EnvironmentContext @@ -155,7 +155,13 @@ def __init__( else: filter_configs = [] self._filter_configs = filter_configs if apply_filters else None - self.sa_url = make_url(db_url) + try: + self.sa_url = make_url(db_url) + except ArgumentError as err: + raise SpineDBAPIError( + f"Could not parse the given URL. " + f"Please check that it is valid." + ) self.username = username if username else "anon" self.codename = self._make_codename(codename) self._memory = memory diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 174c4137..027b559d 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -25,7 +25,7 @@ SpineIntegrityError, ) from spinedb_api.helpers import name_from_elements -from .custom_db_mapping import CustomDatabaseMapping +from tests.custom_db_mapping import CustomDatabaseMapping def create_query_wrapper(db_map): From 4f5674747770bcf543141fe33e65509a3a8a8b46 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Fri, 5 Jan 2024 13:04:31 +0200 Subject: [PATCH 217/317] Make entity and value metadata use class_name as unique key class_name is required to uniquely identify entities and parameter values. Re #318,#328 --- spinedb_api/db_mapping.py | 7 +- spinedb_api/import_functions.py | 4 +- spinedb_api/mapped_items.py | 47 ++++++--- tests/test_DatabaseMapping.py | 173 +++++++++++++++++++++++++++++++- 4 files changed, 208 insertions(+), 23 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index d05a4b3c..5a464839 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -157,11 +157,8 @@ def __init__( self._filter_configs = filter_configs if apply_filters else None try: self.sa_url = make_url(db_url) - except ArgumentError as err: - raise SpineDBAPIError( - f"Could not parse the given URL. " - f"Please check that it is valid." - ) + except ArgumentError: + raise SpineDBAPIError("Could not parse the given URL. Please check that it is valid.") self.username = username if username else "anon" self.codename = self._make_codename(codename) self._memory = memory diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 8da06a1d..19f64341 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -642,7 +642,7 @@ def _get_metadata_for_import(db_map, data): def _get_entity_metadata_for_import(db_map, data): - key = ("entity_class_name", "entity_byname", "metadata_name", "metadata_value") + key = ("class_name", "entity_byname", "metadata_name", "metadata_value") for class_name, entity_byname, metadata in data: if isinstance(entity_byname, str): entity_byname = (entity_byname,) @@ -652,7 +652,7 @@ def _get_entity_metadata_for_import(db_map, data): def _get_parameter_value_metadata_for_import(db_map, data): key = ( - "entity_class_name", + "class_name", "entity_byname", "parameter_definition_name", "metadata_name", diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index c1d86e40..0f7a6d6b 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -37,6 +37,9 @@ def item_factory(item_type): }.get(item_type, MappedItemBase) +_ENTITY_BYNAME_VALUE = 'A tuple with the entity name as single element if the entity is zero-dimensional, or the element names if the entity is multi-dimensional.' + + class CommitItem(MappedItemBase): fields = { 'comment': {'type': str, 'value': 'A comment describing the commit.'}, @@ -282,8 +285,7 @@ class EntityAlternativeItem(MappedItemBase): 'entity_class_name': {'type': str, 'value': 'The entity class name.'}, 'entity_byname': { 'type': tuple, - 'value': 'A tuple with the entity name as single element if the entity is zero-dimensional, ' - 'or the element names if it is multi-dimensional.', + 'value': _ENTITY_BYNAME_VALUE, }, 'alternative_name': {'type': str, 'value': 'The alternative name.'}, 'active': { @@ -504,8 +506,7 @@ class ParameterValueItem(ParameterItemBase): 'parameter_definition_name': {'type': str, 'value': 'The parameter name.'}, 'entity_byname': { 'type': tuple, - 'value': 'A tuple with the entity name as single element if the entity is zero-dimensional, ' - 'or the element names if the entity is multi-dimensional.', + 'value': _ENTITY_BYNAME_VALUE, }, 'value': {'type': bytes, 'value': 'The value.'}, 'type': {'type': str, 'value': 'The value type.', 'optional': True}, @@ -676,44 +677,60 @@ class MetadataItem(MappedItemBase): class EntityMetadataItem(MappedItemBase): fields = { - 'entity_name': {'type': str, 'value': 'The entity name.'}, + 'class_name': {'type': str, 'value': 'The entity class name.'}, + 'entity_byname': {'type': tuple, 'value': _ENTITY_BYNAME_VALUE}, 'metadata_name': {'type': str, 'value': 'The metadata entry name.'}, 'metadata_value': {'type': str, 'value': 'The metadata entry value.'}, } - _unique_keys = (("entity_name", "metadata_name", "metadata_value"),) - _references = {"entity_id": ("entity", "id"), "metadata_id": ("metadata", "id")} + _unique_keys = (("class_name", "entity_byname", "metadata_name", "metadata_value"),) + _references = { + "entity_id": ("entity", "id"), + "metadata_id": ("metadata", "id"), + } _external_fields = { - "entity_name": ("entity_id", "name"), + "class_name": ("entity_id", "class_name"), + "entity_byname": ("entity_id", "byname"), "metadata_name": ("metadata_id", "name"), "metadata_value": ("metadata_id", "value"), } _alt_references = { - ("entity_class_name", "entity_byname"): ("entity", ("class_name", "byname")), + ( + "class_name", + "entity_byname", + ): ("entity", ("class_name", "byname")), ("metadata_name", "metadata_value"): ("metadata", ("name", "value")), } _internal_fields = { - "entity_id": (("entity_class_name", "entity_byname"), "id"), + "entity_id": (("class_name", "entity_byname"), "id"), "metadata_id": (("metadata_name", "metadata_value"), "id"), } class ParameterValueMetadataItem(MappedItemBase): fields = { + 'class_name': {'type': str, 'value': 'The entity class name.'}, 'parameter_definition_name': {'type': str, 'value': 'The parameter name.'}, 'entity_byname': { 'type': tuple, - 'value': 'A tuple with the entity name as single element if the entity is zero-dimensional, ' - 'or the element names if it is multi-dimensional.', + 'value': _ENTITY_BYNAME_VALUE, }, 'alternative_name': {'type': str, 'value': 'The alternative name.'}, 'metadata_name': {'type': str, 'value': 'The metadata entry name.'}, 'metadata_value': {'type': str, 'value': 'The metadata entry value.'}, } _unique_keys = ( - ("parameter_definition_name", "entity_byname", "alternative_name", "metadata_name", "metadata_value"), + ( + "class_name", + "parameter_definition_name", + "entity_byname", + "alternative_name", + "metadata_name", + "metadata_value", + ), ) _references = {"parameter_value_id": ("parameter_value", "id"), "metadata_id": ("metadata", "id")} _external_fields = { + "class_name": ("parameter_value_id", "entity_class_name"), "parameter_definition_name": ("parameter_value_id", "parameter_definition_name"), "entity_byname": ("parameter_value_id", "entity_byname"), "alternative_name": ("parameter_value_id", "alternative_name"), @@ -721,7 +738,7 @@ class ParameterValueMetadataItem(MappedItemBase): "metadata_value": ("metadata_id", "value"), } _alt_references = { - ("entity_class_name", "parameter_definition_name", "entity_byname", "alternative_name"): ( + ("class_name", "parameter_definition_name", "entity_byname", "alternative_name"): ( "parameter_value", ("entity_class_name", "parameter_definition_name", "entity_byname", "alternative_name"), ), @@ -729,7 +746,7 @@ class ParameterValueMetadataItem(MappedItemBase): } _internal_fields = { "parameter_value_id": ( - ("entity_class_name", "parameter_definition_name", "entity_byname", "alternative_name"), + ("class_name", "parameter_definition_name", "entity_byname", "alternative_name"), "id", ), "metadata_id": (("metadata_name", "metadata_value"), "id"), diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 027b559d..cce216fe 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -216,6 +216,169 @@ def test_updating_entity_name_updates_the_name_in_parameter_value_too(self): ) self.assertIsNotNone(color) + def test_update_entity_metadata_by_changing_its_entity(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + entity_class, _ = db_map.add_entity_class_item(name="my_class") + db_map.add_entity_item(name="entity_1", class_name="my_class") + entity_2, _ = db_map.add_entity_item(name="entity_2", class_name="my_class") + metadata_value = '{"sources": [], "contributors": []}' + metadata, _ = db_map.add_metadata_item(name="my_metadata", value=metadata_value) + entity_metadata, error = db_map.add_entity_metadata_item( + metadata_name="my_metadata", + metadata_value=metadata_value, + class_name="my_class", + entity_byname=("entity_1",), + ) + self.assertIsNone(error) + entity_metadata.update(entity_byname=("entity_2",)) + self.assertEqual( + entity_metadata._extended(), + { + "class_name": "my_class", + "entity_byname": ("entity_2",), + "entity_id": entity_2["id"], + "id": entity_metadata["id"], + "metadata_id": metadata["id"], + "metadata_name": "my_metadata", + "metadata_value": metadata_value, + }, + ) + db_map.commit_session("Add initial data.") + entity_sq = ( + db_map.query( + db_map.entity_sq.c.id.label("entity_id"), + db_map.entity_class_sq.c.name.label("class_name"), + db_map.entity_sq.c.name.label("entity_name"), + ) + .join(db_map.entity_class_sq, db_map.entity_class_sq.c.id == db_map.entity_sq.c.class_id) + .subquery() + ) + metadata_records = ( + db_map.query( + db_map.entity_metadata_sq.c.id, + entity_sq.c.class_name, + entity_sq.c.entity_name, + db_map.metadata_sq.c.name.label("metadata_name"), + db_map.metadata_sq.c.value.label("metadata_value"), + ) + .join(entity_sq, entity_sq.c.entity_id == db_map.entity_metadata_sq.c.entity_id) + .join(db_map.metadata_sq, db_map.metadata_sq.c.id == db_map.entity_metadata_sq.c.metadata_id) + .all() + ) + self.assertEqual(len(metadata_records), 1) + self.assertEqual( + dict(**metadata_records[0]), + { + "id": 1, + "class_name": "my_class", + "entity_name": "entity_2", + "metadata_name": "my_metadata", + "metadata_value": metadata_value, + }, + ) + + def test_update_parameter_value_metadata_by_changing_its_parameter(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + entity_class, _ = db_map.add_entity_class_item(name="my_class") + _, error = db_map.add_parameter_definition_item(name="x", entity_class_name="my_class") + self.assertIsNone(error) + db_map.add_parameter_definition_item(name="y", entity_class_name="my_class") + entity, _ = db_map.add_entity_item(name="my_entity", class_name="my_class") + value, value_type = to_database(2.3) + _, error = db_map.add_parameter_value_item( + entity_class_name="my_class", + entity_byname=("my_entity",), + parameter_definition_name="x", + alternative_name="Base", + value=value, + type=value_type, + ) + self.assertIsNone(error) + value, value_type = to_database(-2.3) + y, error = db_map.add_parameter_value_item( + entity_class_name="my_class", + entity_byname=("my_entity",), + parameter_definition_name="y", + alternative_name="Base", + value=value, + type=value_type, + ) + self.assertIsNone(error) + metadata_value = '{"sources": [], "contributors": []}' + metadata, error = db_map.add_metadata_item(name="my_metadata", value=metadata_value) + self.assertIsNone(error) + value_metadata, error = db_map.add_parameter_value_metadata_item( + metadata_name="my_metadata", + metadata_value=metadata_value, + class_name="my_class", + entity_byname=("my_entity",), + parameter_definition_name="x", + alternative_name="Base", + ) + self.assertIsNone(error) + value_metadata.update(parameter_definition_name="y") + self.assertEqual( + value_metadata._extended(), + { + "class_name": "my_class", + "entity_byname": ("my_entity",), + "alternative_name": "Base", + "parameter_definition_name": "y", + "parameter_value_id": y["id"], + "id": value_metadata["id"], + "metadata_id": metadata["id"], + "metadata_name": "my_metadata", + "metadata_value": metadata_value, + }, + ) + db_map.commit_session("Add initial data.") + parameter_sq = ( + db_map.query( + db_map.parameter_value_sq.c.id.label("value_id"), + db_map.entity_class_sq.c.name.label("class_name"), + db_map.entity_sq.c.name.label("entity_name"), + db_map.parameter_definition_sq.c.name.label("parameter_definition_name"), + db_map.alternative_sq.c.name.label("alternative_name"), + ) + .join( + db_map.entity_class_sq, db_map.entity_class_sq.c.id == db_map.parameter_value_sq.c.entity_class_id + ) + .join(db_map.entity_sq, db_map.entity_sq.c.id == db_map.parameter_value_sq.c.entity_id) + .join( + db_map.parameter_definition_sq, + db_map.parameter_definition_sq.c.id == db_map.parameter_value_sq.c.parameter_definition_id, + ) + .join(db_map.alternative_sq, db_map.alternative_sq.c.id == db_map.parameter_value_sq.c.alternative_id) + .subquery("parameter_sq") + ) + metadata_records = ( + db_map.query( + db_map.parameter_value_metadata_sq.c.id, + parameter_sq.c.class_name, + parameter_sq.c.entity_name, + parameter_sq.c.parameter_definition_name, + parameter_sq.c.alternative_name, + db_map.metadata_sq.c.name.label("metadata_name"), + db_map.metadata_sq.c.value.label("metadata_value"), + ) + .join(parameter_sq, parameter_sq.c.value_id == db_map.parameter_value_metadata_sq.c.parameter_value_id) + .join(db_map.metadata_sq, db_map.metadata_sq.c.id == db_map.parameter_value_metadata_sq.c.metadata_id) + .all() + ) + self.assertEqual(len(metadata_records), 1) + self.assertEqual( + dict(**metadata_records[0]), + { + "id": 1, + "class_name": "my_class", + "entity_name": "my_entity", + "parameter_definition_name": "y", + "alternative_name": "Base", + "metadata_name": "my_metadata", + "metadata_value": metadata_value, + }, + ) + def test_fetch_more(self): with DatabaseMapping("sqlite://", create=True) as db_map: alternatives = db_map.fetch_more("alternative") @@ -2092,6 +2255,10 @@ def setUp(self): def tearDown(self): self._db_map.close() + def _assert_import(self, result): + error = result[1] + self.assertEqual(error, []) + def test_remove_object_class(self): """Test adding and removing an object class and committing""" items, _ = self._db_map.add_object_classes({"name": "oc1", "id": 1}, {"name": "oc2", "id": 2}) @@ -2332,7 +2499,11 @@ def test_cascade_remove_entity_metadata_leaves_metadata_used_by_value_intact(sel self._db_map, (("my_class", "my_object", "my_parameter", 99.0),) ) import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) - import_functions.import_object_metadata(self._db_map, (("my_class", "my_object", '{"title": "My metadata."}'),)) + self._assert_import( + import_functions.import_object_metadata( + self._db_map, (("my_class", "my_object", '{"title": "My metadata."}'),) + ) + ) import_functions.import_object_parameter_value_metadata( self._db_map, (("my_class", "my_object", "my_parameter", '{"title": "My metadata."}'),) ) From 50ace91b5147871ec983623fe443bd837ee14638 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Mon, 8 Jan 2024 13:56:44 +0200 Subject: [PATCH 218/317] Don't add commit ids to items that don't have them The database tables backing EntityGroupItem and SuperclassSubclassItem don't have commit id columns so we shouldn't add the ids to corresponding items when committing. Re #331 --- spinedb_api/mapped_items.py | 6 ++++++ tests/test_DatabaseMapping.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 0f7a6d6b..8a32d102 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -279,6 +279,9 @@ def __getitem__(self, key): return self["entity_id"] return super().__getitem__(key) + def commit(self, _commit_id): + super().commit(None) + class EntityAlternativeItem(MappedItemBase): fields = { @@ -777,3 +780,6 @@ def check_mutability(self): if self._subclass_entities(): return "can't set or modify the superclass for a class that already has entities" return super().check_mutability() + + def commit(self, _commit_id): + super().commit(None) diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index cce216fe..70c7055c 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -481,6 +481,35 @@ def test_committing_scenario_alternatives(self): self.assertEqual(scenario_alternatives[1]["alternative_name"], "alt2") self.assertEqual(scenario_alternatives[1]["rank"], 1) + def test_committing_entity_class_items_doesnt_add_commit_ids_to_them(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + db_map.add_entity_class_item(name="my_class") + db_map.commit_session("Add class.") + classes = db_map.get_entity_class_items() + self.assertEqual(len(classes), 1) + self.assertNotIn("commit_id", classes[0]._extended()) + + def test_committing_superclass_subclass_items_doesnt_add_commit_ids_to_them(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + db_map.add_entity_class_item(name="high") + db_map.add_entity_class_item(name="low") + db_map.add_superclass_subclass_item(superclass_name="high", subclass_name="low") + db_map.commit_session("Add class hierarchy.") + classes = db_map.get_superclass_subclass_items() + self.assertEqual(len(classes), 1) + self.assertNotIn("commit_id", classes[0]._extended()) + + def test_committing_entity_group_items_doesnt_add_commit_ids_to_them(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + db_map.add_entity_class_item(name="my_class") + db_map.add_entity_item(name="element", class_name="my_class") + db_map.add_entity_item(name="container", class_name="my_class") + db_map.add_entity_group_item(group_name="container", member_name="element", class_name="my_class") + db_map.commit_session("Add entity group.") + groups = db_map.get_entity_group_items() + self.assertEqual(len(groups), 1) + self.assertNotIn("commit_id", groups[0]._extended()) + class TestDatabaseMappingLegacy(unittest.TestCase): """'Backward compatibility' tests, i.e. pre-entity tests converted to work with the entity structure.""" From 6959fd512aa25fbf1a71f00b9592accc515c2481 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Wed, 13 Dec 2023 16:28:14 +0200 Subject: [PATCH 219/317] Implement rudimentary "merge conflict resolution" We now apply a rudimentary conflict resolution when fetching items that already exist in the db cache. Conflict resolution is handled by a callback that could let user decide what to do in the future but is currently used to force the default behavior, which is to keep cache as-is. Re spine-tools/Spine-Toolbox#2431 --- spinedb_api/conflict_resolution.py | 104 +++++++++ spinedb_api/db_mapping.py | 21 +- spinedb_api/db_mapping_base.py | 79 ++++--- spinedb_api/item_status.py | 23 ++ spinedb_api/server_client_helpers.py | 2 +- tests/test_DatabaseMapping.py | 315 ++++++++++++++++++++------- 6 files changed, 433 insertions(+), 111 deletions(-) create mode 100644 spinedb_api/conflict_resolution.py create mode 100644 spinedb_api/item_status.py diff --git a/spinedb_api/conflict_resolution.py b/spinedb_api/conflict_resolution.py new file mode 100644 index 00000000..d81d2855 --- /dev/null +++ b/spinedb_api/conflict_resolution.py @@ -0,0 +1,104 @@ +###################################################################################################################### +# Copyright (C) 2017-2022 Spine project consortium +# This file is part of Spine Database API. +# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser +# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General +# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with +# this program. If not, see . +###################################################################################################################### +from __future__ import annotations +from enum import auto, Enum, unique +from dataclasses import dataclass + +from .item_status import Status + + +@unique +class Resolution(Enum): + USE_IN_MEMORY = auto() + USE_IN_DB = auto() + + +@dataclass +class Conflict: + in_memory: MappedItemBase + in_db: MappedItemBase + + +@dataclass +class Resolved(Conflict): + resolution: Resolution + + def __init__(self, conflict, resolution): + self.in_memory = conflict.in_memory + self.in_db = conflict.in_db + self.resolution = resolution + + +def select_in_memory_item_always(conflicts): + return [Resolved(conflict, Resolution.USE_IN_MEMORY) for conflict in conflicts] + + +def select_in_db_item_always(conflicts): + return [Resolved(conflict, Resolution.USE_IN_DB) for conflict in conflicts] + + +@dataclass +class KeepInMemoryAction: + in_memory: MappedItemBase + set_uncommitted: bool + + def __init__(self, conflict): + self.in_memory = conflict.in_memory + self.set_uncommitted = conflict.in_memory.extended() != conflict.in_db.extended() + + +@dataclass +class UpdateInMemoryAction: + in_memory: MappedItemBase + in_db: MappedItemBase + + def __init__(self, conflict): + self.in_memory = conflict.in_memory + self.in_db = conflict.in_db + + +@dataclass +class ResurrectAction: + in_memory: MappedItemBase + in_db: MappedItemBase + + def __init__(self, conflict): + self.in_memory = conflict.in_memory + self.in_db = conflict.in_db + + +def resolved_conflict_actions(conflicts): + for conflict in conflicts: + if conflict.resolution == Resolution.USE_IN_MEMORY: + yield KeepInMemoryAction(conflict) + elif conflict.resolution == Resolution.USE_IN_DB: + yield UpdateInMemoryAction(conflict) + else: + raise RuntimeError(f"unknown conflict resolution") + + +def resurrection_conflicts_from_resolved(conflicts): + resurrection_conflicts = [] + for conflict in conflicts: + if conflict.resolution != Resolution.USE_IN_DB or not conflict.in_memory.removed: + continue + resurrection_conflicts.append(conflict) + return resurrection_conflicts + + +def make_changed_in_memory_items_dirty(conflicts): + for conflict in conflicts: + if conflict.resolution != Resolution.USE_IN_MEMORY: + continue + if conflict.in_memory.removed: + conflict.in_memory.status = Status.to_remove + elif conflict.in_memory.asdict_() != conflict.in_db: + conflict.in_memory.status = Status.to_update diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 5a464839..6b758d54 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -33,6 +33,7 @@ from alembic.config import Config from alembic.util.exc import CommandError +from .conflict_resolution import select_in_memory_item_always from .filters.tools import pop_filter_configs, apply_filter_stack, load_filters from .spine_db_client import get_db_url_from_server from .mapped_items import item_factory @@ -364,13 +365,16 @@ def get_item(self, item_type, fetch=True, skip_removed=True, **kwargs): return {} return item.public_item - def get_items(self, item_type, fetch=True, skip_removed=True, **kwargs): + def get_items( + self, item_type, fetch=True, skip_removed=True, resolve_conflicts=select_in_memory_item_always, **kwargs + ): """Finds and returns all the items of one type. Args: item_type (str): One of . fetch (bool, optional): Whether to fetch the DB before returning the items. skip_removed (bool, optional): Whether to ignore removed items. + resolve_conflicts (Callable): function that resolves fetch conflicts **kwargs: Fields and values for one the unique keys as specified for the item type in :ref:`db_mapping_schema`. @@ -381,7 +385,7 @@ def get_items(self, item_type, fetch=True, skip_removed=True, **kwargs): mapped_table = self.mapped_table(item_type) mapped_table.check_fields(kwargs, valid_types=(type(None),)) if fetch: - self.do_fetch_all(item_type, **kwargs) + self.do_fetch_all(item_type, resolve_conflicts=resolve_conflicts, **kwargs) get_items = mapped_table.valid_values if skip_removed else mapped_table.values return [x.public_item for x in get_items() if all(x.get(k) == v for k, v in kwargs.items())] @@ -617,7 +621,7 @@ def purge_items(self, item_type): """ return bool(self.remove_items(item_type, Asterisk)) - def fetch_more(self, item_type, offset=0, limit=None, **kwargs): + def fetch_more(self, item_type, offset=0, limit=None, resolve_conflicts=select_in_memory_item_always, **kwargs): """Fetches items from the DB into the in-memory mapping, incrementally. Args: @@ -631,7 +635,12 @@ def fetch_more(self, item_type, offset=0, limit=None, **kwargs): list(:class:`PublicItem`): The items fetched. """ item_type = self.real_item_type(item_type) - return [x.public_item for x in self.do_fetch_more(item_type, offset=offset, limit=limit, **kwargs)] + return [ + x.public_item + for x in self.do_fetch_more( + item_type, offset=offset, limit=limit, resolve_conflicts=resolve_conflicts, **kwargs + ) + ] def fetch_all(self, *item_types): """Fetches items from the DB into the in-memory mapping. @@ -720,10 +729,6 @@ def rollback_session(self): if self._memory: self._memory_dirty = False - def refresh_session(self): - """Resets the fetch status so new items from the DB can be retrieved.""" - self._refresh() - def has_external_commits(self): """Tests whether the database has had commits from other sources than this mapping. diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 1602d613..bb5005c4 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -9,8 +9,16 @@ # this program. If not, see . ###################################################################################################################### -from enum import Enum, unique, auto from difflib import SequenceMatcher + +from .conflict_resolution import ( + Conflict, + KeepInMemoryAction, + resolved_conflict_actions, + select_in_memory_item_always, + UpdateInMemoryAction, +) +from .item_status import Status from .temp_id import TempId, resolve from .exception import SpineDBAPIError from .helpers import Asterisk @@ -18,17 +26,6 @@ # TODO: Implement MappedItem.pop() to do lookup? -@unique -class Status(Enum): - """Mapped item status.""" - - committed = auto() - to_add = auto() - to_update = auto() - to_remove = auto() - added_and_removed = auto() - - class DatabaseMappingBase: """An in-memory mapping of a DB, mapping item types (table names), to numeric ids, to items. @@ -208,9 +205,6 @@ def _rollback(self): item.invalidate_id() return True - def _refresh(self): - """Clears fetch progress, so the DB is queried again.""" - def _check_item_type(self, item_type): if item_type not in self.all_item_types(): candidate = max(self.all_item_types(), key=lambda x: SequenceMatcher(None, item_type, x).ratio()) @@ -258,7 +252,7 @@ def _get_next_chunk(self, item_type, offset, limit, **kwargs): return [dict(x) for x in qry] return [dict(x) for x in qry.limit(limit).offset(offset)] - def do_fetch_more(self, item_type, offset=0, limit=None, **kwargs): + def do_fetch_more(self, item_type, offset=0, limit=None, resolve_conflicts=select_in_memory_item_always, **kwargs): """Fetches items from the DB and adds them to the mapping. Args: @@ -274,19 +268,49 @@ def do_fetch_more(self, item_type, offset=0, limit=None, **kwargs): items = [] new_items = [] # Add items first + conflicts = [] for x in chunk: item, new = mapped_table.add_item_from_db(x) - if new: + if not new: + fetched_item = self.make_item(item_type, **x) + fetched_item.polish() + conflicts.append(Conflict(item, fetched_item)) + else: new_items.append(item) - items.append(item) + items.append(item) + if conflicts: + resolved = resolve_conflicts(conflicts) + items += self._apply_conflict_resolutions(resolved) # Once all items are added, add the unique key values # Otherwise items that refer to other items that come later in the query will be seen as corrupted for item in new_items: mapped_table.add_unique(item) return items - def do_fetch_all(self, item_type, **kwargs): - self.do_fetch_more(item_type, offset=0, limit=None, **kwargs) + def do_fetch_all(self, item_type, resolve_conflicts=select_in_memory_item_always, **kwargs): + self.do_fetch_more(item_type, offset=0, limit=None, resolve_conflicts=resolve_conflicts, **kwargs) + + @staticmethod + def _apply_conflict_resolutions(resolved_conflicts): + items = [] + for action in resolved_conflict_actions(resolved_conflicts): + if isinstance(action, KeepInMemoryAction): + item = action.in_memory + items.append(item) + if action.set_uncommitted and item.is_committed(): + if item.removed: + item.status = Status.to_remove + else: + item.status = Status.to_update + elif isinstance(action, UpdateInMemoryAction): + item = action.in_memory + if item.removed: + item.resurrect() + item.update(action.in_db) + items.append(item) + else: + raise RuntimeError("unknown conflict resolution action") + return items class _MappedTable(dict): @@ -669,6 +693,11 @@ def removed(self): """ return self._removed + def resurrect(self): + """Sets item as not-removed but does not resurrect referrers.""" + self._removed = False + self._removal_source = None + @property def item_type(self): """Returns this item's type @@ -699,7 +728,7 @@ def invalidate_id(self): """Sets id as invalid.""" self._is_id_valid = False - def _extended(self): + def extended(self): """Returns a dict from this item's original fields plus all the references resolved statically. Returns: @@ -734,7 +763,7 @@ def merge(self, other): if not self._something_to_update(other): # Nothing to update, that's fine return None, "" - merged = {**self._extended(), **other} + merged = {**self.extended(), **other} if not isinstance(merged["id"], int): merged["id"] = self["id"] return merged, "" @@ -1060,7 +1089,7 @@ def commit(self, commit_id): def __repr__(self): """Overridden to return a more verbose representation.""" - return f"{self._item_type}{self._extended()}" + return f"{self._item_type}{self.extended()}" def __getattr__(self, name): """Overridden to return the dictionary key named after the attribute, or None if it doesn't exist.""" @@ -1148,8 +1177,8 @@ def is_committed(self): def _asdict(self): return self._mapped_item._asdict() - def _extended(self): - return self._mapped_item._extended() + def extended(self): + return self._mapped_item.extended() def update(self, **kwargs): self._db_map.update_item(self.item_type, id=self["id"], **kwargs) diff --git a/spinedb_api/item_status.py b/spinedb_api/item_status.py new file mode 100644 index 00000000..0fcc6fd8 --- /dev/null +++ b/spinedb_api/item_status.py @@ -0,0 +1,23 @@ +###################################################################################################################### +# Copyright (C) 2017-2022 Spine project consortium +# This file is part of Spine Database API. +# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser +# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General +# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with +# this program. If not, see . +###################################################################################################################### + +from enum import auto, Enum, unique + + +@unique +class Status(Enum): + """Mapped item status.""" + + committed = auto() + to_add = auto() + to_update = auto() + to_remove = auto() + added_and_removed = auto() diff --git a/spinedb_api/server_client_helpers.py b/spinedb_api/server_client_helpers.py index f9b3b0b1..318698a9 100644 --- a/spinedb_api/server_client_helpers.py +++ b/spinedb_api/server_client_helpers.py @@ -66,7 +66,7 @@ def default(self, o): if isinstance(o, SpineDBAPIError): return str(o) if isinstance(o, PublicItem): - return o._extended() + return o.extended() return super().default(o) @property diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 70c7055c..862a0098 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -14,6 +14,7 @@ import unittest from unittest import mock from unittest.mock import patch + from sqlalchemy.engine.url import make_url, URL from sqlalchemy.util import KeyedTuple from spinedb_api import ( @@ -24,6 +25,7 @@ SpineDBAPIError, SpineIntegrityError, ) +from spinedb_api.conflict_resolution import select_in_db_item_always from spinedb_api.helpers import name_from_elements from tests.custom_db_mapping import CustomDatabaseMapping @@ -83,29 +85,89 @@ def test_shorthand_filter_query_works(self): class TestDatabaseMapping(unittest.TestCase): + def _assert_success(self, result): + item, error = result + self.assertIsNone(error) + return item + + def test_restore_uncommitted_item(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + item, error = db_map.add_entity_class_item(name="my_class") + self.assertIsNone(error) + self.assertEqual(item["name"], "my_class") + self.assertTrue(item.is_valid()) + self.assertFalse(item.is_committed()) + item.remove() + self.assertFalse(item.is_valid()) + item.restore() + self.assertTrue(item.is_valid()) + + def test_restore_committed_and_removed_item(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + item, error = db_map.add_entity_class_item(name="my_class") + self.assertIsNone(error) + self.assertEqual(item["name"], "my_class") + self.assertTrue(item.is_valid()) + self.assertFalse(item.is_committed()) + db_map.commit_session("Add entity class") + self.assertTrue(item.is_committed()) + entity_classes = db_map.query(db_map.entity_class_sq).all() + self.assertEqual(len(entity_classes), 1) + item.remove() + self.assertFalse(item.is_valid()) + self.assertFalse(item.is_committed()) + db_map.commit_session("Remove entity class") + self.assertFalse(item.is_valid()) + self.assertTrue(item.is_committed()) + entity_classes = db_map.query(db_map.entity_class_sq).all() + self.assertEqual(len(entity_classes), 0) + item.restore() + self.assertTrue(item.is_valid()) + self.assertFalse(item.is_committed()) + db_map.commit_session("Restore entity class") + self.assertTrue(item.is_valid()) + self.assertTrue(item.is_committed()) + entity_classes = db_map.query(db_map.entity_class_sq).all() + self.assertEqual(len(entity_classes), 1) + + def test_add_commit_update_commit(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + item, error = db_map.add_entity_class_item(name="my_class") + self.assertIsNone(error) + self.assertEqual(item["name"], "my_class") + self.assertTrue(item.is_valid()) + self.assertFalse(item.is_committed()) + db_map.commit_session("Add entity class") + self.assertTrue(item.is_committed()) + item.update(name="renamed") + self.assertFalse(item.is_committed()) + db_map.commit_session("Rename entity class") + self.assertTrue(item.is_committed()) + entity_classes = db_map.query(db_map.entity_class_sq).all() + self.assertEqual(len(entity_classes), 1) + self.assertEqual(entity_classes[0].name, "renamed") + def test_commit_parameter_value(self): with TemporaryDirectory() as temp_dir: url = "sqlite:///" + os.path.join(temp_dir, "database.sqlite") with DatabaseMapping(url, create=True) as db_map: - _, error = db_map.add_item("entity_class", name="fish", description="It swims.") - self.assertIsNone(error) - _, error = db_map.add_item( - "entity", class_name="fish", name="Nemo", description="Peacefully swimming away." + self._assert_success(db_map.add_item("entity_class", name="fish", description="It swims.")) + self._assert_success( + db_map.add_item("entity", class_name="fish", name="Nemo", description="Peacefully swimming away.") ) - self.assertIsNone(error) - _, error = db_map.add_item("parameter_definition", entity_class_name="fish", name="color") - self.assertIsNone(error) + self._assert_success(db_map.add_item("parameter_definition", entity_class_name="fish", name="color")) value, type_ = to_database("mainly orange") - _, error = db_map.add_item( - "parameter_value", - entity_class_name="fish", - entity_byname=("Nemo",), - parameter_definition_name="color", - alternative_name="Base", - value=value, - type=type_, + self._assert_success( + db_map.add_item( + "parameter_value", + entity_class_name="fish", + entity_byname=("Nemo",), + parameter_definition_name="color", + alternative_name="Base", + value=value, + type=type_, + ) ) - self.assertIsNone(error) db_map.commit_session("Added data") with DatabaseMapping(url) as db_map: color = db_map.get_item( @@ -122,38 +184,42 @@ def test_commit_multidimensional_parameter_value(self): with TemporaryDirectory() as temp_dir: url = "sqlite:///" + os.path.join(temp_dir, "database.sqlite") with DatabaseMapping(url, create=True) as db_map: - _, error = db_map.add_item("entity_class", name="fish", description="It swims.") - self.assertIsNone(error) - _, error = db_map.add_item("entity_class", name="cat", description="Eats fish.") - self.assertIsNone(error) - _, error = db_map.add_item( - "entity_class", - name="fish__cat", - dimension_name_list=("fish", "cat"), - description="A fish getting eaten by a cat?", + self._assert_success(db_map.add_item("entity_class", name="fish", description="It swims.")) + self._assert_success(db_map.add_item("entity_class", name="cat", description="Eats fish.")) + self._assert_success( + db_map.add_item( + "entity_class", + name="fish__cat", + dimension_name_list=("fish", "cat"), + description="A fish getting eaten by a cat?", + ) ) - self.assertIsNone(error) - _, error = db_map.add_item("entity", class_name="fish", name="Nemo", description="Lost (soon).") - self.assertIsNone(error) - _, error = db_map.add_item( - "entity", class_name="cat", name="Felix", description="The wonderful wonderful cat." + self._assert_success( + db_map.add_item("entity", class_name="fish", name="Nemo", description="Lost (soon).") + ) + self._assert_success( + db_map.add_item( + "entity", class_name="cat", name="Felix", description="The wonderful wonderful cat." + ) + ) + self._assert_success( + db_map.add_item("entity", class_name="fish__cat", element_name_list=("Nemo", "Felix")) + ) + self._assert_success( + db_map.add_item("parameter_definition", entity_class_name="fish__cat", name="rate") ) - self.assertIsNone(error) - _, error = db_map.add_item("entity", class_name="fish__cat", element_name_list=("Nemo", "Felix")) - self.assertIsNone(error) - _, error = db_map.add_item("parameter_definition", entity_class_name="fish__cat", name="rate") - self.assertIsNone(error) value, type_ = to_database(0.23) - _, error = db_map.add_item( - "parameter_value", - entity_class_name="fish__cat", - entity_byname=("Nemo", "Felix"), - parameter_definition_name="rate", - alternative_name="Base", - value=value, - type=type_, + self._assert_success( + db_map.add_item( + "parameter_value", + entity_class_name="fish__cat", + entity_byname=("Nemo", "Felix"), + parameter_definition_name="rate", + alternative_name="Base", + value=value, + type=type_, + ) ) - self.assertIsNone(error) db_map.commit_session("Added data") with DatabaseMapping(url) as db_map: color = db_map.get_item( @@ -168,25 +234,23 @@ def test_commit_multidimensional_parameter_value(self): def test_updating_entity_name_updates_the_name_in_parameter_value_too(self): with DatabaseMapping(IN_MEMORY_DB_URL, create=True) as db_map: - _, error = db_map.add_item("entity_class", name="fish", description="It swims.") - self.assertIsNone(error) - _, error = db_map.add_item( - "entity", class_name="fish", name="Nemo", description="Peacefully swimming away." + self._assert_success(db_map.add_item("entity_class", name="fish", description="It swims.")) + self._assert_success( + db_map.add_item("entity", class_name="fish", name="Nemo", description="Peacefully swimming away.") ) - self.assertIsNone(error) - _, error = db_map.add_item("parameter_definition", entity_class_name="fish", name="color") - self.assertIsNone(error) + self._assert_success(db_map.add_item("parameter_definition", entity_class_name="fish", name="color")) value, type_ = to_database("mainly orange") - _, error = db_map.add_item( - "parameter_value", - entity_class_name="fish", - entity_byname=("Nemo",), - parameter_definition_name="color", - alternative_name="Base", - value=value, - type=type_, + self._assert_success( + db_map.add_item( + "parameter_value", + entity_class_name="fish", + entity_byname=("Nemo",), + parameter_definition_name="color", + alternative_name="Base", + value=value, + type=type_, + ) ) - self.assertIsNone(error) color = db_map.get_item( "parameter_value", entity_class_name="fish", @@ -385,12 +449,11 @@ def test_fetch_more(self): expected = [{"id": 1, "name": "Base", "description": "Base alternative", "commit_id": 1}] self.assertEqual([a._asdict() for a in alternatives], expected) - def test_fetch_more_after_commit_and_refresh(self): + def test_fetch_more_after_commit(self): with DatabaseMapping("sqlite://", create=True) as db_map: db_map.add_item("entity_class", name="Widget") db_map.add_item("entity", class_name="Widget", name="gadget") db_map.commit_session("Add test data.") - db_map.refresh_session() entities = db_map.fetch_more("entity") self.assertEqual([(x["class_name"], x["name"]) for x in entities], [("Widget", "gadget")]) @@ -451,24 +514,19 @@ def test_committing_scenario_alternatives(self): with TemporaryDirectory() as temp_dir: url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") with DatabaseMapping(url, create=True) as db_map: - item, error = db_map.add_alternative_item(name="alt1") - self.assertIsNone(error) + item = self._assert_success(db_map.add_alternative_item(name="alt1")) self.assertIsNotNone(item) - item, error = db_map.add_alternative_item(name="alt2") - self.assertIsNone(error) + item = self._assert_success(db_map.add_alternative_item(name="alt2")) self.assertIsNotNone(item) - item, error = db_map.add_scenario_item(name="my_scenario") - self.assertIsNone(error) + item = self._assert_success(db_map.add_scenario_item(name="my_scenario")) self.assertIsNotNone(item) - item, error = db_map.add_scenario_alternative_item( - scenario_name="my_scenario", alternative_name="alt1", rank=0 + item = self._assert_success( + db_map.add_scenario_alternative_item(scenario_name="my_scenario", alternative_name="alt1", rank=0) ) - self.assertIsNone(error) self.assertIsNotNone(item) - item, error = db_map.add_scenario_alternative_item( - scenario_name="my_scenario", alternative_name="alt2", rank=1 + item = self._assert_success( + db_map.add_scenario_alternative_item(scenario_name="my_scenario", alternative_name="alt2", rank=1) ) - self.assertIsNone(error) self.assertIsNotNone(item) db_map.commit_session("Add test data.") with DatabaseMapping(url) as db_map: @@ -510,6 +568,112 @@ def test_committing_entity_group_items_doesnt_add_commit_ids_to_them(self): self.assertEqual(len(groups), 1) self.assertNotIn("commit_id", groups[0]._extended()) + def test_additive_commit_from_another_db_map_gets_fetched(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + items = db_map.get_items("entity") + self.assertEqual(len(items), 0) + with DatabaseMapping(url) as shadow_db_map: + self._assert_success(shadow_db_map.add_entity_class_item(name="my_class")) + self._assert_success(shadow_db_map.add_entity_item(name="my_entity", class_name="my_class")) + shadow_db_map.commit_session("Add entity.") + items = db_map.get_items("entity") + self.assertEqual(len(items), 1) + self.assertEqual( + items[0]._asdict(), + { + "id": 1, + "name": "my_entity", + "description": None, + "class_id": 1, + "element_name_list": None, + "element_id_list": (), + "commit_id": 2, + }, + ) + + def test_updating_item_from_another_db_map_is_overwritten_by_default_conflict_resolution(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + self._assert_success(db_map.add_entity_class_item(name="my_class")) + original_item = self._assert_success(db_map.add_entity_item(name="my_entity", class_name="my_class")) + db_map.commit_session("Add initial data.") + self.assertTrue(original_item.is_committed()) + with DatabaseMapping(url) as shadow_db_map: + items = shadow_db_map.get_items("entity") + self.assertEqual(len(items), 1) + items[0].update(name="renamed_entity") + shadow_db_map.commit_session("Renamed the entity.") + items = db_map.get_items("entity") + self.assertEqual(len(items), 1) + self.assertEqual(items[0], original_item) + self.assertFalse(items[0].is_committed()) + + def test_resolve_an_update_conflict_in_favor_of_external_modification(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + self._assert_success(db_map.add_entity_class_item(name="my_class")) + self._assert_success(db_map.add_entity_item(name="my_entity", class_name="my_class")) + db_map.commit_session("Add initial data.") + with DatabaseMapping(url) as shadow_db_map: + items = shadow_db_map.fetch_more("entity") + self.assertEqual(len(items), 1) + updated_item = items[0] + updated_item.update(name="renamed_entity") + shadow_db_map.commit_session("Renamed the entity.") + items = db_map.fetch_more("entity", resolve_conflicts=select_in_db_item_always) + self.assertEqual(len(items), 1) + self.assertEqual(items[0].item_type, updated_item.item_type) + self.assertEqual(items[0].extended(), updated_item.extended()) + + def test_recreating_deleted_item_externally_brings_it_back_if_favored_by_conflict_resolution(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + removed_item = self._assert_success(db_map.add_entity_class_item(name="my_class")) + self._assert_success(db_map.add_entity_item(name="my_entity", class_name="my_class")) + db_map.commit_session("Add initial data.") + removed_item.remove() + db_map.commit_session("Remove entity class.") + self.assertTrue(removed_item.is_committed()) + with DatabaseMapping(url) as shadow_db_map: + items = shadow_db_map.fetch_more("entity_class") + self.assertEqual(len(items), 0) + self._assert_success(shadow_db_map.add_entity_class_item(name="my_class")) + shadow_db_map.commit_session("Added entity class back.") + items = db_map.get_items("entity_class", resolve_conflicts=select_in_db_item_always) + self.assertEqual(len(items), 1) + self.assertTrue(items[0].is_valid()) + self.assertFalse(items[0].is_committed()) + + def test_restoring_entity_whose_db_id_has_been_replaced_by_external_db_modification(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + self._assert_success(db_map.add_entity_class_item(name="my_class")) + item = self._assert_success(db_map.add_entity_item(class_name="my_class", name="my_entity")) + original_id = item["id"] + db_map.commit_session("Add initial data.") + items = db_map.fetch_more("entity") + self.assertEqual(len(items), 1) + db_map.remove_item("entity", original_id) + db_map.commit_session("Removed entity.") + with DatabaseMapping(url) as shadow_db_map: + self._assert_success(shadow_db_map.add_entity_item(class_name="my_class", name="other_entity")) + shadow_db_map.commit_session("Add entity with different name, probably reusing previous id.") + items = db_map.fetch_more("entity") + self.assertEqual(len(items), 1) + self.assertEqual(items[0]["name"], "my_entity") + all_items = db_map.get_entity_items() + self.assertEqual(len(all_items), 0) + restored_item = db_map.restore_item("entity", original_id) + self.assertEqual(restored_item["name"], "my_entity") + all_items = db_map.get_entity_items() + self.assertEqual(len(all_items), 1) + class TestDatabaseMappingLegacy(unittest.TestCase): """'Backward compatibility' tests, i.e. pre-entity tests converted to work with the entity structure.""" @@ -2754,7 +2918,6 @@ def test_refresh_addition(self): import_functions.import_object_classes(self._db_map, ("second_class",)) entity_class_names = {x["name"] for x in self._db_map.mapped_table("entity_class").valid_values()} self.assertEqual(entity_class_names, {"my_class", "second_class"}) - self._db_map.refresh_session() self._db_map.fetch_all() entity_class_names = {x["name"] for x in self._db_map.mapped_table("entity_class").valid_values()} self.assertEqual(entity_class_names, {"my_class", "second_class"}) @@ -2765,7 +2928,6 @@ def test_refresh_removal(self): self._db_map.remove_items("entity_class", 1) entity_class_names = {x["name"] for x in self._db_map.mapped_table("entity_class").valid_values()} self.assertEqual(entity_class_names, set()) - self._db_map.refresh_session() self._db_map.fetch_all() entity_class_names = {x["name"] for x in self._db_map.mapped_table("entity_class").valid_values()} self.assertEqual(entity_class_names, set()) @@ -2776,7 +2938,6 @@ def test_refresh_update(self): self._db_map.get_item("entity_class", name="my_class").update(name="new_name") entity_class_names = {x["name"] for x in self._db_map.mapped_table("entity_class").valid_values()} self.assertEqual(entity_class_names, {"new_name"}) - self._db_map.refresh_session() self._db_map.fetch_all() entity_class_names = {x["name"] for x in self._db_map.mapped_table("entity_class").valid_values()} self.assertEqual(entity_class_names, {"new_name"}) From 222523609dcf9d6342430086ee5e66e2c0bbfd54 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Thu, 21 Dec 2023 13:39:39 +0200 Subject: [PATCH 220/317] Use in-memory ids instead of database ids This work fixes a lot of issues where the database has been modified externally between fetches/commits. The in-memory cache of DatabaseMapping now uses its own ids for items called item ids, i.e. every item now has a temporary id. However, instead of a specific TempId type we use negative integers for simplicity. There are mappings from item ids to database ids and back. Every time an item is added to the mapping tables, a new item id is created. When fetching or committing, we map the database id to the item id. We also check more rigorously whether a fetched item is actually the same item that may already be in the mapped tables. Re spine-tools/Spine-Toolbox#2431 --- spinedb_api/conflict_resolution.py | 2 +- spinedb_api/db_mapping.py | 9 +- spinedb_api/db_mapping_base.py | 276 +++++++-- spinedb_api/db_mapping_commit_mixin.py | 46 +- spinedb_api/helpers.py | 17 +- spinedb_api/item_id.py | 60 ++ spinedb_api/item_status.py | 1 + spinedb_api/mapped_items.py | 195 +++++- spinedb_api/temp_id.py | 54 -- tests/test_DatabaseMapping.py | 814 +++++++++++++++++++++---- tests/test_db_mapping_base.py | 80 --- tests/test_helpers.py | 74 ++- tests/test_item_id.py | 70 +++ 13 files changed, 1371 insertions(+), 327 deletions(-) create mode 100644 spinedb_api/item_id.py delete mode 100644 spinedb_api/temp_id.py delete mode 100644 tests/test_db_mapping_base.py create mode 100644 tests/test_item_id.py diff --git a/spinedb_api/conflict_resolution.py b/spinedb_api/conflict_resolution.py index d81d2855..9e459b99 100644 --- a/spinedb_api/conflict_resolution.py +++ b/spinedb_api/conflict_resolution.py @@ -52,7 +52,7 @@ class KeepInMemoryAction: def __init__(self, conflict): self.in_memory = conflict.in_memory - self.set_uncommitted = conflict.in_memory.extended() != conflict.in_db.extended() + self.set_uncommitted = not conflict.in_memory.equal_ignoring_ids(conflict.in_db) @dataclass diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 6b758d54..213a939b 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -705,13 +705,18 @@ def commit_session(self, comment): date = datetime.now(timezone.utc) ins = self._metadata.tables["commit"].insert() with self.engine.begin() as connection: + commit_item = {"user": user, "date": date, "comment": comment} try: - commit_id = connection.execute(ins, dict(user=user, date=date, comment=comment)).inserted_primary_key[0] + commit_id = connection.execute(ins, commit_item).inserted_primary_key[0] except DBAPIError as e: raise SpineDBAPIError(f"Fail to commit: {e.orig.args}") from e + commit_item["id"] = commit_id + commit_table = self.mapped_table("commit") + commit_table.add_item_from_db(commit_item) + commit_item_id = commit_table.id_map.item_id(commit_id) for tablename, (to_add, to_update, to_remove) in dirty_items: for item in to_add + to_update + to_remove: - item.commit(commit_id) + item.commit(commit_item_id) # Remove before add, to help with keeping integrity constraints self._do_remove_items(connection, tablename, *{x["id"] for x in to_remove}) self._do_update_items(connection, tablename, *to_update) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index bb5005c4..c3253258 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -8,8 +8,9 @@ # Public License for more details. You should have received a copy of the GNU Lesser General Public License along with # this program. If not, see . ###################################################################################################################### - from difflib import SequenceMatcher +from enum import auto, Enum, unique +from typing import Iterable from .conflict_resolution import ( Conflict, @@ -18,20 +19,27 @@ select_in_memory_item_always, UpdateInMemoryAction, ) +from .item_id import IdFactory, IdMap from .item_status import Status -from .temp_id import TempId, resolve from .exception import SpineDBAPIError from .helpers import Asterisk # TODO: Implement MappedItem.pop() to do lookup? +@unique +class _AddStatus(Enum): + ADDED = auto() + CONFLICT = auto() + DUPLICATE = auto() + + class DatabaseMappingBase: """An in-memory mapping of a DB, mapping item types (table names), to numeric ids, to items. This class is not meant to be used directly. Instead, you should subclass it to fit your particular DB schema. - When subclassing, you need to implement :meth:`item_types`, :meth:`item_factory`, and :meth:`_make_query`. + When subclassing, you need to implement :meth:`item_types`, :meth:`item_factory`, and :meth:`_make_sq`. """ def __init__(self): @@ -81,7 +89,7 @@ def item_factory(item_type): """ raise NotImplementedError() - def _make_query(self, item_type, **kwargs): + def make_query(self, item_type, **kwargs): """Returns a :class:`~spinedb_api.query.Query` object to fetch items of given type. Args: @@ -126,6 +134,15 @@ def make_item(self, item_type, **item): factory = self.item_factory(item_type) return factory(self, item_type, **item) + def any_uncommitted_items(self): + """Returns True if there are uncommitted changes.""" + available_types = tuple(item_type for item_type in self._sorted_item_types if item_type in self._mapped_tables) + return any( + not item.is_committed() + for item_type in available_types + for item in self._mapped_tables[item_type].valid_values() + ) + def dirty_ids(self, item_type): return { item["id"] @@ -201,8 +218,7 @@ def _rollback(self): for item_type, to_add in to_add_by_type: mapped_table = self.mapped_table(item_type) for item in to_add: - if mapped_table.remove_item(item) is not None: - item.invalidate_id() + mapped_table.remove_item(item) return True def _check_item_type(self, item_type): @@ -216,6 +232,14 @@ def mapped_table(self, item_type): self._mapped_tables[item_type] = _MappedTable(self, item_type) return self._mapped_tables[item_type] + def find_item_id(self, item_type, db_id): + """Searches for item id that corresponds to given database id.""" + return self.mapped_table(item_type).id_map.item_id(db_id) + + def find_db_id(self, item_type, item_id): + """Searches for database id that corresponds to given item id.""" + return self.mapped_table(item_type).id_map.db_id(item_id) if item_id < 0 else item_id + def reset(self, *item_types): """Resets the mapping for given item types as if nothing was fetched from the DB or modified in the mapping. Any modifications in the mapping that aren't committed to the DB are lost after this. @@ -235,9 +259,18 @@ def _add_descendants(self, item_types): if not changed: break - def get_mapped_item(self, item_type, id_, fetch=True): + def reset_purging(self): + """Resets purging status for all item types. + + Fetching items of an item type that has been purged will automatically mark those items removed. + Resetting the purge status lets fetched items to be added unmodified. + """ + for mapped_table in self._mapped_tables.values(): + mapped_table.wildcard_item.status = Status.committed + + def get_mapped_item(self, item_type, id_): mapped_table = self.mapped_table(item_type) - return mapped_table.find_item_by_id(id_, fetch=fetch) or {} + return mapped_table.find_item_by_id(id_) or {} def _get_next_chunk(self, item_type, offset, limit, **kwargs): """Gets chunk of items from the DB. @@ -245,7 +278,7 @@ def _get_next_chunk(self, item_type, offset, limit, **kwargs): Returns: list(dict): list of dictionary items. """ - qry = self._make_query(item_type, **kwargs) + qry = self.make_query(item_type, **kwargs) if not qry: return [] if not limit: @@ -270,14 +303,16 @@ def do_fetch_more(self, item_type, offset=0, limit=None, resolve_conflicts=selec # Add items first conflicts = [] for x in chunk: - item, new = mapped_table.add_item_from_db(x) - if not new: + item, add_status = mapped_table.add_item_from_db(x) + if add_status == _AddStatus.CONFLICT: fetched_item = self.make_item(item_type, **x) fetched_item.polish() conflicts.append(Conflict(item, fetched_item)) - else: + elif add_status == _AddStatus.ADDED: new_items.append(item) items.append(item) + elif add_status == _AddStatus.DUPLICATE: + items.append(item) if conflicts: resolved = resolve_conflicts(conflicts) items += self._apply_conflict_resolutions(resolved) @@ -323,6 +358,8 @@ def __init__(self, db_map, item_type, *args, **kwargs): super().__init__(*args, **kwargs) self._db_map = db_map self._item_type = item_type + self._id_factory = IdFactory() + self.id_map = IdMap() self._id_by_unique_key_value = {} self._temp_id_by_db_id = {} self.wildcard_item = MappedItemBase(self._db_map, self._item_type, id=Asterisk) @@ -340,13 +377,9 @@ def get(self, id_, default=None): return super().get(id_, default) def _new_id(self): - temp_id = TempId(self._item_type) - - def _callback(db_id): - self._temp_id_by_db_id[db_id] = temp_id - - temp_id.add_resolve_callback(_callback) - return temp_id + item_id = self._id_factory.next_id() + self.id_map.add_item_id(item_id) + return item_id def _unique_key_value_to_id(self, key, value, fetch=True): """Returns the id that has the given value for the given unique key, or None if not found. @@ -360,11 +393,13 @@ def _unique_key_value_to_id(self, key, value, fetch=True): int """ id_by_unique_value = self._id_by_unique_key_value.get(key, {}) - if not id_by_unique_value and fetch: + value = tuple(tuple(x) if isinstance(x, list) else x for x in value) + item_id = id_by_unique_value.get(value) + if item_id is None and fetch: self._db_map.do_fetch_all(self._item_type) id_by_unique_value = self._id_by_unique_key_value.get(key, {}) - value = tuple(tuple(x) if isinstance(x, list) else x for x in value) - return id_by_unique_value.get(value) + item_id = id_by_unique_value.get(value) + return item_id def _unique_key_value_to_item(self, key, value, fetch=True): return self.get(self._unique_key_value_to_id(key, value, fetch=fetch)) @@ -398,10 +433,19 @@ def find_item(self, item, skip_keys=(), fetch=True): return self.find_item_by_unique_key(item, skip_keys=skip_keys, fetch=fetch) def find_item_by_id(self, id_, fetch=True): + if id_ > 0: + try: + id_ = self.id_map.item_id(id_) + except KeyError: + if fetch: + self._db_map.do_fetch_all(self._item_type) + try: + id_ = self.id_map.item_id(id_) + except KeyError: + return {} + else: + return {} current_item = self.get(id_, {}) - if not current_item and fetch: - self._db_map.do_fetch_all(self._item_type) - current_item = self.get(id_, {}) return current_item def find_item_by_unique_key(self, item, skip_keys=(), fetch=True, complete=True): @@ -440,7 +484,10 @@ def checked_item_and_error(self, item, for_update=False): if error: return None, error valid_types = (type(None),) if for_update else () + self.check_fields_for_addition(candidate_item) self.check_fields(candidate_item._asdict(), valid_types=valid_types) + if not for_update: + candidate_item.convert_dicts_db_ids_to_item_ids(self._item_type, candidate_item, self._db_map) return candidate_item, merge_error def _prepare_item(self, candidate_item, current_item, original_item): @@ -500,10 +547,16 @@ def remove_unique(self, item): def _make_and_add_item(self, item): if not isinstance(item, MappedItemBase): item = self._make_item(item) + item.convert_db_ids_to_item_ids() item.polish() - if "id" not in item or not item.is_id_valid: - item["id"] = self._new_id() - self[item["id"]] = item + item_id = self._new_id() + db_id = item.get("id") + if db_id is not None: + self.id_map.set_db_id(item_id, db_id) + else: + self.id_map.add_item_id(item_id) + item["id"] = item_id + self[item_id] = item return item def add_item_from_db(self, item): @@ -515,16 +568,46 @@ def add_item_from_db(self, item): Returns: tuple(MappedItem,bool): The mapped item and whether it hadn't been added before. """ - current = self.find_item_by_id(item["id"], fetch=False) or self.find_item_by_unique_key( - item, fetch=False, complete=False - ) - if current: - return current, False + same_item = False + if current := self.find_item_by_id(item["id"], fetch=False): + same_item = current.same_db_item(item) + if same_item: + return ( + current, + _AddStatus.DUPLICATE + if not current.removed and self._compare_non_unique_fields(current, item) + else _AddStatus.CONFLICT, + ) + self.id_map.remove_db_id(current["id"]) + if not current.removed: + current.status = Status.to_add + if "commit_id" in current: + current["commit_id"] = None + else: + current.status = Status.overwritten + if not same_item: + current = self.find_item_by_unique_key(item, fetch=False, complete=False) + if current: + return ( + current, + _AddStatus.DUPLICATE if self._compare_non_unique_fields(current, item) else _AddStatus.CONFLICT, + ) item = self._make_and_add_item(item) if self.purged: # Lazy purge: instead of fetching all at purge time, we purge stuff as it comes. item.cascade_remove(source=self.wildcard_item) - return item, True + return item, _AddStatus.ADDED + + @staticmethod + def _compare_non_unique_fields(mapped_item, item): + unique_keys = mapped_item.unique_keys() + for key, value in item.items(): + if key not in mapped_item.fields or key in unique_keys: + continue + mapped_value = mapped_item[key] + if value != mapped_value and (not isinstance(mapped_value, tuple) or (mapped_value and value)): + return False + return True def check_fields(self, item, valid_types=()): factory = self._db_map.item_factory(self._item_type) @@ -553,6 +636,12 @@ def _error(key, value, valid_types): if errors: raise SpineDBAPIError("\n".join(errors)) + def check_fields_for_addition(self, item): + factory = self._db_map.item_factory(self._item_type) + for required_field in factory.required_fields: + if required_field not in item: + raise SpineDBAPIError(f"missing keyword argument {required_field}") + def add_item(self, item): item = self._make_and_add_item(item) self.add_unique(item) @@ -576,7 +665,6 @@ def remove_item(self, item): self.remove_unique(current_item) current_item.cascade_remove(source=self.wildcard_item) return self.wildcard_item - self.remove_unique(item) item.cascade_remove() return item @@ -589,7 +677,6 @@ def restore_item(self, id_): return self.wildcard_item current_item = self.find_item({"id": id_}) if current_item: - self.add_unique(current_item) current_item.cascade_restore() return current_item @@ -600,6 +687,8 @@ class MappedItemBase(dict): fields = {} """A dictionary mapping keys to a another dict mapping "type" to a Python type, "value" to a description of the value for the key, and "optional" to a bool.""" + required_fields = () + """A tuple of field names that are required to create new items in addition to unique constraints.""" _defaults = {} """A dictionary mapping keys to their default values""" _unique_keys = () @@ -621,6 +710,11 @@ class MappedItemBase(dict): Keys in _internal_fields are resolved to the reference key of the alternative reference pointed at by the source key. """ + _id_fields = {} + """A dictionary mapping item types to field names that contain database ids. + Required for conversion from database ids to item ids and back.""" + _external_id_fields = set() + """A set of external field names that contain database ids.""" _private_fields = set() """A set with fields that should be ignored in validations.""" @@ -637,7 +731,6 @@ def __init__(self, db_map, item_type, **kwargs): self.restore_callbacks = set() self.update_callbacks = set() self.remove_callbacks = set() - self._is_id_valid = True self._to_remove = False self._removed = False self._corrupted = False @@ -720,13 +813,84 @@ def key(self): return None return (self._item_type, id_) - @property - def is_id_valid(self): - return self._is_id_valid + def same_db_item(self, db_item): + """Tests if database item that has same db id is in fact same as this item. + + Args: + db_item (dict): item fetched from database + + Returns: + bool: True if items are the same, False otherwise + """ + raise NotImplementedError() + + def convert_db_ids_to_item_ids(self): + for item_type, id_fields in self._id_fields.items(): + for id_field in id_fields: + try: + field = self[id_field] + except KeyError: + continue + if field is None: + continue + if isinstance(field, Iterable): + self[id_field] = tuple( + self._find_or_fetch_item_id(item_type, self._item_type, db_id, self._db_map) for db_id in field + ) + else: + self[id_field] = self._find_or_fetch_item_id(item_type, self._item_type, field, self._db_map) + + @staticmethod + def _find_or_fetch_item_id(item_type, requesting_item_type, db_id, db_map): + try: + item_id = db_map.find_item_id(item_type, db_id) + except KeyError: + pass + else: + item = db_map.mapped_table(item_type)[item_id] + if not item.removed: + return item_id + if item_type == requesting_item_type: + # We could be fetching everything already, so fetch only a specific id + # to avoid endless recursion. + db_map.do_fetch_all(item_type, id=db_id) + else: + db_map.do_fetch_all(item_type) + return db_map.find_item_id(item_type, db_id) + + @classmethod + def convert_dicts_db_ids_to_item_ids(cls, item_type, item_dict, db_map): + for field_item_type, id_fields in cls._id_fields.items(): + for id_field in id_fields: + try: + field = item_dict[id_field] + except KeyError: + continue + if field is None: + continue + if isinstance(field, Iterable): + item_dict[id_field] = tuple( + cls._find_or_fetch_item_id(field_item_type, item_type, id_, db_map) if id_ > 0 else id_ + for id_ in field + ) + else: + item_dict[id_field] = ( + cls._find_or_fetch_item_id(field_item_type, item_type, field, db_map) if field > 0 else field + ) - def invalidate_id(self): - """Sets id as invalid.""" - self._is_id_valid = False + def make_db_item(self, find_db_id): + db_item = dict(self) + db_item["id"] = find_db_id(self._item_type, db_item["id"]) + for item_type, id_fields in self._id_fields.items(): + for id_field in id_fields: + field = db_item[id_field] + if field is None: + continue + if isinstance(field, Iterable): + db_item[id_field] = tuple(find_db_id(item_type, item_id) for item_id in field) + else: + db_item[id_field] = find_db_id(item_type, field) + return db_item def extended(self): """Returns a dict from this item's original fields plus all the references resolved statically. @@ -746,8 +910,16 @@ def _asdict(self): """ return dict(self) - def resolve(self): - return {k: resolve(v) for k, v in self._asdict().items()} + def equal_ignoring_ids(self, other): + """Compares the non-id fields for equality. + + Args: + other (MappedItemBase): other item + + Returns: + bool: True if non-id fields are equal, False otherwise + """ + return all(self[field] == other[field] for field in self.fields) def merge(self, other): """Merges this item with another and returns the merged item together with any errors. @@ -760,6 +932,10 @@ def merge(self, other): dict: merged item. str: error description if any. """ + other = {key: value for key, value in other.items() if key not in self._external_id_fields} + self.convert_dicts_db_ids_to_item_ids(self._item_type, other, self._db_map) + if "id" in other: + del other["id"] if not self._something_to_update(other): # Nothing to update, that's fine return None, "" @@ -808,6 +984,10 @@ def _invalid_keys(self): elif not self._get_ref(ref_type, {ref_key: src_val}): yield src_key + @classmethod + def unique_keys(cls): + return set(sum(cls._unique_keys, ())) + @classmethod def unique_values_for_item(cls, item, skip_keys=()): for key in cls._unique_keys: @@ -999,7 +1179,7 @@ def cascade_restore(self, source=None): return if self.status in (Status.added_and_removed, Status.to_remove): self._status = self._status_when_removed - elif self.status == Status.committed: + elif self.status == Status.committed or self.status == Status.overwritten: self._status = Status.to_add else: raise RuntimeError("invalid status for item being restored") @@ -1074,12 +1254,12 @@ def cascade_remove_unique(self): referrer.cascade_remove_unique() def is_committed(self): - """Returns whether or not this item is committed to the DB. + """Returns whether this item is committed to the DB. Returns: bool """ - return self._status == Status.committed + return self._status == Status.committed or self._status == Status.overwritten def commit(self, commit_id): """Sets this item as committed with the given commit id.""" diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py index ce105140..50d3019a 100644 --- a/spinedb_api/db_mapping_commit_mixin.py +++ b/spinedb_api/db_mapping_commit_mixin.py @@ -13,7 +13,6 @@ from sqlalchemy.sql.expression import bindparam from sqlalchemy.exc import DBAPIError from .exception import SpineDBAPIError -from .temp_id import TempId, resolve from .helpers import group_consecutive, Asterisk @@ -36,15 +35,17 @@ def _do_add_items(self, connection, tablename, *items_to_add): if not items_to_add: return try: - table = self._metadata.tables[self.real_item_type(tablename)] + item_type = self.real_item_type(tablename) + table = self._metadata.tables[item_type] id_items, temp_id_items = [], [] + id_map = self.mapped_table(item_type).id_map for item in items_to_add: - if isinstance(item["id"], TempId): + if id_map.db_id(item["id"]) is None: temp_id_items.append(item) else: id_items.append(item) if id_items: - connection.execute(table.insert(), [x.resolve() for x in id_items]) + connection.execute(table.insert(), [x.make_db_item(self.find_db_id) for x in id_items]) if temp_id_items: current_ids = {x["id"] for x in connection.execute(table.select())} next_id = max(current_ids, default=0) + 1 @@ -53,35 +54,35 @@ def _do_add_items(self, connection, tablename, *items_to_add): new_ids = set(range(next_id, next_id + required_id_count)) ids = sorted(available_ids | new_ids) for id_, item in zip(ids, temp_id_items): - temp_id = item["id"] - temp_id.resolve(id_) - connection.execute(table.insert(), [x.resolve() for x in temp_id_items]) + id_map.set_db_id(item["id"], id_) + connection.execute(table.insert(), [x.make_db_item(self.find_db_id) for x in temp_id_items]) for tablename_, items_to_add_ in self._extra_items_to_add_per_table(tablename, items_to_add): if not items_to_add_: continue table = self._metadata.tables[self.real_item_type(tablename_)] - connection.execute(table.insert(), [resolve(x) for x in items_to_add_]) + connection.execute(table.insert(), items_to_add_) except DBAPIError as e: msg = f"DBAPIError while inserting {tablename} items: {e.orig.args}" raise SpineDBAPIError(msg) from e - @staticmethod - def _dimensions_for_classes(classes): + def _dimensions_for_classes(self, classes): + id_map = self.mapped_table("entity_class").id_map return [ - {"entity_class_id": x["id"], "position": position, "dimension_id": dimension_id} + {"entity_class_id": id_map.db_id(x["id"]), "position": position, "dimension_id": id_map.db_id(dimension_id)} for x in classes for position, dimension_id in enumerate(x["dimension_id_list"]) ] - @staticmethod - def _elements_for_entities(entities): + def _elements_for_entities(self, entities): + entity_id_map = self.mapped_table("entity").id_map + class_id_map = self.mapped_table("entity_class").id_map return [ { - "entity_id": x["id"], - "entity_class_id": x["class_id"], + "entity_id": entity_id_map.db_id(x["id"]), + "entity_class_id": class_id_map.db_id(x["class_id"]), "position": position, - "element_id": element_id, - "dimension_id": dimension_id, + "element_id": entity_id_map.db_id(element_id), + "dimension_id": class_id_map.db_id(dimension_id), } for x in entities for position, (element_id, dimension_id) in enumerate(zip(x["element_id_list"], x["dimension_id_list"])) @@ -117,12 +118,12 @@ def _do_update_items(self, connection, tablename, *items_to_update): return try: upd = self._make_update_stmt(tablename, items_to_update[0].keys()) - connection.execute(upd, [x.resolve() for x in items_to_update]) + connection.execute(upd, [x.make_db_item(self.find_db_id) for x in items_to_update]) for tablename_, items_to_update_ in self._extra_items_to_update_per_table(tablename, items_to_update): if not items_to_update_: continue upd = self._make_update_stmt(tablename_, items_to_update_[0].keys()) - connection.execute(upd, [resolve(x) for x in items_to_update_]) + connection.execute(upd, items_to_update_) except DBAPIError as e: msg = f"DBAPIError while updating '{tablename}' items: {e.orig.args}" raise SpineDBAPIError(msg) from e @@ -134,7 +135,10 @@ def _do_remove_items(self, connection, tablename, *ids): *ids: ids to remove """ tablename = self.real_item_type(tablename) - ids = {resolve(id_) for id_ in ids} + id_map = self.mapped_table(tablename).id_map + purging = Asterisk in ids + if not purging: + ids = {id_map.db_id(id_) for id_ in ids} if tablename == "alternative": # Do not remove the Base alternative ids.discard(1) @@ -150,7 +154,7 @@ def _do_remove_items(self, connection, tablename, *ids): for tablename_ in tablenames: table = self._metadata.tables[tablename_] delete = table.delete() - if Asterisk not in ids: + if not purging: id_field = self._id_fields.get(tablename_, "id") id_column = getattr(table.c, id_field) cond = or_(*(and_(id_column >= first, id_column <= last) for first, last in group_consecutive(ids))) diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index 98722b35..41f55b85 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -8,10 +8,7 @@ # Public License for more details. You should have received a copy of the GNU Lesser General Public License along with # this program. If not, see . ###################################################################################################################### - -""" -General helper functions. -""" +""" General helper functions. """ import os import json import warnings @@ -875,3 +872,15 @@ def group_consecutive(list_of_numbers): for _k, g in groupby(enumerate(sorted(list_of_numbers)), lambda x: x[0] - x[1]): group = list(map(itemgetter(1), g)) yield group[0], group[-1] + + +def query_byname(entity_row, db_map): + element_ids = entity_row["element_id_list"] + if element_ids is None: + return (entity_row["name"],) + sq = db_map.wide_entity_sq + byname = [] + for element_id in element_ids.split(","): + element_row = db_map.query(sq).filter(sq.c.id == element_id).one() + byname += list(query_byname(element_row, db_map)) + return tuple(byname) diff --git a/spinedb_api/item_id.py b/spinedb_api/item_id.py new file mode 100644 index 00000000..2aec7f84 --- /dev/null +++ b/spinedb_api/item_id.py @@ -0,0 +1,60 @@ +###################################################################################################################### +# Copyright (C) 2017-2022 Spine project consortium +# This file is part of Spine Database API. +# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser +# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General +# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with +# this program. If not, see . +###################################################################################################################### +from collections import Counter + + +class IdFactory: + def __init__(self): + self._next_id = -1 + + def next_id(self): + item_id = self._next_id + self._next_id -= 1 + return item_id + + +class IdMap: + def __init__(self): + self._item_id_by_db_id = {} + self._db_id_by_item_id = {} + + def add_item_id(self, item_id): + self._db_id_by_item_id[item_id] = None + + def remove_item_id(self, item_id): + db_id = self._db_id_by_item_id.pop(item_id, None) + if db_id is not None: + del self._item_id_by_db_id[db_id] + + def set_db_id(self, item_id, db_id): + self._db_id_by_item_id[item_id] = db_id + self._item_id_by_db_id[db_id] = item_id + + def remove_db_id(self, id_): + if id_ > 0: + item_id = self._item_id_by_db_id.pop(id_) + else: + item_id = id_ + db_id = self._db_id_by_item_id[item_id] + del self._item_id_by_db_id[db_id] + self._db_id_by_item_id[item_id] = None + + def item_id(self, db_id): + return self._item_id_by_db_id[db_id] + + def has_db_id(self, item_id): + return item_id in self._db_id_by_item_id + + def db_id(self, item_id): + return self._db_id_by_item_id[item_id] + + def db_id_iter(self): + yield from self._db_id_by_item_id diff --git a/spinedb_api/item_status.py b/spinedb_api/item_status.py index 0fcc6fd8..7fb8e76f 100644 --- a/spinedb_api/item_status.py +++ b/spinedb_api/item_status.py @@ -21,3 +21,4 @@ class Status(Enum): to_update = auto() to_remove = auto() added_and_removed = auto() + overwritten = auto() diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 8a32d102..9b939a9e 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -11,7 +11,7 @@ from operator import itemgetter -from .helpers import name_from_elements +from .helpers import name_from_elements, query_byname from .parameter_value import to_database, from_database, ParameterValueFormatError from .db_mapping_base import MappedItemBase @@ -49,6 +49,9 @@ class CommitItem(MappedItemBase): _unique_keys = (("date",),) + def same_db_item(self, db_item): + return self["date"].replace(tzinfo=None) == db_item["date"] + def commit(self, commit_id): raise RuntimeError("Commits are created automatically when session is committed.") @@ -76,6 +79,7 @@ class EntityClassItem(MappedItemBase): _external_fields = {"dimension_name_list": ("dimension_id_list", "name")} _alt_references = {("dimension_name_list",): ("entity_class", ("name",))} _internal_fields = {"dimension_id_list": (("dimension_name_list",), "id")} + _id_fields = {"entity_class": ("dimension_id_list",)} _private_fields = {"dimension_count"} def __init__(self, *args, **kwargs): @@ -92,6 +96,9 @@ def __getitem__(self, key): return self._get_ref("superclass_subclass", {"subclass_id": self["id"]}, strong=False).get(key) return super().__getitem__(key) + def same_db_item(self, db_item): + return self["name"] == db_item["name"] + def merge(self, other): dimension_id_list = other.pop("dimension_id_list", None) error = ( @@ -139,6 +146,8 @@ class EntityItem(MappedItemBase): "class_id": (("class_name",), "id"), "element_id_list": (("dimension_name_list", "element_name_list"), "id"), } + _id_fields = {"entity_class": ("class_id",), "entity": ("element_id_list",), "commit": ("commit_id",)} + _external_id_fields = {"dimension_id_list", "superclass_id"} def __init__(self, *args, **kwargs): element_id_list = kwargs.get("element_id_list") @@ -149,6 +158,13 @@ def __init__(self, *args, **kwargs): kwargs["element_id_list"] = tuple(element_id_list) super().__init__(*args, **kwargs) + def same_db_item(self, db_item): + if _commit_ids_equal(self, db_item, self._db_map): + return True + if self["name"] != db_item["name"]: + return False + return _fields_equal("entity_class", db_item["class_id"], "name", self["class_name"], self._db_map) + @classmethod def unique_values_for_item(cls, item, skip_keys=()): """Overriden to also yield unique values for the superclass.""" @@ -271,6 +287,8 @@ class EntityGroupItem(MappedItemBase): "entity_id": (("class_name", "group_name"), "id"), "member_id": (("class_name", "member_name"), "id"), } + _id_fields = {"entity_class": ("entity_class_id",), "entity": ("entity_id", "member_id")} + _external_id_fields = {"dimension_id_list"} def __getitem__(self, key): if key == "class_id": @@ -279,6 +297,14 @@ def __getitem__(self, key): return self["entity_id"] return super().__getitem__(key) + def same_db_item(self, db_item): + db_map = self._db_map + if not _fields_equal("entity", db_item["entity_id"], "name", self["group_name"], db_map): + return False + if not _fields_equal("entity", db_item["member_id"], "name", self["member_name"], db_map): + return False + return _fields_equal("entity_class", db_item["entity_class_id"], "name", self["class_name"], db_map) + def commit(self, _commit_id): super().commit(None) @@ -323,6 +349,20 @@ class EntityAlternativeItem(MappedItemBase): "entity_id": (("entity_class_name", "entity_byname"), "id"), "alternative_id": (("alternative_name",), "id"), } + _id_fields = {"entity": ("entity_id",), "alternative": ("alternative_id",), "commit": ("commit_id",)} + _external_id_fields = {"entity_class_id", "dimension_id_list", "element_id_list"} + + def same_db_item(self, db_item): + if not _commit_ids_equal(self, db_item, self._db_map): + return False + entity_record = self._db_map.make_query("entity", id=db_item["entity_id"]).one() + if not _fields_equal( + "entity_class", entity_record["class_id"], "name", self["entity_class_name"], self._db_map + ): + return False + if query_byname(entity_record, self._db_map) != self["entity_byname"]: + return False + return _fields_equal("alternative", db_item["alternative_id"], "name", self["alternative_name"], self._db_map) class ParsedValueBase(MappedItemBase): @@ -332,6 +372,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._parsed_value = None + def same_db_item(self, db_item): + raise NotImplementedError() + @property def parsed_value(self): if self._parsed_value is None: @@ -376,6 +419,9 @@ def _something_to_update(self, other): class ParameterItemBase(ParsedValueBase): + def same_db_item(self, db_item): + raise NotImplementedError() + @property def _value_key(self): raise NotImplementedError() @@ -395,12 +441,13 @@ def ref_types(cls): def list_value_id(self): return self["list_value_id"] - def resolve(self): - d = super().resolve() - list_value_id = d.get("list_value_id") + def make_db_item(self, find_db_id): + db_item = super().make_db_item(find_db_id) + list_value_id = db_item.get("list_value_id") if list_value_id is not None: - d[self._value_key] = to_database(list_value_id)[0] - return d + list_value_db_id = self._db_map.find_db_id("list_value", list_value_id) + db_item[self._value_key] = to_database(list_value_db_id)[0] + return db_item def polish(self): self["list_value_id"] = None @@ -456,6 +503,12 @@ class ParameterDefinitionItem(ParameterItemBase): "entity_class_id": (("entity_class_name",), "id"), "parameter_value_list_id": (("parameter_value_list_name",), "id"), } + _id_fields = { + "entity_class": ("entity_class_id",), + "parameter_value_list": ("parameter_value_list_id",), + "commit": ("commit_id",), + } + _external_id_fields = {"dimension_id_list"} @property def _value_key(self): @@ -482,6 +535,15 @@ def __getitem__(self, key): return dict.get(self, key) return super().__getitem__(key) + def same_db_item(self, db_item): + if _commit_ids_equal(self, db_item, self._db_map): + return True + if self["name"] != db_item["name"]: + return False + return _fields_equal( + "entity_class", db_item["entity_class_id"], "name", self["entity_class_name"], self._db_map + ) + def merge(self, other): other_parameter_value_list_id = other.get("parameter_value_list_id") if ( @@ -515,6 +577,7 @@ class ParameterValueItem(ParameterItemBase): 'type': {'type': str, 'value': 'The value type.', 'optional': True}, 'alternative_name': {'type': str, 'value': "The alternative name - defaults to 'Base'.", 'optional': True}, } + required_fields = ("value", "type") _unique_keys = (("entity_class_name", "parameter_definition_name", "entity_byname", "alternative_name"),) _references = { "entity_class_id": ("entity_class", "id"), @@ -547,6 +610,15 @@ class ParameterValueItem(ParameterItemBase): "entity_id": (("entity_class_name", "entity_byname"), "id"), "alternative_id": (("alternative_name",), "id"), } + _id_fields = { + "parameter_definition": ("parameter_definition_id",), + "entity_class": ("entity_class_id",), + "entity": ("entity_id",), + "list_value": ("list_value_id",), + "alternative": ("alternative_id",), + "commit": ("commit_id",), + } + _external_id_fields = {"dimension_id_list", "element_id_list", "parameter_value_list_id"} @property def _value_key(self): @@ -567,6 +639,25 @@ def __getitem__(self, key): return self._get_ref("list_value", {"id": list_value_id}, strong=False).get(key) return super().__getitem__(key) + def same_db_item(self, db_item): + if _commit_ids_equal(self, db_item, self._db_map): + return True + if not _fields_equal( + "entity_class", db_item["entity_class_id"], "name", self["entity_class_name"], self._db_map + ): + return False + if not _fields_equal( + "parameter_definition", + db_item["parameter_definition_id"], + "name", + self["parameter_definition_name"], + self._db_map, + ): + return False + if not _fields_equal("entity", db_item["entity_id"], "name", self["entity_name"], self._db_map): + return False + return _fields_equal("alternative", db_item["alternative_id"], "name", self["alternative_name"], self._db_map) + def _value_not_in_list_error(self, parsed_value, list_name): return ( f"value {parsed_value} of {self['parameter_definition_name']} for {self['entity_byname']} " @@ -577,6 +668,10 @@ def _value_not_in_list_error(self, parsed_value, list_name): class ParameterValueListItem(MappedItemBase): fields = {'name': {'type': str, 'value': 'The parameter value list name.'}} _unique_keys = (("name",),) + _id_fields = {"commit": ("commit_id",)} + + def same_db_item(self, db_item): + return db_item["name"] == self["name"] class ListValueItem(ParsedValueBase): @@ -591,6 +686,7 @@ class ListValueItem(ParsedValueBase): _external_fields = {"parameter_value_list_name": ("parameter_value_list_id", "name")} _alt_references = {("parameter_value_list_name",): ("parameter_value_list", ("name",))} _internal_fields = {"parameter_value_list_id": (("parameter_value_list_name",), "id")} + _id_fields = {"parameter_value_list": ("parameter_value_list_id",), "commit": ("commit_id",)} @property def _value_key(self): @@ -605,6 +701,19 @@ def __getitem__(self, key): return (self["value"], self["type"]) return super().__getitem__(key) + def same_db_item(self, db_item): + if _commit_ids_equal(self, db_item, self._db_map): + return True + if self["index"] != db_item["index"]: + return False + return _fields_equal( + "parameter_value_list", + db_item["parameter_value_list_id"], + "name", + self["parameter_value_list_name"], + self._db_map, + ) + class AlternativeItem(MappedItemBase): fields = { @@ -613,6 +722,10 @@ class AlternativeItem(MappedItemBase): } _defaults = {"description": None} _unique_keys = (("name",),) + _id_fields = {"commit": ("commit_id",)} + + def same_db_item(self, db_item): + return self["name"] == db_item["name"] class ScenarioItem(MappedItemBase): @@ -623,6 +736,7 @@ class ScenarioItem(MappedItemBase): } _defaults = {"active": False, "description": None} _unique_keys = (("name",),) + _id_fields = {"commit": ("commit_id",)} def __getitem__(self, key): if key == "alternative_id_list": @@ -641,6 +755,9 @@ def __getitem__(self, key): ) return super().__getitem__(key) + def same_db_item(self, db_item): + return self["name"] == db_item["name"] + class ScenarioAlternativeItem(MappedItemBase): fields = { @@ -653,6 +770,7 @@ class ScenarioAlternativeItem(MappedItemBase): _external_fields = {"scenario_name": ("scenario_id", "name"), "alternative_name": ("alternative_id", "name")} _alt_references = {("scenario_name",): ("scenario", ("name",)), ("alternative_name",): ("alternative", ("name",))} _internal_fields = {"scenario_id": (("scenario_name",), "id"), "alternative_id": (("alternative_name",), "id")} + _id_fields = {"scenario": ("scenario_id",), "alternative": ("alternative_id",), "commit": ("commit_id",)} def __getitem__(self, key): # The 'before' is to be interpreted as, this scenario alternative goes *before* the before_alternative. @@ -669,6 +787,15 @@ def __getitem__(self, key): return None return super().__getitem__(key) + def same_db_item(self, db_item): + if _commit_ids_equal(self, db_item, self._db_map): + return True + if not _fields_equal("scenario", db_item["scenario_id"], "name", self["scenario_name"], self._db_map): + return False + if not _fields_equal("alternative", db_item["alternative_id"], "name", self["alternative_name"], self._db_map): + return False + return self["rank"] == db_item["rank"] + class MetadataItem(MappedItemBase): fields = { @@ -676,6 +803,10 @@ class MetadataItem(MappedItemBase): 'value': {'type': str, 'value': 'The metadata entry value.'}, } _unique_keys = (("name", "value"),) + _id_fields = {"commit": ("commit_id",)} + + def same_db_item(self, db_item): + return self["name"] == db_item["name"] and self["value"] == db_item["value"] class EntityMetadataItem(MappedItemBase): @@ -707,6 +838,18 @@ class EntityMetadataItem(MappedItemBase): "entity_id": (("class_name", "entity_byname"), "id"), "metadata_id": (("metadata_name", "metadata_value"), "id"), } + _id_fields = {"entity": ("entity_id",), "metadata": ("metadata_id",), "commit": ("commit_id",)} + + def same_db_item(self, db_item): + if _commit_ids_equal(self, db_item, self._db_map): + return True + entity_record = self._db_map.make_query("entity", id=db_item["entity_id"]).one_or_none() + if not entity_record or query_byname(entity_record, self._db_map) != self["entity_byname"]: + return False + if not _fields_equal("entity_class", db_item["class_id"], "name", self["class_name"], self._db_map): + return False + record = self._db_map.make_query("metadata", id=db_item["metadata_id"]).one_or_none() + return record and self["metadata_name"] == record["name"] and self["metadata_value"] == record["value"] class ParameterValueMetadataItem(MappedItemBase): @@ -754,6 +897,31 @@ class ParameterValueMetadataItem(MappedItemBase): ), "metadata_id": (("metadata_name", "metadata_value"), "id"), } + _id_fields = {"parameter_value": ("parameter_value_id",), "metadata": ("metadata_id",), "commit": ("commit_id",)} + + def same_db_item(self, db_item): + if _commit_ids_equal(self, db_item, self._db_map): + return True + value_record = self._db_map.make_query("parameter_value", id=db_item["parameter_value_id"]).one() + entity_record = self._db_map.make_query("entity", id=value_record["entity_id"]).one() + if query_byname(entity_record, self._db_map) != self["entity_byname"]: + return False + if not _fields_equal("entity_class", value_record["entity_class_id"], "name", self["class_name"], self._db_map): + return False + if not _fields_equal( + "parameter_definition", + value_record["parameter_definition_id"], + "name", + self["parameter_definition_name"], + self._db_map, + ): + return False + if not _fields_equal( + "alternative", value_record["alternative_id"], "name", self["alternative_name"], self._db_map + ): + return False + record = self._db_map.make_query("metadata", id=db_item["metadata_id"]).one() + return self["metadata_name"] == record["name"] and self["metadata_value"] == record["value"] class SuperclassSubclassItem(MappedItemBase): @@ -772,6 +940,7 @@ class SuperclassSubclassItem(MappedItemBase): ("subclass_name",): ("entity_class", ("name",)), } _internal_fields = {"superclass_id": (("superclass_name",), "id"), "subclass_id": (("subclass_name",), "id")} + _id_fields = {"entity_class": ("superclass_id", "subclass_id")} def _subclass_entities(self): return self._db_map.get_items("entity", class_id=self["subclass_id"]) @@ -781,5 +950,19 @@ def check_mutability(self): return "can't set or modify the superclass for a class that already has entities" return super().check_mutability() + def same_db_item(self, db_item): + return _fields_equal("entity_class", db_item["subclass_id"], "name", self["subclass_name"], self._db_map) + def commit(self, _commit_id): super().commit(None) + + +def _commit_ids_equal(item, db_item, db_map): + db_commit_id = db_map.find_db_id("commit", item["commit_id"]) + return db_commit_id == db_item["commit_id"] + + +def _fields_equal(item_type, db_id, field, expected_value, db_map): + # Use plain query as we want the raw data from database, not something that may have been conflict resolved. + record = db_map.make_query(item_type, id=db_id).one() + return expected_value == record[field] diff --git a/spinedb_api/temp_id.py b/spinedb_api/temp_id.py deleted file mode 100644 index 79066941..00000000 --- a/spinedb_api/temp_id.py +++ /dev/null @@ -1,54 +0,0 @@ -###################################################################################################################### -# Copyright (C) 2017-2022 Spine project consortium -# This file is part of Spine Database API. -# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser -# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your -# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; -# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General -# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with -# this program. If not, see . -###################################################################################################################### - - -class TempId(int): - _next_id = {} - - def __new__(cls, item_type): - id_ = cls._next_id.setdefault(item_type, -1) - cls._next_id[item_type] -= 1 - return super().__new__(cls, id_) - - def __init__(self, item_type): - super().__init__() - self._item_type = item_type - self._resolve_callbacks = [] - self._db_id = None - - @property - def db_id(self): - return self._db_id - - def __eq__(self, other): - return super().__eq__(other) or (self._db_id is not None and other == self._db_id) - - def __hash__(self): - return int(self) - - def __repr__(self): - return f"TempId({self._item_type}, {super().__repr__()})" - - def add_resolve_callback(self, callback): - self._resolve_callbacks.append(callback) - - def resolve(self, db_id): - self._db_id = db_id - while self._resolve_callbacks: - self._resolve_callbacks.pop(0)(db_id) - - -def resolve(value): - if isinstance(value, dict): - return {k: resolve(v) for k, v in value.items()} - if isinstance(value, TempId): - return value.db_id - return value diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 862a0098..c45780cd 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -27,6 +27,7 @@ ) from spinedb_api.conflict_resolution import select_in_db_item_always from spinedb_api.helpers import name_from_elements +from spinedb_api.mapped_items import EntityItem from tests.custom_db_mapping import CustomDatabaseMapping @@ -90,6 +91,40 @@ def _assert_success(self, result): self.assertIsNone(error) return item + def test_add_parameter_value_without_value_gives_error(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + self._assert_success(db_map.add_entity_class_item(name="Widget")) + self._assert_success(db_map.add_entity_item(name="spoon", class_name="Widget")) + self._assert_success(db_map.add_parameter_definition_item(name="size", entity_class_name="Widget")) + self.assertRaises( + SpineDBAPIError, + db_map.add_parameter_value_item, + **dict( + parameter_definition_name="size", + entity_class_name="Widget", + entity_byname=("spoon",), + alternative_name="Base", + type=None, + ) + ) + + def test_add_parameter_value_without_type_gives_error(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + self._assert_success(db_map.add_entity_class_item(name="Widget")) + self._assert_success(db_map.add_entity_item(name="spoon", class_name="Widget")) + self._assert_success(db_map.add_parameter_definition_item(name="size", entity_class_name="Widget")) + self.assertRaises( + SpineDBAPIError, + db_map.add_parameter_value_item, + **dict( + parameter_definition_name="size", + entity_class_name="Widget", + entity_byname=("spoon",), + alternative_name="Base", + value=to_database(2.3)[0], + ) + ) + def test_restore_uncommitted_item(self): with DatabaseMapping("sqlite://", create=True) as db_map: item, error = db_map.add_entity_class_item(name="my_class") @@ -282,21 +317,22 @@ def test_updating_entity_name_updates_the_name_in_parameter_value_too(self): def test_update_entity_metadata_by_changing_its_entity(self): with DatabaseMapping("sqlite://", create=True) as db_map: - entity_class, _ = db_map.add_entity_class_item(name="my_class") - db_map.add_entity_item(name="entity_1", class_name="my_class") - entity_2, _ = db_map.add_entity_item(name="entity_2", class_name="my_class") + self._assert_success(db_map.add_entity_class_item(name="my_class")) + self._assert_success(db_map.add_entity_item(name="entity_1", class_name="my_class")) + entity_2 = self._assert_success(db_map.add_entity_item(name="entity_2", class_name="my_class")) metadata_value = '{"sources": [], "contributors": []}' - metadata, _ = db_map.add_metadata_item(name="my_metadata", value=metadata_value) - entity_metadata, error = db_map.add_entity_metadata_item( - metadata_name="my_metadata", - metadata_value=metadata_value, - class_name="my_class", - entity_byname=("entity_1",), + metadata = self._assert_success(db_map.add_metadata_item(name="my_metadata", value=metadata_value)) + entity_metadata = self._assert_success( + db_map.add_entity_metadata_item( + metadata_name="my_metadata", + metadata_value=metadata_value, + class_name="my_class", + entity_byname=("entity_1",), + ) ) - self.assertIsNone(error) entity_metadata.update(entity_byname=("entity_2",)) self.assertEqual( - entity_metadata._extended(), + entity_metadata.extended(), { "class_name": "my_class", "entity_byname": ("entity_2",), @@ -343,46 +379,47 @@ def test_update_entity_metadata_by_changing_its_entity(self): def test_update_parameter_value_metadata_by_changing_its_parameter(self): with DatabaseMapping("sqlite://", create=True) as db_map: - entity_class, _ = db_map.add_entity_class_item(name="my_class") - _, error = db_map.add_parameter_definition_item(name="x", entity_class_name="my_class") - self.assertIsNone(error) - db_map.add_parameter_definition_item(name="y", entity_class_name="my_class") - entity, _ = db_map.add_entity_item(name="my_entity", class_name="my_class") + entity_class = self._assert_success(db_map.add_entity_class_item(name="my_class")) + self._assert_success(db_map.add_parameter_definition_item(name="x", entity_class_name="my_class")) + self._assert_success(db_map.add_parameter_definition_item(name="y", entity_class_name="my_class")) + self._assert_success(db_map.add_entity_item(name="my_entity", class_name="my_class")) value, value_type = to_database(2.3) - _, error = db_map.add_parameter_value_item( - entity_class_name="my_class", - entity_byname=("my_entity",), - parameter_definition_name="x", - alternative_name="Base", - value=value, - type=value_type, + self._assert_success( + db_map.add_parameter_value_item( + entity_class_name="my_class", + entity_byname=("my_entity",), + parameter_definition_name="x", + alternative_name="Base", + value=value, + type=value_type, + ) ) - self.assertIsNone(error) value, value_type = to_database(-2.3) - y, error = db_map.add_parameter_value_item( - entity_class_name="my_class", - entity_byname=("my_entity",), - parameter_definition_name="y", - alternative_name="Base", - value=value, - type=value_type, + y = self._assert_success( + db_map.add_parameter_value_item( + entity_class_name="my_class", + entity_byname=("my_entity",), + parameter_definition_name="y", + alternative_name="Base", + value=value, + type=value_type, + ) ) - self.assertIsNone(error) metadata_value = '{"sources": [], "contributors": []}' - metadata, error = db_map.add_metadata_item(name="my_metadata", value=metadata_value) - self.assertIsNone(error) - value_metadata, error = db_map.add_parameter_value_metadata_item( - metadata_name="my_metadata", - metadata_value=metadata_value, - class_name="my_class", - entity_byname=("my_entity",), - parameter_definition_name="x", - alternative_name="Base", + metadata = self._assert_success(db_map.add_metadata_item(name="my_metadata", value=metadata_value)) + value_metadata = self._assert_success( + db_map.add_parameter_value_metadata_item( + metadata_name="my_metadata", + metadata_value=metadata_value, + class_name="my_class", + entity_byname=("my_entity",), + parameter_definition_name="x", + alternative_name="Base", + ) ) - self.assertIsNone(error) value_metadata.update(parameter_definition_name="y") self.assertEqual( - value_metadata._extended(), + value_metadata.extended(), { "class_name": "my_class", "entity_byname": ("my_entity",), @@ -446,7 +483,7 @@ def test_update_parameter_value_metadata_by_changing_its_parameter(self): def test_fetch_more(self): with DatabaseMapping("sqlite://", create=True) as db_map: alternatives = db_map.fetch_more("alternative") - expected = [{"id": 1, "name": "Base", "description": "Base alternative", "commit_id": 1}] + expected = [{"id": -1, "name": "Base", "description": "Base alternative", "commit_id": -1}] self.assertEqual([a._asdict() for a in alternatives], expected) def test_fetch_more_after_commit(self): @@ -491,18 +528,20 @@ def test_fetch_entities_that_refer_to_unfetched_entities(self): with TemporaryDirectory() as temp_dir: url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") with DatabaseMapping(url, create=True) as db_map: - db_map.add_entity_class_item(name="dog") - db_map.add_entity_class_item(name="cat") - db_map.add_entity_class_item(name="dog__cat", dimension_name_list=("dog", "cat")) - db_map.add_entity_item(name="Pulgoso", class_name="dog") - db_map.add_entity_item(name="Sylvester", class_name="cat") - db_map.add_entity_item(name="Tom", class_name="cat") + self._assert_success(db_map.add_entity_class_item(name="dog")) + self._assert_success(db_map.add_entity_class_item(name="cat")) + self._assert_success(db_map.add_entity_class_item(name="dog__cat", dimension_name_list=("dog", "cat"))) + self._assert_success(db_map.add_entity_item(name="Pulgoso", class_name="dog")) + self._assert_success(db_map.add_entity_item(name="Sylvester", class_name="cat")) + self._assert_success(db_map.add_entity_item(name="Tom", class_name="cat")) db_map.commit_session("Arf!") with DatabaseMapping(url) as db_map: # Remove the entity in the middle and add a multi-D one referring to the third entity. # The multi-D one will go in the middle. db_map.get_entity_item(name="Sylvester", class_name="cat").remove() - db_map.add_entity_item(element_name_list=("Pulgoso", "Tom"), class_name="dog__cat") + self._assert_success( + db_map.add_entity_item(element_name_list=("Pulgoso", "Tom"), class_name="dog__cat") + ) db_map.commit_session("Meow!") with DatabaseMapping(url) as db_map: # The ("Pulgoso", "Tom") entity will be fetched before "Tom". @@ -541,32 +580,34 @@ def test_committing_scenario_alternatives(self): def test_committing_entity_class_items_doesnt_add_commit_ids_to_them(self): with DatabaseMapping("sqlite://", create=True) as db_map: - db_map.add_entity_class_item(name="my_class") + self._assert_success(db_map.add_entity_class_item(name="my_class")) db_map.commit_session("Add class.") classes = db_map.get_entity_class_items() self.assertEqual(len(classes), 1) - self.assertNotIn("commit_id", classes[0]._extended()) + self.assertNotIn("commit_id", classes[0].extended()) def test_committing_superclass_subclass_items_doesnt_add_commit_ids_to_them(self): with DatabaseMapping("sqlite://", create=True) as db_map: - db_map.add_entity_class_item(name="high") - db_map.add_entity_class_item(name="low") - db_map.add_superclass_subclass_item(superclass_name="high", subclass_name="low") + self._assert_success(db_map.add_entity_class_item(name="high")) + self._assert_success(db_map.add_entity_class_item(name="low")) + self._assert_success(db_map.add_superclass_subclass_item(superclass_name="high", subclass_name="low")) db_map.commit_session("Add class hierarchy.") classes = db_map.get_superclass_subclass_items() self.assertEqual(len(classes), 1) - self.assertNotIn("commit_id", classes[0]._extended()) + self.assertNotIn("commit_id", classes[0].extended()) def test_committing_entity_group_items_doesnt_add_commit_ids_to_them(self): with DatabaseMapping("sqlite://", create=True) as db_map: - db_map.add_entity_class_item(name="my_class") - db_map.add_entity_item(name="element", class_name="my_class") - db_map.add_entity_item(name="container", class_name="my_class") - db_map.add_entity_group_item(group_name="container", member_name="element", class_name="my_class") + self._assert_success(db_map.add_entity_class_item(name="my_class")) + self._assert_success(db_map.add_entity_item(name="element", class_name="my_class")) + self._assert_success(db_map.add_entity_item(name="container", class_name="my_class")) + self._assert_success( + db_map.add_entity_group_item(group_name="container", member_name="element", class_name="my_class") + ) db_map.commit_session("Add entity group.") groups = db_map.get_entity_group_items() self.assertEqual(len(groups), 1) - self.assertNotIn("commit_id", groups[0]._extended()) + self.assertNotIn("commit_id", groups[0].extended()) def test_additive_commit_from_another_db_map_gets_fetched(self): with TemporaryDirectory() as temp_dir: @@ -583,13 +624,13 @@ def test_additive_commit_from_another_db_map_gets_fetched(self): self.assertEqual( items[0]._asdict(), { - "id": 1, + "id": -1, "name": "my_entity", "description": None, - "class_id": 1, + "class_id": -1, "element_name_list": None, "element_id_list": (), - "commit_id": 2, + "commit_id": -2, }, ) @@ -598,18 +639,37 @@ def test_updating_item_from_another_db_map_is_overwritten_by_default_conflict_re url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") with DatabaseMapping(url, create=True) as db_map: self._assert_success(db_map.add_entity_class_item(name="my_class")) - original_item = self._assert_success(db_map.add_entity_item(name="my_entity", class_name="my_class")) + value, type_ = to_database(2.3) + original_item = self._assert_success( + db_map.add_parameter_definition_item( + entity_class_name="my_class", name="measurable", default_type=type_, default_value=value + ) + ) db_map.commit_session("Add initial data.") self.assertTrue(original_item.is_committed()) + definitions = db_map.query(db_map.parameter_definition_sq).all() + self.assertEqual(len(definitions), 1) + value = from_database(definitions[0]["default_value"], definitions[0]["default_type"]) + self.assertEqual(value, 2.3) with DatabaseMapping(url) as shadow_db_map: - items = shadow_db_map.get_items("entity") + items = shadow_db_map.get_items("parameter_definition") self.assertEqual(len(items), 1) - items[0].update(name="renamed_entity") - shadow_db_map.commit_session("Renamed the entity.") - items = db_map.get_items("entity") + value, type_ = to_database(5.0) + items[0].update(default_value=value, default_type=type_) + shadow_db_map.commit_session("Changed default value.") + definitions = shadow_db_map.query(shadow_db_map.parameter_definition_sq).all() + self.assertEqual(len(definitions), 1) + value = from_database(definitions[0]["default_value"], definitions[0]["default_type"]) + self.assertEqual(value, 5.0) + items = db_map.get_items("parameter_definition") self.assertEqual(len(items), 1) self.assertEqual(items[0], original_item) self.assertFalse(items[0].is_committed()) + db_map.commit_session("Restore default value back to original.") + definitions = db_map.query(db_map.parameter_definition_sq).all() + self.assertEqual(len(definitions), 1) + value = from_database(definitions[0]["default_value"], definitions[0]["default_type"]) + self.assertEqual(value, 2.3) def test_resolve_an_update_conflict_in_favor_of_external_modification(self): with TemporaryDirectory() as temp_dir: @@ -627,7 +687,10 @@ def test_resolve_an_update_conflict_in_favor_of_external_modification(self): items = db_map.fetch_more("entity", resolve_conflicts=select_in_db_item_always) self.assertEqual(len(items), 1) self.assertEqual(items[0].item_type, updated_item.item_type) - self.assertEqual(items[0].extended(), updated_item.extended()) + for keys, values in EntityItem.unique_values_for_item(items[0]): + for key, value in zip(keys, values): + with self.subTest(key=key): + self.assertEqual(value, updated_item[key]) def test_recreating_deleted_item_externally_brings_it_back_if_favored_by_conflict_resolution(self): with TemporaryDirectory() as temp_dir: @@ -661,18 +724,520 @@ def test_restoring_entity_whose_db_id_has_been_replaced_by_external_db_modificat self.assertEqual(len(items), 1) db_map.remove_item("entity", original_id) db_map.commit_session("Removed entity.") + self.assertEqual(len(db_map.get_entity_items()), 0) with DatabaseMapping(url) as shadow_db_map: self._assert_success(shadow_db_map.add_entity_item(class_name="my_class", name="other_entity")) shadow_db_map.commit_session("Add entity with different name, probably reusing previous id.") items = db_map.fetch_more("entity") self.assertEqual(len(items), 1) - self.assertEqual(items[0]["name"], "my_entity") + self.assertEqual(items[0]["name"], "other_entity") all_items = db_map.get_entity_items() - self.assertEqual(len(all_items), 0) + self.assertEqual(len(all_items), 1) restored_item = db_map.restore_item("entity", original_id) self.assertEqual(restored_item["name"], "my_entity") all_items = db_map.get_entity_items() - self.assertEqual(len(all_items), 1) + self.assertEqual(len(all_items), 2) + + def test_cunning_ways_to_make_external_changes(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + self._assert_success(db_map.add_entity_class_item(name="interesting_class")) + self._assert_success(db_map.add_entity_class_item(name="filler_class")) + self._assert_success( + db_map.add_parameter_definition_item(name="quality", entity_class_name="interesting_class") + ) + self._assert_success( + db_map.add_parameter_definition_item(name="quantity", entity_class_name="filler_class") + ) + self._assert_success(db_map.add_entity_item(name="object_of_interest", class_name="interesting_class")) + value, value_type = to_database(2.3) + self._assert_success( + db_map.add_parameter_value_item( + parameter_definition_name="quality", + entity_class_name="interesting_class", + entity_byname=("object_of_interest",), + alternative_name="Base", + value=value, + type=value_type, + ) + ) + db_map.commit_session("Add initial data") + removed_item = db_map.get_entity_item(name="object_of_interest", class_name="interesting_class") + removed_item.remove() + db_map.commit_session("Remove object of interest") + with DatabaseMapping(url) as shadow_db_map: + self._assert_success( + shadow_db_map.add_entity_item(name="other_entity", class_name="interesting_class") + ) + self._assert_success(shadow_db_map.add_entity_item(name="filler", class_name="filler_class")) + value, value_type = to_database(-2.3) + self._assert_success( + shadow_db_map.add_parameter_value_item( + parameter_definition_name="quantity", + entity_class_name="filler_class", + entity_byname=("filler",), + alternative_name="Base", + value=value, + type=value_type, + ) + ) + value, value_type = to_database(99.9) + self._assert_success( + shadow_db_map.add_parameter_value_item( + parameter_definition_name="quality", + entity_class_name="interesting_class", + entity_byname=("other_entity",), + alternative_name="Base", + value=value, + type=value_type, + ) + ) + shadow_db_map.commit_session("Add entities.") + entity_items = db_map.get_entity_items() + self.assertEqual(len(entity_items), 2) + self.assertEqual( + entity_items[0].extended(), + { + "id": -2, + "name": "other_entity", + "description": None, + "class_id": -1, + "element_id_list": (), + "element_name_list": (), + "commit_id": -4, + "class_name": "interesting_class", + "dimension_id_list": (), + "dimension_name_list": (), + "element_byname_list": (), + "superclass_id": None, + "superclass_name": None, + }, + ) + self.assertEqual( + entity_items[1].extended(), + { + "id": -3, + "name": "filler", + "description": None, + "class_id": -2, + "element_id_list": (), + "element_name_list": (), + "commit_id": -4, + "class_name": "filler_class", + "dimension_id_list": (), + "dimension_name_list": (), + "element_byname_list": (), + "superclass_id": None, + "superclass_name": None, + }, + ) + value_items = db_map.get_parameter_value_items() + self.assertEqual(len(value_items), 2) + self.assertTrue(removed_item.is_committed()) + self.assertEqual( + value_items[0].extended(), + { + "alternative_id": -1, + "alternative_name": "Base", + "commit_id": -4, + "dimension_id_list": (), + "dimension_name_list": (), + "element_id_list": (), + "element_name_list": (), + "entity_byname": ("filler",), + "entity_class_id": -2, + "entity_class_name": "filler_class", + "entity_id": -3, + "entity_name": "filler", + "id": -2, + "list_value_id": None, + "parameter_definition_id": -2, + "parameter_definition_name": "quantity", + "parameter_value_list_id": None, + "parameter_value_list_name": None, + "type": to_database(-2.3)[1], + "value": to_database(-2.3)[0], + }, + ) + self.assertEqual( + value_items[1].extended(), + { + "alternative_id": -1, + "alternative_name": "Base", + "commit_id": -4, + "dimension_id_list": (), + "dimension_name_list": (), + "element_id_list": (), + "element_name_list": (), + "entity_byname": ("other_entity",), + "entity_class_id": -1, + "entity_class_name": "interesting_class", + "entity_id": -2, + "entity_name": "other_entity", + "id": -3, + "list_value_id": None, + "parameter_definition_id": -1, + "parameter_definition_name": "quality", + "parameter_value_list_id": None, + "parameter_value_list_name": None, + "type": to_database(99.9)[1], + "value": to_database(99.9)[0], + }, + ) + + def test_update_entity_metadata_externally(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + self._assert_success(db_map.add_entity_class_item(name="my_class")) + self._assert_success(db_map.add_entity_item(name="my_entity", class_name="my_class")) + metadata_value = '{"sources": [], "contributors": []}' + self._assert_success(db_map.add_metadata_item(name="my_metadata", value=metadata_value)) + self._assert_success( + db_map.add_entity_metadata_item( + metadata_name="my_metadata", + metadata_value=metadata_value, + class_name="my_class", + entity_byname=("my_entity",), + ) + ) + db_map.commit_session("Add initial data.") + with DatabaseMapping(url) as shadow_db_map: + self._assert_success(shadow_db_map.add_entity_item(name="other_entity", class_name="my_class")) + metadata_item = shadow_db_map.get_entity_metadata_item( + metadata_name="my_metadata", + metadata_value=metadata_value, + class_name="my_class", + entity_byname=("my_entity",), + ) + self.assertTrue(metadata_item) + metadata_item.update(entity_byname=("other_entity",)) + shadow_db_map.commit_session("Move entity metadata to another entity") + metadata_items = db_map.get_entity_metadata_items() + self.assertEqual(len(metadata_items), 2) + self.assertEqual( + metadata_items[0].extended(), + { + "id": -1, + "class_name": "my_class", + "entity_byname": ("my_entity",), + "entity_id": -1, + "metadata_id": -1, + "metadata_name": "my_metadata", + "metadata_value": metadata_value, + "commit_id": None, + }, + ) + self.assertFalse(metadata_items[0].is_committed()) + self.assertEqual( + metadata_items[1].extended(), + { + "id": -2, + "class_name": "my_class", + "entity_byname": ("other_entity",), + "entity_id": -2, + "metadata_id": -1, + "metadata_name": "my_metadata", + "metadata_value": metadata_value, + "commit_id": -3, + }, + ) + self.assertTrue(metadata_items[1].is_committed()) + + def test_update_parameter_value_metadata_externally(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + self._assert_success(db_map.add_entity_class_item(name="my_class")) + self._assert_success(db_map.add_parameter_definition_item(name="x", entity_class_name="my_class")) + self._assert_success(db_map.add_entity_item(name="my_entity", class_name="my_class")) + value, value_type = to_database(2.3) + self._assert_success( + db_map.add_parameter_value_item( + entity_class_name="my_class", + entity_byname=("my_entity",), + parameter_definition_name="x", + alternative_name="Base", + value=value, + type=value_type, + ) + ) + metadata_value = '{"sources": [], "contributors": []}' + self._assert_success(db_map.add_metadata_item(name="my_metadata", value=metadata_value)) + self._assert_success( + db_map.add_parameter_value_metadata_item( + metadata_name="my_metadata", + metadata_value=metadata_value, + class_name="my_class", + entity_byname=("my_entity",), + parameter_definition_name="x", + alternative_name="Base", + ) + ) + db_map.commit_session("Add initial data.") + with DatabaseMapping(url) as shadow_db_map: + self._assert_success(shadow_db_map.add_entity_item(name="other_entity", class_name="my_class")) + value, value_type = to_database(5.0) + self._assert_success( + shadow_db_map.add_parameter_value_item( + entity_class_name="my_class", + entity_byname=("other_entity",), + parameter_definition_name="x", + alternative_name="Base", + value=value, + type=value_type, + ) + ) + metadata_item = shadow_db_map.get_parameter_value_metadata_item( + metadata_name="my_metadata", + metadata_value=metadata_value, + class_name="my_class", + entity_byname=("my_entity",), + parameter_definition_name="x", + alternative_name="Base", + ) + self.assertTrue(metadata_item) + metadata_item.update(entity_byname=("other_entity",)) + shadow_db_map.commit_session("Move parameter value metadata to another entity") + metadata_items = db_map.get_parameter_value_metadata_items() + self.assertEqual(len(metadata_items), 2) + self.assertEqual( + metadata_items[0].extended(), + { + "id": -1, + "class_name": "my_class", + "parameter_definition_name": "x", + "parameter_value_id": -1, + "entity_byname": ("my_entity",), + "metadata_id": -1, + "metadata_name": "my_metadata", + "metadata_value": metadata_value, + "alternative_name": "Base", + "commit_id": None, + }, + ) + self.assertFalse(metadata_items[0].is_committed()) + self.assertEqual( + metadata_items[1].extended(), + { + "id": -2, + "class_name": "my_class", + "parameter_definition_name": "x", + "parameter_value_id": -2, + "entity_byname": ("other_entity",), + "metadata_id": -1, + "metadata_name": "my_metadata", + "metadata_value": metadata_value, + "alternative_name": "Base", + "commit_id": -3, + }, + ) + self.assertTrue(metadata_items[1].is_committed()) + + def test_update_entity_alternative_externally(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + self._assert_success(db_map.add_entity_class_item(name="my_class")) + self._assert_success(db_map.add_entity_item(name="my_entity", class_name="my_class")) + self._assert_success( + db_map.add_entity_alternative_item( + entity_byname=("my_entity",), + entity_class_name="my_class", + alternative_name="Base", + active=False, + ) + ) + db_map.commit_session("Add initial data.") + with DatabaseMapping(url) as shadow_db_map: + self._assert_success(shadow_db_map.add_entity_item(name="other_entity", class_name="my_class")) + entity_alternative = shadow_db_map.get_entity_alternative_item( + entity_class_name="my_class", entity_byname=("my_entity",), alternative_name="Base" + ) + self.assertTrue(entity_alternative) + entity_alternative.update(entity_byname=("other_entity",)) + shadow_db_map.commit_session("Move entity alternative to another entity.") + entity_alternatives = db_map.get_entity_alternative_items() + self.assertEqual(len(entity_alternatives), 2) + self.assertEqual( + entity_alternatives[0].extended(), + { + "id": -1, + "entity_class_name": "my_class", + "entity_class_id": -1, + "entity_byname": ("my_entity",), + "entity_name": "my_entity", + "entity_id": -1, + "dimension_name_list": (), + "dimension_id_list": (), + "element_name_list": (), + "element_id_list": (), + "alternative_name": "Base", + "alternative_id": -1, + "active": False, + "commit_id": None, + }, + ) + self.assertFalse(entity_alternatives[0].is_committed()) + self.assertEqual( + entity_alternatives[1].extended(), + { + "id": -2, + "entity_class_name": "my_class", + "entity_class_id": -1, + "entity_byname": ("other_entity",), + "entity_name": "other_entity", + "entity_id": -2, + "dimension_name_list": (), + "dimension_id_list": (), + "element_name_list": (), + "element_id_list": (), + "alternative_name": "Base", + "alternative_id": -1, + "active": False, + "commit_id": -3, + }, + ) + self.assertTrue(entity_alternatives[1].is_committed()) + + def test_update_superclass_subclass_externally(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + self._assert_success(db_map.add_entity_class_item(name="ceiling")) + self._assert_success(db_map.add_entity_class_item(name="floor")) + self._assert_success(db_map.add_entity_class_item(name="soil")) + self._assert_success( + db_map.add_superclass_subclass_item(superclass_name="ceiling", subclass_name="floor") + ) + db_map.commit_session("Add initial data.") + with DatabaseMapping(url) as shadow_db_map: + superclass_subclass = shadow_db_map.get_superclass_subclass_item(subclass_name="floor") + superclass_subclass.update(subclass_name="soil") + shadow_db_map.commit_session("Changes subclass to another one.") + superclass_subclasses = db_map.get_superclass_subclass_items() + self.assertEqual(len(superclass_subclasses), 2) + self.assertEqual( + superclass_subclasses[0].extended(), + { + "id": -1, + "superclass_name": "ceiling", + "superclass_id": -1, + "subclass_name": "floor", + "subclass_id": -2, + }, + ) + self.assertFalse(superclass_subclasses[0].is_committed()) + self.assertEqual( + superclass_subclasses[1].extended(), + { + "id": -2, + "superclass_name": "ceiling", + "superclass_id": -1, + "subclass_name": "soil", + "subclass_id": -3, + }, + ) + self.assertTrue(superclass_subclasses[1].is_committed()) + + def test_adding_same_parameters_values_to_different_entities_externally(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + self._assert_success(db_map.add_entity_class_item(name="my_class")) + self._assert_success(db_map.add_parameter_definition_item(name="x", entity_class_name="my_class")) + my_entity = self._assert_success(db_map.add_entity_item(name="my_entity", class_name="my_class")) + value, value_type = to_database(2.3) + self._assert_success( + db_map.add_parameter_value_item( + entity_class_name="my_class", + entity_byname=("my_entity",), + parameter_definition_name="x", + alternative_name="Base", + value=value, + type=value_type, + ) + ) + db_map.commit_session("Add initial data.") + my_entity.remove() + db_map.commit_session("Remove entity.") + with DatabaseMapping(url) as shadow_db_map: + self._assert_success(shadow_db_map.add_entity_item(name="other_entity", class_name="my_class")) + self._assert_success( + shadow_db_map.add_parameter_value_item( + entity_class_name="my_class", + entity_byname=("other_entity",), + parameter_definition_name="x", + alternative_name="Base", + value=value, + type=value_type, + ) + ) + shadow_db_map.commit_session("Add another entity.") + values = db_map.get_parameter_value_items() + self.assertEqual(len(values), 1) + self.assertEqual( + values[0].extended(), + { + "id": -2, + "entity_class_name": "my_class", + "entity_class_id": -1, + "dimension_name_list": (), + "dimension_id_list": (), + "parameter_definition_name": "x", + "parameter_definition_id": -1, + "entity_byname": ("other_entity",), + "entity_name": "other_entity", + "entity_id": -2, + "element_name_list": (), + "element_id_list": (), + "alternative_name": "Base", + "alternative_id": -1, + "parameter_value_list_name": None, + "parameter_value_list_id": None, + "list_value_id": None, + "type": value_type, + "value": value, + "commit_id": -4, + }, + ) + + def test_committing_changes_purged_entity_has_been_overwritten_by_external_change(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + self._assert_success(db_map.add_entity_class_item(name="my_class")) + self._assert_success(db_map.add_entity_item(name="ghost", class_name="my_class")) + db_map.commit_session("Add soon-to-be-removed entity.") + db_map.purge_items("entity") + db_map.commit_session("Purge entities.") + with DatabaseMapping(url) as shadow_db_map: + self._assert_success(shadow_db_map.add_entity_item(name="other_entity", class_name="my_class")) + shadow_db_map.commit_session("Add another entity that steals ghost's id.") + db_map.do_fetch_all("entity") + self.assertFalse(db_map.any_uncommitted_items()) + self._assert_success(db_map.add_entity_item(name="dirty_entity", class_name="my_class")) + self.assertTrue(db_map.any_uncommitted_items()) + db_map.commit_session("Add still uncommitted entity.") + entities = db_map.query(db_map.wide_entity_sq).all() + self.assertEqual(len(entities), 2) + + def test_reset_purging(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + self._assert_success(db_map.add_entity_class_item(name="my_class")) + db_map.commit_session("Add entity_class.") + self._assert_success(db_map.add_entity_item(name="my_entity", class_name="my_class")) + db_map.purge_items("entity") + with DatabaseMapping(url) as shadow_db_map: + self._assert_success(shadow_db_map.add_entity_item(name="other_entity", class_name="my_class")) + shadow_db_map.commit_session("Add another entity that should not be purged.") + db_map.reset_purging() + entities = db_map.get_entity_items("entity") + self.assertEqual(len(entities), 1) + self.assertEqual(entities[0]["name"], "other_entity") class TestDatabaseMappingLegacy(unittest.TestCase): @@ -708,6 +1273,7 @@ def test_construction_with_sqlalchemy_url_and_filters(self): ) as mock_load: db_map = CustomDatabaseMapping(sa_url, create=True) db_map.close() + mock_load.assert_called_once_with(["fltr1", "fltr2"]) mock_apply.assert_called_once_with(db_map, [{"fltr1": "config1", "fltr2": "config2"}]) @@ -1640,6 +2206,7 @@ def test_add_parameter_values(self): "entity_id": nemo_row.id, "entity_class_id": nemo_row.class_id, "value": b'"orange"', + "type": None, "alternative_id": 1, }, { @@ -1647,6 +2214,7 @@ def test_add_parameter_values(self): "entity_id": nemo__pluto_row.id, "entity_class_id": nemo__pluto_row.class_id, "value": b"125", + "type": None, "alternative_id": 1, }, ) @@ -1698,6 +2266,7 @@ def test_add_same_parameter_value_twice(self): "entity_id": nemo_row.id, "entity_class_id": nemo_row.class_id, "value": b'"orange"', + "type": None, "alternative_id": 1, }, { @@ -1705,6 +2274,7 @@ def test_add_same_parameter_value_twice(self): "entity_id": nemo_row.id, "entity_class_id": nemo_row.class_id, "value": b'"blue"', + "type": None, "alternative_id": 1, }, ) @@ -2014,8 +2584,12 @@ def test_add_entity_to_a_class_with_abstract_dimensions(self): import_functions.import_entity_classes( self._db_map, (("fish", ()), ("dog", ()), ("animal", ()), ("two_animals", ("animal", "animal"))) ) - import_functions.import_superclass_subclasses(self._db_map, (("animal", "fish"), ("animal", "dog"))) - import_functions.import_entities(self._db_map, (("fish", "Nemo"), ("dog", "Pulgoso"))) + count, errors = import_functions.import_superclass_subclasses( + self._db_map, (("animal", "fish"), ("animal", "dog")) + ) + self.assertEqual(errors, []) + count, errors = import_functions.import_entities(self._db_map, (("fish", "Nemo"), ("dog", "Pulgoso"))) + self.assertEqual(errors, []) self._db_map.commit_session("Add test data.") item, error = self._db_map.add_item("entity", class_name="two_animals", element_name_list=("Nemo", "Pulgoso")) self.assertTrue(item) @@ -2033,13 +2607,18 @@ def setUp(self): def tearDown(self): self._db_map.close() + def _assert_success(self, result): + items, errors = result + self.assertEqual(errors, []) + return items + def test_update_object_classes(self): """Test that updating object classes works.""" self._db_map.add_object_classes({"id": 1, "name": "fish"}, {"id": 2, "name": "dog"}) items, intgr_error_log = self._db_map.update_object_classes( {"id": 1, "name": "octopus"}, {"id": 2, "name": "god"} ) - ids = {x["id"] for x in items} + ids = {1, 2} self._db_map.commit_session("test commit") sq = self._db_map.object_class_sq object_classes = {x.id: x.name for x in self._db_map.query(sq).filter(sq.c.id.in_(ids))} @@ -2052,7 +2631,7 @@ def test_update_objects(self): self._db_map.add_object_classes({"id": 1, "name": "fish"}) self._db_map.add_objects({"id": 1, "name": "nemo", "class_id": 1}, {"id": 2, "name": "dory", "class_id": 1}) items, intgr_error_log = self._db_map.update_objects({"id": 1, "name": "klaus"}, {"id": 2, "name": "squidward"}) - ids = {x["id"] for x in items} + ids = {1, 2} self._db_map.commit_session("test commit") sq = self._db_map.object_sq objects = {x.id: x.name for x in self._db_map.query(sq).filter(sq.c.id.in_(ids))} @@ -2062,11 +2641,11 @@ def test_update_objects(self): def test_update_committed_object(self): """Test that updating objects works.""" - self._db_map.add_object_classes({"id": 1, "name": "some_class"}) - self._db_map.add_objects({"id": 1, "name": "nemo", "class_id": 1}) + self._assert_success(self._db_map.add_object_classes({"id": 1, "name": "some_class"})) + self._assert_success(self._db_map.add_objects({"id": 1, "name": "nemo", "class_id": 1})) self._db_map.commit_session("update") items, intgr_error_log = self._db_map.update_objects({"id": 1, "name": "klaus"}) - ids = {x["id"] for x in items} + ids = {1} self._db_map.commit_session("test commit") sq = self._db_map.object_sq objects = {x.id: x.name for x in self._db_map.query(sq).filter(sq.c.id.in_(ids))} @@ -2084,7 +2663,7 @@ def test_update_relationship_classes(self): items, intgr_error_log = self._db_map.update_wide_relationship_classes( {"id": 3, "name": "god__octopus"}, {"id": 4, "name": "octopus__dog"} ) - ids = {x["id"] for x in items} + ids = {3, 4} self._db_map.commit_session("test commit") sq = self._db_map.wide_relationship_class_sq rel_clss = {x.id: x.name for x in self._db_map.query(sq).filter(sq.c.id.in_(ids))} @@ -2093,13 +2672,15 @@ def test_update_relationship_classes(self): self.assertEqual(rel_clss[4], "octopus__dog") def test_update_committed_relationship_class(self): - _ = import_functions.import_object_classes(self._db_map, ("object_class_1",)) - _ = import_functions.import_relationship_classes(self._db_map, (("my_class", ("object_class_1",)),)) + self._assert_success(import_functions.import_object_classes(self._db_map, ("object_class_1",))) + self._assert_success( + import_functions.import_relationship_classes(self._db_map, (("my_class", ("object_class_1",)),)) + ) self._db_map.commit_session("Add test data") items, errors = self._db_map.update_wide_relationship_classes({"id": 2, "name": "renamed"}) updated_ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(updated_ids, {2}) + self.assertEqual(updated_ids, {-2}) self._db_map.commit_session("Update data.") classes = self._db_map.query(self._db_map.wide_relationship_class_sq).all() self.assertEqual(len(classes), 1) @@ -2122,29 +2703,38 @@ def test_update_relationship_class_does_not_update_member_class_id(self): def test_update_relationships(self): """Test that updating relationships works.""" - self._db_map.add_object_classes({"name": "fish", "id": 1}, {"name": "dog", "id": 2}) - self._db_map.add_wide_relationship_classes({"name": "fish__dog", "id": 3, "object_class_id_list": [1, 2]}) - self._db_map.add_objects( - {"name": "nemo", "id": 1, "class_id": 1}, - {"name": "pluto", "id": 2, "class_id": 2}, - {"name": "scooby", "id": 3, "class_id": 2}, + self._assert_success(self._db_map.add_object_classes({"name": "fish", "id": 1}, {"name": "dog", "id": 2})) + self._assert_success( + self._db_map.add_wide_relationship_classes({"name": "fish__dog", "id": 3, "object_class_id_list": [1, 2]}) + ) + self._assert_success( + self._db_map.add_objects( + {"name": "nemo", "id": 1, "class_id": 1}, + {"name": "pluto", "id": 2, "class_id": 2}, + {"name": "scooby", "id": 3, "class_id": 2}, + ) ) - self._db_map.add_wide_relationships( - {"id": 4, "name": "nemo__pluto", "class_id": 3, "object_id_list": [1, 2], "object_class_id_list": [1, 2]} + self._assert_success( + self._db_map.add_wide_relationships( + { + "id": 4, + "name": "nemo__pluto", + "class_id": 3, + "object_id_list": [1, 2], + "object_class_id_list": [1, 2], + } + ) ) items, intgr_error_log = self._db_map.update_wide_relationships( {"id": 4, "name": "nemo__scooby", "class_id": 3, "object_id_list": [1, 3], "object_class_id_list": [1, 2]} ) - ids = {x["id"] for x in items} self._db_map.commit_session("test commit") sq = self._db_map.wide_relationship_sq - rels = { - x.id: {"name": x.name, "object_id_list": x.object_id_list} - for x in self._db_map.query(sq).filter(sq.c.id.in_(ids)) - } + rels = [{"name": x.name, "object_id_list": x.object_id_list} for x in self._db_map.query(sq)] self.assertEqual(intgr_error_log, []) - self.assertEqual(rels[4]["name"], "nemo__scooby") - self.assertEqual(rels[4]["object_id_list"], "1,3") + self.assertEqual(len(rels), 1) + self.assertEqual(rels[0]["name"], "nemo__scooby") + self.assertEqual(rels[0]["object_id_list"], "1,3") def test_update_committed_relationship(self): import_functions.import_object_classes(self._db_map, ("object_class_1", "object_class_2")) @@ -2160,7 +2750,7 @@ def test_update_committed_relationship(self): items, errors = self._db_map.update_wide_relationships({"id": 4, "name": "renamed", "object_id_list": [2, 3]}) updated_ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(updated_ids, {4}) + self.assertEqual(updated_ids, {-4}) self._db_map.commit_session("Update data.") relationships = self._db_map.query(self._db_map.wide_relationship_sq).all() self.assertEqual(len(relationships), 1) @@ -2178,7 +2768,7 @@ def test_update_parameter_value_by_id_only(self): items, errors = self._db_map.update_parameter_values({"id": 1, "value": b"something else"}) updated_ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(updated_ids, {1}) + self.assertEqual(updated_ids, {-1}) self._db_map.commit_session("Update data.") pvals = self._db_map.query(self._db_map.parameter_value_sq).all() self.assertEqual(len(pvals), 1) @@ -2210,7 +2800,7 @@ def test_update_parameter_definition_by_id_only(self): items, errors = self._db_map.update_parameter_definitions({"id": 1, "name": "parameter2"}) updated_ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(updated_ids, {1}) + self.assertEqual(updated_ids, {-1}) self._db_map.commit_session("Update data.") pdefs = self._db_map.query(self._db_map.parameter_definition_sq).all() self.assertEqual(len(pdefs), 1) @@ -2226,7 +2816,7 @@ def test_update_parameter_definition_value_list(self): ) updated_ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(updated_ids, {1}) + self.assertEqual(updated_ids, {-1}) self._db_map.commit_session("Update data.") pdefs = self._db_map.query(self._db_map.parameter_definition_sq).all() self.assertEqual(len(pdefs), 1) @@ -2326,7 +2916,7 @@ def test_update_object_metadata_reuses_existing_metadata(self): ) ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(ids, {1}) + self.assertEqual(ids, {-1}) self._db_map.remove_unused_metadata() self._db_map.commit_session("Update data") metadata_entries = self._db_map.query(self._db_map.metadata_sq).all() @@ -2432,7 +3022,7 @@ def test_update_metadata(self): items, errors = self._db_map.update_metadata(*({"id": 1, "name": "author", "value": "Prof. T. Est"},)) ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(ids, {1}) + self.assertEqual(ids, {-1}) self._db_map.commit_session("Update data") metadata_records = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata_records), 1) @@ -2566,6 +3156,7 @@ def test_remove_parameter_value(self): self._db_map.add_parameter_values( { "value": b"0", + "type": None, "id": 1, "parameter_definition_id": 1, "object_id": 1, @@ -2588,6 +3179,7 @@ def test_remove_parameter_value_from_committed_session(self): self._db_map.add_parameter_values( { "value": b"0", + "type": None, "id": 1, "parameter_definition_id": 1, "object_id": 1, @@ -2610,6 +3202,7 @@ def test_cascade_remove_object_removes_parameter_value_as_well(self): self._db_map.add_parameter_values( { "value": b"0", + "type": None, "id": 1, "parameter_definition_id": 1, "object_id": 1, @@ -2632,6 +3225,7 @@ def test_cascade_remove_object_from_committed_session_removes_parameter_value_as self._db_map.add_parameter_values( { "value": b"0", + "type": None, "id": 1, "parameter_definition_id": 1, "object_id": 1, @@ -2925,7 +3519,7 @@ def test_refresh_addition(self): def test_refresh_removal(self): import_functions.import_object_classes(self._db_map, ("my_class",)) self._db_map.commit_session("test commit") - self._db_map.remove_items("entity_class", 1) + self._db_map.remove_items("entity_class", -1) entity_class_names = {x["name"] for x in self._db_map.mapped_table("entity_class").valid_values()} self.assertEqual(entity_class_names, set()) self._db_map.fetch_all() @@ -2940,7 +3534,7 @@ def test_refresh_update(self): self.assertEqual(entity_class_names, {"new_name"}) self._db_map.fetch_all() entity_class_names = {x["name"] for x in self._db_map.mapped_table("entity_class").valid_values()} - self.assertEqual(entity_class_names, {"new_name"}) + self.assertEqual(entity_class_names, {"new_name", "my_class"}) def test_cascade_remove_unfetched(self): import_functions.import_object_classes(self._db_map, ("my_class",)) @@ -2949,8 +3543,8 @@ def test_cascade_remove_unfetched(self): self._db_map.reset() self._db_map.remove_items("entity_class", 1) self._db_map.commit_session("test commit") - ents = self._db_map.query(self._db_map.entity_sq).all() - self.assertEqual(ents, []) + entities = self._db_map.query(self._db_map.entity_sq).all() + self.assertEqual(entities, []) if __name__ == "__main__": diff --git a/tests/test_db_mapping_base.py b/tests/test_db_mapping_base.py deleted file mode 100644 index 179c1ed8..00000000 --- a/tests/test_db_mapping_base.py +++ /dev/null @@ -1,80 +0,0 @@ -###################################################################################################################### -# Copyright (C) 2017-2022 Spine project consortium -# This file is part of Spine Database API. -# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser -# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your -# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; -# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General -# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with -# this program. If not, see . -###################################################################################################################### -import unittest - -from spinedb_api.db_mapping_base import MappedItemBase, DatabaseMappingBase - - -class TestDBMapping(DatabaseMappingBase): - @staticmethod - def item_types(): - return ["cutlery"] - - @staticmethod - def all_item_types(): - return ["cutlery"] - - @staticmethod - def item_factory(item_type): - if item_type == "cutlery": - return MappedItemBase - raise RuntimeError(f"unknown item_type '{item_type}'") - - -class TestDBMappingBase(unittest.TestCase): - def test_rolling_back_new_item_invalidates_its_id(self): - db_map = TestDBMapping() - mapped_table = db_map.mapped_table("cutlery") - item = mapped_table.add_item({}) - self.assertTrue(item.is_id_valid) - self.assertIn("id", item) - id_ = item["id"] - db_map._rollback() - self.assertFalse(item.is_id_valid) - self.assertEqual(item["id"], id_) - - -class TestMappedTable(unittest.TestCase): - def test_readding_item_with_invalid_id_creates_new_id(self): - db_map = TestDBMapping() - mapped_table = db_map.mapped_table("cutlery") - item = mapped_table.add_item({}) - id_ = item["id"] - db_map._rollback() - self.assertFalse(item.is_id_valid) - mapped_table.add_item(item) - self.assertTrue(item.is_id_valid) - self.assertNotEqual(item["id"], id_) - - -class TestMappedItemBase(unittest.TestCase): - def test_id_is_valid_initially(self): - db_map = TestDBMapping() - item = MappedItemBase(db_map, "cutlery") - self.assertTrue(item.is_id_valid) - - def test_id_can_be_invalidated(self): - db_map = TestDBMapping() - item = MappedItemBase(db_map, "cutlery") - item.invalidate_id() - self.assertFalse(item.is_id_valid) - - def test_setting_new_id_validates_it(self): - db_map = TestDBMapping() - item = MappedItemBase(db_map, "cutlery") - item.invalidate_id() - self.assertFalse(item.is_id_valid) - item["id"] = 23 - self.assertTrue(item.is_id_valid) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 5ba88aac..e19f1979 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -10,15 +10,16 @@ ###################################################################################################################### """Unit tests for helpers.py.""" - import unittest from spinedb_api.helpers import ( compare_schemas, create_new_spine_database, name_from_dimensions, name_from_elements, + query_byname, remove_credentials_from_url, ) +from spinedb_api.db_mapping import DatabaseMapping class TestNameFromElements(unittest.TestCase): @@ -67,5 +68,76 @@ def test_password_with_special_characters(self): self.assertEqual(sanitized, "mysql://example.com/db") +class TestQueryByname(unittest.TestCase): + def _assert_success(self, result): + item, error = result + self.assertIsNone(error) + return item + + def test_zero_dimension_entity(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + self._assert_success(db_map.add_entity_class_item(name="my_class")) + self._assert_success(db_map.add_entity_item(name="my_entity", class_name="my_class")) + db_map.commit_session("Add entity.") + entity_row = db_map.query(db_map.wide_entity_sq).one() + self.assertEqual(query_byname(entity_row, db_map), ("my_entity",)) + + def test_dimensioned_entity(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + self._assert_success(db_map.add_entity_class_item(name="class_1")) + self._assert_success(db_map.add_entity_class_item(name="class_2")) + self._assert_success(db_map.add_entity_item(name="entity_1", class_name="class_1")) + self._assert_success(db_map.add_entity_item(name="entity_2", class_name="class_2")) + self._assert_success( + db_map.add_entity_class_item(name="relationship", dimension_name_list=("class_1", "class_2")) + ) + relationship = self._assert_success( + db_map.add_entity_item(class_name="relationship", element_name_list=("entity_1", "entity_2")) + ) + db_map.commit_session("Add entities") + entity_row = ( + db_map.query(db_map.wide_entity_sq) + .filter(db_map.wide_entity_sq.c.id == db_map.find_db_id("entity", relationship["id"])) + .one() + ) + self.assertEqual(query_byname(entity_row, db_map), ("entity_1", "entity_2")) + + def test_deep_dimensioned_entity(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + self._assert_success(db_map.add_entity_class_item(name="class_1")) + self._assert_success(db_map.add_entity_class_item(name="class_2")) + self._assert_success(db_map.add_entity_item(name="entity_1", class_name="class_1")) + self._assert_success(db_map.add_entity_item(name="entity_2", class_name="class_2")) + self._assert_success( + db_map.add_entity_class_item(name="relationship_1", dimension_name_list=("class_1", "class_2")) + ) + relationship_1 = self._assert_success( + db_map.add_entity_item(class_name="relationship_1", element_name_list=("entity_1", "entity_2")) + ) + self._assert_success( + db_map.add_entity_class_item(name="relationship_2", dimension_name_list=("class_2", "class_1")) + ) + relationship_2 = self._assert_success( + db_map.add_entity_item(class_name="relationship_2", element_name_list=("entity_2", "entity_1")) + ) + self._assert_success( + db_map.add_entity_class_item( + name="super_relationship", dimension_name_list=("relationship_1", "relationship_2") + ) + ) + superrelationship = self._assert_success( + db_map.add_entity_item( + class_name="super_relationship", element_name_list=(relationship_1["name"], relationship_2["name"]) + ) + ) + db_map.commit_session("Add entities") + entity_row = ( + db_map.query(db_map.wide_entity_sq) + .filter(db_map.wide_entity_sq.c.id == db_map.find_db_id("entity", superrelationship["id"])) + .one() + ) + self.assertEqual(query_byname(entity_row, db_map), ("entity_1", "entity_2", "entity_2", "entity_1")) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_item_id.py b/tests/test_item_id.py new file mode 100644 index 00000000..fda8d602 --- /dev/null +++ b/tests/test_item_id.py @@ -0,0 +1,70 @@ +###################################################################################################################### +# Copyright (C) 2017-2022 Spine project consortium +# This file is part of Spine Database API. +# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser +# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General +# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with +# this program. If not, see . +###################################################################################################################### +import unittest + +from spinedb_api.item_id import IdFactory, IdMap + + +class TestIdFactory(unittest.TestCase): + def test_ids_are_negative_and_consecutive(self): + factory = IdFactory() + self.assertEqual(factory.next_id(), -1) + self.assertEqual(factory.next_id(), -2) + + +class TestIdMap(unittest.TestCase): + def test_add_item_id(self): + id_map = IdMap() + id_map.add_item_id(-2) + self.assertIsNone(id_map.db_id(-2)) + + def test_remove_item_id(self): + id_map = IdMap() + id_map.set_db_id(-2, 3) + id_map.remove_item_id(-2) + self.assertRaises(KeyError, id_map.item_id, 3) + self.assertRaises(KeyError, id_map.db_id, -2) + + def test_set_db_id(self): + id_map = IdMap() + id_map.set_db_id(-2, 3) + self.assertEqual(id_map.db_id(-2), 3) + self.assertEqual(id_map.item_id(3), -2) + + def test_remove_db_id_using_db_id(self): + id_map = IdMap() + id_map.set_db_id(-2, 3) + id_map.remove_db_id(3) + self.assertIsNone(id_map.db_id(-2)) + self.assertRaises(KeyError, id_map.item_id, 3) + + def test_remove_db_id_using_item_id(self): + id_map = IdMap() + id_map.set_db_id(-2, 3) + id_map.remove_db_id(-2) + self.assertIsNone(id_map.db_id(-2)) + self.assertRaises(KeyError, id_map.item_id, 3) + + def test_item_id(self): + id_map = IdMap() + id_map.set_db_id(-2, 3) + self.assertEqual(id_map.item_id(3), -2) + self.assertRaises(KeyError, id_map.item_id, 99) + + def test_db_id(self): + id_map = IdMap() + id_map.set_db_id(-2, 3) + self.assertEqual(id_map.db_id(-2), 3) + self.assertRaises(KeyError, id_map.db_id, -99) + + +if __name__ == '__main__': + unittest.main() From 3e88fb4727cde9aa87b3cbe4c319b44e23191c02 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Thu, 11 Jan 2024 09:36:58 +0200 Subject: [PATCH 221/317] Fix removing items using Asterisk as id Re spine-tools/Spine-Toolbox#2431 --- spinedb_api/db_mapping_commit_mixin.py | 6 +++--- tests/test_DatabaseMapping.py | 14 +++++++++++++- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py index 50d3019a..14f80d4e 100644 --- a/spinedb_api/db_mapping_commit_mixin.py +++ b/spinedb_api/db_mapping_commit_mixin.py @@ -139,9 +139,9 @@ def _do_remove_items(self, connection, tablename, *ids): purging = Asterisk in ids if not purging: ids = {id_map.db_id(id_) for id_ in ids} - if tablename == "alternative": - # Do not remove the Base alternative - ids.discard(1) + if tablename == "alternative": + # Do not remove the Base alternative + ids.discard(1) if not ids: return tablenames = [tablename] diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index c45780cd..56957cf9 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -26,7 +26,7 @@ SpineIntegrityError, ) from spinedb_api.conflict_resolution import select_in_db_item_always -from spinedb_api.helpers import name_from_elements +from spinedb_api.helpers import Asterisk, name_from_elements from spinedb_api.mapped_items import EntityItem from tests.custom_db_mapping import CustomDatabaseMapping @@ -1239,6 +1239,18 @@ def test_reset_purging(self): self.assertEqual(len(entities), 1) self.assertEqual(entities[0]["name"], "other_entity") + def test_remove_items_by_asterisk(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + self._assert_success(db_map.add_alternative_item(name="alt_1")) + self._assert_success(db_map.add_alternative_item(name="alt_2")) + db_map.commit_session("Add alternatives.") + alternatives = db_map.get_alternative_items() + self.assertEqual(len(alternatives), 3) + db_map.remove_items("alternative", Asterisk) + db_map.commit_session("Remove all alternatives.") + alternatives = db_map.get_alternative_items() + self.assertEqual(alternatives, []) + class TestDatabaseMappingLegacy(unittest.TestCase): """'Backward compatibility' tests, i.e. pre-entity tests converted to work with the entity structure.""" From 92a2bf09d4ea2bffa1e977f1259950616c4171fd Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Thu, 11 Jan 2024 10:56:06 +0200 Subject: [PATCH 222/317] Use enity_byname consistently everywhere This replaces uses of 'byname' in entity items by 'entity_byname'. Re #318 --- spinedb_api/import_functions.py | 2 +- spinedb_api/mapped_items.py | 34 +++++++++++++++++---------------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 19f64341..498ee376 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -498,7 +498,7 @@ def _get_superclass_subclasses_for_import(db_map, data): def _get_entities_for_import(db_map, data): items_by_el_count = {} - key = ("class_name", "byname", "description") + key = ("class_name", "entity_byname", "description") for class_name, name_or_el_name_list, *optionals in data: if isinstance(name_or_el_name_list, (list, tuple)): el_count = len(name_or_el_name_list) diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 8a32d102..8dad5012 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -111,7 +111,7 @@ class EntityItem(MappedItemBase): 'class_name': {'type': str, 'value': 'The entity class name.'}, 'name': {'type': str, 'value': 'The entity name.'}, 'element_name_list': {'type': tuple, 'value': 'The element names if the entity is multi-dimensional.'}, - 'byname': { + 'entity_byname': { 'type': tuple, 'value': 'A tuple with the entity name as single element if the entity is zero-dimensional,' 'or the element names if it is multi-dimensional.', @@ -120,7 +120,7 @@ class EntityItem(MappedItemBase): } _defaults = {"description": None} - _unique_keys = (("class_name", "name"), ("class_name", "byname")) + _unique_keys = (("class_name", "name"), ("class_name", "entity_byname")) _references = {"class_id": ("entity_class", "id"), "element_id_list": ("entity", "id")} _external_fields = { "class_name": ("class_id", "name"), @@ -129,7 +129,7 @@ class EntityItem(MappedItemBase): "superclass_id": ("class_id", "superclass_id"), "superclass_name": ("class_id", "superclass_name"), "element_name_list": ("element_id_list", "name"), - "element_byname_list": ("element_id_list", "byname"), + "element_byname_list": ("element_id_list", "entity_byname"), } _alt_references = { ("class_name",): ("entity_class", ("name",)), @@ -168,7 +168,7 @@ def _byname_iter(self, entity): yield from self._byname_iter(element) def __getitem__(self, key): - if key == "byname": + if key == "entity_byname": return tuple(self._byname_iter(self)) return super().__getitem__(key) @@ -177,7 +177,7 @@ def resolve_internal_fields(self, skip_keys=()): error = super().resolve_internal_fields(skip_keys=skip_keys) if error: return error - byname = dict.pop(self, "byname", None) + byname = dict.pop(self, "entity_byname", None) if byname is None: return dim_count = len(self["dimension_id_list"]) @@ -193,7 +193,7 @@ def resolve_internal_fields(self, skip_keys=()): self["element_name_list"] = element_name_list return self._do_resolve_internal_field("element_id_list") - def _element_name_list_recursive(self, class_name, byname): + def _element_name_list_recursive(self, class_name, entity_byname): """Returns the element name list corresponding to given class and byname. If the class is multi-dimensional then recurses for each dimension. If the class is a superclass then it tries for each subclass until finding something useful. @@ -205,18 +205,20 @@ def _element_name_list_recursive(self, class_name, byname): dimension_name_list = self._db_map.get_item("entity_class", name=class_name_).get("dimension_name_list") if not dimension_name_list: continue - byname_backup = list(byname) + byname_backup = list(entity_byname) element_name_list = tuple( self._db_map.get_item( "entity", - **dict(zip(("byname", "class_name"), self._element_name_list_recursive(dim_name, byname))), + **dict( + zip(("entity_byname", "class_name"), self._element_name_list_recursive(dim_name, entity_byname)) + ), ).get("name") for dim_name in dimension_name_list ) if None not in element_name_list: return element_name_list, class_name_ - byname = byname_backup - name = byname.pop(0) if byname else None + entity_byname = byname_backup + name = entity_byname.pop(0) if entity_byname else None return (name,), class_name def polish(self): @@ -310,13 +312,13 @@ class EntityAlternativeItem(MappedItemBase): "dimension_id_list": ("entity_class_id", "dimension_id_list"), "dimension_name_list": ("entity_class_id", "dimension_name_list"), "entity_name": ("entity_id", "name"), - "entity_byname": ("entity_id", "byname"), + "entity_byname": ("entity_id", "entity_byname"), "element_id_list": ("entity_id", "element_id_list"), "element_name_list": ("entity_id", "element_name_list"), "alternative_name": ("alternative_id", "name"), } _alt_references = { - ("entity_class_name", "entity_byname"): ("entity", ("class_name", "byname")), + ("entity_class_name", "entity_byname"): ("entity", ("class_name", "entity_byname")), ("alternative_name",): ("alternative", ("name",)), } _internal_fields = { @@ -530,7 +532,7 @@ class ParameterValueItem(ParameterItemBase): "parameter_value_list_id": ("parameter_definition_id", "parameter_value_list_id"), "parameter_value_list_name": ("parameter_definition_id", "parameter_value_list_name"), "entity_name": ("entity_id", "name"), - "entity_byname": ("entity_id", "byname"), + "entity_byname": ("entity_id", "entity_byname"), "element_id_list": ("entity_id", "element_id_list"), "element_name_list": ("entity_id", "element_name_list"), "alternative_name": ("alternative_id", "name"), @@ -538,7 +540,7 @@ class ParameterValueItem(ParameterItemBase): _alt_references = { ("entity_class_name",): ("entity_class", ("name",)), ("entity_class_name", "parameter_definition_name"): ("parameter_definition", ("entity_class_name", "name")), - ("entity_class_name", "entity_byname"): ("entity", ("class_name", "byname")), + ("entity_class_name", "entity_byname"): ("entity", ("class_name", "entity_byname")), ("alternative_name",): ("alternative", ("name",)), } _internal_fields = { @@ -692,7 +694,7 @@ class EntityMetadataItem(MappedItemBase): } _external_fields = { "class_name": ("entity_id", "class_name"), - "entity_byname": ("entity_id", "byname"), + "entity_byname": ("entity_id", "entity_byname"), "metadata_name": ("metadata_id", "name"), "metadata_value": ("metadata_id", "value"), } @@ -700,7 +702,7 @@ class EntityMetadataItem(MappedItemBase): ( "class_name", "entity_byname", - ): ("entity", ("class_name", "byname")), + ): ("entity", ("class_name", "entity_byname")), ("metadata_name", "metadata_value"): ("metadata", ("name", "value")), } _internal_fields = { From 5e489d84cb00c02308084e94053c5f255c7692c1 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Thu, 11 Jan 2024 13:58:33 +0200 Subject: [PATCH 223/317] Rename 'class_name' fields to 'entity_class_name' for consistency This renames most cases of 'class_name' to 'entity_class_name' for consistency. The most important exceptions are object_sq, relationship_sq and their wide versions which were left as-is for backwards compatibility. Re #318 --- spinedb_api/db_mapping.py | 10 ++-- spinedb_api/db_mapping_query_mixin.py | 2 +- spinedb_api/export_functions.py | 7 ++- spinedb_api/import_functions.py | 8 +-- spinedb_api/mapped_items.py | 71 +++++++++++++------------ spinedb_api/spine_io/exporters/excel.py | 8 +-- tests/test_DatabaseMapping.py | 66 ++++++++++++----------- 7 files changed, 90 insertions(+), 82 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 5a464839..8d1c5f31 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -342,7 +342,7 @@ def get_item(self, item_type, fetch=True, skip_removed=True, **kwargs): Example:: with DatabaseMapping(db_url) as db_map: - prince = db_map.get_item("entity", class_name="musician", name="Prince") + prince = db_map.get_item("entity", entity_class_name="musician", name="Prince") Args: item_type (str): One of . @@ -405,7 +405,7 @@ def add_item(self, item_type, check=True, **kwargs): with DatabaseMapping(db_url) as db_map: db_map.add_item("entity_class", name="musician") - db_map.add_item("entity", class_name="musician", name="Prince") + db_map.add_item("entity", entity_class_name="musician", name="Prince") Args: item_type (str): One of . @@ -443,7 +443,7 @@ def update_item(self, item_type, check=True, **kwargs): Example:: with DatabaseMapping(db_url) as db_map: - prince = db_map.get_item("entity", class_name="musician", name="Prince") + prince = db_map.get_item("entity", entity_class_name="musician", name="Prince") db_map.update_item( "entity", id=prince["id"], name="the Artist", description="Formerly known as Prince." ) @@ -529,7 +529,7 @@ def remove_item(self, item_type, id_, check=True): Example:: with DatabaseMapping(db_url) as db_map: - prince = db_map.get_item("entity", class_name="musician", name="Prince") + prince = db_map.get_item("entity", entity_class_name="musician", name="Prince") db_map.remove_item("entity", prince["id"]) Args: @@ -578,7 +578,7 @@ def restore_item(self, item_type, id_): Example:: with DatabaseMapping(db_url) as db_map: - prince = db_map.get_item("entity", skip_remove=False, class_name="musician", name="Prince") + prince = db_map.get_item("entity", skip_remove=False, entity_class_name="musician", name="Prince") db_map.restore_item("entity", prince["id"]) Args: diff --git a/spinedb_api/db_mapping_query_mixin.py b/spinedb_api/db_mapping_query_mixin.py index 89cf0ae8..366c5bed 100644 --- a/spinedb_api/db_mapping_query_mixin.py +++ b/spinedb_api/db_mapping_query_mixin.py @@ -893,7 +893,7 @@ def ext_entity_group_sq(self): self.entity_group_sq.c.entity_class_id.label("class_id"), self.entity_group_sq.c.entity_id.label("group_id"), self.entity_group_sq.c.member_id.label("member_id"), - self.wide_entity_class_sq.c.name.label("class_name"), + self.wide_entity_class_sq.c.name.label("entity_class_name"), group_entity.c.name.label("group_name"), member_entity.c.name.label("member_name"), label("object_class_id", self._object_class_id()), diff --git a/spinedb_api/export_functions.py b/spinedb_api/export_functions.py index 4bb5b822..9fbd853e 100644 --- a/spinedb_api/export_functions.py +++ b/spinedb_api/export_functions.py @@ -135,13 +135,16 @@ def export_superclass_subclasses(db_map, ids=Asterisk): def export_entities(db_map, ids=Asterisk): return sorted( - ((x.class_name, x.element_name_list or x.name, x.description) for x in _get_items(db_map, "entity", ids)), + ( + (x.entity_class_name, x.element_name_list or x.name, x.description) + for x in _get_items(db_map, "entity", ids) + ), key=lambda x: (0 if isinstance(x[1], str) else len(x[1]), x[0], (x[1],) if isinstance(x[1], str) else x[1]), ) def export_entity_groups(db_map, ids=Asterisk): - return sorted((x.class_name, x.group_name, x.member_name) for x in _get_items(db_map, "entity_group", ids)) + return sorted((x.entity_class_name, x.group_name, x.member_name) for x in _get_items(db_map, "entity_group", ids)) def export_entity_alternatives(db_map, ids=Asterisk): diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 498ee376..e35e38a9 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -498,7 +498,7 @@ def _get_superclass_subclasses_for_import(db_map, data): def _get_entities_for_import(db_map, data): items_by_el_count = {} - key = ("class_name", "entity_byname", "description") + key = ("entity_class_name", "entity_byname", "description") for class_name, name_or_el_name_list, *optionals in data: if isinstance(name_or_el_name_list, (list, tuple)): el_count = len(name_or_el_name_list) @@ -520,7 +520,7 @@ def _get_entity_alternatives_for_import(db_map, data): def _get_entity_groups_for_import(db_map, data): - key = ("class_name", "group_name", "member_name") + key = ("entity_class_name", "group_name", "member_name") return (dict(zip(key, x)) for x in data) @@ -642,7 +642,7 @@ def _get_metadata_for_import(db_map, data): def _get_entity_metadata_for_import(db_map, data): - key = ("class_name", "entity_byname", "metadata_name", "metadata_value") + key = ("entity_class_name", "entity_byname", "metadata_name", "metadata_value") for class_name, entity_byname, metadata in data: if isinstance(entity_byname, str): entity_byname = (entity_byname,) @@ -652,7 +652,7 @@ def _get_entity_metadata_for_import(db_map, data): def _get_parameter_value_metadata_for_import(db_map, data): key = ( - "class_name", + "entity_class_name", "entity_byname", "parameter_definition_name", "metadata_name", diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 8dad5012..15decff2 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -108,7 +108,7 @@ def commit(self, _commit_id): class EntityItem(MappedItemBase): fields = { - 'class_name': {'type': str, 'value': 'The entity class name.'}, + 'entity_class_name': {'type': str, 'value': 'The entity class name.'}, 'name': {'type': str, 'value': 'The entity name.'}, 'element_name_list': {'type': tuple, 'value': 'The element names if the entity is multi-dimensional.'}, 'entity_byname': { @@ -120,10 +120,10 @@ class EntityItem(MappedItemBase): } _defaults = {"description": None} - _unique_keys = (("class_name", "name"), ("class_name", "entity_byname")) + _unique_keys = (("entity_class_name", "name"), ("entity_class_name", "entity_byname")) _references = {"class_id": ("entity_class", "id"), "element_id_list": ("entity", "id")} _external_fields = { - "class_name": ("class_id", "name"), + "entity_class_name": ("class_id", "name"), "dimension_id_list": ("class_id", "dimension_id_list"), "dimension_name_list": ("class_id", "dimension_name_list"), "superclass_id": ("class_id", "superclass_id"), @@ -132,11 +132,11 @@ class EntityItem(MappedItemBase): "element_byname_list": ("element_id_list", "entity_byname"), } _alt_references = { - ("class_name",): ("entity_class", ("name",)), - ("dimension_name_list", "element_name_list"): ("entity", ("class_name", "name")), + ("entity_class_name",): ("entity_class", ("name",)), + ("dimension_name_list", "element_name_list"): ("entity", ("entity_class_name", "name")), } _internal_fields = { - "class_id": (("class_name",), "id"), + "class_id": (("entity_class_name",), "id"), "element_id_list": (("dimension_name_list", "element_name_list"), "id"), } @@ -154,7 +154,7 @@ def unique_values_for_item(cls, item, skip_keys=()): """Overriden to also yield unique values for the superclass.""" for key, value in super().unique_values_for_item(item, skip_keys=skip_keys): yield key, value - sc_value = tuple(item.get("superclass_name" if k == "class_name" else k) for k in key) + sc_value = tuple(item.get("superclass_name" if k == "entity_class_name" else k) for k in key) if None not in sc_value: yield (key, sc_value) @@ -185,7 +185,7 @@ def resolve_internal_fields(self, skip_keys=()): self["name"] = byname[0] return byname_remainder = list(byname) - element_name_list, _ = self._element_name_list_recursive(self["class_name"], byname_remainder) + element_name_list, _ = self._element_name_list_recursive(self["entity_class_name"], byname_remainder) if len(element_name_list) < dim_count: return f"too few elements given for entity ({byname})" if byname_remainder: @@ -210,7 +210,10 @@ def _element_name_list_recursive(self, class_name, entity_byname): self._db_map.get_item( "entity", **dict( - zip(("entity_byname", "class_name"), self._element_name_list_recursive(dim_name, entity_byname)) + zip( + ("entity_byname", "entity_class_name"), + self._element_name_list_recursive(dim_name, entity_byname), + ) ), ).get("name") for dim_name in dimension_name_list @@ -228,7 +231,7 @@ def polish(self): dim_name_lst, el_name_lst = dict.get(self, "dimension_name_list"), dict.get(self, "element_name_list") if dim_name_lst and el_name_lst: for dim_name, el_name in zip(dim_name_lst, el_name_lst): - if not self._db_map.get_item("entity", class_name=dim_name, name=el_name, fetch=False): + if not self._db_map.get_item("entity", entity_class_name=dim_name, name=el_name, fetch=False): return f"element '{el_name}' is not an instance of class '{dim_name}'" if self.get("name") is not None: return @@ -236,8 +239,8 @@ def polish(self): name = base_name index = 1 while any( - self._db_map.get_item("entity", class_name=self[k], name=name) - for k in ("class_name", "superclass_name") + self._db_map.get_item("entity", entity_class_name=self[k], name=name) + for k in ("entity_class_name", "superclass_name") if self[k] is not None ): name = f"{base_name}_{index}" @@ -247,31 +250,31 @@ def polish(self): class EntityGroupItem(MappedItemBase): fields = { - 'class_name': {'type': str, 'value': 'The entity class name.'}, + 'entity_class_name': {'type': str, 'value': 'The entity class name.'}, 'group_name': {'type': str, 'value': 'The group entity name.'}, 'member_name': {'type': str, 'value': 'The member entity name.'}, } - _unique_keys = (("class_name", "group_name", "member_name"),) + _unique_keys = (("entity_class_name", "group_name", "member_name"),) _references = { "entity_class_id": ("entity_class", "id"), "entity_id": ("entity", "id"), "member_id": ("entity", "id"), } _external_fields = { - "class_name": ("entity_class_id", "name"), + "entity_class_name": ("entity_class_id", "name"), "dimension_id_list": ("entity_class_id", "dimension_id_list"), "group_name": ("entity_id", "name"), "member_name": ("member_id", "name"), } _alt_references = { - ("class_name",): ("entity_class", ("name",)), - ("class_name", "group_name"): ("entity", ("class_name", "name")), - ("class_name", "member_name"): ("entity", ("class_name", "name")), + ("entity_class_name",): ("entity_class", ("name",)), + ("entity_class_name", "group_name"): ("entity", ("entity_class_name", "name")), + ("entity_class_name", "member_name"): ("entity", ("entity_class_name", "name")), } _internal_fields = { - "entity_class_id": (("class_name",), "id"), - "entity_id": (("class_name", "group_name"), "id"), - "member_id": (("class_name", "member_name"), "id"), + "entity_class_id": (("entity_class_name",), "id"), + "entity_id": (("entity_class_name", "group_name"), "id"), + "member_id": (("entity_class_name", "member_name"), "id"), } def __getitem__(self, key): @@ -318,7 +321,7 @@ class EntityAlternativeItem(MappedItemBase): "alternative_name": ("alternative_id", "name"), } _alt_references = { - ("entity_class_name", "entity_byname"): ("entity", ("class_name", "entity_byname")), + ("entity_class_name", "entity_byname"): ("entity", ("entity_class_name", "entity_byname")), ("alternative_name",): ("alternative", ("name",)), } _internal_fields = { @@ -540,7 +543,7 @@ class ParameterValueItem(ParameterItemBase): _alt_references = { ("entity_class_name",): ("entity_class", ("name",)), ("entity_class_name", "parameter_definition_name"): ("parameter_definition", ("entity_class_name", "name")), - ("entity_class_name", "entity_byname"): ("entity", ("class_name", "entity_byname")), + ("entity_class_name", "entity_byname"): ("entity", ("entity_class_name", "entity_byname")), ("alternative_name",): ("alternative", ("name",)), } _internal_fields = { @@ -682,38 +685,38 @@ class MetadataItem(MappedItemBase): class EntityMetadataItem(MappedItemBase): fields = { - 'class_name': {'type': str, 'value': 'The entity class name.'}, + 'entity_class_name': {'type': str, 'value': 'The entity class name.'}, 'entity_byname': {'type': tuple, 'value': _ENTITY_BYNAME_VALUE}, 'metadata_name': {'type': str, 'value': 'The metadata entry name.'}, 'metadata_value': {'type': str, 'value': 'The metadata entry value.'}, } - _unique_keys = (("class_name", "entity_byname", "metadata_name", "metadata_value"),) + _unique_keys = (("entity_class_name", "entity_byname", "metadata_name", "metadata_value"),) _references = { "entity_id": ("entity", "id"), "metadata_id": ("metadata", "id"), } _external_fields = { - "class_name": ("entity_id", "class_name"), + "entity_class_name": ("entity_id", "entity_class_name"), "entity_byname": ("entity_id", "entity_byname"), "metadata_name": ("metadata_id", "name"), "metadata_value": ("metadata_id", "value"), } _alt_references = { ( - "class_name", + "entity_class_name", "entity_byname", - ): ("entity", ("class_name", "entity_byname")), + ): ("entity", ("entity_class_name", "entity_byname")), ("metadata_name", "metadata_value"): ("metadata", ("name", "value")), } _internal_fields = { - "entity_id": (("class_name", "entity_byname"), "id"), + "entity_id": (("entity_class_name", "entity_byname"), "id"), "metadata_id": (("metadata_name", "metadata_value"), "id"), } class ParameterValueMetadataItem(MappedItemBase): fields = { - 'class_name': {'type': str, 'value': 'The entity class name.'}, + 'entity_class_name': {'type': str, 'value': 'The entity class name.'}, 'parameter_definition_name': {'type': str, 'value': 'The parameter name.'}, 'entity_byname': { 'type': tuple, @@ -725,7 +728,7 @@ class ParameterValueMetadataItem(MappedItemBase): } _unique_keys = ( ( - "class_name", + "entity_class_name", "parameter_definition_name", "entity_byname", "alternative_name", @@ -735,7 +738,7 @@ class ParameterValueMetadataItem(MappedItemBase): ) _references = {"parameter_value_id": ("parameter_value", "id"), "metadata_id": ("metadata", "id")} _external_fields = { - "class_name": ("parameter_value_id", "entity_class_name"), + "entity_class_name": ("parameter_value_id", "entity_class_name"), "parameter_definition_name": ("parameter_value_id", "parameter_definition_name"), "entity_byname": ("parameter_value_id", "entity_byname"), "alternative_name": ("parameter_value_id", "alternative_name"), @@ -743,7 +746,7 @@ class ParameterValueMetadataItem(MappedItemBase): "metadata_value": ("metadata_id", "value"), } _alt_references = { - ("class_name", "parameter_definition_name", "entity_byname", "alternative_name"): ( + ("entity_class_name", "parameter_definition_name", "entity_byname", "alternative_name"): ( "parameter_value", ("entity_class_name", "parameter_definition_name", "entity_byname", "alternative_name"), ), @@ -751,7 +754,7 @@ class ParameterValueMetadataItem(MappedItemBase): } _internal_fields = { "parameter_value_id": ( - ("class_name", "parameter_definition_name", "entity_byname", "alternative_name"), + ("entity_class_name", "parameter_definition_name", "entity_byname", "alternative_name"), "id", ), "metadata_id": (("metadata_name", "metadata_value"), "id"), diff --git a/spinedb_api/spine_io/exporters/excel.py b/spinedb_api/spine_io/exporters/excel.py index 2cacb664..d75ecd95 100644 --- a/spinedb_api/spine_io/exporters/excel.py +++ b/spinedb_api/spine_io/exporters/excel.py @@ -122,8 +122,8 @@ def _make_scenario_alternative_mapping(): def _make_object_group_mappings(db_map): - for obj_grp in db_map.query(db_map.ext_entity_group_sq).group_by(db_map.ext_entity_group_sq.c.class_name): - root_mapping = EntityClassMapping(Position.table_name, filter_re=obj_grp.class_name) + for obj_grp in db_map.query(db_map.ext_entity_group_sq).group_by(db_map.ext_entity_group_sq.c.entity_class_name): + root_mapping = EntityClassMapping(Position.table_name, filter_re=obj_grp.entity_class_name) group_mapping = root_mapping.child = FixedValueMapping(Position.table_name, value="group") object_mapping = group_mapping.child = EntityMapping(1, header="member") object_mapping.child = EntityGroupMapping(0, header="group") @@ -184,10 +184,10 @@ def _make_relationship_mapping(relationship_class_name, object_class_name_list, root_mapping = EntityClassMapping(Position.table_name, filter_re=f"^{relationship_class_name}$") relationship_mapping = root_mapping.child = EntityMapping(Position.hidden) parent_mapping = relationship_mapping - for d, class_name in enumerate(object_class_name_list): + for d, entity_class_name in enumerate(object_class_name_list): if pivoted: d = -(d + 1) - object_mapping = parent_mapping.child = ElementMapping(d, header=class_name) + object_mapping = parent_mapping.child = ElementMapping(d, header=entity_class_name) parent_mapping = object_mapping return root_mapping diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 70c7055c..86cc7dc5 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -90,7 +90,7 @@ def test_commit_parameter_value(self): _, error = db_map.add_item("entity_class", name="fish", description="It swims.") self.assertIsNone(error) _, error = db_map.add_item( - "entity", class_name="fish", name="Nemo", description="Peacefully swimming away." + "entity", entity_class_name="fish", name="Nemo", description="Peacefully swimming away." ) self.assertIsNone(error) _, error = db_map.add_item("parameter_definition", entity_class_name="fish", name="color") @@ -133,13 +133,13 @@ def test_commit_multidimensional_parameter_value(self): description="A fish getting eaten by a cat?", ) self.assertIsNone(error) - _, error = db_map.add_item("entity", class_name="fish", name="Nemo", description="Lost (soon).") + _, error = db_map.add_item("entity", entity_class_name="fish", name="Nemo", description="Lost (soon).") self.assertIsNone(error) _, error = db_map.add_item( - "entity", class_name="cat", name="Felix", description="The wonderful wonderful cat." + "entity", entity_class_name="cat", name="Felix", description="The wonderful wonderful cat." ) self.assertIsNone(error) - _, error = db_map.add_item("entity", class_name="fish__cat", element_name_list=("Nemo", "Felix")) + _, error = db_map.add_item("entity", entity_class_name="fish__cat", element_name_list=("Nemo", "Felix")) self.assertIsNone(error) _, error = db_map.add_item("parameter_definition", entity_class_name="fish__cat", name="rate") self.assertIsNone(error) @@ -171,7 +171,7 @@ def test_updating_entity_name_updates_the_name_in_parameter_value_too(self): _, error = db_map.add_item("entity_class", name="fish", description="It swims.") self.assertIsNone(error) _, error = db_map.add_item( - "entity", class_name="fish", name="Nemo", description="Peacefully swimming away." + "entity", entity_class_name="fish", name="Nemo", description="Peacefully swimming away." ) self.assertIsNone(error) _, error = db_map.add_item("parameter_definition", entity_class_name="fish", name="color") @@ -195,7 +195,7 @@ def test_updating_entity_name_updates_the_name_in_parameter_value_too(self): alternative_name="Base", ) self.assertIsNotNone(color) - fish = db_map.get_item("entity", class_name="fish", name="Nemo") + fish = db_map.get_item("entity", entity_class_name="fish", name="Nemo") self.assertIsNotNone(fish) fish.update(name="NotNemo") self.assertEqual(fish["name"], "NotNemo") @@ -219,14 +219,14 @@ def test_updating_entity_name_updates_the_name_in_parameter_value_too(self): def test_update_entity_metadata_by_changing_its_entity(self): with DatabaseMapping("sqlite://", create=True) as db_map: entity_class, _ = db_map.add_entity_class_item(name="my_class") - db_map.add_entity_item(name="entity_1", class_name="my_class") - entity_2, _ = db_map.add_entity_item(name="entity_2", class_name="my_class") + db_map.add_entity_item(name="entity_1", entity_class_name="my_class") + entity_2, _ = db_map.add_entity_item(name="entity_2", entity_class_name="my_class") metadata_value = '{"sources": [], "contributors": []}' metadata, _ = db_map.add_metadata_item(name="my_metadata", value=metadata_value) entity_metadata, error = db_map.add_entity_metadata_item( metadata_name="my_metadata", metadata_value=metadata_value, - class_name="my_class", + entity_class_name="my_class", entity_byname=("entity_1",), ) self.assertIsNone(error) @@ -234,7 +234,7 @@ def test_update_entity_metadata_by_changing_its_entity(self): self.assertEqual( entity_metadata._extended(), { - "class_name": "my_class", + "entity_class_name": "my_class", "entity_byname": ("entity_2",), "entity_id": entity_2["id"], "id": entity_metadata["id"], @@ -247,7 +247,7 @@ def test_update_entity_metadata_by_changing_its_entity(self): entity_sq = ( db_map.query( db_map.entity_sq.c.id.label("entity_id"), - db_map.entity_class_sq.c.name.label("class_name"), + db_map.entity_class_sq.c.name.label("entity_class_name"), db_map.entity_sq.c.name.label("entity_name"), ) .join(db_map.entity_class_sq, db_map.entity_class_sq.c.id == db_map.entity_sq.c.class_id) @@ -256,7 +256,7 @@ def test_update_entity_metadata_by_changing_its_entity(self): metadata_records = ( db_map.query( db_map.entity_metadata_sq.c.id, - entity_sq.c.class_name, + entity_sq.c.entity_class_name, entity_sq.c.entity_name, db_map.metadata_sq.c.name.label("metadata_name"), db_map.metadata_sq.c.value.label("metadata_value"), @@ -270,7 +270,7 @@ def test_update_entity_metadata_by_changing_its_entity(self): dict(**metadata_records[0]), { "id": 1, - "class_name": "my_class", + "entity_class_name": "my_class", "entity_name": "entity_2", "metadata_name": "my_metadata", "metadata_value": metadata_value, @@ -283,7 +283,7 @@ def test_update_parameter_value_metadata_by_changing_its_parameter(self): _, error = db_map.add_parameter_definition_item(name="x", entity_class_name="my_class") self.assertIsNone(error) db_map.add_parameter_definition_item(name="y", entity_class_name="my_class") - entity, _ = db_map.add_entity_item(name="my_entity", class_name="my_class") + entity, _ = db_map.add_entity_item(name="my_entity", entity_class_name="my_class") value, value_type = to_database(2.3) _, error = db_map.add_parameter_value_item( entity_class_name="my_class", @@ -310,7 +310,7 @@ def test_update_parameter_value_metadata_by_changing_its_parameter(self): value_metadata, error = db_map.add_parameter_value_metadata_item( metadata_name="my_metadata", metadata_value=metadata_value, - class_name="my_class", + entity_class_name="my_class", entity_byname=("my_entity",), parameter_definition_name="x", alternative_name="Base", @@ -320,7 +320,7 @@ def test_update_parameter_value_metadata_by_changing_its_parameter(self): self.assertEqual( value_metadata._extended(), { - "class_name": "my_class", + "entity_class_name": "my_class", "entity_byname": ("my_entity",), "alternative_name": "Base", "parameter_definition_name": "y", @@ -335,7 +335,7 @@ def test_update_parameter_value_metadata_by_changing_its_parameter(self): parameter_sq = ( db_map.query( db_map.parameter_value_sq.c.id.label("value_id"), - db_map.entity_class_sq.c.name.label("class_name"), + db_map.entity_class_sq.c.name.label("entity_class_name"), db_map.entity_sq.c.name.label("entity_name"), db_map.parameter_definition_sq.c.name.label("parameter_definition_name"), db_map.alternative_sq.c.name.label("alternative_name"), @@ -354,7 +354,7 @@ def test_update_parameter_value_metadata_by_changing_its_parameter(self): metadata_records = ( db_map.query( db_map.parameter_value_metadata_sq.c.id, - parameter_sq.c.class_name, + parameter_sq.c.entity_class_name, parameter_sq.c.entity_name, parameter_sq.c.parameter_definition_name, parameter_sq.c.alternative_name, @@ -370,7 +370,7 @@ def test_update_parameter_value_metadata_by_changing_its_parameter(self): dict(**metadata_records[0]), { "id": 1, - "class_name": "my_class", + "entity_class_name": "my_class", "entity_name": "my_entity", "parameter_definition_name": "y", "alternative_name": "Base", @@ -388,11 +388,11 @@ def test_fetch_more(self): def test_fetch_more_after_commit_and_refresh(self): with DatabaseMapping("sqlite://", create=True) as db_map: db_map.add_item("entity_class", name="Widget") - db_map.add_item("entity", class_name="Widget", name="gadget") + db_map.add_item("entity", entity_class_name="Widget", name="gadget") db_map.commit_session("Add test data.") db_map.refresh_session() entities = db_map.fetch_more("entity") - self.assertEqual([(x["class_name"], x["name"]) for x in entities], [("Widget", "gadget")]) + self.assertEqual([(x["entity_class_name"], x["name"]) for x in entities], [("Widget", "gadget")]) def test_has_external_commits_returns_false_initially(self): with DatabaseMapping("sqlite://", create=True) as db_map: @@ -431,15 +431,15 @@ def test_fetch_entities_that_refer_to_unfetched_entities(self): db_map.add_entity_class_item(name="dog") db_map.add_entity_class_item(name="cat") db_map.add_entity_class_item(name="dog__cat", dimension_name_list=("dog", "cat")) - db_map.add_entity_item(name="Pulgoso", class_name="dog") - db_map.add_entity_item(name="Sylvester", class_name="cat") - db_map.add_entity_item(name="Tom", class_name="cat") + db_map.add_entity_item(name="Pulgoso", entity_class_name="dog") + db_map.add_entity_item(name="Sylvester", entity_class_name="cat") + db_map.add_entity_item(name="Tom", entity_class_name="cat") db_map.commit_session("Arf!") with DatabaseMapping(url) as db_map: # Remove the entity in the middle and add a multi-D one referring to the third entity. # The multi-D one will go in the middle. - db_map.get_entity_item(name="Sylvester", class_name="cat").remove() - db_map.add_entity_item(element_name_list=("Pulgoso", "Tom"), class_name="dog__cat") + db_map.get_entity_item(name="Sylvester", entity_class_name="cat").remove() + db_map.add_entity_item(element_name_list=("Pulgoso", "Tom"), entity_class_name="dog__cat") db_map.commit_session("Meow!") with DatabaseMapping(url) as db_map: # The ("Pulgoso", "Tom") entity will be fetched before "Tom". @@ -502,9 +502,9 @@ def test_committing_superclass_subclass_items_doesnt_add_commit_ids_to_them(self def test_committing_entity_group_items_doesnt_add_commit_ids_to_them(self): with DatabaseMapping("sqlite://", create=True) as db_map: db_map.add_entity_class_item(name="my_class") - db_map.add_entity_item(name="element", class_name="my_class") - db_map.add_entity_item(name="container", class_name="my_class") - db_map.add_entity_group_item(group_name="container", member_name="element", class_name="my_class") + db_map.add_entity_item(name="element", entity_class_name="my_class") + db_map.add_entity_item(name="container", entity_class_name="my_class") + db_map.add_entity_group_item(group_name="container", member_name="element", entity_class_name="my_class") db_map.commit_session("Add entity group.") groups = db_map.get_entity_group_items() self.assertEqual(len(groups), 1) @@ -1059,7 +1059,7 @@ def tearDown(self): def test_get_entity_class_items_check_fields(self): import_functions.import_data(self._db_map, entity_classes=(("fish",),)) with self.assertRaises(SpineDBAPIError): - self._db_map.get_entity_class_item(class_name="fish") + self._db_map.get_entity_class_item(entity_class_name="fish") with self.assertRaises(SpineDBAPIError): self._db_map.get_entity_class_item(name=("fish",)) self._db_map.get_entity_class_item(name="fish") @@ -1142,7 +1142,7 @@ def test_add_object_with_invalid_name(self): """Test that adding object classes with empty name raises error""" self._db_map.add_object_classes({"name": "fish"}) with self.assertRaises(SpineIntegrityError): - self._db_map.add_objects({"name": "", "class_name": "fish"}, strict=True) + self._db_map.add_objects({"name": "", "entity_class_name": "fish"}, strict=True) def test_add_objects_with_same_name(self): """Test that adding two objects with the same name only adds one of them.""" @@ -1853,7 +1853,9 @@ def test_add_entity_to_a_class_with_abstract_dimensions(self): import_functions.import_superclass_subclasses(self._db_map, (("animal", "fish"), ("animal", "dog"))) import_functions.import_entities(self._db_map, (("fish", "Nemo"), ("dog", "Pulgoso"))) self._db_map.commit_session("Add test data.") - item, error = self._db_map.add_item("entity", class_name="two_animals", element_name_list=("Nemo", "Pulgoso")) + item, error = self._db_map.add_item( + "entity", entity_class_name="two_animals", element_name_list=("Nemo", "Pulgoso") + ) self.assertTrue(item) self.assertFalse(error) self._db_map.commit_session("Add test data.") From 86c24821f5f4c4a6c6d392e65fb446afd6283b19 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Fri, 12 Jan 2024 09:35:05 +0200 Subject: [PATCH 224/317] Fix tutorial Bring tutorial up-to-date with the class_name->entity_class_name and byname->entity_byname renames. Re #318 --- docs/source/tutorial.rst | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/docs/source/tutorial.rst b/docs/source/tutorial.rst index 92e788a8..f4adfa56 100644 --- a/docs/source/tutorial.rst +++ b/docs/source/tutorial.rst @@ -84,15 +84,15 @@ as `dimension_name_list`:: Let's add entities to our zero-dimensional classes:: - db_map.add_entity_item(class_name="fish", name="Nemo", description="Lost (for now).") + db_map.add_entity_item(entity_class_name="fish", name="Nemo", description="Lost (for now).") db_map.add_entity_item( - class_name="cat", name="Felix", description="The wonderful wonderful cat." + entity_class_name="cat", name="Felix", description="The wonderful wonderful cat." ) Let's add a multi-dimensional entity to our multi-dimensional class. For this we need to specify the entity names as `element_name_list`:: - db_map.add_entity_item(class_name="fish__cat", element_name_list=("Nemo", "Felix")) + db_map.add_entity_item(entity_class_name="fish__cat", element_name_list=("Nemo", "Felix")) Let's add a parameter definition for one of our entity classes:: @@ -132,14 +132,14 @@ This implicitly fetches data from the DB into the in-memory mapping, if not already there. For example, let's find one of the entities we inserted above:: - felix_item = db_map.get_entity_item(class_name="cat", name="Felix") + felix_item = db_map.get_entity_item(entity_class_name="cat", name="Felix") assert felix_item["description"] == "The wonderful wonderful cat." Above, ``felix_item`` is a :class:`~.PublicItem` object, representing an item. Let's find our multi-dimensional entity:: - nemo_felix_item = db_map.get_entity_item("entity", class_name="fish__cat", element_name_list=("Nemo", "Felix")) + nemo_felix_item = db_map.get_entity_item("entity", entity_class_name="fish__cat", element_name_list=("Nemo", "Felix")) assert nemo_felix_item["dimension_name_list"] == ('fish', 'cat') Now let's retrieve our parameter value:: @@ -151,13 +151,14 @@ Now let's retrieve our parameter value:: alternative_name="Base" ) -We use :func:`.from_database` to convert the value and type from the parameter value into our original value:: +We use :func:`.from_database` to convert the value and type from the parameter value into our original value:: + nemo_color = api.from_database(nemo_color_item["value"], nemo_color_item["type"]) assert nemo_color == "mainly orange" To retrieve all the items of a given type, we use :meth:`~.DatabaseMapping.get_items`:: - assert [entity["byname"] for entity in db_map.get_items("entity")] == [ + assert [entity["entity_byname"] for entity in db_map.get_items("entity")] == [ ("Nemo",), ("Felix",), ("Nemo", "Felix") ] @@ -171,7 +172,7 @@ To update data, we use the :meth:`~.PublicItem.update` method of :class:`~.Publi Let's rename our fish entity to avoid any copyright infringements:: - db_map.get_entity_item(class_name="fish", name="Nemo").update(name="NotNemo") + db_map.get_entity_item(entity_class_name="fish", name="Nemo").update(name="NotNemo") To be safe, let's also change the color:: @@ -191,7 +192,7 @@ Removing data You know what, let's just remove the entity entirely. To do this we use the :meth:`~.PublicItem.remove` method of :class:`~.PublicItem`:: - db_map.get_entity_item(class_name="fish", name="NotNemo").remove() + db_map.get_entity_item(entity_class_name="fish", name="NotNemo").remove() Note that the above call removes items in *cascade*, meaning that items that depend on ``"NotNemo"`` will get removed as well. From 7583f0d2e5a45532793b5f1cc7ff4253f02d3e8c Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Fri, 12 Jan 2024 09:38:52 +0200 Subject: [PATCH 225/317] Fix sphinx errors from the auto-generated db_mapping_schema.rst --- docs/source/conf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index b82aec6e..80b15481 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -161,7 +161,8 @@ def type_(f_dict): lines.extend([".. list-table:: Unique keys", " :header-rows: 0", ""]) for f_names in factory._unique_keys: f_names = ", ".join(f_names) - lines.extend([f" * - {f_names}"]) + lines.append(f" * - {f_names}") + lines.append("") return lines From 7441bfb2310d7f4265afe96dbbdae07c72e58ca4 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Mon, 15 Jan 2024 10:42:35 +0200 Subject: [PATCH 226/317] Add active_by_default flag to entity_class table Added active_by_default flag to entity_class table which specifies the default value for entity alternatives of the entities of the class. By default, zero-dimensional entities have the flag unset while it is set for multidimensional entities. Since this change alters the entity_class table, we now have a new migration script as well. Queries, filters etc. were updated, too. Re #316 --- ...b_add_active_by_default_to_entity_class.py | 42 +++++++++++ spinedb_api/db_mapping_query_mixin.py | 2 + spinedb_api/filters/renamer.py | 7 +- spinedb_api/filters/scenario_filter.py | 17 ++--- spinedb_api/helpers.py | 16 +++-- spinedb_api/mapped_items.py | 12 ++++ tests/filters/test_renamer.py | 10 +-- tests/filters/test_scenario_filter.py | 69 +++++++++++++++++-- tests/test_DatabaseMapping.py | 37 ++++++++++ tests/test_helpers.py | 7 ++ 10 files changed, 189 insertions(+), 30 deletions(-) create mode 100644 spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py diff --git a/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py b/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py new file mode 100644 index 00000000..5eb0e49e --- /dev/null +++ b/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py @@ -0,0 +1,42 @@ +"""add active_by_default to entity_class + +Revision ID: 8b0eff478bcb +Revises: 5385f063bef2 +Create Date: 2024-01-12 09:55:08.934574 + +""" +from alembic import op +import sqlalchemy as sa +import sqlalchemy.orm + + +# revision identifiers, used by Alembic. +revision = '8b0eff478bcb' +down_revision = '5385f063bef2' +branch_labels = None +depends_on = None + + +def upgrade(): + with op.batch_alter_table("entity_class") as batch_op: + batch_op.add_column( + sa.Column( + "active_by_default", sa.Boolean(name="active_by_default"), server_default=sa.false(), nullable=False + ), + ) + conn = op.get_bind() + session = sa.orm.sessionmaker(bind=conn)() + metadata = sa.MetaData() + metadata.reflect(bind=conn) + dimension_table = metadata.tables["entity_class_dimension"] + dimensional_class_ids = {row.entity_class_id for row in session.query(dimension_table)} + metadata.reflect(bind=conn) + class_table = metadata.tables["entity_class"] + update_statement = ( + class_table.update().where(class_table.c.id == sa.bindparam("target_id")).values(active_by_default=True) + ) + conn.execute(update_statement, [{"target_id": class_id} for class_id in dimensional_class_ids]) + + +def downgrade(): + raise NotImplementedError() diff --git a/spinedb_api/db_mapping_query_mixin.py b/spinedb_api/db_mapping_query_mixin.py index 366c5bed..c5fd4f37 100644 --- a/spinedb_api/db_mapping_query_mixin.py +++ b/spinedb_api/db_mapping_query_mixin.py @@ -207,6 +207,7 @@ def wide_entity_class_sq(self): self.entity_class_sq.c.display_order, self.entity_class_sq.c.display_icon, self.entity_class_sq.c.hidden, + self.entity_class_sq.c.active_by_default, entity_class_dimension_sq.c.dimension_id, entity_class_dimension_sq.c.dimension_name, entity_class_dimension_sq.c.position, @@ -226,6 +227,7 @@ def wide_entity_class_sq(self): ecd_sq.c.display_order, ecd_sq.c.display_icon, ecd_sq.c.hidden, + ecd_sq.c.active_by_default, group_concat(ecd_sq.c.dimension_id, ecd_sq.c.position).label("dimension_id_list"), group_concat(ecd_sq.c.dimension_name, ecd_sq.c.position).label("dimension_name_list"), func.count(ecd_sq.c.dimension_id).label("dimension_count"), diff --git a/spinedb_api/filters/renamer.py b/spinedb_api/filters/renamer.py index fc7cc05f..a3970ac7 100644 --- a/spinedb_api/filters/renamer.py +++ b/spinedb_api/filters/renamer.py @@ -8,11 +8,7 @@ # Public License for more details. You should have received a copy of the GNU Lesser General Public License along with # this program. If not, see . ###################################################################################################################### - -""" -Provides a database query manipulator that renames database items. - -""" +""" Provides a database query manipulator that renames database items. """ from functools import partial from sqlalchemy import case @@ -215,6 +211,7 @@ def _make_renaming_entity_class_sq(db_map, state): subquery.c.display_order, subquery.c.display_icon, subquery.c.hidden, + subquery.c.active_by_default, ).subquery() return entity_class_sq diff --git a/spinedb_api/filters/scenario_filter.py b/spinedb_api/filters/scenario_filter.py index bacb6c51..3bbcb5e7 100644 --- a/spinedb_api/filters/scenario_filter.py +++ b/spinedb_api/filters/scenario_filter.py @@ -8,14 +8,10 @@ # Public License for more details. You should have received a copy of the GNU Lesser General Public License along with # this program. If not, see . ###################################################################################################################### - -""" -Provides functions to apply filtering based on scenarios to subqueries. - -""" +""" Provides functions to apply filtering based on scenarios to subqueries. """ from functools import partial -from sqlalchemy import desc, func, or_ +from sqlalchemy import and_, desc, func, or_ from ..exception import SpineDBAPIError SCENARIO_FILTER_TYPE = "scenario_filter" @@ -197,11 +193,13 @@ def _ext_entity_sq(db_map, state): ) .label("desc_rank_row_number"), db_map.entity_alternative_sq.c.active, + db_map.entity_class_sq.c.active_by_default, db_map.scenario_alternative_sq.c.scenario_id, ) .outerjoin( db_map.entity_alternative_sq, state.original_entity_sq.c.id == db_map.entity_alternative_sq.c.entity_id ) + .outerjoin(db_map.entity_class_sq, state.original_entity_sq.c.class_id == db_map.entity_class_sq.c.id) .outerjoin( db_map.scenario_alternative_sq, db_map.entity_alternative_sq.c.alternative_id == db_map.scenario_alternative_sq.c.alternative_id, @@ -240,7 +238,7 @@ def _make_scenario_filtered_entity_element_sq(db_map, state): ) .filter( element_sq.c.desc_rank_row_number == 1, - or_(element_sq.c.active == True, element_sq.c.active == None), + or_(element_sq.c.active == True, and_(element_sq.c.active == None, element_sq.c.active_by_default == True)), ) .subquery() ) @@ -277,7 +275,10 @@ def _make_scenario_filtered_entity_sq(db_map, state): ) .filter( ext_entity_sq.c.desc_rank_row_number == 1, - or_(ext_entity_sq.c.active == True, ext_entity_sq.c.active == None), + or_( + ext_entity_sq.c.active == True, + and_(ext_entity_sq.c.active == None, ext_entity_sq.c.active_by_default == True), + ), ) .outerjoin( ext_entity_class_dimension_count_sq, diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index 98722b35..da671683 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -8,10 +8,7 @@ # Public License for more details. You should have received a copy of the GNU Lesser General Public License along with # this program. If not, see . ###################################################################################################################### - -""" -General helper functions. -""" +""" General helper functions. """ import os import json import warnings @@ -318,6 +315,13 @@ def is_empty(db_url): return True +def get_head_alembic_version(): + config = Config() + config.set_main_option("script_location", "spinedb_api:alembic") + script = ScriptDirectory.from_config(config) + return script.get_current_head() + + def create_spine_metadata(): meta = MetaData(naming_convention=naming_convention) Table( @@ -374,6 +378,7 @@ def create_spine_metadata(): Column("display_order", Integer, server_default="99"), Column("display_icon", BigInteger, server_default=null()), Column("hidden", Integer, server_default="0"), + Column("active_by_default", Boolean(name="active_by_default"), server_default=false(), nullable=False), ) Table( "superclass_subclass", @@ -622,11 +627,12 @@ def create_new_spine_database(db_url): meta.drop_all() # Create new tables meta = create_spine_metadata() + version = get_head_alembic_version() try: meta.create_all(engine) engine.execute("INSERT INTO `commit` VALUES (1, 'Create the database', CURRENT_TIMESTAMP, 'spinedb_api')") engine.execute("INSERT INTO alternative VALUES (1, 'Base', 'Base alternative', 1)") - engine.execute("INSERT INTO alembic_version VALUES ('5385f063bef2')") + engine.execute(f"INSERT INTO alembic_version VALUES ('{version}')") except DatabaseError as e: raise SpineDBAPIError(f"Unable to create Spine database: {e}") from None return engine diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 15decff2..18cfcdc2 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -69,6 +69,11 @@ class EntityClassItem(MappedItemBase): }, 'display_order': {'type': int, 'value': 'Not in use at the moment.', 'optional': True}, 'hidden': {'type': int, 'value': 'Not in use at the moment.', 'optional': True}, + "active_by_default": { + "type": bool, + "value": "Default activity for the entity alternatives of the class.", + "optional": True, + }, } _defaults = {"description": None, "display_icon": None, "display_order": 99, "hidden": False} _unique_keys = (("name",),) @@ -92,6 +97,13 @@ def __getitem__(self, key): return self._get_ref("superclass_subclass", {"subclass_id": self["id"]}, strong=False).get(key) return super().__getitem__(key) + def polish(self): + error = super().polish() + if error: + return error + if "active_by_default" not in self: + self["active_by_default"] = bool(dict.get(self, "dimension_id_list")) + def merge(self, other): dimension_id_list = other.pop("dimension_id_list", None) error = ( diff --git a/tests/filters/test_renamer.py b/tests/filters/test_renamer.py index d92dc634..d3bf469e 100644 --- a/tests/filters/test_renamer.py +++ b/tests/filters/test_renamer.py @@ -8,11 +8,7 @@ # Public License for more details. You should have received a copy of the GNU Lesser General Public License along with # this program. If not, see . ###################################################################################################################### - -""" -Unit tests for ``renamer`` module. - -""" +""" Unit tests for ``renamer`` module. """ from pathlib import Path from tempfile import TemporaryDirectory import unittest @@ -67,7 +63,7 @@ def test_renaming_singe_entity_class(self): self.assertEqual(len(classes), 1) class_row = classes[0] keys = tuple(class_row.keys()) - expected_keys = ("id", "name", "description", "display_order", "display_icon", "hidden") + expected_keys = ("id", "name", "description", "display_order", "display_icon", "hidden", "active_by_default") self.assertEqual(len(keys), len(expected_keys)) for expected_key in expected_keys: self.assertIn(expected_key, keys) @@ -124,7 +120,7 @@ def test_entity_class_renamer_from_dict(self): self.assertEqual(len(classes), 1) class_row = classes[0] keys = tuple(class_row.keys()) - expected_keys = ("id", "name", "description", "display_order", "display_icon", "hidden") + expected_keys = ("id", "name", "description", "display_order", "display_icon", "hidden", "active_by_default") self.assertEqual(len(keys), len(expected_keys)) for expected_key in expected_keys: self.assertIn(expected_key, keys) diff --git a/tests/filters/test_scenario_filter.py b/tests/filters/test_scenario_filter.py index cff394f7..de1be424 100644 --- a/tests/filters/test_scenario_filter.py +++ b/tests/filters/test_scenario_filter.py @@ -8,16 +8,13 @@ # Public License for more details. You should have received a copy of the GNU Lesser General Public License along with # this program. If not, see . ###################################################################################################################### - -""" -Unit tests for ``alternative_value_filter`` module. - -""" +""" Unit tests for ``alternative_value_filter`` module. """ from pathlib import Path from tempfile import TemporaryDirectory import unittest from sqlalchemy.engine.url import URL from spinedb_api import ( + apply_filter_stack, apply_scenario_filter_to_subqueries, create_new_spine_database, DatabaseMapping, @@ -46,6 +43,62 @@ ) +class TestScenarioFilterInMemory(unittest.TestCase): + def _assert_success(self, result): + item, error = result + self.assertIsNone(error) + return item + + def test_filter_entities_with_default_activity_only(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + self._assert_success(db_map.add_entity_class_item(name="visible", active_by_default=True)) + self._assert_success(db_map.add_entity_item(name="visible_object", entity_class_name="visible")) + self._assert_success(db_map.add_entity_class_item(name="hidden", active_by_default=False)) + self._assert_success(db_map.add_entity_item(name="invisible_object", entity_class_name="hidden")) + self._assert_success(db_map.add_scenario_item(name="S")) + db_map.commit_session("Add data.") + apply_filter_stack(db_map, [scenario_filter_config("S")]) + entities = db_map.query(db_map.wide_entity_sq).all() + self.assertEqual(len(entities), 1) + self.assertEqual(entities[0]["name"], "visible_object") + + def test_filter_entities_with_default_activity(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + self._assert_success(db_map.add_alternative_item(name="alt")) + self._assert_success(db_map.add_entity_class_item(name="visible_by_default", active_by_default=True)) + self._assert_success(db_map.add_entity_item(name="visible", entity_class_name="visible_by_default")) + self._assert_success(db_map.add_entity_item(name="hidden", entity_class_name="visible_by_default")) + self._assert_success( + db_map.add_entity_alternative_item( + entity_class_name="visible_by_default", + entity_byname=("hidden",), + alternative_name="alt", + active=False, + ) + ) + self._assert_success(db_map.add_entity_class_item(name="hidden_by_default", active_by_default=False)) + self._assert_success(db_map.add_entity_item(name="visible", entity_class_name="hidden_by_default")) + self._assert_success( + db_map.add_entity_alternative_item( + entity_class_name="hidden_by_default", + entity_byname=("visible",), + alternative_name="alt", + active=True, + ) + ) + self._assert_success(db_map.add_entity_item(name="hidden", entity_class_name="hidden_by_default")) + self._assert_success(db_map.add_scenario_item(name="S")) + self._assert_success( + db_map.add_scenario_alternative_item(scenario_name="S", alternative_name="alt", rank=0) + ) + db_map.commit_session("Add data.") + apply_filter_stack(db_map, [scenario_filter_config("S")]) + entities = db_map.query(db_map.wide_entity_sq).all() + self.assertEqual(len(entities), 2) + self.assertEqual(entities[0]["name"], "visible") + self.assertEqual(entities[1]["name"], "visible") + + class TestScenarioFilter(unittest.TestCase): _db_url = None _temp_dir = None @@ -138,6 +191,8 @@ def test_scenario_filter_works_for_entity_sq(self): import_scenario_alternatives( self._out_db_map, [("scenario1", "alternative2"), ("scenario1", "alternative1", "alternative2")] ) + for entity_class in self._out_db_map.get_entity_class_items(): + entity_class.update(active_by_default=True) self._out_db_map.commit_session("Add test data") entities = self._db_map.query(self._db_map.entity_sq).all() self.assertEqual(len(entities), 5) @@ -356,6 +411,8 @@ def test_scenario_filter_for_multiple_objects_and_parameters(self): ) import_scenarios(self._out_db_map, [("scenario", True)]) import_scenario_alternatives(self._out_db_map, [("scenario", "alternative")]) + for item in self._out_db_map.get_entity_class_items(): + item.update(active_by_default=True) self._out_db_map.commit_session("Add test data") apply_scenario_filter_to_subqueries(self._db_map, "scenario") parameters = self._db_map.query(self._db_map.parameter_value_sq).all() @@ -469,6 +526,8 @@ def _build_data_with_single_scenario(db_map, commit=True): import_object_parameter_values(db_map, [("object_class", "object", "parameter", 23.0, "alternative")]) import_scenarios(db_map, [("scenario", True)]) import_scenario_alternatives(db_map, [("scenario", "alternative")]) + for entity_class in db_map.get_entity_class_items(): + entity_class.update(active_by_default=True) if commit: db_map.commit_session("Add test data.") diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 86cc7dc5..cb2596ce 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -83,6 +83,43 @@ def test_shorthand_filter_query_works(self): class TestDatabaseMapping(unittest.TestCase): + def test_active_by_default_is_initially_false_for_zero_dimensional_entity_class(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + item, error = db_map.add_entity_class_item(name="Entity") + self.assertIsNone(error) + self.assertFalse(item["active_by_default"]) + + def test_active_by_default_is_initially_false_for_multi_dimensional_entity_class(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + db_map.add_entity_class_item(name="Dimension") + item, error = db_map.add_entity_class_item(name="Entity", dimension_name_list=("Dimension",)) + self.assertIsNone(error) + self.assertTrue(item["active_by_default"]) + + def test_read_active_by_default_from_database(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "database.sqlite") + with DatabaseMapping(url, create=True) as out_db_map: + _, error = out_db_map.add_entity_class_item(name="HiddenStuff", active_by_default=False) + self.assertIsNone(error) + _, error = out_db_map.add_entity_class_item(name="VisibleStuff", active_by_default=True) + self.assertIsNone(error) + out_db_map.commit_session("Add entity classes.") + entity_classes = out_db_map.query(out_db_map.wide_entity_class_sq).all() + self.assertEqual(len(entity_classes), 2) + activities = ((row.name, row.active_by_default) for row in entity_classes) + expected = (("HiddenStuff", False), ("VisibleStuff", True)) + self.assertCountEqual(activities, expected) + with DatabaseMapping(url) as db_map: + entity_classes = db_map.get_entity_class_items() + self.assertEqual(len(entity_classes), 2) + active_by_default = {c["name"]: c["active_by_default"] for c in entity_classes} + expected = {"HiddenStuff": False, "VisibleStuff": True} + for name, activity in active_by_default.items(): + expected_activity = expected.pop(name) + with self.subTest(class_name=name): + self.assertEqual(activity, expected_activity) + def test_commit_parameter_value(self): with TemporaryDirectory() as temp_dir: url = "sqlite:///" + os.path.join(temp_dir, "database.sqlite") diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 5ba88aac..5497ae66 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -15,6 +15,7 @@ from spinedb_api.helpers import ( compare_schemas, create_new_spine_database, + get_head_alembic_version, name_from_dimensions, name_from_elements, remove_credentials_from_url, @@ -67,5 +68,11 @@ def test_password_with_special_characters(self): self.assertEqual(sanitized, "mysql://example.com/db") +class TestGetHeadAlembicVersion(unittest.TestCase): + def test_returns_latest_version(self): + # This test must be updated each time new migration script is added. + self.assertEqual(get_head_alembic_version(), "8b0eff478bcb") + + if __name__ == "__main__": unittest.main() From 1cea3590c7d635e5faacbf510718413821202798 Mon Sep 17 00:00:00 2001 From: Henrik Koski <98282892+PiispaH@users.noreply.github.com> Date: Mon, 15 Jan 2024 11:47:09 +0200 Subject: [PATCH 227/317] Fix opening infinite db editor tabs from same db (#337) --- spinedb_api/db_mapping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 8d1c5f31..057e34cb 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -211,7 +211,7 @@ def _make_codename(self, codename): if not self.sa_url.drivername.startswith("sqlite"): return self.sa_url.database if self.sa_url.database is not None: - return os.path.basename(self.sa_url.database) + return os.path.splitext(os.path.basename(self.sa_url.database))[0] hashing = hashlib.sha1() hashing.update(bytes(str(time.time()), "utf-8")) return hashing.hexdigest() From 8775ef0c37f7757abc7385dcecdb91048f323a05 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Mon, 15 Jan 2024 12:44:02 +0200 Subject: [PATCH 228/317] Add active_by_default to import_functions and export_functions Re #316 --- spinedb_api/export_functions.py | 7 +- spinedb_api/import_functions.py | 6 +- tests/test_export_functions.py | 181 ++++++++++++++++++-------------- tests/test_import_functions.py | 43 +++++++- 4 files changed, 145 insertions(+), 92 deletions(-) diff --git a/spinedb_api/export_functions.py b/spinedb_api/export_functions.py index 9fbd853e..1040a427 100644 --- a/spinedb_api/export_functions.py +++ b/spinedb_api/export_functions.py @@ -8,10 +8,7 @@ # Public License for more details. You should have received a copy of the GNU Lesser General Public License along with # this program. If not, see . ###################################################################################################################### - -""" -Functions for exporting data from a Spine database in a standard format. -""" +""" Functions for exporting data from a Spine database in a standard format. """ from operator import itemgetter from sqlalchemy.util import KeyedTuple @@ -122,7 +119,7 @@ def export_parameter_value_lists(db_map, ids=Asterisk, parse_value=from_database def export_entity_classes(db_map, ids=Asterisk): return sorted( ( - (x.name, x.dimension_name_list, x.description, x.display_icon) + (x.name, x.dimension_name_list, x.description, x.display_icon, x.active_by_default) for x in _get_items(db_map, "entity_class", ids) ), key=lambda x: (len(x[1]), x[0]), diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index e35e38a9..c270c3e4 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -262,8 +262,8 @@ def import_entity_classes(db_map, data): Args: db_map (spinedb_api.DiffDatabaseMapping): database mapping - data (list(tuple(str,tuple,str,int)): tuples of - (name, dimension name tuple, description, display icon integer) + data (list(tuple(str,tuple,str,int,bool)): tuples of + (name, dimension name tuple, description, display icon integer, active by default flag) Returns: int: number of items imported @@ -471,7 +471,7 @@ def import_relationship_parameter_value_metadata(db_map, data): def _get_entity_classes_for_import(db_map, data): dim_name_list_by_name = {} items = [] - key = ("name", "dimension_name_list", "description", "display_icon") + key = ("name", "dimension_name_list", "description", "display_icon", "active_by_default") for x in data: if isinstance(x, str): x = x, () diff --git a/tests/test_export_functions.py b/tests/test_export_functions.py index 489d549a..3a4e9ccc 100644 --- a/tests/test_export_functions.py +++ b/tests/test_export_functions.py @@ -8,17 +8,14 @@ # Public License for more details. You should have received a copy of the GNU Lesser General Public License along with # this program. If not, see . ###################################################################################################################### - -""" -Unit tests for export_functions. - -""" +""" Unit tests for export_functions. """ import unittest from spinedb_api import ( DatabaseMapping, export_alternatives, export_data, + export_entity_classes, export_scenarios, export_scenario_alternatives, import_alternatives, @@ -37,91 +34,117 @@ class TestExportFunctions(unittest.TestCase): - def setUp(self): - db_url = "sqlite://" - self._db_map = DatabaseMapping(db_url, username="UnitTest", create=True) + def _assert_import_success(self, result): + errors = result[1] + self.assertEqual(errors, []) - def tearDown(self): - self._db_map.close() + def _assert_addition_success(self, result): + error = result[1] + self.assertIsNone(error) def test_export_alternatives(self): - import_alternatives(self._db_map, [("alternative", "Description")]) - exported = export_alternatives(self._db_map) - self.assertEqual(exported, [("Base", "Base alternative"), ("alternative", "Description")]) + with DatabaseMapping("sqlite://", username="UnitTest", create=True) as db_map: + self._assert_import_success(import_alternatives(db_map, [("alternative", "Description")])) + exported = export_alternatives(db_map) + self.assertEqual(exported, [("Base", "Base alternative"), ("alternative", "Description")]) def test_export_scenarios(self): - import_scenarios(self._db_map, [("scenario", False, "Description")]) - exported = export_scenarios(self._db_map) - self.assertEqual(exported, [("scenario", False, "Description")]) + with DatabaseMapping("sqlite://", username="UnitTest", create=True) as db_map: + self._assert_import_success(import_scenarios(db_map, [("scenario", False, "Description")])) + exported = export_scenarios(db_map) + self.assertEqual(exported, [("scenario", False, "Description")]) def test_export_scenario_alternatives(self): - import_alternatives(self._db_map, ["alternative"]) - import_scenarios(self._db_map, ["scenario"]) - import_scenario_alternatives(self._db_map, (("scenario", "alternative"),)) - exported = export_scenario_alternatives(self._db_map) - self.assertEqual(exported, [("scenario", "alternative", None)]) + with DatabaseMapping("sqlite://", username="UnitTest", create=True) as db_map: + self._assert_import_success(import_alternatives(db_map, ["alternative"])) + self._assert_import_success(import_scenarios(db_map, ["scenario"])) + self._assert_import_success(import_scenario_alternatives(db_map, (("scenario", "alternative"),))) + exported = export_scenario_alternatives(db_map) + self.assertEqual(exported, [("scenario", "alternative", None)]) def test_export_multiple_scenario_alternatives(self): - import_alternatives(self._db_map, ["alternative1"]) - import_alternatives(self._db_map, ["alternative2"]) - import_scenarios(self._db_map, ["scenario"]) - import_scenario_alternatives(self._db_map, (("scenario", "alternative1"),)) - import_scenario_alternatives(self._db_map, (("scenario", "alternative2", "alternative1"),)) - exported = export_scenario_alternatives(self._db_map) - self.assertEqual( - set(exported), {("scenario", "alternative2", "alternative1"), ("scenario", "alternative1", None)} - ) + with DatabaseMapping("sqlite://", username="UnitTest", create=True) as db_map: + self._assert_import_success(import_alternatives(db_map, ["alternative1"])) + self._assert_import_success(import_alternatives(db_map, ["alternative2"])) + self._assert_import_success(import_scenarios(db_map, ["scenario"])) + self._assert_import_success(import_scenario_alternatives(db_map, (("scenario", "alternative1"),))) + self._assert_import_success( + import_scenario_alternatives(db_map, (("scenario", "alternative2", "alternative1"),)) + ) + exported = export_scenario_alternatives(db_map) + self.assertEqual( + set(exported), {("scenario", "alternative2", "alternative1"), ("scenario", "alternative1", None)} + ) + + def test_export_entity_classes(self): + with DatabaseMapping("sqlite://", username="UnitTest", create=True) as db_map: + self._assert_addition_success(db_map.add_entity_class_item(name="Object")) + self._assert_addition_success( + db_map.add_entity_class_item(name="Relation", dimension_name_list=("Object",)) + ) + exported = export_entity_classes(db_map) + expected = (("Object", (), None, None, False), ("Relation", ("Object",), None, None, True)) + self.assertCountEqual(exported, expected) def test_export_data(self): - import_object_classes(self._db_map, ["object_class"]) - import_object_parameters(self._db_map, [("object_class", "object_parameter")]) - import_objects(self._db_map, [("object_class", "object")]) - import_object_parameter_values(self._db_map, [("object_class", "object", "object_parameter", 2.3)]) - import_relationship_classes(self._db_map, [("relationship_class", ["object_class"])]) - import_relationship_parameters(self._db_map, [("relationship_class", "relationship_parameter")]) - import_relationships(self._db_map, [("relationship_class", ["object"])]) - import_relationship_parameter_values( - self._db_map, [("relationship_class", ["object"], "relationship_parameter", 3.14)] - ) - import_parameter_value_lists(self._db_map, [("value_list", "5.5"), ("value_list", "6.4")]) - import_alternatives(self._db_map, ["alternative"]) - import_scenarios(self._db_map, ["scenario"]) - import_scenario_alternatives(self._db_map, [("scenario", "alternative")]) - exported = export_data(self._db_map) - self.assertEqual(len(exported), 8) - self.assertIn("entity_classes", exported) - self.assertEqual( - exported["entity_classes"], - [("object_class", (), None, None), ("relationship_class", ("object_class",), None, None)], - ) - self.assertIn("parameter_definitions", exported) - self.assertEqual( - exported["parameter_definitions"], - [ - ("object_class", "object_parameter", None, None, None), - ("relationship_class", "relationship_parameter", None, None, None), - ], - ) - self.assertIn("entities", exported) - self.assertEqual( - exported["entities"], [("object_class", "object", None), ("relationship_class", ("object",), None)] - ) - self.assertIn("parameter_values", exported) - self.assertEqual( - exported["parameter_values"], - [ - ("object_class", "object", "object_parameter", 2.3, "Base"), - ("relationship_class", ("object",), "relationship_parameter", 3.14, "Base"), - ], - ) - self.assertIn("parameter_value_lists", exported) - self.assertEqual(exported["parameter_value_lists"], [("value_list", "5.5"), ("value_list", "6.4")]) - self.assertIn("alternatives", exported) - self.assertEqual(exported["alternatives"], [("Base", "Base alternative"), ("alternative", None)]) - self.assertIn("scenarios", exported) - self.assertEqual(exported["scenarios"], [("scenario", False, None)]) - self.assertIn("scenario_alternatives", exported) - self.assertEqual(exported["scenario_alternatives"], [("scenario", "alternative", None)]) + with DatabaseMapping("sqlite://", username="UnitTest", create=True) as db_map: + self._assert_import_success(import_object_classes(db_map, ["object_class"])) + self._assert_import_success(import_object_parameters(db_map, [("object_class", "object_parameter")])) + self._assert_import_success(import_objects(db_map, [("object_class", "object")])) + self._assert_import_success( + import_object_parameter_values(db_map, [("object_class", "object", "object_parameter", 2.3)]) + ) + self._assert_import_success(import_relationship_classes(db_map, [("relationship_class", ["object_class"])])) + self._assert_import_success( + import_relationship_parameters(db_map, [("relationship_class", "relationship_parameter")]) + ) + self._assert_import_success(import_relationships(db_map, [("relationship_class", ["object"])])) + self._assert_import_success( + import_relationship_parameter_values( + db_map, [("relationship_class", ["object"], "relationship_parameter", 3.14)] + ) + ) + self._assert_import_success( + import_parameter_value_lists(db_map, [("value_list", "5.5"), ("value_list", "6.4")]) + ) + self._assert_import_success(import_alternatives(db_map, ["alternative"])) + self._assert_import_success(import_scenarios(db_map, ["scenario"])) + self._assert_import_success(import_scenario_alternatives(db_map, [("scenario", "alternative")])) + exported = export_data(db_map) + self.assertEqual(len(exported), 8) + self.assertIn("entity_classes", exported) + self.assertEqual( + exported["entity_classes"], + [("object_class", (), None, None, False), ("relationship_class", ("object_class",), None, None, True)], + ) + self.assertIn("parameter_definitions", exported) + self.assertEqual( + exported["parameter_definitions"], + [ + ("object_class", "object_parameter", None, None, None), + ("relationship_class", "relationship_parameter", None, None, None), + ], + ) + self.assertIn("entities", exported) + self.assertEqual( + exported["entities"], [("object_class", "object", None), ("relationship_class", ("object",), None)] + ) + self.assertIn("parameter_values", exported) + self.assertEqual( + exported["parameter_values"], + [ + ("object_class", "object", "object_parameter", 2.3, "Base"), + ("relationship_class", ("object",), "relationship_parameter", 3.14, "Base"), + ], + ) + self.assertIn("parameter_value_lists", exported) + self.assertEqual(exported["parameter_value_lists"], [("value_list", "5.5"), ("value_list", "6.4")]) + self.assertIn("alternatives", exported) + self.assertEqual(exported["alternatives"], [("Base", "Base alternative"), ("alternative", None)]) + self.assertIn("scenarios", exported) + self.assertEqual(exported["scenarios"], [("scenario", False, None)]) + self.assertIn("scenario_alternatives", exported) + self.assertEqual(exported["scenario_alternatives"], [("scenario", "alternative", None)]) if __name__ == '__main__': diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py index a7a813b8..6cfec085 100644 --- a/tests/test_import_functions.py +++ b/tests/test_import_functions.py @@ -8,11 +8,7 @@ # Public License for more details. You should have received a copy of the GNU Lesser General Public License along with # this program. If not, see . ###################################################################################################################### - -""" -Unit tests for import_functions.py. - -""" +""" Unit tests for import_functions.py. """ import unittest @@ -20,6 +16,7 @@ from spinedb_api.db_mapping import DatabaseMapping from spinedb_api.import_functions import ( import_alternatives, + import_entity_classes, import_object_classes, import_object_parameter_values, import_object_parameters, @@ -340,6 +337,42 @@ def test_import_existing_relationship_class_parameter(self): db_map.close() +class TestImportEntityClasses(unittest.TestCase): + def _assert_success(self, result): + items, errors = result + self.assertEqual(errors, []) + return items + + def test_import_object_class_with_all_optional_data(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + self._assert_success( + import_entity_classes( + db_map, + ( + ("Object", (), "The test class.", 23, True), + ("Relation", ("Object",), "The test relationship.", 5, False), + ), + ) + ) + entity_classes = db_map.get_entity_class_items() + self.assertEqual(len(entity_classes), 2) + data = ( + ( + row["name"], + row["dimension_name_list"], + row["description"], + row["display_icon"], + row["active_by_default"], + ) + for row in entity_classes + ) + expected = ( + ("Object", (), "The test class.", 23, True), + ("Relation", ("Object",), "The test relationship.", 5, False), + ) + self.assertCountEqual(data, expected) + + class TestImportEntity(unittest.TestCase): def test_import_multi_d_entity_twice(self): db_map = DatabaseMapping("sqlite://", create=True) From ab1b8e228c936ec9a60f0c385e2ee8c3e2923835 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Mon, 15 Jan 2024 12:52:36 +0200 Subject: [PATCH 229/317] Fix unit tests Re #316 --- tests/filters/test_tools.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/filters/test_tools.py b/tests/filters/test_tools.py index f9a65abf..41e7f03b 100644 --- a/tests/filters/test_tools.py +++ b/tests/filters/test_tools.py @@ -100,7 +100,7 @@ def test_empty_stack(self): try: apply_filter_stack(db_map, []) object_classes = export_entity_classes(db_map) - self.assertEqual(object_classes, [("object_class", (), None, None)]) + self.assertEqual(object_classes, [("object_class", (), None, None, False)]) finally: db_map.close() @@ -110,7 +110,7 @@ def test_single_renaming_filter(self): stack = [entity_class_renamer_config(object_class="renamed_once")] apply_filter_stack(db_map, stack) object_classes = export_entity_classes(db_map) - self.assertEqual(object_classes, [("renamed_once", (), None, None)]) + self.assertEqual(object_classes, [("renamed_once", (), None, None, False)]) finally: db_map.close() @@ -123,7 +123,7 @@ def test_two_renaming_filters(self): ] apply_filter_stack(db_map, stack) object_classes = export_entity_classes(db_map) - self.assertEqual(object_classes, [("renamed_twice", (), None, None)]) + self.assertEqual(object_classes, [("renamed_twice", (), None, None, False)]) finally: db_map.close() @@ -146,7 +146,7 @@ def test_without_filters(self): db_map = DatabaseMapping(self._db_url, self._engine) try: object_classes = export_entity_classes(db_map) - self.assertEqual(object_classes, [("object_class", (), None, None)]) + self.assertEqual(object_classes, [("object_class", (), None, None, False)]) finally: db_map.close() @@ -158,7 +158,7 @@ def test_single_renaming_filter(self): db_map = DatabaseMapping(url, self._engine) try: object_classes = export_entity_classes(db_map) - self.assertEqual(object_classes, [("renamed_once", (), None, None)]) + self.assertEqual(object_classes, [("renamed_once", (), None, None, False)]) finally: db_map.close() @@ -174,7 +174,7 @@ def test_two_renaming_filters(self): db_map = DatabaseMapping(url, self._engine) try: object_classes = export_entity_classes(db_map) - self.assertEqual(object_classes, [("renamed_twice", (), None, None)]) + self.assertEqual(object_classes, [("renamed_twice", (), None, None, False)]) finally: db_map.close() @@ -184,7 +184,7 @@ def test_config_embedded_to_url(self): db_map = DatabaseMapping(url, self._engine) try: object_classes = export_entity_classes(db_map) - self.assertEqual(object_classes, [("renamed_once", (), None, None)]) + self.assertEqual(object_classes, [("renamed_once", (), None, None, False)]) finally: db_map.close() From c874f4289e3f65bf3afeaf8fdb3d666f6bb49a3b Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Tue, 23 Jan 2024 13:56:04 +0100 Subject: [PATCH 230/317] Add changelog. --- CHANGELOG.md | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 CHANGELOG.md diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..91fc8b81 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,38 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +This is the first release where we keep a Spine-Database-API specific changelog. + +The database structure has changed quite a bit. +Large parts of the API have been rewritten or replaced by new systems. +We are still keeping many old entry points for backwards compatibility, +but those functions and methods are pending deprecation. + +### Changed + +- Objects and relationships have been replaced by *entities*. + Zero-dimensional entities correspond to objects while multidimensional entities to relationships. + +### Added + +- *Entity alternatives* control the visibility of entities. + This replaces previous tools, features and methods. +- Support for *superclasses*. + It is not possible to set a superclass for an entity class. + The class then inherits parameter definitions from its superclass. + +### Fixed + +### Removed + +- Tools, features and methods have been removed. + +### Deprecated + +### Security From 7fc383e8cfa99e4bc4c94dfb4c13144b01a6ce63 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Tue, 23 Jan 2024 15:22:15 +0100 Subject: [PATCH 231/317] Fix typo in CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 91fc8b81..4ce2b6bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,7 +24,7 @@ but those functions and methods are pending deprecation. - *Entity alternatives* control the visibility of entities. This replaces previous tools, features and methods. - Support for *superclasses*. - It is not possible to set a superclass for an entity class. + It is now possible to set a superclass for an entity class. The class then inherits parameter definitions from its superclass. ### Fixed From 782b0413c8ba4c753a88e6b0c6a6eebf9c3c08c9 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Wed, 24 Jan 2024 09:35:39 +0100 Subject: [PATCH 232/317] Update copyright notices Re spine-tools/Spine-Toolbox#2514 --- bin/update_copyrights.py | 6 +++--- spinedb_api/__init__.py | 1 + spinedb_api/compatibility.py | 1 + spinedb_api/db_mapping.py | 1 + spinedb_api/db_mapping_base.py | 1 + spinedb_api/db_mapping_commit_mixin.py | 1 + spinedb_api/db_mapping_query_mixin.py | 1 + spinedb_api/exception.py | 1 + spinedb_api/export_functions.py | 1 + spinedb_api/export_mapping/__init__.py | 1 + spinedb_api/export_mapping/export_mapping.py | 1 + spinedb_api/export_mapping/generator.py | 1 + spinedb_api/export_mapping/group_functions.py | 1 + spinedb_api/export_mapping/pivot.py | 1 + spinedb_api/export_mapping/settings.py | 1 + spinedb_api/filters/__init__.py | 1 + spinedb_api/filters/alternative_filter.py | 1 + spinedb_api/filters/execution_filter.py | 1 + spinedb_api/filters/renamer.py | 1 + spinedb_api/filters/scenario_filter.py | 1 + spinedb_api/filters/tools.py | 1 + spinedb_api/filters/value_transformer.py | 1 + spinedb_api/graph_layout_generator.py | 1 + spinedb_api/helpers.py | 1 + spinedb_api/import_functions.py | 1 + spinedb_api/import_mapping/__init__.py | 1 + spinedb_api/import_mapping/generator.py | 1 + spinedb_api/import_mapping/import_mapping.py | 1 + spinedb_api/import_mapping/import_mapping_compat.py | 1 + spinedb_api/import_mapping/type_conversion.py | 1 + spinedb_api/mapped_items.py | 1 + spinedb_api/mapping.py | 1 + spinedb_api/parameter_value.py | 1 + spinedb_api/perfect_split.py | 1 + spinedb_api/purge.py | 1 + spinedb_api/query.py | 1 + spinedb_api/server_client_helpers.py | 1 + spinedb_api/spine_db_client.py | 1 + spinedb_api/spine_db_server.py | 1 + spinedb_api/spine_io/__init__.py | 1 + spinedb_api/spine_io/exporters/__init__.py | 1 + spinedb_api/spine_io/exporters/csv_writer.py | 1 + spinedb_api/spine_io/exporters/excel.py | 1 + spinedb_api/spine_io/exporters/excel_writer.py | 1 + spinedb_api/spine_io/exporters/gdx_writer.py | 1 + spinedb_api/spine_io/exporters/sql_writer.py | 1 + spinedb_api/spine_io/exporters/writer.py | 1 + spinedb_api/spine_io/gdx_utils.py | 1 + spinedb_api/spine_io/importers/__init__.py | 1 + spinedb_api/spine_io/importers/csv_reader.py | 1 + spinedb_api/spine_io/importers/datapackage_reader.py | 1 + spinedb_api/spine_io/importers/excel_reader.py | 1 + spinedb_api/spine_io/importers/gdx_connector.py | 1 + spinedb_api/spine_io/importers/json_reader.py | 1 + spinedb_api/spine_io/importers/reader.py | 1 + spinedb_api/spine_io/importers/sqlalchemy_connector.py | 1 + spinedb_api/temp_id.py | 1 + tests/__init__.py | 1 + tests/custom_db_mapping.py | 1 + tests/export_mapping/__init__.py | 1 + tests/export_mapping/test_export_mapping.py | 1 + tests/export_mapping/test_pivot.py | 1 + tests/export_mapping/test_settings.py | 1 + tests/filters/__init__.py | 1 + tests/filters/test_alternative_filter.py | 1 + tests/filters/test_execution_filter.py | 1 + tests/filters/test_renamer.py | 1 + tests/filters/test_scenario_filter.py | 1 + tests/filters/test_tool_filter.py | 1 + tests/filters/test_tools.py | 1 + tests/filters/test_value_transformer.py | 1 + tests/import_mapping/__init__.py | 1 + tests/import_mapping/test_generator.py | 1 + tests/import_mapping/test_import_mapping.py | 1 + tests/import_mapping/test_type_conversion.py | 1 + tests/spine_io/__init__.py | 1 + tests/spine_io/exporters/__init__.py | 1 + tests/spine_io/exporters/test_csv_writer.py | 1 + tests/spine_io/exporters/test_excel_writer.py | 1 + tests/spine_io/exporters/test_gdx_writer.py | 1 + tests/spine_io/exporters/test_sql_writer.py | 1 + tests/spine_io/exporters/test_writer.py | 1 + tests/spine_io/importers/__init__.py | 1 + tests/spine_io/importers/test_CSVConnector.py | 1 + tests/spine_io/importers/test_GdxConnector.py | 1 + tests/spine_io/importers/test_datapackage_reader.py | 1 + tests/spine_io/importers/test_excel_reader.py | 1 + tests/spine_io/importers/test_json_reader.py | 1 + tests/spine_io/importers/test_reader.py | 1 + tests/spine_io/importers/test_sqlalchemy_connector.py | 1 + tests/spine_io/test_excel_integration.py | 1 + tests/test_DatabaseMapping.py | 1 + tests/test_check_integrity.py | 1 + tests/test_db_mapping_base.py | 1 + tests/test_export_functions.py | 1 + tests/test_helpers.py | 1 + tests/test_import_functions.py | 1 + tests/test_mapping.py | 1 + tests/test_migration.py | 1 + tests/test_parameter_value.py | 1 + tests/test_purge.py | 1 + 101 files changed, 103 insertions(+), 3 deletions(-) diff --git a/bin/update_copyrights.py b/bin/update_copyrights.py index 6194729c..f88a9ade 100644 --- a/bin/update_copyrights.py +++ b/bin/update_copyrights.py @@ -9,7 +9,7 @@ project_source_dir = Path(root_dir, "spinedb_api") test_source_dir = Path(root_dir, "tests") -expected = f"# Copyright (C) 2017-{current_year} Spine project consortium" +expected = f"# Copyright (C) 2023-{current_year} Mopo project consortium" def update_copyrights(path, suffix, recursive=True): @@ -18,8 +18,8 @@ def update_copyrights(path, suffix, recursive=True): i = 0 with open(path) as python_file: lines = python_file.readlines() - for i, line in enumerate(lines[1:4]): - if line.startswith("# Copyright (C) "): + for i, line in enumerate(lines[1:5]): + if line.startswith("# Copyright (C) ") and "Mopo" in line: lines[i + 1] = lines[i + 1][:21] + str(current_year) + lines[i + 1][25:] break if len(lines) <= i + 1 or not lines[i + 1].startswith(expected): diff --git a/spinedb_api/__init__.py b/spinedb_api/__init__.py index b6b41a6b..5f074081 100644 --- a/spinedb_api/__init__.py +++ b/spinedb_api/__init__.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/compatibility.py b/spinedb_api/compatibility.py index 92a6783b..be88f526 100644 --- a/spinedb_api/compatibility.py +++ b/spinedb_api/compatibility.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 057e34cb..5a4fe296 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 1602d613..f342786f 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py index ce105140..9f5db582 100644 --- a/spinedb_api/db_mapping_commit_mixin.py +++ b/spinedb_api/db_mapping_commit_mixin.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/db_mapping_query_mixin.py b/spinedb_api/db_mapping_query_mixin.py index c5fd4f37..bbe6b4eb 100644 --- a/spinedb_api/db_mapping_query_mixin.py +++ b/spinedb_api/db_mapping_query_mixin.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/exception.py b/spinedb_api/exception.py index c2554dab..e598fe75 100644 --- a/spinedb_api/exception.py +++ b/spinedb_api/exception.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/export_functions.py b/spinedb_api/export_functions.py index 1040a427..f4836de8 100644 --- a/spinedb_api/export_functions.py +++ b/spinedb_api/export_functions.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General # Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) diff --git a/spinedb_api/export_mapping/__init__.py b/spinedb_api/export_mapping/__init__.py index c75b202c..8a6b7acb 100644 --- a/spinedb_api/export_mapping/__init__.py +++ b/spinedb_api/export_mapping/__init__.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/export_mapping/export_mapping.py b/spinedb_api/export_mapping/export_mapping.py index e5ccef3e..942a5615 100644 --- a/spinedb_api/export_mapping/export_mapping.py +++ b/spinedb_api/export_mapping/export_mapping.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/export_mapping/generator.py b/spinedb_api/export_mapping/generator.py index 026454be..bb1982ed 100644 --- a/spinedb_api/export_mapping/generator.py +++ b/spinedb_api/export_mapping/generator.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/export_mapping/group_functions.py b/spinedb_api/export_mapping/group_functions.py index ce4598fb..c43532a3 100644 --- a/spinedb_api/export_mapping/group_functions.py +++ b/spinedb_api/export_mapping/group_functions.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/export_mapping/pivot.py b/spinedb_api/export_mapping/pivot.py index afd78344..fbd5fad0 100644 --- a/spinedb_api/export_mapping/pivot.py +++ b/spinedb_api/export_mapping/pivot.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/export_mapping/settings.py b/spinedb_api/export_mapping/settings.py index ed10ff58..6d6a5ae5 100644 --- a/spinedb_api/export_mapping/settings.py +++ b/spinedb_api/export_mapping/settings.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/filters/__init__.py b/spinedb_api/filters/__init__.py index 46105c99..1eaeee9f 100644 --- a/spinedb_api/filters/__init__.py +++ b/spinedb_api/filters/__init__.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/filters/alternative_filter.py b/spinedb_api/filters/alternative_filter.py index f406f793..82bf0ed7 100644 --- a/spinedb_api/filters/alternative_filter.py +++ b/spinedb_api/filters/alternative_filter.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/filters/execution_filter.py b/spinedb_api/filters/execution_filter.py index 51a9f2db..182e5586 100644 --- a/spinedb_api/filters/execution_filter.py +++ b/spinedb_api/filters/execution_filter.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/filters/renamer.py b/spinedb_api/filters/renamer.py index a3970ac7..68f372ad 100644 --- a/spinedb_api/filters/renamer.py +++ b/spinedb_api/filters/renamer.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/filters/scenario_filter.py b/spinedb_api/filters/scenario_filter.py index 3bbcb5e7..abe9977b 100644 --- a/spinedb_api/filters/scenario_filter.py +++ b/spinedb_api/filters/scenario_filter.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/filters/tools.py b/spinedb_api/filters/tools.py index 94d60545..46ef738d 100644 --- a/spinedb_api/filters/tools.py +++ b/spinedb_api/filters/tools.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/filters/value_transformer.py b/spinedb_api/filters/value_transformer.py index 956de19d..5d3dfeb2 100644 --- a/spinedb_api/filters/value_transformer.py +++ b/spinedb_api/filters/value_transformer.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/graph_layout_generator.py b/spinedb_api/graph_layout_generator.py index be0a149a..804a290f 100644 --- a/spinedb_api/graph_layout_generator.py +++ b/spinedb_api/graph_layout_generator.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Engine. # Spine Engine is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General # Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index da671683..b86037b5 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index c270c3e4..4d445941 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General # Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) diff --git a/spinedb_api/import_mapping/__init__.py b/spinedb_api/import_mapping/__init__.py index 9966601e..a1c7afd5 100644 --- a/spinedb_api/import_mapping/__init__.py +++ b/spinedb_api/import_mapping/__init__.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/import_mapping/generator.py b/spinedb_api/import_mapping/generator.py index b73037a4..7b4f2acd 100644 --- a/spinedb_api/import_mapping/generator.py +++ b/spinedb_api/import_mapping/generator.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/import_mapping/import_mapping.py b/spinedb_api/import_mapping/import_mapping.py index c07d4f3c..7d804e10 100644 --- a/spinedb_api/import_mapping/import_mapping.py +++ b/spinedb_api/import_mapping/import_mapping.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/import_mapping/import_mapping_compat.py b/spinedb_api/import_mapping/import_mapping_compat.py index c8170b51..8346d094 100644 --- a/spinedb_api/import_mapping/import_mapping_compat.py +++ b/spinedb_api/import_mapping/import_mapping_compat.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/import_mapping/type_conversion.py b/spinedb_api/import_mapping/type_conversion.py index 4f054c9c..a55ed388 100644 --- a/spinedb_api/import_mapping/type_conversion.py +++ b/spinedb_api/import_mapping/type_conversion.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 18cfcdc2..c6cdd39d 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/mapping.py b/spinedb_api/mapping.py index 5b197b17..3b7d08b0 100644 --- a/spinedb_api/mapping.py +++ b/spinedb_api/mapping.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/parameter_value.py b/spinedb_api/parameter_value.py index 7dbd704d..93a9ae75 100644 --- a/spinedb_api/parameter_value.py +++ b/spinedb_api/parameter_value.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/perfect_split.py b/spinedb_api/perfect_split.py index 6028e47a..5bbd4f13 100644 --- a/spinedb_api/perfect_split.py +++ b/spinedb_api/perfect_split.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General # Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) diff --git a/spinedb_api/purge.py b/spinedb_api/purge.py index 472fa5c3..49f5edc6 100644 --- a/spinedb_api/purge.py +++ b/spinedb_api/purge.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General # Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) diff --git a/spinedb_api/query.py b/spinedb_api/query.py index db23f1e7..4201a842 100644 --- a/spinedb_api/query.py +++ b/spinedb_api/query.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/server_client_helpers.py b/spinedb_api/server_client_helpers.py index f9b3b0b1..6ba2f737 100644 --- a/spinedb_api/server_client_helpers.py +++ b/spinedb_api/server_client_helpers.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_db_client.py b/spinedb_api/spine_db_client.py index 570043b4..57a30ccd 100644 --- a/spinedb_api/spine_db_client.py +++ b/spinedb_api/spine_db_client.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Engine. # Spine Engine is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General # Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) diff --git a/spinedb_api/spine_db_server.py b/spinedb_api/spine_db_server.py index 4ac5b7d8..5b1a7c42 100644 --- a/spinedb_api/spine_db_server.py +++ b/spinedb_api/spine_db_server.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/__init__.py b/spinedb_api/spine_io/__init__.py index adea0648..a9fd7693 100644 --- a/spinedb_api/spine_io/__init__.py +++ b/spinedb_api/spine_io/__init__.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/exporters/__init__.py b/spinedb_api/spine_io/exporters/__init__.py index 3d6ed59b..298d66f0 100644 --- a/spinedb_api/spine_io/exporters/__init__.py +++ b/spinedb_api/spine_io/exporters/__init__.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/exporters/csv_writer.py b/spinedb_api/spine_io/exporters/csv_writer.py index 4974be65..b018b0d5 100644 --- a/spinedb_api/spine_io/exporters/csv_writer.py +++ b/spinedb_api/spine_io/exporters/csv_writer.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/exporters/excel.py b/spinedb_api/spine_io/exporters/excel.py index d75ecd95..75b43cdb 100644 --- a/spinedb_api/spine_io/exporters/excel.py +++ b/spinedb_api/spine_io/exporters/excel.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/exporters/excel_writer.py b/spinedb_api/spine_io/exporters/excel_writer.py index 2c1996d2..39362126 100644 --- a/spinedb_api/spine_io/exporters/excel_writer.py +++ b/spinedb_api/spine_io/exporters/excel_writer.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/exporters/gdx_writer.py b/spinedb_api/spine_io/exporters/gdx_writer.py index faae3d53..8e7f9f3c 100644 --- a/spinedb_api/spine_io/exporters/gdx_writer.py +++ b/spinedb_api/spine_io/exporters/gdx_writer.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/exporters/sql_writer.py b/spinedb_api/spine_io/exporters/sql_writer.py index 6065d997..c726baac 100644 --- a/spinedb_api/spine_io/exporters/sql_writer.py +++ b/spinedb_api/spine_io/exporters/sql_writer.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/exporters/writer.py b/spinedb_api/spine_io/exporters/writer.py index 26698854..881ca9c3 100644 --- a/spinedb_api/spine_io/exporters/writer.py +++ b/spinedb_api/spine_io/exporters/writer.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/gdx_utils.py b/spinedb_api/spine_io/gdx_utils.py index 229c9338..9732a22a 100644 --- a/spinedb_api/spine_io/gdx_utils.py +++ b/spinedb_api/spine_io/gdx_utils.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/importers/__init__.py b/spinedb_api/spine_io/importers/__init__.py index ab3c7b4b..d59c4b8a 100644 --- a/spinedb_api/spine_io/importers/__init__.py +++ b/spinedb_api/spine_io/importers/__init__.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/importers/csv_reader.py b/spinedb_api/spine_io/importers/csv_reader.py index b3381bba..e8e19f4c 100644 --- a/spinedb_api/spine_io/importers/csv_reader.py +++ b/spinedb_api/spine_io/importers/csv_reader.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/importers/datapackage_reader.py b/spinedb_api/spine_io/importers/datapackage_reader.py index 8fa1bc5b..a12e9584 100644 --- a/spinedb_api/spine_io/importers/datapackage_reader.py +++ b/spinedb_api/spine_io/importers/datapackage_reader.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/importers/excel_reader.py b/spinedb_api/spine_io/importers/excel_reader.py index 22cabe0f..498ecfe0 100644 --- a/spinedb_api/spine_io/importers/excel_reader.py +++ b/spinedb_api/spine_io/importers/excel_reader.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/importers/gdx_connector.py b/spinedb_api/spine_io/importers/gdx_connector.py index 61f28935..22eef1f1 100644 --- a/spinedb_api/spine_io/importers/gdx_connector.py +++ b/spinedb_api/spine_io/importers/gdx_connector.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/importers/json_reader.py b/spinedb_api/spine_io/importers/json_reader.py index 024b98d7..06011b11 100644 --- a/spinedb_api/spine_io/importers/json_reader.py +++ b/spinedb_api/spine_io/importers/json_reader.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/importers/reader.py b/spinedb_api/spine_io/importers/reader.py index 3a645e96..1d33837b 100644 --- a/spinedb_api/spine_io/importers/reader.py +++ b/spinedb_api/spine_io/importers/reader.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/importers/sqlalchemy_connector.py b/spinedb_api/spine_io/importers/sqlalchemy_connector.py index e187356d..eb382b75 100644 --- a/spinedb_api/spine_io/importers/sqlalchemy_connector.py +++ b/spinedb_api/spine_io/importers/sqlalchemy_connector.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/temp_id.py b/spinedb_api/temp_id.py index 79066941..7bd502a3 100644 --- a/spinedb_api/temp_id.py +++ b/spinedb_api/temp_id.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/__init__.py b/tests/__init__.py index f9452d16..6516cd9e 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/custom_db_mapping.py b/tests/custom_db_mapping.py index ab578e3f..39880507 100644 --- a/tests/custom_db_mapping.py +++ b/tests/custom_db_mapping.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/export_mapping/__init__.py b/tests/export_mapping/__init__.py index 46105c99..1eaeee9f 100644 --- a/tests/export_mapping/__init__.py +++ b/tests/export_mapping/__init__.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/export_mapping/test_export_mapping.py b/tests/export_mapping/test_export_mapping.py index 826588ac..0caa5629 100644 --- a/tests/export_mapping/test_export_mapping.py +++ b/tests/export_mapping/test_export_mapping.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/export_mapping/test_pivot.py b/tests/export_mapping/test_pivot.py index cfe2e12c..5f3e2772 100644 --- a/tests/export_mapping/test_pivot.py +++ b/tests/export_mapping/test_pivot.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/export_mapping/test_settings.py b/tests/export_mapping/test_settings.py index fb84d20a..bd8b5fc7 100644 --- a/tests/export_mapping/test_settings.py +++ b/tests/export_mapping/test_settings.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/filters/__init__.py b/tests/filters/__init__.py index 46105c99..1eaeee9f 100644 --- a/tests/filters/__init__.py +++ b/tests/filters/__init__.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/filters/test_alternative_filter.py b/tests/filters/test_alternative_filter.py index a677bbe2..b6c20846 100644 --- a/tests/filters/test_alternative_filter.py +++ b/tests/filters/test_alternative_filter.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/filters/test_execution_filter.py b/tests/filters/test_execution_filter.py index 6a092ee6..fc819bd1 100644 --- a/tests/filters/test_execution_filter.py +++ b/tests/filters/test_execution_filter.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/filters/test_renamer.py b/tests/filters/test_renamer.py index d3bf469e..82599da0 100644 --- a/tests/filters/test_renamer.py +++ b/tests/filters/test_renamer.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/filters/test_scenario_filter.py b/tests/filters/test_scenario_filter.py index de1be424..52a92c12 100644 --- a/tests/filters/test_scenario_filter.py +++ b/tests/filters/test_scenario_filter.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/filters/test_tool_filter.py b/tests/filters/test_tool_filter.py index a85455da..69efaab3 100644 --- a/tests/filters/test_tool_filter.py +++ b/tests/filters/test_tool_filter.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/filters/test_tools.py b/tests/filters/test_tools.py index 41e7f03b..834943c0 100644 --- a/tests/filters/test_tools.py +++ b/tests/filters/test_tools.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/filters/test_value_transformer.py b/tests/filters/test_value_transformer.py index 3660e85a..7b43bd79 100644 --- a/tests/filters/test_value_transformer.py +++ b/tests/filters/test_value_transformer.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/import_mapping/__init__.py b/tests/import_mapping/__init__.py index 46105c99..1eaeee9f 100644 --- a/tests/import_mapping/__init__.py +++ b/tests/import_mapping/__init__.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/import_mapping/test_generator.py b/tests/import_mapping/test_generator.py index df0a8dac..ca9d9be3 100644 --- a/tests/import_mapping/test_generator.py +++ b/tests/import_mapping/test_generator.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/import_mapping/test_import_mapping.py b/tests/import_mapping/test_import_mapping.py index ff2c5e6c..0deec7ac 100644 --- a/tests/import_mapping/test_import_mapping.py +++ b/tests/import_mapping/test_import_mapping.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/import_mapping/test_type_conversion.py b/tests/import_mapping/test_type_conversion.py index c772dabd..2ac3eac9 100644 --- a/tests/import_mapping/test_type_conversion.py +++ b/tests/import_mapping/test_type_conversion.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/__init__.py b/tests/spine_io/__init__.py index 219c44b8..85466aa1 100644 --- a/tests/spine_io/__init__.py +++ b/tests/spine_io/__init__.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/exporters/__init__.py b/tests/spine_io/exporters/__init__.py index a0581eb2..99ed4315 100644 --- a/tests/spine_io/exporters/__init__.py +++ b/tests/spine_io/exporters/__init__.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/exporters/test_csv_writer.py b/tests/spine_io/exporters/test_csv_writer.py index 9ddf6dfc..ae452acc 100644 --- a/tests/spine_io/exporters/test_csv_writer.py +++ b/tests/spine_io/exporters/test_csv_writer.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/exporters/test_excel_writer.py b/tests/spine_io/exporters/test_excel_writer.py index 65f8cf02..bfe4d5b1 100644 --- a/tests/spine_io/exporters/test_excel_writer.py +++ b/tests/spine_io/exporters/test_excel_writer.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/exporters/test_gdx_writer.py b/tests/spine_io/exporters/test_gdx_writer.py index 48407fd4..785d5e16 100644 --- a/tests/spine_io/exporters/test_gdx_writer.py +++ b/tests/spine_io/exporters/test_gdx_writer.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/exporters/test_sql_writer.py b/tests/spine_io/exporters/test_sql_writer.py index 702adeac..b674e8df 100644 --- a/tests/spine_io/exporters/test_sql_writer.py +++ b/tests/spine_io/exporters/test_sql_writer.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/exporters/test_writer.py b/tests/spine_io/exporters/test_writer.py index 9a9283c0..9a45f1a4 100644 --- a/tests/spine_io/exporters/test_writer.py +++ b/tests/spine_io/exporters/test_writer.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/importers/__init__.py b/tests/spine_io/importers/__init__.py index bb67bf8a..d705009c 100644 --- a/tests/spine_io/importers/__init__.py +++ b/tests/spine_io/importers/__init__.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/importers/test_CSVConnector.py b/tests/spine_io/importers/test_CSVConnector.py index b8a9d93a..956c86e2 100644 --- a/tests/spine_io/importers/test_CSVConnector.py +++ b/tests/spine_io/importers/test_CSVConnector.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/importers/test_GdxConnector.py b/tests/spine_io/importers/test_GdxConnector.py index 7a55c992..ac0f4b28 100644 --- a/tests/spine_io/importers/test_GdxConnector.py +++ b/tests/spine_io/importers/test_GdxConnector.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/importers/test_datapackage_reader.py b/tests/spine_io/importers/test_datapackage_reader.py index 75906b29..4b991f65 100644 --- a/tests/spine_io/importers/test_datapackage_reader.py +++ b/tests/spine_io/importers/test_datapackage_reader.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/importers/test_excel_reader.py b/tests/spine_io/importers/test_excel_reader.py index 8c3f59ed..5477188c 100644 --- a/tests/spine_io/importers/test_excel_reader.py +++ b/tests/spine_io/importers/test_excel_reader.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/importers/test_json_reader.py b/tests/spine_io/importers/test_json_reader.py index 303bf7ec..7b20c007 100644 --- a/tests/spine_io/importers/test_json_reader.py +++ b/tests/spine_io/importers/test_json_reader.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/importers/test_reader.py b/tests/spine_io/importers/test_reader.py index 71e90535..c4dc1826 100644 --- a/tests/spine_io/importers/test_reader.py +++ b/tests/spine_io/importers/test_reader.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/importers/test_sqlalchemy_connector.py b/tests/spine_io/importers/test_sqlalchemy_connector.py index 9fab91b1..50230112 100644 --- a/tests/spine_io/importers/test_sqlalchemy_connector.py +++ b/tests/spine_io/importers/test_sqlalchemy_connector.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/test_excel_integration.py b/tests/spine_io/test_excel_integration.py index 7ca1a318..d1bbe6a7 100644 --- a/tests/spine_io/test_excel_integration.py +++ b/tests/spine_io/test_excel_integration.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index cb2596ce..f6d03af3 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/test_check_integrity.py b/tests/test_check_integrity.py index 9d54caae..b4476d83 100644 --- a/tests/test_check_integrity.py +++ b/tests/test_check_integrity.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/test_db_mapping_base.py b/tests/test_db_mapping_base.py index 179c1ed8..e55a4b32 100644 --- a/tests/test_db_mapping_base.py +++ b/tests/test_db_mapping_base.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/test_export_functions.py b/tests/test_export_functions.py index 3a4e9ccc..d442e372 100644 --- a/tests/test_export_functions.py +++ b/tests/test_export_functions.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General # Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 5497ae66..a9879261 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py index 6cfec085..17dcdf4c 100644 --- a/tests/test_import_functions.py +++ b/tests/test_import_functions.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/test_mapping.py b/tests/test_mapping.py index 94358de2..a00aebba 100644 --- a/tests/test_mapping.py +++ b/tests/test_mapping.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/test_migration.py b/tests/test_migration.py index c9cf76d1..d9168e49 100644 --- a/tests/test_migration.py +++ b/tests/test_migration.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/test_parameter_value.py b/tests/test_parameter_value.py index 8ae63ef5..77e794f1 100644 --- a/tests/test_parameter_value.py +++ b/tests/test_parameter_value.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/test_purge.py b/tests/test_purge.py index f6093e1c..d465afe8 100644 --- a/tests/test_purge.py +++ b/tests/test_purge.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General # Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) From 74d39adf79040c6aac8c95caea930be425ac0855 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Wed, 24 Jan 2024 10:21:49 +0100 Subject: [PATCH 233/317] Fix database migration when database has no entity classes We must bail out from the migration script if there is no active_by_default values to set, e.g. when the database is empty. Re spine-tools/Spine-Toolbox#2512 --- .../8b0eff478bcb_add_active_by_default_to_entity_class.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py b/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py index 5eb0e49e..17afafc2 100644 --- a/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py +++ b/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py @@ -30,6 +30,8 @@ def upgrade(): metadata.reflect(bind=conn) dimension_table = metadata.tables["entity_class_dimension"] dimensional_class_ids = {row.entity_class_id for row in session.query(dimension_table)} + if not dimensional_class_ids: + return metadata.reflect(bind=conn) class_table = metadata.tables["entity_class"] update_statement = ( From a97c7940a82aec9e9921b514e63d05cd113c9790 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Wed, 24 Jan 2024 15:12:06 +0100 Subject: [PATCH 234/317] Revert "Update copyright notices" This reverts commit 782b0413c8ba4c753a88e6b0c6a6eebf9c3c08c9. --- bin/update_copyrights.py | 6 +++--- spinedb_api/__init__.py | 1 - spinedb_api/compatibility.py | 1 - spinedb_api/db_mapping.py | 1 - spinedb_api/db_mapping_base.py | 1 - spinedb_api/db_mapping_commit_mixin.py | 1 - spinedb_api/db_mapping_query_mixin.py | 1 - spinedb_api/exception.py | 1 - spinedb_api/export_functions.py | 1 - spinedb_api/export_mapping/__init__.py | 1 - spinedb_api/export_mapping/export_mapping.py | 1 - spinedb_api/export_mapping/generator.py | 1 - spinedb_api/export_mapping/group_functions.py | 1 - spinedb_api/export_mapping/pivot.py | 1 - spinedb_api/export_mapping/settings.py | 1 - spinedb_api/filters/__init__.py | 1 - spinedb_api/filters/alternative_filter.py | 1 - spinedb_api/filters/execution_filter.py | 1 - spinedb_api/filters/renamer.py | 1 - spinedb_api/filters/scenario_filter.py | 1 - spinedb_api/filters/tools.py | 1 - spinedb_api/filters/value_transformer.py | 1 - spinedb_api/graph_layout_generator.py | 1 - spinedb_api/helpers.py | 1 - spinedb_api/import_functions.py | 1 - spinedb_api/import_mapping/__init__.py | 1 - spinedb_api/import_mapping/generator.py | 1 - spinedb_api/import_mapping/import_mapping.py | 1 - spinedb_api/import_mapping/import_mapping_compat.py | 1 - spinedb_api/import_mapping/type_conversion.py | 1 - spinedb_api/mapped_items.py | 1 - spinedb_api/mapping.py | 1 - spinedb_api/parameter_value.py | 1 - spinedb_api/perfect_split.py | 1 - spinedb_api/purge.py | 1 - spinedb_api/query.py | 1 - spinedb_api/server_client_helpers.py | 1 - spinedb_api/spine_db_client.py | 1 - spinedb_api/spine_db_server.py | 1 - spinedb_api/spine_io/__init__.py | 1 - spinedb_api/spine_io/exporters/__init__.py | 1 - spinedb_api/spine_io/exporters/csv_writer.py | 1 - spinedb_api/spine_io/exporters/excel.py | 1 - spinedb_api/spine_io/exporters/excel_writer.py | 1 - spinedb_api/spine_io/exporters/gdx_writer.py | 1 - spinedb_api/spine_io/exporters/sql_writer.py | 1 - spinedb_api/spine_io/exporters/writer.py | 1 - spinedb_api/spine_io/gdx_utils.py | 1 - spinedb_api/spine_io/importers/__init__.py | 1 - spinedb_api/spine_io/importers/csv_reader.py | 1 - spinedb_api/spine_io/importers/datapackage_reader.py | 1 - spinedb_api/spine_io/importers/excel_reader.py | 1 - spinedb_api/spine_io/importers/gdx_connector.py | 1 - spinedb_api/spine_io/importers/json_reader.py | 1 - spinedb_api/spine_io/importers/reader.py | 1 - spinedb_api/spine_io/importers/sqlalchemy_connector.py | 1 - spinedb_api/temp_id.py | 1 - tests/__init__.py | 1 - tests/custom_db_mapping.py | 1 - tests/export_mapping/__init__.py | 1 - tests/export_mapping/test_export_mapping.py | 1 - tests/export_mapping/test_pivot.py | 1 - tests/export_mapping/test_settings.py | 1 - tests/filters/__init__.py | 1 - tests/filters/test_alternative_filter.py | 1 - tests/filters/test_execution_filter.py | 1 - tests/filters/test_renamer.py | 1 - tests/filters/test_scenario_filter.py | 1 - tests/filters/test_tool_filter.py | 1 - tests/filters/test_tools.py | 1 - tests/filters/test_value_transformer.py | 1 - tests/import_mapping/__init__.py | 1 - tests/import_mapping/test_generator.py | 1 - tests/import_mapping/test_import_mapping.py | 1 - tests/import_mapping/test_type_conversion.py | 1 - tests/spine_io/__init__.py | 1 - tests/spine_io/exporters/__init__.py | 1 - tests/spine_io/exporters/test_csv_writer.py | 1 - tests/spine_io/exporters/test_excel_writer.py | 1 - tests/spine_io/exporters/test_gdx_writer.py | 1 - tests/spine_io/exporters/test_sql_writer.py | 1 - tests/spine_io/exporters/test_writer.py | 1 - tests/spine_io/importers/__init__.py | 1 - tests/spine_io/importers/test_CSVConnector.py | 1 - tests/spine_io/importers/test_GdxConnector.py | 1 - tests/spine_io/importers/test_datapackage_reader.py | 1 - tests/spine_io/importers/test_excel_reader.py | 1 - tests/spine_io/importers/test_json_reader.py | 1 - tests/spine_io/importers/test_reader.py | 1 - tests/spine_io/importers/test_sqlalchemy_connector.py | 1 - tests/spine_io/test_excel_integration.py | 1 - tests/test_DatabaseMapping.py | 1 - tests/test_check_integrity.py | 1 - tests/test_db_mapping_base.py | 1 - tests/test_export_functions.py | 1 - tests/test_helpers.py | 1 - tests/test_import_functions.py | 1 - tests/test_mapping.py | 1 - tests/test_migration.py | 1 - tests/test_parameter_value.py | 1 - tests/test_purge.py | 1 - 101 files changed, 3 insertions(+), 103 deletions(-) diff --git a/bin/update_copyrights.py b/bin/update_copyrights.py index f88a9ade..6194729c 100644 --- a/bin/update_copyrights.py +++ b/bin/update_copyrights.py @@ -9,7 +9,7 @@ project_source_dir = Path(root_dir, "spinedb_api") test_source_dir = Path(root_dir, "tests") -expected = f"# Copyright (C) 2023-{current_year} Mopo project consortium" +expected = f"# Copyright (C) 2017-{current_year} Spine project consortium" def update_copyrights(path, suffix, recursive=True): @@ -18,8 +18,8 @@ def update_copyrights(path, suffix, recursive=True): i = 0 with open(path) as python_file: lines = python_file.readlines() - for i, line in enumerate(lines[1:5]): - if line.startswith("# Copyright (C) ") and "Mopo" in line: + for i, line in enumerate(lines[1:4]): + if line.startswith("# Copyright (C) "): lines[i + 1] = lines[i + 1][:21] + str(current_year) + lines[i + 1][25:] break if len(lines) <= i + 1 or not lines[i + 1].startswith(expected): diff --git a/spinedb_api/__init__.py b/spinedb_api/__init__.py index 5f074081..b6b41a6b 100644 --- a/spinedb_api/__init__.py +++ b/spinedb_api/__init__.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/compatibility.py b/spinedb_api/compatibility.py index be88f526..92a6783b 100644 --- a/spinedb_api/compatibility.py +++ b/spinedb_api/compatibility.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 5a4fe296..057e34cb 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index f342786f..1602d613 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py index 9f5db582..ce105140 100644 --- a/spinedb_api/db_mapping_commit_mixin.py +++ b/spinedb_api/db_mapping_commit_mixin.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/db_mapping_query_mixin.py b/spinedb_api/db_mapping_query_mixin.py index bbe6b4eb..c5fd4f37 100644 --- a/spinedb_api/db_mapping_query_mixin.py +++ b/spinedb_api/db_mapping_query_mixin.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/exception.py b/spinedb_api/exception.py index e598fe75..c2554dab 100644 --- a/spinedb_api/exception.py +++ b/spinedb_api/exception.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/export_functions.py b/spinedb_api/export_functions.py index f4836de8..1040a427 100644 --- a/spinedb_api/export_functions.py +++ b/spinedb_api/export_functions.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General # Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) diff --git a/spinedb_api/export_mapping/__init__.py b/spinedb_api/export_mapping/__init__.py index 8a6b7acb..c75b202c 100644 --- a/spinedb_api/export_mapping/__init__.py +++ b/spinedb_api/export_mapping/__init__.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/export_mapping/export_mapping.py b/spinedb_api/export_mapping/export_mapping.py index 942a5615..e5ccef3e 100644 --- a/spinedb_api/export_mapping/export_mapping.py +++ b/spinedb_api/export_mapping/export_mapping.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/export_mapping/generator.py b/spinedb_api/export_mapping/generator.py index bb1982ed..026454be 100644 --- a/spinedb_api/export_mapping/generator.py +++ b/spinedb_api/export_mapping/generator.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/export_mapping/group_functions.py b/spinedb_api/export_mapping/group_functions.py index c43532a3..ce4598fb 100644 --- a/spinedb_api/export_mapping/group_functions.py +++ b/spinedb_api/export_mapping/group_functions.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/export_mapping/pivot.py b/spinedb_api/export_mapping/pivot.py index fbd5fad0..afd78344 100644 --- a/spinedb_api/export_mapping/pivot.py +++ b/spinedb_api/export_mapping/pivot.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/export_mapping/settings.py b/spinedb_api/export_mapping/settings.py index 6d6a5ae5..ed10ff58 100644 --- a/spinedb_api/export_mapping/settings.py +++ b/spinedb_api/export_mapping/settings.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/filters/__init__.py b/spinedb_api/filters/__init__.py index 1eaeee9f..46105c99 100644 --- a/spinedb_api/filters/__init__.py +++ b/spinedb_api/filters/__init__.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/filters/alternative_filter.py b/spinedb_api/filters/alternative_filter.py index 82bf0ed7..f406f793 100644 --- a/spinedb_api/filters/alternative_filter.py +++ b/spinedb_api/filters/alternative_filter.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/filters/execution_filter.py b/spinedb_api/filters/execution_filter.py index 182e5586..51a9f2db 100644 --- a/spinedb_api/filters/execution_filter.py +++ b/spinedb_api/filters/execution_filter.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/filters/renamer.py b/spinedb_api/filters/renamer.py index 68f372ad..a3970ac7 100644 --- a/spinedb_api/filters/renamer.py +++ b/spinedb_api/filters/renamer.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/filters/scenario_filter.py b/spinedb_api/filters/scenario_filter.py index abe9977b..3bbcb5e7 100644 --- a/spinedb_api/filters/scenario_filter.py +++ b/spinedb_api/filters/scenario_filter.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/filters/tools.py b/spinedb_api/filters/tools.py index 46ef738d..94d60545 100644 --- a/spinedb_api/filters/tools.py +++ b/spinedb_api/filters/tools.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/filters/value_transformer.py b/spinedb_api/filters/value_transformer.py index 5d3dfeb2..956de19d 100644 --- a/spinedb_api/filters/value_transformer.py +++ b/spinedb_api/filters/value_transformer.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/graph_layout_generator.py b/spinedb_api/graph_layout_generator.py index 804a290f..be0a149a 100644 --- a/spinedb_api/graph_layout_generator.py +++ b/spinedb_api/graph_layout_generator.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Engine. # Spine Engine is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General # Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index b86037b5..da671683 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 4d445941..c270c3e4 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General # Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) diff --git a/spinedb_api/import_mapping/__init__.py b/spinedb_api/import_mapping/__init__.py index a1c7afd5..9966601e 100644 --- a/spinedb_api/import_mapping/__init__.py +++ b/spinedb_api/import_mapping/__init__.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/import_mapping/generator.py b/spinedb_api/import_mapping/generator.py index 7b4f2acd..b73037a4 100644 --- a/spinedb_api/import_mapping/generator.py +++ b/spinedb_api/import_mapping/generator.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/import_mapping/import_mapping.py b/spinedb_api/import_mapping/import_mapping.py index 7d804e10..c07d4f3c 100644 --- a/spinedb_api/import_mapping/import_mapping.py +++ b/spinedb_api/import_mapping/import_mapping.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/import_mapping/import_mapping_compat.py b/spinedb_api/import_mapping/import_mapping_compat.py index 8346d094..c8170b51 100644 --- a/spinedb_api/import_mapping/import_mapping_compat.py +++ b/spinedb_api/import_mapping/import_mapping_compat.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/import_mapping/type_conversion.py b/spinedb_api/import_mapping/type_conversion.py index a55ed388..4f054c9c 100644 --- a/spinedb_api/import_mapping/type_conversion.py +++ b/spinedb_api/import_mapping/type_conversion.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index c6cdd39d..18cfcdc2 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/mapping.py b/spinedb_api/mapping.py index 3b7d08b0..5b197b17 100644 --- a/spinedb_api/mapping.py +++ b/spinedb_api/mapping.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/parameter_value.py b/spinedb_api/parameter_value.py index 93a9ae75..7dbd704d 100644 --- a/spinedb_api/parameter_value.py +++ b/spinedb_api/parameter_value.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/perfect_split.py b/spinedb_api/perfect_split.py index 5bbd4f13..6028e47a 100644 --- a/spinedb_api/perfect_split.py +++ b/spinedb_api/perfect_split.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General # Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) diff --git a/spinedb_api/purge.py b/spinedb_api/purge.py index 49f5edc6..472fa5c3 100644 --- a/spinedb_api/purge.py +++ b/spinedb_api/purge.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General # Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) diff --git a/spinedb_api/query.py b/spinedb_api/query.py index 4201a842..db23f1e7 100644 --- a/spinedb_api/query.py +++ b/spinedb_api/query.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/server_client_helpers.py b/spinedb_api/server_client_helpers.py index 6ba2f737..f9b3b0b1 100644 --- a/spinedb_api/server_client_helpers.py +++ b/spinedb_api/server_client_helpers.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_db_client.py b/spinedb_api/spine_db_client.py index 57a30ccd..570043b4 100644 --- a/spinedb_api/spine_db_client.py +++ b/spinedb_api/spine_db_client.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Engine. # Spine Engine is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General # Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) diff --git a/spinedb_api/spine_db_server.py b/spinedb_api/spine_db_server.py index 5b1a7c42..4ac5b7d8 100644 --- a/spinedb_api/spine_db_server.py +++ b/spinedb_api/spine_db_server.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/__init__.py b/spinedb_api/spine_io/__init__.py index a9fd7693..adea0648 100644 --- a/spinedb_api/spine_io/__init__.py +++ b/spinedb_api/spine_io/__init__.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/exporters/__init__.py b/spinedb_api/spine_io/exporters/__init__.py index 298d66f0..3d6ed59b 100644 --- a/spinedb_api/spine_io/exporters/__init__.py +++ b/spinedb_api/spine_io/exporters/__init__.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/exporters/csv_writer.py b/spinedb_api/spine_io/exporters/csv_writer.py index b018b0d5..4974be65 100644 --- a/spinedb_api/spine_io/exporters/csv_writer.py +++ b/spinedb_api/spine_io/exporters/csv_writer.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/exporters/excel.py b/spinedb_api/spine_io/exporters/excel.py index 75b43cdb..d75ecd95 100644 --- a/spinedb_api/spine_io/exporters/excel.py +++ b/spinedb_api/spine_io/exporters/excel.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/exporters/excel_writer.py b/spinedb_api/spine_io/exporters/excel_writer.py index 39362126..2c1996d2 100644 --- a/spinedb_api/spine_io/exporters/excel_writer.py +++ b/spinedb_api/spine_io/exporters/excel_writer.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/exporters/gdx_writer.py b/spinedb_api/spine_io/exporters/gdx_writer.py index 8e7f9f3c..faae3d53 100644 --- a/spinedb_api/spine_io/exporters/gdx_writer.py +++ b/spinedb_api/spine_io/exporters/gdx_writer.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/exporters/sql_writer.py b/spinedb_api/spine_io/exporters/sql_writer.py index c726baac..6065d997 100644 --- a/spinedb_api/spine_io/exporters/sql_writer.py +++ b/spinedb_api/spine_io/exporters/sql_writer.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/exporters/writer.py b/spinedb_api/spine_io/exporters/writer.py index 881ca9c3..26698854 100644 --- a/spinedb_api/spine_io/exporters/writer.py +++ b/spinedb_api/spine_io/exporters/writer.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/gdx_utils.py b/spinedb_api/spine_io/gdx_utils.py index 9732a22a..229c9338 100644 --- a/spinedb_api/spine_io/gdx_utils.py +++ b/spinedb_api/spine_io/gdx_utils.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/importers/__init__.py b/spinedb_api/spine_io/importers/__init__.py index d59c4b8a..ab3c7b4b 100644 --- a/spinedb_api/spine_io/importers/__init__.py +++ b/spinedb_api/spine_io/importers/__init__.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/importers/csv_reader.py b/spinedb_api/spine_io/importers/csv_reader.py index e8e19f4c..b3381bba 100644 --- a/spinedb_api/spine_io/importers/csv_reader.py +++ b/spinedb_api/spine_io/importers/csv_reader.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/importers/datapackage_reader.py b/spinedb_api/spine_io/importers/datapackage_reader.py index a12e9584..8fa1bc5b 100644 --- a/spinedb_api/spine_io/importers/datapackage_reader.py +++ b/spinedb_api/spine_io/importers/datapackage_reader.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/importers/excel_reader.py b/spinedb_api/spine_io/importers/excel_reader.py index 498ecfe0..22cabe0f 100644 --- a/spinedb_api/spine_io/importers/excel_reader.py +++ b/spinedb_api/spine_io/importers/excel_reader.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/importers/gdx_connector.py b/spinedb_api/spine_io/importers/gdx_connector.py index 22eef1f1..61f28935 100644 --- a/spinedb_api/spine_io/importers/gdx_connector.py +++ b/spinedb_api/spine_io/importers/gdx_connector.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/importers/json_reader.py b/spinedb_api/spine_io/importers/json_reader.py index 06011b11..024b98d7 100644 --- a/spinedb_api/spine_io/importers/json_reader.py +++ b/spinedb_api/spine_io/importers/json_reader.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/importers/reader.py b/spinedb_api/spine_io/importers/reader.py index 1d33837b..3a645e96 100644 --- a/spinedb_api/spine_io/importers/reader.py +++ b/spinedb_api/spine_io/importers/reader.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/importers/sqlalchemy_connector.py b/spinedb_api/spine_io/importers/sqlalchemy_connector.py index eb382b75..e187356d 100644 --- a/spinedb_api/spine_io/importers/sqlalchemy_connector.py +++ b/spinedb_api/spine_io/importers/sqlalchemy_connector.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/temp_id.py b/spinedb_api/temp_id.py index 7bd502a3..79066941 100644 --- a/spinedb_api/temp_id.py +++ b/spinedb_api/temp_id.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/__init__.py b/tests/__init__.py index 6516cd9e..f9452d16 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/custom_db_mapping.py b/tests/custom_db_mapping.py index 39880507..ab578e3f 100644 --- a/tests/custom_db_mapping.py +++ b/tests/custom_db_mapping.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/export_mapping/__init__.py b/tests/export_mapping/__init__.py index 1eaeee9f..46105c99 100644 --- a/tests/export_mapping/__init__.py +++ b/tests/export_mapping/__init__.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/export_mapping/test_export_mapping.py b/tests/export_mapping/test_export_mapping.py index 0caa5629..826588ac 100644 --- a/tests/export_mapping/test_export_mapping.py +++ b/tests/export_mapping/test_export_mapping.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/export_mapping/test_pivot.py b/tests/export_mapping/test_pivot.py index 5f3e2772..cfe2e12c 100644 --- a/tests/export_mapping/test_pivot.py +++ b/tests/export_mapping/test_pivot.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/export_mapping/test_settings.py b/tests/export_mapping/test_settings.py index bd8b5fc7..fb84d20a 100644 --- a/tests/export_mapping/test_settings.py +++ b/tests/export_mapping/test_settings.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/filters/__init__.py b/tests/filters/__init__.py index 1eaeee9f..46105c99 100644 --- a/tests/filters/__init__.py +++ b/tests/filters/__init__.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/filters/test_alternative_filter.py b/tests/filters/test_alternative_filter.py index b6c20846..a677bbe2 100644 --- a/tests/filters/test_alternative_filter.py +++ b/tests/filters/test_alternative_filter.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/filters/test_execution_filter.py b/tests/filters/test_execution_filter.py index fc819bd1..6a092ee6 100644 --- a/tests/filters/test_execution_filter.py +++ b/tests/filters/test_execution_filter.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/filters/test_renamer.py b/tests/filters/test_renamer.py index 82599da0..d3bf469e 100644 --- a/tests/filters/test_renamer.py +++ b/tests/filters/test_renamer.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/filters/test_scenario_filter.py b/tests/filters/test_scenario_filter.py index 52a92c12..de1be424 100644 --- a/tests/filters/test_scenario_filter.py +++ b/tests/filters/test_scenario_filter.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/filters/test_tool_filter.py b/tests/filters/test_tool_filter.py index 69efaab3..a85455da 100644 --- a/tests/filters/test_tool_filter.py +++ b/tests/filters/test_tool_filter.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/filters/test_tools.py b/tests/filters/test_tools.py index 834943c0..41e7f03b 100644 --- a/tests/filters/test_tools.py +++ b/tests/filters/test_tools.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/filters/test_value_transformer.py b/tests/filters/test_value_transformer.py index 7b43bd79..3660e85a 100644 --- a/tests/filters/test_value_transformer.py +++ b/tests/filters/test_value_transformer.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/import_mapping/__init__.py b/tests/import_mapping/__init__.py index 1eaeee9f..46105c99 100644 --- a/tests/import_mapping/__init__.py +++ b/tests/import_mapping/__init__.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/import_mapping/test_generator.py b/tests/import_mapping/test_generator.py index ca9d9be3..df0a8dac 100644 --- a/tests/import_mapping/test_generator.py +++ b/tests/import_mapping/test_generator.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/import_mapping/test_import_mapping.py b/tests/import_mapping/test_import_mapping.py index 0deec7ac..ff2c5e6c 100644 --- a/tests/import_mapping/test_import_mapping.py +++ b/tests/import_mapping/test_import_mapping.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/import_mapping/test_type_conversion.py b/tests/import_mapping/test_type_conversion.py index 2ac3eac9..c772dabd 100644 --- a/tests/import_mapping/test_type_conversion.py +++ b/tests/import_mapping/test_type_conversion.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/__init__.py b/tests/spine_io/__init__.py index 85466aa1..219c44b8 100644 --- a/tests/spine_io/__init__.py +++ b/tests/spine_io/__init__.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/exporters/__init__.py b/tests/spine_io/exporters/__init__.py index 99ed4315..a0581eb2 100644 --- a/tests/spine_io/exporters/__init__.py +++ b/tests/spine_io/exporters/__init__.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/exporters/test_csv_writer.py b/tests/spine_io/exporters/test_csv_writer.py index ae452acc..9ddf6dfc 100644 --- a/tests/spine_io/exporters/test_csv_writer.py +++ b/tests/spine_io/exporters/test_csv_writer.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/exporters/test_excel_writer.py b/tests/spine_io/exporters/test_excel_writer.py index bfe4d5b1..65f8cf02 100644 --- a/tests/spine_io/exporters/test_excel_writer.py +++ b/tests/spine_io/exporters/test_excel_writer.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/exporters/test_gdx_writer.py b/tests/spine_io/exporters/test_gdx_writer.py index 785d5e16..48407fd4 100644 --- a/tests/spine_io/exporters/test_gdx_writer.py +++ b/tests/spine_io/exporters/test_gdx_writer.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/exporters/test_sql_writer.py b/tests/spine_io/exporters/test_sql_writer.py index b674e8df..702adeac 100644 --- a/tests/spine_io/exporters/test_sql_writer.py +++ b/tests/spine_io/exporters/test_sql_writer.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/exporters/test_writer.py b/tests/spine_io/exporters/test_writer.py index 9a45f1a4..9a9283c0 100644 --- a/tests/spine_io/exporters/test_writer.py +++ b/tests/spine_io/exporters/test_writer.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/importers/__init__.py b/tests/spine_io/importers/__init__.py index d705009c..bb67bf8a 100644 --- a/tests/spine_io/importers/__init__.py +++ b/tests/spine_io/importers/__init__.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/importers/test_CSVConnector.py b/tests/spine_io/importers/test_CSVConnector.py index 956c86e2..b8a9d93a 100644 --- a/tests/spine_io/importers/test_CSVConnector.py +++ b/tests/spine_io/importers/test_CSVConnector.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/importers/test_GdxConnector.py b/tests/spine_io/importers/test_GdxConnector.py index ac0f4b28..7a55c992 100644 --- a/tests/spine_io/importers/test_GdxConnector.py +++ b/tests/spine_io/importers/test_GdxConnector.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/importers/test_datapackage_reader.py b/tests/spine_io/importers/test_datapackage_reader.py index 4b991f65..75906b29 100644 --- a/tests/spine_io/importers/test_datapackage_reader.py +++ b/tests/spine_io/importers/test_datapackage_reader.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/importers/test_excel_reader.py b/tests/spine_io/importers/test_excel_reader.py index 5477188c..8c3f59ed 100644 --- a/tests/spine_io/importers/test_excel_reader.py +++ b/tests/spine_io/importers/test_excel_reader.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/importers/test_json_reader.py b/tests/spine_io/importers/test_json_reader.py index 7b20c007..303bf7ec 100644 --- a/tests/spine_io/importers/test_json_reader.py +++ b/tests/spine_io/importers/test_json_reader.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/importers/test_reader.py b/tests/spine_io/importers/test_reader.py index c4dc1826..71e90535 100644 --- a/tests/spine_io/importers/test_reader.py +++ b/tests/spine_io/importers/test_reader.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/importers/test_sqlalchemy_connector.py b/tests/spine_io/importers/test_sqlalchemy_connector.py index 50230112..9fab91b1 100644 --- a/tests/spine_io/importers/test_sqlalchemy_connector.py +++ b/tests/spine_io/importers/test_sqlalchemy_connector.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/test_excel_integration.py b/tests/spine_io/test_excel_integration.py index d1bbe6a7..7ca1a318 100644 --- a/tests/spine_io/test_excel_integration.py +++ b/tests/spine_io/test_excel_integration.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index f6d03af3..cb2596ce 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/test_check_integrity.py b/tests/test_check_integrity.py index b4476d83..9d54caae 100644 --- a/tests/test_check_integrity.py +++ b/tests/test_check_integrity.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/test_db_mapping_base.py b/tests/test_db_mapping_base.py index e55a4b32..179c1ed8 100644 --- a/tests/test_db_mapping_base.py +++ b/tests/test_db_mapping_base.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/test_export_functions.py b/tests/test_export_functions.py index d442e372..3a4e9ccc 100644 --- a/tests/test_export_functions.py +++ b/tests/test_export_functions.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General # Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index a9879261..5497ae66 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py index 17dcdf4c..6cfec085 100644 --- a/tests/test_import_functions.py +++ b/tests/test_import_functions.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/test_mapping.py b/tests/test_mapping.py index a00aebba..94358de2 100644 --- a/tests/test_mapping.py +++ b/tests/test_mapping.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/test_migration.py b/tests/test_migration.py index d9168e49..c9cf76d1 100644 --- a/tests/test_migration.py +++ b/tests/test_migration.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/test_parameter_value.py b/tests/test_parameter_value.py index 77e794f1..8ae63ef5 100644 --- a/tests/test_parameter_value.py +++ b/tests/test_parameter_value.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/test_purge.py b/tests/test_purge.py index d465afe8..f6093e1c 100644 --- a/tests/test_purge.py +++ b/tests/test_purge.py @@ -1,6 +1,5 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium -# Copyright (C) 2023-2024 Mopo project consortium # This file is part of Spine Database API. # Spine Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General # Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) From 9e7abb5ebc78364d08cf4378a4649a5b79afd833 Mon Sep 17 00:00:00 2001 From: Pekka T Savolainen Date: Tue, 30 Jan 2024 13:11:56 +0200 Subject: [PATCH 235/317] New version of cx_oracle is called oracledb Re spine-tools/Spine-Toolbox#2524 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 47410815..52077d59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "chardet >=4.0.0", "pymysql >=1.0.2", "psycopg2", - "cx_Oracle", + "oracledb", ] [project.urls] From 2563b3bbf3c6d2740d32b9772a191e3bb495f73d Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Tue, 30 Jan 2024 13:21:47 +0200 Subject: [PATCH 236/317] Rewrite the pages about metadata in the docs This replaces "Metadata description" and "Result metadata description" pages in the documentation by a single "Metadata" paget. The new page contains a short introduction to metadata and some suggestions on what kind of stuff to store there. This is in contrast to the datapackage-like JSON definition that the two removed pages contained. Re #329 --- docs/source/index.rst | 3 +- docs/source/metadata.rst | 60 +++++++++++ docs/source/metadata_description.rst | 105 ------------------- docs/source/results_metadata_description.rst | 43 -------- 4 files changed, 61 insertions(+), 150 deletions(-) create mode 100644 docs/source/metadata.rst delete mode 100644 docs/source/metadata_description.rst delete mode 100644 docs/source/results_metadata_description.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index 9815a94c..fbee0830 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -17,8 +17,7 @@ Welcome to Spine Database API's documentation! front_matter tutorial parameter_value_format - metadata_description - results_metadata_description + metadata db_mapping_schema autoapi/index diff --git a/docs/source/metadata.rst b/docs/source/metadata.rst new file mode 100644 index 00000000..6446ba2a --- /dev/null +++ b/docs/source/metadata.rst @@ -0,0 +1,60 @@ +******** +Metadata +******** + +Metadata can be used to provide additional information about the data in Spine data structure. +Every entity and parameter value can have metadata associated with it. + +A metadata "item" has a *name* and a *value*, e.g. "authors" and "N.N, M.M et al.". +The same metadata item can be referenced by multiple entities and parameter values. +Entities and values can also refer to multiple items of metadata. + +.. note:: + + Referring to multiple items of metadata from a huge number of entities or parameter values + may take up a lot of space in the database. + Therefore, it might make more sense, for example, + to list all contributors to the data in a single metadata value than + having each contributor as a separate name-value pair. + +Choosing suitable names and values for metadata is left up to the user. +However, some suggestions and recommendations are presented below. + +title + One sentence description for the data. + +sources + The raw sources of the data. + +tools + Names and versions of tools that were used to process the data. + +contributors + The people or organisations who contributed to the data. + +created + The date this data was created or put together, e.g. in ISO8601 format (YYYY-MM-DDTHH:MM). + +description + A more complete description of the data than the title. + +keywords + Keywords that categorize the data. + +homepage + A URL for the home on the web that is related to the data. + +id + Globally unique id, such as UUID or DOI. + +licenses + Licences that apply to the data. + +temporal + Temporal properties of the data. + +spatial + Spatial properties of the data. + +unitOfMeasurement + Unit of measurement. diff --git a/docs/source/metadata_description.rst b/docs/source/metadata_description.rst deleted file mode 100644 index aeccdaba..00000000 --- a/docs/source/metadata_description.rst +++ /dev/null @@ -1,105 +0,0 @@ -******************** -Metadata description -******************** - -This is the metadata description for Spine, edited from ``_. - -Required properties -------------------- - -``title`` - One sentence description for the data. - -``sources`` - The raw sources of the data. Each source must have a ``title`` property and optionally a ``path`` property. - -.. code-block:: - :caption: Example - - "sources": [{ - "title": "World Bank and OECD", - "path": "http://data.worldbank.org/indicator/NY.GDP.MKTP.CD" - }] - -``contributors`` - The people or organisations who contributed to the data. - Must include ``title`` and may include ``path``, ``email``, ``role`` and ``organization``. - ``role`` is one of ``author``, ``publisher``, ``maintainer``, ``wrangler``, or ``contributor``. - -.. code-block:: - :caption: Example - - "contributors": [{ - "title": "Joe Bloggs", - "email": "joe@bloggs.com", - "path": "http://www.bloggs.com", - "role": "author" - }] - -``created`` - The date this data was created or put together, in ISO8601 format (YYYY-MM-DDTHH:MM). - -Optional properties -------------------- - -``description`` - A description of the data. Describe here how the data was collected, how it was processed etc. - The description *must* be markdown formatted – - this also allows for simple plain text as plain text is itself valid markdown. - The first paragraph (up to the first double line break) should be usable as summary information for the package. - -``spine_results_metadata`` - Property contains :ref:`results metadata description`. - -``keywords`` - An array of keywords. - -``homepage`` - A URL for the home on the web that is related to this data package. - -``name`` - Name of the data package, url-usable, all-lowercase string. - -``id`` - Globally unique id, such as UUID or DOI. - -``licenses`` - Licences that apply to the data. - Each item must have a name property (Open Definition license ID) or a path property and may contain title. - -.. code-block:: - :caption: Example - - "licenses": [{ - "name": "ODC-PDDL-1.0", - "path": "http://opendatacommons.org/licenses/pddl/", - "title": "Open Data Commons Public Domain Dedication and License v1.0" - }] - -``temporal`` - Temporal properties of the data (if applicable). - -.. code-block:: - :caption: Example using DCMI Period Encoding Scheme - - "temporal": { - "start": "2000-01-01", - "end": "2000-12-31", - "name": "The first year of the 21st century" - } - -``spatial`` - Spatial properties of the data (if applicable). - -.. code-block:: - :caption: Example using DCMI Point Encoding Scheme - - "spatial": { - "east": 23.766667, - "north": 61.5, - "projection": "geographic coordinates (WGS 84)", - "name": "Tampere, Finland" - } - -``unitOfMeasurement`` - Unit of measurement. Can also be embedded in description. diff --git a/docs/source/results_metadata_description.rst b/docs/source/results_metadata_description.rst deleted file mode 100644 index 034f3aa2..00000000 --- a/docs/source/results_metadata_description.rst +++ /dev/null @@ -1,43 +0,0 @@ -.. _results-metadata-description: - -**************************** -Results metadata description -**************************** - -Required properties -------------------- - -``description`` - A plain text (or Markdown formatted) description what the data is about. - -``author`` - Author details. - -.. code-block:: - :caption: Example - - "author": { - "title": "Joe Bloggs", - "email": "joe@bloggs.com", - } - -``spine_toolbox_version`` - Version string of Spine Toolbox application. - -``created`` - The date these results were created, in ISO8601 format (YYYY-MM-DDTHH:MM). - -Optional properties -------------------- - -``tools`` - An array of records with processing tool names and versions. - -.. code-block:: - :caption: Example - - "tools": [{"name": "Spine Model", - "version": "1.0.2", - "path": "https://github.com/spine-tools/Spine-Model"}, - ... - ] From 0bb592ba30a5cdf4ea9873d7215432fa8d5e9a7e Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Tue, 30 Jan 2024 14:09:30 +0200 Subject: [PATCH 237/317] Update copyrights Re spine-tools/Spine-Toolbox#2514 --- bin/update_copyrights.py | 36 ------------------- spinedb_api/__init__.py | 1 + spinedb_api/compatibility.py | 1 + spinedb_api/conflict_resolution.py | 1 + spinedb_api/db_mapping.py | 1 + spinedb_api/db_mapping_base.py | 2 ++ spinedb_api/db_mapping_commit_mixin.py | 1 + spinedb_api/db_mapping_query_mixin.py | 1 + spinedb_api/exception.py | 1 + spinedb_api/export_functions.py | 1 + spinedb_api/export_mapping/__init__.py | 1 + spinedb_api/export_mapping/export_mapping.py | 1 + spinedb_api/export_mapping/generator.py | 1 + spinedb_api/export_mapping/group_functions.py | 1 + spinedb_api/export_mapping/pivot.py | 1 + spinedb_api/export_mapping/settings.py | 1 + spinedb_api/filters/__init__.py | 1 + spinedb_api/filters/alternative_filter.py | 1 + spinedb_api/filters/execution_filter.py | 1 + spinedb_api/filters/renamer.py | 1 + spinedb_api/filters/scenario_filter.py | 1 + spinedb_api/filters/tools.py | 1 + spinedb_api/filters/value_transformer.py | 1 + spinedb_api/graph_layout_generator.py | 1 + spinedb_api/helpers.py | 1 + spinedb_api/import_functions.py | 1 + spinedb_api/import_mapping/__init__.py | 1 + spinedb_api/import_mapping/generator.py | 1 + spinedb_api/import_mapping/import_mapping.py | 1 + .../import_mapping/import_mapping_compat.py | 1 + spinedb_api/import_mapping/type_conversion.py | 1 + spinedb_api/item_id.py | 1 + spinedb_api/item_status.py | 1 + spinedb_api/mapped_items.py | 1 + spinedb_api/mapping.py | 1 + spinedb_api/parameter_value.py | 1 + spinedb_api/perfect_split.py | 1 + spinedb_api/purge.py | 1 + spinedb_api/query.py | 1 + spinedb_api/server_client_helpers.py | 1 + spinedb_api/spine_db_client.py | 1 + spinedb_api/spine_db_server.py | 1 + spinedb_api/spine_io/__init__.py | 1 + spinedb_api/spine_io/exporters/__init__.py | 1 + spinedb_api/spine_io/exporters/csv_writer.py | 1 + spinedb_api/spine_io/exporters/excel.py | 1 + .../spine_io/exporters/excel_writer.py | 1 + spinedb_api/spine_io/exporters/gdx_writer.py | 1 + spinedb_api/spine_io/exporters/sql_writer.py | 1 + spinedb_api/spine_io/exporters/writer.py | 1 + spinedb_api/spine_io/gdx_utils.py | 1 + spinedb_api/spine_io/importers/__init__.py | 1 + spinedb_api/spine_io/importers/csv_reader.py | 1 + .../spine_io/importers/datapackage_reader.py | 1 + .../spine_io/importers/excel_reader.py | 1 + .../spine_io/importers/gdx_connector.py | 1 + spinedb_api/spine_io/importers/json_reader.py | 1 + spinedb_api/spine_io/importers/reader.py | 1 + .../importers/sqlalchemy_connector.py | 1 + tests/__init__.py | 1 + tests/custom_db_mapping.py | 1 + tests/export_mapping/__init__.py | 1 + tests/export_mapping/test_export_mapping.py | 1 + tests/export_mapping/test_pivot.py | 1 + tests/export_mapping/test_settings.py | 1 + tests/filters/__init__.py | 1 + tests/filters/test_alternative_filter.py | 1 + tests/filters/test_execution_filter.py | 1 + tests/filters/test_renamer.py | 1 + tests/filters/test_scenario_filter.py | 1 + tests/filters/test_tool_filter.py | 1 + tests/filters/test_tools.py | 1 + tests/filters/test_value_transformer.py | 1 + tests/import_mapping/__init__.py | 1 + tests/import_mapping/test_generator.py | 1 + tests/import_mapping/test_import_mapping.py | 1 + tests/import_mapping/test_type_conversion.py | 1 + tests/spine_io/__init__.py | 1 + tests/spine_io/exporters/__init__.py | 1 + tests/spine_io/exporters/test_csv_writer.py | 1 + tests/spine_io/exporters/test_excel_writer.py | 1 + tests/spine_io/exporters/test_gdx_writer.py | 1 + tests/spine_io/exporters/test_sql_writer.py | 1 + tests/spine_io/exporters/test_writer.py | 1 + tests/spine_io/importers/__init__.py | 1 + tests/spine_io/importers/test_CSVConnector.py | 1 + tests/spine_io/importers/test_GdxConnector.py | 1 + .../importers/test_datapackage_reader.py | 1 + tests/spine_io/importers/test_excel_reader.py | 1 + tests/spine_io/importers/test_json_reader.py | 1 + tests/spine_io/importers/test_reader.py | 1 + .../importers/test_sqlalchemy_connector.py | 1 + tests/spine_io/test_excel_integration.py | 1 + tests/test_DatabaseMapping.py | 1 + tests/test_check_integrity.py | 1 + tests/test_export_functions.py | 1 + tests/test_helpers.py | 1 + tests/test_import_functions.py | 1 + tests/test_item_id.py | 1 + tests/test_mapping.py | 1 + tests/test_migration.py | 1 + tests/test_parameter_value.py | 1 + tests/test_purge.py | 1 + 103 files changed, 103 insertions(+), 36 deletions(-) delete mode 100644 bin/update_copyrights.py diff --git a/bin/update_copyrights.py b/bin/update_copyrights.py deleted file mode 100644 index 6194729c..00000000 --- a/bin/update_copyrights.py +++ /dev/null @@ -1,36 +0,0 @@ -#!/usr/bin/env python - -from pathlib import Path -import time - - -current_year = time.gmtime().tm_year -root_dir = Path(__file__).parent.parent -project_source_dir = Path(root_dir, "spinedb_api") -test_source_dir = Path(root_dir, "tests") - -expected = f"# Copyright (C) 2017-{current_year} Spine project consortium" - - -def update_copyrights(path, suffix, recursive=True): - for path in path.iterdir(): - if path.suffix == suffix: - i = 0 - with open(path) as python_file: - lines = python_file.readlines() - for i, line in enumerate(lines[1:4]): - if line.startswith("# Copyright (C) "): - lines[i + 1] = lines[i + 1][:21] + str(current_year) + lines[i + 1][25:] - break - if len(lines) <= i + 1 or not lines[i + 1].startswith(expected): - print(f"Confusing or no copyright: {path}") - else: - with open(path, "w") as python_file: - python_file.writelines(lines) - elif recursive and path.is_dir(): - update_copyrights(path, suffix) - - -update_copyrights(root_dir, ".py", recursive=False) -update_copyrights(project_source_dir, ".py") -update_copyrights(test_source_dir, ".py") diff --git a/spinedb_api/__init__.py b/spinedb_api/__init__.py index b6b41a6b..c947b13c 100644 --- a/spinedb_api/__init__.py +++ b/spinedb_api/__init__.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/compatibility.py b/spinedb_api/compatibility.py index 92a6783b..04b02804 100644 --- a/spinedb_api/compatibility.py +++ b/spinedb_api/compatibility.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/conflict_resolution.py b/spinedb_api/conflict_resolution.py index 9e459b99..aa73f4c9 100644 --- a/spinedb_api/conflict_resolution.py +++ b/spinedb_api/conflict_resolution.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 8992ee8e..ab7b8aad 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index c3253258..30662108 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -1,5 +1,7 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py index 14f80d4e..ae5985e6 100644 --- a/spinedb_api/db_mapping_commit_mixin.py +++ b/spinedb_api/db_mapping_commit_mixin.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/db_mapping_query_mixin.py b/spinedb_api/db_mapping_query_mixin.py index c5fd4f37..30831290 100644 --- a/spinedb_api/db_mapping_query_mixin.py +++ b/spinedb_api/db_mapping_query_mixin.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/exception.py b/spinedb_api/exception.py index c2554dab..4a6ad078 100644 --- a/spinedb_api/exception.py +++ b/spinedb_api/exception.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/export_functions.py b/spinedb_api/export_functions.py index 1040a427..2da0b132 100644 --- a/spinedb_api/export_functions.py +++ b/spinedb_api/export_functions.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General # Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) diff --git a/spinedb_api/export_mapping/__init__.py b/spinedb_api/export_mapping/__init__.py index c75b202c..08655b98 100644 --- a/spinedb_api/export_mapping/__init__.py +++ b/spinedb_api/export_mapping/__init__.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/export_mapping/export_mapping.py b/spinedb_api/export_mapping/export_mapping.py index e5ccef3e..69ba2fa4 100644 --- a/spinedb_api/export_mapping/export_mapping.py +++ b/spinedb_api/export_mapping/export_mapping.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/export_mapping/generator.py b/spinedb_api/export_mapping/generator.py index 026454be..1cc551c7 100644 --- a/spinedb_api/export_mapping/generator.py +++ b/spinedb_api/export_mapping/generator.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/export_mapping/group_functions.py b/spinedb_api/export_mapping/group_functions.py index ce4598fb..edfb7e76 100644 --- a/spinedb_api/export_mapping/group_functions.py +++ b/spinedb_api/export_mapping/group_functions.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/export_mapping/pivot.py b/spinedb_api/export_mapping/pivot.py index afd78344..fd1bfcf3 100644 --- a/spinedb_api/export_mapping/pivot.py +++ b/spinedb_api/export_mapping/pivot.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/export_mapping/settings.py b/spinedb_api/export_mapping/settings.py index ed10ff58..7fb932fe 100644 --- a/spinedb_api/export_mapping/settings.py +++ b/spinedb_api/export_mapping/settings.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/filters/__init__.py b/spinedb_api/filters/__init__.py index 46105c99..5f01a618 100644 --- a/spinedb_api/filters/__init__.py +++ b/spinedb_api/filters/__init__.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/filters/alternative_filter.py b/spinedb_api/filters/alternative_filter.py index f406f793..b76ded9f 100644 --- a/spinedb_api/filters/alternative_filter.py +++ b/spinedb_api/filters/alternative_filter.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/filters/execution_filter.py b/spinedb_api/filters/execution_filter.py index 51a9f2db..bbaa4932 100644 --- a/spinedb_api/filters/execution_filter.py +++ b/spinedb_api/filters/execution_filter.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/filters/renamer.py b/spinedb_api/filters/renamer.py index a3970ac7..f591d883 100644 --- a/spinedb_api/filters/renamer.py +++ b/spinedb_api/filters/renamer.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/filters/scenario_filter.py b/spinedb_api/filters/scenario_filter.py index 3bbcb5e7..80d264fa 100644 --- a/spinedb_api/filters/scenario_filter.py +++ b/spinedb_api/filters/scenario_filter.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/filters/tools.py b/spinedb_api/filters/tools.py index 94d60545..0787d515 100644 --- a/spinedb_api/filters/tools.py +++ b/spinedb_api/filters/tools.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/filters/value_transformer.py b/spinedb_api/filters/value_transformer.py index 956de19d..6e444726 100644 --- a/spinedb_api/filters/value_transformer.py +++ b/spinedb_api/filters/value_transformer.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/graph_layout_generator.py b/spinedb_api/graph_layout_generator.py index be0a149a..0d200ee6 100644 --- a/spinedb_api/graph_layout_generator.py +++ b/spinedb_api/graph_layout_generator.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Engine. # Spine Engine is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General # Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index aaba62be..9f26606f 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index c270c3e4..62cb63a0 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General # Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) diff --git a/spinedb_api/import_mapping/__init__.py b/spinedb_api/import_mapping/__init__.py index 9966601e..08d399e5 100644 --- a/spinedb_api/import_mapping/__init__.py +++ b/spinedb_api/import_mapping/__init__.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/import_mapping/generator.py b/spinedb_api/import_mapping/generator.py index b73037a4..43d7ee03 100644 --- a/spinedb_api/import_mapping/generator.py +++ b/spinedb_api/import_mapping/generator.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/import_mapping/import_mapping.py b/spinedb_api/import_mapping/import_mapping.py index c07d4f3c..625809c5 100644 --- a/spinedb_api/import_mapping/import_mapping.py +++ b/spinedb_api/import_mapping/import_mapping.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/import_mapping/import_mapping_compat.py b/spinedb_api/import_mapping/import_mapping_compat.py index c8170b51..4434334d 100644 --- a/spinedb_api/import_mapping/import_mapping_compat.py +++ b/spinedb_api/import_mapping/import_mapping_compat.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/import_mapping/type_conversion.py b/spinedb_api/import_mapping/type_conversion.py index 4f054c9c..fea46b4b 100644 --- a/spinedb_api/import_mapping/type_conversion.py +++ b/spinedb_api/import_mapping/type_conversion.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/item_id.py b/spinedb_api/item_id.py index 2aec7f84..441736f4 100644 --- a/spinedb_api/item_id.py +++ b/spinedb_api/item_id.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/item_status.py b/spinedb_api/item_status.py index 7fb8e76f..32e4a6a2 100644 --- a/spinedb_api/item_status.py +++ b/spinedb_api/item_status.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 3f6b8eab..59833e9f 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/mapping.py b/spinedb_api/mapping.py index 5b197b17..8e30318a 100644 --- a/spinedb_api/mapping.py +++ b/spinedb_api/mapping.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/parameter_value.py b/spinedb_api/parameter_value.py index 7dbd704d..b2a5aa76 100644 --- a/spinedb_api/parameter_value.py +++ b/spinedb_api/parameter_value.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/perfect_split.py b/spinedb_api/perfect_split.py index 6028e47a..287a4a87 100644 --- a/spinedb_api/perfect_split.py +++ b/spinedb_api/perfect_split.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General # Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) diff --git a/spinedb_api/purge.py b/spinedb_api/purge.py index 472fa5c3..179541c5 100644 --- a/spinedb_api/purge.py +++ b/spinedb_api/purge.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General # Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) diff --git a/spinedb_api/query.py b/spinedb_api/query.py index db23f1e7..3cbdbe00 100644 --- a/spinedb_api/query.py +++ b/spinedb_api/query.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/server_client_helpers.py b/spinedb_api/server_client_helpers.py index 318698a9..d8d62ede 100644 --- a/spinedb_api/server_client_helpers.py +++ b/spinedb_api/server_client_helpers.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_db_client.py b/spinedb_api/spine_db_client.py index 570043b4..017af3e8 100644 --- a/spinedb_api/spine_db_client.py +++ b/spinedb_api/spine_db_client.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Engine. # Spine Engine is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General # Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) diff --git a/spinedb_api/spine_db_server.py b/spinedb_api/spine_db_server.py index 4ac5b7d8..95853922 100644 --- a/spinedb_api/spine_db_server.py +++ b/spinedb_api/spine_db_server.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/__init__.py b/spinedb_api/spine_io/__init__.py index adea0648..861b4d46 100644 --- a/spinedb_api/spine_io/__init__.py +++ b/spinedb_api/spine_io/__init__.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/exporters/__init__.py b/spinedb_api/spine_io/exporters/__init__.py index 3d6ed59b..bcc84ab1 100644 --- a/spinedb_api/spine_io/exporters/__init__.py +++ b/spinedb_api/spine_io/exporters/__init__.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/exporters/csv_writer.py b/spinedb_api/spine_io/exporters/csv_writer.py index 4974be65..05dc7616 100644 --- a/spinedb_api/spine_io/exporters/csv_writer.py +++ b/spinedb_api/spine_io/exporters/csv_writer.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/exporters/excel.py b/spinedb_api/spine_io/exporters/excel.py index d75ecd95..56798905 100644 --- a/spinedb_api/spine_io/exporters/excel.py +++ b/spinedb_api/spine_io/exporters/excel.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/exporters/excel_writer.py b/spinedb_api/spine_io/exporters/excel_writer.py index 2c1996d2..e29c34ca 100644 --- a/spinedb_api/spine_io/exporters/excel_writer.py +++ b/spinedb_api/spine_io/exporters/excel_writer.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/exporters/gdx_writer.py b/spinedb_api/spine_io/exporters/gdx_writer.py index faae3d53..209cfa59 100644 --- a/spinedb_api/spine_io/exporters/gdx_writer.py +++ b/spinedb_api/spine_io/exporters/gdx_writer.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/exporters/sql_writer.py b/spinedb_api/spine_io/exporters/sql_writer.py index 6065d997..f7c84b78 100644 --- a/spinedb_api/spine_io/exporters/sql_writer.py +++ b/spinedb_api/spine_io/exporters/sql_writer.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/exporters/writer.py b/spinedb_api/spine_io/exporters/writer.py index 26698854..bf661970 100644 --- a/spinedb_api/spine_io/exporters/writer.py +++ b/spinedb_api/spine_io/exporters/writer.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/gdx_utils.py b/spinedb_api/spine_io/gdx_utils.py index 229c9338..0a033bb6 100644 --- a/spinedb_api/spine_io/gdx_utils.py +++ b/spinedb_api/spine_io/gdx_utils.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/importers/__init__.py b/spinedb_api/spine_io/importers/__init__.py index ab3c7b4b..cb5c60c9 100644 --- a/spinedb_api/spine_io/importers/__init__.py +++ b/spinedb_api/spine_io/importers/__init__.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/importers/csv_reader.py b/spinedb_api/spine_io/importers/csv_reader.py index b3381bba..a9bea8c4 100644 --- a/spinedb_api/spine_io/importers/csv_reader.py +++ b/spinedb_api/spine_io/importers/csv_reader.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/importers/datapackage_reader.py b/spinedb_api/spine_io/importers/datapackage_reader.py index 8fa1bc5b..baf05ef9 100644 --- a/spinedb_api/spine_io/importers/datapackage_reader.py +++ b/spinedb_api/spine_io/importers/datapackage_reader.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/importers/excel_reader.py b/spinedb_api/spine_io/importers/excel_reader.py index 22cabe0f..0e143cf3 100644 --- a/spinedb_api/spine_io/importers/excel_reader.py +++ b/spinedb_api/spine_io/importers/excel_reader.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/importers/gdx_connector.py b/spinedb_api/spine_io/importers/gdx_connector.py index 61f28935..f9c85071 100644 --- a/spinedb_api/spine_io/importers/gdx_connector.py +++ b/spinedb_api/spine_io/importers/gdx_connector.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/importers/json_reader.py b/spinedb_api/spine_io/importers/json_reader.py index 024b98d7..70d211ea 100644 --- a/spinedb_api/spine_io/importers/json_reader.py +++ b/spinedb_api/spine_io/importers/json_reader.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/importers/reader.py b/spinedb_api/spine_io/importers/reader.py index 3a645e96..0f3d875c 100644 --- a/spinedb_api/spine_io/importers/reader.py +++ b/spinedb_api/spine_io/importers/reader.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/spinedb_api/spine_io/importers/sqlalchemy_connector.py b/spinedb_api/spine_io/importers/sqlalchemy_connector.py index e187356d..360e3ec6 100644 --- a/spinedb_api/spine_io/importers/sqlalchemy_connector.py +++ b/spinedb_api/spine_io/importers/sqlalchemy_connector.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/__init__.py b/tests/__init__.py index f9452d16..c183189d 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/custom_db_mapping.py b/tests/custom_db_mapping.py index ab578e3f..991da42e 100644 --- a/tests/custom_db_mapping.py +++ b/tests/custom_db_mapping.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/export_mapping/__init__.py b/tests/export_mapping/__init__.py index 46105c99..5f01a618 100644 --- a/tests/export_mapping/__init__.py +++ b/tests/export_mapping/__init__.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/export_mapping/test_export_mapping.py b/tests/export_mapping/test_export_mapping.py index 826588ac..00a5609e 100644 --- a/tests/export_mapping/test_export_mapping.py +++ b/tests/export_mapping/test_export_mapping.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/export_mapping/test_pivot.py b/tests/export_mapping/test_pivot.py index cfe2e12c..657f6176 100644 --- a/tests/export_mapping/test_pivot.py +++ b/tests/export_mapping/test_pivot.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/export_mapping/test_settings.py b/tests/export_mapping/test_settings.py index fb84d20a..2d4e1044 100644 --- a/tests/export_mapping/test_settings.py +++ b/tests/export_mapping/test_settings.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/filters/__init__.py b/tests/filters/__init__.py index 46105c99..5f01a618 100644 --- a/tests/filters/__init__.py +++ b/tests/filters/__init__.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/filters/test_alternative_filter.py b/tests/filters/test_alternative_filter.py index a677bbe2..11e4b799 100644 --- a/tests/filters/test_alternative_filter.py +++ b/tests/filters/test_alternative_filter.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/filters/test_execution_filter.py b/tests/filters/test_execution_filter.py index 6a092ee6..c43a894e 100644 --- a/tests/filters/test_execution_filter.py +++ b/tests/filters/test_execution_filter.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/filters/test_renamer.py b/tests/filters/test_renamer.py index d3bf469e..3fac2083 100644 --- a/tests/filters/test_renamer.py +++ b/tests/filters/test_renamer.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/filters/test_scenario_filter.py b/tests/filters/test_scenario_filter.py index de1be424..f804d515 100644 --- a/tests/filters/test_scenario_filter.py +++ b/tests/filters/test_scenario_filter.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/filters/test_tool_filter.py b/tests/filters/test_tool_filter.py index a85455da..4e90e46d 100644 --- a/tests/filters/test_tool_filter.py +++ b/tests/filters/test_tool_filter.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/filters/test_tools.py b/tests/filters/test_tools.py index 41e7f03b..f5e27410 100644 --- a/tests/filters/test_tools.py +++ b/tests/filters/test_tools.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/filters/test_value_transformer.py b/tests/filters/test_value_transformer.py index 3660e85a..f3d39cf8 100644 --- a/tests/filters/test_value_transformer.py +++ b/tests/filters/test_value_transformer.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/import_mapping/__init__.py b/tests/import_mapping/__init__.py index 46105c99..5f01a618 100644 --- a/tests/import_mapping/__init__.py +++ b/tests/import_mapping/__init__.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/import_mapping/test_generator.py b/tests/import_mapping/test_generator.py index df0a8dac..589210ec 100644 --- a/tests/import_mapping/test_generator.py +++ b/tests/import_mapping/test_generator.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/import_mapping/test_import_mapping.py b/tests/import_mapping/test_import_mapping.py index ff2c5e6c..d8b24df4 100644 --- a/tests/import_mapping/test_import_mapping.py +++ b/tests/import_mapping/test_import_mapping.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/import_mapping/test_type_conversion.py b/tests/import_mapping/test_type_conversion.py index c772dabd..803b88f3 100644 --- a/tests/import_mapping/test_type_conversion.py +++ b/tests/import_mapping/test_type_conversion.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/__init__.py b/tests/spine_io/__init__.py index 219c44b8..c4267e43 100644 --- a/tests/spine_io/__init__.py +++ b/tests/spine_io/__init__.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/exporters/__init__.py b/tests/spine_io/exporters/__init__.py index a0581eb2..f31adfe2 100644 --- a/tests/spine_io/exporters/__init__.py +++ b/tests/spine_io/exporters/__init__.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/exporters/test_csv_writer.py b/tests/spine_io/exporters/test_csv_writer.py index 9ddf6dfc..7df82da9 100644 --- a/tests/spine_io/exporters/test_csv_writer.py +++ b/tests/spine_io/exporters/test_csv_writer.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/exporters/test_excel_writer.py b/tests/spine_io/exporters/test_excel_writer.py index 65f8cf02..3176ef8b 100644 --- a/tests/spine_io/exporters/test_excel_writer.py +++ b/tests/spine_io/exporters/test_excel_writer.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/exporters/test_gdx_writer.py b/tests/spine_io/exporters/test_gdx_writer.py index 48407fd4..5fa97328 100644 --- a/tests/spine_io/exporters/test_gdx_writer.py +++ b/tests/spine_io/exporters/test_gdx_writer.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/exporters/test_sql_writer.py b/tests/spine_io/exporters/test_sql_writer.py index 702adeac..278593a0 100644 --- a/tests/spine_io/exporters/test_sql_writer.py +++ b/tests/spine_io/exporters/test_sql_writer.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/exporters/test_writer.py b/tests/spine_io/exporters/test_writer.py index 9a9283c0..3a150b9b 100644 --- a/tests/spine_io/exporters/test_writer.py +++ b/tests/spine_io/exporters/test_writer.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/importers/__init__.py b/tests/spine_io/importers/__init__.py index bb67bf8a..7d98b02d 100644 --- a/tests/spine_io/importers/__init__.py +++ b/tests/spine_io/importers/__init__.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/importers/test_CSVConnector.py b/tests/spine_io/importers/test_CSVConnector.py index b8a9d93a..ee746fc8 100644 --- a/tests/spine_io/importers/test_CSVConnector.py +++ b/tests/spine_io/importers/test_CSVConnector.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/importers/test_GdxConnector.py b/tests/spine_io/importers/test_GdxConnector.py index 7a55c992..39e871da 100644 --- a/tests/spine_io/importers/test_GdxConnector.py +++ b/tests/spine_io/importers/test_GdxConnector.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/importers/test_datapackage_reader.py b/tests/spine_io/importers/test_datapackage_reader.py index 75906b29..58fabb3f 100644 --- a/tests/spine_io/importers/test_datapackage_reader.py +++ b/tests/spine_io/importers/test_datapackage_reader.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/importers/test_excel_reader.py b/tests/spine_io/importers/test_excel_reader.py index 8c3f59ed..77fdf4cb 100644 --- a/tests/spine_io/importers/test_excel_reader.py +++ b/tests/spine_io/importers/test_excel_reader.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/importers/test_json_reader.py b/tests/spine_io/importers/test_json_reader.py index 303bf7ec..58c2754b 100644 --- a/tests/spine_io/importers/test_json_reader.py +++ b/tests/spine_io/importers/test_json_reader.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/importers/test_reader.py b/tests/spine_io/importers/test_reader.py index 71e90535..4cb9026f 100644 --- a/tests/spine_io/importers/test_reader.py +++ b/tests/spine_io/importers/test_reader.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/importers/test_sqlalchemy_connector.py b/tests/spine_io/importers/test_sqlalchemy_connector.py index 9fab91b1..d4e4ff18 100644 --- a/tests/spine_io/importers/test_sqlalchemy_connector.py +++ b/tests/spine_io/importers/test_sqlalchemy_connector.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/spine_io/test_excel_integration.py b/tests/spine_io/test_excel_integration.py index 7ca1a318..925d0e56 100644 --- a/tests/spine_io/test_excel_integration.py +++ b/tests/spine_io/test_excel_integration.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 3069b7c6..f0aa98ac 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/test_check_integrity.py b/tests/test_check_integrity.py index 9d54caae..afdd4c2d 100644 --- a/tests/test_check_integrity.py +++ b/tests/test_check_integrity.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/test_export_functions.py b/tests/test_export_functions.py index 3a4e9ccc..2a8aaba9 100644 --- a/tests/test_export_functions.py +++ b/tests/test_export_functions.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General # Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 91629536..194f8afd 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py index 6cfec085..373ba0f9 100644 --- a/tests/test_import_functions.py +++ b/tests/test_import_functions.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/test_item_id.py b/tests/test_item_id.py index fda8d602..e1965ae3 100644 --- a/tests/test_item_id.py +++ b/tests/test_item_id.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/test_mapping.py b/tests/test_mapping.py index 94358de2..3e678a5e 100644 --- a/tests/test_mapping.py +++ b/tests/test_mapping.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/test_migration.py b/tests/test_migration.py index c9cf76d1..27c1a882 100644 --- a/tests/test_migration.py +++ b/tests/test_migration.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/test_parameter_value.py b/tests/test_parameter_value.py index 8ae63ef5..d54604c7 100644 --- a/tests/test_parameter_value.py +++ b/tests/test_parameter_value.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your diff --git a/tests/test_purge.py b/tests/test_purge.py index f6093e1c..8c08c81a 100644 --- a/tests/test_purge.py +++ b/tests/test_purge.py @@ -1,5 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General # Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) From e9d10b4e4b8f81fc9401e914128ec987a9913ff1 Mon Sep 17 00:00:00 2001 From: Pekka T Savolainen Date: Tue, 30 Jan 2024 14:48:48 +0200 Subject: [PATCH 238/317] Remove cx_oracle from requirements and 'oracle' from unsupported dialects --- pyproject.toml | 1 - spinedb_api/helpers.py | 1 - 2 files changed, 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 52077d59..bcfa9d3b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,6 @@ dependencies = [ "chardet >=4.0.0", "pymysql >=1.0.2", "psycopg2", - "oracledb", ] [project.urls] diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index da671683..cc4fa6d8 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -60,7 +60,6 @@ UNSUPPORTED_DIALECTS = { "mssql": "pyodbc", "postgresql": "psycopg2", - "oracle": "cx_oracle", } """Dialects and recommended dbapi that are not supported by DatabaseMapping but are supported by SqlAlchemy.""" From 58f12b35fb9947900d20c762de94051d877e80b9 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Wed, 31 Jan 2024 14:00:53 +0200 Subject: [PATCH 239/317] Revert "Make database mapping work when fetching external changes (#333)" This reverts commit d187d29c32943e462af2a50327070d132f9cb78b, reversing changes made to 474d3a72946b85b9c7182487b9c71dddfb089c9d. --- spinedb_api/conflict_resolution.py | 105 --- spinedb_api/db_mapping.py | 30 +- spinedb_api/db_mapping_base.py | 346 ++----- spinedb_api/db_mapping_commit_mixin.py | 52 +- spinedb_api/helpers.py | 12 - spinedb_api/item_id.py | 61 -- spinedb_api/item_status.py | 25 - spinedb_api/mapped_items.py | 197 +--- spinedb_api/server_client_helpers.py | 2 +- spinedb_api/temp_id.py | 54 ++ tests/test_DatabaseMapping.py | 1147 ++++-------------------- tests/test_db_mapping_base.py | 80 ++ tests/test_helpers.py | 74 -- tests/test_item_id.py | 71 -- 14 files changed, 425 insertions(+), 1831 deletions(-) delete mode 100644 spinedb_api/conflict_resolution.py delete mode 100644 spinedb_api/item_id.py delete mode 100644 spinedb_api/item_status.py create mode 100644 spinedb_api/temp_id.py create mode 100644 tests/test_db_mapping_base.py delete mode 100644 tests/test_item_id.py diff --git a/spinedb_api/conflict_resolution.py b/spinedb_api/conflict_resolution.py deleted file mode 100644 index aa73f4c9..00000000 --- a/spinedb_api/conflict_resolution.py +++ /dev/null @@ -1,105 +0,0 @@ -###################################################################################################################### -# Copyright (C) 2017-2022 Spine project consortium -# Copyright Spine Database API contributors -# This file is part of Spine Database API. -# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser -# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your -# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; -# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General -# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with -# this program. If not, see . -###################################################################################################################### -from __future__ import annotations -from enum import auto, Enum, unique -from dataclasses import dataclass - -from .item_status import Status - - -@unique -class Resolution(Enum): - USE_IN_MEMORY = auto() - USE_IN_DB = auto() - - -@dataclass -class Conflict: - in_memory: MappedItemBase - in_db: MappedItemBase - - -@dataclass -class Resolved(Conflict): - resolution: Resolution - - def __init__(self, conflict, resolution): - self.in_memory = conflict.in_memory - self.in_db = conflict.in_db - self.resolution = resolution - - -def select_in_memory_item_always(conflicts): - return [Resolved(conflict, Resolution.USE_IN_MEMORY) for conflict in conflicts] - - -def select_in_db_item_always(conflicts): - return [Resolved(conflict, Resolution.USE_IN_DB) for conflict in conflicts] - - -@dataclass -class KeepInMemoryAction: - in_memory: MappedItemBase - set_uncommitted: bool - - def __init__(self, conflict): - self.in_memory = conflict.in_memory - self.set_uncommitted = not conflict.in_memory.equal_ignoring_ids(conflict.in_db) - - -@dataclass -class UpdateInMemoryAction: - in_memory: MappedItemBase - in_db: MappedItemBase - - def __init__(self, conflict): - self.in_memory = conflict.in_memory - self.in_db = conflict.in_db - - -@dataclass -class ResurrectAction: - in_memory: MappedItemBase - in_db: MappedItemBase - - def __init__(self, conflict): - self.in_memory = conflict.in_memory - self.in_db = conflict.in_db - - -def resolved_conflict_actions(conflicts): - for conflict in conflicts: - if conflict.resolution == Resolution.USE_IN_MEMORY: - yield KeepInMemoryAction(conflict) - elif conflict.resolution == Resolution.USE_IN_DB: - yield UpdateInMemoryAction(conflict) - else: - raise RuntimeError(f"unknown conflict resolution") - - -def resurrection_conflicts_from_resolved(conflicts): - resurrection_conflicts = [] - for conflict in conflicts: - if conflict.resolution != Resolution.USE_IN_DB or not conflict.in_memory.removed: - continue - resurrection_conflicts.append(conflict) - return resurrection_conflicts - - -def make_changed_in_memory_items_dirty(conflicts): - for conflict in conflicts: - if conflict.resolution != Resolution.USE_IN_MEMORY: - continue - if conflict.in_memory.removed: - conflict.in_memory.status = Status.to_remove - elif conflict.in_memory.asdict_() != conflict.in_db: - conflict.in_memory.status = Status.to_update diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index ab7b8aad..ff2a8edf 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -34,7 +34,6 @@ from alembic.config import Config from alembic.util.exc import CommandError -from .conflict_resolution import select_in_memory_item_always from .filters.tools import pop_filter_configs, apply_filter_stack, load_filters from .spine_db_client import get_db_url_from_server from .mapped_items import item_factory @@ -366,16 +365,13 @@ def get_item(self, item_type, fetch=True, skip_removed=True, **kwargs): return {} return item.public_item - def get_items( - self, item_type, fetch=True, skip_removed=True, resolve_conflicts=select_in_memory_item_always, **kwargs - ): + def get_items(self, item_type, fetch=True, skip_removed=True, **kwargs): """Finds and returns all the items of one type. Args: item_type (str): One of . fetch (bool, optional): Whether to fetch the DB before returning the items. skip_removed (bool, optional): Whether to ignore removed items. - resolve_conflicts (Callable): function that resolves fetch conflicts **kwargs: Fields and values for one the unique keys as specified for the item type in :ref:`db_mapping_schema`. @@ -386,7 +382,7 @@ def get_items( mapped_table = self.mapped_table(item_type) mapped_table.check_fields(kwargs, valid_types=(type(None),)) if fetch: - self.do_fetch_all(item_type, resolve_conflicts=resolve_conflicts, **kwargs) + self.do_fetch_all(item_type, **kwargs) get_items = mapped_table.valid_values if skip_removed else mapped_table.values return [x.public_item for x in get_items() if all(x.get(k) == v for k, v in kwargs.items())] @@ -622,7 +618,7 @@ def purge_items(self, item_type): """ return bool(self.remove_items(item_type, Asterisk)) - def fetch_more(self, item_type, offset=0, limit=None, resolve_conflicts=select_in_memory_item_always, **kwargs): + def fetch_more(self, item_type, offset=0, limit=None, **kwargs): """Fetches items from the DB into the in-memory mapping, incrementally. Args: @@ -636,12 +632,7 @@ def fetch_more(self, item_type, offset=0, limit=None, resolve_conflicts=select_i list(:class:`PublicItem`): The items fetched. """ item_type = self.real_item_type(item_type) - return [ - x.public_item - for x in self.do_fetch_more( - item_type, offset=offset, limit=limit, resolve_conflicts=resolve_conflicts, **kwargs - ) - ] + return [x.public_item for x in self.do_fetch_more(item_type, offset=offset, limit=limit, **kwargs)] def fetch_all(self, *item_types): """Fetches items from the DB into the in-memory mapping. @@ -706,18 +697,13 @@ def commit_session(self, comment): date = datetime.now(timezone.utc) ins = self._metadata.tables["commit"].insert() with self.engine.begin() as connection: - commit_item = {"user": user, "date": date, "comment": comment} try: - commit_id = connection.execute(ins, commit_item).inserted_primary_key[0] + commit_id = connection.execute(ins, dict(user=user, date=date, comment=comment)).inserted_primary_key[0] except DBAPIError as e: raise SpineDBAPIError(f"Fail to commit: {e.orig.args}") from e - commit_item["id"] = commit_id - commit_table = self.mapped_table("commit") - commit_table.add_item_from_db(commit_item) - commit_item_id = commit_table.id_map.item_id(commit_id) for tablename, (to_add, to_update, to_remove) in dirty_items: for item in to_add + to_update + to_remove: - item.commit(commit_item_id) + item.commit(commit_id) # Remove before add, to help with keeping integrity constraints self._do_remove_items(connection, tablename, *{x["id"] for x in to_remove}) self._do_update_items(connection, tablename, *to_update) @@ -735,6 +721,10 @@ def rollback_session(self): if self._memory: self._memory_dirty = False + def refresh_session(self): + """Resets the fetch status so new items from the DB can be retrieved.""" + self._refresh() + def has_external_commits(self): """Tests whether the database has had commits from other sources than this mapping. diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 30662108..f14a0cd6 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -1,7 +1,6 @@ ###################################################################################################################### # Copyright (C) 2017-2022 Spine project consortium # Copyright Spine Database API contributors -# Copyright Spine Database API contributors # This file is part of Spine Database API. # Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser # General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your @@ -10,19 +9,10 @@ # Public License for more details. You should have received a copy of the GNU Lesser General Public License along with # this program. If not, see . ###################################################################################################################### + +from enum import Enum, unique, auto from difflib import SequenceMatcher -from enum import auto, Enum, unique -from typing import Iterable - -from .conflict_resolution import ( - Conflict, - KeepInMemoryAction, - resolved_conflict_actions, - select_in_memory_item_always, - UpdateInMemoryAction, -) -from .item_id import IdFactory, IdMap -from .item_status import Status +from .temp_id import TempId, resolve from .exception import SpineDBAPIError from .helpers import Asterisk @@ -30,10 +20,14 @@ @unique -class _AddStatus(Enum): - ADDED = auto() - CONFLICT = auto() - DUPLICATE = auto() +class Status(Enum): + """Mapped item status.""" + + committed = auto() + to_add = auto() + to_update = auto() + to_remove = auto() + added_and_removed = auto() class DatabaseMappingBase: @@ -41,7 +35,7 @@ class DatabaseMappingBase: This class is not meant to be used directly. Instead, you should subclass it to fit your particular DB schema. - When subclassing, you need to implement :meth:`item_types`, :meth:`item_factory`, and :meth:`_make_sq`. + When subclassing, you need to implement :meth:`item_types`, :meth:`item_factory`, and :meth:`_make_query`. """ def __init__(self): @@ -91,7 +85,7 @@ def item_factory(item_type): """ raise NotImplementedError() - def make_query(self, item_type, **kwargs): + def _make_query(self, item_type, **kwargs): """Returns a :class:`~spinedb_api.query.Query` object to fetch items of given type. Args: @@ -136,15 +130,6 @@ def make_item(self, item_type, **item): factory = self.item_factory(item_type) return factory(self, item_type, **item) - def any_uncommitted_items(self): - """Returns True if there are uncommitted changes.""" - available_types = tuple(item_type for item_type in self._sorted_item_types if item_type in self._mapped_tables) - return any( - not item.is_committed() - for item_type in available_types - for item in self._mapped_tables[item_type].valid_values() - ) - def dirty_ids(self, item_type): return { item["id"] @@ -220,9 +205,13 @@ def _rollback(self): for item_type, to_add in to_add_by_type: mapped_table = self.mapped_table(item_type) for item in to_add: - mapped_table.remove_item(item) + if mapped_table.remove_item(item) is not None: + item.invalidate_id() return True + def _refresh(self): + """Clears fetch progress, so the DB is queried again.""" + def _check_item_type(self, item_type): if item_type not in self.all_item_types(): candidate = max(self.all_item_types(), key=lambda x: SequenceMatcher(None, item_type, x).ratio()) @@ -234,14 +223,6 @@ def mapped_table(self, item_type): self._mapped_tables[item_type] = _MappedTable(self, item_type) return self._mapped_tables[item_type] - def find_item_id(self, item_type, db_id): - """Searches for item id that corresponds to given database id.""" - return self.mapped_table(item_type).id_map.item_id(db_id) - - def find_db_id(self, item_type, item_id): - """Searches for database id that corresponds to given item id.""" - return self.mapped_table(item_type).id_map.db_id(item_id) if item_id < 0 else item_id - def reset(self, *item_types): """Resets the mapping for given item types as if nothing was fetched from the DB or modified in the mapping. Any modifications in the mapping that aren't committed to the DB are lost after this. @@ -261,18 +242,9 @@ def _add_descendants(self, item_types): if not changed: break - def reset_purging(self): - """Resets purging status for all item types. - - Fetching items of an item type that has been purged will automatically mark those items removed. - Resetting the purge status lets fetched items to be added unmodified. - """ - for mapped_table in self._mapped_tables.values(): - mapped_table.wildcard_item.status = Status.committed - - def get_mapped_item(self, item_type, id_): + def get_mapped_item(self, item_type, id_, fetch=True): mapped_table = self.mapped_table(item_type) - return mapped_table.find_item_by_id(id_) or {} + return mapped_table.find_item_by_id(id_, fetch=fetch) or {} def _get_next_chunk(self, item_type, offset, limit, **kwargs): """Gets chunk of items from the DB. @@ -280,14 +252,14 @@ def _get_next_chunk(self, item_type, offset, limit, **kwargs): Returns: list(dict): list of dictionary items. """ - qry = self.make_query(item_type, **kwargs) + qry = self._make_query(item_type, **kwargs) if not qry: return [] if not limit: return [dict(x) for x in qry] return [dict(x) for x in qry.limit(limit).offset(offset)] - def do_fetch_more(self, item_type, offset=0, limit=None, resolve_conflicts=select_in_memory_item_always, **kwargs): + def do_fetch_more(self, item_type, offset=0, limit=None, **kwargs): """Fetches items from the DB and adds them to the mapping. Args: @@ -303,51 +275,19 @@ def do_fetch_more(self, item_type, offset=0, limit=None, resolve_conflicts=selec items = [] new_items = [] # Add items first - conflicts = [] for x in chunk: - item, add_status = mapped_table.add_item_from_db(x) - if add_status == _AddStatus.CONFLICT: - fetched_item = self.make_item(item_type, **x) - fetched_item.polish() - conflicts.append(Conflict(item, fetched_item)) - elif add_status == _AddStatus.ADDED: + item, new = mapped_table.add_item_from_db(x) + if new: new_items.append(item) - items.append(item) - elif add_status == _AddStatus.DUPLICATE: - items.append(item) - if conflicts: - resolved = resolve_conflicts(conflicts) - items += self._apply_conflict_resolutions(resolved) + items.append(item) # Once all items are added, add the unique key values # Otherwise items that refer to other items that come later in the query will be seen as corrupted for item in new_items: mapped_table.add_unique(item) return items - def do_fetch_all(self, item_type, resolve_conflicts=select_in_memory_item_always, **kwargs): - self.do_fetch_more(item_type, offset=0, limit=None, resolve_conflicts=resolve_conflicts, **kwargs) - - @staticmethod - def _apply_conflict_resolutions(resolved_conflicts): - items = [] - for action in resolved_conflict_actions(resolved_conflicts): - if isinstance(action, KeepInMemoryAction): - item = action.in_memory - items.append(item) - if action.set_uncommitted and item.is_committed(): - if item.removed: - item.status = Status.to_remove - else: - item.status = Status.to_update - elif isinstance(action, UpdateInMemoryAction): - item = action.in_memory - if item.removed: - item.resurrect() - item.update(action.in_db) - items.append(item) - else: - raise RuntimeError("unknown conflict resolution action") - return items + def do_fetch_all(self, item_type, **kwargs): + self.do_fetch_more(item_type, offset=0, limit=None, **kwargs) class _MappedTable(dict): @@ -360,8 +300,6 @@ def __init__(self, db_map, item_type, *args, **kwargs): super().__init__(*args, **kwargs) self._db_map = db_map self._item_type = item_type - self._id_factory = IdFactory() - self.id_map = IdMap() self._id_by_unique_key_value = {} self._temp_id_by_db_id = {} self.wildcard_item = MappedItemBase(self._db_map, self._item_type, id=Asterisk) @@ -379,9 +317,13 @@ def get(self, id_, default=None): return super().get(id_, default) def _new_id(self): - item_id = self._id_factory.next_id() - self.id_map.add_item_id(item_id) - return item_id + temp_id = TempId(self._item_type) + + def _callback(db_id): + self._temp_id_by_db_id[db_id] = temp_id + + temp_id.add_resolve_callback(_callback) + return temp_id def _unique_key_value_to_id(self, key, value, fetch=True): """Returns the id that has the given value for the given unique key, or None if not found. @@ -395,13 +337,11 @@ def _unique_key_value_to_id(self, key, value, fetch=True): int """ id_by_unique_value = self._id_by_unique_key_value.get(key, {}) - value = tuple(tuple(x) if isinstance(x, list) else x for x in value) - item_id = id_by_unique_value.get(value) - if item_id is None and fetch: + if not id_by_unique_value and fetch: self._db_map.do_fetch_all(self._item_type) id_by_unique_value = self._id_by_unique_key_value.get(key, {}) - item_id = id_by_unique_value.get(value) - return item_id + value = tuple(tuple(x) if isinstance(x, list) else x for x in value) + return id_by_unique_value.get(value) def _unique_key_value_to_item(self, key, value, fetch=True): return self.get(self._unique_key_value_to_id(key, value, fetch=fetch)) @@ -435,19 +375,10 @@ def find_item(self, item, skip_keys=(), fetch=True): return self.find_item_by_unique_key(item, skip_keys=skip_keys, fetch=fetch) def find_item_by_id(self, id_, fetch=True): - if id_ > 0: - try: - id_ = self.id_map.item_id(id_) - except KeyError: - if fetch: - self._db_map.do_fetch_all(self._item_type) - try: - id_ = self.id_map.item_id(id_) - except KeyError: - return {} - else: - return {} current_item = self.get(id_, {}) + if not current_item and fetch: + self._db_map.do_fetch_all(self._item_type) + current_item = self.get(id_, {}) return current_item def find_item_by_unique_key(self, item, skip_keys=(), fetch=True, complete=True): @@ -486,10 +417,7 @@ def checked_item_and_error(self, item, for_update=False): if error: return None, error valid_types = (type(None),) if for_update else () - self.check_fields_for_addition(candidate_item) self.check_fields(candidate_item._asdict(), valid_types=valid_types) - if not for_update: - candidate_item.convert_dicts_db_ids_to_item_ids(self._item_type, candidate_item, self._db_map) return candidate_item, merge_error def _prepare_item(self, candidate_item, current_item, original_item): @@ -549,16 +477,10 @@ def remove_unique(self, item): def _make_and_add_item(self, item): if not isinstance(item, MappedItemBase): item = self._make_item(item) - item.convert_db_ids_to_item_ids() item.polish() - item_id = self._new_id() - db_id = item.get("id") - if db_id is not None: - self.id_map.set_db_id(item_id, db_id) - else: - self.id_map.add_item_id(item_id) - item["id"] = item_id - self[item_id] = item + if "id" not in item or not item.is_id_valid: + item["id"] = self._new_id() + self[item["id"]] = item return item def add_item_from_db(self, item): @@ -570,46 +492,16 @@ def add_item_from_db(self, item): Returns: tuple(MappedItem,bool): The mapped item and whether it hadn't been added before. """ - same_item = False - if current := self.find_item_by_id(item["id"], fetch=False): - same_item = current.same_db_item(item) - if same_item: - return ( - current, - _AddStatus.DUPLICATE - if not current.removed and self._compare_non_unique_fields(current, item) - else _AddStatus.CONFLICT, - ) - self.id_map.remove_db_id(current["id"]) - if not current.removed: - current.status = Status.to_add - if "commit_id" in current: - current["commit_id"] = None - else: - current.status = Status.overwritten - if not same_item: - current = self.find_item_by_unique_key(item, fetch=False, complete=False) - if current: - return ( - current, - _AddStatus.DUPLICATE if self._compare_non_unique_fields(current, item) else _AddStatus.CONFLICT, - ) + current = self.find_item_by_id(item["id"], fetch=False) or self.find_item_by_unique_key( + item, fetch=False, complete=False + ) + if current: + return current, False item = self._make_and_add_item(item) if self.purged: # Lazy purge: instead of fetching all at purge time, we purge stuff as it comes. item.cascade_remove(source=self.wildcard_item) - return item, _AddStatus.ADDED - - @staticmethod - def _compare_non_unique_fields(mapped_item, item): - unique_keys = mapped_item.unique_keys() - for key, value in item.items(): - if key not in mapped_item.fields or key in unique_keys: - continue - mapped_value = mapped_item[key] - if value != mapped_value and (not isinstance(mapped_value, tuple) or (mapped_value and value)): - return False - return True + return item, True def check_fields(self, item, valid_types=()): factory = self._db_map.item_factory(self._item_type) @@ -638,12 +530,6 @@ def _error(key, value, valid_types): if errors: raise SpineDBAPIError("\n".join(errors)) - def check_fields_for_addition(self, item): - factory = self._db_map.item_factory(self._item_type) - for required_field in factory.required_fields: - if required_field not in item: - raise SpineDBAPIError(f"missing keyword argument {required_field}") - def add_item(self, item): item = self._make_and_add_item(item) self.add_unique(item) @@ -667,6 +553,7 @@ def remove_item(self, item): self.remove_unique(current_item) current_item.cascade_remove(source=self.wildcard_item) return self.wildcard_item + self.remove_unique(item) item.cascade_remove() return item @@ -679,6 +566,7 @@ def restore_item(self, id_): return self.wildcard_item current_item = self.find_item({"id": id_}) if current_item: + self.add_unique(current_item) current_item.cascade_restore() return current_item @@ -689,8 +577,6 @@ class MappedItemBase(dict): fields = {} """A dictionary mapping keys to a another dict mapping "type" to a Python type, "value" to a description of the value for the key, and "optional" to a bool.""" - required_fields = () - """A tuple of field names that are required to create new items in addition to unique constraints.""" _defaults = {} """A dictionary mapping keys to their default values""" _unique_keys = () @@ -712,11 +598,6 @@ class MappedItemBase(dict): Keys in _internal_fields are resolved to the reference key of the alternative reference pointed at by the source key. """ - _id_fields = {} - """A dictionary mapping item types to field names that contain database ids. - Required for conversion from database ids to item ids and back.""" - _external_id_fields = set() - """A set of external field names that contain database ids.""" _private_fields = set() """A set with fields that should be ignored in validations.""" @@ -733,6 +614,7 @@ def __init__(self, db_map, item_type, **kwargs): self.restore_callbacks = set() self.update_callbacks = set() self.remove_callbacks = set() + self._is_id_valid = True self._to_remove = False self._removed = False self._corrupted = False @@ -788,11 +670,6 @@ def removed(self): """ return self._removed - def resurrect(self): - """Sets item as not-removed but does not resurrect referrers.""" - self._removed = False - self._removal_source = None - @property def item_type(self): """Returns this item's type @@ -815,86 +692,15 @@ def key(self): return None return (self._item_type, id_) - def same_db_item(self, db_item): - """Tests if database item that has same db id is in fact same as this item. - - Args: - db_item (dict): item fetched from database - - Returns: - bool: True if items are the same, False otherwise - """ - raise NotImplementedError() - - def convert_db_ids_to_item_ids(self): - for item_type, id_fields in self._id_fields.items(): - for id_field in id_fields: - try: - field = self[id_field] - except KeyError: - continue - if field is None: - continue - if isinstance(field, Iterable): - self[id_field] = tuple( - self._find_or_fetch_item_id(item_type, self._item_type, db_id, self._db_map) for db_id in field - ) - else: - self[id_field] = self._find_or_fetch_item_id(item_type, self._item_type, field, self._db_map) - - @staticmethod - def _find_or_fetch_item_id(item_type, requesting_item_type, db_id, db_map): - try: - item_id = db_map.find_item_id(item_type, db_id) - except KeyError: - pass - else: - item = db_map.mapped_table(item_type)[item_id] - if not item.removed: - return item_id - if item_type == requesting_item_type: - # We could be fetching everything already, so fetch only a specific id - # to avoid endless recursion. - db_map.do_fetch_all(item_type, id=db_id) - else: - db_map.do_fetch_all(item_type) - return db_map.find_item_id(item_type, db_id) - - @classmethod - def convert_dicts_db_ids_to_item_ids(cls, item_type, item_dict, db_map): - for field_item_type, id_fields in cls._id_fields.items(): - for id_field in id_fields: - try: - field = item_dict[id_field] - except KeyError: - continue - if field is None: - continue - if isinstance(field, Iterable): - item_dict[id_field] = tuple( - cls._find_or_fetch_item_id(field_item_type, item_type, id_, db_map) if id_ > 0 else id_ - for id_ in field - ) - else: - item_dict[id_field] = ( - cls._find_or_fetch_item_id(field_item_type, item_type, field, db_map) if field > 0 else field - ) + @property + def is_id_valid(self): + return self._is_id_valid - def make_db_item(self, find_db_id): - db_item = dict(self) - db_item["id"] = find_db_id(self._item_type, db_item["id"]) - for item_type, id_fields in self._id_fields.items(): - for id_field in id_fields: - field = db_item[id_field] - if field is None: - continue - if isinstance(field, Iterable): - db_item[id_field] = tuple(find_db_id(item_type, item_id) for item_id in field) - else: - db_item[id_field] = find_db_id(item_type, field) - return db_item + def invalidate_id(self): + """Sets id as invalid.""" + self._is_id_valid = False - def extended(self): + def _extended(self): """Returns a dict from this item's original fields plus all the references resolved statically. Returns: @@ -912,16 +718,8 @@ def _asdict(self): """ return dict(self) - def equal_ignoring_ids(self, other): - """Compares the non-id fields for equality. - - Args: - other (MappedItemBase): other item - - Returns: - bool: True if non-id fields are equal, False otherwise - """ - return all(self[field] == other[field] for field in self.fields) + def resolve(self): + return {k: resolve(v) for k, v in self._asdict().items()} def merge(self, other): """Merges this item with another and returns the merged item together with any errors. @@ -934,14 +732,10 @@ def merge(self, other): dict: merged item. str: error description if any. """ - other = {key: value for key, value in other.items() if key not in self._external_id_fields} - self.convert_dicts_db_ids_to_item_ids(self._item_type, other, self._db_map) - if "id" in other: - del other["id"] if not self._something_to_update(other): # Nothing to update, that's fine return None, "" - merged = {**self.extended(), **other} + merged = {**self._extended(), **other} if not isinstance(merged["id"], int): merged["id"] = self["id"] return merged, "" @@ -986,10 +780,6 @@ def _invalid_keys(self): elif not self._get_ref(ref_type, {ref_key: src_val}): yield src_key - @classmethod - def unique_keys(cls): - return set(sum(cls._unique_keys, ())) - @classmethod def unique_values_for_item(cls, item, skip_keys=()): for key in cls._unique_keys: @@ -1181,7 +971,7 @@ def cascade_restore(self, source=None): return if self.status in (Status.added_and_removed, Status.to_remove): self._status = self._status_when_removed - elif self.status == Status.committed or self.status == Status.overwritten: + elif self.status == Status.committed: self._status = Status.to_add else: raise RuntimeError("invalid status for item being restored") @@ -1256,12 +1046,12 @@ def cascade_remove_unique(self): referrer.cascade_remove_unique() def is_committed(self): - """Returns whether this item is committed to the DB. + """Returns whether or not this item is committed to the DB. Returns: bool """ - return self._status == Status.committed or self._status == Status.overwritten + return self._status == Status.committed def commit(self, commit_id): """Sets this item as committed with the given commit id.""" @@ -1271,7 +1061,7 @@ def commit(self, commit_id): def __repr__(self): """Overridden to return a more verbose representation.""" - return f"{self._item_type}{self.extended()}" + return f"{self._item_type}{self._extended()}" def __getattr__(self, name): """Overridden to return the dictionary key named after the attribute, or None if it doesn't exist.""" @@ -1359,8 +1149,8 @@ def is_committed(self): def _asdict(self): return self._mapped_item._asdict() - def extended(self): - return self._mapped_item.extended() + def _extended(self): + return self._mapped_item._extended() def update(self, **kwargs): self._db_map.update_item(self.item_type, id=self["id"], **kwargs) diff --git a/spinedb_api/db_mapping_commit_mixin.py b/spinedb_api/db_mapping_commit_mixin.py index ae5985e6..5d2a78b8 100644 --- a/spinedb_api/db_mapping_commit_mixin.py +++ b/spinedb_api/db_mapping_commit_mixin.py @@ -14,6 +14,7 @@ from sqlalchemy.sql.expression import bindparam from sqlalchemy.exc import DBAPIError from .exception import SpineDBAPIError +from .temp_id import TempId, resolve from .helpers import group_consecutive, Asterisk @@ -36,17 +37,15 @@ def _do_add_items(self, connection, tablename, *items_to_add): if not items_to_add: return try: - item_type = self.real_item_type(tablename) - table = self._metadata.tables[item_type] + table = self._metadata.tables[self.real_item_type(tablename)] id_items, temp_id_items = [], [] - id_map = self.mapped_table(item_type).id_map for item in items_to_add: - if id_map.db_id(item["id"]) is None: + if isinstance(item["id"], TempId): temp_id_items.append(item) else: id_items.append(item) if id_items: - connection.execute(table.insert(), [x.make_db_item(self.find_db_id) for x in id_items]) + connection.execute(table.insert(), [x.resolve() for x in id_items]) if temp_id_items: current_ids = {x["id"] for x in connection.execute(table.select())} next_id = max(current_ids, default=0) + 1 @@ -55,35 +54,35 @@ def _do_add_items(self, connection, tablename, *items_to_add): new_ids = set(range(next_id, next_id + required_id_count)) ids = sorted(available_ids | new_ids) for id_, item in zip(ids, temp_id_items): - id_map.set_db_id(item["id"], id_) - connection.execute(table.insert(), [x.make_db_item(self.find_db_id) for x in temp_id_items]) + temp_id = item["id"] + temp_id.resolve(id_) + connection.execute(table.insert(), [x.resolve() for x in temp_id_items]) for tablename_, items_to_add_ in self._extra_items_to_add_per_table(tablename, items_to_add): if not items_to_add_: continue table = self._metadata.tables[self.real_item_type(tablename_)] - connection.execute(table.insert(), items_to_add_) + connection.execute(table.insert(), [resolve(x) for x in items_to_add_]) except DBAPIError as e: msg = f"DBAPIError while inserting {tablename} items: {e.orig.args}" raise SpineDBAPIError(msg) from e - def _dimensions_for_classes(self, classes): - id_map = self.mapped_table("entity_class").id_map + @staticmethod + def _dimensions_for_classes(classes): return [ - {"entity_class_id": id_map.db_id(x["id"]), "position": position, "dimension_id": id_map.db_id(dimension_id)} + {"entity_class_id": x["id"], "position": position, "dimension_id": dimension_id} for x in classes for position, dimension_id in enumerate(x["dimension_id_list"]) ] - def _elements_for_entities(self, entities): - entity_id_map = self.mapped_table("entity").id_map - class_id_map = self.mapped_table("entity_class").id_map + @staticmethod + def _elements_for_entities(entities): return [ { - "entity_id": entity_id_map.db_id(x["id"]), - "entity_class_id": class_id_map.db_id(x["class_id"]), + "entity_id": x["id"], + "entity_class_id": x["class_id"], "position": position, - "element_id": entity_id_map.db_id(element_id), - "dimension_id": class_id_map.db_id(dimension_id), + "element_id": element_id, + "dimension_id": dimension_id, } for x in entities for position, (element_id, dimension_id) in enumerate(zip(x["element_id_list"], x["dimension_id_list"])) @@ -119,12 +118,12 @@ def _do_update_items(self, connection, tablename, *items_to_update): return try: upd = self._make_update_stmt(tablename, items_to_update[0].keys()) - connection.execute(upd, [x.make_db_item(self.find_db_id) for x in items_to_update]) + connection.execute(upd, [x.resolve() for x in items_to_update]) for tablename_, items_to_update_ in self._extra_items_to_update_per_table(tablename, items_to_update): if not items_to_update_: continue upd = self._make_update_stmt(tablename_, items_to_update_[0].keys()) - connection.execute(upd, items_to_update_) + connection.execute(upd, [resolve(x) for x in items_to_update_]) except DBAPIError as e: msg = f"DBAPIError while updating '{tablename}' items: {e.orig.args}" raise SpineDBAPIError(msg) from e @@ -136,13 +135,10 @@ def _do_remove_items(self, connection, tablename, *ids): *ids: ids to remove """ tablename = self.real_item_type(tablename) - id_map = self.mapped_table(tablename).id_map - purging = Asterisk in ids - if not purging: - ids = {id_map.db_id(id_) for id_ in ids} - if tablename == "alternative": - # Do not remove the Base alternative - ids.discard(1) + ids = {resolve(id_) for id_ in ids} + if tablename == "alternative": + # Do not remove the Base alternative + ids.discard(1) if not ids: return tablenames = [tablename] @@ -155,7 +151,7 @@ def _do_remove_items(self, connection, tablename, *ids): for tablename_ in tablenames: table = self._metadata.tables[tablename_] delete = table.delete() - if not purging: + if Asterisk not in ids: id_field = self._id_fields.get(tablename_, "id") id_column = getattr(table.c, id_field) cond = or_(*(and_(id_column >= first, id_column <= last) for first, last in group_consecutive(ids))) diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index bba06510..b1b72970 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -881,15 +881,3 @@ def group_consecutive(list_of_numbers): for _k, g in groupby(enumerate(sorted(list_of_numbers)), lambda x: x[0] - x[1]): group = list(map(itemgetter(1), g)) yield group[0], group[-1] - - -def query_byname(entity_row, db_map): - element_ids = entity_row["element_id_list"] - if element_ids is None: - return (entity_row["name"],) - sq = db_map.wide_entity_sq - byname = [] - for element_id in element_ids.split(","): - element_row = db_map.query(sq).filter(sq.c.id == element_id).one() - byname += list(query_byname(element_row, db_map)) - return tuple(byname) diff --git a/spinedb_api/item_id.py b/spinedb_api/item_id.py deleted file mode 100644 index 441736f4..00000000 --- a/spinedb_api/item_id.py +++ /dev/null @@ -1,61 +0,0 @@ -###################################################################################################################### -# Copyright (C) 2017-2022 Spine project consortium -# Copyright Spine Database API contributors -# This file is part of Spine Database API. -# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser -# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your -# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; -# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General -# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with -# this program. If not, see . -###################################################################################################################### -from collections import Counter - - -class IdFactory: - def __init__(self): - self._next_id = -1 - - def next_id(self): - item_id = self._next_id - self._next_id -= 1 - return item_id - - -class IdMap: - def __init__(self): - self._item_id_by_db_id = {} - self._db_id_by_item_id = {} - - def add_item_id(self, item_id): - self._db_id_by_item_id[item_id] = None - - def remove_item_id(self, item_id): - db_id = self._db_id_by_item_id.pop(item_id, None) - if db_id is not None: - del self._item_id_by_db_id[db_id] - - def set_db_id(self, item_id, db_id): - self._db_id_by_item_id[item_id] = db_id - self._item_id_by_db_id[db_id] = item_id - - def remove_db_id(self, id_): - if id_ > 0: - item_id = self._item_id_by_db_id.pop(id_) - else: - item_id = id_ - db_id = self._db_id_by_item_id[item_id] - del self._item_id_by_db_id[db_id] - self._db_id_by_item_id[item_id] = None - - def item_id(self, db_id): - return self._item_id_by_db_id[db_id] - - def has_db_id(self, item_id): - return item_id in self._db_id_by_item_id - - def db_id(self, item_id): - return self._db_id_by_item_id[item_id] - - def db_id_iter(self): - yield from self._db_id_by_item_id diff --git a/spinedb_api/item_status.py b/spinedb_api/item_status.py deleted file mode 100644 index 32e4a6a2..00000000 --- a/spinedb_api/item_status.py +++ /dev/null @@ -1,25 +0,0 @@ -###################################################################################################################### -# Copyright (C) 2017-2022 Spine project consortium -# Copyright Spine Database API contributors -# This file is part of Spine Database API. -# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser -# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your -# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; -# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General -# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with -# this program. If not, see . -###################################################################################################################### - -from enum import auto, Enum, unique - - -@unique -class Status(Enum): - """Mapped item status.""" - - committed = auto() - to_add = auto() - to_update = auto() - to_remove = auto() - added_and_removed = auto() - overwritten = auto() diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 59833e9f..6a0f7104 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -12,7 +12,7 @@ from operator import itemgetter -from .helpers import name_from_elements, query_byname +from .helpers import name_from_elements from .parameter_value import to_database, from_database, ParameterValueFormatError from .db_mapping_base import MappedItemBase @@ -50,9 +50,6 @@ class CommitItem(MappedItemBase): _unique_keys = (("date",),) - def same_db_item(self, db_item): - return self["date"].replace(tzinfo=None) == db_item["date"] - def commit(self, commit_id): raise RuntimeError("Commits are created automatically when session is committed.") @@ -85,7 +82,6 @@ class EntityClassItem(MappedItemBase): _external_fields = {"dimension_name_list": ("dimension_id_list", "name")} _alt_references = {("dimension_name_list",): ("entity_class", ("name",))} _internal_fields = {"dimension_id_list": (("dimension_name_list",), "id")} - _id_fields = {"entity_class": ("dimension_id_list",)} _private_fields = {"dimension_count"} def __init__(self, *args, **kwargs): @@ -109,9 +105,6 @@ def polish(self): if "active_by_default" not in self: self["active_by_default"] = bool(dict.get(self, "dimension_id_list")) - def same_db_item(self, db_item): - return self["name"] == db_item["name"] - def merge(self, other): dimension_id_list = other.pop("dimension_id_list", None) error = ( @@ -159,8 +152,6 @@ class EntityItem(MappedItemBase): "class_id": (("entity_class_name",), "id"), "element_id_list": (("dimension_name_list", "element_name_list"), "id"), } - _id_fields = {"entity_class": ("class_id",), "entity": ("element_id_list",), "commit": ("commit_id",)} - _external_id_fields = {"dimension_id_list", "superclass_id"} def __init__(self, *args, **kwargs): element_id_list = kwargs.get("element_id_list") @@ -171,13 +162,6 @@ def __init__(self, *args, **kwargs): kwargs["element_id_list"] = tuple(element_id_list) super().__init__(*args, **kwargs) - def same_db_item(self, db_item): - if _commit_ids_equal(self, db_item, self._db_map): - return True - if self["name"] != db_item["name"]: - return False - return _fields_equal("entity_class", db_item["class_id"], "name", self["entity_class_name"], self._db_map) - @classmethod def unique_values_for_item(cls, item, skip_keys=()): """Overriden to also yield unique values for the superclass.""" @@ -305,8 +289,6 @@ class EntityGroupItem(MappedItemBase): "entity_id": (("entity_class_name", "group_name"), "id"), "member_id": (("entity_class_name", "member_name"), "id"), } - _id_fields = {"entity_class": ("entity_class_id",), "entity": ("entity_id", "member_id")} - _external_id_fields = {"dimension_id_list"} def __getitem__(self, key): if key == "class_id": @@ -315,14 +297,6 @@ def __getitem__(self, key): return self["entity_id"] return super().__getitem__(key) - def same_db_item(self, db_item): - db_map = self._db_map - if not _fields_equal("entity", db_item["entity_id"], "name", self["group_name"], db_map): - return False - if not _fields_equal("entity", db_item["member_id"], "name", self["member_name"], db_map): - return False - return _fields_equal("entity_class", db_item["entity_class_id"], "name", self["entity_class_name"], db_map) - def commit(self, _commit_id): super().commit(None) @@ -367,20 +341,6 @@ class EntityAlternativeItem(MappedItemBase): "entity_id": (("entity_class_name", "entity_byname"), "id"), "alternative_id": (("alternative_name",), "id"), } - _id_fields = {"entity": ("entity_id",), "alternative": ("alternative_id",), "commit": ("commit_id",)} - _external_id_fields = {"entity_class_id", "dimension_id_list", "element_id_list"} - - def same_db_item(self, db_item): - if not _commit_ids_equal(self, db_item, self._db_map): - return False - entity_record = self._db_map.make_query("entity", id=db_item["entity_id"]).one() - if not _fields_equal( - "entity_class", entity_record["class_id"], "name", self["entity_class_name"], self._db_map - ): - return False - if query_byname(entity_record, self._db_map) != self["entity_byname"]: - return False - return _fields_equal("alternative", db_item["alternative_id"], "name", self["alternative_name"], self._db_map) class ParsedValueBase(MappedItemBase): @@ -390,9 +350,6 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._parsed_value = None - def same_db_item(self, db_item): - raise NotImplementedError() - @property def parsed_value(self): if self._parsed_value is None: @@ -437,9 +394,6 @@ def _something_to_update(self, other): class ParameterItemBase(ParsedValueBase): - def same_db_item(self, db_item): - raise NotImplementedError() - @property def _value_key(self): raise NotImplementedError() @@ -459,13 +413,12 @@ def ref_types(cls): def list_value_id(self): return self["list_value_id"] - def make_db_item(self, find_db_id): - db_item = super().make_db_item(find_db_id) - list_value_id = db_item.get("list_value_id") + def resolve(self): + d = super().resolve() + list_value_id = d.get("list_value_id") if list_value_id is not None: - list_value_db_id = self._db_map.find_db_id("list_value", list_value_id) - db_item[self._value_key] = to_database(list_value_db_id)[0] - return db_item + d[self._value_key] = to_database(list_value_id)[0] + return d def polish(self): self["list_value_id"] = None @@ -521,12 +474,6 @@ class ParameterDefinitionItem(ParameterItemBase): "entity_class_id": (("entity_class_name",), "id"), "parameter_value_list_id": (("parameter_value_list_name",), "id"), } - _id_fields = { - "entity_class": ("entity_class_id",), - "parameter_value_list": ("parameter_value_list_id",), - "commit": ("commit_id",), - } - _external_id_fields = {"dimension_id_list"} @property def _value_key(self): @@ -553,15 +500,6 @@ def __getitem__(self, key): return dict.get(self, key) return super().__getitem__(key) - def same_db_item(self, db_item): - if _commit_ids_equal(self, db_item, self._db_map): - return True - if self["name"] != db_item["name"]: - return False - return _fields_equal( - "entity_class", db_item["entity_class_id"], "name", self["entity_class_name"], self._db_map - ) - def merge(self, other): other_parameter_value_list_id = other.get("parameter_value_list_id") if ( @@ -595,7 +533,6 @@ class ParameterValueItem(ParameterItemBase): 'type': {'type': str, 'value': 'The value type.', 'optional': True}, 'alternative_name': {'type': str, 'value': "The alternative name - defaults to 'Base'.", 'optional': True}, } - required_fields = ("value", "type") _unique_keys = (("entity_class_name", "parameter_definition_name", "entity_byname", "alternative_name"),) _references = { "entity_class_id": ("entity_class", "id"), @@ -628,15 +565,6 @@ class ParameterValueItem(ParameterItemBase): "entity_id": (("entity_class_name", "entity_byname"), "id"), "alternative_id": (("alternative_name",), "id"), } - _id_fields = { - "parameter_definition": ("parameter_definition_id",), - "entity_class": ("entity_class_id",), - "entity": ("entity_id",), - "list_value": ("list_value_id",), - "alternative": ("alternative_id",), - "commit": ("commit_id",), - } - _external_id_fields = {"dimension_id_list", "element_id_list", "parameter_value_list_id"} @property def _value_key(self): @@ -657,25 +585,6 @@ def __getitem__(self, key): return self._get_ref("list_value", {"id": list_value_id}, strong=False).get(key) return super().__getitem__(key) - def same_db_item(self, db_item): - if _commit_ids_equal(self, db_item, self._db_map): - return True - if not _fields_equal( - "entity_class", db_item["entity_class_id"], "name", self["entity_class_name"], self._db_map - ): - return False - if not _fields_equal( - "parameter_definition", - db_item["parameter_definition_id"], - "name", - self["parameter_definition_name"], - self._db_map, - ): - return False - if not _fields_equal("entity", db_item["entity_id"], "name", self["entity_name"], self._db_map): - return False - return _fields_equal("alternative", db_item["alternative_id"], "name", self["alternative_name"], self._db_map) - def _value_not_in_list_error(self, parsed_value, list_name): return ( f"value {parsed_value} of {self['parameter_definition_name']} for {self['entity_byname']} " @@ -686,10 +595,6 @@ def _value_not_in_list_error(self, parsed_value, list_name): class ParameterValueListItem(MappedItemBase): fields = {'name': {'type': str, 'value': 'The parameter value list name.'}} _unique_keys = (("name",),) - _id_fields = {"commit": ("commit_id",)} - - def same_db_item(self, db_item): - return db_item["name"] == self["name"] class ListValueItem(ParsedValueBase): @@ -704,7 +609,6 @@ class ListValueItem(ParsedValueBase): _external_fields = {"parameter_value_list_name": ("parameter_value_list_id", "name")} _alt_references = {("parameter_value_list_name",): ("parameter_value_list", ("name",))} _internal_fields = {"parameter_value_list_id": (("parameter_value_list_name",), "id")} - _id_fields = {"parameter_value_list": ("parameter_value_list_id",), "commit": ("commit_id",)} @property def _value_key(self): @@ -719,19 +623,6 @@ def __getitem__(self, key): return (self["value"], self["type"]) return super().__getitem__(key) - def same_db_item(self, db_item): - if _commit_ids_equal(self, db_item, self._db_map): - return True - if self["index"] != db_item["index"]: - return False - return _fields_equal( - "parameter_value_list", - db_item["parameter_value_list_id"], - "name", - self["parameter_value_list_name"], - self._db_map, - ) - class AlternativeItem(MappedItemBase): fields = { @@ -740,10 +631,6 @@ class AlternativeItem(MappedItemBase): } _defaults = {"description": None} _unique_keys = (("name",),) - _id_fields = {"commit": ("commit_id",)} - - def same_db_item(self, db_item): - return self["name"] == db_item["name"] class ScenarioItem(MappedItemBase): @@ -754,7 +641,6 @@ class ScenarioItem(MappedItemBase): } _defaults = {"active": False, "description": None} _unique_keys = (("name",),) - _id_fields = {"commit": ("commit_id",)} def __getitem__(self, key): if key == "alternative_id_list": @@ -773,9 +659,6 @@ def __getitem__(self, key): ) return super().__getitem__(key) - def same_db_item(self, db_item): - return self["name"] == db_item["name"] - class ScenarioAlternativeItem(MappedItemBase): fields = { @@ -788,7 +671,6 @@ class ScenarioAlternativeItem(MappedItemBase): _external_fields = {"scenario_name": ("scenario_id", "name"), "alternative_name": ("alternative_id", "name")} _alt_references = {("scenario_name",): ("scenario", ("name",)), ("alternative_name",): ("alternative", ("name",))} _internal_fields = {"scenario_id": (("scenario_name",), "id"), "alternative_id": (("alternative_name",), "id")} - _id_fields = {"scenario": ("scenario_id",), "alternative": ("alternative_id",), "commit": ("commit_id",)} def __getitem__(self, key): # The 'before' is to be interpreted as, this scenario alternative goes *before* the before_alternative. @@ -805,15 +687,6 @@ def __getitem__(self, key): return None return super().__getitem__(key) - def same_db_item(self, db_item): - if _commit_ids_equal(self, db_item, self._db_map): - return True - if not _fields_equal("scenario", db_item["scenario_id"], "name", self["scenario_name"], self._db_map): - return False - if not _fields_equal("alternative", db_item["alternative_id"], "name", self["alternative_name"], self._db_map): - return False - return self["rank"] == db_item["rank"] - class MetadataItem(MappedItemBase): fields = { @@ -821,10 +694,6 @@ class MetadataItem(MappedItemBase): 'value': {'type': str, 'value': 'The metadata entry value.'}, } _unique_keys = (("name", "value"),) - _id_fields = {"commit": ("commit_id",)} - - def same_db_item(self, db_item): - return self["name"] == db_item["name"] and self["value"] == db_item["value"] class EntityMetadataItem(MappedItemBase): @@ -856,18 +725,6 @@ class EntityMetadataItem(MappedItemBase): "entity_id": (("entity_class_name", "entity_byname"), "id"), "metadata_id": (("metadata_name", "metadata_value"), "id"), } - _id_fields = {"entity": ("entity_id",), "metadata": ("metadata_id",), "commit": ("commit_id",)} - - def same_db_item(self, db_item): - if _commit_ids_equal(self, db_item, self._db_map): - return True - entity_record = self._db_map.make_query("entity", id=db_item["entity_id"]).one_or_none() - if not entity_record or query_byname(entity_record, self._db_map) != self["entity_byname"]: - return False - if not _fields_equal("entity_class", db_item["class_id"], "name", self["entity_class_name"], self._db_map): - return False - record = self._db_map.make_query("metadata", id=db_item["metadata_id"]).one_or_none() - return record and self["metadata_name"] == record["name"] and self["metadata_value"] == record["value"] class ParameterValueMetadataItem(MappedItemBase): @@ -915,33 +772,6 @@ class ParameterValueMetadataItem(MappedItemBase): ), "metadata_id": (("metadata_name", "metadata_value"), "id"), } - _id_fields = {"parameter_value": ("parameter_value_id",), "metadata": ("metadata_id",), "commit": ("commit_id",)} - - def same_db_item(self, db_item): - if _commit_ids_equal(self, db_item, self._db_map): - return True - value_record = self._db_map.make_query("parameter_value", id=db_item["parameter_value_id"]).one() - entity_record = self._db_map.make_query("entity", id=value_record["entity_id"]).one() - if query_byname(entity_record, self._db_map) != self["entity_byname"]: - return False - if not _fields_equal( - "entity_class", value_record["entity_class_id"], "name", self["entity_class_name"], self._db_map - ): - return False - if not _fields_equal( - "parameter_definition", - value_record["parameter_definition_id"], - "name", - self["parameter_definition_name"], - self._db_map, - ): - return False - if not _fields_equal( - "alternative", value_record["alternative_id"], "name", self["alternative_name"], self._db_map - ): - return False - record = self._db_map.make_query("metadata", id=db_item["metadata_id"]).one() - return self["metadata_name"] == record["name"] and self["metadata_value"] == record["value"] class SuperclassSubclassItem(MappedItemBase): @@ -960,7 +790,6 @@ class SuperclassSubclassItem(MappedItemBase): ("subclass_name",): ("entity_class", ("name",)), } _internal_fields = {"superclass_id": (("superclass_name",), "id"), "subclass_id": (("subclass_name",), "id")} - _id_fields = {"entity_class": ("superclass_id", "subclass_id")} def _subclass_entities(self): return self._db_map.get_items("entity", class_id=self["subclass_id"]) @@ -970,19 +799,5 @@ def check_mutability(self): return "can't set or modify the superclass for a class that already has entities" return super().check_mutability() - def same_db_item(self, db_item): - return _fields_equal("entity_class", db_item["subclass_id"], "name", self["subclass_name"], self._db_map) - def commit(self, _commit_id): super().commit(None) - - -def _commit_ids_equal(item, db_item, db_map): - db_commit_id = db_map.find_db_id("commit", item["commit_id"]) - return db_commit_id == db_item["commit_id"] - - -def _fields_equal(item_type, db_id, field, expected_value, db_map): - # Use plain query as we want the raw data from database, not something that may have been conflict resolved. - record = db_map.make_query(item_type, id=db_id).one() - return expected_value == record[field] diff --git a/spinedb_api/server_client_helpers.py b/spinedb_api/server_client_helpers.py index d8d62ede..86b33fb3 100644 --- a/spinedb_api/server_client_helpers.py +++ b/spinedb_api/server_client_helpers.py @@ -67,7 +67,7 @@ def default(self, o): if isinstance(o, SpineDBAPIError): return str(o) if isinstance(o, PublicItem): - return o.extended() + return o._extended() return super().default(o) @property diff --git a/spinedb_api/temp_id.py b/spinedb_api/temp_id.py new file mode 100644 index 00000000..79066941 --- /dev/null +++ b/spinedb_api/temp_id.py @@ -0,0 +1,54 @@ +###################################################################################################################### +# Copyright (C) 2017-2022 Spine project consortium +# This file is part of Spine Database API. +# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser +# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General +# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with +# this program. If not, see . +###################################################################################################################### + + +class TempId(int): + _next_id = {} + + def __new__(cls, item_type): + id_ = cls._next_id.setdefault(item_type, -1) + cls._next_id[item_type] -= 1 + return super().__new__(cls, id_) + + def __init__(self, item_type): + super().__init__() + self._item_type = item_type + self._resolve_callbacks = [] + self._db_id = None + + @property + def db_id(self): + return self._db_id + + def __eq__(self, other): + return super().__eq__(other) or (self._db_id is not None and other == self._db_id) + + def __hash__(self): + return int(self) + + def __repr__(self): + return f"TempId({self._item_type}, {super().__repr__()})" + + def add_resolve_callback(self, callback): + self._resolve_callbacks.append(callback) + + def resolve(self, db_id): + self._db_id = db_id + while self._resolve_callbacks: + self._resolve_callbacks.pop(0)(db_id) + + +def resolve(value): + if isinstance(value, dict): + return {k: resolve(v) for k, v in value.items()} + if isinstance(value, TempId): + return value.db_id + return value diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index f0aa98ac..3cdefd4c 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -25,9 +25,7 @@ SpineDBAPIError, SpineIntegrityError, ) -from spinedb_api.conflict_resolution import select_in_db_item_always -from spinedb_api.helpers import Asterisk, name_from_elements -from spinedb_api.mapped_items import EntityItem +from spinedb_api.helpers import name_from_elements from tests.custom_db_mapping import CustomDatabaseMapping @@ -86,28 +84,27 @@ def test_shorthand_filter_query_works(self): class TestDatabaseMapping(unittest.TestCase): - def _assert_success(self, result): - item, error = result - self.assertIsNone(error) - return item - def test_active_by_default_is_initially_false_for_zero_dimensional_entity_class(self): with DatabaseMapping("sqlite://", create=True) as db_map: - item = self._assert_success(db_map.add_entity_class_item(name="Entity")) + item, error = db_map.add_entity_class_item(name="Entity") + self.assertIsNone(error) self.assertFalse(item["active_by_default"]) def test_active_by_default_is_initially_false_for_multi_dimensional_entity_class(self): with DatabaseMapping("sqlite://", create=True) as db_map: db_map.add_entity_class_item(name="Dimension") - item = self._assert_success(db_map.add_entity_class_item(name="Entity", dimension_name_list=("Dimension",))) + item, error = db_map.add_entity_class_item(name="Entity", dimension_name_list=("Dimension",)) + self.assertIsNone(error) self.assertTrue(item["active_by_default"]) def test_read_active_by_default_from_database(self): with TemporaryDirectory() as temp_dir: url = "sqlite:///" + os.path.join(temp_dir, "database.sqlite") with DatabaseMapping(url, create=True) as out_db_map: - self._assert_success(out_db_map.add_entity_class_item(name="HiddenStuff", active_by_default=False)) - self._assert_success(out_db_map.add_entity_class_item(name="VisibleStuff", active_by_default=True)) + _, error = out_db_map.add_entity_class_item(name="HiddenStuff", active_by_default=False) + self.assertIsNone(error) + _, error = out_db_map.add_entity_class_item(name="VisibleStuff", active_by_default=True) + self.assertIsNone(error) out_db_map.commit_session("Add entity classes.") entity_classes = out_db_map.query(out_db_map.wide_entity_class_sq).all() self.assertEqual(len(entity_classes), 2) @@ -124,120 +121,29 @@ def test_read_active_by_default_from_database(self): with self.subTest(class_name=name): self.assertEqual(activity, expected_activity) - def test_add_parameter_value_without_value_gives_error(self): - with DatabaseMapping("sqlite://", create=True) as db_map: - self._assert_success(db_map.add_entity_class_item(name="Widget")) - self._assert_success(db_map.add_entity_item(name="spoon", entity_class_name="Widget")) - self._assert_success(db_map.add_parameter_definition_item(name="size", entity_class_name="Widget")) - self.assertRaises( - SpineDBAPIError, - db_map.add_parameter_value_item, - **dict( - parameter_definition_name="size", - entity_class_name="Widget", - entity_byname=("spoon",), - alternative_name="Base", - type=None, - ) - ) - - def test_add_parameter_value_without_type_gives_error(self): - with DatabaseMapping("sqlite://", create=True) as db_map: - self._assert_success(db_map.add_entity_class_item(name="Widget")) - self._assert_success(db_map.add_entity_item(name="spoon", entity_class_name="Widget")) - self._assert_success(db_map.add_parameter_definition_item(name="size", entity_class_name="Widget")) - self.assertRaises( - SpineDBAPIError, - db_map.add_parameter_value_item, - **dict( - parameter_definition_name="size", - entity_class_name="Widget", - entity_byname=("spoon",), - alternative_name="Base", - value=to_database(2.3)[0], - ) - ) - - def test_restore_uncommitted_item(self): - with DatabaseMapping("sqlite://", create=True) as db_map: - item, error = db_map.add_entity_class_item(name="my_class") - self.assertIsNone(error) - self.assertEqual(item["name"], "my_class") - self.assertTrue(item.is_valid()) - self.assertFalse(item.is_committed()) - item.remove() - self.assertFalse(item.is_valid()) - item.restore() - self.assertTrue(item.is_valid()) - - def test_restore_committed_and_removed_item(self): - with DatabaseMapping("sqlite://", create=True) as db_map: - item, error = db_map.add_entity_class_item(name="my_class") - self.assertIsNone(error) - self.assertEqual(item["name"], "my_class") - self.assertTrue(item.is_valid()) - self.assertFalse(item.is_committed()) - db_map.commit_session("Add entity class") - self.assertTrue(item.is_committed()) - entity_classes = db_map.query(db_map.entity_class_sq).all() - self.assertEqual(len(entity_classes), 1) - item.remove() - self.assertFalse(item.is_valid()) - self.assertFalse(item.is_committed()) - db_map.commit_session("Remove entity class") - self.assertFalse(item.is_valid()) - self.assertTrue(item.is_committed()) - entity_classes = db_map.query(db_map.entity_class_sq).all() - self.assertEqual(len(entity_classes), 0) - item.restore() - self.assertTrue(item.is_valid()) - self.assertFalse(item.is_committed()) - db_map.commit_session("Restore entity class") - self.assertTrue(item.is_valid()) - self.assertTrue(item.is_committed()) - entity_classes = db_map.query(db_map.entity_class_sq).all() - self.assertEqual(len(entity_classes), 1) - - def test_add_commit_update_commit(self): - with DatabaseMapping("sqlite://", create=True) as db_map: - item, error = db_map.add_entity_class_item(name="my_class") - self.assertIsNone(error) - self.assertEqual(item["name"], "my_class") - self.assertTrue(item.is_valid()) - self.assertFalse(item.is_committed()) - db_map.commit_session("Add entity class") - self.assertTrue(item.is_committed()) - item.update(name="renamed") - self.assertFalse(item.is_committed()) - db_map.commit_session("Rename entity class") - self.assertTrue(item.is_committed()) - entity_classes = db_map.query(db_map.entity_class_sq).all() - self.assertEqual(len(entity_classes), 1) - self.assertEqual(entity_classes[0].name, "renamed") - def test_commit_parameter_value(self): with TemporaryDirectory() as temp_dir: url = "sqlite:///" + os.path.join(temp_dir, "database.sqlite") with DatabaseMapping(url, create=True) as db_map: - self._assert_success(db_map.add_item("entity_class", name="fish", description="It swims.")) - self._assert_success( - db_map.add_item( - "entity", entity_class_name="fish", name="Nemo", description="Peacefully swimming away." - ) + _, error = db_map.add_item("entity_class", name="fish", description="It swims.") + self.assertIsNone(error) + _, error = db_map.add_item( + "entity", entity_class_name="fish", name="Nemo", description="Peacefully swimming away." ) - self._assert_success(db_map.add_item("parameter_definition", entity_class_name="fish", name="color")) + self.assertIsNone(error) + _, error = db_map.add_item("parameter_definition", entity_class_name="fish", name="color") + self.assertIsNone(error) value, type_ = to_database("mainly orange") - self._assert_success( - db_map.add_item( - "parameter_value", - entity_class_name="fish", - entity_byname=("Nemo",), - parameter_definition_name="color", - alternative_name="Base", - value=value, - type=type_, - ) + _, error = db_map.add_item( + "parameter_value", + entity_class_name="fish", + entity_byname=("Nemo",), + parameter_definition_name="color", + alternative_name="Base", + value=value, + type=type_, ) + self.assertIsNone(error) db_map.commit_session("Added data") with DatabaseMapping(url) as db_map: color = db_map.get_item( @@ -254,42 +160,38 @@ def test_commit_multidimensional_parameter_value(self): with TemporaryDirectory() as temp_dir: url = "sqlite:///" + os.path.join(temp_dir, "database.sqlite") with DatabaseMapping(url, create=True) as db_map: - self._assert_success(db_map.add_item("entity_class", name="fish", description="It swims.")) - self._assert_success(db_map.add_item("entity_class", name="cat", description="Eats fish.")) - self._assert_success( - db_map.add_item( - "entity_class", - name="fish__cat", - dimension_name_list=("fish", "cat"), - description="A fish getting eaten by a cat?", - ) - ) - self._assert_success( - db_map.add_item("entity", entity_class_name="fish", name="Nemo", description="Lost (soon).") - ) - self._assert_success( - db_map.add_item( - "entity", entity_class_name="cat", name="Felix", description="The wonderful wonderful cat." - ) + _, error = db_map.add_item("entity_class", name="fish", description="It swims.") + self.assertIsNone(error) + _, error = db_map.add_item("entity_class", name="cat", description="Eats fish.") + self.assertIsNone(error) + _, error = db_map.add_item( + "entity_class", + name="fish__cat", + dimension_name_list=("fish", "cat"), + description="A fish getting eaten by a cat?", ) - self._assert_success( - db_map.add_item("entity", entity_class_name="fish__cat", element_name_list=("Nemo", "Felix")) - ) - self._assert_success( - db_map.add_item("parameter_definition", entity_class_name="fish__cat", name="rate") + self.assertIsNone(error) + _, error = db_map.add_item("entity", entity_class_name="fish", name="Nemo", description="Lost (soon).") + self.assertIsNone(error) + _, error = db_map.add_item( + "entity", entity_class_name="cat", name="Felix", description="The wonderful wonderful cat." ) + self.assertIsNone(error) + _, error = db_map.add_item("entity", entity_class_name="fish__cat", element_name_list=("Nemo", "Felix")) + self.assertIsNone(error) + _, error = db_map.add_item("parameter_definition", entity_class_name="fish__cat", name="rate") + self.assertIsNone(error) value, type_ = to_database(0.23) - self._assert_success( - db_map.add_item( - "parameter_value", - entity_class_name="fish__cat", - entity_byname=("Nemo", "Felix"), - parameter_definition_name="rate", - alternative_name="Base", - value=value, - type=type_, - ) + _, error = db_map.add_item( + "parameter_value", + entity_class_name="fish__cat", + entity_byname=("Nemo", "Felix"), + parameter_definition_name="rate", + alternative_name="Base", + value=value, + type=type_, ) + self.assertIsNone(error) db_map.commit_session("Added data") with DatabaseMapping(url) as db_map: color = db_map.get_item( @@ -304,25 +206,25 @@ def test_commit_multidimensional_parameter_value(self): def test_updating_entity_name_updates_the_name_in_parameter_value_too(self): with DatabaseMapping(IN_MEMORY_DB_URL, create=True) as db_map: - self._assert_success(db_map.add_item("entity_class", name="fish", description="It swims.")) - self._assert_success( - db_map.add_item( - "entity", entity_class_name="fish", name="Nemo", description="Peacefully swimming away." - ) + _, error = db_map.add_item("entity_class", name="fish", description="It swims.") + self.assertIsNone(error) + _, error = db_map.add_item( + "entity", entity_class_name="fish", name="Nemo", description="Peacefully swimming away." ) - self._assert_success(db_map.add_item("parameter_definition", entity_class_name="fish", name="color")) + self.assertIsNone(error) + _, error = db_map.add_item("parameter_definition", entity_class_name="fish", name="color") + self.assertIsNone(error) value, type_ = to_database("mainly orange") - self._assert_success( - db_map.add_item( - "parameter_value", - entity_class_name="fish", - entity_byname=("Nemo",), - parameter_definition_name="color", - alternative_name="Base", - value=value, - type=type_, - ) + _, error = db_map.add_item( + "parameter_value", + entity_class_name="fish", + entity_byname=("Nemo",), + parameter_definition_name="color", + alternative_name="Base", + value=value, + type=type_, ) + self.assertIsNone(error) color = db_map.get_item( "parameter_value", entity_class_name="fish", @@ -354,22 +256,21 @@ def test_updating_entity_name_updates_the_name_in_parameter_value_too(self): def test_update_entity_metadata_by_changing_its_entity(self): with DatabaseMapping("sqlite://", create=True) as db_map: - self._assert_success(db_map.add_entity_class_item(name="my_class")) - self._assert_success(db_map.add_entity_item(name="entity_1", entity_class_name="my_class")) - entity_2 = self._assert_success(db_map.add_entity_item(name="entity_2", entity_class_name="my_class")) + entity_class, _ = db_map.add_entity_class_item(name="my_class") + db_map.add_entity_item(name="entity_1", entity_class_name="my_class") + entity_2, _ = db_map.add_entity_item(name="entity_2", entity_class_name="my_class") metadata_value = '{"sources": [], "contributors": []}' - metadata = self._assert_success(db_map.add_metadata_item(name="my_metadata", value=metadata_value)) - entity_metadata = self._assert_success( - db_map.add_entity_metadata_item( - metadata_name="my_metadata", - metadata_value=metadata_value, - entity_class_name="my_class", - entity_byname=("entity_1",), - ) + metadata, _ = db_map.add_metadata_item(name="my_metadata", value=metadata_value) + entity_metadata, error = db_map.add_entity_metadata_item( + metadata_name="my_metadata", + metadata_value=metadata_value, + entity_class_name="my_class", + entity_byname=("entity_1",), ) + self.assertIsNone(error) entity_metadata.update(entity_byname=("entity_2",)) self.assertEqual( - entity_metadata.extended(), + entity_metadata._extended(), { "entity_class_name": "my_class", "entity_byname": ("entity_2",), @@ -416,47 +317,46 @@ def test_update_entity_metadata_by_changing_its_entity(self): def test_update_parameter_value_metadata_by_changing_its_parameter(self): with DatabaseMapping("sqlite://", create=True) as db_map: - entity_class = self._assert_success(db_map.add_entity_class_item(name="my_class")) - self._assert_success(db_map.add_parameter_definition_item(name="x", entity_class_name="my_class")) - self._assert_success(db_map.add_parameter_definition_item(name="y", entity_class_name="my_class")) - self._assert_success(db_map.add_entity_item(name="my_entity", entity_class_name="my_class")) + entity_class, _ = db_map.add_entity_class_item(name="my_class") + _, error = db_map.add_parameter_definition_item(name="x", entity_class_name="my_class") + self.assertIsNone(error) + db_map.add_parameter_definition_item(name="y", entity_class_name="my_class") + entity, _ = db_map.add_entity_item(name="my_entity", entity_class_name="my_class") value, value_type = to_database(2.3) - self._assert_success( - db_map.add_parameter_value_item( - entity_class_name="my_class", - entity_byname=("my_entity",), - parameter_definition_name="x", - alternative_name="Base", - value=value, - type=value_type, - ) + _, error = db_map.add_parameter_value_item( + entity_class_name="my_class", + entity_byname=("my_entity",), + parameter_definition_name="x", + alternative_name="Base", + value=value, + type=value_type, ) + self.assertIsNone(error) value, value_type = to_database(-2.3) - y = self._assert_success( - db_map.add_parameter_value_item( - entity_class_name="my_class", - entity_byname=("my_entity",), - parameter_definition_name="y", - alternative_name="Base", - value=value, - type=value_type, - ) + y, error = db_map.add_parameter_value_item( + entity_class_name="my_class", + entity_byname=("my_entity",), + parameter_definition_name="y", + alternative_name="Base", + value=value, + type=value_type, ) + self.assertIsNone(error) metadata_value = '{"sources": [], "contributors": []}' - metadata = self._assert_success(db_map.add_metadata_item(name="my_metadata", value=metadata_value)) - value_metadata = self._assert_success( - db_map.add_parameter_value_metadata_item( - metadata_name="my_metadata", - metadata_value=metadata_value, - entity_class_name="my_class", - entity_byname=("my_entity",), - parameter_definition_name="x", - alternative_name="Base", - ) + metadata, error = db_map.add_metadata_item(name="my_metadata", value=metadata_value) + self.assertIsNone(error) + value_metadata, error = db_map.add_parameter_value_metadata_item( + metadata_name="my_metadata", + metadata_value=metadata_value, + entity_class_name="my_class", + entity_byname=("my_entity",), + parameter_definition_name="x", + alternative_name="Base", ) + self.assertIsNone(error) value_metadata.update(parameter_definition_name="y") self.assertEqual( - value_metadata.extended(), + value_metadata._extended(), { "entity_class_name": "my_class", "entity_byname": ("my_entity",), @@ -520,14 +420,15 @@ def test_update_parameter_value_metadata_by_changing_its_parameter(self): def test_fetch_more(self): with DatabaseMapping("sqlite://", create=True) as db_map: alternatives = db_map.fetch_more("alternative") - expected = [{"id": -1, "name": "Base", "description": "Base alternative", "commit_id": -1}] + expected = [{"id": 1, "name": "Base", "description": "Base alternative", "commit_id": 1}] self.assertEqual([a._asdict() for a in alternatives], expected) - def test_fetch_more_after_commit(self): + def test_fetch_more_after_commit_and_refresh(self): with DatabaseMapping("sqlite://", create=True) as db_map: db_map.add_item("entity_class", name="Widget") db_map.add_item("entity", entity_class_name="Widget", name="gadget") db_map.commit_session("Add test data.") + db_map.refresh_session() entities = db_map.fetch_more("entity") self.assertEqual([(x["entity_class_name"], x["name"]) for x in entities], [("Widget", "gadget")]) @@ -565,20 +466,18 @@ def test_fetch_entities_that_refer_to_unfetched_entities(self): with TemporaryDirectory() as temp_dir: url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") with DatabaseMapping(url, create=True) as db_map: - self._assert_success(db_map.add_entity_class_item(name="dog")) - self._assert_success(db_map.add_entity_class_item(name="cat")) - self._assert_success(db_map.add_entity_class_item(name="dog__cat", dimension_name_list=("dog", "cat"))) - self._assert_success(db_map.add_entity_item(name="Pulgoso", entity_class_name="dog")) - self._assert_success(db_map.add_entity_item(name="Sylvester", entity_class_name="cat")) - self._assert_success(db_map.add_entity_item(name="Tom", entity_class_name="cat")) + db_map.add_entity_class_item(name="dog") + db_map.add_entity_class_item(name="cat") + db_map.add_entity_class_item(name="dog__cat", dimension_name_list=("dog", "cat")) + db_map.add_entity_item(name="Pulgoso", entity_class_name="dog") + db_map.add_entity_item(name="Sylvester", entity_class_name="cat") + db_map.add_entity_item(name="Tom", entity_class_name="cat") db_map.commit_session("Arf!") with DatabaseMapping(url) as db_map: # Remove the entity in the middle and add a multi-D one referring to the third entity. # The multi-D one will go in the middle. db_map.get_entity_item(name="Sylvester", entity_class_name="cat").remove() - self._assert_success( - db_map.add_entity_item(element_name_list=("Pulgoso", "Tom"), entity_class_name="dog__cat") - ) + db_map.add_entity_item(element_name_list=("Pulgoso", "Tom"), entity_class_name="dog__cat") db_map.commit_session("Meow!") with DatabaseMapping(url) as db_map: # The ("Pulgoso", "Tom") entity will be fetched before "Tom". @@ -590,19 +489,24 @@ def test_committing_scenario_alternatives(self): with TemporaryDirectory() as temp_dir: url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") with DatabaseMapping(url, create=True) as db_map: - item = self._assert_success(db_map.add_alternative_item(name="alt1")) + item, error = db_map.add_alternative_item(name="alt1") + self.assertIsNone(error) self.assertIsNotNone(item) - item = self._assert_success(db_map.add_alternative_item(name="alt2")) + item, error = db_map.add_alternative_item(name="alt2") + self.assertIsNone(error) self.assertIsNotNone(item) - item = self._assert_success(db_map.add_scenario_item(name="my_scenario")) + item, error = db_map.add_scenario_item(name="my_scenario") + self.assertIsNone(error) self.assertIsNotNone(item) - item = self._assert_success( - db_map.add_scenario_alternative_item(scenario_name="my_scenario", alternative_name="alt1", rank=0) + item, error = db_map.add_scenario_alternative_item( + scenario_name="my_scenario", alternative_name="alt1", rank=0 ) + self.assertIsNone(error) self.assertIsNotNone(item) - item = self._assert_success( - db_map.add_scenario_alternative_item(scenario_name="my_scenario", alternative_name="alt2", rank=1) + item, error = db_map.add_scenario_alternative_item( + scenario_name="my_scenario", alternative_name="alt2", rank=1 ) + self.assertIsNone(error) self.assertIsNotNone(item) db_map.commit_session("Add test data.") with DatabaseMapping(url) as db_map: @@ -617,694 +521,32 @@ def test_committing_scenario_alternatives(self): def test_committing_entity_class_items_doesnt_add_commit_ids_to_them(self): with DatabaseMapping("sqlite://", create=True) as db_map: - self._assert_success(db_map.add_entity_class_item(name="my_class")) + db_map.add_entity_class_item(name="my_class") db_map.commit_session("Add class.") classes = db_map.get_entity_class_items() self.assertEqual(len(classes), 1) - self.assertNotIn("commit_id", classes[0].extended()) + self.assertNotIn("commit_id", classes[0]._extended()) def test_committing_superclass_subclass_items_doesnt_add_commit_ids_to_them(self): with DatabaseMapping("sqlite://", create=True) as db_map: - self._assert_success(db_map.add_entity_class_item(name="high")) - self._assert_success(db_map.add_entity_class_item(name="low")) - self._assert_success(db_map.add_superclass_subclass_item(superclass_name="high", subclass_name="low")) + db_map.add_entity_class_item(name="high") + db_map.add_entity_class_item(name="low") + db_map.add_superclass_subclass_item(superclass_name="high", subclass_name="low") db_map.commit_session("Add class hierarchy.") classes = db_map.get_superclass_subclass_items() self.assertEqual(len(classes), 1) - self.assertNotIn("commit_id", classes[0].extended()) + self.assertNotIn("commit_id", classes[0]._extended()) def test_committing_entity_group_items_doesnt_add_commit_ids_to_them(self): with DatabaseMapping("sqlite://", create=True) as db_map: - self._assert_success(db_map.add_entity_class_item(name="my_class")) - self._assert_success(db_map.add_entity_item(name="element", entity_class_name="my_class")) - self._assert_success(db_map.add_entity_item(name="container", entity_class_name="my_class")) - self._assert_success( - db_map.add_entity_group_item( - group_name="container", member_name="element", entity_class_name="my_class" - ) - ) + db_map.add_entity_class_item(name="my_class") + db_map.add_entity_item(name="element", entity_class_name="my_class") + db_map.add_entity_item(name="container", entity_class_name="my_class") + db_map.add_entity_group_item(group_name="container", member_name="element", entity_class_name="my_class") db_map.commit_session("Add entity group.") groups = db_map.get_entity_group_items() self.assertEqual(len(groups), 1) - self.assertNotIn("commit_id", groups[0].extended()) - - def test_additive_commit_from_another_db_map_gets_fetched(self): - with TemporaryDirectory() as temp_dir: - url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") - with DatabaseMapping(url, create=True) as db_map: - items = db_map.get_items("entity") - self.assertEqual(len(items), 0) - with DatabaseMapping(url) as shadow_db_map: - self._assert_success(shadow_db_map.add_entity_class_item(name="my_class")) - self._assert_success(shadow_db_map.add_entity_item(name="my_entity", entity_class_name="my_class")) - shadow_db_map.commit_session("Add entity.") - items = db_map.get_items("entity") - self.assertEqual(len(items), 1) - self.assertEqual( - items[0]._asdict(), - { - "id": -1, - "name": "my_entity", - "description": None, - "class_id": -1, - "element_name_list": None, - "element_id_list": (), - "commit_id": -2, - }, - ) - - def test_updating_item_from_another_db_map_is_overwritten_by_default_conflict_resolution(self): - with TemporaryDirectory() as temp_dir: - url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") - with DatabaseMapping(url, create=True) as db_map: - self._assert_success(db_map.add_entity_class_item(name="my_class")) - value, type_ = to_database(2.3) - original_item = self._assert_success( - db_map.add_parameter_definition_item( - entity_class_name="my_class", name="measurable", default_type=type_, default_value=value - ) - ) - db_map.commit_session("Add initial data.") - self.assertTrue(original_item.is_committed()) - definitions = db_map.query(db_map.parameter_definition_sq).all() - self.assertEqual(len(definitions), 1) - value = from_database(definitions[0]["default_value"], definitions[0]["default_type"]) - self.assertEqual(value, 2.3) - with DatabaseMapping(url) as shadow_db_map: - items = shadow_db_map.get_items("parameter_definition") - self.assertEqual(len(items), 1) - value, type_ = to_database(5.0) - items[0].update(default_value=value, default_type=type_) - shadow_db_map.commit_session("Changed default value.") - definitions = shadow_db_map.query(shadow_db_map.parameter_definition_sq).all() - self.assertEqual(len(definitions), 1) - value = from_database(definitions[0]["default_value"], definitions[0]["default_type"]) - self.assertEqual(value, 5.0) - items = db_map.get_items("parameter_definition") - self.assertEqual(len(items), 1) - self.assertEqual(items[0], original_item) - self.assertFalse(items[0].is_committed()) - db_map.commit_session("Restore default value back to original.") - definitions = db_map.query(db_map.parameter_definition_sq).all() - self.assertEqual(len(definitions), 1) - value = from_database(definitions[0]["default_value"], definitions[0]["default_type"]) - self.assertEqual(value, 2.3) - - def test_resolve_an_update_conflict_in_favor_of_external_modification(self): - with TemporaryDirectory() as temp_dir: - url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") - with DatabaseMapping(url, create=True) as db_map: - self._assert_success(db_map.add_entity_class_item(name="my_class")) - self._assert_success(db_map.add_entity_item(name="my_entity", entity_class_name="my_class")) - db_map.commit_session("Add initial data.") - with DatabaseMapping(url) as shadow_db_map: - items = shadow_db_map.fetch_more("entity") - self.assertEqual(len(items), 1) - updated_item = items[0] - updated_item.update(name="renamed_entity") - shadow_db_map.commit_session("Renamed the entity.") - items = db_map.fetch_more("entity", resolve_conflicts=select_in_db_item_always) - self.assertEqual(len(items), 1) - self.assertEqual(items[0].item_type, updated_item.item_type) - for keys, values in EntityItem.unique_values_for_item(items[0]): - for key, value in zip(keys, values): - with self.subTest(key=key): - self.assertEqual(value, updated_item[key]) - - def test_recreating_deleted_item_externally_brings_it_back_if_favored_by_conflict_resolution(self): - with TemporaryDirectory() as temp_dir: - url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") - with DatabaseMapping(url, create=True) as db_map: - removed_item = self._assert_success(db_map.add_entity_class_item(name="my_class")) - self._assert_success(db_map.add_entity_item(name="my_entity", entity_class_name="my_class")) - db_map.commit_session("Add initial data.") - removed_item.remove() - db_map.commit_session("Remove entity class.") - self.assertTrue(removed_item.is_committed()) - with DatabaseMapping(url) as shadow_db_map: - items = shadow_db_map.fetch_more("entity_class") - self.assertEqual(len(items), 0) - self._assert_success(shadow_db_map.add_entity_class_item(name="my_class")) - shadow_db_map.commit_session("Added entity class back.") - items = db_map.get_items("entity_class", resolve_conflicts=select_in_db_item_always) - self.assertEqual(len(items), 1) - self.assertTrue(items[0].is_valid()) - self.assertFalse(items[0].is_committed()) - - def test_restoring_entity_whose_db_id_has_been_replaced_by_external_db_modification(self): - with TemporaryDirectory() as temp_dir: - url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") - with DatabaseMapping(url, create=True) as db_map: - self._assert_success(db_map.add_entity_class_item(name="my_class")) - item = self._assert_success(db_map.add_entity_item(entity_class_name="my_class", name="my_entity")) - original_id = item["id"] - db_map.commit_session("Add initial data.") - items = db_map.fetch_more("entity") - self.assertEqual(len(items), 1) - db_map.remove_item("entity", original_id) - db_map.commit_session("Removed entity.") - self.assertEqual(len(db_map.get_entity_items()), 0) - with DatabaseMapping(url) as shadow_db_map: - self._assert_success( - shadow_db_map.add_entity_item(entity_class_name="my_class", name="other_entity") - ) - shadow_db_map.commit_session("Add entity with different name, probably reusing previous id.") - items = db_map.fetch_more("entity") - self.assertEqual(len(items), 1) - self.assertEqual(items[0]["name"], "other_entity") - all_items = db_map.get_entity_items() - self.assertEqual(len(all_items), 1) - restored_item = db_map.restore_item("entity", original_id) - self.assertEqual(restored_item["name"], "my_entity") - all_items = db_map.get_entity_items() - self.assertEqual(len(all_items), 2) - - def test_cunning_ways_to_make_external_changes(self): - with TemporaryDirectory() as temp_dir: - url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") - with DatabaseMapping(url, create=True) as db_map: - self._assert_success(db_map.add_entity_class_item(name="interesting_class")) - self._assert_success(db_map.add_entity_class_item(name="filler_class")) - self._assert_success( - db_map.add_parameter_definition_item(name="quality", entity_class_name="interesting_class") - ) - self._assert_success( - db_map.add_parameter_definition_item(name="quantity", entity_class_name="filler_class") - ) - self._assert_success( - db_map.add_entity_item(name="object_of_interest", entity_class_name="interesting_class") - ) - value, value_type = to_database(2.3) - self._assert_success( - db_map.add_parameter_value_item( - parameter_definition_name="quality", - entity_class_name="interesting_class", - entity_byname=("object_of_interest",), - alternative_name="Base", - value=value, - type=value_type, - ) - ) - db_map.commit_session("Add initial data") - removed_item = db_map.get_entity_item(name="object_of_interest", entity_class_name="interesting_class") - removed_item.remove() - db_map.commit_session("Remove object of interest") - with DatabaseMapping(url) as shadow_db_map: - self._assert_success( - shadow_db_map.add_entity_item(name="other_entity", entity_class_name="interesting_class") - ) - self._assert_success(shadow_db_map.add_entity_item(name="filler", entity_class_name="filler_class")) - value, value_type = to_database(-2.3) - self._assert_success( - shadow_db_map.add_parameter_value_item( - parameter_definition_name="quantity", - entity_class_name="filler_class", - entity_byname=("filler",), - alternative_name="Base", - value=value, - type=value_type, - ) - ) - value, value_type = to_database(99.9) - self._assert_success( - shadow_db_map.add_parameter_value_item( - parameter_definition_name="quality", - entity_class_name="interesting_class", - entity_byname=("other_entity",), - alternative_name="Base", - value=value, - type=value_type, - ) - ) - shadow_db_map.commit_session("Add entities.") - entity_items = db_map.get_entity_items() - self.assertEqual(len(entity_items), 2) - self.assertEqual( - entity_items[0].extended(), - { - "id": -2, - "name": "other_entity", - "description": None, - "class_id": -1, - "element_id_list": (), - "element_name_list": (), - "commit_id": -4, - "entity_class_name": "interesting_class", - "dimension_id_list": (), - "dimension_name_list": (), - "element_byname_list": (), - "superclass_id": None, - "superclass_name": None, - }, - ) - self.assertEqual( - entity_items[1].extended(), - { - "id": -3, - "name": "filler", - "description": None, - "class_id": -2, - "element_id_list": (), - "element_name_list": (), - "commit_id": -4, - "entity_class_name": "filler_class", - "dimension_id_list": (), - "dimension_name_list": (), - "element_byname_list": (), - "superclass_id": None, - "superclass_name": None, - }, - ) - value_items = db_map.get_parameter_value_items() - self.assertEqual(len(value_items), 2) - self.assertTrue(removed_item.is_committed()) - self.assertEqual( - value_items[0].extended(), - { - "alternative_id": -1, - "alternative_name": "Base", - "commit_id": -4, - "dimension_id_list": (), - "dimension_name_list": (), - "element_id_list": (), - "element_name_list": (), - "entity_byname": ("filler",), - "entity_class_id": -2, - "entity_class_name": "filler_class", - "entity_id": -3, - "entity_name": "filler", - "id": -2, - "list_value_id": None, - "parameter_definition_id": -2, - "parameter_definition_name": "quantity", - "parameter_value_list_id": None, - "parameter_value_list_name": None, - "type": to_database(-2.3)[1], - "value": to_database(-2.3)[0], - }, - ) - self.assertEqual( - value_items[1].extended(), - { - "alternative_id": -1, - "alternative_name": "Base", - "commit_id": -4, - "dimension_id_list": (), - "dimension_name_list": (), - "element_id_list": (), - "element_name_list": (), - "entity_byname": ("other_entity",), - "entity_class_id": -1, - "entity_class_name": "interesting_class", - "entity_id": -2, - "entity_name": "other_entity", - "id": -3, - "list_value_id": None, - "parameter_definition_id": -1, - "parameter_definition_name": "quality", - "parameter_value_list_id": None, - "parameter_value_list_name": None, - "type": to_database(99.9)[1], - "value": to_database(99.9)[0], - }, - ) - - def test_update_entity_metadata_externally(self): - with TemporaryDirectory() as temp_dir: - url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") - with DatabaseMapping(url, create=True) as db_map: - self._assert_success(db_map.add_entity_class_item(name="my_class")) - self._assert_success(db_map.add_entity_item(name="my_entity", entity_class_name="my_class")) - metadata_value = '{"sources": [], "contributors": []}' - self._assert_success(db_map.add_metadata_item(name="my_metadata", value=metadata_value)) - self._assert_success( - db_map.add_entity_metadata_item( - metadata_name="my_metadata", - metadata_value=metadata_value, - entity_class_name="my_class", - entity_byname=("my_entity",), - ) - ) - db_map.commit_session("Add initial data.") - with DatabaseMapping(url) as shadow_db_map: - self._assert_success( - shadow_db_map.add_entity_item(name="other_entity", entity_class_name="my_class") - ) - metadata_item = shadow_db_map.get_entity_metadata_item( - metadata_name="my_metadata", - metadata_value=metadata_value, - entity_class_name="my_class", - entity_byname=("my_entity",), - ) - self.assertTrue(metadata_item) - metadata_item.update(entity_byname=("other_entity",)) - shadow_db_map.commit_session("Move entity metadata to another entity") - metadata_items = db_map.get_entity_metadata_items() - self.assertEqual(len(metadata_items), 2) - self.assertEqual( - metadata_items[0].extended(), - { - "id": -1, - "entity_class_name": "my_class", - "entity_byname": ("my_entity",), - "entity_id": -1, - "metadata_id": -1, - "metadata_name": "my_metadata", - "metadata_value": metadata_value, - "commit_id": None, - }, - ) - self.assertFalse(metadata_items[0].is_committed()) - self.assertEqual( - metadata_items[1].extended(), - { - "id": -2, - "entity_class_name": "my_class", - "entity_byname": ("other_entity",), - "entity_id": -2, - "metadata_id": -1, - "metadata_name": "my_metadata", - "metadata_value": metadata_value, - "commit_id": -3, - }, - ) - self.assertTrue(metadata_items[1].is_committed()) - - def test_update_parameter_value_metadata_externally(self): - with TemporaryDirectory() as temp_dir: - url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") - with DatabaseMapping(url, create=True) as db_map: - self._assert_success(db_map.add_entity_class_item(name="my_class")) - self._assert_success(db_map.add_parameter_definition_item(name="x", entity_class_name="my_class")) - self._assert_success(db_map.add_entity_item(name="my_entity", entity_class_name="my_class")) - value, value_type = to_database(2.3) - self._assert_success( - db_map.add_parameter_value_item( - entity_class_name="my_class", - entity_byname=("my_entity",), - parameter_definition_name="x", - alternative_name="Base", - value=value, - type=value_type, - ) - ) - metadata_value = '{"sources": [], "contributors": []}' - self._assert_success(db_map.add_metadata_item(name="my_metadata", value=metadata_value)) - self._assert_success( - db_map.add_parameter_value_metadata_item( - metadata_name="my_metadata", - metadata_value=metadata_value, - entity_class_name="my_class", - entity_byname=("my_entity",), - parameter_definition_name="x", - alternative_name="Base", - ) - ) - db_map.commit_session("Add initial data.") - with DatabaseMapping(url) as shadow_db_map: - self._assert_success( - shadow_db_map.add_entity_item(name="other_entity", entity_class_name="my_class") - ) - value, value_type = to_database(5.0) - self._assert_success( - shadow_db_map.add_parameter_value_item( - entity_class_name="my_class", - entity_byname=("other_entity",), - parameter_definition_name="x", - alternative_name="Base", - value=value, - type=value_type, - ) - ) - metadata_item = shadow_db_map.get_parameter_value_metadata_item( - metadata_name="my_metadata", - metadata_value=metadata_value, - entity_class_name="my_class", - entity_byname=("my_entity",), - parameter_definition_name="x", - alternative_name="Base", - ) - self.assertTrue(metadata_item) - metadata_item.update(entity_byname=("other_entity",)) - shadow_db_map.commit_session("Move parameter value metadata to another entity") - metadata_items = db_map.get_parameter_value_metadata_items() - self.assertEqual(len(metadata_items), 2) - self.assertEqual( - metadata_items[0].extended(), - { - "id": -1, - "entity_class_name": "my_class", - "parameter_definition_name": "x", - "parameter_value_id": -1, - "entity_byname": ("my_entity",), - "metadata_id": -1, - "metadata_name": "my_metadata", - "metadata_value": metadata_value, - "alternative_name": "Base", - "commit_id": None, - }, - ) - self.assertFalse(metadata_items[0].is_committed()) - self.assertEqual( - metadata_items[1].extended(), - { - "id": -2, - "entity_class_name": "my_class", - "parameter_definition_name": "x", - "parameter_value_id": -2, - "entity_byname": ("other_entity",), - "metadata_id": -1, - "metadata_name": "my_metadata", - "metadata_value": metadata_value, - "alternative_name": "Base", - "commit_id": -3, - }, - ) - self.assertTrue(metadata_items[1].is_committed()) - - def test_update_entity_alternative_externally(self): - with TemporaryDirectory() as temp_dir: - url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") - with DatabaseMapping(url, create=True) as db_map: - self._assert_success(db_map.add_entity_class_item(name="my_class")) - self._assert_success(db_map.add_entity_item(name="my_entity", entity_class_name="my_class")) - self._assert_success( - db_map.add_entity_alternative_item( - entity_byname=("my_entity",), - entity_class_name="my_class", - alternative_name="Base", - active=False, - ) - ) - db_map.commit_session("Add initial data.") - with DatabaseMapping(url) as shadow_db_map: - self._assert_success( - shadow_db_map.add_entity_item(name="other_entity", entity_class_name="my_class") - ) - entity_alternative = shadow_db_map.get_entity_alternative_item( - entity_class_name="my_class", entity_byname=("my_entity",), alternative_name="Base" - ) - self.assertTrue(entity_alternative) - entity_alternative.update(entity_byname=("other_entity",)) - shadow_db_map.commit_session("Move entity alternative to another entity.") - entity_alternatives = db_map.get_entity_alternative_items() - self.assertEqual(len(entity_alternatives), 2) - self.assertEqual( - entity_alternatives[0].extended(), - { - "id": -1, - "entity_class_name": "my_class", - "entity_class_id": -1, - "entity_byname": ("my_entity",), - "entity_name": "my_entity", - "entity_id": -1, - "dimension_name_list": (), - "dimension_id_list": (), - "element_name_list": (), - "element_id_list": (), - "alternative_name": "Base", - "alternative_id": -1, - "active": False, - "commit_id": None, - }, - ) - self.assertFalse(entity_alternatives[0].is_committed()) - self.assertEqual( - entity_alternatives[1].extended(), - { - "id": -2, - "entity_class_name": "my_class", - "entity_class_id": -1, - "entity_byname": ("other_entity",), - "entity_name": "other_entity", - "entity_id": -2, - "dimension_name_list": (), - "dimension_id_list": (), - "element_name_list": (), - "element_id_list": (), - "alternative_name": "Base", - "alternative_id": -1, - "active": False, - "commit_id": -3, - }, - ) - self.assertTrue(entity_alternatives[1].is_committed()) - - def test_update_superclass_subclass_externally(self): - with TemporaryDirectory() as temp_dir: - url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") - with DatabaseMapping(url, create=True) as db_map: - self._assert_success(db_map.add_entity_class_item(name="ceiling")) - self._assert_success(db_map.add_entity_class_item(name="floor")) - self._assert_success(db_map.add_entity_class_item(name="soil")) - self._assert_success( - db_map.add_superclass_subclass_item(superclass_name="ceiling", subclass_name="floor") - ) - db_map.commit_session("Add initial data.") - with DatabaseMapping(url) as shadow_db_map: - superclass_subclass = shadow_db_map.get_superclass_subclass_item(subclass_name="floor") - superclass_subclass.update(subclass_name="soil") - shadow_db_map.commit_session("Changes subclass to another one.") - superclass_subclasses = db_map.get_superclass_subclass_items() - self.assertEqual(len(superclass_subclasses), 2) - self.assertEqual( - superclass_subclasses[0].extended(), - { - "id": -1, - "superclass_name": "ceiling", - "superclass_id": -1, - "subclass_name": "floor", - "subclass_id": -2, - }, - ) - self.assertFalse(superclass_subclasses[0].is_committed()) - self.assertEqual( - superclass_subclasses[1].extended(), - { - "id": -2, - "superclass_name": "ceiling", - "superclass_id": -1, - "subclass_name": "soil", - "subclass_id": -3, - }, - ) - self.assertTrue(superclass_subclasses[1].is_committed()) - - def test_adding_same_parameters_values_to_different_entities_externally(self): - with TemporaryDirectory() as temp_dir: - url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") - with DatabaseMapping(url, create=True) as db_map: - self._assert_success(db_map.add_entity_class_item(name="my_class")) - self._assert_success(db_map.add_parameter_definition_item(name="x", entity_class_name="my_class")) - my_entity = self._assert_success(db_map.add_entity_item(name="my_entity", entity_class_name="my_class")) - value, value_type = to_database(2.3) - self._assert_success( - db_map.add_parameter_value_item( - entity_class_name="my_class", - entity_byname=("my_entity",), - parameter_definition_name="x", - alternative_name="Base", - value=value, - type=value_type, - ) - ) - db_map.commit_session("Add initial data.") - my_entity.remove() - db_map.commit_session("Remove entity.") - with DatabaseMapping(url) as shadow_db_map: - self._assert_success( - shadow_db_map.add_entity_item(name="other_entity", entity_class_name="my_class") - ) - self._assert_success( - shadow_db_map.add_parameter_value_item( - entity_class_name="my_class", - entity_byname=("other_entity",), - parameter_definition_name="x", - alternative_name="Base", - value=value, - type=value_type, - ) - ) - shadow_db_map.commit_session("Add another entity.") - values = db_map.get_parameter_value_items() - self.assertEqual(len(values), 1) - self.assertEqual( - values[0].extended(), - { - "id": -2, - "entity_class_name": "my_class", - "entity_class_id": -1, - "dimension_name_list": (), - "dimension_id_list": (), - "parameter_definition_name": "x", - "parameter_definition_id": -1, - "entity_byname": ("other_entity",), - "entity_name": "other_entity", - "entity_id": -2, - "element_name_list": (), - "element_id_list": (), - "alternative_name": "Base", - "alternative_id": -1, - "parameter_value_list_name": None, - "parameter_value_list_id": None, - "list_value_id": None, - "type": value_type, - "value": value, - "commit_id": -4, - }, - ) - - def test_committing_changes_purged_entity_has_been_overwritten_by_external_change(self): - with TemporaryDirectory() as temp_dir: - url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") - with DatabaseMapping(url, create=True) as db_map: - self._assert_success(db_map.add_entity_class_item(name="my_class")) - self._assert_success(db_map.add_entity_item(name="ghost", entity_class_name="my_class")) - db_map.commit_session("Add soon-to-be-removed entity.") - db_map.purge_items("entity") - db_map.commit_session("Purge entities.") - with DatabaseMapping(url) as shadow_db_map: - self._assert_success( - shadow_db_map.add_entity_item(name="other_entity", entity_class_name="my_class") - ) - shadow_db_map.commit_session("Add another entity that steals ghost's id.") - db_map.do_fetch_all("entity") - self.assertFalse(db_map.any_uncommitted_items()) - self._assert_success(db_map.add_entity_item(name="dirty_entity", entity_class_name="my_class")) - self.assertTrue(db_map.any_uncommitted_items()) - db_map.commit_session("Add still uncommitted entity.") - entities = db_map.query(db_map.wide_entity_sq).all() - self.assertEqual(len(entities), 2) - - def test_reset_purging(self): - with TemporaryDirectory() as temp_dir: - url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") - with DatabaseMapping(url, create=True) as db_map: - self._assert_success(db_map.add_entity_class_item(name="my_class")) - db_map.commit_session("Add entity_class.") - self._assert_success(db_map.add_entity_item(name="my_entity", entity_class_name="my_class")) - db_map.purge_items("entity") - with DatabaseMapping(url) as shadow_db_map: - self._assert_success( - shadow_db_map.add_entity_item(name="other_entity", entity_class_name="my_class") - ) - shadow_db_map.commit_session("Add another entity that should not be purged.") - db_map.reset_purging() - entities = db_map.get_entity_items("entity") - self.assertEqual(len(entities), 1) - self.assertEqual(entities[0]["name"], "other_entity") - - def test_remove_items_by_asterisk(self): - with DatabaseMapping("sqlite://", create=True) as db_map: - self._assert_success(db_map.add_alternative_item(name="alt_1")) - self._assert_success(db_map.add_alternative_item(name="alt_2")) - db_map.commit_session("Add alternatives.") - alternatives = db_map.get_alternative_items() - self.assertEqual(len(alternatives), 3) - db_map.remove_items("alternative", Asterisk) - db_map.commit_session("Remove all alternatives.") - alternatives = db_map.get_alternative_items() - self.assertEqual(alternatives, []) + self.assertNotIn("commit_id", groups[0]._extended()) class TestDatabaseMappingLegacy(unittest.TestCase): @@ -2272,7 +1514,6 @@ def test_add_parameter_values(self): "entity_id": nemo_row.id, "entity_class_id": nemo_row.class_id, "value": b'"orange"', - "type": None, "alternative_id": 1, }, { @@ -2280,7 +1521,6 @@ def test_add_parameter_values(self): "entity_id": nemo__pluto_row.id, "entity_class_id": nemo__pluto_row.class_id, "value": b"125", - "type": None, "alternative_id": 1, }, ) @@ -2332,7 +1572,6 @@ def test_add_same_parameter_value_twice(self): "entity_id": nemo_row.id, "entity_class_id": nemo_row.class_id, "value": b'"orange"', - "type": None, "alternative_id": 1, }, { @@ -2340,7 +1579,6 @@ def test_add_same_parameter_value_twice(self): "entity_id": nemo_row.id, "entity_class_id": nemo_row.class_id, "value": b'"blue"', - "type": None, "alternative_id": 1, }, ) @@ -2650,12 +1888,8 @@ def test_add_entity_to_a_class_with_abstract_dimensions(self): import_functions.import_entity_classes( self._db_map, (("fish", ()), ("dog", ()), ("animal", ()), ("two_animals", ("animal", "animal"))) ) - count, errors = import_functions.import_superclass_subclasses( - self._db_map, (("animal", "fish"), ("animal", "dog")) - ) - self.assertEqual(errors, []) - count, errors = import_functions.import_entities(self._db_map, (("fish", "Nemo"), ("dog", "Pulgoso"))) - self.assertEqual(errors, []) + import_functions.import_superclass_subclasses(self._db_map, (("animal", "fish"), ("animal", "dog"))) + import_functions.import_entities(self._db_map, (("fish", "Nemo"), ("dog", "Pulgoso"))) self._db_map.commit_session("Add test data.") item, error = self._db_map.add_item( "entity", entity_class_name="two_animals", element_name_list=("Nemo", "Pulgoso") @@ -2675,18 +1909,13 @@ def setUp(self): def tearDown(self): self._db_map.close() - def _assert_success(self, result): - items, errors = result - self.assertEqual(errors, []) - return items - def test_update_object_classes(self): """Test that updating object classes works.""" self._db_map.add_object_classes({"id": 1, "name": "fish"}, {"id": 2, "name": "dog"}) items, intgr_error_log = self._db_map.update_object_classes( {"id": 1, "name": "octopus"}, {"id": 2, "name": "god"} ) - ids = {1, 2} + ids = {x["id"] for x in items} self._db_map.commit_session("test commit") sq = self._db_map.object_class_sq object_classes = {x.id: x.name for x in self._db_map.query(sq).filter(sq.c.id.in_(ids))} @@ -2699,7 +1928,7 @@ def test_update_objects(self): self._db_map.add_object_classes({"id": 1, "name": "fish"}) self._db_map.add_objects({"id": 1, "name": "nemo", "class_id": 1}, {"id": 2, "name": "dory", "class_id": 1}) items, intgr_error_log = self._db_map.update_objects({"id": 1, "name": "klaus"}, {"id": 2, "name": "squidward"}) - ids = {1, 2} + ids = {x["id"] for x in items} self._db_map.commit_session("test commit") sq = self._db_map.object_sq objects = {x.id: x.name for x in self._db_map.query(sq).filter(sq.c.id.in_(ids))} @@ -2709,11 +1938,11 @@ def test_update_objects(self): def test_update_committed_object(self): """Test that updating objects works.""" - self._assert_success(self._db_map.add_object_classes({"id": 1, "name": "some_class"})) - self._assert_success(self._db_map.add_objects({"id": 1, "name": "nemo", "class_id": 1})) + self._db_map.add_object_classes({"id": 1, "name": "some_class"}) + self._db_map.add_objects({"id": 1, "name": "nemo", "class_id": 1}) self._db_map.commit_session("update") items, intgr_error_log = self._db_map.update_objects({"id": 1, "name": "klaus"}) - ids = {1} + ids = {x["id"] for x in items} self._db_map.commit_session("test commit") sq = self._db_map.object_sq objects = {x.id: x.name for x in self._db_map.query(sq).filter(sq.c.id.in_(ids))} @@ -2731,7 +1960,7 @@ def test_update_relationship_classes(self): items, intgr_error_log = self._db_map.update_wide_relationship_classes( {"id": 3, "name": "god__octopus"}, {"id": 4, "name": "octopus__dog"} ) - ids = {3, 4} + ids = {x["id"] for x in items} self._db_map.commit_session("test commit") sq = self._db_map.wide_relationship_class_sq rel_clss = {x.id: x.name for x in self._db_map.query(sq).filter(sq.c.id.in_(ids))} @@ -2740,15 +1969,13 @@ def test_update_relationship_classes(self): self.assertEqual(rel_clss[4], "octopus__dog") def test_update_committed_relationship_class(self): - self._assert_success(import_functions.import_object_classes(self._db_map, ("object_class_1",))) - self._assert_success( - import_functions.import_relationship_classes(self._db_map, (("my_class", ("object_class_1",)),)) - ) + _ = import_functions.import_object_classes(self._db_map, ("object_class_1",)) + _ = import_functions.import_relationship_classes(self._db_map, (("my_class", ("object_class_1",)),)) self._db_map.commit_session("Add test data") items, errors = self._db_map.update_wide_relationship_classes({"id": 2, "name": "renamed"}) updated_ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(updated_ids, {-2}) + self.assertEqual(updated_ids, {2}) self._db_map.commit_session("Update data.") classes = self._db_map.query(self._db_map.wide_relationship_class_sq).all() self.assertEqual(len(classes), 1) @@ -2771,38 +1998,29 @@ def test_update_relationship_class_does_not_update_member_class_id(self): def test_update_relationships(self): """Test that updating relationships works.""" - self._assert_success(self._db_map.add_object_classes({"name": "fish", "id": 1}, {"name": "dog", "id": 2})) - self._assert_success( - self._db_map.add_wide_relationship_classes({"name": "fish__dog", "id": 3, "object_class_id_list": [1, 2]}) - ) - self._assert_success( - self._db_map.add_objects( - {"name": "nemo", "id": 1, "class_id": 1}, - {"name": "pluto", "id": 2, "class_id": 2}, - {"name": "scooby", "id": 3, "class_id": 2}, - ) + self._db_map.add_object_classes({"name": "fish", "id": 1}, {"name": "dog", "id": 2}) + self._db_map.add_wide_relationship_classes({"name": "fish__dog", "id": 3, "object_class_id_list": [1, 2]}) + self._db_map.add_objects( + {"name": "nemo", "id": 1, "class_id": 1}, + {"name": "pluto", "id": 2, "class_id": 2}, + {"name": "scooby", "id": 3, "class_id": 2}, ) - self._assert_success( - self._db_map.add_wide_relationships( - { - "id": 4, - "name": "nemo__pluto", - "class_id": 3, - "object_id_list": [1, 2], - "object_class_id_list": [1, 2], - } - ) + self._db_map.add_wide_relationships( + {"id": 4, "name": "nemo__pluto", "class_id": 3, "object_id_list": [1, 2], "object_class_id_list": [1, 2]} ) items, intgr_error_log = self._db_map.update_wide_relationships( {"id": 4, "name": "nemo__scooby", "class_id": 3, "object_id_list": [1, 3], "object_class_id_list": [1, 2]} ) + ids = {x["id"] for x in items} self._db_map.commit_session("test commit") sq = self._db_map.wide_relationship_sq - rels = [{"name": x.name, "object_id_list": x.object_id_list} for x in self._db_map.query(sq)] + rels = { + x.id: {"name": x.name, "object_id_list": x.object_id_list} + for x in self._db_map.query(sq).filter(sq.c.id.in_(ids)) + } self.assertEqual(intgr_error_log, []) - self.assertEqual(len(rels), 1) - self.assertEqual(rels[0]["name"], "nemo__scooby") - self.assertEqual(rels[0]["object_id_list"], "1,3") + self.assertEqual(rels[4]["name"], "nemo__scooby") + self.assertEqual(rels[4]["object_id_list"], "1,3") def test_update_committed_relationship(self): import_functions.import_object_classes(self._db_map, ("object_class_1", "object_class_2")) @@ -2818,7 +2036,7 @@ def test_update_committed_relationship(self): items, errors = self._db_map.update_wide_relationships({"id": 4, "name": "renamed", "object_id_list": [2, 3]}) updated_ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(updated_ids, {-4}) + self.assertEqual(updated_ids, {4}) self._db_map.commit_session("Update data.") relationships = self._db_map.query(self._db_map.wide_relationship_sq).all() self.assertEqual(len(relationships), 1) @@ -2836,7 +2054,7 @@ def test_update_parameter_value_by_id_only(self): items, errors = self._db_map.update_parameter_values({"id": 1, "value": b"something else"}) updated_ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(updated_ids, {-1}) + self.assertEqual(updated_ids, {1}) self._db_map.commit_session("Update data.") pvals = self._db_map.query(self._db_map.parameter_value_sq).all() self.assertEqual(len(pvals), 1) @@ -2868,7 +2086,7 @@ def test_update_parameter_definition_by_id_only(self): items, errors = self._db_map.update_parameter_definitions({"id": 1, "name": "parameter2"}) updated_ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(updated_ids, {-1}) + self.assertEqual(updated_ids, {1}) self._db_map.commit_session("Update data.") pdefs = self._db_map.query(self._db_map.parameter_definition_sq).all() self.assertEqual(len(pdefs), 1) @@ -2884,7 +2102,7 @@ def test_update_parameter_definition_value_list(self): ) updated_ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(updated_ids, {-1}) + self.assertEqual(updated_ids, {1}) self._db_map.commit_session("Update data.") pdefs = self._db_map.query(self._db_map.parameter_definition_sq).all() self.assertEqual(len(pdefs), 1) @@ -2984,7 +2202,7 @@ def test_update_object_metadata_reuses_existing_metadata(self): ) ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(ids, {-1}) + self.assertEqual(ids, {1}) self._db_map.remove_unused_metadata() self._db_map.commit_session("Update data") metadata_entries = self._db_map.query(self._db_map.metadata_sq).all() @@ -3090,7 +2308,7 @@ def test_update_metadata(self): items, errors = self._db_map.update_metadata(*({"id": 1, "name": "author", "value": "Prof. T. Est"},)) ids = {x["id"] for x in items} self.assertEqual(errors, []) - self.assertEqual(ids, {-1}) + self.assertEqual(ids, {1}) self._db_map.commit_session("Update data") metadata_records = self._db_map.query(self._db_map.metadata_sq).all() self.assertEqual(len(metadata_records), 1) @@ -3224,7 +2442,6 @@ def test_remove_parameter_value(self): self._db_map.add_parameter_values( { "value": b"0", - "type": None, "id": 1, "parameter_definition_id": 1, "object_id": 1, @@ -3247,7 +2464,6 @@ def test_remove_parameter_value_from_committed_session(self): self._db_map.add_parameter_values( { "value": b"0", - "type": None, "id": 1, "parameter_definition_id": 1, "object_id": 1, @@ -3270,7 +2486,6 @@ def test_cascade_remove_object_removes_parameter_value_as_well(self): self._db_map.add_parameter_values( { "value": b"0", - "type": None, "id": 1, "parameter_definition_id": 1, "object_id": 1, @@ -3293,7 +2508,6 @@ def test_cascade_remove_object_from_committed_session_removes_parameter_value_as self._db_map.add_parameter_values( { "value": b"0", - "type": None, "id": 1, "parameter_definition_id": 1, "object_id": 1, @@ -3580,6 +2794,7 @@ def test_refresh_addition(self): import_functions.import_object_classes(self._db_map, ("second_class",)) entity_class_names = {x["name"] for x in self._db_map.mapped_table("entity_class").valid_values()} self.assertEqual(entity_class_names, {"my_class", "second_class"}) + self._db_map.refresh_session() self._db_map.fetch_all() entity_class_names = {x["name"] for x in self._db_map.mapped_table("entity_class").valid_values()} self.assertEqual(entity_class_names, {"my_class", "second_class"}) @@ -3587,9 +2802,10 @@ def test_refresh_addition(self): def test_refresh_removal(self): import_functions.import_object_classes(self._db_map, ("my_class",)) self._db_map.commit_session("test commit") - self._db_map.remove_items("entity_class", -1) + self._db_map.remove_items("entity_class", 1) entity_class_names = {x["name"] for x in self._db_map.mapped_table("entity_class").valid_values()} self.assertEqual(entity_class_names, set()) + self._db_map.refresh_session() self._db_map.fetch_all() entity_class_names = {x["name"] for x in self._db_map.mapped_table("entity_class").valid_values()} self.assertEqual(entity_class_names, set()) @@ -3600,9 +2816,10 @@ def test_refresh_update(self): self._db_map.get_item("entity_class", name="my_class").update(name="new_name") entity_class_names = {x["name"] for x in self._db_map.mapped_table("entity_class").valid_values()} self.assertEqual(entity_class_names, {"new_name"}) + self._db_map.refresh_session() self._db_map.fetch_all() entity_class_names = {x["name"] for x in self._db_map.mapped_table("entity_class").valid_values()} - self.assertEqual(entity_class_names, {"new_name", "my_class"}) + self.assertEqual(entity_class_names, {"new_name"}) def test_cascade_remove_unfetched(self): import_functions.import_object_classes(self._db_map, ("my_class",)) diff --git a/tests/test_db_mapping_base.py b/tests/test_db_mapping_base.py new file mode 100644 index 00000000..179c1ed8 --- /dev/null +++ b/tests/test_db_mapping_base.py @@ -0,0 +1,80 @@ +###################################################################################################################### +# Copyright (C) 2017-2022 Spine project consortium +# This file is part of Spine Database API. +# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser +# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General +# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with +# this program. If not, see . +###################################################################################################################### +import unittest + +from spinedb_api.db_mapping_base import MappedItemBase, DatabaseMappingBase + + +class TestDBMapping(DatabaseMappingBase): + @staticmethod + def item_types(): + return ["cutlery"] + + @staticmethod + def all_item_types(): + return ["cutlery"] + + @staticmethod + def item_factory(item_type): + if item_type == "cutlery": + return MappedItemBase + raise RuntimeError(f"unknown item_type '{item_type}'") + + +class TestDBMappingBase(unittest.TestCase): + def test_rolling_back_new_item_invalidates_its_id(self): + db_map = TestDBMapping() + mapped_table = db_map.mapped_table("cutlery") + item = mapped_table.add_item({}) + self.assertTrue(item.is_id_valid) + self.assertIn("id", item) + id_ = item["id"] + db_map._rollback() + self.assertFalse(item.is_id_valid) + self.assertEqual(item["id"], id_) + + +class TestMappedTable(unittest.TestCase): + def test_readding_item_with_invalid_id_creates_new_id(self): + db_map = TestDBMapping() + mapped_table = db_map.mapped_table("cutlery") + item = mapped_table.add_item({}) + id_ = item["id"] + db_map._rollback() + self.assertFalse(item.is_id_valid) + mapped_table.add_item(item) + self.assertTrue(item.is_id_valid) + self.assertNotEqual(item["id"], id_) + + +class TestMappedItemBase(unittest.TestCase): + def test_id_is_valid_initially(self): + db_map = TestDBMapping() + item = MappedItemBase(db_map, "cutlery") + self.assertTrue(item.is_id_valid) + + def test_id_can_be_invalidated(self): + db_map = TestDBMapping() + item = MappedItemBase(db_map, "cutlery") + item.invalidate_id() + self.assertFalse(item.is_id_valid) + + def test_setting_new_id_validates_it(self): + db_map = TestDBMapping() + item = MappedItemBase(db_map, "cutlery") + item.invalidate_id() + self.assertFalse(item.is_id_valid) + item["id"] = 23 + self.assertTrue(item.is_id_valid) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 194f8afd..7d3a8b7b 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -19,10 +19,8 @@ get_head_alembic_version, name_from_dimensions, name_from_elements, - query_byname, remove_credentials_from_url, ) -from spinedb_api.db_mapping import DatabaseMapping class TestNameFromElements(unittest.TestCase): @@ -71,78 +69,6 @@ def test_password_with_special_characters(self): self.assertEqual(sanitized, "mysql://example.com/db") -class TestQueryByname(unittest.TestCase): - def _assert_success(self, result): - item, error = result - self.assertIsNone(error) - return item - - def test_zero_dimension_entity(self): - with DatabaseMapping("sqlite://", create=True) as db_map: - self._assert_success(db_map.add_entity_class_item(name="my_class")) - self._assert_success(db_map.add_entity_item(name="my_entity", entity_class_name="my_class")) - db_map.commit_session("Add entity.") - entity_row = db_map.query(db_map.wide_entity_sq).one() - self.assertEqual(query_byname(entity_row, db_map), ("my_entity",)) - - def test_dimensioned_entity(self): - with DatabaseMapping("sqlite://", create=True) as db_map: - self._assert_success(db_map.add_entity_class_item(name="class_1")) - self._assert_success(db_map.add_entity_class_item(name="class_2")) - self._assert_success(db_map.add_entity_item(name="entity_1", entity_class_name="class_1")) - self._assert_success(db_map.add_entity_item(name="entity_2", entity_class_name="class_2")) - self._assert_success( - db_map.add_entity_class_item(name="relationship", dimension_name_list=("class_1", "class_2")) - ) - relationship = self._assert_success( - db_map.add_entity_item(entity_class_name="relationship", element_name_list=("entity_1", "entity_2")) - ) - db_map.commit_session("Add entities") - entity_row = ( - db_map.query(db_map.wide_entity_sq) - .filter(db_map.wide_entity_sq.c.id == db_map.find_db_id("entity", relationship["id"])) - .one() - ) - self.assertEqual(query_byname(entity_row, db_map), ("entity_1", "entity_2")) - - def test_deep_dimensioned_entity(self): - with DatabaseMapping("sqlite://", create=True) as db_map: - self._assert_success(db_map.add_entity_class_item(name="class_1")) - self._assert_success(db_map.add_entity_class_item(name="class_2")) - self._assert_success(db_map.add_entity_item(name="entity_1", entity_class_name="class_1")) - self._assert_success(db_map.add_entity_item(name="entity_2", entity_class_name="class_2")) - self._assert_success( - db_map.add_entity_class_item(name="relationship_1", dimension_name_list=("class_1", "class_2")) - ) - relationship_1 = self._assert_success( - db_map.add_entity_item(entity_class_name="relationship_1", element_name_list=("entity_1", "entity_2")) - ) - self._assert_success( - db_map.add_entity_class_item(name="relationship_2", dimension_name_list=("class_2", "class_1")) - ) - relationship_2 = self._assert_success( - db_map.add_entity_item(entity_class_name="relationship_2", element_name_list=("entity_2", "entity_1")) - ) - self._assert_success( - db_map.add_entity_class_item( - name="super_relationship", dimension_name_list=("relationship_1", "relationship_2") - ) - ) - superrelationship = self._assert_success( - db_map.add_entity_item( - entity_class_name="super_relationship", - element_name_list=(relationship_1["name"], relationship_2["name"]), - ) - ) - db_map.commit_session("Add entities") - entity_row = ( - db_map.query(db_map.wide_entity_sq) - .filter(db_map.wide_entity_sq.c.id == db_map.find_db_id("entity", superrelationship["id"])) - .one() - ) - self.assertEqual(query_byname(entity_row, db_map), ("entity_1", "entity_2", "entity_2", "entity_1")) - - class TestGetHeadAlembicVersion(unittest.TestCase): def test_returns_latest_version(self): # This test must be updated each time new migration script is added. diff --git a/tests/test_item_id.py b/tests/test_item_id.py deleted file mode 100644 index e1965ae3..00000000 --- a/tests/test_item_id.py +++ /dev/null @@ -1,71 +0,0 @@ -###################################################################################################################### -# Copyright (C) 2017-2022 Spine project consortium -# Copyright Spine Database API contributors -# This file is part of Spine Database API. -# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser -# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your -# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; -# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General -# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with -# this program. If not, see . -###################################################################################################################### -import unittest - -from spinedb_api.item_id import IdFactory, IdMap - - -class TestIdFactory(unittest.TestCase): - def test_ids_are_negative_and_consecutive(self): - factory = IdFactory() - self.assertEqual(factory.next_id(), -1) - self.assertEqual(factory.next_id(), -2) - - -class TestIdMap(unittest.TestCase): - def test_add_item_id(self): - id_map = IdMap() - id_map.add_item_id(-2) - self.assertIsNone(id_map.db_id(-2)) - - def test_remove_item_id(self): - id_map = IdMap() - id_map.set_db_id(-2, 3) - id_map.remove_item_id(-2) - self.assertRaises(KeyError, id_map.item_id, 3) - self.assertRaises(KeyError, id_map.db_id, -2) - - def test_set_db_id(self): - id_map = IdMap() - id_map.set_db_id(-2, 3) - self.assertEqual(id_map.db_id(-2), 3) - self.assertEqual(id_map.item_id(3), -2) - - def test_remove_db_id_using_db_id(self): - id_map = IdMap() - id_map.set_db_id(-2, 3) - id_map.remove_db_id(3) - self.assertIsNone(id_map.db_id(-2)) - self.assertRaises(KeyError, id_map.item_id, 3) - - def test_remove_db_id_using_item_id(self): - id_map = IdMap() - id_map.set_db_id(-2, 3) - id_map.remove_db_id(-2) - self.assertIsNone(id_map.db_id(-2)) - self.assertRaises(KeyError, id_map.item_id, 3) - - def test_item_id(self): - id_map = IdMap() - id_map.set_db_id(-2, 3) - self.assertEqual(id_map.item_id(3), -2) - self.assertRaises(KeyError, id_map.item_id, 99) - - def test_db_id(self): - id_map = IdMap() - id_map.set_db_id(-2, 3) - self.assertEqual(id_map.db_id(-2), 3) - self.assertRaises(KeyError, id_map.db_id, -99) - - -if __name__ == '__main__': - unittest.main() From 6cf762c188174916d791d23da2dccfd13fd3fc5b Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 1 Feb 2024 09:00:05 +0100 Subject: [PATCH 240/317] Speed up scenario filtered entity subquery Seems better to compute and join element count earlier in the query rather than later --- spinedb_api/filters/scenario_filter.py | 24 ++++++++++++++++++------ spinedb_api/query.py | 3 +++ 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/spinedb_api/filters/scenario_filter.py b/spinedb_api/filters/scenario_filter.py index 80d264fa..8b8a1604 100644 --- a/spinedb_api/filters/scenario_filter.py +++ b/spinedb_api/filters/scenario_filter.py @@ -258,6 +258,14 @@ def _make_scenario_filtered_entity_sq(db_map, state): Alias: a subquery for entity filtered by selected scenario """ ext_entity_sq = _ext_entity_sq(db_map, state) + ext_entity_element_count_sq = ( + db_map.query( + db_map.entity_element_sq.c.entity_id, + func.count(db_map.entity_element_sq.c.element_id).label("element_count"), + ) + .group_by(db_map.entity_element_sq.c.entity_id) + .subquery() + ) ext_entity_class_dimension_count_sq = ( db_map.query( db_map.entity_class_dimension_sq.c.entity_class_id, @@ -281,17 +289,21 @@ def _make_scenario_filtered_entity_sq(db_map, state): and_(ext_entity_sq.c.active == None, ext_entity_sq.c.active_by_default == True), ), ) + .outerjoin( + ext_entity_element_count_sq, + ext_entity_element_count_sq.c.entity_id == ext_entity_sq.c.id, + ) .outerjoin( ext_entity_class_dimension_count_sq, ext_entity_class_dimension_count_sq.c.entity_class_id == ext_entity_sq.c.class_id, ) - .outerjoin(db_map.entity_element_sq, ext_entity_sq.c.id == db_map.entity_element_sq.c.entity_id) - .group_by(ext_entity_sq.c.id) - .having( + .filter( or_( - ext_entity_class_dimension_count_sq.c.dimension_count == None, - ext_entity_class_dimension_count_sq.c.dimension_count - == func.count(db_map.entity_element_sq.c.element_id), + and_( + ext_entity_element_count_sq.c.element_count == None, + ext_entity_class_dimension_count_sq.c.dimension_count == None, + ), + ext_entity_element_count_sq.c.element_count == ext_entity_class_dimension_count_sq.c.dimension_count, ) ) .subquery() diff --git a/spinedb_api/query.py b/spinedb_api/query.py index 3cbdbe00..345894d9 100644 --- a/spinedb_api/query.py +++ b/spinedb_api/query.py @@ -31,6 +31,9 @@ def __init__(self, bind, *entities): self._select = select(entities) self._from = None + def __str__(self): + return str(self._select) + @property def column_descriptions(self): return [{"name": c.name} for c in self._select.columns] From 3c2c9a416558777cb4baec0f3781d77c168a5b0c Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 2 Feb 2024 10:24:33 +0100 Subject: [PATCH 241/317] Fix commit Re #345 --- spinedb_api/db_mapping.py | 18 ++++++++++++------ spinedb_api/db_mapping_base.py | 9 +-------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index ff2a8edf..43414f32 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -679,6 +679,11 @@ def query(self, *args, **kwargs): """ return Query(self.engine, *args) + def _get_db_lock(self, connection): + if self.sa_url.get_dialect() == "sqlite": + connection.execute("BEGIN IMMEDIATE") + # TODO: Other dialects? Do they need it? + def commit_session(self, comment): """Commits the changes from the in-memory mapping to the database. @@ -690,13 +695,14 @@ def commit_session(self, comment): """ if not comment: raise SpineDBAPIError("Commit message cannot be empty.") - dirty_items = self._dirty_items() - if not dirty_items: - raise SpineDBAPIError("Nothing to commit.") - user = self.username - date = datetime.now(timezone.utc) - ins = self._metadata.tables["commit"].insert() with self.engine.begin() as connection: + self._get_db_lock(connection) + dirty_items = self._dirty_items() + if not dirty_items: + raise SpineDBAPIError("Nothing to commit.") + user = self.username + date = datetime.now(timezone.utc) + ins = self._metadata.tables["commit"].insert() try: commit_id = connection.execute(ins, dict(user=user, date=date, comment=comment)).inserted_primary_key[0] except DBAPIError as e: diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index f14a0cd6..5ebecd0d 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -148,6 +148,7 @@ def _dirty_items(self): purged_item_types = {x for x in self.item_types() if self.mapped_table(x).purged} self._add_descendants(purged_item_types) for item_type in self._sorted_item_types: + self.do_fetch_all(item_type) # To fix conflicts in add_item_from_db mapped_table = self.mapped_table(item_type) to_add = [] to_update = [] @@ -164,14 +165,6 @@ def _dirty_items(self): _ = item.is_valid() if item.status == Status.to_remove: to_remove.append(item) - if to_remove: - # Fetch descendants, so that they are validated in next iterations of the loop. - # This ensures cascade removal. - # FIXME: We should also fetch the current item type because of multi-dimensional entities and - # classes which also depend on zero-dimensional ones - for other_item_type in self.item_types(): - if item_type in self.item_factory(other_item_type).ref_types(): - self.fetch_all(other_item_type) if to_add or to_update or to_remove: dirty_items.append((item_type, (to_add, to_update, to_remove))) return dirty_items From 74c00215e2061c2fab18ca13b76ea1240578cdf2 Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 2 Feb 2024 10:34:53 +0100 Subject: [PATCH 242/317] Add test --- tests/test_DatabaseMapping.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 3cdefd4c..f7a21a62 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -13,6 +13,7 @@ import os.path from tempfile import TemporaryDirectory import unittest +import threading from unittest import mock from unittest.mock import patch from sqlalchemy.engine.url import make_url, URL @@ -2831,6 +2832,21 @@ def test_cascade_remove_unfetched(self): ents = self._db_map.query(self._db_map.entity_sq).all() self.assertEqual(ents, []) + def test_concurrent_commit(self): + def _commit_on_thread(db_map): + db_map.commit_session("...") + + with CustomDatabaseMapping(IN_MEMORY_DB_URL) as other_db_map: + self._db_map.add_entity_class(name="dog") + self._db_map.add_entity_class(name="cat") + other_db_map.add_entity_class(name="cat") + t1 = threading.Thread(target=_commit_on_thread, args=(self._db_map,)) + t2 = threading.Thread(target=_commit_on_thread, args=(other_db_map,)) + t1.start() + t2.start() + t1.join() + t2.join() + if __name__ == "__main__": unittest.main() From a64128dbc9cecc48a49c1bbc34a7a3f08f129fec Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Fri, 2 Feb 2024 10:17:55 +0200 Subject: [PATCH 243/317] Add conversion from is_active default values to active_by_default We now convert is_active default values properly to active_by_default on migration. Re #316 --- ...b_add_active_by_default_to_entity_class.py | 2 + spinedb_api/compatibility.py | 88 +++++++++++++++++++ tests/test_DatabaseMapping.py | 37 ++++++++ 3 files changed, 127 insertions(+) diff --git a/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py b/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py index 17afafc2..35399f3e 100644 --- a/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py +++ b/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py @@ -9,6 +9,7 @@ import sqlalchemy as sa import sqlalchemy.orm +from spinedb_api.compatibility import convert_tool_feature_method_to_active_by_default # revision identifiers, used by Alembic. revision = '8b0eff478bcb' @@ -38,6 +39,7 @@ def upgrade(): class_table.update().where(class_table.c.id == sa.bindparam("target_id")).values(active_by_default=True) ) conn.execute(update_statement, [{"target_id": class_id} for class_id in dimensional_class_ids]) + convert_tool_feature_method_to_active_by_default(conn) def downgrade(): diff --git a/spinedb_api/compatibility.py b/spinedb_api/compatibility.py index 04b02804..89ca17ca 100644 --- a/spinedb_api/compatibility.py +++ b/spinedb_api/compatibility.py @@ -15,6 +15,91 @@ import sqlalchemy as sa +def convert_tool_feature_method_to_active_by_default(conn): + """Transforms default parameter values into active_by_default values, whenever the former are used in a tool filter + to control entity activity. + + Args: + conn (Connection) + + Returns: + tuple: list of entity classes to add, update and ids to remove + """ + meta = sa.MetaData(conn) + meta.reflect() + lv_table = meta.tables["list_value"] + pd_table = meta.tables["parameter_definition"] + try: + # Compute list-value id by parameter definition id for all features and methods + tfm_table = meta.tables["tool_feature_method"] + tf_table = meta.tables["tool_feature"] + f_table = meta.tables["feature"] + lv_id_by_pdef_id = { + x["parameter_definition_id"]: x["id"] + for x in conn.execute( + sa.select([lv_table.c.id, f_table.c.parameter_definition_id]) + .where(tfm_table.c.parameter_value_list_id == lv_table.c.parameter_value_list_id) + .where(tfm_table.c.method_index == lv_table.c.index) + .where(tf_table.c.id == tfm_table.c.tool_feature_id) + .where(f_table.c.id == tf_table.c.feature_id) + ) + } + except KeyError: + # It's a new DB without tool/feature/method + # we take 'is_active' as feature and JSON "yes" and true as methods + lv_id_by_pdef_id = { + x["parameter_definition_id"]: x["id"] + for x in conn.execute( + sa.select([lv_table.c.id, lv_table.c.value, pd_table.c.id.label("parameter_definition_id")]) + .where(lv_table.c.parameter_value_list_id == pd_table.c.parameter_value_list_id) + .where(pd_table.c.name == "is_active") + .where(lv_table.c.value.in_((b'"yes"', b"true"))) + ) + } + # Collect 'is_active' default values + list_value_id = sa.case( + [(pd_table.c.default_type == "list_value_ref", sa.cast(pd_table.c.default_value, sa.Integer()))], else_=None + ) + is_active_default_vals = [ + {c: x[c] for c in ("entity_class_id", "parameter_definition_id", "list_value_id")} + for x in conn.execute( + sa.select( + [ + pd_table.c.entity_class_id, + pd_table.c.id.label("parameter_definition_id"), + list_value_id.label("list_value_id"), + ] + ).where(pd_table.c.id.in_(lv_id_by_pdef_id)) + ) + ] + # Compute new active_by_default values from 'is_active' default values, + # where active_by_default is True if the value of 'is_active' is the one from the tool_feature_method specification + entity_class_items_to_update = { + x["entity_class_id"]: { + "active_by_default": x["list_value_id"] == lv_id_by_pdef_id[x["parameter_definition_id"]], + } + for x in is_active_default_vals + if x["list_value_id"] is not None + } + updated_items = [] + entity_class_table = meta.tables["entity_class"] + update_statement = entity_class_table.update() + for class_id, update in entity_class_items_to_update.items(): + conn.execute(update_statement.where(entity_class_table.c.id == class_id), update) + update["id"] = class_id + updated_items.append(update) + parameter_definitions_to_update = ( + x["parameter_definition_id"] for x in is_active_default_vals if x["list_value_id"] is not None + ) + update_statement = pd_table.update() + for definition_id in parameter_definitions_to_update: + update = {"default_value": None, "default_type": None} + conn.execute(update_statement.where(pd_table.c.id == definition_id), update) + update["id"] = definition_id + updated_items.append(update) + return [], updated_items, [] + + def convert_tool_feature_method_to_entity_alternative(conn): """Transforms parameter_value rows into entity_alternative rows, whenever the former are used in a tool filter to control entity activity. @@ -121,4 +206,7 @@ def compatibility_transformations(connection): transformations.append(("parameter_value", ((), (), pval_ids_removed))) if ea_items_added or ea_items_updated or pval_ids_removed: info.append("Convert entity activity control using tool/feature/method into entity_alternative") + _, ec_items_updated, _ = convert_tool_feature_method_to_active_by_default(connection) + if ec_items_updated: + transformations.append(("entity_class", ((), ec_items_updated, ()))) return transformations, info diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 3cdefd4c..7559ffa3 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -548,6 +548,43 @@ def test_committing_entity_group_items_doesnt_add_commit_ids_to_them(self): self.assertEqual(len(groups), 1) self.assertNotIn("commit_id", groups[0]._extended()) + def test_commit_default_value_for_parameter_called_is_active(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + db_map.add_parameter_value_list_item(name="booleans") + value, value_type = to_database(True) + db_map.add_list_value_item(parameter_value_list_name="booleans", value=value, type=value_type, index=0) + db_map.add_entity_class_item(name="Widget") + db_map.add_parameter_definition_item( + name="is_active", + entity_class_name="Widget", + parameter_value_list_name="booleans", + default_value=value, + default_type=value_type, + ) + db_map.add_entity_class_item(name="Gadget") + db_map.add_parameter_definition_item( + name="is_active", + entity_class_name="Gadget", + parameter_value_list_name="booleans", + default_value=value, + default_type=value_type, + ) + db_map.add_entity_class_item(name="NoIsActiveDefault") + db_map.add_parameter_definition_item( + name="is_active", entity_class_name="NoIsActiveDefault", parameter_value_list_name="booleans" + ) + db_map.commit_session("Add test data to see if this crashes") + active_by_defaults = { + entity_class["name"]: entity_class["active_by_default"] + for entity_class in db_map.query(db_map.wide_entity_class_sq) + } + self.assertEqual(active_by_defaults, {"Widget": True, "Gadget": True, "NoIsActiveDefault": False}) + defaults = [ + from_database(definition["default_value"], definition["default_type"]) + for definition in db_map.query(db_map.parameter_definition_sq) + ] + self.assertEqual(defaults, 3 * [None]) + class TestDatabaseMappingLegacy(unittest.TestCase): """'Backward compatibility' tests, i.e. pre-entity tests converted to work with the entity structure.""" From 2fa0c3b93fca24dce3dc3475529c05783abb124f Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Fri, 2 Feb 2024 14:19:56 +0200 Subject: [PATCH 244/317] Fix updating entity alternatives by compatibility transforms The update statement was formulated incorrectly: it was trying to update the ID of every row in the entity_alternative table. The ID should rather be in the WHERE clause so the DB knows which row to update. Re spine-tools/Spine-Toolbox#2535 --- spinedb_api/compatibility.py | 5 +-- tests/test_DatabaseMapping.py | 58 +++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 2 deletions(-) diff --git a/spinedb_api/compatibility.py b/spinedb_api/compatibility.py index 89ca17ca..8f9d9eb4 100644 --- a/spinedb_api/compatibility.py +++ b/spinedb_api/compatibility.py @@ -177,8 +177,9 @@ def convert_tool_feature_method_to_entity_alternative(conn): pval_ids_to_remove = [x["id"] for x in is_active_pvals] if ea_items_to_add: conn.execute(ea_table.insert(), ea_items_to_add) - if ea_items_to_update: - conn.execute(ea_table.update(), ea_items_to_update) + ea_update = ea_table.update() + for item in ea_items_to_update: + conn.execute(ea_update.where(ea_table.c.id == item["id"]), {"active": item["active"]}) # Delete pvals 499 at a time to avoid too many sql variables size = 499 for i in range(0, len(pval_ids_to_remove), size): diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 7559ffa3..e434cabe 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -548,6 +548,64 @@ def test_committing_entity_group_items_doesnt_add_commit_ids_to_them(self): self.assertEqual(len(groups), 1) self.assertNotIn("commit_id", groups[0]._extended()) + def test_commit_parameter_value_coincidentally_called_is_active(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + db_map.add_parameter_value_list_item(name="booleans") + value, value_type = to_database(True) + db_map.add_list_value_item(parameter_value_list_name="booleans", value=value, type=value_type, index=0) + db_map.add_entity_class_item(name="my_class") + db_map.add_parameter_definition_item( + name="is_active", entity_class_name="my_class", parameter_value_list_name="booleans" + ) + db_map.add_entity_item(name="widget1", entity_class_name="my_class") + db_map.add_entity_item(name="widget2", entity_class_name="my_class") + db_map.add_entity_item(name="no_is_active", entity_class_name="my_class") + db_map.add_entity_alternative_item( + entity_class_name="my_class", entity_byname=("widget1",), alternative_name="Base", active=False + ) + db_map.add_entity_alternative_item( + entity_class_name="my_class", entity_byname=("widget2",), alternative_name="Base", active=False + ) + db_map.add_entity_alternative_item( + entity_class_name="my_class", entity_byname=("no_is_active",), alternative_name="Base", active=False + ) + value, value_type = to_database(True) + db_map.add_parameter_value_item( + entity_class_name="my_class", + parameter_definition_name="is_active", + entity_byname=("widget1",), + alternative_name="Base", + value=value, + type=value_type, + ) + db_map.add_parameter_value_item( + entity_class_name="my_class", + parameter_definition_name="is_active", + entity_byname=("widget2",), + alternative_name="Base", + value=value, + type=value_type, + ) + db_map.commit_session("Add test data to see if this crashes.") + entity_names = {entity["id"]: entity["name"] for entity in db_map.query(db_map.wide_entity_sq)} + alternative_names = { + alternative["id"]: alternative["name"] for alternative in db_map.query(db_map.alternative_sq) + } + expected = { + ("widget1", "Base"): True, + ("widget2", "Base"): True, + ("no_is_active", "Base"): False, + } + in_database = {} + entity_alternatives = db_map.query(db_map.entity_alternative_sq) + for entity_alternative in entity_alternatives: + entity_name = entity_names[entity_alternative["entity_id"]] + alternative_name = alternative_names[entity_alternative["alternative_id"]] + in_database[(entity_name, alternative_name)] = entity_alternative["active"] + self.assertEqual(in_database, expected) + self.assertEqual(db_map.query(db_map.parameter_value_sq).all(), []) + + def test_commit_default_value_for_parameter_called_is_active(self): with DatabaseMapping("sqlite://", create=True) as db_map: db_map.add_parameter_value_list_item(name="booleans") From 0bc2aacbaf239cfd211a973a63e78b8d728449c9 Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 2 Feb 2024 16:27:32 +0100 Subject: [PATCH 245/317] Fix locking, solve conflicts, add tests --- spinedb_api/db_mapping.py | 21 ++++++--------- spinedb_api/db_mapping_base.py | 3 +++ tests/test_DatabaseMapping.py | 47 ++++++++++++++++++++++++---------- 3 files changed, 44 insertions(+), 27 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 43414f32..b39c1b78 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -679,11 +679,6 @@ def query(self, *args, **kwargs): """ return Query(self.engine, *args) - def _get_db_lock(self, connection): - if self.sa_url.get_dialect() == "sqlite": - connection.execute("BEGIN IMMEDIATE") - # TODO: Other dialects? Do they need it? - def commit_session(self, comment): """Commits the changes from the in-memory mapping to the database. @@ -696,17 +691,17 @@ def commit_session(self, comment): if not comment: raise SpineDBAPIError("Commit message cannot be empty.") with self.engine.begin() as connection: - self._get_db_lock(connection) - dirty_items = self._dirty_items() - if not dirty_items: - raise SpineDBAPIError("Nothing to commit.") - user = self.username - date = datetime.now(timezone.utc) - ins = self._metadata.tables["commit"].insert() + commit = self._metadata.tables["commit"] + commit_item = dict(user=self.username, date=datetime.now(timezone.utc), comment=comment) try: - commit_id = connection.execute(ins, dict(user=user, date=date, comment=comment)).inserted_primary_key[0] + # The below locks the DB in sqlite + commit_id = connection.execute(commit.insert(), commit_item).inserted_primary_key[0] except DBAPIError as e: raise SpineDBAPIError(f"Fail to commit: {e.orig.args}") from e + dirty_items = self._dirty_items() + if not dirty_items: + connection.execute(commit.delete().where(commit.c.id == commit_id)) + raise SpineDBAPIError("Nothing to commit.") for tablename, (to_add, to_update, to_remove) in dirty_items: for item in to_add + to_update + to_remove: item.commit(commit_id) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 5ebecd0d..8402d35a 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -489,6 +489,9 @@ def add_item_from_db(self, item): item, fetch=False, complete=False ) if current: + if current.status == Status.to_add: + current["id"].resolve(item["id"]) + current.status = Status.committed return current, False item = self._make_and_add_item(item) if self.purged: diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index f7a21a62..aa48d71b 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -14,6 +14,7 @@ from tempfile import TemporaryDirectory import unittest import threading +import multiprocessing from unittest import mock from unittest.mock import patch from sqlalchemy.engine.url import make_url, URL @@ -2832,20 +2833,38 @@ def test_cascade_remove_unfetched(self): ents = self._db_map.query(self._db_map.entity_sq).all() self.assertEqual(ents, []) - def test_concurrent_commit(self): - def _commit_on_thread(db_map): - db_map.commit_session("...") - - with CustomDatabaseMapping(IN_MEMORY_DB_URL) as other_db_map: - self._db_map.add_entity_class(name="dog") - self._db_map.add_entity_class(name="cat") - other_db_map.add_entity_class(name="cat") - t1 = threading.Thread(target=_commit_on_thread, args=(self._db_map,)) - t2 = threading.Thread(target=_commit_on_thread, args=(other_db_map,)) - t1.start() - t2.start() - t1.join() - t2.join() + +class TestDatabaseMappingConcurrent(unittest.TestCase): + def test_concurrent_commit_threading(self): + self._do_test_concurrent_commit(threading.Thread) + + def test_concurrent_commit_multiprocessing(self): + self._do_test_concurrent_commit(multiprocessing.Process) + + def _do_test_concurrent_commit(self, make_concurrent): + def _commit_on_thread(db_map, msg): + db_map.commit_session(msg) + + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "database.sqlite") + + with CustomDatabaseMapping(url, create=True) as db_map1: + with CustomDatabaseMapping(url) as db_map2: + db_map1.add_entity_class_item(name="dog") + db_map1.add_entity_class_item(name="cat") + db_map2.add_entity_class_item(name="cat") + t1 = make_concurrent(target=_commit_on_thread, args=(db_map1, "one")) + t2 = make_concurrent(target=_commit_on_thread, args=(db_map2, "two")) + t2.start() + t1.start() + t1.join() + t2.join() + + with CustomDatabaseMapping(url) as db_map: + commit_msgs = {x["comment"] for x in db_map.query(db_map.commit_sq)} + entity_class_names = {x["name"] for x in db_map.query(db_map.entity_class_sq)} + self.assertEqual(commit_msgs, {"Create the database", "one", "two"}) + self.assertEqual(entity_class_names, {"cat", "dog"}) if __name__ == "__main__": From 05e0beb572119249dea280e8dd5e80eb681a944f Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 2 Feb 2024 16:34:04 +0100 Subject: [PATCH 246/317] Fix check for uniqueness --- spinedb_api/db_mapping_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 8402d35a..ce13c2ad 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -486,7 +486,7 @@ def add_item_from_db(self, item): tuple(MappedItem,bool): The mapped item and whether it hadn't been added before. """ current = self.find_item_by_id(item["id"], fetch=False) or self.find_item_by_unique_key( - item, fetch=False, complete=False + item, fetch=False, complete=True ) if current: if current.status == Status.to_add: From 92608b83a867aa51421ded8fecc721b9455151fe Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 2 Feb 2024 17:02:31 +0100 Subject: [PATCH 247/317] Fix check for uniqueness better --- spinedb_api/db_mapping.py | 2 +- spinedb_api/db_mapping_base.py | 6 +++--- tests/test_DatabaseMapping.py | 17 +++++++++-------- tests/test_db_mapping_base.py | 3 +++ 4 files changed, 16 insertions(+), 12 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index b39c1b78..6b1d11d5 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -694,7 +694,7 @@ def commit_session(self, comment): commit = self._metadata.tables["commit"] commit_item = dict(user=self.username, date=datetime.now(timezone.utc), comment=comment) try: - # The below locks the DB in sqlite + # TODO: The below locks the DB in sqlite, how about other dialects? commit_id = connection.execute(commit.insert(), commit_item).inserted_primary_key[0] except DBAPIError as e: raise SpineDBAPIError(f"Fail to commit: {e.orig.args}") from e diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index ce13c2ad..02a15c5d 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -365,7 +365,7 @@ def find_item(self, item, skip_keys=(), fetch=True): id_ = item.get("id") if id_ is not None: return self.find_item_by_id(id_, fetch=fetch) - return self.find_item_by_unique_key(item, skip_keys=skip_keys, fetch=fetch) + return self._find_item_by_unique_key(item, skip_keys=skip_keys, fetch=fetch) def find_item_by_id(self, id_, fetch=True): current_item = self.get(id_, {}) @@ -374,7 +374,7 @@ def find_item_by_id(self, id_, fetch=True): current_item = self.get(id_, {}) return current_item - def find_item_by_unique_key(self, item, skip_keys=(), fetch=True, complete=True): + def _find_item_by_unique_key(self, item, skip_keys=(), fetch=True, complete=True): for key, value in self._db_map.item_factory(self._item_type).unique_values_for_item(item, skip_keys=skip_keys): current_item = self._unique_key_value_to_item(key, value, fetch=fetch) if current_item: @@ -486,7 +486,7 @@ def add_item_from_db(self, item): tuple(MappedItem,bool): The mapped item and whether it hadn't been added before. """ current = self.find_item_by_id(item["id"], fetch=False) or self.find_item_by_unique_key( - item, fetch=False, complete=True + item, fetch=False, complete=self._db_map.has_external_commits() ) if current: if current.status == Status.to_add: diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index aa48d71b..753caead 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -2853,18 +2853,19 @@ def _commit_on_thread(db_map, msg): db_map1.add_entity_class_item(name="dog") db_map1.add_entity_class_item(name="cat") db_map2.add_entity_class_item(name="cat") - t1 = make_concurrent(target=_commit_on_thread, args=(db_map1, "one")) - t2 = make_concurrent(target=_commit_on_thread, args=(db_map2, "two")) - t2.start() - t1.start() - t1.join() - t2.join() + c1 = make_concurrent(target=_commit_on_thread, args=(db_map1, "one")) + c2 = make_concurrent(target=_commit_on_thread, args=(db_map2, "two")) + c2.start() + c1.start() + c1.join() + c2.join() with CustomDatabaseMapping(url) as db_map: commit_msgs = {x["comment"] for x in db_map.query(db_map.commit_sq)} - entity_class_names = {x["name"] for x in db_map.query(db_map.entity_class_sq)} + entity_class_names = [x["name"] for x in db_map.query(db_map.entity_class_sq)] self.assertEqual(commit_msgs, {"Create the database", "one", "two"}) - self.assertEqual(entity_class_names, {"cat", "dog"}) + self.assertEqual(len(entity_class_names), 2) + self.assertEqual(set(entity_class_names), {"cat", "dog"}) if __name__ == "__main__": diff --git a/tests/test_db_mapping_base.py b/tests/test_db_mapping_base.py index 179c1ed8..65d75803 100644 --- a/tests/test_db_mapping_base.py +++ b/tests/test_db_mapping_base.py @@ -28,6 +28,9 @@ def item_factory(item_type): return MappedItemBase raise RuntimeError(f"unknown item_type '{item_type}'") + def _make_query(self, _item_type, **kwargs): + return None + class TestDBMappingBase(unittest.TestCase): def test_rolling_back_new_item_invalidates_its_id(self): From 0339d1d4be87d7e1752e2f1db7d1b21272421bc6 Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 2 Feb 2024 17:05:20 +0100 Subject: [PATCH 248/317] Fix typo --- spinedb_api/db_mapping_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 02a15c5d..3d6eb2da 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -485,7 +485,7 @@ def add_item_from_db(self, item): Returns: tuple(MappedItem,bool): The mapped item and whether it hadn't been added before. """ - current = self.find_item_by_id(item["id"], fetch=False) or self.find_item_by_unique_key( + current = self.find_item_by_id(item["id"], fetch=False) or self._find_item_by_unique_key( item, fetch=False, complete=self._db_map.has_external_commits() ) if current: From 6b5287ba32b22b8649e73535e06f9fda0d542bff Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 2 Feb 2024 17:21:39 +0100 Subject: [PATCH 249/317] Skip faulty test on windows --- tests/test_DatabaseMapping.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 6161f3fd..e987cc40 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -607,7 +607,6 @@ def test_commit_parameter_value_coincidentally_called_is_active(self): self.assertEqual(in_database, expected) self.assertEqual(db_map.query(db_map.parameter_value_sq).all(), []) - def test_commit_default_value_for_parameter_called_is_active(self): with DatabaseMapping("sqlite://", create=True) as db_map: db_map.add_parameter_value_list_item(name="booleans") @@ -2933,6 +2932,10 @@ class TestDatabaseMappingConcurrent(unittest.TestCase): def test_concurrent_commit_threading(self): self._do_test_concurrent_commit(threading.Thread) + @unittest.skipIf( + os.name == 'nt', + "AttributeError: Can't pickle local object 'TestDatabaseMappingConcurrent._do_test_concurrent_commit.._commit_on_thread", + ) def test_concurrent_commit_multiprocessing(self): self._do_test_concurrent_commit(multiprocessing.Process) From 4bf93d7f0a29688d42be2e6f626fed34c8b81fc1 Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 2 Feb 2024 17:32:18 +0100 Subject: [PATCH 250/317] Skip tests more aggressively --- tests/test_DatabaseMapping.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index e987cc40..52b61ffa 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -2928,14 +2928,11 @@ def test_cascade_remove_unfetched(self): self.assertEqual(ents, []) +@unittest.skipIf(os.name == 'nt') class TestDatabaseMappingConcurrent(unittest.TestCase): def test_concurrent_commit_threading(self): self._do_test_concurrent_commit(threading.Thread) - @unittest.skipIf( - os.name == 'nt', - "AttributeError: Can't pickle local object 'TestDatabaseMappingConcurrent._do_test_concurrent_commit.._commit_on_thread", - ) def test_concurrent_commit_multiprocessing(self): self._do_test_concurrent_commit(multiprocessing.Process) From 7fa63d3c06f16fb4e20dd441122b5a375c76ad6d Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 2 Feb 2024 17:38:28 +0100 Subject: [PATCH 251/317] Fix decorator --- tests/test_DatabaseMapping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 52b61ffa..d06f8271 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -2928,7 +2928,7 @@ def test_cascade_remove_unfetched(self): self.assertEqual(ents, []) -@unittest.skipIf(os.name == 'nt') +@unittest.skipIf(os.name == 'nt', "Need to fix") class TestDatabaseMappingConcurrent(unittest.TestCase): def test_concurrent_commit_threading(self): self._do_test_concurrent_commit(threading.Thread) From 28b4a68ff85d500b29a665c5de102305d8daec16 Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 2 Feb 2024 17:41:25 +0100 Subject: [PATCH 252/317] Trying out things --- spinedb_api/db_mapping_base.py | 74 ++++++++++++++++++++++++++++------ tests/test_DatabaseMapping.py | 1 + 2 files changed, 62 insertions(+), 13 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 3d6eb2da..8d681395 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -268,8 +268,9 @@ def do_fetch_more(self, item_type, offset=0, limit=None, **kwargs): items = [] new_items = [] # Add items first + fix_id_conflics = self.has_external_commits() for x in chunk: - item, new = mapped_table.add_item_from_db(x) + item, new = mapped_table.add_item_from_db(x, fix_id_conflics) if new: new_items.append(item) items.append(item) @@ -282,6 +283,32 @@ def do_fetch_more(self, item_type, offset=0, limit=None, **kwargs): def do_fetch_all(self, item_type, **kwargs): self.do_fetch_more(item_type, offset=0, limit=None, **kwargs) + def update_id(self, mapped_item, new_id): + """Updates the id of the given item to the given new_id, also in all its referees. + + Args: + mapped_item (MappedItemBase) + new_id (int) + """ + old_id = mapped_item["id"] + mapped_item["id"] = new_id + for item_type in self.item_types(): + dirty_fields = { + field + for field, (ref_type, ref_field) in self.item_factory(item_type)._references.items() + if ref_type == mapped_item.item_type and ref_field == "id" + } # Fields that might refer the id of the mapped item + if not dirty_fields: + continue + mapped_table = self.mapped_table(item_type) + for item in mapped_table.values(): + for field in dirty_fields: + value = item[field] + if isinstance(value, tuple) and old_id in value: + item[field] = tuple(new_id if id_ == old_id else id_ for id_ in value) + elif old_id == value: + item[field] = new_id + class _MappedTable(dict): def __init__(self, db_map, item_type, *args, **kwargs): @@ -476,7 +503,7 @@ def _make_and_add_item(self, item): self[item["id"]] = item return item - def add_item_from_db(self, item): + def add_item_from_db(self, item, fix_id_conflics): """Adds an item fetched from the DB. Args: @@ -485,19 +512,39 @@ def add_item_from_db(self, item): Returns: tuple(MappedItem,bool): The mapped item and whether it hadn't been added before. """ - current = self.find_item_by_id(item["id"], fetch=False) or self._find_item_by_unique_key( - item, fetch=False, complete=self._db_map.has_external_commits() - ) - if current: - if current.status == Status.to_add: - current["id"].resolve(item["id"]) - current.status = Status.committed - return current, False - item = self._make_and_add_item(item) + mapped_item = self._find_item_by_unique_key(item, fetch=False, complete=True) + print(item, mapped_item) + if mapped_item: + self._solve_id_conflict(mapped_item, item) + return mapped_item, False + mapped_item = self._make_and_add_item(item) if self.purged: # Lazy purge: instead of fetching all at purge time, we purge stuff as it comes. - item.cascade_remove(source=self.wildcard_item) - return item, True + mapped_item.cascade_remove(source=self.wildcard_item) + return mapped_item, True + + def _solve_id_conflict(self, mapped_item, db_item): + """Makes sure that mapped_item and db_item don't have conflicting ids. + Both items are equivalent in the sense they share a unique key, + so there's only room for one of them. Therefore, they must have the same id. + + Args: + current (MappedItemBase): An item in the in-memory item. + item (dict): An item just fetched from the DB. + """ + # NOTE: db_item is more recent (because it's been fetched later) so we need to trust its id. + mapped_id, db_id = mapped_item["id"], db_item["id"] + if isinstance(mapped_id, TempId): + # mapped_item was added on this session and hasn't been committed. + # Just do as if it was committed and has the id of db_item. + mapped_id.resolve(db_id) + if mapped_item.status == Status.to_add: + mapped_item.status = Status.committed + elif mapped_id != db_id: + # Both mapped_item and db_item have been committed but with a different id (it can happen). + # Change the id of mapped_item to that of db_item. + self._db_map.update_id(mapped_item, db_id) + self[db_id] = self.pop(mapped_id) def check_fields(self, item, valid_types=()): factory = self._db_map.item_factory(self._item_type) @@ -533,6 +580,7 @@ def add_item(self, item): return item def update_item(self, item): + print("update_item", item) current_item = self.find_item(item) current_item.cascade_remove_unique() current_item.update(item) diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index d06f8271..c2b346cf 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -2037,6 +2037,7 @@ def test_update_committed_object(self): self._db_map.add_object_classes({"id": 1, "name": "some_class"}) self._db_map.add_objects({"id": 1, "name": "nemo", "class_id": 1}) self._db_map.commit_session("update") + print("NOW") items, intgr_error_log = self._db_map.update_objects({"id": 1, "name": "klaus"}) ids = {x["id"] for x in items} self._db_map.commit_session("test commit") From 02130ed36a59e12eab1556c3f1ed251e33caf5a8 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Mon, 5 Feb 2024 09:34:11 +0200 Subject: [PATCH 253/317] Enhance unit tests; add DatabaseMappingBase.reset_purging() This salvages some things from a reverted commit that may be nice-to-have. --- spinedb_api/db_mapping_base.py | 11 +- tests/test_DatabaseMapping.py | 260 ++++++++++++++++----------------- 2 files changed, 140 insertions(+), 131 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index f14a0cd6..f1dba145 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -35,7 +35,7 @@ class DatabaseMappingBase: This class is not meant to be used directly. Instead, you should subclass it to fit your particular DB schema. - When subclassing, you need to implement :meth:`item_types`, :meth:`item_factory`, and :meth:`_make_query`. + When subclassing, you need to implement :meth:`item_types`, :meth:`item_factory`, and :meth:`_make_sq`. """ def __init__(self): @@ -232,6 +232,15 @@ def reset(self, *item_types): for item_type in item_types: self._mapped_tables.pop(item_type, None) + def reset_purging(self): + """Resets purging status for all item types. + + Fetching items of an item type that has been purged will automatically mark those items removed. + Resetting the purge status lets fetched items to be added unmodified. + """ + for mapped_table in self._mapped_tables.values(): + mapped_table.wildcard_item.status = Status.committed + def _add_descendants(self, item_types): while True: changed = False diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index e434cabe..af8b3284 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -25,7 +25,7 @@ SpineDBAPIError, SpineIntegrityError, ) -from spinedb_api.helpers import name_from_elements +from spinedb_api.helpers import Asterisk, name_from_elements from tests.custom_db_mapping import CustomDatabaseMapping @@ -84,27 +84,28 @@ def test_shorthand_filter_query_works(self): class TestDatabaseMapping(unittest.TestCase): + def _assert_success(self, result): + item, error = result + self.assertIsNone(error) + return item + def test_active_by_default_is_initially_false_for_zero_dimensional_entity_class(self): with DatabaseMapping("sqlite://", create=True) as db_map: - item, error = db_map.add_entity_class_item(name="Entity") - self.assertIsNone(error) + item = self._assert_success(db_map.add_entity_class_item(name="Entity")) self.assertFalse(item["active_by_default"]) def test_active_by_default_is_initially_false_for_multi_dimensional_entity_class(self): with DatabaseMapping("sqlite://", create=True) as db_map: db_map.add_entity_class_item(name="Dimension") - item, error = db_map.add_entity_class_item(name="Entity", dimension_name_list=("Dimension",)) - self.assertIsNone(error) + item = self._assert_success(db_map.add_entity_class_item(name="Entity", dimension_name_list=("Dimension",))) self.assertTrue(item["active_by_default"]) def test_read_active_by_default_from_database(self): with TemporaryDirectory() as temp_dir: url = "sqlite:///" + os.path.join(temp_dir, "database.sqlite") with DatabaseMapping(url, create=True) as out_db_map: - _, error = out_db_map.add_entity_class_item(name="HiddenStuff", active_by_default=False) - self.assertIsNone(error) - _, error = out_db_map.add_entity_class_item(name="VisibleStuff", active_by_default=True) - self.assertIsNone(error) + self._assert_success(out_db_map.add_entity_class_item(name="HiddenStuff", active_by_default=False)) + self._assert_success(out_db_map.add_entity_class_item(name="VisibleStuff", active_by_default=True)) out_db_map.commit_session("Add entity classes.") entity_classes = out_db_map.query(out_db_map.wide_entity_class_sq).all() self.assertEqual(len(entity_classes), 2) @@ -125,16 +126,13 @@ def test_commit_parameter_value(self): with TemporaryDirectory() as temp_dir: url = "sqlite:///" + os.path.join(temp_dir, "database.sqlite") with DatabaseMapping(url, create=True) as db_map: - _, error = db_map.add_item("entity_class", name="fish", description="It swims.") - self.assertIsNone(error) - _, error = db_map.add_item( + self._assert_success(db_map.add_item("entity_class", name="fish", description="It swims.")) + self._assert_success(db_map.add_item( "entity", entity_class_name="fish", name="Nemo", description="Peacefully swimming away." - ) - self.assertIsNone(error) - _, error = db_map.add_item("parameter_definition", entity_class_name="fish", name="color") - self.assertIsNone(error) + )) + self._assert_success(db_map.add_item("parameter_definition", entity_class_name="fish", name="color")) value, type_ = to_database("mainly orange") - _, error = db_map.add_item( + self._assert_success(db_map.add_item( "parameter_value", entity_class_name="fish", entity_byname=("Nemo",), @@ -142,8 +140,7 @@ def test_commit_parameter_value(self): alternative_name="Base", value=value, type=type_, - ) - self.assertIsNone(error) + )) db_map.commit_session("Added data") with DatabaseMapping(url) as db_map: color = db_map.get_item( @@ -160,29 +157,22 @@ def test_commit_multidimensional_parameter_value(self): with TemporaryDirectory() as temp_dir: url = "sqlite:///" + os.path.join(temp_dir, "database.sqlite") with DatabaseMapping(url, create=True) as db_map: - _, error = db_map.add_item("entity_class", name="fish", description="It swims.") - self.assertIsNone(error) - _, error = db_map.add_item("entity_class", name="cat", description="Eats fish.") - self.assertIsNone(error) - _, error = db_map.add_item( + self._assert_success(db_map.add_item("entity_class", name="fish", description="It swims.")) + self._assert_success(db_map.add_item("entity_class", name="cat", description="Eats fish.")) + self._assert_success(db_map.add_item( "entity_class", name="fish__cat", dimension_name_list=("fish", "cat"), description="A fish getting eaten by a cat?", - ) - self.assertIsNone(error) - _, error = db_map.add_item("entity", entity_class_name="fish", name="Nemo", description="Lost (soon).") - self.assertIsNone(error) - _, error = db_map.add_item( + )) + self._assert_success(db_map.add_item("entity", entity_class_name="fish", name="Nemo", description="Lost (soon).")) + self._assert_success(db_map.add_item( "entity", entity_class_name="cat", name="Felix", description="The wonderful wonderful cat." - ) - self.assertIsNone(error) - _, error = db_map.add_item("entity", entity_class_name="fish__cat", element_name_list=("Nemo", "Felix")) - self.assertIsNone(error) - _, error = db_map.add_item("parameter_definition", entity_class_name="fish__cat", name="rate") - self.assertIsNone(error) + )) + self._assert_success(db_map.add_item("entity", entity_class_name="fish__cat", element_name_list=("Nemo", "Felix"))) + self._assert_success(db_map.add_item("parameter_definition", entity_class_name="fish__cat", name="rate")) value, type_ = to_database(0.23) - _, error = db_map.add_item( + self._assert_success(db_map.add_item( "parameter_value", entity_class_name="fish__cat", entity_byname=("Nemo", "Felix"), @@ -190,8 +180,7 @@ def test_commit_multidimensional_parameter_value(self): alternative_name="Base", value=value, type=type_, - ) - self.assertIsNone(error) + )) db_map.commit_session("Added data") with DatabaseMapping(url) as db_map: color = db_map.get_item( @@ -206,16 +195,13 @@ def test_commit_multidimensional_parameter_value(self): def test_updating_entity_name_updates_the_name_in_parameter_value_too(self): with DatabaseMapping(IN_MEMORY_DB_URL, create=True) as db_map: - _, error = db_map.add_item("entity_class", name="fish", description="It swims.") - self.assertIsNone(error) - _, error = db_map.add_item( + self._assert_success(db_map.add_item("entity_class", name="fish", description="It swims.")) + self._assert_success(db_map.add_item( "entity", entity_class_name="fish", name="Nemo", description="Peacefully swimming away." - ) - self.assertIsNone(error) - _, error = db_map.add_item("parameter_definition", entity_class_name="fish", name="color") - self.assertIsNone(error) + )) + self._assert_success(db_map.add_item("parameter_definition", entity_class_name="fish", name="color")) value, type_ = to_database("mainly orange") - _, error = db_map.add_item( + self._assert_success(db_map.add_item( "parameter_value", entity_class_name="fish", entity_byname=("Nemo",), @@ -223,8 +209,7 @@ def test_updating_entity_name_updates_the_name_in_parameter_value_too(self): alternative_name="Base", value=value, type=type_, - ) - self.assertIsNone(error) + )) color = db_map.get_item( "parameter_value", entity_class_name="fish", @@ -256,18 +241,17 @@ def test_updating_entity_name_updates_the_name_in_parameter_value_too(self): def test_update_entity_metadata_by_changing_its_entity(self): with DatabaseMapping("sqlite://", create=True) as db_map: - entity_class, _ = db_map.add_entity_class_item(name="my_class") + self._assert_success(db_map.add_entity_class_item(name="my_class")) db_map.add_entity_item(name="entity_1", entity_class_name="my_class") - entity_2, _ = db_map.add_entity_item(name="entity_2", entity_class_name="my_class") + entity_2 = self._assert_success(db_map.add_entity_item(name="entity_2", entity_class_name="my_class")) metadata_value = '{"sources": [], "contributors": []}' - metadata, _ = db_map.add_metadata_item(name="my_metadata", value=metadata_value) - entity_metadata, error = db_map.add_entity_metadata_item( + metadata = self._assert_success(db_map.add_metadata_item(name="my_metadata", value=metadata_value)) + entity_metadata = self._assert_success(db_map.add_entity_metadata_item( metadata_name="my_metadata", metadata_value=metadata_value, entity_class_name="my_class", entity_byname=("entity_1",), - ) - self.assertIsNone(error) + )) entity_metadata.update(entity_byname=("entity_2",)) self.assertEqual( entity_metadata._extended(), @@ -317,43 +301,38 @@ def test_update_entity_metadata_by_changing_its_entity(self): def test_update_parameter_value_metadata_by_changing_its_parameter(self): with DatabaseMapping("sqlite://", create=True) as db_map: - entity_class, _ = db_map.add_entity_class_item(name="my_class") - _, error = db_map.add_parameter_definition_item(name="x", entity_class_name="my_class") - self.assertIsNone(error) - db_map.add_parameter_definition_item(name="y", entity_class_name="my_class") - entity, _ = db_map.add_entity_item(name="my_entity", entity_class_name="my_class") + entity_class = self._assert_success(db_map.add_entity_class_item(name="my_class")) + self._assert_success(db_map.add_parameter_definition_item(name="x", entity_class_name="my_class")) + self._assert_success(db_map.add_parameter_definition_item(name="y", entity_class_name="my_class")) + self._assert_success(db_map.add_entity_item(name="my_entity", entity_class_name="my_class")) value, value_type = to_database(2.3) - _, error = db_map.add_parameter_value_item( + self._assert_success(db_map.add_parameter_value_item( entity_class_name="my_class", entity_byname=("my_entity",), parameter_definition_name="x", alternative_name="Base", value=value, type=value_type, - ) - self.assertIsNone(error) + )) value, value_type = to_database(-2.3) - y, error = db_map.add_parameter_value_item( + y = self._assert_success(db_map.add_parameter_value_item( entity_class_name="my_class", entity_byname=("my_entity",), parameter_definition_name="y", alternative_name="Base", value=value, type=value_type, - ) - self.assertIsNone(error) + )) metadata_value = '{"sources": [], "contributors": []}' - metadata, error = db_map.add_metadata_item(name="my_metadata", value=metadata_value) - self.assertIsNone(error) - value_metadata, error = db_map.add_parameter_value_metadata_item( + metadata = self._assert_success(db_map.add_metadata_item(name="my_metadata", value=metadata_value)) + value_metadata = self._assert_success(db_map.add_parameter_value_metadata_item( metadata_name="my_metadata", metadata_value=metadata_value, entity_class_name="my_class", entity_byname=("my_entity",), parameter_definition_name="x", alternative_name="Base", - ) - self.assertIsNone(error) + )) value_metadata.update(parameter_definition_name="y") self.assertEqual( value_metadata._extended(), @@ -425,8 +404,8 @@ def test_fetch_more(self): def test_fetch_more_after_commit_and_refresh(self): with DatabaseMapping("sqlite://", create=True) as db_map: - db_map.add_item("entity_class", name="Widget") - db_map.add_item("entity", entity_class_name="Widget", name="gadget") + self._assert_success(db_map.add_item("entity_class", name="Widget")) + self._assert_success(db_map.add_item("entity", entity_class_name="Widget", name="gadget")) db_map.commit_session("Add test data.") db_map.refresh_session() entities = db_map.fetch_more("entity") @@ -441,7 +420,7 @@ def test_has_external_commits_returns_true_when_another_db_mapping_has_made_comm url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") with DatabaseMapping(url, create=True) as db_map: with DatabaseMapping(url) as other_db_map: - other_db_map.add_item("entity_class", name="cc") + self._assert_success(other_db_map.add_item("entity_class", name="cc")) other_db_map.commit_session("Added a class") self.assertTrue(db_map.has_external_commits()) @@ -452,7 +431,7 @@ def test_has_external_commits_returns_false_after_commit_session(self): with DatabaseMapping(url) as other_db_map: other_db_map.add_item("entity_class", name="cc") other_db_map.commit_session("Added a class") - db_map.add_item("entity_class", name="omega") + self._assert_success(db_map.add_item("entity_class", name="omega")) db_map.commit_session("Added a class") self.assertFalse(db_map.has_external_commits()) @@ -466,18 +445,18 @@ def test_fetch_entities_that_refer_to_unfetched_entities(self): with TemporaryDirectory() as temp_dir: url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") with DatabaseMapping(url, create=True) as db_map: - db_map.add_entity_class_item(name="dog") - db_map.add_entity_class_item(name="cat") - db_map.add_entity_class_item(name="dog__cat", dimension_name_list=("dog", "cat")) - db_map.add_entity_item(name="Pulgoso", entity_class_name="dog") - db_map.add_entity_item(name="Sylvester", entity_class_name="cat") - db_map.add_entity_item(name="Tom", entity_class_name="cat") + self._assert_success(db_map.add_entity_class_item(name="dog")) + self._assert_success(db_map.add_entity_class_item(name="cat")) + self._assert_success(db_map.add_entity_class_item(name="dog__cat", dimension_name_list=("dog", "cat"))) + self._assert_success(db_map.add_entity_item(name="Pulgoso", entity_class_name="dog")) + self._assert_success(db_map.add_entity_item(name="Sylvester", entity_class_name="cat")) + self._assert_success(db_map.add_entity_item(name="Tom", entity_class_name="cat")) db_map.commit_session("Arf!") with DatabaseMapping(url) as db_map: # Remove the entity in the middle and add a multi-D one referring to the third entity. # The multi-D one will go in the middle. db_map.get_entity_item(name="Sylvester", entity_class_name="cat").remove() - db_map.add_entity_item(element_name_list=("Pulgoso", "Tom"), entity_class_name="dog__cat") + self._assert_success(db_map.add_entity_item(element_name_list=("Pulgoso", "Tom"), entity_class_name="dog__cat")) db_map.commit_session("Meow!") with DatabaseMapping(url) as db_map: # The ("Pulgoso", "Tom") entity will be fetched before "Tom". @@ -489,24 +468,19 @@ def test_committing_scenario_alternatives(self): with TemporaryDirectory() as temp_dir: url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") with DatabaseMapping(url, create=True) as db_map: - item, error = db_map.add_alternative_item(name="alt1") - self.assertIsNone(error) + item = self._assert_success(db_map.add_alternative_item(name="alt1")) self.assertIsNotNone(item) - item, error = db_map.add_alternative_item(name="alt2") - self.assertIsNone(error) + item = self._assert_success(db_map.add_alternative_item(name="alt2")) self.assertIsNotNone(item) - item, error = db_map.add_scenario_item(name="my_scenario") - self.assertIsNone(error) + item = self._assert_success(db_map.add_scenario_item(name="my_scenario")) self.assertIsNotNone(item) - item, error = db_map.add_scenario_alternative_item( + item = self._assert_success(db_map.add_scenario_alternative_item( scenario_name="my_scenario", alternative_name="alt1", rank=0 - ) - self.assertIsNone(error) + )) self.assertIsNotNone(item) - item, error = db_map.add_scenario_alternative_item( + item = self._assert_success(db_map.add_scenario_alternative_item( scenario_name="my_scenario", alternative_name="alt2", rank=1 - ) - self.assertIsNone(error) + )) self.assertIsNotNone(item) db_map.commit_session("Add test data.") with DatabaseMapping(url) as db_map: @@ -521,7 +495,7 @@ def test_committing_scenario_alternatives(self): def test_committing_entity_class_items_doesnt_add_commit_ids_to_them(self): with DatabaseMapping("sqlite://", create=True) as db_map: - db_map.add_entity_class_item(name="my_class") + self._assert_success(db_map.add_entity_class_item(name="my_class")) db_map.commit_session("Add class.") classes = db_map.get_entity_class_items() self.assertEqual(len(classes), 1) @@ -529,9 +503,9 @@ def test_committing_entity_class_items_doesnt_add_commit_ids_to_them(self): def test_committing_superclass_subclass_items_doesnt_add_commit_ids_to_them(self): with DatabaseMapping("sqlite://", create=True) as db_map: - db_map.add_entity_class_item(name="high") - db_map.add_entity_class_item(name="low") - db_map.add_superclass_subclass_item(superclass_name="high", subclass_name="low") + self._assert_success(db_map.add_entity_class_item(name="high")) + self._assert_success(db_map.add_entity_class_item(name="low")) + self._assert_success(db_map.add_superclass_subclass_item(superclass_name="high", subclass_name="low")) db_map.commit_session("Add class hierarchy.") classes = db_map.get_superclass_subclass_items() self.assertEqual(len(classes), 1) @@ -539,10 +513,10 @@ def test_committing_superclass_subclass_items_doesnt_add_commit_ids_to_them(self def test_committing_entity_group_items_doesnt_add_commit_ids_to_them(self): with DatabaseMapping("sqlite://", create=True) as db_map: - db_map.add_entity_class_item(name="my_class") - db_map.add_entity_item(name="element", entity_class_name="my_class") - db_map.add_entity_item(name="container", entity_class_name="my_class") - db_map.add_entity_group_item(group_name="container", member_name="element", entity_class_name="my_class") + self._assert_success(db_map.add_entity_class_item(name="my_class")) + self._assert_success(db_map.add_entity_item(name="element", entity_class_name="my_class")) + self._assert_success(db_map.add_entity_item(name="container", entity_class_name="my_class")) + self._assert_success(db_map.add_entity_group_item(group_name="container", member_name="element", entity_class_name="my_class")) db_map.commit_session("Add entity group.") groups = db_map.get_entity_group_items() self.assertEqual(len(groups), 1) @@ -550,42 +524,42 @@ def test_committing_entity_group_items_doesnt_add_commit_ids_to_them(self): def test_commit_parameter_value_coincidentally_called_is_active(self): with DatabaseMapping("sqlite://", create=True) as db_map: - db_map.add_parameter_value_list_item(name="booleans") + self._assert_success(db_map.add_parameter_value_list_item(name="booleans")) value, value_type = to_database(True) - db_map.add_list_value_item(parameter_value_list_name="booleans", value=value, type=value_type, index=0) - db_map.add_entity_class_item(name="my_class") - db_map.add_parameter_definition_item( + self._assert_success(db_map.add_list_value_item(parameter_value_list_name="booleans", value=value, type=value_type, index=0)) + self._assert_success(db_map.add_entity_class_item(name="my_class")) + self._assert_success(db_map.add_parameter_definition_item( name="is_active", entity_class_name="my_class", parameter_value_list_name="booleans" - ) - db_map.add_entity_item(name="widget1", entity_class_name="my_class") - db_map.add_entity_item(name="widget2", entity_class_name="my_class") - db_map.add_entity_item(name="no_is_active", entity_class_name="my_class") - db_map.add_entity_alternative_item( + )) + self._assert_success(db_map.add_entity_item(name="widget1", entity_class_name="my_class")) + self._assert_success(db_map.add_entity_item(name="widget2", entity_class_name="my_class")) + self._assert_success(db_map.add_entity_item(name="no_is_active", entity_class_name="my_class")) + self._assert_success(db_map.add_entity_alternative_item( entity_class_name="my_class", entity_byname=("widget1",), alternative_name="Base", active=False - ) - db_map.add_entity_alternative_item( + )) + self._assert_success(db_map.add_entity_alternative_item( entity_class_name="my_class", entity_byname=("widget2",), alternative_name="Base", active=False - ) - db_map.add_entity_alternative_item( + )) + self._assert_success(db_map.add_entity_alternative_item( entity_class_name="my_class", entity_byname=("no_is_active",), alternative_name="Base", active=False - ) + )) value, value_type = to_database(True) - db_map.add_parameter_value_item( + self._assert_success(db_map.add_parameter_value_item( entity_class_name="my_class", parameter_definition_name="is_active", entity_byname=("widget1",), alternative_name="Base", value=value, type=value_type, - ) - db_map.add_parameter_value_item( + )) + self._assert_success(db_map.add_parameter_value_item( entity_class_name="my_class", parameter_definition_name="is_active", entity_byname=("widget2",), alternative_name="Base", value=value, type=value_type, - ) + )) db_map.commit_session("Add test data to see if this crashes.") entity_names = {entity["id"]: entity["name"] for entity in db_map.query(db_map.wide_entity_sq)} alternative_names = { @@ -608,29 +582,29 @@ def test_commit_parameter_value_coincidentally_called_is_active(self): def test_commit_default_value_for_parameter_called_is_active(self): with DatabaseMapping("sqlite://", create=True) as db_map: - db_map.add_parameter_value_list_item(name="booleans") + self._assert_success(db_map.add_parameter_value_list_item(name="booleans")) value, value_type = to_database(True) - db_map.add_list_value_item(parameter_value_list_name="booleans", value=value, type=value_type, index=0) - db_map.add_entity_class_item(name="Widget") - db_map.add_parameter_definition_item( + self._assert_success(db_map.add_list_value_item(parameter_value_list_name="booleans", value=value, type=value_type, index=0)) + self._assert_success(db_map.add_entity_class_item(name="Widget")) + self._assert_success(db_map.add_parameter_definition_item( name="is_active", entity_class_name="Widget", parameter_value_list_name="booleans", default_value=value, default_type=value_type, - ) - db_map.add_entity_class_item(name="Gadget") - db_map.add_parameter_definition_item( + )) + self._assert_success(db_map.add_entity_class_item(name="Gadget")) + self._assert_success(db_map.add_parameter_definition_item( name="is_active", entity_class_name="Gadget", parameter_value_list_name="booleans", default_value=value, default_type=value_type, - ) - db_map.add_entity_class_item(name="NoIsActiveDefault") - db_map.add_parameter_definition_item( + )) + self._assert_success(db_map.add_entity_class_item(name="NoIsActiveDefault")) + self._assert_success(db_map.add_parameter_definition_item( name="is_active", entity_class_name="NoIsActiveDefault", parameter_value_list_name="booleans" - ) + )) db_map.commit_session("Add test data to see if this crashes") active_by_defaults = { entity_class["name"]: entity_class["active_by_default"] @@ -643,6 +617,32 @@ def test_commit_default_value_for_parameter_called_is_active(self): ] self.assertEqual(defaults, 3 * [None]) + def test_remove_items_by_asterisk(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + self._assert_success(db_map.add_alternative_item(name="alt_1")) + self._assert_success(db_map.add_alternative_item(name="alt_2")) + db_map.commit_session("Add alternatives.") + alternatives = db_map.get_alternative_items() + self.assertEqual(len(alternatives), 3) + db_map.remove_items("alternative", Asterisk) + db_map.commit_session("Remove all alternatives.") + alternatives = db_map.get_alternative_items() + self.assertEqual(alternatives, []) + + def test_reset_purging(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + db_map.add_entity_class_item(name="Widget") + db_map.purge_items("entity_class") + with DatabaseMapping(url) as another_db_map: + another_db_map.add_entity_class_item(name="Gadget") + another_db_map.commit_session("Add another entity class.") + db_map.reset_purging() + entity_classes = db_map.get_entity_class_items() + self.assertEqual(len(entity_classes), 1) + self.assertEqual(entity_classes[0]["name"], "Gadget") + class TestDatabaseMappingLegacy(unittest.TestCase): """'Backward compatibility' tests, i.e. pre-entity tests converted to work with the entity structure.""" From 7760b311b208605df9ca285711ed81aa2e8e902a Mon Sep 17 00:00:00 2001 From: Manuel Date: Mon, 5 Feb 2024 12:08:21 +0100 Subject: [PATCH 254/317] Keep removed items's unique key values for conflict resolution --- spinedb_api/db_mapping.py | 2 +- spinedb_api/db_mapping_base.py | 144 +++++++++++++++++---------------- tests/test_DatabaseMapping.py | 1 - 3 files changed, 74 insertions(+), 73 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 6b1d11d5..98a68aa9 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -382,7 +382,7 @@ def get_items(self, item_type, fetch=True, skip_removed=True, **kwargs): mapped_table = self.mapped_table(item_type) mapped_table.check_fields(kwargs, valid_types=(type(None),)) if fetch: - self.do_fetch_all(item_type, **kwargs) + self.do_fetch_more(item_type, offset=0, limit=None, **kwargs) get_items = mapped_table.valid_values if skip_removed else mapped_table.values return [x.public_item for x in get_items() if all(x.get(k) == v for k, v in kwargs.items())] diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 8d681395..488fff39 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -40,6 +40,7 @@ class DatabaseMappingBase: def __init__(self): self._mapped_tables = {} + self._fetched = set() item_types = self.item_types() self._sorted_item_types = [] while item_types: @@ -224,6 +225,7 @@ def reset(self, *item_types): self._add_descendants(item_types) for item_type in item_types: self._mapped_tables.pop(item_type, None) + self._fetched.discard(item_type) def _add_descendants(self, item_types): while True: @@ -268,9 +270,8 @@ def do_fetch_more(self, item_type, offset=0, limit=None, **kwargs): items = [] new_items = [] # Add items first - fix_id_conflics = self.has_external_commits() for x in chunk: - item, new = mapped_table.add_item_from_db(x, fix_id_conflics) + item, new = mapped_table.add_item_from_db(x) if new: new_items.append(item) items.append(item) @@ -280,8 +281,11 @@ def do_fetch_more(self, item_type, offset=0, limit=None, **kwargs): mapped_table.add_unique(item) return items - def do_fetch_all(self, item_type, **kwargs): - self.do_fetch_more(item_type, offset=0, limit=None, **kwargs) + def do_fetch_all(self, item_type): + if item_type in self._fetched and not self.has_external_commits(): + return + self._fetched.add(item_type) + self.do_fetch_more(item_type, offset=0, limit=None) def update_id(self, mapped_item, new_id): """Updates the id of the given item to the given new_id, also in all its referees. @@ -293,21 +297,15 @@ def update_id(self, mapped_item, new_id): old_id = mapped_item["id"] mapped_item["id"] = new_id for item_type in self.item_types(): - dirty_fields = { - field - for field, (ref_type, ref_field) in self.item_factory(item_type)._references.items() - if ref_type == mapped_item.item_type and ref_field == "id" - } # Fields that might refer the id of the mapped item - if not dirty_fields: - continue mapped_table = self.mapped_table(item_type) - for item in mapped_table.values(): - for field in dirty_fields: - value = item[field] - if isinstance(value, tuple) and old_id in value: - item[field] = tuple(new_id if id_ == old_id else id_ for id_ in value) - elif old_id == value: - item[field] = new_id + for field, (ref_type, ref_field) in self.item_factory(item_type)._references.items(): + if ref_type == mapped_item.item_type and ref_field == "id": + for item in mapped_table.values(): + value = item[field] + if isinstance(value, tuple) and old_id in value: + item[field] = tuple(new_id if id_ == old_id else id_ for id_ in value) + elif old_id == value: + item[field] = new_id class _MappedTable(dict): @@ -320,7 +318,7 @@ def __init__(self, db_map, item_type, *args, **kwargs): super().__init__(*args, **kwargs) self._db_map = db_map self._item_type = item_type - self._id_by_unique_key_value = {} + self._ids_by_unique_key_value = {} self._temp_id_by_db_id = {} self.wildcard_item = MappedItemBase(self._db_map, self._item_type, id=Asterisk) @@ -346,7 +344,7 @@ def _callback(db_id): return temp_id def _unique_key_value_to_id(self, key, value, fetch=True): - """Returns the id that has the given value for the given unique key, or None if not found. + """Returns the id that has the given value for the given unique key, or None. Args: key (tuple) @@ -354,17 +352,24 @@ def _unique_key_value_to_id(self, key, value, fetch=True): fetch (bool): whether to fetch the DB until found. Returns: - int + int or None """ - id_by_unique_value = self._id_by_unique_key_value.get(key, {}) - if not id_by_unique_value and fetch: + ids_by_unique_value = self._ids_by_unique_key_value.get(key, {}) + if not ids_by_unique_value and fetch: self._db_map.do_fetch_all(self._item_type) - id_by_unique_value = self._id_by_unique_key_value.get(key, {}) + ids_by_unique_value = self._ids_by_unique_key_value.get(key, {}) value = tuple(tuple(x) if isinstance(x, list) else x for x in value) - return id_by_unique_value.get(value) + ids = ids_by_unique_value.get(value, []) + return None if not ids else ids[-1] - def _unique_key_value_to_item(self, key, value, fetch=True): - return self.get(self._unique_key_value_to_id(key, value, fetch=fetch)) + def _unique_key_value_to_item(self, key, value, fetch=True, valid_only=True): + id_ = self._unique_key_value_to_id(key, value, fetch=fetch) + mapped_item = self.get(id_) + if mapped_item is None: + return None + if valid_only and not mapped_item.is_valid(): + return None + return mapped_item def valid_values(self): return (x for x in self.values() if x.is_valid()) @@ -401,24 +406,23 @@ def find_item_by_id(self, id_, fetch=True): current_item = self.get(id_, {}) return current_item - def _find_item_by_unique_key(self, item, skip_keys=(), fetch=True, complete=True): + def _find_item_by_unique_key(self, item, skip_keys=(), fetch=True, valid_only=True): for key, value in self._db_map.item_factory(self._item_type).unique_values_for_item(item, skip_keys=skip_keys): - current_item = self._unique_key_value_to_item(key, value, fetch=fetch) + current_item = self._unique_key_value_to_item(key, value, fetch=fetch, valid_only=valid_only) + if current_item: + return current_item + # Maybe item is missing some key stuff, so try with a resolved and polished MappedItem too... + mapped_item = self._make_item(item) + error = mapped_item.resolve_internal_fields(item.keys()) + if error: + return {} + error = mapped_item.polish() + if error: + return {} + for key, value in mapped_item.unique_key_values(skip_keys=skip_keys): + current_item = self._unique_key_value_to_item(key, value, fetch=fetch, valid_only=valid_only) if current_item: return current_item - if complete: - # Maybe item is missing some key stuff, so try with a resolved and polished MappedItem too... - mapped_item = self._make_item(item) - error = mapped_item.resolve_internal_fields(item.keys()) - if error: - return {} - error = mapped_item.polish() - if error: - return {} - for key, value in mapped_item.unique_key_values(skip_keys=skip_keys): - current_item = self._unique_key_value_to_item(key, value, fetch=fetch) - if current_item: - return current_item return {} def checked_item_and_error(self, item, for_update=False): @@ -485,14 +489,14 @@ def item_to_remove_and_error(self, id_): def add_unique(self, item): id_ = item["id"] for key, value in item.unique_key_values(): - self._id_by_unique_key_value.setdefault(key, {})[value] = id_ + self._ids_by_unique_key_value.setdefault(key, {}).setdefault(value, []).append(id_) def remove_unique(self, item): id_ = item["id"] for key, value in item.unique_key_values(): - id_by_value = self._id_by_unique_key_value.get(key, {}) - if id_by_value.get(value) == id_: - del id_by_value[value] + ids = self._ids_by_unique_key_value.get(key, {}).get(value, []) + if id_ in ids: + ids.remove(id_) def _make_and_add_item(self, item): if not isinstance(item, MappedItemBase): @@ -503,7 +507,7 @@ def _make_and_add_item(self, item): self[item["id"]] = item return item - def add_item_from_db(self, item, fix_id_conflics): + def add_item_from_db(self, item): """Adds an item fetched from the DB. Args: @@ -512,10 +516,14 @@ def add_item_from_db(self, item, fix_id_conflics): Returns: tuple(MappedItem,bool): The mapped item and whether it hadn't been added before. """ - mapped_item = self._find_item_by_unique_key(item, fetch=False, complete=True) - print(item, mapped_item) + mapped_item = self._find_item_by_unique_key(item, fetch=False, valid_only=False) if mapped_item: - self._solve_id_conflict(mapped_item, item) + self._fix_id(mapped_item, item["id"]) + return mapped_item, False + mapped_item = self.find_item_by_id(item["id"], fetch=False) + if mapped_item: + # An item from the DB has the same id as a mapped item, but they are not equivalent + # TODO: Fix id clash return mapped_item, False mapped_item = self._make_and_add_item(item) if self.purged: @@ -523,28 +531,27 @@ def add_item_from_db(self, item, fix_id_conflics): mapped_item.cascade_remove(source=self.wildcard_item) return mapped_item, True - def _solve_id_conflict(self, mapped_item, db_item): - """Makes sure that mapped_item and db_item don't have conflicting ids. - Both items are equivalent in the sense they share a unique key, - so there's only room for one of them. Therefore, they must have the same id. + def _fix_id(self, mapped_item, id_): + """Makes sure that mapped_item has the given id_. Args: - current (MappedItemBase): An item in the in-memory item. - item (dict): An item just fetched from the DB. + mapped_item (MappedItemBase): An item in the in-memory item. + id_ (int): The id_ of an equivalent item just fetched from the DB. """ - # NOTE: db_item is more recent (because it's been fetched later) so we need to trust its id. - mapped_id, db_id = mapped_item["id"], db_item["id"] + mapped_id = mapped_item["id"] if isinstance(mapped_id, TempId): # mapped_item was added on this session and hasn't been committed. - # Just do as if it was committed and has the id of db_item. - mapped_id.resolve(db_id) + # But it was committed by somebody else, so we need to accept that. + mapped_id.resolve(id_) if mapped_item.status == Status.to_add: mapped_item.status = Status.committed - elif mapped_id != db_id: - # Both mapped_item and db_item have been committed but with a different id (it can happen). - # Change the id of mapped_item to that of db_item. - self._db_map.update_id(mapped_item, db_id) - self[db_id] = self.pop(mapped_id) + elif mapped_id != id_: + # The id of mapped_item has changed in the DB. + self._db_map.update_id(mapped_item, id_) + # TODO: Fix id clash if db_id already in self + # Store the mapped item in the new db_id + self[id_] = mapped_item + del self[mapped_id] def check_fields(self, item, valid_types=()): factory = self._db_map.item_factory(self._item_type) @@ -580,7 +587,6 @@ def add_item(self, item): return item def update_item(self, item): - print("update_item", item) current_item = self.find_item(item) current_item.cascade_remove_unique() current_item.update(item) @@ -594,10 +600,8 @@ def remove_item(self, item): if item is self.wildcard_item: self.purged = True for current_item in self.valid_values(): - self.remove_unique(current_item) current_item.cascade_remove(source=self.wildcard_item) return self.wildcard_item - self.remove_unique(item) item.cascade_remove() return item @@ -605,12 +609,10 @@ def restore_item(self, id_): if id_ is Asterisk: self.purged = False for current_item in self.values(): - self.add_unique(current_item) current_item.cascade_restore(source=self.wildcard_item) return self.wildcard_item current_item = self.find_item({"id": id_}) if current_item: - self.add_unique(current_item) current_item.cascade_restore() return current_item @@ -1076,7 +1078,7 @@ def call_update_callbacks(self): self.update_callbacks -= obsolete def cascade_add_unique(self): - """Removes item and all its referrers unique keys and ids in cascade.""" + """Adds item and all its referrers unique keys and ids in cascade.""" mapped_table = self._db_map.mapped_table(self._item_type) mapped_table.add_unique(self) for referrer in self._referrers.values(): diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index c2b346cf..d06f8271 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -2037,7 +2037,6 @@ def test_update_committed_object(self): self._db_map.add_object_classes({"id": 1, "name": "some_class"}) self._db_map.add_objects({"id": 1, "name": "nemo", "class_id": 1}) self._db_map.commit_session("update") - print("NOW") items, intgr_error_log = self._db_map.update_objects({"id": 1, "name": "klaus"}) ids = {x["id"] for x in items} self._db_map.commit_session("test commit") From 282d38e86bbd8604f3e80f8026a85feeb56f1754 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Mon, 5 Feb 2024 13:24:21 +0200 Subject: [PATCH 255/317] Allow the restoration of purged items one-by-one. Previously it was not possible to call restore() on an item that had been purged. This removes the artificial limitation. Re #325 --- spinedb_api/db_mapping_base.py | 4 ++-- tests/test_DatabaseMapping.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 3c5a133c..63e9a60c 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -556,7 +556,7 @@ def remove_item(self, item): self.purged = True for current_item in self.valid_values(): self.remove_unique(current_item) - current_item.cascade_remove(source=self.wildcard_item) + current_item.cascade_remove() return self.wildcard_item self.remove_unique(item) item.cascade_remove() @@ -567,7 +567,7 @@ def restore_item(self, id_): self.purged = False for current_item in self.values(): self.add_unique(current_item) - current_item.cascade_restore(source=self.wildcard_item) + current_item.cascade_restore() return self.wildcard_item current_item = self.find_item({"id": id_}) if current_item: diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 1aa80133..6ad846b0 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -644,6 +644,25 @@ def test_reset_purging(self): self.assertEqual(len(entity_classes), 1) self.assertEqual(entity_classes[0]["name"], "Gadget") + def test_restored_entity_class_item_has_display_icon_field(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + entity_class = self._assert_success(db_map.add_entity_class_item(name="Gadget")) + db_map.purge_items("entity_class") + entity_class.restore() + item = db_map.get_entity_class_item(name="Gadget") + self.assertIsNone(item["display_icon"]) + + def test_trying_to_restore_item_whose_parent_is_removed_fails(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + entity_class = self._assert_success(db_map.add_entity_class_item(name="Object")) + entity = self._assert_success(db_map.add_entity_item(name="knife", entity_class_name="Object")) + entity_class.remove() + self.assertFalse(entity.is_valid()) + entity.restore() + self.assertFalse(entity.is_valid()) + entity_class.restore() + self.assertTrue(entity.is_valid()) + class TestDatabaseMappingLegacy(unittest.TestCase): """'Backward compatibility' tests, i.e. pre-entity tests converted to work with the entity structure.""" From eaba7d2df33325218347293fdcc546607c5be186 Mon Sep 17 00:00:00 2001 From: Manuel Date: Mon, 5 Feb 2024 15:00:59 +0100 Subject: [PATCH 256/317] Introduce _free_id --- spinedb_api/db_mapping_base.py | 71 +++++++++++++++++++++++++++------- spinedb_api/temp_id.py | 4 +- 2 files changed, 59 insertions(+), 16 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index d1fa44c7..6697f4c7 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -527,41 +527,69 @@ def add_item_from_db(self, item): """ mapped_item = self._find_item_by_unique_key(item, fetch=False, valid_only=False) if mapped_item: - self._fix_id(mapped_item, item["id"]) + self._force_id(mapped_item, item["id"]) return mapped_item, False - mapped_item = self.find_item_by_id(item["id"], fetch=False) - if mapped_item: - # An item from the DB has the same id as a mapped item, but they are not equivalent - # TODO: Fix id clash + mapped_item = self.get(item["id"]) + if mapped_item is not None and mapped_item.is_equal_in_db(item): return mapped_item, False + self._free_id(item["id"]) mapped_item = self._make_and_add_item(item) if self.purged: # Lazy purge: instead of fetching all at purge time, we purge stuff as it comes. mapped_item.cascade_remove(source=self.wildcard_item) return mapped_item, True - def _fix_id(self, mapped_item, id_): - """Makes sure that mapped_item has the given id_. + def _force_id(self, mapped_item, id_): + """Makes sure that mapped_item has the given id_, corresponding to the new id of the item + in the DB after some external changes. Args: mapped_item (MappedItemBase): An item in the in-memory item. - id_ (int): The id_ of an equivalent item just fetched from the DB. + id_ (int): The most recent id_ of the item as fetched from the DB. """ mapped_id = mapped_item["id"] + if mapped_id == id_: # This is True even if mapped_id is a TempId resolved to the given id_ + return if isinstance(mapped_id, TempId): - # mapped_item was added on this session and hasn't been committed. - # But it was committed by somebody else, so we need to accept that. + # Easy, resolve the TempId to the new db id (and commit the item if pending) mapped_id.resolve(id_) if mapped_item.status == Status.to_add: mapped_item.status = Status.committed - elif mapped_id != id_: - # The id of mapped_item has changed in the DB. + else: + # Hard, update the id of the item manually. + self._free_id(id_) + self.remove_unique(mapped_item) self._db_map.update_id(mapped_item, id_) - # TODO: Fix id clash if db_id already in self - # Store the mapped item in the new db_id + self.add_unique(mapped_item) self[id_] = mapped_item del self[mapped_id] + def _free_id(self, id_): + """Makes sure the given id_ is free. Fix conflicts if not. + + Args: + id_ (int) + """ + conflicting_item = self.pop(id_, None) + if conflicting_item is not None: + self._resolve_conflict(conflicting_item) + + def _resolve_conflict(self, conflicting_item): + """Does something with conflicting item which has been removed from the DB by an external commit. + + Args: + conflicting_item (MappedItemBase): an item in the memory mapping. + """ + # Here we could let the user choose the strategy. + # For now, we keep the conflicting_item in memory with a new TempId. + # It will be committed in the next call to commit_session. + self.remove_unique(conflicting_item) + new_id = self._new_id() + self._db_map.update_id(conflicting_item, new_id) + self.add_unique(conflicting_item) + if conflicting_item.status in (Status.to_update, Status.committed): + conflicting_item.status = Status.to_add + def check_fields(self, item, valid_types=()): factory = self._db_map.item_factory(self._item_type) @@ -806,6 +834,18 @@ def _convert(x): or self.fields.get(key, {}).get("optional", False) # Ignore mandatory fields that are None ) + def is_equal_in_db(self, other): + """Returns whether this item and other are the same in the DB. + + Args: + other (dict) + + Returns: + bool + """ + this = self._db_map.make_item(self._item_type, **self.backup) if self.status == Status.to_update else self + return not this._something_to_update(other) + def first_invalid_key(self): """Goes through the ``_references`` class attribute and returns the key of the first reference that cannot be resolved. @@ -1224,3 +1264,6 @@ def add_remove_callback(self, callback): def add_restore_callback(self, callback): self._mapped_item.restore_callbacks.add(callback) + + def resolve(self): + return self._mapped_item.resolve() diff --git a/spinedb_api/temp_id.py b/spinedb_api/temp_id.py index 79066941..de9f302d 100644 --- a/spinedb_api/temp_id.py +++ b/spinedb_api/temp_id.py @@ -42,8 +42,8 @@ def add_resolve_callback(self, callback): def resolve(self, db_id): self._db_id = db_id - while self._resolve_callbacks: - self._resolve_callbacks.pop(0)(db_id) + for callback in self._resolve_callbacks: + callback(db_id) def resolve(value): From 02a953ffd5e2e8a348748982e4973741c11931a1 Mon Sep 17 00:00:00 2001 From: Manuel Date: Mon, 5 Feb 2024 17:03:05 +0100 Subject: [PATCH 257/317] Add tests --- spinedb_api/db_mapping_base.py | 14 +- tests/test_DatabaseMapping.py | 360 +++++++++++++++++++++------------ 2 files changed, 237 insertions(+), 137 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 6697f4c7..ee0369e2 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -575,14 +575,20 @@ def _free_id(self, id_): self._resolve_conflict(conflicting_item) def _resolve_conflict(self, conflicting_item): - """Does something with conflicting item which has been removed from the DB by an external commit. + """Does something with conflicting_item which has been removed from the DB by an external commit. + + Args: + conflicting_item (MappedItemBase): an item in the memory mapping. + """ + # Here we could let the user choose the strategy. For now we just 'rescue' the item. + self._rescue_item(conflicting_item) + + def _rescue_item(self, conflicting_item): + """Rescues the given conflicting_item which has been removed from the DB by an external commit. Args: conflicting_item (MappedItemBase): an item in the memory mapping. """ - # Here we could let the user choose the strategy. - # For now, we keep the conflicting_item in memory with a new TempId. - # It will be committed in the next call to commit_session. self.remove_unique(conflicting_item) new_id = self._new_id() self._db_map.update_id(conflicting_item, new_id) diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 1aa80133..8e871e6e 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -129,20 +129,24 @@ def test_commit_parameter_value(self): url = "sqlite:///" + os.path.join(temp_dir, "database.sqlite") with DatabaseMapping(url, create=True) as db_map: self._assert_success(db_map.add_item("entity_class", name="fish", description="It swims.")) - self._assert_success(db_map.add_item( - "entity", entity_class_name="fish", name="Nemo", description="Peacefully swimming away." - )) + self._assert_success( + db_map.add_item( + "entity", entity_class_name="fish", name="Nemo", description="Peacefully swimming away." + ) + ) self._assert_success(db_map.add_item("parameter_definition", entity_class_name="fish", name="color")) value, type_ = to_database("mainly orange") - self._assert_success(db_map.add_item( - "parameter_value", - entity_class_name="fish", - entity_byname=("Nemo",), - parameter_definition_name="color", - alternative_name="Base", - value=value, - type=type_, - )) + self._assert_success( + db_map.add_item( + "parameter_value", + entity_class_name="fish", + entity_byname=("Nemo",), + parameter_definition_name="color", + alternative_name="Base", + value=value, + type=type_, + ) + ) db_map.commit_session("Added data") with DatabaseMapping(url) as db_map: color = db_map.get_item( @@ -161,28 +165,40 @@ def test_commit_multidimensional_parameter_value(self): with DatabaseMapping(url, create=True) as db_map: self._assert_success(db_map.add_item("entity_class", name="fish", description="It swims.")) self._assert_success(db_map.add_item("entity_class", name="cat", description="Eats fish.")) - self._assert_success(db_map.add_item( - "entity_class", - name="fish__cat", - dimension_name_list=("fish", "cat"), - description="A fish getting eaten by a cat?", - )) - self._assert_success(db_map.add_item("entity", entity_class_name="fish", name="Nemo", description="Lost (soon).")) - self._assert_success(db_map.add_item( - "entity", entity_class_name="cat", name="Felix", description="The wonderful wonderful cat." - )) - self._assert_success(db_map.add_item("entity", entity_class_name="fish__cat", element_name_list=("Nemo", "Felix"))) - self._assert_success(db_map.add_item("parameter_definition", entity_class_name="fish__cat", name="rate")) + self._assert_success( + db_map.add_item( + "entity_class", + name="fish__cat", + dimension_name_list=("fish", "cat"), + description="A fish getting eaten by a cat?", + ) + ) + self._assert_success( + db_map.add_item("entity", entity_class_name="fish", name="Nemo", description="Lost (soon).") + ) + self._assert_success( + db_map.add_item( + "entity", entity_class_name="cat", name="Felix", description="The wonderful wonderful cat." + ) + ) + self._assert_success( + db_map.add_item("entity", entity_class_name="fish__cat", element_name_list=("Nemo", "Felix")) + ) + self._assert_success( + db_map.add_item("parameter_definition", entity_class_name="fish__cat", name="rate") + ) value, type_ = to_database(0.23) - self._assert_success(db_map.add_item( - "parameter_value", - entity_class_name="fish__cat", - entity_byname=("Nemo", "Felix"), - parameter_definition_name="rate", - alternative_name="Base", - value=value, - type=type_, - )) + self._assert_success( + db_map.add_item( + "parameter_value", + entity_class_name="fish__cat", + entity_byname=("Nemo", "Felix"), + parameter_definition_name="rate", + alternative_name="Base", + value=value, + type=type_, + ) + ) db_map.commit_session("Added data") with DatabaseMapping(url) as db_map: color = db_map.get_item( @@ -198,20 +214,24 @@ def test_commit_multidimensional_parameter_value(self): def test_updating_entity_name_updates_the_name_in_parameter_value_too(self): with DatabaseMapping(IN_MEMORY_DB_URL, create=True) as db_map: self._assert_success(db_map.add_item("entity_class", name="fish", description="It swims.")) - self._assert_success(db_map.add_item( - "entity", entity_class_name="fish", name="Nemo", description="Peacefully swimming away." - )) + self._assert_success( + db_map.add_item( + "entity", entity_class_name="fish", name="Nemo", description="Peacefully swimming away." + ) + ) self._assert_success(db_map.add_item("parameter_definition", entity_class_name="fish", name="color")) value, type_ = to_database("mainly orange") - self._assert_success(db_map.add_item( - "parameter_value", - entity_class_name="fish", - entity_byname=("Nemo",), - parameter_definition_name="color", - alternative_name="Base", - value=value, - type=type_, - )) + self._assert_success( + db_map.add_item( + "parameter_value", + entity_class_name="fish", + entity_byname=("Nemo",), + parameter_definition_name="color", + alternative_name="Base", + value=value, + type=type_, + ) + ) color = db_map.get_item( "parameter_value", entity_class_name="fish", @@ -248,12 +268,14 @@ def test_update_entity_metadata_by_changing_its_entity(self): entity_2 = self._assert_success(db_map.add_entity_item(name="entity_2", entity_class_name="my_class")) metadata_value = '{"sources": [], "contributors": []}' metadata = self._assert_success(db_map.add_metadata_item(name="my_metadata", value=metadata_value)) - entity_metadata = self._assert_success(db_map.add_entity_metadata_item( - metadata_name="my_metadata", - metadata_value=metadata_value, - entity_class_name="my_class", - entity_byname=("entity_1",), - )) + entity_metadata = self._assert_success( + db_map.add_entity_metadata_item( + metadata_name="my_metadata", + metadata_value=metadata_value, + entity_class_name="my_class", + entity_byname=("entity_1",), + ) + ) entity_metadata.update(entity_byname=("entity_2",)) self.assertEqual( entity_metadata._extended(), @@ -308,33 +330,39 @@ def test_update_parameter_value_metadata_by_changing_its_parameter(self): self._assert_success(db_map.add_parameter_definition_item(name="y", entity_class_name="my_class")) self._assert_success(db_map.add_entity_item(name="my_entity", entity_class_name="my_class")) value, value_type = to_database(2.3) - self._assert_success(db_map.add_parameter_value_item( - entity_class_name="my_class", - entity_byname=("my_entity",), - parameter_definition_name="x", - alternative_name="Base", - value=value, - type=value_type, - )) + self._assert_success( + db_map.add_parameter_value_item( + entity_class_name="my_class", + entity_byname=("my_entity",), + parameter_definition_name="x", + alternative_name="Base", + value=value, + type=value_type, + ) + ) value, value_type = to_database(-2.3) - y = self._assert_success(db_map.add_parameter_value_item( - entity_class_name="my_class", - entity_byname=("my_entity",), - parameter_definition_name="y", - alternative_name="Base", - value=value, - type=value_type, - )) + y = self._assert_success( + db_map.add_parameter_value_item( + entity_class_name="my_class", + entity_byname=("my_entity",), + parameter_definition_name="y", + alternative_name="Base", + value=value, + type=value_type, + ) + ) metadata_value = '{"sources": [], "contributors": []}' metadata = self._assert_success(db_map.add_metadata_item(name="my_metadata", value=metadata_value)) - value_metadata = self._assert_success(db_map.add_parameter_value_metadata_item( - metadata_name="my_metadata", - metadata_value=metadata_value, - entity_class_name="my_class", - entity_byname=("my_entity",), - parameter_definition_name="x", - alternative_name="Base", - )) + value_metadata = self._assert_success( + db_map.add_parameter_value_metadata_item( + metadata_name="my_metadata", + metadata_value=metadata_value, + entity_class_name="my_class", + entity_byname=("my_entity",), + parameter_definition_name="x", + alternative_name="Base", + ) + ) value_metadata.update(parameter_definition_name="y") self.assertEqual( value_metadata._extended(), @@ -458,7 +486,9 @@ def test_fetch_entities_that_refer_to_unfetched_entities(self): # Remove the entity in the middle and add a multi-D one referring to the third entity. # The multi-D one will go in the middle. db_map.get_entity_item(name="Sylvester", entity_class_name="cat").remove() - self._assert_success(db_map.add_entity_item(element_name_list=("Pulgoso", "Tom"), entity_class_name="dog__cat")) + self._assert_success( + db_map.add_entity_item(element_name_list=("Pulgoso", "Tom"), entity_class_name="dog__cat") + ) db_map.commit_session("Meow!") with DatabaseMapping(url) as db_map: # The ("Pulgoso", "Tom") entity will be fetched before "Tom". @@ -476,13 +506,13 @@ def test_committing_scenario_alternatives(self): self.assertIsNotNone(item) item = self._assert_success(db_map.add_scenario_item(name="my_scenario")) self.assertIsNotNone(item) - item = self._assert_success(db_map.add_scenario_alternative_item( - scenario_name="my_scenario", alternative_name="alt1", rank=0 - )) + item = self._assert_success( + db_map.add_scenario_alternative_item(scenario_name="my_scenario", alternative_name="alt1", rank=0) + ) self.assertIsNotNone(item) - item = self._assert_success(db_map.add_scenario_alternative_item( - scenario_name="my_scenario", alternative_name="alt2", rank=1 - )) + item = self._assert_success( + db_map.add_scenario_alternative_item(scenario_name="my_scenario", alternative_name="alt2", rank=1) + ) self.assertIsNotNone(item) db_map.commit_session("Add test data.") with DatabaseMapping(url) as db_map: @@ -518,7 +548,11 @@ def test_committing_entity_group_items_doesnt_add_commit_ids_to_them(self): self._assert_success(db_map.add_entity_class_item(name="my_class")) self._assert_success(db_map.add_entity_item(name="element", entity_class_name="my_class")) self._assert_success(db_map.add_entity_item(name="container", entity_class_name="my_class")) - self._assert_success(db_map.add_entity_group_item(group_name="container", member_name="element", entity_class_name="my_class")) + self._assert_success( + db_map.add_entity_group_item( + group_name="container", member_name="element", entity_class_name="my_class" + ) + ) db_map.commit_session("Add entity group.") groups = db_map.get_entity_group_items() self.assertEqual(len(groups), 1) @@ -528,40 +562,54 @@ def test_commit_parameter_value_coincidentally_called_is_active(self): with DatabaseMapping("sqlite://", create=True) as db_map: self._assert_success(db_map.add_parameter_value_list_item(name="booleans")) value, value_type = to_database(True) - self._assert_success(db_map.add_list_value_item(parameter_value_list_name="booleans", value=value, type=value_type, index=0)) + self._assert_success( + db_map.add_list_value_item(parameter_value_list_name="booleans", value=value, type=value_type, index=0) + ) self._assert_success(db_map.add_entity_class_item(name="my_class")) - self._assert_success(db_map.add_parameter_definition_item( - name="is_active", entity_class_name="my_class", parameter_value_list_name="booleans" - )) + self._assert_success( + db_map.add_parameter_definition_item( + name="is_active", entity_class_name="my_class", parameter_value_list_name="booleans" + ) + ) self._assert_success(db_map.add_entity_item(name="widget1", entity_class_name="my_class")) self._assert_success(db_map.add_entity_item(name="widget2", entity_class_name="my_class")) self._assert_success(db_map.add_entity_item(name="no_is_active", entity_class_name="my_class")) - self._assert_success(db_map.add_entity_alternative_item( - entity_class_name="my_class", entity_byname=("widget1",), alternative_name="Base", active=False - )) - self._assert_success(db_map.add_entity_alternative_item( - entity_class_name="my_class", entity_byname=("widget2",), alternative_name="Base", active=False - )) - self._assert_success(db_map.add_entity_alternative_item( - entity_class_name="my_class", entity_byname=("no_is_active",), alternative_name="Base", active=False - )) + self._assert_success( + db_map.add_entity_alternative_item( + entity_class_name="my_class", entity_byname=("widget1",), alternative_name="Base", active=False + ) + ) + self._assert_success( + db_map.add_entity_alternative_item( + entity_class_name="my_class", entity_byname=("widget2",), alternative_name="Base", active=False + ) + ) + self._assert_success( + db_map.add_entity_alternative_item( + entity_class_name="my_class", entity_byname=("no_is_active",), alternative_name="Base", active=False + ) + ) value, value_type = to_database(True) - self._assert_success(db_map.add_parameter_value_item( - entity_class_name="my_class", - parameter_definition_name="is_active", - entity_byname=("widget1",), - alternative_name="Base", - value=value, - type=value_type, - )) - self._assert_success(db_map.add_parameter_value_item( - entity_class_name="my_class", - parameter_definition_name="is_active", - entity_byname=("widget2",), - alternative_name="Base", - value=value, - type=value_type, - )) + self._assert_success( + db_map.add_parameter_value_item( + entity_class_name="my_class", + parameter_definition_name="is_active", + entity_byname=("widget1",), + alternative_name="Base", + value=value, + type=value_type, + ) + ) + self._assert_success( + db_map.add_parameter_value_item( + entity_class_name="my_class", + parameter_definition_name="is_active", + entity_byname=("widget2",), + alternative_name="Base", + value=value, + type=value_type, + ) + ) db_map.commit_session("Add test data to see if this crashes.") entity_names = {entity["id"]: entity["name"] for entity in db_map.query(db_map.wide_entity_sq)} alternative_names = { @@ -585,27 +633,35 @@ def test_commit_default_value_for_parameter_called_is_active(self): with DatabaseMapping("sqlite://", create=True) as db_map: self._assert_success(db_map.add_parameter_value_list_item(name="booleans")) value, value_type = to_database(True) - self._assert_success(db_map.add_list_value_item(parameter_value_list_name="booleans", value=value, type=value_type, index=0)) + self._assert_success( + db_map.add_list_value_item(parameter_value_list_name="booleans", value=value, type=value_type, index=0) + ) self._assert_success(db_map.add_entity_class_item(name="Widget")) - self._assert_success(db_map.add_parameter_definition_item( - name="is_active", - entity_class_name="Widget", - parameter_value_list_name="booleans", - default_value=value, - default_type=value_type, - )) + self._assert_success( + db_map.add_parameter_definition_item( + name="is_active", + entity_class_name="Widget", + parameter_value_list_name="booleans", + default_value=value, + default_type=value_type, + ) + ) self._assert_success(db_map.add_entity_class_item(name="Gadget")) - self._assert_success(db_map.add_parameter_definition_item( - name="is_active", - entity_class_name="Gadget", - parameter_value_list_name="booleans", - default_value=value, - default_type=value_type, - )) + self._assert_success( + db_map.add_parameter_definition_item( + name="is_active", + entity_class_name="Gadget", + parameter_value_list_name="booleans", + default_value=value, + default_type=value_type, + ) + ) self._assert_success(db_map.add_entity_class_item(name="NoIsActiveDefault")) - self._assert_success(db_map.add_parameter_definition_item( - name="is_active", entity_class_name="NoIsActiveDefault", parameter_value_list_name="booleans" - )) + self._assert_success( + db_map.add_parameter_definition_item( + name="is_active", entity_class_name="NoIsActiveDefault", parameter_value_list_name="booleans" + ) + ) db_map.commit_session("Add test data to see if this crashes") active_by_defaults = { entity_class["name"]: entity_class["active_by_default"] @@ -2928,11 +2984,12 @@ def test_cascade_remove_unfetched(self): self.assertEqual(ents, []) -@unittest.skipIf(os.name == 'nt', "Need to fix") class TestDatabaseMappingConcurrent(unittest.TestCase): + @unittest.skipIf(os.name == 'nt', "Needs fixing") def test_concurrent_commit_threading(self): self._do_test_concurrent_commit(threading.Thread) + @unittest.skipIf(os.name == 'nt', "Needs fixing") def test_concurrent_commit_multiprocessing(self): self._do_test_concurrent_commit(multiprocessing.Process) @@ -2942,7 +2999,6 @@ def _commit_on_thread(db_map, msg): with TemporaryDirectory() as temp_dir: url = "sqlite:///" + os.path.join(temp_dir, "database.sqlite") - with CustomDatabaseMapping(url, create=True) as db_map1: with CustomDatabaseMapping(url) as db_map2: db_map1.add_entity_class_item(name="dog") @@ -2954,7 +3010,6 @@ def _commit_on_thread(db_map, msg): c1.start() c1.join() c2.join() - with CustomDatabaseMapping(url) as db_map: commit_msgs = {x["comment"] for x in db_map.query(db_map.commit_sq)} entity_class_names = [x["name"] for x in db_map.query(db_map.entity_class_sq)] @@ -2962,6 +3017,45 @@ def _commit_on_thread(db_map, msg): self.assertEqual(len(entity_class_names), 2) self.assertEqual(set(entity_class_names), {"cat", "dog"}) + def test_uncommitted_mapped_items_take_id_from_externally_committed_items(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "database.sqlite") + with CustomDatabaseMapping(url, create=True) as db_map1: + with CustomDatabaseMapping(url) as db_map2: + db_map1.add_entity_class_item(name="widget") + db_map1.add_entity_class_item(name="gadget") + db_map2.add_entity_class_item(name="gadget") + db_map2.add_entity_class_item(name="widget") + db_map2.commit_session("No comment") + committed_resolved_entity_classes = [x.resolve() for x in db_map2.get_items("entity_class")] + committed_resolved_id_by_name = {x["name"]: x["id"] for x in committed_resolved_entity_classes} + uncommitted_entity_classes = db_map1.get_items("entity_class") + uncommitted_resolved_entity_classes = [x.resolve() for x in uncommitted_entity_classes] + uncommitted_resolved_id_by_name = {x["name"]: x["id"] for x in uncommitted_resolved_entity_classes} + self.assertEqual(committed_resolved_id_by_name, uncommitted_resolved_id_by_name) + for mapped_item in uncommitted_entity_classes: + self.assertTrue(mapped_item.is_committed()) + with self.assertRaises(SpineDBAPIError): + db_map1.commit_session("No comment") + + def test_committed_mapped_items_take_id_from_externally_committed_items(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "database.sqlite") + with CustomDatabaseMapping(url, create=True) as db_map0: + db_map0.add_entity_class_item(name="widget") + db_map0.add_entity_class_item(name="gadget") + db_map0.commit_session("No comment") + with CustomDatabaseMapping(url) as db_map1: + with CustomDatabaseMapping(url) as db_map2: + entity_classes_before = db_map1.get_items("entity_class") + db_map2.purge_items("entity_class") + db_map2.add_entity_class_item(name="gadget") + db_map2.add_entity_class_item(name="widget") + db_map2.commit_session("No comment") + entity_classes_after = db_map1.get_items("entity_class") + # TODO: check that entity_classes_after is the same as entity_classes_before + # with exchanged ids + if __name__ == "__main__": unittest.main() From c6ba41d246d2336e4fd023b138eb526c4648fb7b Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 6 Feb 2024 13:54:36 +0100 Subject: [PATCH 258/317] Fix checking if db-item and mapped-item with same id have same uq-keys --- spinedb_api/db_mapping.py | 6 +- spinedb_api/db_mapping_base.py | 104 ++++++++++++++++++++------------- spinedb_api/spine_db_server.py | 5 +- tests/test_db_mapping_base.py | 19 +++--- 4 files changed, 77 insertions(+), 57 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 98a68aa9..930f56aa 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -727,11 +727,7 @@ def refresh_session(self): self._refresh() def has_external_commits(self): - """Tests whether the database has had commits from other sources than this mapping. - - Returns: - bool: True if database has external commits, False otherwise - """ + """See base class.""" return self._commit_count != self.query(self.commit_sq).count() def close(self): diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index ee0369e2..58042726 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -35,7 +35,8 @@ class DatabaseMappingBase: This class is not meant to be used directly. Instead, you should subclass it to fit your particular DB schema. - When subclassing, you need to implement :meth:`item_types`, :meth:`item_factory`, and :meth:`_make_sq`. + When subclassing, you need to implement :meth:`item_types`, :meth:`item_factory`, :meth:`_make_sq`, + and :meth:`has_external_commits`. """ def __init__(self): @@ -86,6 +87,14 @@ def item_factory(item_type): """ raise NotImplementedError() + def has_external_commits(self): + """Tests whether the database has had commits from other sources than this mapping. + + Returns: + bool: True if database has external commits, False otherwise + """ + raise NotImplementedError() + def _make_query(self, item_type, **kwargs): """Returns a :class:`~spinedb_api.query.Query` object to fetch items of given type. @@ -145,6 +154,8 @@ def _dirty_items(self): Returns: list """ + if self.has_external_commits(): + self._refresh() dirty_items = [] purged_item_types = {x for x in self.item_types() if self.mapped_table(x).purged} self._add_descendants(purged_item_types) @@ -205,6 +216,8 @@ def _rollback(self): def _refresh(self): """Clears fetch progress, so the DB is queried again.""" + for item_type in self.item_types(): + self._fetched.discard(item_type) def _check_item_type(self, item_type): if item_type not in self.all_item_types(): @@ -275,6 +288,8 @@ def do_fetch_more(self, item_type, offset=0, limit=None, **kwargs): chunk = self._get_next_chunk(item_type, offset, limit, **kwargs) if not chunk: return [] + if self.has_external_commits(): + self._refresh() mapped_table = self.mapped_table(item_type) items = [] new_items = [] @@ -291,7 +306,7 @@ def do_fetch_more(self, item_type, offset=0, limit=None, **kwargs): return items def do_fetch_all(self, item_type): - if item_type in self._fetched and not self.has_external_commits(): + if item_type in self._fetched: return self._fetched.add(item_type) self.do_fetch_more(item_type, offset=0, limit=None) @@ -363,12 +378,11 @@ def _unique_key_value_to_id(self, key, value, fetch=True): Returns: int or None """ - ids_by_unique_value = self._ids_by_unique_key_value.get(key, {}) - if not ids_by_unique_value and fetch: - self._db_map.do_fetch_all(self._item_type) - ids_by_unique_value = self._ids_by_unique_key_value.get(key, {}) value = tuple(tuple(x) if isinstance(x, list) else x for x in value) - ids = ids_by_unique_value.get(value, []) + ids = self._ids_by_unique_key_value.get(key, {}).get(value, []) + if not ids and fetch: + self._db_map.do_fetch_all(self._item_type) + ids = self._ids_by_unique_key_value.get(key, {}).get(value, []) return None if not ids else ids[-1] def _unique_key_value_to_item(self, key, value, fetch=True, valid_only=True): @@ -422,7 +436,7 @@ def _find_item_by_unique_key(self, item, skip_keys=(), fetch=True, valid_only=Tr return current_item # Maybe item is missing some key stuff, so try with a resolved and polished MappedItem too... mapped_item = self._make_item(item) - error = mapped_item.resolve_internal_fields(item.keys()) + error = mapped_item.resolve_internal_fields(skip_keys=item.keys()) if error: return {} error = mapped_item.polish() @@ -511,7 +525,7 @@ def _make_and_add_item(self, item): if not isinstance(item, MappedItemBase): item = self._make_item(item) item.polish() - if "id" not in item or not item.is_id_valid: + if "id" not in item or not item.has_valid_id: item["id"] = self._new_id() self[item["id"]] = item return item @@ -664,28 +678,31 @@ class MappedItemBase(dict): """A dictionary that represents a db item.""" fields = {} - """A dictionary mapping keys to a another dict mapping "type" to a Python type, + """A dictionary mapping fields to a another dict mapping "type" to a Python type, "value" to a description of the value for the key, and "optional" to a bool.""" _defaults = {} - """A dictionary mapping keys to their default values""" + """A dictionary mapping fields to their default values""" _unique_keys = () - """A tuple where each element is itself a tuple of keys corresponding to a unique constraint""" + """A tuple where each element is itself a tuple of fields corresponding to a unique key""" _references = {} - """A dictionary mapping source keys, to a tuple of reference item type and reference key. + """A dictionary mapping source fields, to a tuple of reference item type and reference field. Used to access external fields. """ _external_fields = {} - """A dictionary mapping keys that are not in the original dictionary, to a tuple of source key and reference key. - Keys in _external_fields are accessed via the reference key of the reference pointed at by the source key. + """A dictionary mapping fields that are not in the original dictionary, to a tuple of source field + and target field. + When accessing fields in _external_fields, we first find the reference pointed at by the source field, + and then return the target field of that reference. """ _alt_references = {} - """A dictionary mapping source keys, to a tuple of reference item type and reference key. + """A dictionary mapping source fields, to a tuple of reference item type and reference fields. Used only to resolve internal fields at item creation. """ _internal_fields = {} - """A dictionary mapping keys that are not in the original dictionary, to a tuple of source key and reference key. - Keys in _internal_fields are resolved to the reference key of the alternative reference pointed at by the - source key. + """A dictionary mapping fields that are not in the original dictionary, to a tuple of source field + and target field. + When resolving fields in _internal_fields, we first find the alt_reference pointed at by the source field, + and then use the target field of that reference. """ _private_fields = set() """A set with fields that should be ignored in validations.""" @@ -703,7 +720,7 @@ def __init__(self, db_map, item_type, **kwargs): self.restore_callbacks = set() self.update_callbacks = set() self.remove_callbacks = set() - self._is_id_valid = True + self._has_valid_id = True self._to_remove = False self._removed = False self._corrupted = False @@ -782,12 +799,12 @@ def key(self): return (self._item_type, id_) @property - def is_id_valid(self): - return self._is_id_valid + def has_valid_id(self): + return self._has_valid_id def invalidate_id(self): """Sets id as invalid.""" - self._is_id_valid = False + self._has_valid_id = False def _extended(self): """Returns a dict from this item's original fields plus all the references resolved statically. @@ -849,8 +866,14 @@ def is_equal_in_db(self, other): Returns: bool """ - this = self._db_map.make_item(self._item_type, **self.backup) if self.status == Status.to_update else self - return not this._something_to_update(other) + if self.status == Status.to_update: + this = self._db_map.make_item(self._item_type, **self.backup) + this.polish() + else: + this = self + other = self._db_map.make_item(self._item_type, **other) + other.polish() + return dict(this.unique_key_values()) == dict(other.unique_key_values()) def first_invalid_key(self): """Goes through the ``_references`` class attribute and returns the key of the first reference @@ -963,7 +986,7 @@ def _get_ref(self, ref_type, key_val, strong=True): """Collects a reference from the in-memory mapping. Adds this item to the reference's list of referrers if strong is True; or weak referrers if strong is False. - If the reference is not found, sets some flags. + Sets the self._corrupted and self._removed flags appropriately. Args: ref_type (str): The reference's type @@ -974,14 +997,11 @@ def _get_ref(self, ref_type, key_val, strong=True): MappedItemBase or dict """ mapped_table = self._db_map.mapped_table(ref_type) - ref = mapped_table.find_item(key_val, fetch=False) + ref = mapped_table.find_item(key_val, fetch=True) if not ref: - ref = mapped_table.find_item(key_val, fetch=True) - if not ref: - if strong: - self._corrupted = True - return {} - # Here we have a ref + if strong: + self._corrupted = True + return {} if strong: ref.add_referrer(self) if ref.removed: @@ -1171,20 +1191,20 @@ def __getattr__(self, name): def __getitem__(self, key): """Overridden to return references.""" - ext_val = self._external_fields.get(key) - if ext_val: - src_key, key = ext_val - ref_type, ref_key = self._references[src_key] - src_val = self[src_key] - if isinstance(src_val, tuple): - return tuple(self._get_ref(ref_type, {ref_key: x}).get(key) for x in src_val) - return self._get_ref(ref_type, {ref_key: src_val}).get(key) + source_target_key_tuple = self._external_fields.get(key) + if source_target_key_tuple: + source_key, target_key = source_target_key_tuple + ref_type, ref_key = self._references[source_key] + source_val = self[source_key] + if isinstance(source_val, tuple): + return tuple(self._get_ref(ref_type, {ref_key: x}).get(target_key) for x in source_val) + return self._get_ref(ref_type, {ref_key: source_val}).get(target_key) return super().__getitem__(key) def __setitem__(self, key, value): """Sets id valid if key is 'id'.""" if key == "id": - self._is_id_valid = True + self._has_valid_id = True super().__setitem__(key, value) def get(self, key, default=None): diff --git a/spinedb_api/spine_db_server.py b/spinedb_api/spine_db_server.py index 95853922..827e34bf 100644 --- a/spinedb_api/spine_db_server.py +++ b/spinedb_api/spine_db_server.py @@ -49,7 +49,6 @@ The server is started using :func:`closing_spine_db_server`. To control the order of writing you need to provide a queue, that you would obtain by calling :func:`db_server_manager`. - The below example illustrates most of the functionality of the module. We create two DB servers targeting the same DB, and set the second to write before the first (via the ``ordering`` argument to :func:`closing_spine_db_server`). @@ -76,7 +75,9 @@ def _import_entity_class(server_url, class_name): with db_server_manager() as mngr_queue: first_ordering = {"id": "second_before_first", "current": "first", "precursors": {"second"}, "part_count": 1} second_ordering = {"id": "second_before_first", "current": "second", "precursors": set(), "part_count": 1} - with closing_spine_db_server(db_url, server_manager_queue=mngr_queue, ordering=first_ordering) as first_server_url: + with closing_spine_db_server( + db_url, server_manager_queue=mngr_queue, ordering=first_ordering + ) as first_server_url: with closing_spine_db_server( db_url, server_manager_queue=mngr_queue, ordering=second_ordering ) as second_server_url: diff --git a/tests/test_db_mapping_base.py b/tests/test_db_mapping_base.py index 65d75803..dc88d8af 100644 --- a/tests/test_db_mapping_base.py +++ b/tests/test_db_mapping_base.py @@ -28,6 +28,9 @@ def item_factory(item_type): return MappedItemBase raise RuntimeError(f"unknown item_type '{item_type}'") + def has_external_commits(self): + return False + def _make_query(self, _item_type, **kwargs): return None @@ -37,11 +40,11 @@ def test_rolling_back_new_item_invalidates_its_id(self): db_map = TestDBMapping() mapped_table = db_map.mapped_table("cutlery") item = mapped_table.add_item({}) - self.assertTrue(item.is_id_valid) + self.assertTrue(item.has_valid_id) self.assertIn("id", item) id_ = item["id"] db_map._rollback() - self.assertFalse(item.is_id_valid) + self.assertFalse(item.has_valid_id) self.assertEqual(item["id"], id_) @@ -52,9 +55,9 @@ def test_readding_item_with_invalid_id_creates_new_id(self): item = mapped_table.add_item({}) id_ = item["id"] db_map._rollback() - self.assertFalse(item.is_id_valid) + self.assertFalse(item.has_valid_id) mapped_table.add_item(item) - self.assertTrue(item.is_id_valid) + self.assertTrue(item.has_valid_id) self.assertNotEqual(item["id"], id_) @@ -62,21 +65,21 @@ class TestMappedItemBase(unittest.TestCase): def test_id_is_valid_initially(self): db_map = TestDBMapping() item = MappedItemBase(db_map, "cutlery") - self.assertTrue(item.is_id_valid) + self.assertTrue(item.has_valid_id) def test_id_can_be_invalidated(self): db_map = TestDBMapping() item = MappedItemBase(db_map, "cutlery") item.invalidate_id() - self.assertFalse(item.is_id_valid) + self.assertFalse(item.has_valid_id) def test_setting_new_id_validates_it(self): db_map = TestDBMapping() item = MappedItemBase(db_map, "cutlery") item.invalidate_id() - self.assertFalse(item.is_id_valid) + self.assertFalse(item.has_valid_id) item["id"] = 23 - self.assertTrue(item.is_id_valid) + self.assertTrue(item.has_valid_id) if __name__ == '__main__': From cc2af3fda0cdda9a2bac8ef11960eb9100f63686 Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 6 Feb 2024 15:35:38 +0100 Subject: [PATCH 259/317] Notify clients of updating ids This is done by calling the remove callbacks (cascade_remove) and then readding the item. Hopefully clients that store item ids (Spine Toolbox) will be able to realize the update, but we need to test. --- spinedb_api/db_mapping_base.py | 62 ++++++++++++++++++++-------------- 1 file changed, 37 insertions(+), 25 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 58042726..e326659f 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -295,10 +295,10 @@ def do_fetch_more(self, item_type, offset=0, limit=None, **kwargs): new_items = [] # Add items first for x in chunk: - item, new = mapped_table.add_item_from_db(x) - if new: - new_items.append(item) - items.append(item) + for item, new in mapped_table.add_item_from_db(x): + if new: + new_items.append(item) + items.append(item) # Once all items are added, add the unique key values # Otherwise items that refer to other items that come later in the query will be seen as corrupted for item in new_items: @@ -536,22 +536,23 @@ def add_item_from_db(self, item): Args: item (dict): item from the DB. - Returns: - tuple(MappedItem,bool): The mapped item and whether it hadn't been added before. + Yields: + tuple(MappedItem,bool): A mapped item and whether it needs to be added to the unique key values dict. """ mapped_item = self._find_item_by_unique_key(item, fetch=False, valid_only=False) if mapped_item: - self._force_id(mapped_item, item["id"]) - return mapped_item, False + yield from self._force_id(mapped_item, item["id"]) + return mapped_item = self.get(item["id"]) if mapped_item is not None and mapped_item.is_equal_in_db(item): - return mapped_item, False - self._free_id(item["id"]) + yield mapped_item, False + return + yield from self._free_id(item["id"]) mapped_item = self._make_and_add_item(item) if self.purged: # Lazy purge: instead of fetching all at purge time, we purge stuff as it comes. mapped_item.cascade_remove(source=self.wildcard_item) - return mapped_item, True + yield mapped_item, True def _force_id(self, mapped_item, id_): """Makes sure that mapped_item has the given id_, corresponding to the new id of the item @@ -560,36 +561,46 @@ def _force_id(self, mapped_item, id_): Args: mapped_item (MappedItemBase): An item in the in-memory item. id_ (int): The most recent id_ of the item as fetched from the DB. + + Yields: + tuple(MappedItem,bool): A mapped item and whether it needs to be added to the unique key values dict. """ mapped_id = mapped_item["id"] if mapped_id == id_: # This is True even if mapped_id is a TempId resolved to the given id_ + yield mapped_item, False return if isinstance(mapped_id, TempId): # Easy, resolve the TempId to the new db id (and commit the item if pending) mapped_id.resolve(id_) if mapped_item.status == Status.to_add: mapped_item.status = Status.committed + yield mapped_item, False else: # Hard, update the id of the item manually. - self._free_id(id_) - self.remove_unique(mapped_item) + mapped_item.cascade_remove() + del self[mapped_id] self._db_map.update_id(mapped_item, id_) - self.add_unique(mapped_item) + yield from self._free_id(id_) self[id_] = mapped_item - del self[mapped_id] + yield mapped_item, True def _free_id(self, id_): """Makes sure the given id_ is free. Fix conflicts if not. Args: id_ (int) + + Yields: + tuple(MappedItem,bool): A mapped item and whether it needs to be added to the unique key values dict. """ conflicting_item = self.pop(id_, None) - if conflicting_item is not None: - self._resolve_conflict(conflicting_item) + if conflicting_item is None: + return () + self._resolve_id_conflict(conflicting_item) + yield conflicting_item, True - def _resolve_conflict(self, conflicting_item): - """Does something with conflicting_item which has been removed from the DB by an external commit. + def _resolve_id_conflict(self, conflicting_item): + """Does something with conflicting_item whose id now belongs to a different item after an external commit. Args: conflicting_item (MappedItemBase): an item in the memory mapping. @@ -598,16 +609,17 @@ def _resolve_conflict(self, conflicting_item): self._rescue_item(conflicting_item) def _rescue_item(self, conflicting_item): - """Rescues the given conflicting_item which has been removed from the DB by an external commit. + """Rescues the given conflicting_item whose id now belongs to a different item after an external commit. Args: conflicting_item (MappedItemBase): an item in the memory mapping. """ - self.remove_unique(conflicting_item) - new_id = self._new_id() - self._db_map.update_id(conflicting_item, new_id) - self.add_unique(conflicting_item) - if conflicting_item.status in (Status.to_update, Status.committed): + status = conflicting_item.status + conflicting_item.cascade_remove() + id_ = self._new_id() + self._db_map.update_id(conflicting_item, id_) + self[id_] = conflicting_item + if status in (Status.to_update, Status.committed): conflicting_item.status = Status.to_add def check_fields(self, item, valid_types=()): From d934cd073456c774c24275f918f4938b34d07813 Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 6 Feb 2024 17:24:56 +0100 Subject: [PATCH 260/317] Fetch referred types whenever there are external commits --- spinedb_api/db_mapping_base.py | 28 ++++++++---- tests/test_DatabaseMapping.py | 80 ++++++++++++++++++++++++---------- 2 files changed, 77 insertions(+), 31 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index e326659f..f2966aa8 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -290,6 +290,9 @@ def do_fetch_more(self, item_type, offset=0, limit=None, **kwargs): return [] if self.has_external_commits(): self._refresh() + for ref_type in self.item_factory(item_type).ref_types(): + if ref_type != item_type: + self.do_fetch_all(ref_type) mapped_table = self.mapped_table(item_type) items = [] new_items = [] @@ -318,6 +321,8 @@ def update_id(self, mapped_item, new_id): mapped_item (MappedItemBase) new_id (int) """ + mapped_item.cascade_call_remove_callbacks() + mapped_item.cascade_remove_unique() old_id = mapped_item["id"] mapped_item["id"] = new_id for item_type in self.item_types(): @@ -330,6 +335,7 @@ def update_id(self, mapped_item, new_id): item[field] = tuple(new_id if id_ == old_id else id_ for id_ in value) elif old_id == value: item[field] = new_id + mapped_item.cascade_add_unique() class _MappedTable(dict): @@ -542,6 +548,7 @@ def add_item_from_db(self, item): mapped_item = self._find_item_by_unique_key(item, fetch=False, valid_only=False) if mapped_item: yield from self._force_id(mapped_item, item["id"]) + yield mapped_item, False return mapped_item = self.get(item["id"]) if mapped_item is not None and mapped_item.is_equal_in_db(item): @@ -566,23 +573,19 @@ def _force_id(self, mapped_item, id_): tuple(MappedItem,bool): A mapped item and whether it needs to be added to the unique key values dict. """ mapped_id = mapped_item["id"] - if mapped_id == id_: # This is True even if mapped_id is a TempId resolved to the given id_ - yield mapped_item, False + if mapped_id == id_: return if isinstance(mapped_id, TempId): # Easy, resolve the TempId to the new db id (and commit the item if pending) mapped_id.resolve(id_) if mapped_item.status == Status.to_add: mapped_item.status = Status.committed - yield mapped_item, False else: # Hard, update the id of the item manually. - mapped_item.cascade_remove() - del self[mapped_id] self._db_map.update_id(mapped_item, id_) yield from self._free_id(id_) self[id_] = mapped_item - yield mapped_item, True + del self[mapped_id] def _free_id(self, id_): """Makes sure the given id_ is free. Fix conflicts if not. @@ -595,9 +598,9 @@ def _free_id(self, id_): """ conflicting_item = self.pop(id_, None) if conflicting_item is None: - return () + return self._resolve_id_conflict(conflicting_item) - yield conflicting_item, True + yield conflicting_item, False def _resolve_id_conflict(self, conflicting_item): """Does something with conflicting_item whose id now belongs to a different item after an external commit. @@ -615,7 +618,6 @@ def _rescue_item(self, conflicting_item): conflicting_item (MappedItemBase): an item in the memory mapping. """ status = conflicting_item.status - conflicting_item.cascade_remove() id_ = self._new_id() self._db_map.update_id(conflicting_item, id_) self[id_] = conflicting_item @@ -1140,12 +1142,20 @@ def cascade_remove(self, source=None): for referrer in self._referrers.values(): referrer.cascade_remove(source=self) self._update_weak_referrers() + self.call_remove_callbacks() + + def call_remove_callbacks(self): obsolete = set() for callback in list(self.remove_callbacks): if not callback(self): obsolete.add(callback) self.remove_callbacks -= obsolete + def cascade_call_remove_callbacks(self): + for referrer in self._referrers.values(): + referrer.cascade_call_remove_callbacks() + self.call_remove_callbacks() + def cascade_update(self): """Updates this item and all its referrers in cascade. Also, calls items' update callbacks. diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 8e871e6e..d821fb3f 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -85,12 +85,14 @@ def test_shorthand_filter_query_works(self): db_map.close() -class TestDatabaseMapping(unittest.TestCase): +class AssertSuccessMixin: def _assert_success(self, result): item, error = result self.assertIsNone(error) return item + +class TestDatabaseMapping(AssertSuccessMixin, unittest.TestCase): def test_active_by_default_is_initially_false_for_zero_dimensional_entity_class(self): with DatabaseMapping("sqlite://", create=True) as db_map: item = self._assert_success(db_map.add_entity_class_item(name="Entity")) @@ -2984,7 +2986,7 @@ def test_cascade_remove_unfetched(self): self.assertEqual(ents, []) -class TestDatabaseMappingConcurrent(unittest.TestCase): +class TestDatabaseMappingConcurrent(AssertSuccessMixin, unittest.TestCase): @unittest.skipIf(os.name == 'nt', "Needs fixing") def test_concurrent_commit_threading(self): self._do_test_concurrent_commit(threading.Thread) @@ -3021,40 +3023,74 @@ def test_uncommitted_mapped_items_take_id_from_externally_committed_items(self): with TemporaryDirectory() as temp_dir: url = "sqlite:///" + os.path.join(temp_dir, "database.sqlite") with CustomDatabaseMapping(url, create=True) as db_map1: + db_map1.add_entity_class_item(name="widget") + db_map1.add_entity_class_item(name="gadget") with CustomDatabaseMapping(url) as db_map2: - db_map1.add_entity_class_item(name="widget") - db_map1.add_entity_class_item(name="gadget") + # Add the same classes in different order db_map2.add_entity_class_item(name="gadget") db_map2.add_entity_class_item(name="widget") db_map2.commit_session("No comment") committed_resolved_entity_classes = [x.resolve() for x in db_map2.get_items("entity_class")] committed_resolved_id_by_name = {x["name"]: x["id"] for x in committed_resolved_entity_classes} - uncommitted_entity_classes = db_map1.get_items("entity_class") - uncommitted_resolved_entity_classes = [x.resolve() for x in uncommitted_entity_classes] - uncommitted_resolved_id_by_name = {x["name"]: x["id"] for x in uncommitted_resolved_entity_classes} - self.assertEqual(committed_resolved_id_by_name, uncommitted_resolved_id_by_name) - for mapped_item in uncommitted_entity_classes: - self.assertTrue(mapped_item.is_committed()) - with self.assertRaises(SpineDBAPIError): - db_map1.commit_session("No comment") + # Verify that the uncommitted classes are now seen as 'committed' + uncommitted_entity_classes = db_map1.get_items("entity_class") + uncommitted_resolved_entity_classes = [x.resolve() for x in uncommitted_entity_classes] + uncommitted_resolved_id_by_name = {x["name"]: x["id"] for x in uncommitted_resolved_entity_classes} + self.assertEqual(committed_resolved_id_by_name, uncommitted_resolved_id_by_name) + for mapped_item in uncommitted_entity_classes: + self.assertTrue(mapped_item.is_committed()) + with self.assertRaises(SpineDBAPIError): + db_map1.commit_session("No comment") def test_committed_mapped_items_take_id_from_externally_committed_items(self): with TemporaryDirectory() as temp_dir: url = "sqlite:///" + os.path.join(temp_dir, "database.sqlite") with CustomDatabaseMapping(url, create=True) as db_map0: + # Add widget before gadget db_map0.add_entity_class_item(name="widget") db_map0.add_entity_class_item(name="gadget") db_map0.commit_session("No comment") - with CustomDatabaseMapping(url) as db_map1: - with CustomDatabaseMapping(url) as db_map2: - entity_classes_before = db_map1.get_items("entity_class") - db_map2.purge_items("entity_class") - db_map2.add_entity_class_item(name="gadget") - db_map2.add_entity_class_item(name="widget") - db_map2.commit_session("No comment") - entity_classes_after = db_map1.get_items("entity_class") - # TODO: check that entity_classes_after is the same as entity_classes_before - # with exchanged ids + with CustomDatabaseMapping(url) as db_map1: + # Add classes to a model + model = {} + for x in db_map1.get_items("entity_class"): + model[x["id"]] = x + x.add_remove_callback(lambda x: model.pop(x["id"])) + self.assertEqual(len(model), 2) + with CustomDatabaseMapping(url) as db_map2: + # Purge, then add *gadget* before *widget* (swap the order) + # Also add an entity + db_map2.purge_items("entity_class") + db_map2.add_entity_class_item(name="gadget") + db_map2.add_entity_class_item(name="widget") + db_map2.add_entity_item(entity_class_name="gadget", name="phone") + db_map2.commit_session("No comment") + # Check that we see the entity added by the other mapping + phone = db_map1.get_entity_item(entity_class_name="gadget", name="phone") + self.assertIsNotNone(phone) + # Overwritten classes should have been removed from the model + self.assertEqual(len(model), 0) + + def test_fetching_entities_after_external_change_has_renamed_their_classes(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + self._assert_success(db_map.add_entity_class_item(name="Widget")) + self._assert_success(db_map.add_entity_class_item(name="Gadget")) + self._assert_success(db_map.add_entity_item(entity_class_name="Widget", name="smart_watch")) + widget = db_map.get_entity_item(entity_class_name="Widget", name="smart_watch") + self.assertEqual(widget["name"], "smart_watch") + db_map.commit_session("Add initial data.") + with DatabaseMapping(url) as shadow_db_map: + widget_class = shadow_db_map.get_entity_class_item(name="Widget") + widget_class.update(name="NotAWidget") + gadget_class = shadow_db_map.get_entity_class_item(name="Gadget") + gadget_class.update(name="Widget") + widget_class.update(name="Gadget") + shadow_db_map.commit_session("Swap Widget and Gadget to cause mayhem.") + db_map.refresh_session() + gadget = db_map.get_entity_item(entity_class_name="Gadget", name="smart_watch") + self.assertEqual(gadget["name"], "smart_watch") if __name__ == "__main__": From fb68b5ee6a9f9d022de8aae58242a7ce716d40ca Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 7 Feb 2024 09:15:34 +0100 Subject: [PATCH 261/317] Minimize refetches when DB has external commits --- spinedb_api/db_mapping.py | 10 ++++--- spinedb_api/db_mapping_base.py | 51 ++++++++++++++++++---------------- tests/test_db_mapping_base.py | 4 +-- 3 files changed, 35 insertions(+), 30 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 930f56aa..d30ec53b 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -175,11 +175,10 @@ def __init__( self._metadata = MetaData(self.engine) self._metadata.reflect() self._tablenames = [t.name for t in self._metadata.sorted_tables] - self.closed = False if self._filter_configs is not None: stack = load_filters(self._filter_configs) apply_filter_stack(self, stack) - self._commit_count = self.query(self.commit_sq).count() + self._commit_count = self._query_commit_count() def __enter__(self): return self @@ -202,6 +201,9 @@ def all_item_types(): def item_factory(item_type): return item_factory(item_type) + def _query_commit_count(self): + return self.query(self.commit_sq).count() + def _make_sq(self, item_type): sq_name = self._sq_name_by_item_type[item_type] return getattr(self, sq_name) @@ -712,7 +714,7 @@ def commit_session(self, comment): if self._memory: self._memory_dirty = True transformation_info = compatibility_transformations(connection) - self._commit_count = self.query(self.commit_sq).count() + self._commit_count = self._query_commit_count() return transformation_info def rollback_session(self): @@ -728,7 +730,7 @@ def refresh_session(self): def has_external_commits(self): """See base class.""" - return self._commit_count != self.query(self.commit_sq).count() + return self._commit_count != self._query_commit_count() def close(self): """Closes this DB mapping. This is only needed if you're keeping a long-lived session. diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index f2966aa8..a1f567fe 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -36,12 +36,13 @@ class DatabaseMappingBase: This class is not meant to be used directly. Instead, you should subclass it to fit your particular DB schema. When subclassing, you need to implement :meth:`item_types`, :meth:`item_factory`, :meth:`_make_sq`, - and :meth:`has_external_commits`. + and :meth:`_query_commit_count`. """ def __init__(self): + self.closed = False self._mapped_tables = {} - self._fetched = set() + self._fetched = {} item_types = self.item_types() self._sorted_item_types = [] while item_types: @@ -50,6 +51,7 @@ def __init__(self): item_types.append(item_type) else: self._sorted_item_types.append(item_type) + self._refresh() @staticmethod def item_types(): @@ -87,14 +89,6 @@ def item_factory(item_type): """ raise NotImplementedError() - def has_external_commits(self): - """Tests whether the database has had commits from other sources than this mapping. - - Returns: - bool: True if database has external commits, False otherwise - """ - raise NotImplementedError() - def _make_query(self, item_type, **kwargs): """Returns a :class:`~spinedb_api.query.Query` object to fetch items of given type. @@ -136,6 +130,16 @@ def _make_sq(self, item_type): """ raise NotImplementedError() + def _query_commit_count(self): + """Returns the number of commits in the DB. + + :meta private: + + Returns: + int + """ + raise NotImplementedError() + def make_item(self, item_type, **item): factory = self.item_factory(item_type) return factory(self, item_type, **item) @@ -154,8 +158,6 @@ def _dirty_items(self): Returns: list """ - if self.has_external_commits(): - self._refresh() dirty_items = [] purged_item_types = {x for x in self.item_types() if self.mapped_table(x).purged} self._add_descendants(purged_item_types) @@ -216,8 +218,11 @@ def _rollback(self): def _refresh(self): """Clears fetch progress, so the DB is queried again.""" - for item_type in self.item_types(): - self._fetched.discard(item_type) + self._reset_fetched(self.item_types()) + + def _reset_fetched(self, item_types): + for item_type in item_types: + self._fetched[item_type] = -1 def _check_item_type(self, item_type): if item_type not in self.all_item_types(): @@ -238,7 +243,7 @@ def reset(self, *item_types): self._add_descendants(item_types) for item_type in item_types: self._mapped_tables.pop(item_type, None) - self._fetched.discard(item_type) + self._reset_fetched(item_types) def reset_purging(self): """Resets purging status for all item types. @@ -288,11 +293,9 @@ def do_fetch_more(self, item_type, offset=0, limit=None, **kwargs): chunk = self._get_next_chunk(item_type, offset, limit, **kwargs) if not chunk: return [] - if self.has_external_commits(): - self._refresh() - for ref_type in self.item_factory(item_type).ref_types(): - if ref_type != item_type: - self.do_fetch_all(ref_type) + for ref_type in self.item_factory(item_type).ref_types(): + if ref_type != item_type: + self.do_fetch_all(ref_type) mapped_table = self.mapped_table(item_type) items = [] new_items = [] @@ -309,10 +312,10 @@ def do_fetch_more(self, item_type, offset=0, limit=None, **kwargs): return items def do_fetch_all(self, item_type): - if item_type in self._fetched: - return - self._fetched.add(item_type) - self.do_fetch_more(item_type, offset=0, limit=None) + commit_count = self._query_commit_count() + if self._fetched[item_type] != commit_count: + self._fetched[item_type] = commit_count + self.do_fetch_more(item_type, offset=0, limit=None) def update_id(self, mapped_item, new_id): """Updates the id of the given item to the given new_id, also in all its referees. diff --git a/tests/test_db_mapping_base.py b/tests/test_db_mapping_base.py index dc88d8af..cd67668c 100644 --- a/tests/test_db_mapping_base.py +++ b/tests/test_db_mapping_base.py @@ -28,8 +28,8 @@ def item_factory(item_type): return MappedItemBase raise RuntimeError(f"unknown item_type '{item_type}'") - def has_external_commits(self): - return False + def _query_commit_count(self): + return -1 def _make_query(self, _item_type, **kwargs): return None From 220ae181e309aa9cd84fde3bfb1a66127cb29d43 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 7 Feb 2024 10:27:35 +0100 Subject: [PATCH 262/317] Improve check for db-item equals to mapped-item --- spinedb_api/db_mapping_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index a1f567fe..a3b5bda7 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -549,12 +549,12 @@ def add_item_from_db(self, item): tuple(MappedItem,bool): A mapped item and whether it needs to be added to the unique key values dict. """ mapped_item = self._find_item_by_unique_key(item, fetch=False, valid_only=False) - if mapped_item: + if mapped_item and mapped_item.is_equal_in_db(item): yield from self._force_id(mapped_item, item["id"]) yield mapped_item, False return mapped_item = self.get(item["id"]) - if mapped_item is not None and mapped_item.is_equal_in_db(item): + if mapped_item and mapped_item.is_equal_in_db(item): yield mapped_item, False return yield from self._free_id(item["id"]) From 87fa8d8f30262b00e696de325e4e3f61ca3ae9bb Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 7 Feb 2024 12:41:18 +0100 Subject: [PATCH 263/317] Fix performance issue due to distrusting the DB too much We were constantly asking if the DB had external commits and distrusting DB items even if no external commits. That led to perf. degradation, now we do things better. --- spinedb_api/db_mapping_base.py | 46 ++++++++++++++++------------------ tests/test_db_mapping_base.py | 4 +-- 2 files changed, 24 insertions(+), 26 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index a3b5bda7..a2c83ea8 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -36,13 +36,13 @@ class DatabaseMappingBase: This class is not meant to be used directly. Instead, you should subclass it to fit your particular DB schema. When subclassing, you need to implement :meth:`item_types`, :meth:`item_factory`, :meth:`_make_sq`, - and :meth:`_query_commit_count`. + and :meth:`has_external_commits`. """ def __init__(self): self.closed = False self._mapped_tables = {} - self._fetched = {} + self._fetched = set() item_types = self.item_types() self._sorted_item_types = [] while item_types: @@ -51,7 +51,6 @@ def __init__(self): item_types.append(item_type) else: self._sorted_item_types.append(item_type) - self._refresh() @staticmethod def item_types(): @@ -130,13 +129,11 @@ def _make_sq(self, item_type): """ raise NotImplementedError() - def _query_commit_count(self): - """Returns the number of commits in the DB. - - :meta private: + def has_external_commits(self): + """Tests whether the database has had commits from other sources than this mapping. Returns: - int + bool: True if database has external commits, False otherwise """ raise NotImplementedError() @@ -158,6 +155,8 @@ def _dirty_items(self): Returns: list """ + if self.has_external_commits(): + self._refresh() dirty_items = [] purged_item_types = {x for x in self.item_types() if self.mapped_table(x).purged} self._add_descendants(purged_item_types) @@ -218,11 +217,7 @@ def _rollback(self): def _refresh(self): """Clears fetch progress, so the DB is queried again.""" - self._reset_fetched(self.item_types()) - - def _reset_fetched(self, item_types): - for item_type in item_types: - self._fetched[item_type] = -1 + self._fetched.clear() def _check_item_type(self, item_type): if item_type not in self.all_item_types(): @@ -243,7 +238,7 @@ def reset(self, *item_types): self._add_descendants(item_types) for item_type in item_types: self._mapped_tables.pop(item_type, None) - self._reset_fetched(item_types) + self._fetched.discard(item_type) def reset_purging(self): """Resets purging status for all item types. @@ -293,15 +288,18 @@ def do_fetch_more(self, item_type, offset=0, limit=None, **kwargs): chunk = self._get_next_chunk(item_type, offset, limit, **kwargs) if not chunk: return [] - for ref_type in self.item_factory(item_type).ref_types(): - if ref_type != item_type: - self.do_fetch_all(ref_type) + is_db_dirty = self.has_external_commits() + if is_db_dirty: + for ref_type in self.item_factory(item_type).ref_types(): + if ref_type != item_type: + self._fetched.discard(ref_type) + self.do_fetch_all(ref_type) mapped_table = self.mapped_table(item_type) items = [] new_items = [] # Add items first for x in chunk: - for item, new in mapped_table.add_item_from_db(x): + for item, new in mapped_table.add_item_from_db(x, not is_db_dirty): if new: new_items.append(item) items.append(item) @@ -312,9 +310,8 @@ def do_fetch_more(self, item_type, offset=0, limit=None, **kwargs): return items def do_fetch_all(self, item_type): - commit_count = self._query_commit_count() - if self._fetched[item_type] != commit_count: - self._fetched[item_type] = commit_count + if item_type not in self._fetched: + self._fetched.add(item_type) self.do_fetch_more(item_type, offset=0, limit=None) def update_id(self, mapped_item, new_id): @@ -539,22 +536,23 @@ def _make_and_add_item(self, item): self[item["id"]] = item return item - def add_item_from_db(self, item): + def add_item_from_db(self, item, is_db_clean): """Adds an item fetched from the DB. Args: item (dict): item from the DB. + is_db_clean (Bool) Yields: tuple(MappedItem,bool): A mapped item and whether it needs to be added to the unique key values dict. """ mapped_item = self._find_item_by_unique_key(item, fetch=False, valid_only=False) - if mapped_item and mapped_item.is_equal_in_db(item): + if mapped_item and (is_db_clean or mapped_item.is_equal_in_db(item)): yield from self._force_id(mapped_item, item["id"]) yield mapped_item, False return mapped_item = self.get(item["id"]) - if mapped_item and mapped_item.is_equal_in_db(item): + if mapped_item and (is_db_clean or mapped_item.is_equal_in_db(item)): yield mapped_item, False return yield from self._free_id(item["id"]) diff --git a/tests/test_db_mapping_base.py b/tests/test_db_mapping_base.py index cd67668c..dc88d8af 100644 --- a/tests/test_db_mapping_base.py +++ b/tests/test_db_mapping_base.py @@ -28,8 +28,8 @@ def item_factory(item_type): return MappedItemBase raise RuntimeError(f"unknown item_type '{item_type}'") - def _query_commit_count(self): - return -1 + def has_external_commits(self): + return False def _make_query(self, _item_type, **kwargs): return None From 9fd3eb80dfc4a4629553f72d389ad0c641c22944 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Wed, 7 Feb 2024 15:16:40 +0200 Subject: [PATCH 264/317] Add tests for DatabaseMapping. --- spinedb_api/db_mapping_base.py | 2 +- tests/test_DatabaseMapping.py | 565 ++++++++++++++++++++++++++++++++- 2 files changed, 558 insertions(+), 9 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index a3b5bda7..f7af7810 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -1295,7 +1295,7 @@ def is_committed(self): def _asdict(self): return self._mapped_item._asdict() - def _extended(self): + def extended(self): return self._mapped_item._extended() def update(self, **kwargs): diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index d821fb3f..8978c4c2 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -85,14 +85,14 @@ def test_shorthand_filter_query_works(self): db_map.close() -class AssertSuccessMixin: +class AssertSuccessTestCase(unittest.TestCase): def _assert_success(self, result): item, error = result self.assertIsNone(error) return item -class TestDatabaseMapping(AssertSuccessMixin, unittest.TestCase): +class TestDatabaseMapping(AssertSuccessTestCase): def test_active_by_default_is_initially_false_for_zero_dimensional_entity_class(self): with DatabaseMapping("sqlite://", create=True) as db_map: item = self._assert_success(db_map.add_entity_class_item(name="Entity")) @@ -280,7 +280,7 @@ def test_update_entity_metadata_by_changing_its_entity(self): ) entity_metadata.update(entity_byname=("entity_2",)) self.assertEqual( - entity_metadata._extended(), + entity_metadata.extended(), { "entity_class_name": "my_class", "entity_byname": ("entity_2",), @@ -367,7 +367,7 @@ def test_update_parameter_value_metadata_by_changing_its_parameter(self): ) value_metadata.update(parameter_definition_name="y") self.assertEqual( - value_metadata._extended(), + value_metadata.extended(), { "entity_class_name": "my_class", "entity_byname": ("my_entity",), @@ -533,7 +533,7 @@ def test_committing_entity_class_items_doesnt_add_commit_ids_to_them(self): db_map.commit_session("Add class.") classes = db_map.get_entity_class_items() self.assertEqual(len(classes), 1) - self.assertNotIn("commit_id", classes[0]._extended()) + self.assertNotIn("commit_id", classes[0].extended()) def test_committing_superclass_subclass_items_doesnt_add_commit_ids_to_them(self): with DatabaseMapping("sqlite://", create=True) as db_map: @@ -543,7 +543,7 @@ def test_committing_superclass_subclass_items_doesnt_add_commit_ids_to_them(self db_map.commit_session("Add class hierarchy.") classes = db_map.get_superclass_subclass_items() self.assertEqual(len(classes), 1) - self.assertNotIn("commit_id", classes[0]._extended()) + self.assertNotIn("commit_id", classes[0].extended()) def test_committing_entity_group_items_doesnt_add_commit_ids_to_them(self): with DatabaseMapping("sqlite://", create=True) as db_map: @@ -558,7 +558,7 @@ def test_committing_entity_group_items_doesnt_add_commit_ids_to_them(self): db_map.commit_session("Add entity group.") groups = db_map.get_entity_group_items() self.assertEqual(len(groups), 1) - self.assertNotIn("commit_id", groups[0]._extended()) + self.assertNotIn("commit_id", groups[0].extended()) def test_commit_parameter_value_coincidentally_called_is_active(self): with DatabaseMapping("sqlite://", create=True) as db_map: @@ -2986,7 +2986,7 @@ def test_cascade_remove_unfetched(self): self.assertEqual(ents, []) -class TestDatabaseMappingConcurrent(AssertSuccessMixin, unittest.TestCase): +class TestDatabaseMappingConcurrent(AssertSuccessTestCase): @unittest.skipIf(os.name == 'nt', "Needs fixing") def test_concurrent_commit_threading(self): self._do_test_concurrent_commit(threading.Thread) @@ -3092,6 +3092,555 @@ def test_fetching_entities_after_external_change_has_renamed_their_classes(self) gadget = db_map.get_entity_item(entity_class_name="Gadget", name="smart_watch") self.assertEqual(gadget["name"], "smart_watch") + def test_additive_commit_from_another_db_map_gets_fetched(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + items = db_map.get_items("entity") + self.assertEqual(len(items), 0) + with DatabaseMapping(url) as shadow_db_map: + self._assert_success(shadow_db_map.add_entity_class_item(name="my_class")) + self._assert_success(shadow_db_map.add_entity_item(name="my_entity", entity_class_name="my_class")) + shadow_db_map.commit_session("Add entity.") + db_map.refresh_session() + items = db_map.get_items("entity") + self.assertEqual(len(items), 1) + self.assertEqual( + items[0]._asdict(), + { + "id": 1, + "name": "my_entity", + "description": None, + "class_id": 1, + "element_name_list": None, + "element_id_list": (), + "commit_id": 2, + }, + ) + + def test_restoring_entity_whose_db_id_has_been_replaced_by_external_db_modification(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + self._assert_success(db_map.add_entity_class_item(name="my_class")) + item = self._assert_success(db_map.add_entity_item(entity_class_name="my_class", name="my_entity")) + original_id = item["id"] + db_map.commit_session("Add initial data.") + items = db_map.fetch_more("entity") + self.assertEqual(len(items), 1) + db_map.remove_item("entity", original_id) + db_map.commit_session("Removed entity.") + self.assertEqual(len(db_map.get_entity_items()), 0) + with DatabaseMapping(url) as shadow_db_map: + self._assert_success( + shadow_db_map.add_entity_item(entity_class_name="my_class", name="other_entity") + ) + shadow_db_map.commit_session("Add entity with different name, probably reusing previous id.") + db_map.refresh_session() + items = db_map.fetch_more("entity") + self.assertEqual(len(items), 1) + self.assertEqual(items[0]["name"], "other_entity") + all_items = db_map.get_entity_items() + self.assertEqual(len(all_items), 1) + restored_item = db_map.restore_item("entity", original_id) + self.assertEqual(restored_item["name"], "my_entity") + all_items = db_map.get_entity_items() + self.assertEqual(len(all_items), 2) + + def test_cunning_ways_to_make_external_changes(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + self._assert_success(db_map.add_entity_class_item(name="interesting_class")) + self._assert_success(db_map.add_entity_class_item(name="filler_class")) + self._assert_success( + db_map.add_parameter_definition_item(name="quality", entity_class_name="interesting_class") + ) + self._assert_success( + db_map.add_parameter_definition_item(name="quantity", entity_class_name="filler_class") + ) + self._assert_success( + db_map.add_entity_item(name="object_of_interest", entity_class_name="interesting_class") + ) + value, value_type = to_database(2.3) + self._assert_success( + db_map.add_parameter_value_item( + parameter_definition_name="quality", + entity_class_name="interesting_class", + entity_byname=("object_of_interest",), + alternative_name="Base", + value=value, + type=value_type, + ) + ) + db_map.commit_session("Add initial data") + removed_item = db_map.get_entity_item(name="object_of_interest", entity_class_name="interesting_class") + removed_item.remove() + db_map.commit_session("Remove object of interest") + with DatabaseMapping(url) as shadow_db_map: + self._assert_success( + shadow_db_map.add_entity_item(name="other_entity", entity_class_name="interesting_class") + ) + self._assert_success(shadow_db_map.add_entity_item(name="filler", entity_class_name="filler_class")) + value, value_type = to_database(-2.3) + self._assert_success( + shadow_db_map.add_parameter_value_item( + parameter_definition_name="quantity", + entity_class_name="filler_class", + entity_byname=("filler",), + alternative_name="Base", + value=value, + type=value_type, + ) + ) + value, value_type = to_database(99.9) + self._assert_success( + shadow_db_map.add_parameter_value_item( + parameter_definition_name="quality", + entity_class_name="interesting_class", + entity_byname=("other_entity",), + alternative_name="Base", + value=value, + type=value_type, + ) + ) + shadow_db_map.commit_session("Add entities.") + db_map.refresh_session() + entity_items = db_map.get_entity_items() + self.assertEqual(len(entity_items), 2) + self.assertEqual( + entity_items[0].extended(), + { + "id": 1, + "name": "other_entity", + "description": None, + "class_id": 1, + "element_id_list": (), + "element_name_list": (), + "commit_id": 4, + "entity_class_name": "interesting_class", + "dimension_id_list": (), + "dimension_name_list": (), + "element_byname_list": (), + "superclass_id": None, + "superclass_name": None, + }, + ) + self.assertEqual( + entity_items[1].extended(), + { + "id": 2, + "name": "filler", + "description": None, + "class_id": 2, + "element_id_list": (), + "element_name_list": (), + "commit_id": 4, + "entity_class_name": "filler_class", + "dimension_id_list": (), + "dimension_name_list": (), + "element_byname_list": (), + "superclass_id": None, + "superclass_name": None, + }, + ) + value_items = db_map.get_parameter_value_items() + self.assertEqual(len(value_items), 2) + self.assertTrue(removed_item.is_committed()) + self.assertEqual( + value_items[0].extended(), + { + "alternative_id": 1, + "alternative_name": "Base", + "commit_id": 4, + "dimension_id_list": (), + "dimension_name_list": (), + "element_id_list": (), + "element_name_list": (), + "entity_byname": ("filler",), + "entity_class_id": 2, + "entity_class_name": "filler_class", + "entity_id": 3, + "entity_name": "filler", + "id": 2, + "list_value_id": None, + "parameter_definition_id": 2, + "parameter_definition_name": "quantity", + "parameter_value_list_id": None, + "parameter_value_list_name": None, + "type": to_database(-2.3)[1], + "value": to_database(-2.3)[0], + }, + ) + self.assertEqual( + value_items[1].extended(), + { + "alternative_id": 1, + "alternative_name": "Base", + "commit_id": 4, + "dimension_id_list": (), + "dimension_name_list": (), + "element_id_list": (), + "element_name_list": (), + "entity_byname": ("other_entity",), + "entity_class_id": 1, + "entity_class_name": "interesting_class", + "entity_id": 2, + "entity_name": "other_entity", + "id": 3, + "list_value_id": None, + "parameter_definition_id": 1, + "parameter_definition_name": "quality", + "parameter_value_list_id": None, + "parameter_value_list_name": None, + "type": to_database(99.9)[1], + "value": to_database(99.9)[0], + }, + ) + + def test_update_entity_metadata_externally(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + self._assert_success(db_map.add_entity_class_item(name="my_class")) + self._assert_success(db_map.add_entity_item(name="my_entity", entity_class_name="my_class")) + metadata_value = '{"sources": [], "contributors": []}' + self._assert_success(db_map.add_metadata_item(name="my_metadata", value=metadata_value)) + self._assert_success( + db_map.add_entity_metadata_item( + metadata_name="my_metadata", + metadata_value=metadata_value, + entity_class_name="my_class", + entity_byname=("my_entity",), + ) + ) + db_map.commit_session("Add initial data.") + with DatabaseMapping(url) as shadow_db_map: + self._assert_success( + shadow_db_map.add_entity_item(name="other_entity", entity_class_name="my_class") + ) + metadata_item = shadow_db_map.get_entity_metadata_item( + metadata_name="my_metadata", + metadata_value=metadata_value, + entity_class_name="my_class", + entity_byname=("my_entity",), + ) + self.assertTrue(metadata_item) + metadata_item.update(entity_byname=("other_entity",)) + shadow_db_map.commit_session("Move entity metadata to another entity") + db_map.refresh_session() + metadata_items = db_map.get_entity_metadata_items() + self.assertEqual(len(metadata_items), 2) + self.assertEqual( + metadata_items[0].extended(), + { + "id": 1, + "entity_class_name": "my_class", + "entity_byname": ("my_entity",), + "entity_id": 1, + "metadata_id": 1, + "metadata_name": "my_metadata", + "metadata_value": metadata_value, + "commit_id": 2, + }, + ) + self.assertEqual( + metadata_items[1].extended(), + { + "id": 2, + "entity_class_name": "my_class", + "entity_byname": ("other_entity",), + "entity_id": 2, + "metadata_id": 1, + "metadata_name": "my_metadata", + "metadata_value": metadata_value, + "commit_id": 3, + }, + ) + + def test_update_parameter_value_metadata_externally(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + self._assert_success(db_map.add_entity_class_item(name="my_class")) + self._assert_success(db_map.add_parameter_definition_item(name="x", entity_class_name="my_class")) + self._assert_success(db_map.add_entity_item(name="my_entity", entity_class_name="my_class")) + value, value_type = to_database(2.3) + self._assert_success( + db_map.add_parameter_value_item( + entity_class_name="my_class", + entity_byname=("my_entity",), + parameter_definition_name="x", + alternative_name="Base", + value=value, + type=value_type, + ) + ) + metadata_value = '{"sources": [], "contributors": []}' + self._assert_success(db_map.add_metadata_item(name="my_metadata", value=metadata_value)) + self._assert_success( + db_map.add_parameter_value_metadata_item( + metadata_name="my_metadata", + metadata_value=metadata_value, + entity_class_name="my_class", + entity_byname=("my_entity",), + parameter_definition_name="x", + alternative_name="Base", + ) + ) + db_map.commit_session("Add initial data.") + with DatabaseMapping(url) as shadow_db_map: + self._assert_success( + shadow_db_map.add_entity_item(name="other_entity", entity_class_name="my_class") + ) + value, value_type = to_database(5.0) + self._assert_success( + shadow_db_map.add_parameter_value_item( + entity_class_name="my_class", + entity_byname=("other_entity",), + parameter_definition_name="x", + alternative_name="Base", + value=value, + type=value_type, + ) + ) + metadata_item = shadow_db_map.get_parameter_value_metadata_item( + metadata_name="my_metadata", + metadata_value=metadata_value, + entity_class_name="my_class", + entity_byname=("my_entity",), + parameter_definition_name="x", + alternative_name="Base", + ) + self.assertTrue(metadata_item) + metadata_item.update(entity_byname=("other_entity",)) + shadow_db_map.commit_session("Move parameter value metadata to another entity") + db_map.refresh_session() + metadata_items = db_map.get_parameter_value_metadata_items() + self.assertEqual(len(metadata_items), 2) + self.assertEqual( + metadata_items[0].extended(), + { + "id": 1, + "entity_class_name": "my_class", + "parameter_definition_name": "x", + "parameter_value_id": 1, + "entity_byname": ("my_entity",), + "metadata_id": 1, + "metadata_name": "my_metadata", + "metadata_value": metadata_value, + "alternative_name": "Base", + "commit_id": 2, + }, + ) + self.assertEqual( + metadata_items[1].extended(), + { + "id": 2, + "entity_class_name": "my_class", + "parameter_definition_name": "x", + "parameter_value_id": 2, + "entity_byname": ("other_entity",), + "metadata_id": 1, + "metadata_name": "my_metadata", + "metadata_value": metadata_value, + "alternative_name": "Base", + "commit_id": 3, + }, + ) + + def test_update_entity_alternative_externally(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + self._assert_success(db_map.add_entity_class_item(name="my_class")) + self._assert_success(db_map.add_entity_item(name="my_entity", entity_class_name="my_class")) + self._assert_success( + db_map.add_entity_alternative_item( + entity_byname=("my_entity",), + entity_class_name="my_class", + alternative_name="Base", + active=False, + ) + ) + db_map.commit_session("Add initial data.") + with DatabaseMapping(url) as shadow_db_map: + self._assert_success( + shadow_db_map.add_entity_item(name="other_entity", entity_class_name="my_class") + ) + entity_alternative = shadow_db_map.get_entity_alternative_item( + entity_class_name="my_class", entity_byname=("my_entity",), alternative_name="Base" + ) + self.assertTrue(entity_alternative) + entity_alternative.update(entity_byname=("other_entity",)) + shadow_db_map.commit_session("Move entity alternative to another entity.") + db_map.refresh_session() + entity_alternatives = db_map.get_entity_alternative_items() + self.assertEqual(len(entity_alternatives), 2) + self.assertEqual( + entity_alternatives[0].extended(), + { + "id": 1, + "entity_class_name": "my_class", + "entity_class_id": 1, + "entity_byname": ("my_entity",), + "entity_name": "my_entity", + "entity_id": 1, + "dimension_name_list": (), + "dimension_id_list": (), + "element_name_list": (), + "element_id_list": (), + "alternative_name": "Base", + "alternative_id": 1, + "active": False, + "commit_id": 2, + }, + ) + self.assertEqual( + entity_alternatives[1].extended(), + { + "id": 2, + "entity_class_name": "my_class", + "entity_class_id": 1, + "entity_byname": ("other_entity",), + "entity_name": "other_entity", + "entity_id": 2, + "dimension_name_list": (), + "dimension_id_list": (), + "element_name_list": (), + "element_id_list": (), + "alternative_name": "Base", + "alternative_id": 1, + "active": False, + "commit_id": 3, + }, + ) + + def test_update_superclass_subclass_externally(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + self._assert_success(db_map.add_entity_class_item(name="ceiling")) + self._assert_success(db_map.add_entity_class_item(name="floor")) + self._assert_success(db_map.add_entity_class_item(name="soil")) + self._assert_success( + db_map.add_superclass_subclass_item(superclass_name="ceiling", subclass_name="floor") + ) + db_map.commit_session("Add initial data.") + with DatabaseMapping(url) as shadow_db_map: + superclass_subclass = shadow_db_map.get_superclass_subclass_item(subclass_name="floor") + superclass_subclass.update(subclass_name="soil") + shadow_db_map.commit_session("Changes subclass to another one.") + db_map.refresh_session() + superclass_subclasses = db_map.get_superclass_subclass_items() + self.assertEqual(len(superclass_subclasses), 2) + self.assertEqual( + superclass_subclasses[0].extended(), + { + "id": 1, + "superclass_name": "ceiling", + "superclass_id": 1, + "subclass_name": "floor", + "subclass_id": 2, + }, + ) + self.assertEqual( + superclass_subclasses[1].extended(), + { + "id": 2, + "superclass_name": "ceiling", + "superclass_id": 1, + "subclass_name": "soil", + "subclass_id": 3, + }, + ) + + def test_adding_same_parameters_values_to_different_entities_externally(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + self._assert_success(db_map.add_entity_class_item(name="my_class")) + self._assert_success(db_map.add_parameter_definition_item(name="x", entity_class_name="my_class")) + my_entity = self._assert_success(db_map.add_entity_item(name="my_entity", entity_class_name="my_class")) + value, value_type = to_database(2.3) + self._assert_success( + db_map.add_parameter_value_item( + entity_class_name="my_class", + entity_byname=("my_entity",), + parameter_definition_name="x", + alternative_name="Base", + value=value, + type=value_type, + ) + ) + db_map.commit_session("Add initial data.") + my_entity.remove() + db_map.commit_session("Remove entity.") + with DatabaseMapping(url) as shadow_db_map: + self._assert_success( + shadow_db_map.add_entity_item(name="other_entity", entity_class_name="my_class") + ) + self._assert_success( + shadow_db_map.add_parameter_value_item( + entity_class_name="my_class", + entity_byname=("other_entity",), + parameter_definition_name="x", + alternative_name="Base", + value=value, + type=value_type, + ) + ) + shadow_db_map.commit_session("Add another entity.") + db_map.refresh_session() + values = db_map.get_parameter_value_items() + self.assertEqual(len(values), 1) + self.assertEqual( + values[0].extended(), + { + "id": -2, + "entity_class_name": "my_class", + "entity_class_id": -1, + "dimension_name_list": (), + "dimension_id_list": (), + "parameter_definition_name": "x", + "parameter_definition_id": -1, + "entity_byname": ("other_entity",), + "entity_name": "other_entity", + "entity_id": -2, + "element_name_list": (), + "element_id_list": (), + "alternative_name": "Base", + "alternative_id": -1, + "parameter_value_list_name": None, + "parameter_value_list_id": None, + "list_value_id": None, + "type": value_type, + "value": value, + "commit_id": -4, + }, + ) + + def test_committing_changed_purged_entity_has_been_overwritten_by_external_change(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + self._assert_success(db_map.add_entity_class_item(name="my_class")) + self._assert_success(db_map.add_entity_item(name="ghost", entity_class_name="my_class")) + db_map.commit_session("Add soon-to-be-removed entity.") + db_map.purge_items("entity") + db_map.commit_session("Purge entities.") + with DatabaseMapping(url) as shadow_db_map: + self._assert_success( + shadow_db_map.add_entity_item(name="other_entity", entity_class_name="my_class") + ) + shadow_db_map.commit_session("Add another entity that steals ghost's id.") + db_map.refresh_session() + db_map.do_fetch_all("entity") + self._assert_success(db_map.add_entity_item(name="dirty_entity", entity_class_name="my_class")) + db_map.commit_session("Add still uncommitted entity.") + entities = db_map.query(db_map.wide_entity_sq).all() + self.assertEqual(len(entities), 2) + if __name__ == "__main__": unittest.main() From 459d3503207b911ea694ad80ee782401c0ef76a1 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 7 Feb 2024 17:29:56 +0100 Subject: [PATCH 265/317] Use TempId for all items in the mapping --- spinedb_api/db_mapping_base.py | 153 +++++++++++---------------------- spinedb_api/temp_id.py | 3 + tests/test_DatabaseMapping.py | 149 +++++++------------------------- 3 files changed, 84 insertions(+), 221 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 7ec7e78d..49ef35d2 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -299,10 +299,10 @@ def do_fetch_more(self, item_type, offset=0, limit=None, **kwargs): new_items = [] # Add items first for x in chunk: - for item, new in mapped_table.add_item_from_db(x, not is_db_dirty): - if new: - new_items.append(item) - items.append(item) + item, new = mapped_table.add_item_from_db(x, not is_db_dirty) + if new: + new_items.append(item) + items.append(item) # Once all items are added, add the unique key values # Otherwise items that refer to other items that come later in the query will be seen as corrupted for item in new_items: @@ -314,29 +314,6 @@ def do_fetch_all(self, item_type): self._fetched.add(item_type) self.do_fetch_more(item_type, offset=0, limit=None) - def update_id(self, mapped_item, new_id): - """Updates the id of the given item to the given new_id, also in all its referees. - - Args: - mapped_item (MappedItemBase) - new_id (int) - """ - mapped_item.cascade_call_remove_callbacks() - mapped_item.cascade_remove_unique() - old_id = mapped_item["id"] - mapped_item["id"] = new_id - for item_type in self.item_types(): - mapped_table = self.mapped_table(item_type) - for field, (ref_type, ref_field) in self.item_factory(item_type)._references.items(): - if ref_type == mapped_item.item_type and ref_field == "id": - for item in mapped_table.values(): - value = item[field] - if isinstance(value, tuple) and old_id in value: - item[field] = tuple(new_id if id_ == old_id else id_ for id_ in value) - elif old_id == value: - item[field] = new_id - mapped_item.cascade_add_unique() - class _MappedTable(dict): def __init__(self, db_map, item_type, *args, **kwargs): @@ -531,9 +508,11 @@ def _make_and_add_item(self, item): if not isinstance(item, MappedItemBase): item = self._make_item(item) item.polish() - if "id" not in item or not item.has_valid_id: - item["id"] = self._new_id() - self[item["id"]] = item + db_id = item.pop("id", None) if item.has_valid_id else None + item["id"] = new_id = self._new_id() + if db_id is not None: + new_id.resolve(db_id) + self[new_id] = item return item def add_item_from_db(self, item, is_db_clean): @@ -543,50 +522,22 @@ def add_item_from_db(self, item, is_db_clean): item (dict): item from the DB. is_db_clean (Bool) - Yields: + Returns: tuple(MappedItem,bool): A mapped item and whether it needs to be added to the unique key values dict. """ mapped_item = self._find_item_by_unique_key(item, fetch=False, valid_only=False) if mapped_item and (is_db_clean or mapped_item.is_equal_in_db(item)): - yield from self._force_id(mapped_item, item["id"]) - yield mapped_item, False - return + mapped_item.force_id(item["id"]) + return mapped_item, False mapped_item = self.get(item["id"]) if mapped_item and (is_db_clean or mapped_item.is_equal_in_db(item)): - yield mapped_item, False - return - yield from self._free_id(item["id"]) + return mapped_item, False + self._free_id(item["id"]) mapped_item = self._make_and_add_item(item) if self.purged: # Lazy purge: instead of fetching all at purge time, we purge stuff as it comes. mapped_item.cascade_remove(source=self.wildcard_item) - yield mapped_item, True - - def _force_id(self, mapped_item, id_): - """Makes sure that mapped_item has the given id_, corresponding to the new id of the item - in the DB after some external changes. - - Args: - mapped_item (MappedItemBase): An item in the in-memory item. - id_ (int): The most recent id_ of the item as fetched from the DB. - - Yields: - tuple(MappedItem,bool): A mapped item and whether it needs to be added to the unique key values dict. - """ - mapped_id = mapped_item["id"] - if mapped_id == id_: - return - if isinstance(mapped_id, TempId): - # Easy, resolve the TempId to the new db id (and commit the item if pending) - mapped_id.resolve(id_) - if mapped_item.status == Status.to_add: - mapped_item.status = Status.committed - else: - # Hard, update the id of the item manually. - self._db_map.update_id(mapped_item, id_) - yield from self._free_id(id_) - self[id_] = mapped_item - del self[mapped_id] + return mapped_item, True def _free_id(self, id_): """Makes sure the given id_ is free. Fix conflicts if not. @@ -597,33 +548,10 @@ def _free_id(self, id_): Yields: tuple(MappedItem,bool): A mapped item and whether it needs to be added to the unique key values dict. """ - conflicting_item = self.pop(id_, None) + conflicting_item = self.get(id_) if conflicting_item is None: return - self._resolve_id_conflict(conflicting_item) - yield conflicting_item, False - - def _resolve_id_conflict(self, conflicting_item): - """Does something with conflicting_item whose id now belongs to a different item after an external commit. - - Args: - conflicting_item (MappedItemBase): an item in the memory mapping. - """ - # Here we could let the user choose the strategy. For now we just 'rescue' the item. - self._rescue_item(conflicting_item) - - def _rescue_item(self, conflicting_item): - """Rescues the given conflicting_item whose id now belongs to a different item after an external commit. - - Args: - conflicting_item (MappedItemBase): an item in the memory mapping. - """ - status = conflicting_item.status - id_ = self._new_id() - self._db_map.update_id(conflicting_item, id_) - self[id_] = conflicting_item - if status in (Status.to_update, Status.committed): - conflicting_item.status = Status.to_add + conflicting_item.detach() def check_fields(self, item, valid_types=()): factory = self._db_map.item_factory(self._item_type) @@ -899,9 +827,11 @@ def first_invalid_key(self): """ return next(self._invalid_keys(), None) + # TODO: Maybe rename this method to reflect its more important task now of replacing fields with TempIds def _invalid_keys(self): """Goes through the ``_references`` class attribute and returns the keys of the ones that cannot be resolved. + Also, replace fields referring to db-ids with TempIds. Yields: str: unresolved keys if any. @@ -913,11 +843,17 @@ def _invalid_keys(self): yield src_key else: if isinstance(src_val, tuple): - for x in src_val: - if not self._get_ref(ref_type, {ref_key: x}): - yield src_key - elif not self._get_ref(ref_type, {ref_key: src_val}): - yield src_key + refs = tuple(self._get_ref(ref_type, {ref_key: x}) for x in src_val) + if not all(refs): + yield src_key + elif ref_key == "id": + self[src_key] = tuple(ref["id"] for ref in refs) + else: + ref = self._get_ref(ref_type, {ref_key: src_val}) + if not self._get_ref(ref_type, {ref_key: src_val}): + yield src_key + elif ref_key == "id": + self[src_key] = ref["id"] @classmethod def unique_values_for_item(cls, item, skip_keys=()): @@ -1143,20 +1079,12 @@ def cascade_remove(self, source=None): for referrer in self._referrers.values(): referrer.cascade_remove(source=self) self._update_weak_referrers() - self.call_remove_callbacks() - - def call_remove_callbacks(self): obsolete = set() for callback in list(self.remove_callbacks): if not callback(self): obsolete.add(callback) self.remove_callbacks -= obsolete - def cascade_call_remove_callbacks(self): - for referrer in self._referrers.values(): - referrer.cascade_call_remove_callbacks() - self.call_remove_callbacks() - def cascade_update(self): """Updates this item and all its referrers in cascade. Also, calls items' update callbacks. @@ -1257,6 +1185,27 @@ def update(self, other): if self._asdict() == self._backup: self._status = Status.committed + def force_id(self, id_): + """Makes sure this item's has the given id_, corresponding to the new id of the item + in the DB after some external changes. + + Args: + id_ (int): The most recent id_ of the item as fetched from the DB. + """ + mapped_id = self["id"] + if mapped_id == id_: + return + # Resolve the TempId to the new db id (and commit the item if pending) + mapped_id.resolve(id_) + if self.status == Status.to_add: + self.status = Status.committed + + def detach(self): + """Detaches this item whose id now belongs to a different item after an external commit.""" + self["id"].unresolve() + if self.status in (Status.to_update, Status.committed): + self.status = Status.to_add + class PublicItem: def __init__(self, db_map, mapped_item): diff --git a/spinedb_api/temp_id.py b/spinedb_api/temp_id.py index de9f302d..b504c97d 100644 --- a/spinedb_api/temp_id.py +++ b/spinedb_api/temp_id.py @@ -45,6 +45,9 @@ def resolve(self, db_id): for callback in self._resolve_callbacks: callback(db_id) + def unresolve(self): + self._db_id = None + def resolve(value): if isinstance(value, dict): diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 8978c4c2..0c15dd26 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -3136,7 +3136,7 @@ def test_restoring_entity_whose_db_id_has_been_replaced_by_external_db_modificat shadow_db_map.add_entity_item(entity_class_name="my_class", name="other_entity") ) shadow_db_map.commit_session("Add entity with different name, probably reusing previous id.") - db_map.refresh_session() + # db_map.refresh_session() items = db_map.fetch_more("entity") self.assertEqual(len(items), 1) self.assertEqual(items[0]["name"], "other_entity") @@ -3328,35 +3328,14 @@ def test_update_entity_metadata_externally(self): self.assertTrue(metadata_item) metadata_item.update(entity_byname=("other_entity",)) shadow_db_map.commit_session("Move entity metadata to another entity") - db_map.refresh_session() metadata_items = db_map.get_entity_metadata_items() self.assertEqual(len(metadata_items), 2) - self.assertEqual( - metadata_items[0].extended(), - { - "id": 1, - "entity_class_name": "my_class", - "entity_byname": ("my_entity",), - "entity_id": 1, - "metadata_id": 1, - "metadata_name": "my_metadata", - "metadata_value": metadata_value, - "commit_id": 2, - }, - ) - self.assertEqual( - metadata_items[1].extended(), - { - "id": 2, - "entity_class_name": "my_class", - "entity_byname": ("other_entity",), - "entity_id": 2, - "metadata_id": 1, - "metadata_name": "my_metadata", - "metadata_value": metadata_value, - "commit_id": 3, - }, - ) + self.assertNotEqual(metadata_items[0]["id"], metadata_items[1]["id"]) + unique_values = { + (x["entity_class_name"], x["entity_byname"], x["metadata_name"]) for x in metadata_items + } + self.assertIn(("my_class", ("my_entity",), "my_metadata"), unique_values) + self.assertIn(("my_class", ("other_entity",), "my_metadata"), unique_values) def test_update_parameter_value_metadata_externally(self): with TemporaryDirectory() as temp_dir: @@ -3415,39 +3394,21 @@ def test_update_parameter_value_metadata_externally(self): self.assertTrue(metadata_item) metadata_item.update(entity_byname=("other_entity",)) shadow_db_map.commit_session("Move parameter value metadata to another entity") - db_map.refresh_session() metadata_items = db_map.get_parameter_value_metadata_items() self.assertEqual(len(metadata_items), 2) - self.assertEqual( - metadata_items[0].extended(), - { - "id": 1, - "entity_class_name": "my_class", - "parameter_definition_name": "x", - "parameter_value_id": 1, - "entity_byname": ("my_entity",), - "metadata_id": 1, - "metadata_name": "my_metadata", - "metadata_value": metadata_value, - "alternative_name": "Base", - "commit_id": 2, - }, - ) - self.assertEqual( - metadata_items[1].extended(), - { - "id": 2, - "entity_class_name": "my_class", - "parameter_definition_name": "x", - "parameter_value_id": 2, - "entity_byname": ("other_entity",), - "metadata_id": 1, - "metadata_name": "my_metadata", - "metadata_value": metadata_value, - "alternative_name": "Base", - "commit_id": 3, - }, - ) + self.assertNotEqual(metadata_items[0]["id"], metadata_items[1]["id"]) + unique_values = { + ( + x["entity_class_name"], + x["parameter_definition_name"], + x["entity_byname"], + x["metadata_name"], + x["alternative_name"], + ) + for x in metadata_items + } + self.assertIn(("my_class", "x", ("my_entity",), "my_metadata", "Base"), unique_values) + self.assertIn(("my_class", "x", ("other_entity",), "my_metadata", "Base"), unique_values) def test_update_entity_alternative_externally(self): with TemporaryDirectory() as temp_dir: @@ -3474,47 +3435,14 @@ def test_update_entity_alternative_externally(self): self.assertTrue(entity_alternative) entity_alternative.update(entity_byname=("other_entity",)) shadow_db_map.commit_session("Move entity alternative to another entity.") - db_map.refresh_session() entity_alternatives = db_map.get_entity_alternative_items() self.assertEqual(len(entity_alternatives), 2) - self.assertEqual( - entity_alternatives[0].extended(), - { - "id": 1, - "entity_class_name": "my_class", - "entity_class_id": 1, - "entity_byname": ("my_entity",), - "entity_name": "my_entity", - "entity_id": 1, - "dimension_name_list": (), - "dimension_id_list": (), - "element_name_list": (), - "element_id_list": (), - "alternative_name": "Base", - "alternative_id": 1, - "active": False, - "commit_id": 2, - }, - ) - self.assertEqual( - entity_alternatives[1].extended(), - { - "id": 2, - "entity_class_name": "my_class", - "entity_class_id": 1, - "entity_byname": ("other_entity",), - "entity_name": "other_entity", - "entity_id": 2, - "dimension_name_list": (), - "dimension_id_list": (), - "element_name_list": (), - "element_id_list": (), - "alternative_name": "Base", - "alternative_id": 1, - "active": False, - "commit_id": 3, - }, - ) + self.assertNotEqual(entity_alternatives[0]["id"], entity_alternatives[1]["id"]) + unique_values = { + (x["entity_class_name"], x["entity_name"], x["alternative_name"]) for x in entity_alternatives + } + self.assertIn(("my_class", "my_entity", "Base"), unique_values) + self.assertIn(("my_class", "other_entity", "Base"), unique_values) def test_update_superclass_subclass_externally(self): with TemporaryDirectory() as temp_dir: @@ -3531,29 +3459,12 @@ def test_update_superclass_subclass_externally(self): superclass_subclass = shadow_db_map.get_superclass_subclass_item(subclass_name="floor") superclass_subclass.update(subclass_name="soil") shadow_db_map.commit_session("Changes subclass to another one.") - db_map.refresh_session() superclass_subclasses = db_map.get_superclass_subclass_items() self.assertEqual(len(superclass_subclasses), 2) - self.assertEqual( - superclass_subclasses[0].extended(), - { - "id": 1, - "superclass_name": "ceiling", - "superclass_id": 1, - "subclass_name": "floor", - "subclass_id": 2, - }, - ) - self.assertEqual( - superclass_subclasses[1].extended(), - { - "id": 2, - "superclass_name": "ceiling", - "superclass_id": 1, - "subclass_name": "soil", - "subclass_id": 3, - }, - ) + self.assertNotEqual(superclass_subclasses[0]["id"], superclass_subclasses[1]["id"]) + unique_values = {(x["superclass_name"], x["subclass_name"]) for x in superclass_subclasses} + self.assertIn(("ceiling", "floor"), unique_values) + self.assertIn(("ceiling", "soil"), unique_values) def test_adding_same_parameters_values_to_different_entities_externally(self): with TemporaryDirectory() as temp_dir: From c33f91bfcb8a2fb44387fab9b283f412bcc11f2a Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Thu, 8 Feb 2024 10:14:34 +0200 Subject: [PATCH 266/317] Fix scenario filter for entity alternatives In scenario_filter._ext_entity_sq(), we must drop rows that have an entity alternative for alternative that is not part of any scenario. Otherwise, entities that should be filtered out will show up in the final query. Re spine-tools/Spine-Toolbox#2504 --- spinedb_api/filters/scenario_filter.py | 6 + tests/filters/test_scenario_filter.py | 19 ++ tests/test_DatabaseMapping.py | 316 +++++++++++++++---------- 3 files changed, 211 insertions(+), 130 deletions(-) diff --git a/spinedb_api/filters/scenario_filter.py b/spinedb_api/filters/scenario_filter.py index 8b8a1604..dd6d47f5 100644 --- a/spinedb_api/filters/scenario_filter.py +++ b/spinedb_api/filters/scenario_filter.py @@ -211,6 +211,12 @@ def _ext_entity_sq(db_map, state): db_map.scenario_alternative_sq.c.scenario_id == state.scenario_id, ) ) + .filter( + or_( + db_map.entity_alternative_sq.c.alternative_id == None, + db_map.entity_alternative_sq.c.alternative_id == db_map.scenario_alternative_sq.c.alternative_id, + ) + ) ).subquery() diff --git a/tests/filters/test_scenario_filter.py b/tests/filters/test_scenario_filter.py index f804d515..3912b14c 100644 --- a/tests/filters/test_scenario_filter.py +++ b/tests/filters/test_scenario_filter.py @@ -99,6 +99,25 @@ def test_filter_entities_with_default_activity(self): self.assertEqual(entities[0]["name"], "visible") self.assertEqual(entities[1]["name"], "visible") + def test_filter_entity_that_is_not_active_in_scenario(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + self._assert_success(db_map.add_alternative_item(name="alt")) + self._assert_success(db_map.add_scenario_item(name="scen")) + self._assert_success( + db_map.add_scenario_alternative_item(scenario_name="scen", alternative_name="Base", rank=0) + ) + self._assert_success(db_map.add_entity_class_item(name="Gadget", active_by_default=False)) + self._assert_success(db_map.add_entity_item(name="fork", entity_class_name="Gadget")) + self._assert_success( + db_map.add_entity_alternative_item( + entity_byname=("fork",), entity_class_name="Gadget", alternative_name="alt", active=True + ) + ) + db_map.commit_session("Add test data.") + apply_filter_stack(db_map, [scenario_filter_config("scen")]) + entities = db_map.query(db_map.wide_entity_sq).all() + self.assertEqual(len(entities), 0) + class TestScenarioFilter(unittest.TestCase): _db_url = None diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 1aa80133..c9dd27a0 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -129,20 +129,24 @@ def test_commit_parameter_value(self): url = "sqlite:///" + os.path.join(temp_dir, "database.sqlite") with DatabaseMapping(url, create=True) as db_map: self._assert_success(db_map.add_item("entity_class", name="fish", description="It swims.")) - self._assert_success(db_map.add_item( - "entity", entity_class_name="fish", name="Nemo", description="Peacefully swimming away." - )) + self._assert_success( + db_map.add_item( + "entity", entity_class_name="fish", name="Nemo", description="Peacefully swimming away." + ) + ) self._assert_success(db_map.add_item("parameter_definition", entity_class_name="fish", name="color")) value, type_ = to_database("mainly orange") - self._assert_success(db_map.add_item( - "parameter_value", - entity_class_name="fish", - entity_byname=("Nemo",), - parameter_definition_name="color", - alternative_name="Base", - value=value, - type=type_, - )) + self._assert_success( + db_map.add_item( + "parameter_value", + entity_class_name="fish", + entity_byname=("Nemo",), + parameter_definition_name="color", + alternative_name="Base", + value=value, + type=type_, + ) + ) db_map.commit_session("Added data") with DatabaseMapping(url) as db_map: color = db_map.get_item( @@ -161,28 +165,40 @@ def test_commit_multidimensional_parameter_value(self): with DatabaseMapping(url, create=True) as db_map: self._assert_success(db_map.add_item("entity_class", name="fish", description="It swims.")) self._assert_success(db_map.add_item("entity_class", name="cat", description="Eats fish.")) - self._assert_success(db_map.add_item( - "entity_class", - name="fish__cat", - dimension_name_list=("fish", "cat"), - description="A fish getting eaten by a cat?", - )) - self._assert_success(db_map.add_item("entity", entity_class_name="fish", name="Nemo", description="Lost (soon).")) - self._assert_success(db_map.add_item( - "entity", entity_class_name="cat", name="Felix", description="The wonderful wonderful cat." - )) - self._assert_success(db_map.add_item("entity", entity_class_name="fish__cat", element_name_list=("Nemo", "Felix"))) - self._assert_success(db_map.add_item("parameter_definition", entity_class_name="fish__cat", name="rate")) + self._assert_success( + db_map.add_item( + "entity_class", + name="fish__cat", + dimension_name_list=("fish", "cat"), + description="A fish getting eaten by a cat?", + ) + ) + self._assert_success( + db_map.add_item("entity", entity_class_name="fish", name="Nemo", description="Lost (soon).") + ) + self._assert_success( + db_map.add_item( + "entity", entity_class_name="cat", name="Felix", description="The wonderful wonderful cat." + ) + ) + self._assert_success( + db_map.add_item("entity", entity_class_name="fish__cat", element_name_list=("Nemo", "Felix")) + ) + self._assert_success( + db_map.add_item("parameter_definition", entity_class_name="fish__cat", name="rate") + ) value, type_ = to_database(0.23) - self._assert_success(db_map.add_item( - "parameter_value", - entity_class_name="fish__cat", - entity_byname=("Nemo", "Felix"), - parameter_definition_name="rate", - alternative_name="Base", - value=value, - type=type_, - )) + self._assert_success( + db_map.add_item( + "parameter_value", + entity_class_name="fish__cat", + entity_byname=("Nemo", "Felix"), + parameter_definition_name="rate", + alternative_name="Base", + value=value, + type=type_, + ) + ) db_map.commit_session("Added data") with DatabaseMapping(url) as db_map: color = db_map.get_item( @@ -198,20 +214,24 @@ def test_commit_multidimensional_parameter_value(self): def test_updating_entity_name_updates_the_name_in_parameter_value_too(self): with DatabaseMapping(IN_MEMORY_DB_URL, create=True) as db_map: self._assert_success(db_map.add_item("entity_class", name="fish", description="It swims.")) - self._assert_success(db_map.add_item( - "entity", entity_class_name="fish", name="Nemo", description="Peacefully swimming away." - )) + self._assert_success( + db_map.add_item( + "entity", entity_class_name="fish", name="Nemo", description="Peacefully swimming away." + ) + ) self._assert_success(db_map.add_item("parameter_definition", entity_class_name="fish", name="color")) value, type_ = to_database("mainly orange") - self._assert_success(db_map.add_item( - "parameter_value", - entity_class_name="fish", - entity_byname=("Nemo",), - parameter_definition_name="color", - alternative_name="Base", - value=value, - type=type_, - )) + self._assert_success( + db_map.add_item( + "parameter_value", + entity_class_name="fish", + entity_byname=("Nemo",), + parameter_definition_name="color", + alternative_name="Base", + value=value, + type=type_, + ) + ) color = db_map.get_item( "parameter_value", entity_class_name="fish", @@ -248,12 +268,14 @@ def test_update_entity_metadata_by_changing_its_entity(self): entity_2 = self._assert_success(db_map.add_entity_item(name="entity_2", entity_class_name="my_class")) metadata_value = '{"sources": [], "contributors": []}' metadata = self._assert_success(db_map.add_metadata_item(name="my_metadata", value=metadata_value)) - entity_metadata = self._assert_success(db_map.add_entity_metadata_item( - metadata_name="my_metadata", - metadata_value=metadata_value, - entity_class_name="my_class", - entity_byname=("entity_1",), - )) + entity_metadata = self._assert_success( + db_map.add_entity_metadata_item( + metadata_name="my_metadata", + metadata_value=metadata_value, + entity_class_name="my_class", + entity_byname=("entity_1",), + ) + ) entity_metadata.update(entity_byname=("entity_2",)) self.assertEqual( entity_metadata._extended(), @@ -308,33 +330,39 @@ def test_update_parameter_value_metadata_by_changing_its_parameter(self): self._assert_success(db_map.add_parameter_definition_item(name="y", entity_class_name="my_class")) self._assert_success(db_map.add_entity_item(name="my_entity", entity_class_name="my_class")) value, value_type = to_database(2.3) - self._assert_success(db_map.add_parameter_value_item( - entity_class_name="my_class", - entity_byname=("my_entity",), - parameter_definition_name="x", - alternative_name="Base", - value=value, - type=value_type, - )) + self._assert_success( + db_map.add_parameter_value_item( + entity_class_name="my_class", + entity_byname=("my_entity",), + parameter_definition_name="x", + alternative_name="Base", + value=value, + type=value_type, + ) + ) value, value_type = to_database(-2.3) - y = self._assert_success(db_map.add_parameter_value_item( - entity_class_name="my_class", - entity_byname=("my_entity",), - parameter_definition_name="y", - alternative_name="Base", - value=value, - type=value_type, - )) + y = self._assert_success( + db_map.add_parameter_value_item( + entity_class_name="my_class", + entity_byname=("my_entity",), + parameter_definition_name="y", + alternative_name="Base", + value=value, + type=value_type, + ) + ) metadata_value = '{"sources": [], "contributors": []}' metadata = self._assert_success(db_map.add_metadata_item(name="my_metadata", value=metadata_value)) - value_metadata = self._assert_success(db_map.add_parameter_value_metadata_item( - metadata_name="my_metadata", - metadata_value=metadata_value, - entity_class_name="my_class", - entity_byname=("my_entity",), - parameter_definition_name="x", - alternative_name="Base", - )) + value_metadata = self._assert_success( + db_map.add_parameter_value_metadata_item( + metadata_name="my_metadata", + metadata_value=metadata_value, + entity_class_name="my_class", + entity_byname=("my_entity",), + parameter_definition_name="x", + alternative_name="Base", + ) + ) value_metadata.update(parameter_definition_name="y") self.assertEqual( value_metadata._extended(), @@ -458,7 +486,9 @@ def test_fetch_entities_that_refer_to_unfetched_entities(self): # Remove the entity in the middle and add a multi-D one referring to the third entity. # The multi-D one will go in the middle. db_map.get_entity_item(name="Sylvester", entity_class_name="cat").remove() - self._assert_success(db_map.add_entity_item(element_name_list=("Pulgoso", "Tom"), entity_class_name="dog__cat")) + self._assert_success( + db_map.add_entity_item(element_name_list=("Pulgoso", "Tom"), entity_class_name="dog__cat") + ) db_map.commit_session("Meow!") with DatabaseMapping(url) as db_map: # The ("Pulgoso", "Tom") entity will be fetched before "Tom". @@ -476,13 +506,13 @@ def test_committing_scenario_alternatives(self): self.assertIsNotNone(item) item = self._assert_success(db_map.add_scenario_item(name="my_scenario")) self.assertIsNotNone(item) - item = self._assert_success(db_map.add_scenario_alternative_item( - scenario_name="my_scenario", alternative_name="alt1", rank=0 - )) + item = self._assert_success( + db_map.add_scenario_alternative_item(scenario_name="my_scenario", alternative_name="alt1", rank=0) + ) self.assertIsNotNone(item) - item = self._assert_success(db_map.add_scenario_alternative_item( - scenario_name="my_scenario", alternative_name="alt2", rank=1 - )) + item = self._assert_success( + db_map.add_scenario_alternative_item(scenario_name="my_scenario", alternative_name="alt2", rank=1) + ) self.assertIsNotNone(item) db_map.commit_session("Add test data.") with DatabaseMapping(url) as db_map: @@ -518,7 +548,11 @@ def test_committing_entity_group_items_doesnt_add_commit_ids_to_them(self): self._assert_success(db_map.add_entity_class_item(name="my_class")) self._assert_success(db_map.add_entity_item(name="element", entity_class_name="my_class")) self._assert_success(db_map.add_entity_item(name="container", entity_class_name="my_class")) - self._assert_success(db_map.add_entity_group_item(group_name="container", member_name="element", entity_class_name="my_class")) + self._assert_success( + db_map.add_entity_group_item( + group_name="container", member_name="element", entity_class_name="my_class" + ) + ) db_map.commit_session("Add entity group.") groups = db_map.get_entity_group_items() self.assertEqual(len(groups), 1) @@ -528,40 +562,54 @@ def test_commit_parameter_value_coincidentally_called_is_active(self): with DatabaseMapping("sqlite://", create=True) as db_map: self._assert_success(db_map.add_parameter_value_list_item(name="booleans")) value, value_type = to_database(True) - self._assert_success(db_map.add_list_value_item(parameter_value_list_name="booleans", value=value, type=value_type, index=0)) + self._assert_success( + db_map.add_list_value_item(parameter_value_list_name="booleans", value=value, type=value_type, index=0) + ) self._assert_success(db_map.add_entity_class_item(name="my_class")) - self._assert_success(db_map.add_parameter_definition_item( - name="is_active", entity_class_name="my_class", parameter_value_list_name="booleans" - )) + self._assert_success( + db_map.add_parameter_definition_item( + name="is_active", entity_class_name="my_class", parameter_value_list_name="booleans" + ) + ) self._assert_success(db_map.add_entity_item(name="widget1", entity_class_name="my_class")) self._assert_success(db_map.add_entity_item(name="widget2", entity_class_name="my_class")) self._assert_success(db_map.add_entity_item(name="no_is_active", entity_class_name="my_class")) - self._assert_success(db_map.add_entity_alternative_item( - entity_class_name="my_class", entity_byname=("widget1",), alternative_name="Base", active=False - )) - self._assert_success(db_map.add_entity_alternative_item( - entity_class_name="my_class", entity_byname=("widget2",), alternative_name="Base", active=False - )) - self._assert_success(db_map.add_entity_alternative_item( - entity_class_name="my_class", entity_byname=("no_is_active",), alternative_name="Base", active=False - )) + self._assert_success( + db_map.add_entity_alternative_item( + entity_class_name="my_class", entity_byname=("widget1",), alternative_name="Base", active=False + ) + ) + self._assert_success( + db_map.add_entity_alternative_item( + entity_class_name="my_class", entity_byname=("widget2",), alternative_name="Base", active=False + ) + ) + self._assert_success( + db_map.add_entity_alternative_item( + entity_class_name="my_class", entity_byname=("no_is_active",), alternative_name="Base", active=False + ) + ) value, value_type = to_database(True) - self._assert_success(db_map.add_parameter_value_item( - entity_class_name="my_class", - parameter_definition_name="is_active", - entity_byname=("widget1",), - alternative_name="Base", - value=value, - type=value_type, - )) - self._assert_success(db_map.add_parameter_value_item( - entity_class_name="my_class", - parameter_definition_name="is_active", - entity_byname=("widget2",), - alternative_name="Base", - value=value, - type=value_type, - )) + self._assert_success( + db_map.add_parameter_value_item( + entity_class_name="my_class", + parameter_definition_name="is_active", + entity_byname=("widget1",), + alternative_name="Base", + value=value, + type=value_type, + ) + ) + self._assert_success( + db_map.add_parameter_value_item( + entity_class_name="my_class", + parameter_definition_name="is_active", + entity_byname=("widget2",), + alternative_name="Base", + value=value, + type=value_type, + ) + ) db_map.commit_session("Add test data to see if this crashes.") entity_names = {entity["id"]: entity["name"] for entity in db_map.query(db_map.wide_entity_sq)} alternative_names = { @@ -585,27 +633,35 @@ def test_commit_default_value_for_parameter_called_is_active(self): with DatabaseMapping("sqlite://", create=True) as db_map: self._assert_success(db_map.add_parameter_value_list_item(name="booleans")) value, value_type = to_database(True) - self._assert_success(db_map.add_list_value_item(parameter_value_list_name="booleans", value=value, type=value_type, index=0)) + self._assert_success( + db_map.add_list_value_item(parameter_value_list_name="booleans", value=value, type=value_type, index=0) + ) self._assert_success(db_map.add_entity_class_item(name="Widget")) - self._assert_success(db_map.add_parameter_definition_item( - name="is_active", - entity_class_name="Widget", - parameter_value_list_name="booleans", - default_value=value, - default_type=value_type, - )) + self._assert_success( + db_map.add_parameter_definition_item( + name="is_active", + entity_class_name="Widget", + parameter_value_list_name="booleans", + default_value=value, + default_type=value_type, + ) + ) self._assert_success(db_map.add_entity_class_item(name="Gadget")) - self._assert_success(db_map.add_parameter_definition_item( - name="is_active", - entity_class_name="Gadget", - parameter_value_list_name="booleans", - default_value=value, - default_type=value_type, - )) + self._assert_success( + db_map.add_parameter_definition_item( + name="is_active", + entity_class_name="Gadget", + parameter_value_list_name="booleans", + default_value=value, + default_type=value_type, + ) + ) self._assert_success(db_map.add_entity_class_item(name="NoIsActiveDefault")) - self._assert_success(db_map.add_parameter_definition_item( - name="is_active", entity_class_name="NoIsActiveDefault", parameter_value_list_name="booleans" - )) + self._assert_success( + db_map.add_parameter_definition_item( + name="is_active", entity_class_name="NoIsActiveDefault", parameter_value_list_name="booleans" + ) + ) db_map.commit_session("Add test data to see if this crashes") active_by_defaults = { entity_class["name"]: entity_class["active_by_default"] From 8b91a0d9a3b99cf6637eed9600ddacab94e81409 Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 8 Feb 2024 10:34:35 +0100 Subject: [PATCH 267/317] Fix MappedItemBase.update to keep the previous id as a TempId --- spinedb_api/db_mapping_base.py | 4 +++- spinedb_api/temp_id.py | 3 ++- tests/test_DatabaseMapping.py | 44 +++++++++++++++++----------------- 3 files changed, 27 insertions(+), 24 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 45923af9..e1b5cdc8 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -1181,7 +1181,9 @@ def update(self, other): self._invalidate_ref(ref_type, {ref_key: x}) else: self._invalidate_ref(ref_type, {ref_key: src_val}) + id_ = self["id"] super().update(other) + self["id"] = id_ if self._asdict() == self._backup: self._status = Status.committed @@ -1242,7 +1244,7 @@ def is_committed(self): def _asdict(self): return self._mapped_item._asdict() - def extended(self): + def _extended(self): return self._mapped_item._extended() def update(self, **kwargs): diff --git a/spinedb_api/temp_id.py b/spinedb_api/temp_id.py index b504c97d..28244c26 100644 --- a/spinedb_api/temp_id.py +++ b/spinedb_api/temp_id.py @@ -35,7 +35,8 @@ def __hash__(self): return int(self) def __repr__(self): - return f"TempId({self._item_type}, {super().__repr__()})" + resolved_to = f" - resolved to {self._db_id}" if self._db_id is not None else "" + return f"TempId({self._item_type}, {super().__repr__()}{resolved_to})" def add_resolve_callback(self, callback): self._resolve_callbacks.append(callback) diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 6733ae62..1d9ae0f7 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -280,7 +280,7 @@ def test_update_entity_metadata_by_changing_its_entity(self): ) entity_metadata.update(entity_byname=("entity_2",)) self.assertEqual( - entity_metadata.extended(), + entity_metadata._extended(), { "entity_class_name": "my_class", "entity_byname": ("entity_2",), @@ -367,7 +367,7 @@ def test_update_parameter_value_metadata_by_changing_its_parameter(self): ) value_metadata.update(parameter_definition_name="y") self.assertEqual( - value_metadata.extended(), + value_metadata._extended(), { "entity_class_name": "my_class", "entity_byname": ("my_entity",), @@ -533,7 +533,7 @@ def test_committing_entity_class_items_doesnt_add_commit_ids_to_them(self): db_map.commit_session("Add class.") classes = db_map.get_entity_class_items() self.assertEqual(len(classes), 1) - self.assertNotIn("commit_id", classes[0].extended()) + self.assertNotIn("commit_id", classes[0]._extended()) def test_committing_superclass_subclass_items_doesnt_add_commit_ids_to_them(self): with DatabaseMapping("sqlite://", create=True) as db_map: @@ -543,7 +543,7 @@ def test_committing_superclass_subclass_items_doesnt_add_commit_ids_to_them(self db_map.commit_session("Add class hierarchy.") classes = db_map.get_superclass_subclass_items() self.assertEqual(len(classes), 1) - self.assertNotIn("commit_id", classes[0].extended()) + self.assertNotIn("commit_id", classes[0]._extended()) def test_committing_entity_group_items_doesnt_add_commit_ids_to_them(self): with DatabaseMapping("sqlite://", create=True) as db_map: @@ -558,7 +558,7 @@ def test_committing_entity_group_items_doesnt_add_commit_ids_to_them(self): db_map.commit_session("Add entity group.") groups = db_map.get_entity_group_items() self.assertEqual(len(groups), 1) - self.assertNotIn("commit_id", groups[0].extended()) + self.assertNotIn("commit_id", groups[0]._extended()) def test_commit_parameter_value_coincidentally_called_is_active(self): with DatabaseMapping("sqlite://", create=True) as db_map: @@ -2088,7 +2088,7 @@ def test_update_object_classes(self): items, intgr_error_log = self._db_map.update_object_classes( {"id": 1, "name": "octopus"}, {"id": 2, "name": "god"} ) - ids = {x["id"] for x in items} + ids = {x.resolve()["id"] for x in items} self._db_map.commit_session("test commit") sq = self._db_map.object_class_sq object_classes = {x.id: x.name for x in self._db_map.query(sq).filter(sq.c.id.in_(ids))} @@ -2101,7 +2101,7 @@ def test_update_objects(self): self._db_map.add_object_classes({"id": 1, "name": "fish"}) self._db_map.add_objects({"id": 1, "name": "nemo", "class_id": 1}, {"id": 2, "name": "dory", "class_id": 1}) items, intgr_error_log = self._db_map.update_objects({"id": 1, "name": "klaus"}, {"id": 2, "name": "squidward"}) - ids = {x["id"] for x in items} + ids = {x.resolve()["id"] for x in items} self._db_map.commit_session("test commit") sq = self._db_map.object_sq objects = {x.id: x.name for x in self._db_map.query(sq).filter(sq.c.id.in_(ids))} @@ -2115,7 +2115,7 @@ def test_update_committed_object(self): self._db_map.add_objects({"id": 1, "name": "nemo", "class_id": 1}) self._db_map.commit_session("update") items, intgr_error_log = self._db_map.update_objects({"id": 1, "name": "klaus"}) - ids = {x["id"] for x in items} + ids = {x.resolve()["id"] for x in items} self._db_map.commit_session("test commit") sq = self._db_map.object_sq objects = {x.id: x.name for x in self._db_map.query(sq).filter(sq.c.id.in_(ids))} @@ -2133,7 +2133,7 @@ def test_update_relationship_classes(self): items, intgr_error_log = self._db_map.update_wide_relationship_classes( {"id": 3, "name": "god__octopus"}, {"id": 4, "name": "octopus__dog"} ) - ids = {x["id"] for x in items} + ids = {x.resolve()["id"] for x in items} self._db_map.commit_session("test commit") sq = self._db_map.wide_relationship_class_sq rel_clss = {x.id: x.name for x in self._db_map.query(sq).filter(sq.c.id.in_(ids))} @@ -2146,7 +2146,7 @@ def test_update_committed_relationship_class(self): _ = import_functions.import_relationship_classes(self._db_map, (("my_class", ("object_class_1",)),)) self._db_map.commit_session("Add test data") items, errors = self._db_map.update_wide_relationship_classes({"id": 2, "name": "renamed"}) - updated_ids = {x["id"] for x in items} + updated_ids = {x.resolve()["id"] for x in items} self.assertEqual(errors, []) self.assertEqual(updated_ids, {2}) self._db_map.commit_session("Update data.") @@ -2184,7 +2184,7 @@ def test_update_relationships(self): items, intgr_error_log = self._db_map.update_wide_relationships( {"id": 4, "name": "nemo__scooby", "class_id": 3, "object_id_list": [1, 3], "object_class_id_list": [1, 2]} ) - ids = {x["id"] for x in items} + ids = {x.resolve()["id"] for x in items} self._db_map.commit_session("test commit") sq = self._db_map.wide_relationship_sq rels = { @@ -2207,7 +2207,7 @@ def test_update_committed_relationship(self): import_functions.import_relationships(self._db_map, (("my_class", ("object_11", "object_21")),)) self._db_map.commit_session("Add test data") items, errors = self._db_map.update_wide_relationships({"id": 4, "name": "renamed", "object_id_list": [2, 3]}) - updated_ids = {x["id"] for x in items} + updated_ids = {x.resolve()["id"] for x in items} self.assertEqual(errors, []) self.assertEqual(updated_ids, {4}) self._db_map.commit_session("Update data.") @@ -2225,7 +2225,7 @@ def test_update_parameter_value_by_id_only(self): ) self._db_map.commit_session("Populate with initial data.") items, errors = self._db_map.update_parameter_values({"id": 1, "value": b"something else"}) - updated_ids = {x["id"] for x in items} + updated_ids = {x.resolve()["id"] for x in items} self.assertEqual(errors, []) self.assertEqual(updated_ids, {1}) self._db_map.commit_session("Update data.") @@ -2257,7 +2257,7 @@ def test_update_parameter_definition_by_id_only(self): import_functions.import_object_parameters(self._db_map, (("object_class1", "parameter1"),)) self._db_map.commit_session("Populate with initial data.") items, errors = self._db_map.update_parameter_definitions({"id": 1, "name": "parameter2"}) - updated_ids = {x["id"] for x in items} + updated_ids = {x.resolve()["id"] for x in items} self.assertEqual(errors, []) self.assertEqual(updated_ids, {1}) self._db_map.commit_session("Update data.") @@ -2273,7 +2273,7 @@ def test_update_parameter_definition_value_list(self): items, errors = self._db_map.update_parameter_definitions( {"id": 1, "name": "my_parameter", "parameter_value_list_id": 1} ) - updated_ids = {x["id"] for x in items} + updated_ids = {x.resolve()["id"] for x in items} self.assertEqual(errors, []) self.assertEqual(updated_ids, {1}) self._db_map.commit_session("Update data.") @@ -2373,7 +2373,7 @@ def test_update_object_metadata_reuses_existing_metadata(self): items, errors = self._db_map.update_ext_entity_metadata( *[{"id": 1, "metadata_name": "key 2", "metadata_value": "metadata value 2"}] ) - ids = {x["id"] for x in items} + ids = {x.resolve()["id"] for x in items} self.assertEqual(errors, []) self.assertEqual(ids, {1}) self._db_map.remove_unused_metadata() @@ -2479,7 +2479,7 @@ def test_update_metadata(self): import_functions.import_metadata(self._db_map, ('{"title": "My metadata."}',)) self._db_map.commit_session("Add test data.") items, errors = self._db_map.update_metadata(*({"id": 1, "name": "author", "value": "Prof. T. Est"},)) - ids = {x["id"] for x in items} + ids = {x.resolve()["id"] for x in items} self.assertEqual(errors, []) self.assertEqual(ids, {1}) self._db_map.commit_session("Update data") @@ -3228,7 +3228,7 @@ def test_cunning_ways_to_make_external_changes(self): entity_items = db_map.get_entity_items() self.assertEqual(len(entity_items), 2) self.assertEqual( - entity_items[0].extended(), + entity_items[0]._extended(), { "id": 1, "name": "other_entity", @@ -3246,7 +3246,7 @@ def test_cunning_ways_to_make_external_changes(self): }, ) self.assertEqual( - entity_items[1].extended(), + entity_items[1]._extended(), { "id": 2, "name": "filler", @@ -3267,7 +3267,7 @@ def test_cunning_ways_to_make_external_changes(self): self.assertEqual(len(value_items), 2) self.assertTrue(removed_item.is_committed()) self.assertEqual( - value_items[0].extended(), + value_items[0]._extended(), { "alternative_id": 1, "alternative_name": "Base", @@ -3292,7 +3292,7 @@ def test_cunning_ways_to_make_external_changes(self): }, ) self.assertEqual( - value_items[1].extended(), + value_items[1]._extended(), { "alternative_id": 1, "alternative_name": "Base", @@ -3525,7 +3525,7 @@ def test_adding_same_parameters_values_to_different_entities_externally(self): values = db_map.get_parameter_value_items() self.assertEqual(len(values), 1) self.assertEqual( - values[0].extended(), + values[0]._extended(), { "id": -2, "entity_class_name": "my_class", From ae7f158d700f9735e38f146452f16a702f2194c1 Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 8 Feb 2024 14:36:05 +0100 Subject: [PATCH 268/317] Fix tests --- spinedb_api/db_mapping_base.py | 21 +---- spinedb_api/import_functions.py | 16 ++-- spinedb_api/temp_id.py | 4 +- tests/test_DatabaseMapping.py | 158 +++++++------------------------- tests/test_import_functions.py | 5 +- 5 files changed, 49 insertions(+), 155 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index e1b5cdc8..e9f77ea4 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -532,27 +532,15 @@ def add_item_from_db(self, item, is_db_clean): mapped_item = self.get(item["id"]) if mapped_item and (is_db_clean or mapped_item.is_equal_in_db(item)): return mapped_item, False - self._free_id(item["id"]) + conflicting_item = self.get(item["id"]) + if conflicting_item is not None: + conflicting_item.detach() mapped_item = self._make_and_add_item(item) if self.purged: # Lazy purge: instead of fetching all at purge time, we purge stuff as it comes. mapped_item.cascade_remove(source=self.wildcard_item) return mapped_item, True - def _free_id(self, id_): - """Makes sure the given id_ is free. Fix conflicts if not. - - Args: - id_ (int) - - Yields: - tuple(MappedItem,bool): A mapped item and whether it needs to be added to the unique key values dict. - """ - conflicting_item = self.get(id_) - if conflicting_item is None: - return - conflicting_item.detach() - def check_fields(self, item, valid_types=()): factory = self._db_map.item_factory(self._item_type) @@ -1205,8 +1193,7 @@ def force_id(self, id_): def detach(self): """Detaches this item whose id now belongs to a different item after an external commit.""" self["id"].unresolve() - if self.status in (Status.to_update, Status.committed): - self.status = Status.to_add + # TODO: Update item's status. class PublicItem: diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 62cb63a0..7bb5d2e7 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -623,14 +623,14 @@ def _get_list_values_for_import(db_map, data, unparse_value): index = index_by_list_name.get(list_name) if index is None: current_list = db_map.mapped_table("parameter_value_list").find_item({"name": list_name}) - index = max( - ( - x["index"] - for x in db_map.mapped_table("list_value").valid_values() - if x["parameter_value_list_id"] == current_list["id"] - ), - default=-1, - ) + list_value_idx_by_val_typ = { + (x["value"], x["type"]): x["index"] + for x in db_map.mapped_table("list_value").valid_values() + if x["parameter_value_list_id"] == current_list["id"] + } + if (value, type_) in list_value_idx_by_val_typ: + continue + index = max((idx for idx in list_value_idx_by_val_typ.values()), default=-1) index += 1 index_by_list_name[list_name] = index yield {"parameter_value_list_name": list_name, "value": value, "type": type_, "index": index} diff --git a/spinedb_api/temp_id.py b/spinedb_api/temp_id.py index 28244c26..f7a067b3 100644 --- a/spinedb_api/temp_id.py +++ b/spinedb_api/temp_id.py @@ -35,8 +35,8 @@ def __hash__(self): return int(self) def __repr__(self): - resolved_to = f" - resolved to {self._db_id}" if self._db_id is not None else "" - return f"TempId({self._item_type}, {super().__repr__()}{resolved_to})" + resolved_to = f" resolved to {self._db_id}" if self._db_id is not None else "" + return f"TempId({self._item_type}, {super().__repr__()}){resolved_to}" def add_resolve_callback(self, callback): self._resolve_callbacks.append(callback) diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 1d9ae0f7..30eb190f 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -3070,12 +3070,6 @@ def test_committed_mapped_items_take_id_from_externally_committed_items(self): db_map0.add_entity_class_item(name="gadget") db_map0.commit_session("No comment") with CustomDatabaseMapping(url) as db_map1: - # Add classes to a model - model = {} - for x in db_map1.get_items("entity_class"): - model[x["id"]] = x - x.add_remove_callback(lambda x: model.pop(x["id"])) - self.assertEqual(len(model), 2) with CustomDatabaseMapping(url) as db_map2: # Purge, then add *gadget* before *widget* (swap the order) # Also add an entity @@ -3087,8 +3081,6 @@ def test_committed_mapped_items_take_id_from_externally_committed_items(self): # Check that we see the entity added by the other mapping phone = db_map1.get_entity_item(entity_class_name="gadget", name="phone") self.assertIsNotNone(phone) - # Overwritten classes should have been removed from the model - self.assertEqual(len(model), 0) def test_fetching_entities_after_external_change_has_renamed_their_classes(self): with TemporaryDirectory() as temp_dir: @@ -3227,94 +3219,26 @@ def test_cunning_ways_to_make_external_changes(self): db_map.refresh_session() entity_items = db_map.get_entity_items() self.assertEqual(len(entity_items), 2) - self.assertEqual( - entity_items[0]._extended(), - { - "id": 1, - "name": "other_entity", - "description": None, - "class_id": 1, - "element_id_list": (), - "element_name_list": (), - "commit_id": 4, - "entity_class_name": "interesting_class", - "dimension_id_list": (), - "dimension_name_list": (), - "element_byname_list": (), - "superclass_id": None, - "superclass_name": None, - }, - ) - self.assertEqual( - entity_items[1]._extended(), - { - "id": 2, - "name": "filler", - "description": None, - "class_id": 2, - "element_id_list": (), - "element_name_list": (), - "commit_id": 4, - "entity_class_name": "filler_class", - "dimension_id_list": (), - "dimension_name_list": (), - "element_byname_list": (), - "superclass_id": None, - "superclass_name": None, - }, - ) + unique_values = {(x["name"], x["entity_class_name"]) for x in entity_items} + self.assertIn(("other_entity", "interesting_class"), unique_values) + self.assertIn(("filler", "filler_class"), unique_values) value_items = db_map.get_parameter_value_items() self.assertEqual(len(value_items), 2) self.assertTrue(removed_item.is_committed()) - self.assertEqual( - value_items[0]._extended(), - { - "alternative_id": 1, - "alternative_name": "Base", - "commit_id": 4, - "dimension_id_list": (), - "dimension_name_list": (), - "element_id_list": (), - "element_name_list": (), - "entity_byname": ("filler",), - "entity_class_id": 2, - "entity_class_name": "filler_class", - "entity_id": 3, - "entity_name": "filler", - "id": 2, - "list_value_id": None, - "parameter_definition_id": 2, - "parameter_definition_name": "quantity", - "parameter_value_list_id": None, - "parameter_value_list_name": None, - "type": to_database(-2.3)[1], - "value": to_database(-2.3)[0], - }, - ) - self.assertEqual( - value_items[1]._extended(), - { - "alternative_id": 1, - "alternative_name": "Base", - "commit_id": 4, - "dimension_id_list": (), - "dimension_name_list": (), - "element_id_list": (), - "element_name_list": (), - "entity_byname": ("other_entity",), - "entity_class_id": 1, - "entity_class_name": "interesting_class", - "entity_id": 2, - "entity_name": "other_entity", - "id": 3, - "list_value_id": None, - "parameter_definition_id": 1, - "parameter_definition_name": "quality", - "parameter_value_list_id": None, - "parameter_value_list_name": None, - "type": to_database(99.9)[1], - "value": to_database(99.9)[0], - }, + unique_values = { + ( + x["entity_class_name"], + x["parameter_definition_name"], + x["entity_name"], + x["alternative_name"], + x["value"], + x["type"], + ) + for x in value_items + } + self.assertIn(("filler_class", "quantity", "filler", "Base", *to_database(-2.3)), unique_values) + self.assertIn( + ("interesting_class", "quality", "other_entity", "Base", *to_database(99.9)), unique_values ) def test_update_entity_metadata_externally(self): @@ -3351,10 +3275,11 @@ def test_update_entity_metadata_externally(self): self.assertEqual(len(metadata_items), 2) self.assertNotEqual(metadata_items[0]["id"], metadata_items[1]["id"]) unique_values = { - (x["entity_class_name"], x["entity_byname"], x["metadata_name"]) for x in metadata_items + (x["entity_class_name"], x["entity_byname"], x["metadata_name"], x["metadata_value"]) + for x in metadata_items } - self.assertIn(("my_class", ("my_entity",), "my_metadata"), unique_values) - self.assertIn(("my_class", ("other_entity",), "my_metadata"), unique_values) + self.assertIn(("my_class", ("my_entity",), "my_metadata", metadata_value), unique_values) + self.assertIn(("my_class", ("other_entity",), "my_metadata", metadata_value), unique_values) def test_update_parameter_value_metadata_externally(self): with TemporaryDirectory() as temp_dir: @@ -3423,11 +3348,14 @@ def test_update_parameter_value_metadata_externally(self): x["entity_byname"], x["metadata_name"], x["alternative_name"], + x["metadata_value"], ) for x in metadata_items } - self.assertIn(("my_class", "x", ("my_entity",), "my_metadata", "Base"), unique_values) - self.assertIn(("my_class", "x", ("other_entity",), "my_metadata", "Base"), unique_values) + self.assertIn(("my_class", "x", ("my_entity",), "my_metadata", "Base", metadata_value), unique_values) + self.assertIn( + ("my_class", "x", ("other_entity",), "my_metadata", "Base", metadata_value), unique_values + ) def test_update_entity_alternative_externally(self): with TemporaryDirectory() as temp_dir: @@ -3521,34 +3449,17 @@ def test_adding_same_parameters_values_to_different_entities_externally(self): ) ) shadow_db_map.commit_session("Add another entity.") - db_map.refresh_session() values = db_map.get_parameter_value_items() self.assertEqual(len(values), 1) - self.assertEqual( - values[0]._extended(), - { - "id": -2, - "entity_class_name": "my_class", - "entity_class_id": -1, - "dimension_name_list": (), - "dimension_id_list": (), - "parameter_definition_name": "x", - "parameter_definition_id": -1, - "entity_byname": ("other_entity",), - "entity_name": "other_entity", - "entity_id": -2, - "element_name_list": (), - "element_id_list": (), - "alternative_name": "Base", - "alternative_id": -1, - "parameter_value_list_name": None, - "parameter_value_list_id": None, - "list_value_id": None, - "type": value_type, - "value": value, - "commit_id": -4, - }, + unique_value = ( + values[0]["entity_class_name"], + values[0]["parameter_definition_name"], + values[0]["entity_name"], + values[0]["alternative_name"], ) + value_and_type = (values[0]["value"], values[0]["type"]) + self.assertEqual(unique_value, ("my_class", "x", "other_entity", "Base")) + self.assertEqual(value_and_type, (value, value_type)) def test_committing_changed_purged_entity_has_been_overwritten_by_external_change(self): with TemporaryDirectory() as temp_dir: @@ -3564,7 +3475,6 @@ def test_committing_changed_purged_entity_has_been_overwritten_by_external_chang shadow_db_map.add_entity_item(name="other_entity", entity_class_name="my_class") ) shadow_db_map.commit_session("Add another entity that steals ghost's id.") - db_map.refresh_session() db_map.do_fetch_all("entity") self._assert_success(db_map.add_entity_item(name="dirty_entity", entity_class_name="my_class")) db_map.commit_session("Add still uncommitted entity.") diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py index 373ba0f9..23e07710 100644 --- a/tests/test_import_functions.py +++ b/tests/test_import_functions.py @@ -1342,10 +1342,7 @@ def test_insert_scenario_alternative_in_the_middle_of_other_alternatives(self): self.assertEqual(count, 2) scenario_alternatives = self.scenario_alternatives() self.assertEqual(scenario_alternatives, {"scenario": {"alternative1": 2, "alternative2": 1}}) - count, errors = import_scenario_alternatives( - self._db_map, - [["scenario", "alternative3", "alternative1"]], - ) + count, errors = import_scenario_alternatives(self._db_map, [["scenario", "alternative3", "alternative1"]]) self.assertFalse(errors) self.assertEqual(count, 2) scenario_alternatives = self.scenario_alternatives() From 518f60ac65dc242fc90dab150f84a10dfe06b716 Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 8 Feb 2024 14:56:49 +0100 Subject: [PATCH 269/317] Update mapped item status in detach --- spinedb_api/db_mapping_base.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index e9f77ea4..0ae7a505 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -659,6 +659,7 @@ def __init__(self, db_map, item_type, **kwargs): self._status = Status.committed self._removal_source = None self._status_when_removed = None + self._status_when_committed = None self._backup = None self.public_item = PublicItem(self._db_map, self) @@ -1115,6 +1116,7 @@ def is_committed(self): def commit(self, commit_id): """Sets this item as committed with the given commit id.""" + self._status_when_committed = self._status self._status = Status.committed if commit_id: self["commit_id"] = commit_id @@ -1193,7 +1195,14 @@ def force_id(self, id_): def detach(self): """Detaches this item whose id now belongs to a different item after an external commit.""" self["id"].unresolve() - # TODO: Update item's status. + # TODO: Test if the below works... + if self.is_committed(): + self._status = self._status_when_committed + if self._status == Status.to_update: + self._status = Status.to_add + elif self._status == Status.to_remove: + self._status = Status.committed + self._status_when_removed = Status.to_add class PublicItem: From 53591970c6f8ee388392ce9b03e2e6321c150c1d Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 8 Feb 2024 17:01:08 +0100 Subject: [PATCH 270/317] Fix check for item equivalency --- spinedb_api/db_mapping_base.py | 35 +++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 0ae7a505..26c1d00a 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -526,11 +526,11 @@ def add_item_from_db(self, item, is_db_clean): tuple(MappedItem,bool): A mapped item and whether it needs to be added to the unique key values dict. """ mapped_item = self._find_item_by_unique_key(item, fetch=False, valid_only=False) - if mapped_item and (is_db_clean or mapped_item.is_equal_in_db(item)): + if mapped_item and (is_db_clean or self._same_item(mapped_item, item)): mapped_item.force_id(item["id"]) return mapped_item, False mapped_item = self.get(item["id"]) - if mapped_item and (is_db_clean or mapped_item.is_equal_in_db(item)): + if mapped_item and (is_db_clean or self._same_item(mapped_item.db_equivalent(), item)): return mapped_item, False conflicting_item = self.get(item["id"]) if conflicting_item is not None: @@ -541,6 +541,17 @@ def add_item_from_db(self, item, is_db_clean): mapped_item.cascade_remove(source=self.wildcard_item) return mapped_item, True + def _same_item(self, mapped_item, db_item): + """Whether the two given items have the same unique keys. + + Args: + mapped_item (MappedItemBase): an item in the in-memory mapping + db_item (dict): an item just fetched from the DB + """ + db_item = self._db_map.make_item(self._item_type, **db_item) + db_item.polish() + return dict(mapped_item.unique_key_values()) == dict(db_item.unique_key_values()) + def check_fields(self, item, valid_types=()): factory = self._db_map.item_factory(self._item_type) @@ -789,23 +800,17 @@ def _convert(x): or self.fields.get(key, {}).get("optional", False) # Ignore mandatory fields that are None ) - def is_equal_in_db(self, other): - """Returns whether this item and other are the same in the DB. - - Args: - other (dict) + def db_equivalent(self): + """The equivalent of this item in the DB. Returns: - bool + MappedItemBase """ if self.status == Status.to_update: - this = self._db_map.make_item(self._item_type, **self.backup) - this.polish() - else: - this = self - other = self._db_map.make_item(self._item_type, **other) - other.polish() - return dict(this.unique_key_values()) == dict(other.unique_key_values()) + db_item = self._db_map.make_item(self._item_type, **self.backup) + db_item.polish() + return db_item + return self def first_invalid_key(self): """Goes through the ``_references`` class attribute and returns the key of the first reference From 27a3c6b1df9f61615aca655a80442b555e0b11b7 Mon Sep 17 00:00:00 2001 From: Manuel Date: Thu, 8 Feb 2024 17:30:41 +0100 Subject: [PATCH 271/317] Rename one method for clarity --- spinedb_api/db_mapping_base.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 26c1d00a..f316f7bc 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -819,16 +819,15 @@ def first_invalid_key(self): Returns: str or None: unresolved reference's key if any. """ - return next(self._invalid_keys(), None) + return next(self._resolve_refs(), None) # TODO: Maybe rename this method to reflect its more important task now of replacing fields with TempIds - def _invalid_keys(self): - """Goes through the ``_references`` class attribute and returns the keys of the ones - that cannot be resolved. - Also, replace fields referring to db-ids with TempIds. + def _resolve_refs(self): + """Goes through the ``_references`` class attribute and tries to resolve them. + If successful, replace source fields referring to db-ids with the reference TempId. Yields: - str: unresolved keys if any. + str: the source field of any unresolved reference. """ for src_key, (ref_type, ref_key) in self._references.items(): try: @@ -981,7 +980,7 @@ def is_valid(self): return False self._to_remove = False self._corrupted = False - for _ in self._invalid_keys(): # This sets self._to_remove and self._corrupted + for _ in self._resolve_refs(): # This sets self._to_remove and self._corrupted pass if self._to_remove: self.cascade_remove() From d467319697f2171e2fc7776f2c7484610a196e7d Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 9 Feb 2024 10:46:51 +0100 Subject: [PATCH 272/317] Don't fetch_all the same data twice If the DB had external commits, we would fetch_all the same contents over and over. Now we do it only once per change in the commit count. --- spinedb_api/db_mapping.py | 9 ++++-- spinedb_api/db_mapping_base.py | 56 ++++++++++++++++++++++++---------- tests/test_DatabaseMapping.py | 4 +-- tests/test_db_mapping_base.py | 4 +-- 4 files changed, 50 insertions(+), 23 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index d30ec53b..0d387150 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -178,7 +178,6 @@ def __init__( if self._filter_configs is not None: stack = load_filters(self._filter_configs) apply_filter_stack(self, stack) - self._commit_count = self._query_commit_count() def __enter__(self): return self @@ -647,7 +646,7 @@ def fetch_all(self, *item_types): item_types = set(self.item_types()) if not item_types else set(item_types) & set(self.item_types()) for item_type in item_types: item_type = self.real_item_type(item_type) - self.do_fetch_all(item_type) + self.do_fetch_more(item_type) def query(self, *args, **kwargs): """Returns a :class:`~spinedb_api.query.Query` object to execute against the mapped DB. @@ -729,7 +728,11 @@ def refresh_session(self): self._refresh() def has_external_commits(self): - """See base class.""" + """Tests whether the database has had commits from other sources than this mapping. + + Returns: + bool: True if database has external commits, False otherwise + """ return self._commit_count != self._query_commit_count() def close(self): diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index f316f7bc..67b593e7 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -10,6 +10,7 @@ # this program. If not, see . ###################################################################################################################### +from multiprocessing import RLock from enum import Enum, unique, auto from difflib import SequenceMatcher from .temp_id import TempId, resolve @@ -36,13 +37,15 @@ class DatabaseMappingBase: This class is not meant to be used directly. Instead, you should subclass it to fit your particular DB schema. When subclassing, you need to implement :meth:`item_types`, :meth:`item_factory`, :meth:`_make_sq`, - and :meth:`has_external_commits`. + and :meth:`_query_commit_count`. """ def __init__(self): self.closed = False self._mapped_tables = {} - self._fetched = set() + self._fetched = {} + self._locks = {} + self._commit_count = None item_types = self.item_types() self._sorted_item_types = [] while item_types: @@ -129,11 +132,11 @@ def _make_sq(self, item_type): """ raise NotImplementedError() - def has_external_commits(self): - """Tests whether the database has had commits from other sources than this mapping. + def _query_commit_count(self): + """Returns the number of rows in the commit table in the DB. Returns: - bool: True if database has external commits, False otherwise + int """ raise NotImplementedError() @@ -155,13 +158,12 @@ def _dirty_items(self): Returns: list """ - if self.has_external_commits(): - self._refresh() + real_commit_count = self._query_commit_count() dirty_items = [] purged_item_types = {x for x in self.item_types() if self.mapped_table(x).purged} self._add_descendants(purged_item_types) for item_type in self._sorted_item_types: - self.do_fetch_all(item_type) # To fix conflicts in add_item_from_db + self.do_fetch_all(item_type, commit_count=real_commit_count) # To fix conflicts in add_item_from_db mapped_table = self.mapped_table(item_type) to_add = [] to_update = [] @@ -238,7 +240,7 @@ def reset(self, *item_types): self._add_descendants(item_types) for item_type in item_types: self._mapped_tables.pop(item_type, None) - self._fetched.discard(item_type) + self._fetched.pop(item_type, None) def reset_purging(self): """Resets purging status for all item types. @@ -288,12 +290,13 @@ def do_fetch_more(self, item_type, offset=0, limit=None, **kwargs): chunk = self._get_next_chunk(item_type, offset, limit, **kwargs) if not chunk: return [] - is_db_dirty = self.has_external_commits() + real_commit_count = self._query_commit_count() + is_db_dirty = self._get_commit_count() != real_commit_count if is_db_dirty: + # We need to fetch the most recent references because their ids might have changed in the DB for ref_type in self.item_factory(item_type).ref_types(): if ref_type != item_type: - self._fetched.discard(ref_type) - self.do_fetch_all(ref_type) + self.do_fetch_all(ref_type, commit_count=real_commit_count) mapped_table = self.mapped_table(item_type) items = [] new_items = [] @@ -309,10 +312,31 @@ def do_fetch_more(self, item_type, offset=0, limit=None, **kwargs): mapped_table.add_unique(item) return items - def do_fetch_all(self, item_type): - if item_type not in self._fetched: - self._fetched.add(item_type) - self.do_fetch_more(item_type, offset=0, limit=None) + def _get_commit_count(self): + """Returns current commit count. + + Returns: + int + """ + if self._commit_count is None: + self._commit_count = self._query_commit_count() + return self._commit_count + + def do_fetch_all(self, item_type, commit_count=None): + """Fetches all items of given type, but only once for each commit_count. + In other words, the second time this method is called with the same commit_count, it does nothing. + If not specified, commit_count defaults to the result of self._get_commit_count(). + + Args: + item_type (str) + commit_count (int,optional) + """ + if commit_count is None: + commit_count = self._get_commit_count() + with self._locks.setdefault(item_type, RLock()): + if self._fetched.get(item_type, -1) < commit_count: + self._fetched[item_type] = commit_count + self.do_fetch_more(item_type, offset=0, limit=None) class _MappedTable(dict): diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 30eb190f..b745ac9b 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -443,9 +443,9 @@ def test_fetch_more_after_commit_and_refresh(self): entities = db_map.fetch_more("entity") self.assertEqual([(x["entity_class_name"], x["name"]) for x in entities], [("Widget", "gadget")]) - def test_has_external_commits_returns_false_initially(self): + def test_has_external_commits_returns_true_initially(self): with DatabaseMapping("sqlite://", create=True) as db_map: - self.assertFalse(db_map.has_external_commits()) + self.assertTrue(db_map.has_external_commits()) def test_has_external_commits_returns_true_when_another_db_mapping_has_made_commits(self): with TemporaryDirectory() as temp_dir: diff --git a/tests/test_db_mapping_base.py b/tests/test_db_mapping_base.py index dc88d8af..cd67668c 100644 --- a/tests/test_db_mapping_base.py +++ b/tests/test_db_mapping_base.py @@ -28,8 +28,8 @@ def item_factory(item_type): return MappedItemBase raise RuntimeError(f"unknown item_type '{item_type}'") - def has_external_commits(self): - return False + def _query_commit_count(self): + return -1 def _make_query(self, _item_type, **kwargs): return None From 20ba7add30e8eef4f8bb5efb03d3eb550426a3c7 Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 9 Feb 2024 15:10:38 +0100 Subject: [PATCH 273/317] Fix import scenario_alternatives --- spinedb_api/import_functions.py | 2 +- tests/test_import_functions.py | 23 +++++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 62cb63a0..06b31f37 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -590,9 +590,9 @@ def _get_scenario_alternatives_for_import(db_map, data): some_added = False for pred, succ in list(succ_by_pred.items()): if succ in alternative_name_list: - i = alternative_name_list.index(succ) if pred in alternative_name_list: alternative_name_list.remove(pred) + i = alternative_name_list.index(succ) alternative_name_list.insert(i, pred) del succ_by_pred[pred] some_added = True diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py index 373ba0f9..14f5d624 100644 --- a/tests/test_import_functions.py +++ b/tests/test_import_functions.py @@ -1351,6 +1351,29 @@ def test_insert_scenario_alternative_in_the_middle_of_other_alternatives(self): scenario_alternatives = self.scenario_alternatives() self.assertEqual(scenario_alternatives, {"scenario": {"alternative1": 3, "alternative2": 1, "alternative3": 2}}) + def test_import_inconsistent_scenario_alternatives(self): + import_data(self._db_map, scenarios=["scenario"], alternatives=["alternative1", "alternative2", "alternative3"]) + count, errors = import_scenario_alternatives( + self._db_map, + [["scenario", "alternative3", "alternative1"], ["scenario", "alternative1"]], + ) + self.assertFalse(errors) + self.assertEqual(count, 2) + scenario_alternatives = self.scenario_alternatives() + self.assertEqual(scenario_alternatives, {"scenario": {"alternative1": 2, "alternative3": 1}}) + count, errors = import_scenario_alternatives( + self._db_map, + [ + ["scenario", "alternative3", "alternative2"], + ["scenario", "alternative2", "alternative1"], + ["scenario", "alternative1"], + ], + ) + self.assertFalse(errors) + self.assertEqual(count, 2) + scenario_alternatives = self.scenario_alternatives() + self.assertEqual(scenario_alternatives, {"scenario": {"alternative1": 3, "alternative2": 2, "alternative3": 1}}) + def scenario_alternatives(self): self._db_map.commit_session("test") scenario_alternative_qry = ( From a4313d5fdcf658d468ccc6120da98e020f179ca8 Mon Sep 17 00:00:00 2001 From: Manuel Date: Sun, 11 Feb 2024 12:18:28 +0100 Subject: [PATCH 274/317] Consistently set the removal source for lazy purge too Re #325 --- spinedb_api/db_mapping_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 63e9a60c..18ea3fc1 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -505,7 +505,7 @@ def add_item_from_db(self, item): item = self._make_and_add_item(item) if self.purged: # Lazy purge: instead of fetching all at purge time, we purge stuff as it comes. - item.cascade_remove(source=self.wildcard_item) + item.cascade_remove() return item, True def check_fields(self, item, valid_types=()): From fef05b9c3dd9d0fa4e1d087e7fc806d88ba3dfd7 Mon Sep 17 00:00:00 2001 From: Pekka T Savolainen Date: Mon, 12 Feb 2024 15:32:44 +0200 Subject: [PATCH 275/317] Update Github action --- .github/workflows/run_unit_tests.yml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/run_unit_tests.yml b/.github/workflows/run_unit_tests.yml index e17189af..b4d8d4cf 100644 --- a/.github/workflows/run_unit_tests.yml +++ b/.github/workflows/run_unit_tests.yml @@ -15,13 +15,13 @@ jobs: os: [ubuntu-22.04, windows-latest] python-version: [3.8, 3.9, "3.10", 3.11] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 - name: Version from Git tags run: git describe --tags - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Display Python version @@ -46,4 +46,6 @@ jobs: run: coverage run -m unittest discover --verbose - name: Upload coverage report to Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} From 0f336698f13d5b2d8d63c8d8250f530e856cc2d6 Mon Sep 17 00:00:00 2001 From: Manuel Date: Mon, 12 Feb 2024 14:56:07 +0100 Subject: [PATCH 276/317] Accept external changes if they were done after committing ours --- spinedb_api/db_mapping.py | 11 ++++--- spinedb_api/db_mapping_base.py | 58 +++++++++++++++++----------------- spinedb_api/mapped_items.py | 2 +- spinedb_api/temp_id.py | 14 ++++---- tests/test_DatabaseMapping.py | 16 ++++++++++ 5 files changed, 58 insertions(+), 43 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 0d387150..5243ddf0 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -190,7 +190,7 @@ def __del__(self): @staticmethod def item_types(): - return [x for x in DatabaseMapping._sq_name_by_item_type if item_factory(x).fields] + return [x for x in DatabaseMapping._sq_name_by_item_type if not item_factory(x).is_protected] @staticmethod def all_item_types(): @@ -359,10 +359,11 @@ def get_item(self, item_type, fetch=True, skip_removed=True, **kwargs): item_type = self.real_item_type(item_type) mapped_table = self.mapped_table(item_type) mapped_table.check_fields(kwargs, valid_types=(type(None),)) - item = mapped_table.find_item(kwargs, fetch=fetch) - if not item: - return {} - if skip_removed and not item.is_valid(): + item = mapped_table.find_item(kwargs) + if not item and fetch: + self.do_fetch_more(item_type, offset=0, limit=None, **kwargs) + item = mapped_table.find_item(kwargs) + if not item or (skip_removed and not item.is_valid()): return {} return item.public_item diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 95f1110b..52acac6a 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -106,6 +106,8 @@ def _make_query(self, item_type, **kwargs): sq = self._make_sq(item_type) qry = self.query(sq) for key, value in kwargs.items(): + if isinstance(value, tuple): + continue if hasattr(sq.c, key): qry = qry.filter(getattr(sq.c, key) == value) elif key in self.item_factory(item_type)._external_fields: @@ -175,6 +177,7 @@ def _dirty_items(self): to_update.append(item) if item_type in purged_item_types: to_remove.append(mapped_table.wildcard_item) + to_remove.extend(mapped_table.values()) else: for item in mapped_table.values(): _ = item.is_valid() @@ -305,6 +308,8 @@ def do_fetch_more(self, item_type, offset=0, limit=None, **kwargs): item, new = mapped_table.add_item_from_db(x, not is_db_dirty) if new: new_items.append(item) + else: + item.reset_state() items.append(item) # Once all items are added, add the unique key values # Otherwise items that refer to other items that come later in the query will be seen as corrupted @@ -366,13 +371,7 @@ def get(self, id_, default=None): return super().get(id_, default) def _new_id(self): - temp_id = TempId(self._item_type) - - def _callback(db_id): - self._temp_id_by_db_id[db_id] = temp_id - - temp_id.add_resolve_callback(_callback) - return temp_id + return TempId(self._item_type, self._temp_id_by_db_id) def _unique_key_value_to_id(self, key, value, fetch=True): """Returns the id that has the given value for the given unique key, or None. @@ -672,6 +671,7 @@ class MappedItemBase(dict): """ _private_fields = set() """A set with fields that should be ignored in validations.""" + is_protected = False def __init__(self, db_map, item_type, **kwargs): """ @@ -687,7 +687,6 @@ def __init__(self, db_map, item_type, **kwargs): self.update_callbacks = set() self.remove_callbacks = set() self._has_valid_id = True - self._to_remove = False self._removed = False self._corrupted = False self._valid = None @@ -698,6 +697,17 @@ def __init__(self, db_map, item_type, **kwargs): self._backup = None self.public_item = PublicItem(self._db_map, self) + def reset_state(self): + """Called when an equivalent item is fetched from the DB. + + If this item is already committed, we assume the one from the DB is newer so we reset the state. + Otherwise we assume *this* is newer and do nothing. + """ + if self.is_committed(): + self._removed = False + self._corrupted = False + self._valid = None + @classmethod def ref_types(cls): """Returns a set of item types that this class refers. @@ -843,34 +853,32 @@ def first_invalid_key(self): Returns: str or None: unresolved reference's key if any. """ - return next(self._resolve_refs(), None) + return next((src_key for src_key, ref in self._resolve_refs() if not ref), None) - # TODO: Maybe rename this method to reflect its more important task now of replacing fields with TempIds def _resolve_refs(self): """Goes through the ``_references`` class attribute and tries to resolve them. If successful, replace source fields referring to db-ids with the reference TempId. Yields: - str: the source field of any unresolved reference. + tuple(str,MappedItem or None): the source field and resolved ref. """ for src_key, (ref_type, ref_key) in self._references.items(): try: src_val = self[src_key] except KeyError: - yield src_key + yield src_key, None else: if isinstance(src_val, tuple): refs = tuple(self._get_ref(ref_type, {ref_key: x}) for x in src_val) - if not all(refs): - yield src_key - elif ref_key == "id": + if all(refs) and ref_key == "id": self[src_key] = tuple(ref["id"] for ref in refs) + for ref in refs: + yield src_key, ref else: ref = self._get_ref(ref_type, {ref_key: src_val}) - if not self._get_ref(ref_type, {ref_key: src_val}): - yield src_key - elif ref_key == "id": + if ref and ref_key == "id": self[src_key] = ref["id"] + yield src_key, ref @classmethod def unique_values_for_item(cls, item, skip_keys=()): @@ -954,7 +962,6 @@ def _get_ref(self, ref_type, key_val, strong=True): """Collects a reference from the in-memory mapping. Adds this item to the reference's list of referrers if strong is True; or weak referrers if strong is False. - Sets the self._corrupted and self._removed flags appropriately. Args: ref_type (str): The reference's type @@ -967,13 +974,9 @@ def _get_ref(self, ref_type, key_val, strong=True): mapped_table = self._db_map.mapped_table(ref_type) ref = mapped_table.find_item(key_val, fetch=True) if not ref: - if strong: - self._corrupted = True return {} if strong: ref.add_referrer(self) - if ref.removed: - self._to_remove = True else: ref.add_weak_referrer(self) if ref.removed: @@ -1002,11 +1005,9 @@ def is_valid(self): return self._valid if self._removed or self._corrupted: return False - self._to_remove = False - self._corrupted = False - for _ in self._resolve_refs(): # This sets self._to_remove and self._corrupted - pass - if self._to_remove: + refs = [ref for _, ref in self._resolve_refs()] + self._corrupted = not all(refs) + if any(ref and ref.removed for ref in refs): self.cascade_remove() self._valid = not self._removed and not self._corrupted return self._valid @@ -1090,7 +1091,6 @@ def cascade_remove(self, source=None): raise RuntimeError("invalid status for item being removed") self._removal_source = source self._removed = True - self._to_remove = False self._valid = None # First remove referrers, then this for referrer in self._referrers.values(): diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 6a0f7104..9f73a627 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -47,8 +47,8 @@ class CommitItem(MappedItemBase): 'date': {'type': str, 'value': 'Date and time of the commit in ISO 8601 format.'}, 'user': {'type': str, 'value': 'Username of the committer.'}, } - _unique_keys = (("date",),) + is_protected = True def commit(self, commit_id): raise RuntimeError("Commits are created automatically when session is committed.") diff --git a/spinedb_api/temp_id.py b/spinedb_api/temp_id.py index f7a067b3..71e7b6a1 100644 --- a/spinedb_api/temp_id.py +++ b/spinedb_api/temp_id.py @@ -13,15 +13,15 @@ class TempId(int): _next_id = {} - def __new__(cls, item_type): + def __new__(cls, item_type, _id_map): id_ = cls._next_id.setdefault(item_type, -1) cls._next_id[item_type] -= 1 return super().__new__(cls, id_) - def __init__(self, item_type): + def __init__(self, item_type, id_map): super().__init__() self._item_type = item_type - self._resolve_callbacks = [] + self._id_map = id_map self._db_id = None @property @@ -29,22 +29,20 @@ def db_id(self): return self._db_id def __eq__(self, other): + # FIXME: Can we avoid this? return super().__eq__(other) or (self._db_id is not None and other == self._db_id) def __hash__(self): + # FIXME: Can we avoid this? return int(self) def __repr__(self): resolved_to = f" resolved to {self._db_id}" if self._db_id is not None else "" return f"TempId({self._item_type}, {super().__repr__()}){resolved_to}" - def add_resolve_callback(self, callback): - self._resolve_callbacks.append(callback) - def resolve(self, db_id): self._db_id = db_id - for callback in self._resolve_callbacks: - callback(db_id) + self._id_map[db_id] = self def unresolve(self): self._db_id = None diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index b745ac9b..835a39ac 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -3481,6 +3481,22 @@ def test_committing_changed_purged_entity_has_been_overwritten_by_external_chang entities = db_map.query(db_map.wide_entity_sq).all() self.assertEqual(len(entities), 2) + def test_db_items_prevail_if_mapped_items_are_committed(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + self._assert_success(db_map.add_entity_class_item(name="my_class")) + db_map.commit_session("Add some data") + with DatabaseMapping(url, create=True) as db_map: + db_map.purge_items("entity_class") + db_map.commit_session("Purge all") + with DatabaseMapping(url) as shadow_db_map: + self._assert_success(shadow_db_map.add_entity_class_item(name="my_class")) + shadow_db_map.commit_session("Add same class") + entity_class_item = db_map.get_entity_class_item(name="my_class") + self.assertTrue(entity_class_item) + self.assertEqual(entity_class_item["name"], "my_class") + if __name__ == "__main__": unittest.main() From 7e4c061384fb4ead1df3da0338a2fe6ce3984489 Mon Sep 17 00:00:00 2001 From: Manuel Date: Mon, 12 Feb 2024 16:09:39 +0100 Subject: [PATCH 277/317] Fix TempId.__new__ to make it picklable --- spinedb_api/db_mapping_base.py | 4 +++- spinedb_api/temp_id.py | 9 ++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 52acac6a..f95e3c91 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -371,7 +371,9 @@ def get(self, id_, default=None): return super().get(id_, default) def _new_id(self): - return TempId(self._item_type, self._temp_id_by_db_id) + temp_id = TempId(self._item_type) + temp_id.set_id_map(self._temp_id_by_db_id) + return temp_id def _unique_key_value_to_id(self, key, value, fetch=True): """Returns the id that has the given value for the given unique key, or None. diff --git a/spinedb_api/temp_id.py b/spinedb_api/temp_id.py index 71e7b6a1..b294176a 100644 --- a/spinedb_api/temp_id.py +++ b/spinedb_api/temp_id.py @@ -13,17 +13,20 @@ class TempId(int): _next_id = {} - def __new__(cls, item_type, _id_map): + def __new__(cls, item_type): id_ = cls._next_id.setdefault(item_type, -1) cls._next_id[item_type] -= 1 return super().__new__(cls, id_) - def __init__(self, item_type, id_map): + def __init__(self, item_type): super().__init__() self._item_type = item_type - self._id_map = id_map + self._id_map = {} self._db_id = None + def set_id_map(self, id_map): + self._id_map = id_map + @property def db_id(self): return self._db_id From 4178308dab9e87a288f6509d8ee9c510ba0dcb6a Mon Sep 17 00:00:00 2001 From: Manuel Date: Mon, 12 Feb 2024 17:10:32 +0100 Subject: [PATCH 278/317] Fix issues with pickling TempIds --- spinedb_api/db_mapping_base.py | 7 ++++--- spinedb_api/temp_id.py | 21 +++++++++------------ 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index f95e3c91..87740c95 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -356,6 +356,7 @@ def __init__(self, db_map, item_type, *args, **kwargs): self._item_type = item_type self._ids_by_unique_key_value = {} self._temp_id_by_db_id = {} + self._next_id = -1 self.wildcard_item = MappedItemBase(self._db_map, self._item_type, id=Asterisk) @property @@ -371,9 +372,9 @@ def get(self, id_, default=None): return super().get(id_, default) def _new_id(self): - temp_id = TempId(self._item_type) - temp_id.set_id_map(self._temp_id_by_db_id) - return temp_id + id_ = self._next_id + self._next_id -= 1 + return TempId(id_, self._item_type, self._temp_id_by_db_id) def _unique_key_value_to_id(self, key, value, fetch=True): """Returns the id that has the given value for the given unique key, or None. diff --git a/spinedb_api/temp_id.py b/spinedb_api/temp_id.py index b294176a..0376bd44 100644 --- a/spinedb_api/temp_id.py +++ b/spinedb_api/temp_id.py @@ -11,21 +11,14 @@ class TempId(int): - _next_id = {} + def __new__(cls, value, *args): + return super().__new__(cls, value) - def __new__(cls, item_type): - id_ = cls._next_id.setdefault(item_type, -1) - cls._next_id[item_type] -= 1 - return super().__new__(cls, id_) - - def __init__(self, item_type): + def __init__(self, _value, item_type, id_map): super().__init__() self._item_type = item_type - self._id_map = {} - self._db_id = None - - def set_id_map(self, id_map): self._id_map = id_map + self._db_id = None @property def db_id(self): @@ -33,7 +26,11 @@ def db_id(self): def __eq__(self, other): # FIXME: Can we avoid this? - return super().__eq__(other) or (self._db_id is not None and other == self._db_id) + if super().__eq__(other): + return True + if self._db_id is not None: + return other == self._db_id + return False def __hash__(self): # FIXME: Can we avoid this? From dad322e6156632eed1e080569df1a99b6530b12d Mon Sep 17 00:00:00 2001 From: Manuel Date: Mon, 12 Feb 2024 18:17:01 +0100 Subject: [PATCH 279/317] Don't show removed stuff after a refresh When the user refreshes session and there are external changes, we mark all committed items as compromised - since we don't know if they are still in the DB. Then as we fetch items, we mark them as uncompromised (that is, back to committed). Compromised items are considered invalid so unless they are refetched, they stay in the dark. This doesn't apply to uncommitted items since these are considered for the future, so they shouldn't be affected by newer state of the DB. Let's see. --- spinedb_api/db_mapping_base.py | 34 +++++++++++++++++++++++++++------- tests/test_DatabaseMapping.py | 19 ++++++++++++++++++- 2 files changed, 45 insertions(+), 8 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 87740c95..cba46749 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -29,6 +29,7 @@ class Status(Enum): to_update = auto() to_remove = auto() added_and_removed = auto() + compromised = auto() class DatabaseMappingBase: @@ -222,7 +223,13 @@ def _rollback(self): def _refresh(self): """Clears fetch progress, so the DB is queried again.""" + if self._commit_count == self._query_commit_count(): + return self._fetched.clear() + for item_type in self.item_types(): + mapped_table = self.mapped_table(item_type) + for item in mapped_table.values(): + item.handle_refresh() def _check_item_type(self, item_type): if item_type not in self.all_item_types(): @@ -309,7 +316,7 @@ def do_fetch_more(self, item_type, offset=0, limit=None, **kwargs): if new: new_items.append(item) else: - item.reset_state() + item.handle_refetch() items.append(item) # Once all items are added, add the unique key values # Otherwise items that refer to other items that come later in the query will be seen as corrupted @@ -560,7 +567,7 @@ def add_item_from_db(self, item, is_db_clean): return mapped_item, False conflicting_item = self.get(item["id"]) if conflicting_item is not None: - conflicting_item.detach() + conflicting_item.handle_id_steal() mapped_item = self._make_and_add_item(item) if self.purged: # Lazy purge: instead of fetching all at purge time, we purge stuff as it comes. @@ -700,17 +707,28 @@ def __init__(self, db_map, item_type, **kwargs): self._backup = None self.public_item = PublicItem(self._db_map, self) - def reset_state(self): + def handle_refetch(self): """Called when an equivalent item is fetched from the DB. - If this item is already committed, we assume the one from the DB is newer so we reset the state. - Otherwise we assume *this* is newer and do nothing. + 1. If this item is compromised, then mark it as committed. + 2. If this item is committed, then assume the one from the DB is newer and reset the state. + Otherwise assume *this* is newer and do nothing. """ + if self.status == Status.compromised: + self.status = Status.committed if self.is_committed(): self._removed = False self._corrupted = False self._valid = None + def handle_refresh(self): + """Called when the mapping is refreshed. + + If this item is committed, then set it as compromised. + """ + if self.status == Status.committed: + self.status = Status.compromised + @classmethod def ref_types(cls): """Returns a set of item types that this class refers. @@ -1004,6 +1022,8 @@ def is_valid(self): Returns: bool """ + if self.status == Status.compromised: + return False if self._valid is not None: return self._valid if self._removed or self._corrupted: @@ -1223,8 +1243,8 @@ def force_id(self, id_): if self.status == Status.to_add: self.status = Status.committed - def detach(self): - """Detaches this item whose id now belongs to a different item after an external commit.""" + def handle_id_steal(self): + """Called when a new item is fetched from the DB with this item's id.""" self["id"].unresolve() # TODO: Test if the below works... if self.is_committed(): diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 835a39ac..76a3323f 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -3216,7 +3216,6 @@ def test_cunning_ways_to_make_external_changes(self): ) ) shadow_db_map.commit_session("Add entities.") - db_map.refresh_session() entity_items = db_map.get_entity_items() self.assertEqual(len(entity_items), 2) unique_values = {(x["name"], x["entity_class_name"]) for x in entity_items} @@ -3497,6 +3496,24 @@ def test_db_items_prevail_if_mapped_items_are_committed(self): self.assertTrue(entity_class_item) self.assertEqual(entity_class_item["name"], "my_class") + def test_remove_items_then_refresh(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + self._assert_success(db_map.add_entity_class_item(name="my_class")) + self._assert_success(db_map.add_entity_class_item(name="new_class")) + db_map.commit_session("Add some data") + with DatabaseMapping(url, create=True) as db_map: + db_map.fetch_all("entity_class") + with DatabaseMapping(url) as shadow_db_map: + shadow_db_map.purge_items("entity_class") + self._assert_success(shadow_db_map.add_entity_class_item(name="new_class")) + shadow_db_map.commit_session("Purge then add new class back") + db_map.refresh_session() + entity_class_names = [x["name"] for x in db_map.get_entity_class_items()] + self.assertIn("new_class", entity_class_names) + self.assertNotIn("my_class", entity_class_names) + if __name__ == "__main__": unittest.main() From 2e31fdb9db75df670b876002812472906d14aba3 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Mon, 12 Feb 2024 10:22:30 +0200 Subject: [PATCH 280/317] Add Spine Toolbox unit and execution tests to GitHub workflows Re #358 --- .github/workflows/run_unit_tests.yml | 90 +++++++++++++++++++++++++++- 1 file changed, 87 insertions(+), 3 deletions(-) diff --git a/.github/workflows/run_unit_tests.yml b/.github/workflows/run_unit_tests.yml index b4d8d4cf..22a48065 100644 --- a/.github/workflows/run_unit_tests.yml +++ b/.github/workflows/run_unit_tests.yml @@ -4,7 +4,12 @@ name: Unit tests # Run workflow on every push on: - push + push: + paths: + - "**.py" + - "requirements.txt" + - "pyproject.toml" + - ".github/workflows/*.yml" jobs: unit-tests: @@ -36,10 +41,10 @@ jobs: PYTHONUTF8: 1 run: | python -m pip install --upgrade pip - pip install .[dev] + python -m pip install .[dev] - name: List packages run: - pip list + python -m pip list - name: Run tests env: QT_QPA_PLATFORM: offscreen @@ -49,3 +54,82 @@ jobs: uses: codecov/codecov-action@v4 env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + toolbox-unit-tests: + name: Spine Toolbox unit tests + runs-on: ${{ matrix.os }} + strategy: + fail-fast: true + matrix: + python-version: [3.8] + os: [ubuntu-22.04] + steps: + - uses: actions/checkout@v4 + with: + repository: spine-tools/Spine-Toolbox + fetch-depth: 0 + # Temporarily fetch the 0.8-dev branch until everything is merged to master + ref: 0.8-dev + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install additional packages for Linux + if: runner.os == 'Linux' + run: | + sudo apt-get update -y + sudo apt-get install -y libegl1 + - name: Install dependencies + env: + PYTHONUTF8: 1 + run: | + python -m pip install --upgrade pip + python -m pip install -r requirements.txt + - name: List packages + run: + python -m pip list + - name: Install python3 kernelspecs + run: | + python -m pip install ipykernel + python -m ipykernel install --user + - name: Run tests + run: | + if [ "$RUNNER_OS" != "Windows" ]; then + export QT_QPA_PLATFORM=offscreen + fi + python -m unittest discover --verbose + shell: bash + toolbox-execution-tests: + name: Spine Toolbox execution tests + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + python-version: [3.8] + os: [ubuntu-22.04] + steps: + - uses: actions/checkout@v4 + with: + repository: spine-tools/Spine-Toolbox + # Temporarily fetch the 0.8-dev branch until everything is merged to master + ref: 0.8-dev + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install additional packages for Linux + if: runner.os == 'Linux' + run: | + sudo apt-get update -y + sudo apt-get install -y libegl1 + - name: Install dependencies + env: + PYTHONUTF8: 1 + run: | + python -m pip install --upgrade pip + python -m pip install -r requirements.txt + - name: List packages + run: + python -m pip list + - name: Run tests + run: + python -m unittest discover --pattern execution_test.py --verbose From bf575a42a8d14013eaca233171c1cf495313af2a Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Tue, 13 Feb 2024 14:24:42 +0200 Subject: [PATCH 281/317] Add index_name to Parameter value format documentation Re #359 --- docs/source/parameter_value_format.rst | 66 +++++++++++++++++++++++--- 1 file changed, 60 insertions(+), 6 deletions(-) diff --git a/docs/source/parameter_value_format.rst b/docs/source/parameter_value_format.rst index ea6f3513..8284c859 100644 --- a/docs/source/parameter_value_format.rst +++ b/docs/source/parameter_value_format.rst @@ -111,6 +111,9 @@ where the accepted values are the following: The ``data`` property must be a JSON object mapping time periods to values. +A time-pattern may have an additional property, ``index_name``. +``index_name`` must be a JSON string. If not specified, a default name 'p' will be used. + Example ~~~~~~~ @@ -143,7 +146,7 @@ Accepted values for the ``data`` property are the following: In this case it is assumed that the time-series begins at the first hour of *any* year, has a resolution of one hour, and repeats cyclically until the *end* of time. -In case of time-series, the specification may have one additional property, ``index``. +In case of time-series, the specification may have two additional properties, ``index`` and ``index_name``. ``index`` must be a JSON object with the following properties, all of them optional: - ``start``: the *first* time-stamp, used in case ``data`` is a one-column array (ignored otherwise). @@ -158,6 +161,8 @@ In case of time-series, the specification may have one additional property, ``in - ``repeat``: a JSON boolean whether or not the time-series should repeat cyclically until the *end* of time. The default is ``false``, unless ``data`` is a one-column array and ``start`` is not given. +``index_name`` must be a JSON string. If not specified, a default name 't' will be used. + Examples ~~~~~~~~ @@ -212,6 +217,21 @@ One-column array with explicit (custom) indices: } } +Two-column array with named indices: + +.. code-block:: json + + { + + "type": "time_series", + "data": [ + ["2019-01-01T00:00", 1], + ["2019-01-01T00:30", 2], + ["2019-01-01T02:00", 8] + ], + "index_name": "Time stamps" + } + Array ----- @@ -221,7 +241,7 @@ All values are of the same type which is specified by an optional ``value_type`` If specified, ``value_type`` must be one of the following: ``float``, ``str``, ``duration``, or ``date_time``. If omitted, ``value_type`` defaults to ``float`` -The ``data`` property must be a JSON list. The elements depent on ``value_type``: +The ``data`` property must be a JSON list. The elements depend on ``value_type``: - If ``value_type`` is ``float`` then all elements in ``data`` must be JSON numbers. - If ``value_type`` is ``str`` then all elements in ``data`` must be JSON strings. @@ -229,6 +249,10 @@ The ``data`` property must be a JSON list. The elements depent on ``value_type`` - If ``value_type`` is ``date_time`` then all elements in ``data`` must be JSON strings in the `ISO8601 `_ format. +An array may have an additional property, ``index_name``. +``index_name`` must be a JSON string. If not specified, a default name 'i' will be used. + + Examples ~~~~~~~~ @@ -251,6 +275,17 @@ An array of durations: "data": ["3 months", "2Y", "4 minutes"] } +An array of strings with index name: + +.. code-block:: json + + { + "type": "array", + "data": ["one", "two"], + "index_name": "step" + } + + Map --- @@ -279,6 +314,9 @@ The ``data`` property can be a JSON mapping with the following properties: Optionally, the ``data`` property can be a two-column JSON array where the first element is the key and the second the value. +A map may have an additional property, ``index_name``. +``index_name`` must be a JSON string. If not specified, a default name 'x' will be used. + Examples ~~~~~~~~ @@ -331,12 +369,28 @@ Forecast time Target time Stochastic scenario Value { "type": "map", "index_type": "date_time", + "index_name": "Forecast time", "data": [ ["2020-04-17T08:00", - {"type": "map", "index_type": "date_time", "data": [ - ["2020-04-17T08:00", {"type": "map", "index_type": "float", "data": [[0, 23.0], [1, 5.5]]}], - ["2020-04-17T09:00", {"type": "map", "index_type": "float", "data": [[0, 24.0], [1, 6.6]]}], - ["2020-04-17T10:00", {"type": "map", "index_type": "float", "data": [[0, 25.0], [1, 7.7]]}] + {"type": "map", "index_type": "date_time", "index_name": "Target time", "data": [ + [ + "2020-04-17T08:00", {"type": "map", + "index_type": "float", + "index_name": "Stochastic scenario", + "data": [[0, 23.0], [1, 5.5]]} + ], + [ + "2020-04-17T09:00", {"type": "map", + "index_type": "float", + "index_name": "Stochastic scenario", + "data": [[0, 24.0], [1, 6.6]]} + ], + [ + "2020-04-17T10:00", {"type": "map", + "index_type": "float", + "index_name": "Stochastic scenario", + "data": [[0, 25.0], [1, 7.7]]} + ] ]} ] ] From eb0d64a8de56feccb0753030269d144338291eb8 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 14 Feb 2024 14:11:36 +0100 Subject: [PATCH 282/317] Fix TempId.__hash__ to follow python rules --- spinedb_api/db_mapping_base.py | 9 ++++----- spinedb_api/mapped_items.py | 2 +- spinedb_api/temp_id.py | 28 ++++++++++++---------------- tests/test_DatabaseMapping.py | 4 ++-- 4 files changed, 19 insertions(+), 24 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index cba46749..264629a7 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -363,7 +363,6 @@ def __init__(self, db_map, item_type, *args, **kwargs): self._item_type = item_type self._ids_by_unique_key_value = {} self._temp_id_by_db_id = {} - self._next_id = -1 self.wildcard_item = MappedItemBase(self._db_map, self._item_type, id=Asterisk) @property @@ -379,9 +378,7 @@ def get(self, id_, default=None): return super().get(id_, default) def _new_id(self): - id_ = self._next_id - self._next_id -= 1 - return TempId(id_, self._item_type, self._temp_id_by_db_id) + return TempId(self._item_type, self._temp_id_by_db_id) def _unique_key_value_to_id(self, key, value, fetch=True): """Returns the id that has the given value for the given unique key, or None. @@ -846,7 +843,9 @@ def merge(self, other): def _something_to_update(self, other): def _convert(x): - return tuple(x) if isinstance(x, list) else x + if isinstance(x, list): + x = tuple(x) + return resolve(x) return not all( _convert(self.get(key)) == _convert(value) diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 9f73a627..f0066e6b 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -792,7 +792,7 @@ class SuperclassSubclassItem(MappedItemBase): _internal_fields = {"superclass_id": (("superclass_name",), "id"), "subclass_id": (("subclass_name",), "id")} def _subclass_entities(self): - return self._db_map.get_items("entity", class_id=self["subclass_id"]) + return self._db_map.get_items("entity", class_id=self["subclass_id"], fetch=False) def check_mutability(self): if self._subclass_entities(): diff --git a/spinedb_api/temp_id.py b/spinedb_api/temp_id.py index 0376bd44..bd2e49fe 100644 --- a/spinedb_api/temp_id.py +++ b/spinedb_api/temp_id.py @@ -9,13 +9,13 @@ # this program. If not, see . ###################################################################################################################### +import uuid -class TempId(int): - def __new__(cls, value, *args): - return super().__new__(cls, value) - def __init__(self, _value, item_type, id_map): +class TempId: + def __init__(self, item_type, id_map): super().__init__() + self._id = uuid.uuid4() self._item_type = item_type self._id_map = id_map self._db_id = None @@ -24,21 +24,15 @@ def __init__(self, _value, item_type, id_map): def db_id(self): return self._db_id + def __repr__(self): + resolved_to = f" resolved to {self._db_id}" if self._db_id is not None else "" + return f"TempId({self._item_type}){resolved_to}" + def __eq__(self, other): - # FIXME: Can we avoid this? - if super().__eq__(other): - return True - if self._db_id is not None: - return other == self._db_id - return False + return isinstance(other, TempId) and other._item_type == self._item_type and other._id == self._id def __hash__(self): - # FIXME: Can we avoid this? - return int(self) - - def __repr__(self): - resolved_to = f" resolved to {self._db_id}" if self._db_id is not None else "" - return f"TempId({self._item_type}, {super().__repr__()}){resolved_to}" + return hash((self._item_type, self._id)) def resolve(self, db_id): self._db_id = db_id @@ -49,6 +43,8 @@ def unresolve(self): def resolve(value): + if isinstance(value, tuple): + return tuple(resolve(v) for v in value) if isinstance(value, dict): return {k: resolve(v) for k, v in value.items()} if isinstance(value, TempId): diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 76a3323f..d30745ef 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -432,7 +432,7 @@ def test_fetch_more(self): with DatabaseMapping("sqlite://", create=True) as db_map: alternatives = db_map.fetch_more("alternative") expected = [{"id": 1, "name": "Base", "description": "Base alternative", "commit_id": 1}] - self.assertEqual([a._asdict() for a in alternatives], expected) + self.assertEqual([a.resolve() for a in alternatives], expected) def test_fetch_more_after_commit_and_refresh(self): with DatabaseMapping("sqlite://", create=True) as db_map: @@ -3117,7 +3117,7 @@ def test_additive_commit_from_another_db_map_gets_fetched(self): items = db_map.get_items("entity") self.assertEqual(len(items), 1) self.assertEqual( - items[0]._asdict(), + items[0].resolve(), { "id": 1, "name": "my_entity", From 956ad39da0a9a2f5d8a603c767a52a13ffd29587 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 14 Feb 2024 14:18:54 +0100 Subject: [PATCH 283/317] Add one unit-test for 'difficult' refresh --- tests/test_DatabaseMapping.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 76a3323f..22e7ae38 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -3496,7 +3496,7 @@ def test_db_items_prevail_if_mapped_items_are_committed(self): self.assertTrue(entity_class_item) self.assertEqual(entity_class_item["name"], "my_class") - def test_remove_items_then_refresh(self): + def test_remove_items_then_refresh_then_readd(self): with TemporaryDirectory() as temp_dir: url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") with DatabaseMapping(url, create=True) as db_map: @@ -3514,6 +3514,26 @@ def test_remove_items_then_refresh(self): self.assertIn("new_class", entity_class_names) self.assertNotIn("my_class", entity_class_names) + def test_remove_items_then_refresh_then_readd2(self): + with TemporaryDirectory() as temp_dir: + url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") + with DatabaseMapping(url, create=True) as db_map: + self._assert_success(db_map.add_entity_class_item(name="xxx")) + self._assert_success(db_map.add_entity_class_item(name="yyy")) + self._assert_success(db_map.add_entity_class_item(name="zzz")) + db_map.commit_session("Add some data") + with DatabaseMapping(url, create=True) as db_map: + db_map.fetch_all("entity_class") + with DatabaseMapping(url) as shadow_db_map: + shadow_db_map.purge_items("entity_class") + self._assert_success(shadow_db_map.add_entity_class_item(name="zzz")) + self._assert_success(shadow_db_map.add_entity_class_item(name="www")) + shadow_db_map.commit_session("Purge then add one old class and one new class") + db_map.refresh_session() + entity_class_names = [x["name"] for x in db_map.get_entity_class_items()] + self.assertEqual(len(entity_class_names), 2) + self.assertEqual(set(entity_class_names), {"zzz", "www"}) + if __name__ == "__main__": unittest.main() From 50f41d26270fdcd37066e873b63d547f9c274e9a Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 16 Feb 2024 11:09:16 +0100 Subject: [PATCH 284/317] Implement __lt__ for TempId Looks like Spine DB Editor's entity graph view relies on being able to sort the ids for placing the multi-D entities nicely. --- spinedb_api/temp_id.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/spinedb_api/temp_id.py b/spinedb_api/temp_id.py index bd2e49fe..36857769 100644 --- a/spinedb_api/temp_id.py +++ b/spinedb_api/temp_id.py @@ -34,6 +34,9 @@ def __eq__(self, other): def __hash__(self): return hash((self._item_type, self._id)) + def __lt__(self, other): + return self._id < other._id + def resolve(self, db_id): self._db_id = db_id self._id_map[db_id] = self From c3830deb70211a4282673ab4789f486bebf416e2 Mon Sep 17 00:00:00 2001 From: Henrik Koski Date: Fri, 16 Feb 2024 12:20:36 +0200 Subject: [PATCH 285/317] Fix graph view not drawing anything Re spine-tools/Spine-Toolbox#2593 --- spinedb_api/temp_id.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/spinedb_api/temp_id.py b/spinedb_api/temp_id.py index bd2e49fe..948f13ec 100644 --- a/spinedb_api/temp_id.py +++ b/spinedb_api/temp_id.py @@ -31,6 +31,12 @@ def __repr__(self): def __eq__(self, other): return isinstance(other, TempId) and other._item_type == self._item_type and other._id == self._id + def __gt__(self, other): + return isinstance(other, TempId) and other._item_type == self._item_type and self._id > other._id + + def __lt__(self, other): + return isinstance(other, TempId) and other._item_type == self._item_type and self._id < other._id + def __hash__(self): return hash((self._item_type, self._id)) From 40d8ab97c93aac24928d6a722638b0a070ce9eda Mon Sep 17 00:00:00 2001 From: Henrik Koski Date: Fri, 16 Feb 2024 12:36:14 +0200 Subject: [PATCH 286/317] Remove duplicate of method Re spine-tools/Spine-Toolbox#2593 --- spinedb_api/temp_id.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/spinedb_api/temp_id.py b/spinedb_api/temp_id.py index fca02bc0..948f13ec 100644 --- a/spinedb_api/temp_id.py +++ b/spinedb_api/temp_id.py @@ -40,9 +40,6 @@ def __lt__(self, other): def __hash__(self): return hash((self._item_type, self._id)) - def __lt__(self, other): - return self._id < other._id - def resolve(self, db_id): self._db_id = db_id self._id_map[db_id] = self From 549b2cb0fb656ff8299115b8535d3ff18f252407 Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 16 Feb 2024 11:52:34 +0100 Subject: [PATCH 287/317] Resolve values before setting query filters --- spinedb_api/db_mapping_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 264629a7..7c79ca36 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -109,6 +109,7 @@ def _make_query(self, item_type, **kwargs): for key, value in kwargs.items(): if isinstance(value, tuple): continue + value = resolve(value) if hasattr(sq.c, key): qry = qry.filter(getattr(sq.c, key) == value) elif key in self.item_factory(item_type)._external_fields: From aaa889d591b6d4798467081a09d49180cb2709e4 Mon Sep 17 00:00:00 2001 From: Manuel Date: Sun, 18 Feb 2024 12:43:06 +0100 Subject: [PATCH 288/317] Introduce MappedItemBase.validate This is to make client code more intuitive (before they sometimes needed to call is_valid() and discard the result, just to do the validation). Now they can just call validate(). --- spinedb_api/db_mapping_base.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 7c79ca36..9586dbea 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -182,7 +182,7 @@ def _dirty_items(self): to_remove.extend(mapped_table.values()) else: for item in mapped_table.values(): - _ = item.is_valid() + item.validate() if item.status == Status.to_remove: to_remove.append(item) if to_add or to_update or to_remove: @@ -696,7 +696,6 @@ def __init__(self, db_map, item_type, **kwargs): self.remove_callbacks = set() self._has_valid_id = True self._removed = False - self._corrupted = False self._valid = None self._status = Status.committed self._removal_source = None @@ -716,7 +715,6 @@ def handle_refetch(self): self.status = Status.committed if self.is_committed(): self._removed = False - self._corrupted = False self._valid = None def handle_refresh(self): @@ -878,7 +876,7 @@ def first_invalid_key(self): def _resolve_refs(self): """Goes through the ``_references`` class attribute and tries to resolve them. - If successful, replace source fields referring to db-ids with the reference TempId. + If successful, replace source fields referring to db-ids with the reference's TempId. Yields: tuple(str,MappedItem or None): the source field and resolved ref. @@ -1024,16 +1022,18 @@ def is_valid(self): """ if self.status == Status.compromised: return False + self.validate() + return self._valid + + def validate(self): + """Resolves all references and checks if the item is valid. + The item is valid if it's not removed, has all of its references, and none of them is removed.""" if self._valid is not None: - return self._valid - if self._removed or self._corrupted: - return False + return refs = [ref for _, ref in self._resolve_refs()] - self._corrupted = not all(refs) - if any(ref and ref.removed for ref in refs): + self._valid = not self._removed and all(ref and not ref.removed for ref in refs) + if not self._valid: self.cascade_remove() - self._valid = not self._removed and not self._corrupted - return self._valid def add_referrer(self, referrer): """Adds a strong referrer to this item. Strong referrers are removed, updated and restored @@ -1089,6 +1089,7 @@ def cascade_restore(self, source=None): else: raise RuntimeError("invalid status for item being restored") self._removed = False + self._valid = None # First restore this, then referrers obsolete = set() for callback in list(self.restore_callbacks): @@ -1282,6 +1283,9 @@ def __str__(self): def get(self, key, default=None): return self._mapped_item.get(key, default) + def validate(self): + self._mapped_item.validate() + def is_valid(self): return self._mapped_item.is_valid() From 6b488078e5eeb6199e0eb322af1368667c669df3 Mon Sep 17 00:00:00 2001 From: Manuel Date: Mon, 19 Feb 2024 15:26:09 +0100 Subject: [PATCH 289/317] Introduce API to backup DBs --- spinedb_api/db_mapping.py | 65 +++++++++++++++++++++++++++++++++++++-- spinedb_api/helpers.py | 2 +- 2 files changed, 64 insertions(+), 3 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 5243ddf0..4881b376 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -126,6 +126,7 @@ def __init__( db_url, username=None, upgrade=False, + backup_url="", codename=None, create=False, apply_filters=True, @@ -139,6 +140,7 @@ def __init__( username (str, optional): A user name. If not given, it gets replaced by the string `anon`. upgrade (bool, optional): Whether the DB at the given `url` should be upgraded to the most recent version. + backup_url (str, optional): A URL to backup the DB before upgrading. codename (str, optional): A name to identify this object in your application. create (bool, optional): Whether to create a new Spine DB at the given `url` if it's not already one. apply_filters (bool, optional): Whether to apply filters in the `url`'s query segment. @@ -165,7 +167,7 @@ def __init__( self._memory = memory self._memory_dirty = False self._original_engine = self.create_engine( - self.sa_url, upgrade=upgrade, create=create, sqlite_timeout=sqlite_timeout + self.sa_url, create=create, upgrade=upgrade, backup_url=backup_url, sqlite_timeout=sqlite_timeout ) # NOTE: The NullPool is needed to receive the close event (or any events), for some reason self.engine = create_engine("sqlite://", poolclass=NullPool) if self._memory else self._original_engine @@ -219,7 +221,63 @@ def _make_codename(self, codename): return hashing.hexdigest() @staticmethod - def create_engine(sa_url, upgrade=False, create=False, sqlite_timeout=1800): + def get_upgrade_db_prompt_data(url, create=False): + """Returns data to prompt the user what to do if the DB at the given url is not the latest version. + If it is, then returns None. + + Args: + url (str) + create (bool,optional) + + Returns: + str: The title of the prompt + str: The text of the prompt + dict: Mapping different options, to kwargs to pass to DatabaseMapping constructor in order to apply them + dict or None: Mapping different options, to additional notes + int or None: The preferred option if any + """ + sa_url = make_url(url) + try: + DatabaseMapping.create_engine(sa_url, create=create) + return None + except SpineDBVersionError as v_err: + if v_err.upgrade_available: + title = "Incompatible database version" + text = ( + f"The database at
'{sa_url}'

is at revision {v_err.current} " + f"and needs to be upgraded to revision {v_err.expected} " + "in order to be used with the current version of Spine." + "

WARNING: After the upgrade, the database may no longer be used with previous versions." + ) + if sa_url.drivername == "sqlite": + folder_name, file_name = os.path.split(sa_url.database) + file_name, _ = os.path.splitext(file_name) + else: + folder_name = os.path.expanduser("~") + file_name = sa_url.database + database = os.path.join(folder_name, file_name + "." + v_err.current) + backup_url = URL("sqlite", database=database) + option_to_kwargs = { + "Do not upgrade": {}, + "Just upgrade": dict(upgrade=True), + "Backup and upgrade": dict(upgrade=True, backup_url=backup_url), + } + notes = {"Backup and upgrade": f"The backup will be written at '{backup_url}'"} + preferred = 2 + else: + title = "Unsupported database version" + text = ( + f"The database at

'{sa_url}'

is at revision {v_err.current} " + f"while this version of Spine supports revisions up to {v_err.expected}." + "

Please upgrade Spine to use this database." + ) + option_to_kwargs = {} + notes = None + preferred = None + return title, text, option_to_kwargs, notes, preferred + + @staticmethod + def create_engine(sa_url, create=False, upgrade=False, backup_url="", sqlite_timeout=1800): if sa_url.drivername == "sqlite": connect_args = {'timeout': sqlite_timeout} else: @@ -264,6 +322,9 @@ def create_engine(sa_url, upgrade=False, create=False, sqlite_timeout=1800): url=sa_url, current=current, expected=head, upgrade_available=False ) from None raise SpineDBVersionError(url=sa_url, current=current, expected=head) + if backup_url: + dst_engine = create_engine(backup_url) + copy_database_bind(dst_engine, engine) # Upgrade function def upgrade_to_head(rev, context): diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index b1b72970..4fcf0b8c 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -272,7 +272,7 @@ def copy_database_bind(dest_bind, source_bind, overwrite=True, upgrade=False, on try: dest_bind.execute(ins, data) except IntegrityError as e: - warnings.warn("Skipping table {0}: {1}".format(source_table.name, e.orig.args)) + warnings.warn(f"Skipping table {source_table.name}: {e.orig.args}") def custom_generate_relationship(base, direction, return_fn, attrname, local_cls, referred_cls, **kw): From 204edd406b66a71246be421fc5ba4a4b876ba63c Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 20 Feb 2024 07:35:53 +0100 Subject: [PATCH 290/317] Remove 'do not upgrade' as an option --- spinedb_api/db_mapping.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 4881b376..08554bb5 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -258,12 +258,11 @@ def get_upgrade_db_prompt_data(url, create=False): database = os.path.join(folder_name, file_name + "." + v_err.current) backup_url = URL("sqlite", database=database) option_to_kwargs = { - "Do not upgrade": {}, - "Just upgrade": dict(upgrade=True), "Backup and upgrade": dict(upgrade=True, backup_url=backup_url), + "Just upgrade": dict(upgrade=True), } notes = {"Backup and upgrade": f"The backup will be written at '{backup_url}'"} - preferred = 2 + preferred = 0 else: title = "Unsupported database version" text = ( From 4b8ef692282c8e5a93a89aa5123b846ea0cf46f7 Mon Sep 17 00:00:00 2001 From: Manuel Date: Tue, 20 Feb 2024 17:46:39 +0100 Subject: [PATCH 291/317] Lock sqlite DB before upgrade Because some clients might actually upgrade concurrently --- spinedb_api/db_mapping.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 08554bb5..2b9ca230 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -256,7 +256,7 @@ def get_upgrade_db_prompt_data(url, create=False): folder_name = os.path.expanduser("~") file_name = sa_url.database database = os.path.join(folder_name, file_name + "." + v_err.current) - backup_url = URL("sqlite", database=database) + backup_url = str(URL("sqlite", database=database)) option_to_kwargs = { "Backup and upgrade": dict(upgrade=True, backup_url=backup_url), "Just upgrade": dict(upgrade=True), @@ -294,7 +294,10 @@ def create_engine(sa_url, create=False, upgrade=False, backup_url="", sqlite_tim config.set_main_option("script_location", "spinedb_api:alembic") script = ScriptDirectory.from_config(config) head = script.get_current_head() - with engine.connect() as connection: + with engine.begin() as connection: + if sa_url.drivername == "sqlite": + connection.execute("BEGIN IMMEDIATE") + # TODO: Do other dialects need to lock? migration_context = MigrationContext.configure(connection) try: current = migration_context.get_current_revision() From d806612e7fdca6b17711fe9166057a4bc675df5d Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 21 Feb 2024 10:12:04 +0100 Subject: [PATCH 292/317] Fix engine creation for DB mapping when also creating the DB --- spinedb_api/db_mapping.py | 13 +++++++------ spinedb_api/helpers.py | 17 +++++++++++------ 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 2b9ca230..e5ec771a 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -45,7 +45,7 @@ from .compatibility import compatibility_transformations from .helpers import ( _create_first_spine_database, - create_new_spine_database, + create_new_spine_database_from_bind, compare_schemas, model_meta, copy_database_bind, @@ -290,10 +290,6 @@ def create_engine(sa_url, create=False, upgrade=False, backup_url="", sqlite_tim f"Could not connect to '{sa_url}': {str(e)}. " f"Please make sure that '{sa_url}' is a valid sqlalchemy URL." ) from None - config = Config() - config.set_main_option("script_location", "spinedb_api:alembic") - script = ScriptDirectory.from_config(config) - head = script.get_current_head() with engine.begin() as connection: if sa_url.drivername == "sqlite": connection.execute("BEGIN IMMEDIATE") @@ -313,7 +309,12 @@ def create_engine(sa_url, create=False, upgrade=False, backup_url="", sqlite_tim "Unable to determine db revision. " f"Please check that\n\n\t{sa_url}\n\nis the URL of a valid Spine db." ) - return create_new_spine_database(sa_url) + create_new_spine_database_from_bind(connection) + return engine + config = Config() + config.set_main_option("script_location", "spinedb_api:alembic") + script = ScriptDirectory.from_config(config) + head = script.get_current_head() if current != head: if not upgrade: try: diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index 4fcf0b8c..796c9b8f 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -10,6 +10,7 @@ # this program. If not, see . ###################################################################################################################### """ General helper functions. """ + import os import json import warnings @@ -621,21 +622,25 @@ def create_new_spine_database(db_url): engine = create_engine(db_url) except DatabaseError as e: raise SpineDBAPIError(f"Could not connect to '{db_url}': {e.orig.args}") from None + create_new_spine_database_from_bind(engine) + return engine + + +def create_new_spine_database_from_bind(bind): # Drop existing tables. This is a Spine db now... - meta = MetaData(engine) + meta = MetaData(bind) meta.reflect() meta.drop_all() # Create new tables meta = create_spine_metadata() version = get_head_alembic_version() try: - meta.create_all(engine) - engine.execute("INSERT INTO `commit` VALUES (1, 'Create the database', CURRENT_TIMESTAMP, 'spinedb_api')") - engine.execute("INSERT INTO alternative VALUES (1, 'Base', 'Base alternative', 1)") - engine.execute(f"INSERT INTO alembic_version VALUES ('{version}')") + meta.create_all(bind) + bind.execute("INSERT INTO `commit` VALUES (1, 'Create the database', CURRENT_TIMESTAMP, 'spinedb_api')") + bind.execute("INSERT INTO alternative VALUES (1, 'Base', 'Base alternative', 1)") + bind.execute(f"INSERT INTO alembic_version VALUES ('{version}')") except DatabaseError as e: raise SpineDBAPIError(f"Unable to create Spine database: {e}") from None - return engine def _create_first_spine_database(db_url): From 38b44a459ad6e12cb767cc40462f8c1d245cf799 Mon Sep 17 00:00:00 2001 From: Manuel Date: Wed, 21 Feb 2024 16:31:39 +0100 Subject: [PATCH 293/317] Only use existing tools when migrating, not when committing Re #355 --- ...b_add_active_by_default_to_entity_class.py | 2 +- ...a82ed59_create_entity_alternative_table.py | 2 +- spinedb_api/compatibility.py | 88 +++++++++++-------- 3 files changed, 52 insertions(+), 40 deletions(-) diff --git a/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py b/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py index 35399f3e..a90508f6 100644 --- a/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py +++ b/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py @@ -39,7 +39,7 @@ def upgrade(): class_table.update().where(class_table.c.id == sa.bindparam("target_id")).values(active_by_default=True) ) conn.execute(update_statement, [{"target_id": class_id} for class_id in dimensional_class_ids]) - convert_tool_feature_method_to_active_by_default(conn) + convert_tool_feature_method_to_active_by_default(conn, use_existing_tool_feature_method=True) def downgrade(): diff --git a/spinedb_api/alembic/versions/ce9faa82ed59_create_entity_alternative_table.py b/spinedb_api/alembic/versions/ce9faa82ed59_create_entity_alternative_table.py index cb011771..7c3b2dd7 100644 --- a/spinedb_api/alembic/versions/ce9faa82ed59_create_entity_alternative_table.py +++ b/spinedb_api/alembic/versions/ce9faa82ed59_create_entity_alternative_table.py @@ -47,7 +47,7 @@ def upgrade(): op.drop_table('next_id') except sa.exc.OperationalError: pass - convert_tool_feature_method_to_entity_alternative(op.get_bind()) + convert_tool_feature_method_to_entity_alternative(op.get_bind(), use_existing_tool_feature_method=True) def downgrade(): diff --git a/spinedb_api/compatibility.py b/spinedb_api/compatibility.py index 8f9d9eb4..306628a2 100644 --- a/spinedb_api/compatibility.py +++ b/spinedb_api/compatibility.py @@ -15,12 +15,13 @@ import sqlalchemy as sa -def convert_tool_feature_method_to_active_by_default(conn): +def convert_tool_feature_method_to_active_by_default(conn, use_existing_tool_feature_method): """Transforms default parameter values into active_by_default values, whenever the former are used in a tool filter to control entity activity. Args: conn (Connection) + use_existing_tool_feature_method (Bool): Whether to use existing tool/feature/method definitions. Returns: tuple: list of entity classes to add, update and ids to remove @@ -29,23 +30,26 @@ def convert_tool_feature_method_to_active_by_default(conn): meta.reflect() lv_table = meta.tables["list_value"] pd_table = meta.tables["parameter_definition"] - try: - # Compute list-value id by parameter definition id for all features and methods - tfm_table = meta.tables["tool_feature_method"] - tf_table = meta.tables["tool_feature"] - f_table = meta.tables["feature"] - lv_id_by_pdef_id = { - x["parameter_definition_id"]: x["id"] - for x in conn.execute( - sa.select([lv_table.c.id, f_table.c.parameter_definition_id]) - .where(tfm_table.c.parameter_value_list_id == lv_table.c.parameter_value_list_id) - .where(tfm_table.c.method_index == lv_table.c.index) - .where(tf_table.c.id == tfm_table.c.tool_feature_id) - .where(f_table.c.id == tf_table.c.feature_id) - ) - } - except KeyError: - # It's a new DB without tool/feature/method + if use_existing_tool_feature_method: + try: + # Compute list-value id by parameter definition id for all features and methods + tfm_table = meta.tables["tool_feature_method"] + tf_table = meta.tables["tool_feature"] + f_table = meta.tables["feature"] + lv_id_by_pdef_id = { + x["parameter_definition_id"]: x["id"] + for x in conn.execute( + sa.select([lv_table.c.id, f_table.c.parameter_definition_id]) + .where(tfm_table.c.parameter_value_list_id == lv_table.c.parameter_value_list_id) + .where(tfm_table.c.method_index == lv_table.c.index) + .where(tf_table.c.id == tfm_table.c.tool_feature_id) + .where(f_table.c.id == tf_table.c.feature_id) + ) + } + except KeyError: + use_existing_tool_feature_method = False + if not use_existing_tool_feature_method: + # It's a new DB without tool/feature/method or we don't want to use them... # we take 'is_active' as feature and JSON "yes" and true as methods lv_id_by_pdef_id = { x["parameter_definition_id"]: x["id"] @@ -100,12 +104,13 @@ def convert_tool_feature_method_to_active_by_default(conn): return [], updated_items, [] -def convert_tool_feature_method_to_entity_alternative(conn): +def convert_tool_feature_method_to_entity_alternative(conn, use_existing_tool_feature_method): """Transforms parameter_value rows into entity_alternative rows, whenever the former are used in a tool filter to control entity activity. Args: conn (Connection) + use_existing_tool_feature_method (Bool): Whether to use existing tool/feature/method definitions. Returns: list: entity_alternative items to add @@ -117,23 +122,26 @@ def convert_tool_feature_method_to_entity_alternative(conn): ea_table = meta.tables["entity_alternative"] lv_table = meta.tables["list_value"] pv_table = meta.tables["parameter_value"] - try: - # Compute list-value id by parameter definition id for all features and methods - tfm_table = meta.tables["tool_feature_method"] - tf_table = meta.tables["tool_feature"] - f_table = meta.tables["feature"] - lv_id_by_pdef_id = { - x["parameter_definition_id"]: x["id"] - for x in conn.execute( - sa.select([lv_table.c.id, f_table.c.parameter_definition_id]) - .where(tfm_table.c.parameter_value_list_id == lv_table.c.parameter_value_list_id) - .where(tfm_table.c.method_index == lv_table.c.index) - .where(tf_table.c.id == tfm_table.c.tool_feature_id) - .where(f_table.c.id == tf_table.c.feature_id) - ) - } - except KeyError: - # It's a new DB without tool/feature/method + if use_existing_tool_feature_method: + try: + # Compute list-value id by parameter definition id for all features and methods + tfm_table = meta.tables["tool_feature_method"] + tf_table = meta.tables["tool_feature"] + f_table = meta.tables["feature"] + lv_id_by_pdef_id = { + x["parameter_definition_id"]: x["id"] + for x in conn.execute( + sa.select([lv_table.c.id, f_table.c.parameter_definition_id]) + .where(tfm_table.c.parameter_value_list_id == lv_table.c.parameter_value_list_id) + .where(tfm_table.c.method_index == lv_table.c.index) + .where(tf_table.c.id == tfm_table.c.tool_feature_id) + .where(f_table.c.id == tf_table.c.feature_id) + ) + } + except KeyError: + use_existing_tool_feature_method = False + if not use_existing_tool_feature_method: + # It's a new DB without tool/feature/method or we don't want to use them... # we take 'is_active' as feature and JSON "yes" and true as methods pd_table = meta.tables["parameter_definition"] lv_id_by_pdef_id = { @@ -198,7 +206,9 @@ def compatibility_transformations(connection): tuple(list, list): list of tuples (tablename, (items_added, items_updated, ids_removed)), and list of strings indicating the changes """ - ea_items_added, ea_items_updated, pval_ids_removed = convert_tool_feature_method_to_entity_alternative(connection) + ea_items_added, ea_items_updated, pval_ids_removed = convert_tool_feature_method_to_entity_alternative( + connection, use_existing_tool_feature_method=False + ) transformations = [] info = [] if ea_items_added or ea_items_updated: @@ -207,7 +217,9 @@ def compatibility_transformations(connection): transformations.append(("parameter_value", ((), (), pval_ids_removed))) if ea_items_added or ea_items_updated or pval_ids_removed: info.append("Convert entity activity control using tool/feature/method into entity_alternative") - _, ec_items_updated, _ = convert_tool_feature_method_to_active_by_default(connection) + _, ec_items_updated, _ = convert_tool_feature_method_to_active_by_default( + connection, use_existing_tool_feature_method=False + ) if ec_items_updated: transformations.append(("entity_class", ((), ec_items_updated, ()))) return transformations, info From 30fdba0d58790d4ba04b7170a9e74197f02be6b4 Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 23 Feb 2024 13:24:16 +0100 Subject: [PATCH 294/317] Try and make sure TempIds work with the DB Server --- spinedb_api/server_client_helpers.py | 3 +++ spinedb_api/spine_db_server.py | 9 ++++++--- spinedb_api/temp_id.py | 17 ++++++++++++++--- 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/spinedb_api/server_client_helpers.py b/spinedb_api/server_client_helpers.py index 86b33fb3..1dd44e67 100644 --- a/spinedb_api/server_client_helpers.py +++ b/spinedb_api/server_client_helpers.py @@ -13,6 +13,7 @@ import json from .exception import SpineDBAPIError from .db_mapping_base import PublicItem +from .temp_id import TempId # Encode decode server messages _START_OF_TAIL = '\u001f' # Unit separator @@ -68,6 +69,8 @@ def default(self, o): return str(o) if isinstance(o, PublicItem): return o._extended() + if isinstance(o, TempId): + return o.private_id return super().default(o) @property diff --git a/spinedb_api/spine_db_server.py b/spinedb_api/spine_db_server.py index 827e34bf..5604823f 100644 --- a/spinedb_api/spine_db_server.py +++ b/spinedb_api/spine_db_server.py @@ -411,9 +411,12 @@ def _do_export_data(self, **kwargs): return dict(result=export_data(self._db_map, parse_value=_parse_value, **kwargs)) def _do_call_method(self, method_name, *args, **kwargs): - method = getattr(self._db_map, method_name) - result = method(*args, **kwargs) - return dict(result=result) + try: + method = getattr(self._db_map, method_name) + result = method(*args, **kwargs) + return dict(result=result) + except Exception as err: + return dict(error=str(err)) def _do_clear_filters(self): self._db_map.restore_entity_sq_maker() diff --git a/spinedb_api/temp_id.py b/spinedb_api/temp_id.py index 948f13ec..b6ebe6bc 100644 --- a/spinedb_api/temp_id.py +++ b/spinedb_api/temp_id.py @@ -9,16 +9,22 @@ # this program. If not, see . ###################################################################################################################### -import uuid - class TempId: + _next_id = {} + def __init__(self, item_type, id_map): super().__init__() - self._id = uuid.uuid4() + self._id = self._next_id.get(item_type, -1) + self._next_id[item_type] = self._id - 1 self._item_type = item_type self._id_map = id_map self._db_id = None + self._id_map[self._id] = self + + @property + def private_id(self): + return self._id @property def db_id(self): @@ -41,10 +47,15 @@ def __hash__(self): return hash((self._item_type, self._id)) def resolve(self, db_id): + self.unresolve() self._db_id = db_id self._id_map[db_id] = self def unresolve(self): + if self._db_id is None: + return + if self._id_map[self._db_id] is self: + del self._id_map[self._db_id] self._db_id = None From 42c3644ebc021901a5aa28d8766d78b062e5ad1d Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 23 Feb 2024 13:58:27 +0100 Subject: [PATCH 295/317] Make sure all ids are replaced by TempId in referrers Fixes spine-tools/Spine-Toolbox#2609 --- spinedb_api/db_mapping_base.py | 49 +++++++++++++++++++--------------- spinedb_api/mapped_items.py | 4 +-- 2 files changed, 28 insertions(+), 25 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 9586dbea..fa025df1 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -882,22 +882,27 @@ def _resolve_refs(self): tuple(str,MappedItem or None): the source field and resolved ref. """ for src_key, (ref_type, ref_key) in self._references.items(): - try: - src_val = self[src_key] - except KeyError: - yield src_key, None + ref = self._get_full_ref(src_key, ref_type, ref_key) + if isinstance(ref, tuple): + for r in ref: + yield src_key, r else: - if isinstance(src_val, tuple): - refs = tuple(self._get_ref(ref_type, {ref_key: x}) for x in src_val) - if all(refs) and ref_key == "id": - self[src_key] = tuple(ref["id"] for ref in refs) - for ref in refs: - yield src_key, ref - else: - ref = self._get_ref(ref_type, {ref_key: src_val}) - if ref and ref_key == "id": - self[src_key] = ref["id"] - yield src_key, ref + yield src_key, ref + + def _get_full_ref(self, src_key, ref_type, ref_key, strong=True): + try: + src_val = self[src_key] + except KeyError: + return {} + if isinstance(src_val, tuple): + ref = tuple(self._get_ref(ref_type, {ref_key: x}, strong=strong) for x in src_val) + if all(ref) and ref_key == "id": + self[src_key] = tuple(r["id"] for r in ref) + return ref + ref = self._get_ref(ref_type, {ref_key: src_val}, strong=strong) + if ref and ref_key == "id": + self[src_key] = ref["id"] + return ref @classmethod def unique_values_for_item(cls, item, skip_keys=()): @@ -1184,14 +1189,14 @@ def __getattr__(self, name): def __getitem__(self, key): """Overridden to return references.""" - source_target_key_tuple = self._external_fields.get(key) - if source_target_key_tuple: - source_key, target_key = source_target_key_tuple + source_and_target_key = self._external_fields.get(key) + if source_and_target_key: + source_key, target_key = source_and_target_key ref_type, ref_key = self._references[source_key] - source_val = self[source_key] - if isinstance(source_val, tuple): - return tuple(self._get_ref(ref_type, {ref_key: x}).get(target_key) for x in source_val) - return self._get_ref(ref_type, {ref_key: source_val}).get(target_key) + ref = self._get_full_ref(source_key, ref_type, ref_key) + if isinstance(ref, tuple): + return tuple(r.get(target_key) for r in ref) + return ref.get(target_key) return super().__getitem__(key) def __setitem__(self, key, value): diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index f0066e6b..e7f0c878 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -489,9 +489,7 @@ def __getitem__(self, key): if key == "parameter_value_list_id": return dict.get(self, key) if key == "parameter_value_list_name": - return self._get_ref("parameter_value_list", {"id": self["parameter_value_list_id"]}, strong=False).get( - "name" - ) + return self._get_full_ref("parameter_value_list_id", "parameter_value_list", "id", strong=False).get("name") if key in ("default_value", "default_type"): list_value_id = self["list_value_id"] if list_value_id is not None: From 98561f74aa0f8d04e1027b8a04dbc749e1a45325 Mon Sep 17 00:00:00 2001 From: Manuel Date: Fri, 23 Feb 2024 14:27:56 +0100 Subject: [PATCH 296/317] Add some tests for the DB server --- spinedb_api/db_mapping_base.py | 6 +-- spinedb_api/temp_id.py | 12 ++--- tests/test_db_server.py | 82 ++++++++++++++++++++++++++++++++++ 3 files changed, 91 insertions(+), 9 deletions(-) create mode 100644 tests/test_db_server.py diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index fa025df1..8c4a2ba5 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -363,7 +363,7 @@ def __init__(self, db_map, item_type, *args, **kwargs): self._db_map = db_map self._item_type = item_type self._ids_by_unique_key_value = {} - self._temp_id_by_db_id = {} + self._temp_id_lookup = {} self.wildcard_item = MappedItemBase(self._db_map, self._item_type, id=Asterisk) @property @@ -375,11 +375,11 @@ def purged(self, purged): self.wildcard_item.status = Status.to_remove if purged else Status.committed def get(self, id_, default=None): - id_ = self._temp_id_by_db_id.get(id_, id_) + id_ = self._temp_id_lookup.get(id_, id_) return super().get(id_, default) def _new_id(self): - return TempId(self._item_type, self._temp_id_by_db_id) + return TempId(self._item_type, self._temp_id_lookup) def _unique_key_value_to_id(self, key, value, fetch=True): """Returns the id that has the given value for the given unique key, or None. diff --git a/spinedb_api/temp_id.py b/spinedb_api/temp_id.py index b6ebe6bc..90f6fd54 100644 --- a/spinedb_api/temp_id.py +++ b/spinedb_api/temp_id.py @@ -13,14 +13,14 @@ class TempId: _next_id = {} - def __init__(self, item_type, id_map): + def __init__(self, item_type, temp_id_lookup): super().__init__() self._id = self._next_id.get(item_type, -1) self._next_id[item_type] = self._id - 1 self._item_type = item_type - self._id_map = id_map + self._temp_id_lookup = temp_id_lookup self._db_id = None - self._id_map[self._id] = self + self._temp_id_lookup[self._id] = self @property def private_id(self): @@ -49,13 +49,13 @@ def __hash__(self): def resolve(self, db_id): self.unresolve() self._db_id = db_id - self._id_map[db_id] = self + self._temp_id_lookup[db_id] = self def unresolve(self): if self._db_id is None: return - if self._id_map[self._db_id] is self: - del self._id_map[self._db_id] + if self._temp_id_lookup[self._db_id] is self: + del self._temp_id_lookup[self._db_id] self._db_id = None diff --git a/tests/test_db_server.py b/tests/test_db_server.py new file mode 100644 index 00000000..cf02ecfd --- /dev/null +++ b/tests/test_db_server.py @@ -0,0 +1,82 @@ +###################################################################################################################### +# Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors +# This file is part of Spine Database API. +# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser +# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General +# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with +# this program. If not, see . +###################################################################################################################### +""" Unit tests for spine_db_server module. """ +import os +import unittest +import threading +from tempfile import TemporaryDirectory +from spinedb_api.spine_db_server import db_server_manager, closing_spine_db_server +from spinedb_api.spine_db_client import SpineDBClient +from spinedb_api.db_mapping import DatabaseMapping + + +class TestDBServer(unittest.TestCase): + def test_use_id_from_server_response(self): + with TemporaryDirectory() as temp_dir: + db_url = "sqlite:///" + os.path.join(temp_dir, "database.sqlite") + with DatabaseMapping(db_url, create=True) as db_map: + db_map.add_entity_class_item(name="fish") + db_map.commit_session("Fishing") + with closing_spine_db_server(db_url) as server_url: + client = SpineDBClient.from_server_url(server_url) + fish = client.call_method("get_entity_class_item", name="fish")["result"] + mouse = client.call_method("update_entity_class_item", id=fish["id"], name="mouse") + client.call_method("commit_session", "Mousing") + with DatabaseMapping(db_url) as db_map: + fish = db_map.get_entity_class_item(name="fish") + mouse = db_map.get_entity_class_item(name="mouse") + self.assertFalse(fish) + self.assertTrue(mouse) + self.assertEqual(mouse["name"], "mouse") + + def test_ordering(self): + def _import_entity_class(server_url, class_name): + client = SpineDBClient.from_server_url(server_url) + client.db_checkin() + _answer = client.import_data({"entity_classes": [(class_name, ())]}, f"Import {class_name}") + client.db_checkout() + + with TemporaryDirectory() as temp_dir: + db_url = "sqlite:///" + os.path.join(temp_dir, "database.sqlite") + with db_server_manager() as mngr_queue: + first_ordering = { + "id": "second_before_first", + "current": "first", + "precursors": {"second"}, + "part_count": 1, + } + second_ordering = { + "id": "second_before_first", + "current": "second", + "precursors": set(), + "part_count": 1, + } + with closing_spine_db_server( + db_url, server_manager_queue=mngr_queue, ordering=first_ordering + ) as first_server_url: + with closing_spine_db_server( + db_url, server_manager_queue=mngr_queue, ordering=second_ordering + ) as second_server_url: + t1 = threading.Thread(target=_import_entity_class, args=(first_server_url, "monkey")) + t2 = threading.Thread(target=_import_entity_class, args=(second_server_url, "donkey")) + t1.start() + with DatabaseMapping(db_url) as db_map: + assert db_map.get_items("entity_class") == [] # Nothing written yet + t2.start() + t1.join() + t2.join() + with DatabaseMapping(db_url) as db_map: + self.assertEqual([x["name"] for x in db_map.get_items("entity_class")], ["donkey", "monkey"]) + + +if __name__ == "__main__": + unittest.main() From 8580b06636579fbc78cdcee19015bfffcb69a74a Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Mon, 26 Feb 2024 10:26:08 +0200 Subject: [PATCH 297/317] Set active_by_default to true when migrating if is_active is missing When migrating, set all active_by_defaults to true unless is_active (or equivalent) has been defined. This mimics the pre-0.8 behavior where all classes were visible by default. Re spine-tools/Spine-Toolbox#2611 --- ...eff478bcb_add_active_by_default_to_entity_class.py | 11 ++--------- spinedb_api/compatibility.py | 5 +++-- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py b/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py index a90508f6..996cbd24 100644 --- a/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py +++ b/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py @@ -26,19 +26,12 @@ def upgrade(): ), ) conn = op.get_bind() - session = sa.orm.sessionmaker(bind=conn)() metadata = sa.MetaData() metadata.reflect(bind=conn) - dimension_table = metadata.tables["entity_class_dimension"] - dimensional_class_ids = {row.entity_class_id for row in session.query(dimension_table)} - if not dimensional_class_ids: - return metadata.reflect(bind=conn) class_table = metadata.tables["entity_class"] - update_statement = ( - class_table.update().where(class_table.c.id == sa.bindparam("target_id")).values(active_by_default=True) - ) - conn.execute(update_statement, [{"target_id": class_id} for class_id in dimensional_class_ids]) + update_statement = class_table.update().values(active_by_default=True) + conn.execute(update_statement) convert_tool_feature_method_to_active_by_default(conn, use_existing_tool_feature_method=True) diff --git a/spinedb_api/compatibility.py b/spinedb_api/compatibility.py index 306628a2..deef839f 100644 --- a/spinedb_api/compatibility.py +++ b/spinedb_api/compatibility.py @@ -80,10 +80,11 @@ def convert_tool_feature_method_to_active_by_default(conn, use_existing_tool_fea # where active_by_default is True if the value of 'is_active' is the one from the tool_feature_method specification entity_class_items_to_update = { x["entity_class_id"]: { - "active_by_default": x["list_value_id"] == lv_id_by_pdef_id[x["parameter_definition_id"]], + "active_by_default": False + if x["list_value_id"] is None + else x["list_value_id"] == lv_id_by_pdef_id[x["parameter_definition_id"]], } for x in is_active_default_vals - if x["list_value_id"] is not None } updated_items = [] entity_class_table = meta.tables["entity_class"] From 20ec2ad67c8ccae34be1c1f26431e022462aa546 Mon Sep 17 00:00:00 2001 From: Henrik Koski Date: Fri, 8 Mar 2024 13:04:51 +0200 Subject: [PATCH 298/317] Fix time-series db export Now the exported time-series don't show up as python objects in the generated Excel sheets, but instead they are correctly unpacked into a pivot table. Re #2601 --- tests/test_import_functions.py | 47 +++++++++++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py index 10a930b7..68f5650a 100644 --- a/tests/test_import_functions.py +++ b/tests/test_import_functions.py @@ -36,7 +36,7 @@ import_relationship_parameter_value_metadata, import_data, ) -from spinedb_api.parameter_value import from_database +from spinedb_api.parameter_value import from_database, dump_db_value, TimeSeriesFixedResolution def assert_import_equivalent(test, obs, exp, strict=True): @@ -1135,6 +1135,51 @@ def test_non_existent_relationship_parameter_value_from_value_list_fails_gracefu self.assertEqual(len(errors), 1) db_map.close() + def test_unparse_value_imports_fields_correctly(self): + with DatabaseMapping("sqlite:///", create=True) as db_map: + data = { + 'entity_classes': [('A', (), None, None, False)], + 'entities': [('A', 'aa', None)], + 'parameter_definitions': [('A', 'test1', None, None, None)], + 'parameter_values': [( + 'A', + 'aa', + 'test1', + { + 'type': 'time_series', + 'index': { + 'start': '2000-01-01 00:00:00', + 'resolution': '1h', + 'ignore_year': False, + 'repeat': False + }, + 'data': [0.0, 1.0, 2.0, 4.0, 8.0, 0.0] + }, + 'Base' + )], + 'alternatives': [('Base', 'Base alternative')]} + + count, errors = import_data(db_map, **data, unparse_value=dump_db_value) + self.assertEqual(errors, []) + self.assertEqual(count, 4) + db_map.commit_session("add test data") + value = db_map.query(db_map.entity_parameter_value_sq).one() + self.assertEqual(value.type, "time_series") + self.assertEqual(value.parameter_name, "test1") + self.assertEqual(value.alternative_name, "Base") + self.assertEqual(value.entity_class_name, "A") + self.assertEqual(value.entity_name, "aa") + + time_series = from_database(value.value, value.type) + expected_result = TimeSeriesFixedResolution( + '2000-01-01 00:00:00', + '1h', + [0.0, 1.0, 2.0, 4.0, 8.0, 0.0], + False, + False + ) + self.assertEqual(time_series, expected_result) + class TestImportParameterValueList(unittest.TestCase): def setUp(self): From 909eccd6457e8181a09a40a138ac86f70530cde8 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Fri, 8 Mar 2024 16:19:06 +0200 Subject: [PATCH 299/317] Separate unique id creation from TempId initialization This enables initializing arbitrary TempIds making it much more useful. Those needing unique ids should use TempId.new_unique(). --- spinedb_api/db_mapping_base.py | 2 +- spinedb_api/temp_id.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/spinedb_api/db_mapping_base.py b/spinedb_api/db_mapping_base.py index 8c4a2ba5..b1b4ed8e 100644 --- a/spinedb_api/db_mapping_base.py +++ b/spinedb_api/db_mapping_base.py @@ -379,7 +379,7 @@ def get(self, id_, default=None): return super().get(id_, default) def _new_id(self): - return TempId(self._item_type, self._temp_id_lookup) + return TempId.new_unique(self._item_type, self._temp_id_lookup) def _unique_key_value_to_id(self, key, value, fetch=True): """Returns the id that has the given value for the given unique key, or None. diff --git a/spinedb_api/temp_id.py b/spinedb_api/temp_id.py index 90f6fd54..bddd5ec1 100644 --- a/spinedb_api/temp_id.py +++ b/spinedb_api/temp_id.py @@ -13,15 +13,20 @@ class TempId: _next_id = {} - def __init__(self, item_type, temp_id_lookup): + def __init__(self, id_, item_type, temp_id_lookup=None): super().__init__() - self._id = self._next_id.get(item_type, -1) - self._next_id[item_type] = self._id - 1 + self._id = id_ self._item_type = item_type - self._temp_id_lookup = temp_id_lookup + self._temp_id_lookup = temp_id_lookup if temp_id_lookup is not None else {} self._db_id = None self._temp_id_lookup[self._id] = self + @staticmethod + def new_unique(item_type, temp_id_lookup): + id_ = TempId._next_id.get(item_type, -1) + TempId._next_id[item_type] = id_ - 1 + return TempId(id_, item_type, temp_id_lookup) + @property def private_id(self): return self._id From 69008957a4baa59c5666e5bdf3525ec72bb7b520 Mon Sep 17 00:00:00 2001 From: Henrik Koski Date: Mon, 11 Mar 2024 13:10:37 +0200 Subject: [PATCH 300/317] Fix entity dimension, value_type -fields in Excel exports Re #2601 --- spinedb_api/spine_io/exporters/excel.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/spinedb_api/spine_io/exporters/excel.py b/spinedb_api/spine_io/exporters/excel.py index 56798905..80b7bc68 100644 --- a/spinedb_api/spine_io/exporters/excel.py +++ b/spinedb_api/spine_io/exporters/excel.py @@ -58,13 +58,13 @@ def _make_preamble(table_name, title_key): class_name = title_key["entity_class_name"] if table_name.endswith(",group"): return {"sheet_type": "object_group", "class_name": class_name} - object_class_id_list = title_key.get("object_class_id_list") - if object_class_id_list is None: + dimension_id_list = title_key.get("dimension_id_list") + if dimension_id_list is None: entity_type = "object" - entity_dim_count = 1 + entity_dim_count = 0 else: entity_type = "relationship" - entity_dim_count = len(object_class_id_list.split(",")) + entity_dim_count = len(dimension_id_list.split(",")) preamble = { "sheet_type": "entity", "entity_type": entity_type, @@ -73,7 +73,7 @@ def _make_preamble(table_name, title_key): } td = title_key.get("type_and_dimensions") if td is not None: - preamble["value_type"] = td[0] + preamble["value_type"] = td[0] if td[0] else "single_value" preamble["index_dim_count"] = td[1] return preamble From 7fe60ec7c315a8a76869c2b13426f0fd970866bf Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Tue, 12 Mar 2024 09:40:04 +0200 Subject: [PATCH 301/317] Make compatibility transform optional on commit Some clients like DB manager may want to apply the compatibility transforms that take place on commit themselves. Re spine-tools/Spine-Toolbox#2625 --- ...b_add_active_by_default_to_entity_class.py | 2 +- ...a82ed59_create_entity_alternative_table.py | 2 +- spinedb_api/compatibility.py | 44 +++++++++++-------- spinedb_api/db_mapping.py | 5 ++- tests/test_import_functions.py | 41 +++++++++-------- 5 files changed, 50 insertions(+), 44 deletions(-) diff --git a/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py b/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py index 996cbd24..3cbd6ede 100644 --- a/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py +++ b/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py @@ -32,7 +32,7 @@ def upgrade(): class_table = metadata.tables["entity_class"] update_statement = class_table.update().values(active_by_default=True) conn.execute(update_statement) - convert_tool_feature_method_to_active_by_default(conn, use_existing_tool_feature_method=True) + convert_tool_feature_method_to_active_by_default(conn, use_existing_tool_feature_method=True, apply=True) def downgrade(): diff --git a/spinedb_api/alembic/versions/ce9faa82ed59_create_entity_alternative_table.py b/spinedb_api/alembic/versions/ce9faa82ed59_create_entity_alternative_table.py index 7c3b2dd7..ccc873e2 100644 --- a/spinedb_api/alembic/versions/ce9faa82ed59_create_entity_alternative_table.py +++ b/spinedb_api/alembic/versions/ce9faa82ed59_create_entity_alternative_table.py @@ -47,7 +47,7 @@ def upgrade(): op.drop_table('next_id') except sa.exc.OperationalError: pass - convert_tool_feature_method_to_entity_alternative(op.get_bind(), use_existing_tool_feature_method=True) + convert_tool_feature_method_to_entity_alternative(op.get_bind(), use_existing_tool_feature_method=True, apply=True) def downgrade(): diff --git a/spinedb_api/compatibility.py b/spinedb_api/compatibility.py index deef839f..c70bf796 100644 --- a/spinedb_api/compatibility.py +++ b/spinedb_api/compatibility.py @@ -15,13 +15,14 @@ import sqlalchemy as sa -def convert_tool_feature_method_to_active_by_default(conn, use_existing_tool_feature_method): +def convert_tool_feature_method_to_active_by_default(conn, use_existing_tool_feature_method, apply): """Transforms default parameter values into active_by_default values, whenever the former are used in a tool filter to control entity activity. Args: conn (Connection) - use_existing_tool_feature_method (Bool): Whether to use existing tool/feature/method definitions. + use_existing_tool_feature_method (bool): Whether to use existing tool/feature/method definitions. + apply (bool): if True, apply the transformations Returns: tuple: list of entity classes to add, update and ids to remove @@ -90,7 +91,8 @@ def convert_tool_feature_method_to_active_by_default(conn, use_existing_tool_fea entity_class_table = meta.tables["entity_class"] update_statement = entity_class_table.update() for class_id, update in entity_class_items_to_update.items(): - conn.execute(update_statement.where(entity_class_table.c.id == class_id), update) + if apply: + conn.execute(update_statement.where(entity_class_table.c.id == class_id), update) update["id"] = class_id updated_items.append(update) parameter_definitions_to_update = ( @@ -99,19 +101,21 @@ def convert_tool_feature_method_to_active_by_default(conn, use_existing_tool_fea update_statement = pd_table.update() for definition_id in parameter_definitions_to_update: update = {"default_value": None, "default_type": None} - conn.execute(update_statement.where(pd_table.c.id == definition_id), update) + if apply: + conn.execute(update_statement.where(pd_table.c.id == definition_id), update) update["id"] = definition_id updated_items.append(update) return [], updated_items, [] -def convert_tool_feature_method_to_entity_alternative(conn, use_existing_tool_feature_method): +def convert_tool_feature_method_to_entity_alternative(conn, use_existing_tool_feature_method, apply): """Transforms parameter_value rows into entity_alternative rows, whenever the former are used in a tool filter to control entity activity. Args: conn (Connection) - use_existing_tool_feature_method (Bool): Whether to use existing tool/feature/method definitions. + use_existing_tool_feature_method (bool): Whether to use existing tool/feature/method definitions. + apply (bool): Returns: list: entity_alternative items to add @@ -184,31 +188,33 @@ def convert_tool_feature_method_to_entity_alternative(conn, use_existing_tool_fe for key in set(new_ea_items) & set(current_ea_ids) ] pval_ids_to_remove = [x["id"] for x in is_active_pvals] - if ea_items_to_add: - conn.execute(ea_table.insert(), ea_items_to_add) - ea_update = ea_table.update() - for item in ea_items_to_update: - conn.execute(ea_update.where(ea_table.c.id == item["id"]), {"active": item["active"]}) - # Delete pvals 499 at a time to avoid too many sql variables - size = 499 - for i in range(0, len(pval_ids_to_remove), size): - ids = pval_ids_to_remove[i : i + size] - conn.execute(pv_table.delete().where(pv_table.c.id.in_(ids))) + if apply: + if ea_items_to_add: + conn.execute(ea_table.insert(), ea_items_to_add) + ea_update = ea_table.update() + for item in ea_items_to_update: + conn.execute(ea_update.where(ea_table.c.id == item["id"]), {"active": item["active"]}) + # Delete pvals 499 at a time to avoid too many sql variables + size = 499 + for i in range(0, len(pval_ids_to_remove), size): + ids = pval_ids_to_remove[i : i + size] + conn.execute(pv_table.delete().where(pv_table.c.id.in_(ids))) return ea_items_to_add, ea_items_to_update, set(pval_ids_to_remove) -def compatibility_transformations(connection): +def compatibility_transformations(connection, apply=True): """Refits any data having an old format and returns changes made. Args: connection (Connection) + apply (bool): if True, apply the transformations Returns: tuple(list, list): list of tuples (tablename, (items_added, items_updated, ids_removed)), and list of strings indicating the changes """ ea_items_added, ea_items_updated, pval_ids_removed = convert_tool_feature_method_to_entity_alternative( - connection, use_existing_tool_feature_method=False + connection, use_existing_tool_feature_method=False, apply=apply ) transformations = [] info = [] @@ -219,7 +225,7 @@ def compatibility_transformations(connection): if ea_items_added or ea_items_updated or pval_ids_removed: info.append("Convert entity activity control using tool/feature/method into entity_alternative") _, ec_items_updated, _ = convert_tool_feature_method_to_active_by_default( - connection, use_existing_tool_feature_method=False + connection, use_existing_tool_feature_method=False, apply=apply ) if ec_items_updated: transformations.append(("entity_class", ((), ec_items_updated, ()))) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index e5ec771a..11e9d03a 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -745,11 +745,12 @@ def query(self, *args, **kwargs): """ return Query(self.engine, *args) - def commit_session(self, comment): + def commit_session(self, comment, apply_compatibility_transforms=True): """Commits the changes from the in-memory mapping to the database. Args: comment (str): commit message + apply_compatibility_transforms (bool): if True, apply compatibility transforms Returns: tuple(list, list): compatibility transformations @@ -777,7 +778,7 @@ def commit_session(self, comment): self._do_add_items(connection, tablename, *to_add) if self._memory: self._memory_dirty = True - transformation_info = compatibility_transformations(connection) + transformation_info = compatibility_transformations(connection, apply=apply_compatibility_transforms) self._commit_count = self._query_commit_count() return transformation_info diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py index 68f5650a..7a8039ad 100644 --- a/tests/test_import_functions.py +++ b/tests/test_import_functions.py @@ -1141,23 +1141,26 @@ def test_unparse_value_imports_fields_correctly(self): 'entity_classes': [('A', (), None, None, False)], 'entities': [('A', 'aa', None)], 'parameter_definitions': [('A', 'test1', None, None, None)], - 'parameter_values': [( - 'A', - 'aa', - 'test1', - { - 'type': 'time_series', - 'index': { - 'start': '2000-01-01 00:00:00', - 'resolution': '1h', - 'ignore_year': False, - 'repeat': False + 'parameter_values': [ + ( + 'A', + 'aa', + 'test1', + { + 'type': 'time_series', + 'index': { + 'start': '2000-01-01 00:00:00', + 'resolution': '1h', + 'ignore_year': False, + 'repeat': False, + }, + 'data': [0.0, 1.0, 2.0, 4.0, 8.0, 0.0], }, - 'data': [0.0, 1.0, 2.0, 4.0, 8.0, 0.0] - }, - 'Base' - )], - 'alternatives': [('Base', 'Base alternative')]} + 'Base', + ) + ], + 'alternatives': [('Base', 'Base alternative')], + } count, errors = import_data(db_map, **data, unparse_value=dump_db_value) self.assertEqual(errors, []) @@ -1172,11 +1175,7 @@ def test_unparse_value_imports_fields_correctly(self): time_series = from_database(value.value, value.type) expected_result = TimeSeriesFixedResolution( - '2000-01-01 00:00:00', - '1h', - [0.0, 1.0, 2.0, 4.0, 8.0, 0.0], - False, - False + '2000-01-01 00:00:00', '1h', [0.0, 1.0, 2.0, 4.0, 8.0, 0.0], False, False ) self.assertEqual(time_series, expected_result) From 56a6a78592297881dc98826dbe6e0f3459838fee Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Mon, 25 Mar 2024 16:20:05 +0200 Subject: [PATCH 302/317] Add deep_copy_value() function We need a function that make deep copies of parameter values. Re spine-tools/Spine-Toolbox#2657 --- spinedb_api/parameter_value.py | 55 +++++++++++++++++++++++++++ tests/test_parameter_value.py | 69 ++++++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+) diff --git a/spinedb_api/parameter_value.py b/spinedb_api/parameter_value.py index b2a5aa76..73971fbc 100644 --- a/spinedb_api/parameter_value.py +++ b/spinedb_api/parameter_value.py @@ -1494,6 +1494,10 @@ def __eq__(self, other): return NotImplemented return other._indexes == self._indexes and other._values == self._values and self.index_name == other.index_name + @property + def index_type(self): + return self._index_type + def is_nested(self): """Whether any of the values is also a map. @@ -1727,3 +1731,54 @@ def split_value_and_type(value_and_type): except (TypeError, json.JSONDecodeError): parsed = value_and_type return dump_db_value(parsed) + + +def deep_copy_value(value): + """Copies a value. + The operation is deep meaning that nested Maps will be copied as well. + + :meta private: + + Args: + value (Any): value to copy + + Returns: + Any: deep-copied value + """ + if isinstance(value, (Number, str)): + return value + if isinstance(value, Array): + return Array(value.values, value.value_type, value.index_name) + if isinstance(value, DateTime): + return DateTime(value) + if isinstance(value, Duration): + return Duration(value) + if isinstance(value, Map): + return deep_copy_map(value) + if isinstance(value, TimePattern): + return TimePattern(value.indexes.copy(), value.values.copy(), value.index_name) + if isinstance(value, TimeSeriesFixedResolution): + return TimeSeriesFixedResolution( + value.start, value.resolution, value.values.copy(), value.ignore_year, value.repeat, value.index_name + ) + if isinstance(value, TimeSeriesVariableResolution): + return TimeSeriesVariableResolution( + value.indexes.copy(), value.values.copy(), value.ignore_year, value.repeat, value.index_name + ) + raise ValueError("unknown value") + + +def deep_copy_map(value): + """Deep copies a Map value. + + :meta private: + + Args: + value (Map): Map to copy + + Returns: + Map: deep-copied Map + """ + xs = value.indexes.copy() + ys = [deep_copy_value(y) for y in value.values] + return Map(xs, ys, index_type=value.index_type, index_name=value.index_name) diff --git a/tests/test_parameter_value.py b/tests/test_parameter_value.py index d54604c7..deaf7cea 100644 --- a/tests/test_parameter_value.py +++ b/tests/test_parameter_value.py @@ -26,6 +26,7 @@ convert_containers_to_maps, convert_leaf_maps_to_specialized_containers, convert_map_to_table, + deep_copy_value, duration_to_relativedelta, relativedelta_to_duration, from_database, @@ -1003,6 +1004,74 @@ def convert_map_to_dict(self): nested_map = Map(["A", "B"], [map1, map2]) self.assertEqual(nested_map, {"A": {"a": -3.2, "b": -2.3}, "B": {"c": 3.2, "d": 2.3}}) + def test_deep_copy_value_for_scalars(self): + x = 1.0 + copy_of_x = deep_copy_value(x) + self.assertEqual(x, copy_of_x) + x = "y" + copy_of_x = deep_copy_value(x) + self.assertEqual(x, copy_of_x) + x = Duration("3h") + copy_of_x = deep_copy_value(x) + self.assertEqual(x, copy_of_x) + self.assertIsNot(x, copy_of_x) + x = DateTime("2024-03-25T15:58:33") + copy_of_x = deep_copy_value(x) + self.assertEqual(x, copy_of_x) + self.assertIsNot(x, copy_of_x) + + def test_deep_copy_for_arrays(self): + x = Array([], value_type=float, index_name="floaters") + copy_of_x = deep_copy_value(x) + self.assertEqual(x, copy_of_x) + self.assertIsNot(x, copy_of_x) + x = Array(["1", "2", "3"], index_name="number likes") + copy_of_x = deep_copy_value(x) + self.assertEqual(x, copy_of_x) + self.assertIsNot(x, copy_of_x) + + def test_deep_copy_time_pattern(self): + x = TimePattern(["M1-12"], [2.3], index_name="moments") + copy_of_x = deep_copy_value(x) + self.assertEqual(x, copy_of_x) + self.assertIsNot(x, copy_of_x) + + def test_deep_copy_time_series_fixed_resolution(self): + x = TimeSeriesFixedResolution( + "2024-03-25T16:08:23", "4h", [2.3, 23.0, 5.0], ignore_year=True, repeat=True, index_name="my times" + ) + copy_of_x = deep_copy_value(x) + self.assertEqual(x, copy_of_x) + self.assertIsNot(x, copy_of_x) + + def test_deep_copy_time_series_variable_resolution(self): + x = TimeSeriesVariableResolution( + ["2024-03-25T16:10:23", "2024-04-26T16:11:23"], + [2.3, 23.0], + ignore_year=True, + repeat=True, + index_name="your times", + ) + copy_of_x = deep_copy_value(x) + self.assertEqual(x, copy_of_x) + self.assertIsNot(x, copy_of_x) + + def test_deep_copy_map(self): + x = Map([], [], index_type=str, index_name="first i") + copy_of_x = deep_copy_value(x) + self.assertEqual(x, copy_of_x) + self.assertIsNot(x, copy_of_x) + x = Map(["T1", "T2"], [2.3, 23.0], index_name="our times") + copy_of_x = deep_copy_value(x) + self.assertEqual(x, copy_of_x) + self.assertIsNot(x, copy_of_x) + leaf = Map(["t1"], [2.3], index_name="inner") + x = Map(["T1"], [leaf], index_name="outer") + copy_of_x = deep_copy_value(x) + self.assertEqual(x, copy_of_x) + self.assertIsNot(x, copy_of_x) + self.assertIsNot(x.get_value("T1"), copy_of_x.get_value("T1")) + if __name__ == "__main__": unittest.main() From c5b8d2ffcdef2bd5686659f25449e9c8fc093b79 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Wed, 27 Mar 2024 10:56:22 +0200 Subject: [PATCH 303/317] Add some performance benchmarks The benchmarks test updating the default values in parameter definitions. --- .gitignore | 1 + benchmarks/README.md | 26 +++++++++ benchmarks/__init__.py | 0 ...update_default_value_to_different_value.py | 53 +++++++++++++++++++ .../update_default_value_to_same_value.py | 41 ++++++++++++++ benchmarks/utils.py | 32 +++++++++++ pyproject.toml | 2 +- 7 files changed, 154 insertions(+), 1 deletion(-) create mode 100644 benchmarks/README.md create mode 100644 benchmarks/__init__.py create mode 100644 benchmarks/update_default_value_to_different_value.py create mode 100644 benchmarks/update_default_value_to_same_value.py create mode 100644 benchmarks/utils.py diff --git a/.gitignore b/.gitignore index 89f5b605..40e54f1a 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,4 @@ /htmlcov spinedb_api/version.py +benchmarks/*.json diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 00000000..5482382f --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,26 @@ +# Performance benchmarks + +This Python package contains performance benchmarks for `spinedb_api`. +The benchmarks use [`pyperf`](https://pyperf.readthedocs.io/en/latest/index.html) +which can be installed by installing the optional developer dependencies: + +```commandline +python -mpip install .[dev] +``` + +Each Python file is an individual script +that writes the run results into a common `.json` file. +The file can be inspected by + +```commandline +python -mpyperf show +``` + +Benchmarks from e.g. different commits/branches can be compared by + +```commandline +python -mpyperf compare_to +``` + +Check the [`pyperf` documentation]((https://pyperf.readthedocs.io/en/latest/index.html)) +for further things you can do with it. diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/benchmarks/update_default_value_to_different_value.py b/benchmarks/update_default_value_to_different_value.py new file mode 100644 index 00000000..0f27836d --- /dev/null +++ b/benchmarks/update_default_value_to_different_value.py @@ -0,0 +1,53 @@ +""" +This benchmark tests the performance of updating a parameter definition when +the update changes the default value from None to a somewhat complex Map. +""" + +import time +import pyperf +from spinedb_api import DatabaseMapping, to_database +from benchmarks.utils import build_sizeable_map, run_file_name + + +def update_default_value(loops, db_map, first_db_value, first_value_type, second_db_value, second_value_type): + total_time = 0.0 + for counter in range(loops): + start = time.perf_counter() + result = db_map.update_parameter_definition_item( + name="x", entity_class_name="Object", default_value=second_db_value, default_type=second_value_type + ) + finish = time.perf_counter() + error = result[1] + if error: + raise RuntimeError(error) + total_time += finish - start + db_map.update_parameter_definition_item( + name="x", entity_class_name="Object", default_value=first_db_value, default_type=first_value_type + ) + return total_time + + +def run_benchmark(): + first_value, first_type = to_database(None) + second_value, second_type = to_database(build_sizeable_map()) + with DatabaseMapping("sqlite://", create=True) as db_map: + db_map.add_entity_class_item(name="Object") + db_map.add_parameter_definition_item( + name="x", entity_class_name="Object", default_value=first_value, default_type=first_type + ) + runner = pyperf.Runner(min_time=0.0001) + benchmark = runner.bench_time_func( + "update_parameter_definition_item[None,Map]", + update_default_value, + db_map, + first_value, + first_type, + second_value, + second_type, + inner_loops=10, + ) + pyperf.add_runs(run_file_name(), benchmark) + + +if __name__ == "__main__": + run_benchmark() diff --git a/benchmarks/update_default_value_to_same_value.py b/benchmarks/update_default_value_to_same_value.py new file mode 100644 index 00000000..13e7b634 --- /dev/null +++ b/benchmarks/update_default_value_to_same_value.py @@ -0,0 +1,41 @@ +""" +This benchmark tests the performance of updating a parameter definition item when +the default value is somewhat complex Map and the update does not change anything. +""" +import time +import pyperf +from spinedb_api import DatabaseMapping, to_database +from benchmarks.utils import build_sizeable_map, run_file_name + + +def update_default_value(loops, db_map, value, value_type): + total_time = 0.0 + for counter in range(loops): + start = time.perf_counter() + result = db_map.update_parameter_definition_item( + name="x", entity_class_name="Object", default_value=value, default_type=value_type + ) + finish = time.perf_counter() + error = result[1] + if error: + raise RuntimeError(error) + total_time += finish - start + return total_time + + +def run_benchmark(): + value, value_type = to_database(build_sizeable_map()) + with DatabaseMapping("sqlite://", create=True) as db_map: + db_map.add_entity_class_item(name="Object") + db_map.add_parameter_definition_item( + name="x", entity_class_name="Object", default_value=value, default_type=value_type + ) + runner = pyperf.Runner() + benchmark = runner.bench_time_func( + "update_parameter_definition_item[Map,Map]", update_default_value, db_map, value, value_type, inner_loops=10 + ) + pyperf.add_runs(run_file_name(), benchmark) + + +if __name__ == "__main__": + run_benchmark() diff --git a/benchmarks/utils.py b/benchmarks/utils.py new file mode 100644 index 00000000..c2285ba2 --- /dev/null +++ b/benchmarks/utils.py @@ -0,0 +1,32 @@ +import datetime +import math +from spinedb_api import __version__, DateTime, Map + + +def build_sizeable_map(): + start = datetime.datetime(year=2024, month=1, day=1) + root_xs = [] + root_ys = [] + i_max = 10 + j_max = 10 + k_max = 10 + total = i_max * j_max * k_max + for i in range(i_max): + root_xs.append(DateTime(start + datetime.timedelta(hours=i))) + leaf_xs = [] + leaf_ys = [] + for j in range(j_max): + leaf_xs.append(DateTime(start + datetime.timedelta(hours=j))) + xs = [] + ys = [] + for k in range(k_max): + xs.append(DateTime(start + datetime.timedelta(hours=k))) + x = float(k + k_max * j + j_max * i) / total + ys.append(math.sin(x * math.pi / 2.0) + (x * j) ** 2 + x * i) + leaf_ys.append(Map(xs, ys)) + root_ys.append(Map(leaf_xs, leaf_ys)) + return Map(root_xs, root_ys) + + +def run_file_name(): + return f"benchmark-{__version__}.json" diff --git a/pyproject.toml b/pyproject.toml index bcfa9d3b..7852c4d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ Repository = "https://github.com/spine-tools/Spine-Database-API" [project.optional-dependencies] -dev = ["coverage[toml]"] +dev = ["coverage[toml]", "pyperf"] [build-system] requires = ["setuptools>=64", "setuptools_scm[toml]>=6.2", "wheel", "build"] From 1585480eb82511bc25656c72b8a4c6d8d35b4804 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Wed, 13 Mar 2024 14:37:16 +0200 Subject: [PATCH 304/317] Avoid parsing parameter value when checking need for update from_database() is potentially expensive operation so we should avoid it in ParsedValueBase._something_to_update() if we can. We now check if the types or raw binary blobs differ before parsing the value. --- spinedb_api/mapped_items.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index e7f0c878..3fc447ec 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -11,7 +11,7 @@ ###################################################################################################################### from operator import itemgetter - +import time from .helpers import name_from_elements from .parameter_value import to_database, from_database, ParameterValueFormatError from .db_mapping_base import MappedItemBase @@ -380,16 +380,21 @@ def __getitem__(self, key): return super().__getitem__(key) def _something_to_update(self, other): - other = other.copy() if self._value_key in other and self._type_key in other: - try: - other_parsed_value = from_database(other[self._value_key], other[self._type_key]) - if self.parsed_value != other_parsed_value: - return True - _ = other.pop(self._value_key, None) - _ = other.pop(self._type_key, None) - except ParameterValueFormatError: - pass + other_value_type = other[self._type_key] + if self.type != other_value_type: + return True + other_value = other[self._value_key] + if self.value != other_value: + try: + other_parsed_value = from_database(other_value, other_value_type) + if self.parsed_value != other_parsed_value: + return True + other = other.copy() + _ = other.pop(self._value_key, None) + _ = other.pop(self._type_key, None) + except ParameterValueFormatError: + pass return super()._something_to_update(other) From 0106526b53bf326d5a3dfdf8dd8199dd1f515fa5 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Tue, 2 Apr 2024 13:37:12 +0300 Subject: [PATCH 305/317] Add a function to create default import mappings Importer needs methods to create blank mappings. Here they are. Re spine-tools/Spine-Toolbox#2662 --- spinedb_api/import_mapping/import_mapping.py | 93 ++++++++++++++++++- .../import_mapping/import_mapping_compat.py | 5 +- tests/import_mapping/test_import_mapping.py | 22 ++++- 3 files changed, 108 insertions(+), 12 deletions(-) diff --git a/spinedb_api/import_mapping/import_mapping.py b/spinedb_api/import_mapping/import_mapping.py index 625809c5..c521a55e 100644 --- a/spinedb_api/import_mapping/import_mapping.py +++ b/spinedb_api/import_mapping/import_mapping.py @@ -9,10 +9,7 @@ # Public License for more details. You should have received a copy of the GNU Lesser General Public License along with # this program. If not, see . ###################################################################################################################### -""" -Contains import mappings for database items such as entities, entity classes and parameter values. - -""" +""" Contains import mappings for database items such as entities, entity classes and parameter values. """ from distutils.util import strtobool from enum import auto, Enum, unique @@ -860,6 +857,94 @@ def _import_row(self, source_data, state, mapped_data): mapped_data.setdefault("tools", set()).add(tool) +def default_import_mapping(map_type): + """Creates default mappings for given map type. + + Args: + map_type (str): map type + + Returns: + ImportMapping: root mapping of desired type + """ + make_root_mapping = { + "EntityClass": _default_entity_class_mapping, + "Alternative": _default_alternative_mapping, + "Scenario": _default_scenario_mapping, + "ScenarioAlternative": _default_scenario_alternative_mapping, + "EntityGroup": _default_entity_group_mapping, + "ParameterValueList": _default_parameter_value_list_mapping, + }[map_type] + return make_root_mapping() + + +def _default_entity_class_mapping(): + """Creates default entity class mappings. + + Returns: + EntityClassMapping: root mapping + """ + root_mapping = EntityClassMapping(Position.hidden) + object_mapping = root_mapping.child = EntityMapping(Position.hidden) + object_mapping.child = EntityMetadataMapping(Position.hidden) + return root_mapping + + +def _default_alternative_mapping(): + """Creates default alternative mappings. + + Returns: + AlternativeMapping: root mapping + """ + root_mapping = AlternativeMapping(Position.hidden) + return root_mapping + + +def _default_scenario_mapping(): + """Creates default scenario mappings. + + Returns: + ScenarioMapping: root mapping + """ + root_mapping = ScenarioMapping(Position.hidden) + root_mapping.child = ScenarioActiveFlagMapping(Position.hidden) + return root_mapping + + +def _default_scenario_alternative_mapping(): + """Creates default scenario alternative mappings. + + Returns: + ScenarioAlternativeMapping: root mapping + """ + root_mapping = ScenarioMapping(Position.hidden) + scen_alt_mapping = root_mapping.child = ScenarioAlternativeMapping(Position.hidden) + scen_alt_mapping.child = ScenarioBeforeAlternativeMapping(Position.hidden) + return root_mapping + + +def _default_entity_group_mapping(): + """Creates default entity group mappings. + + Returns: + EntityClassMapping: root mapping + """ + root_mapping = EntityClassMapping(Position.hidden) + object_mapping = root_mapping.child = EntityMapping(Position.hidden) + object_mapping.child = EntityGroupMapping(Position.hidden) + return root_mapping + + +def _default_parameter_value_list_mapping(): + """Creates default parameter value list mappings. + + Returns: + ParameterValueListMapping: root mapping + """ + root_mapping = ParameterValueListMapping(Position.hidden) + root_mapping.child = ParameterValueListValueMapping(Position.hidden) + return root_mapping + + def from_dict(serialized): """ Deserializes mappings. diff --git a/spinedb_api/import_mapping/import_mapping_compat.py b/spinedb_api/import_mapping/import_mapping_compat.py index 4434334d..65ac8f2e 100644 --- a/spinedb_api/import_mapping/import_mapping_compat.py +++ b/spinedb_api/import_mapping/import_mapping_compat.py @@ -10,10 +10,7 @@ # this program. If not, see . ###################################################################################################################### -""" -Functions for creating import mappings from dicts. - -""" +""" Functions for creating import mappings from dicts. """ from .import_mapping import ( Position, EntityClassMapping, diff --git a/tests/import_mapping/test_import_mapping.py b/tests/import_mapping/test_import_mapping.py index d8b24df4..d0f119d8 100644 --- a/tests/import_mapping/test_import_mapping.py +++ b/tests/import_mapping/test_import_mapping.py @@ -10,15 +10,13 @@ # this program. If not, see . ###################################################################################################################### -""" -Unit tests for import Mappings. - -""" +""" Unit tests for import Mappings. """ import unittest from unittest.mock import Mock from spinedb_api.exception import InvalidMapping from spinedb_api.mapping import Position, to_dict as mapping_to_dict, unflatten from spinedb_api.import_mapping.import_mapping import ( + default_import_mapping, ImportMapping, EntityClassMapping, EntityMapping, @@ -2083,5 +2081,21 @@ def test_returns_false_when_position_is_header_and_is_leaf(self): self.assertFalse(mapping.is_pivoted()) +class TestDefaultMappings(unittest.TestCase): + def test_mappings_are_hidden(self): + map_types = ( + "EntityClass", + "Alternative", + "Scenario", + "ScenarioAlternative", + "EntityGroup", + "ParameterValueList", + ) + for map_type in map_types: + root = default_import_mapping(map_type) + flattened = root.flatten() + self.assertTrue(all(m.position == Position.hidden for m in flattened)) + + if __name__ == "__main__": unittest.main() From 09b2edcb23d960821506b5a463ea9c4e993ab008 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Tue, 2 Apr 2024 13:39:15 +0300 Subject: [PATCH 306/317] Improve purge_url() and purge() docstrings. --- spinedb_api/purge.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/spinedb_api/purge.py b/spinedb_api/purge.py index 179541c5..2e8066d6 100644 --- a/spinedb_api/purge.py +++ b/spinedb_api/purge.py @@ -22,9 +22,11 @@ def purge_url(url, purge_settings, logger=None): """Removes all items of selected types from the database at a given URL. + Purges everything if ``purge_settings`` is None. + Args: url (str): database URL - purge_settings (dict): mapping from item type to a boolean indicating whether to remove them or not + purge_settings (dict, optional): mapping from item type to a boolean indicating whether to remove them or not logger (LoggerInterface, optional): logger Returns: @@ -45,9 +47,11 @@ def purge_url(url, purge_settings, logger=None): def purge(db_map, purge_settings, logger=None): """Removes all items of selected types from a database. + Purges everything if ``purge_settings`` is None. + Args: db_map (DatabaseMapping): target database mapping - purge_settings (dict): mapping from item type to a boolean indicating whether to remove them or not + purge_settings (dict, optional): mapping from item type to a boolean indicating whether to remove them or not logger (LoggerInterface): logger Returns: From 2cccc95417bf9411b9a6e86f5b45cbb689c7d58f Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Tue, 2 Apr 2024 14:38:09 +0300 Subject: [PATCH 307/317] Enable string normalization in black Just reformatting, no functional changes. Re spine-tools/Spine-Toolbox#2684 --- pyproject.toml | 1 - spinedb_api/alembic/env.py | 2 +- .../0c7d199ae915_add_list_value_table.py | 40 ++--- .../1892adebc00f_create_metadata_tables.py | 4 +- .../1e4997105288_separate_type_from_value.py | 10 +- ...60a11b05_add_alternatives_and_scenarios.py | 14 +- ...63bef2_create_superclass_subclass_table.py | 36 ++-- ...c61_drop_object_and_relationship_tables.py | 120 ++++++------- ..._fix_foreign_key_constraints_in_object_.py | 24 +-- ..._fix_foreign_key_constraints_in_entity_.py | 24 +-- ...b_add_active_by_default_to_entity_class.py | 4 +- ..._replace_values_with_reference_to_list_.py | 4 +- .../9da58d2def22_create_entity_group_table.py | 4 +- ...a82ed59_create_entity_alternative_table.py | 44 ++--- .../defbda3bf2b5_add_tool_feature_tables.py | 4 +- .../fbb540efbf15_add_support_for_mysql.py | 82 ++++----- ...drop_on_update_clauses_from_object_and_.py | 24 +-- spinedb_api/db_mapping.py | 6 +- spinedb_api/helpers.py | 2 +- spinedb_api/import_mapping/import_mapping.py | 4 +- spinedb_api/mapped_items.py | 164 +++++++++--------- spinedb_api/parameter_value.py | 8 +- spinedb_api/perfect_split.py | 4 +- spinedb_api/server_client_helpers.py | 8 +- tests/export_mapping/test_export_mapping.py | 6 +- tests/filters/test_alternative_filter.py | 2 +- tests/filters/test_execution_filter.py | 2 +- tests/filters/test_scenario_filter.py | 2 +- tests/filters/test_tool_filter.py | 2 +- tests/import_mapping/test_generator.py | 44 ++--- tests/import_mapping/test_import_mapping.py | 158 ++++++++--------- tests/spine_io/exporters/test_gdx_writer.py | 2 +- tests/spine_io/exporters/test_writer.py | 2 +- .../importers/test_datapackage_reader.py | 2 +- tests/spine_io/importers/test_json_reader.py | 2 +- tests/spine_io/importers/test_reader.py | 2 +- .../importers/test_sqlalchemy_connector.py | 2 +- tests/test_DatabaseMapping.py | 30 ++-- tests/test_check_integrity.py | 36 ++-- tests/test_db_mapping_base.py | 2 +- tests/test_export_functions.py | 2 +- tests/test_helpers.py | 8 +- tests/test_import_functions.py | 80 ++++----- tests/test_migration.py | 30 ++-- tests/test_parameter_value.py | 16 +- tests/test_purge.py | 2 +- 46 files changed, 535 insertions(+), 536 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7852c4d1..f4dd65bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,5 +65,4 @@ ignore_errors = true [tool.black] line-length = 120 -skip-string-normalization = true exclude = '\.git' diff --git a/spinedb_api/alembic/env.py b/spinedb_api/alembic/env.py index 02973d56..6e228ba8 100644 --- a/spinedb_api/alembic/env.py +++ b/spinedb_api/alembic/env.py @@ -19,7 +19,7 @@ # for 'autogenerate' support import sys -sys.path = ['', '..'] + sys.path[1:] +sys.path = ["", ".."] + sys.path[1:] from spinedb_api.helpers import create_spine_metadata target_metadata = create_spine_metadata() diff --git a/spinedb_api/alembic/versions/0c7d199ae915_add_list_value_table.py b/spinedb_api/alembic/versions/0c7d199ae915_add_list_value_table.py index 0a4e3491..e348f970 100644 --- a/spinedb_api/alembic/versions/0c7d199ae915_add_list_value_table.py +++ b/spinedb_api/alembic/versions/0c7d199ae915_add_list_value_table.py @@ -12,8 +12,8 @@ from spinedb_api.helpers import LONGTEXT_LENGTH # revision identifiers, used by Alembic. -revision = '0c7d199ae915' -down_revision = '7d0b467f2f4e' +revision = "0c7d199ae915" +down_revision = "7d0b467f2f4e" branch_labels = None depends_on = None @@ -36,14 +36,14 @@ def upgrade(): with op.batch_alter_table("next_id") as batch_op: batch_op.add_column(sa.Column("list_value_id", sa.Integer, server_default=sa.null())) op.create_table( - 'list_value', - sa.Column('id', sa.Integer, primary_key=True), - sa.Column('parameter_value_list_id', sa.Integer, sa.ForeignKey("parameter_value_list.id"), nullable=False), - sa.Column('index', sa.Integer, nullable=False), - sa.Column('type', sa.String(255)), - sa.Column('value', sa.LargeBinary(LONGTEXT_LENGTH), server_default=sa.null()), - sa.Column('commit_id', sa.Integer, sa.ForeignKey("commit.id")), - sa.UniqueConstraint('parameter_value_list_id', 'index'), + "list_value", + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("parameter_value_list_id", sa.Integer, sa.ForeignKey("parameter_value_list.id"), nullable=False), + sa.Column("index", sa.Integer, nullable=False), + sa.Column("type", sa.String(255)), + sa.Column("value", sa.LargeBinary(LONGTEXT_LENGTH), server_default=sa.null()), + sa.Column("commit_id", sa.Integer, sa.ForeignKey("commit.id")), + sa.UniqueConstraint("parameter_value_list_id", "index"), ) # NOTE: At some point, by mistake, we modified ``helpers.create_new_spine_database`` by specifying a name for the fk # that refers parameter_value_list in tool_feature_method. But since this was just a mistake, we didn't provide a @@ -55,22 +55,22 @@ def upgrade(): x["name"] for x in sa.inspect(conn).get_foreign_keys("tool_feature_method") if x["referred_table"] == "parameter_value_list" - and x["referred_columns"] == ['id', 'value_index'] - and x["constrained_columns"] == ['parameter_value_list_id', 'method_index'] + and x["referred_columns"] == ["id", "value_index"] + and x["constrained_columns"] == ["parameter_value_list_id", "method_index"] ) with op.batch_alter_table("tool_feature_method") as batch_op: - batch_op.drop_constraint(fk_name, type_='foreignkey') + batch_op.drop_constraint(fk_name, type_="foreignkey") with op.batch_alter_table("parameter_value_list") as batch_op: - batch_op.drop_column('value_index') - batch_op.drop_column('value') + batch_op.drop_column("value_index") + batch_op.drop_column("value") with op.batch_alter_table("tool_feature_method") as batch_op: batch_op.create_foreign_key( None, - 'list_value', - ['parameter_value_list_id', 'method_index'], - ['parameter_value_list_id', 'index'], - onupdate='CASCADE', - ondelete='CASCADE', + "list_value", + ["parameter_value_list_id", "method_index"], + ["parameter_value_list_id", "index"], + onupdate="CASCADE", + ondelete="CASCADE", ) # Add rescued data pvl_items = list({x.id: {"id": x.id, "name": x.name, "commit_id": x.commit_id} for x in pvl}.values()) diff --git a/spinedb_api/alembic/versions/1892adebc00f_create_metadata_tables.py b/spinedb_api/alembic/versions/1892adebc00f_create_metadata_tables.py index 8bd5f19e..614da1b3 100644 --- a/spinedb_api/alembic/versions/1892adebc00f_create_metadata_tables.py +++ b/spinedb_api/alembic/versions/1892adebc00f_create_metadata_tables.py @@ -10,8 +10,8 @@ # revision identifiers, used by Alembic. -revision = '1892adebc00f' -down_revision = 'defbda3bf2b5' +revision = "1892adebc00f" +down_revision = "defbda3bf2b5" branch_labels = None depends_on = None diff --git a/spinedb_api/alembic/versions/1e4997105288_separate_type_from_value.py b/spinedb_api/alembic/versions/1e4997105288_separate_type_from_value.py index 30d74c38..7d021cca 100644 --- a/spinedb_api/alembic/versions/1e4997105288_separate_type_from_value.py +++ b/spinedb_api/alembic/versions/1e4997105288_separate_type_from_value.py @@ -14,8 +14,8 @@ # revision identifiers, used by Alembic. -revision = '1e4997105288' -down_revision = 'fbb540efbf15' +revision = "1e4997105288" +down_revision = "fbb540efbf15" branch_labels = None depends_on = None @@ -34,14 +34,14 @@ def upgrade(): pvl_items = _get_pvl_items(session, Base) # Alter tables with op.batch_alter_table("parameter_definition") as batch_op: - batch_op.drop_column('data_type') + batch_op.drop_column("data_type") batch_op.drop_column("default_value") batch_op.add_column(sa.Column("default_value", sa.LargeBinary(LONGTEXT_LENGTH), server_default=sa.null())) - batch_op.add_column(sa.Column('default_type', sa.String(length=255), nullable=True)) + batch_op.add_column(sa.Column("default_type", sa.String(length=255), nullable=True)) with op.batch_alter_table("parameter_value") as batch_op: batch_op.drop_column("value") batch_op.add_column(sa.Column("value", sa.LargeBinary(LONGTEXT_LENGTH), server_default=sa.null())) - batch_op.add_column(sa.Column('type', sa.String(length=255), nullable=True)) + batch_op.add_column(sa.Column("type", sa.String(length=255), nullable=True)) with op.batch_alter_table("parameter_value_list") as batch_op: batch_op.drop_column("value") batch_op.add_column(sa.Column("value", sa.LargeBinary(LONGTEXT_LENGTH), server_default=sa.null())) diff --git a/spinedb_api/alembic/versions/39e860a11b05_add_alternatives_and_scenarios.py b/spinedb_api/alembic/versions/39e860a11b05_add_alternatives_and_scenarios.py index 06331d2a..377bdd12 100644 --- a/spinedb_api/alembic/versions/39e860a11b05_add_alternatives_and_scenarios.py +++ b/spinedb_api/alembic/versions/39e860a11b05_add_alternatives_and_scenarios.py @@ -13,8 +13,8 @@ # revision identifiers, used by Alembic. -revision = '39e860a11b05' -down_revision = '9da58d2def22' +revision = "39e860a11b05" +down_revision = "9da58d2def22" branch_labels = None depends_on = None @@ -85,7 +85,7 @@ def alter_tables_after_update(): op.execute("UPDATE parameter_value SET alternative_id = 1") with op.batch_alter_table("parameter_value") as batch_op: - batch_op.alter_column('alternative_id', nullable=False) + batch_op.alter_column("alternative_id", nullable=False) batch_op.create_foreign_key( None, "alternative", ("alternative_id",), ("id",), onupdate="CASCADE", ondelete="CASCADE" ) @@ -115,9 +115,9 @@ def alter_tables_after_update(): ) with op.batch_alter_table("entity_type") as batch_op: - batch_op.alter_column('commit_id', nullable=False) + batch_op.alter_column("commit_id", nullable=False) with op.batch_alter_table("entity_class_type") as batch_op: - batch_op.alter_column('commit_id', nullable=False) + batch_op.alter_column("commit_id", nullable=False) def upgrade(): @@ -146,6 +146,6 @@ def downgrade(): batch_op.drop_column("scenario_id") batch_op.drop_column("scenario_alternative_id") with op.batch_alter_table("entity_type") as batch_op: - batch_op.alter_column('commit_id', nullable=True) + batch_op.alter_column("commit_id", nullable=True) with op.batch_alter_table("entity_class_type") as batch_op: - batch_op.alter_column('commit_id', nullable=True) + batch_op.alter_column("commit_id", nullable=True) diff --git a/spinedb_api/alembic/versions/5385f063bef2_create_superclass_subclass_table.py b/spinedb_api/alembic/versions/5385f063bef2_create_superclass_subclass_table.py index 2fdb006b..8d07d5e1 100644 --- a/spinedb_api/alembic/versions/5385f063bef2_create_superclass_subclass_table.py +++ b/spinedb_api/alembic/versions/5385f063bef2_create_superclass_subclass_table.py @@ -10,34 +10,34 @@ # revision identifiers, used by Alembic. -revision = '5385f063bef2' -down_revision = 'ce9faa82ed59' +revision = "5385f063bef2" +down_revision = "ce9faa82ed59" branch_labels = None depends_on = None def upgrade(): op.create_table( - 'superclass_subclass', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('superclass_id', sa.Integer(), nullable=False), - sa.Column('subclass_id', sa.Integer(), nullable=False), + "superclass_subclass", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("superclass_id", sa.Integer(), nullable=False), + sa.Column("subclass_id", sa.Integer(), nullable=False), sa.ForeignKeyConstraint( - ['subclass_id'], - ['entity_class.id'], - name=op.f('fk_superclass_subclass_subclass_id_entity_class'), - onupdate='CASCADE', - ondelete='CASCADE', + ["subclass_id"], + ["entity_class.id"], + name=op.f("fk_superclass_subclass_subclass_id_entity_class"), + onupdate="CASCADE", + ondelete="CASCADE", ), sa.ForeignKeyConstraint( - ['superclass_id'], - ['entity_class.id'], - name=op.f('fk_superclass_subclass_superclass_id_entity_class'), - onupdate='CASCADE', - ondelete='CASCADE', + ["superclass_id"], + ["entity_class.id"], + name=op.f("fk_superclass_subclass_superclass_id_entity_class"), + onupdate="CASCADE", + ondelete="CASCADE", ), - sa.PrimaryKeyConstraint('id', name=op.f('pk_superclass_subclass')), - sa.UniqueConstraint('subclass_id', name=op.f('uq_superclass_subclass_subclass_id')), + sa.PrimaryKeyConstraint("id", name=op.f("pk_superclass_subclass")), + sa.UniqueConstraint("subclass_id", name=op.f("uq_superclass_subclass_subclass_id")), ) diff --git a/spinedb_api/alembic/versions/6b7c994c1c61_drop_object_and_relationship_tables.py b/spinedb_api/alembic/versions/6b7c994c1c61_drop_object_and_relationship_tables.py index afb3adda..b4dc0212 100644 --- a/spinedb_api/alembic/versions/6b7c994c1c61_drop_object_and_relationship_tables.py +++ b/spinedb_api/alembic/versions/6b7c994c1c61_drop_object_and_relationship_tables.py @@ -10,68 +10,68 @@ from spinedb_api.helpers import naming_convention # revision identifiers, used by Alembic. -revision = '6b7c994c1c61' -down_revision = '989fccf80441' +revision = "6b7c994c1c61" +down_revision = "989fccf80441" branch_labels = None depends_on = None def upgrade(): op.create_table( - 'entity_class_dimension', - sa.Column('entity_class_id', sa.Integer(), nullable=False), - sa.Column('dimension_id', sa.Integer(), nullable=False), - sa.Column('position', sa.Integer(), nullable=False), + "entity_class_dimension", + sa.Column("entity_class_id", sa.Integer(), nullable=False), + sa.Column("dimension_id", sa.Integer(), nullable=False), + sa.Column("position", sa.Integer(), nullable=False), sa.ForeignKeyConstraint( - ['dimension_id'], - ['entity_class.id'], - name=op.f('fk_entity_class_dimension_dimension_id_entity_class'), - onupdate='CASCADE', - ondelete='CASCADE', + ["dimension_id"], + ["entity_class.id"], + name=op.f("fk_entity_class_dimension_dimension_id_entity_class"), + onupdate="CASCADE", + ondelete="CASCADE", ), sa.ForeignKeyConstraint( - ['entity_class_id'], - ['entity_class.id'], - name=op.f('fk_entity_class_dimension_entity_class_id_entity_class'), - onupdate='CASCADE', - ondelete='CASCADE', + ["entity_class_id"], + ["entity_class.id"], + name=op.f("fk_entity_class_dimension_entity_class_id_entity_class"), + onupdate="CASCADE", + ondelete="CASCADE", ), - sa.PrimaryKeyConstraint('entity_class_id', 'dimension_id', 'position', name=op.f('pk_entity_class_dimension')), - sa.UniqueConstraint('entity_class_id', 'dimension_id', 'position', name='uq_entity_class_dimension'), + sa.PrimaryKeyConstraint("entity_class_id", "dimension_id", "position", name=op.f("pk_entity_class_dimension")), + sa.UniqueConstraint("entity_class_id", "dimension_id", "position", name="uq_entity_class_dimension"), ) op.create_table( - 'entity_element', - sa.Column('entity_id', sa.Integer(), nullable=False), - sa.Column('entity_class_id', sa.Integer(), nullable=False), - sa.Column('element_id', sa.Integer(), nullable=False), - sa.Column('dimension_id', sa.Integer(), nullable=False), - sa.Column('position', sa.Integer(), nullable=False), + "entity_element", + sa.Column("entity_id", sa.Integer(), nullable=False), + sa.Column("entity_class_id", sa.Integer(), nullable=False), + sa.Column("element_id", sa.Integer(), nullable=False), + sa.Column("dimension_id", sa.Integer(), nullable=False), + sa.Column("position", sa.Integer(), nullable=False), sa.ForeignKeyConstraint( - ['element_id', 'dimension_id'], - ['entity.id', 'entity.class_id'], - name=op.f('fk_entity_element_element_id_entity'), - onupdate='CASCADE', - ondelete='CASCADE', + ["element_id", "dimension_id"], + ["entity.id", "entity.class_id"], + name=op.f("fk_entity_element_element_id_entity"), + onupdate="CASCADE", + ondelete="CASCADE", ), sa.ForeignKeyConstraint( - ['entity_class_id', 'dimension_id', 'position'], + ["entity_class_id", "dimension_id", "position"], [ - 'entity_class_dimension.entity_class_id', - 'entity_class_dimension.dimension_id', - 'entity_class_dimension.position', + "entity_class_dimension.entity_class_id", + "entity_class_dimension.dimension_id", + "entity_class_dimension.position", ], - name=op.f('fk_entity_element_entity_class_id_entity_class_dimension'), - onupdate='CASCADE', - ondelete='CASCADE', + name=op.f("fk_entity_element_entity_class_id_entity_class_dimension"), + onupdate="CASCADE", + ondelete="CASCADE", ), sa.ForeignKeyConstraint( - ['entity_id', 'entity_class_id'], - ['entity.id', 'entity.class_id'], - name=op.f('fk_entity_element_entity_id_entity'), - onupdate='CASCADE', - ondelete='CASCADE', + ["entity_id", "entity_class_id"], + ["entity.id", "entity.class_id"], + name=op.f("fk_entity_element_entity_id_entity"), + onupdate="CASCADE", + ondelete="CASCADE", ), - sa.PrimaryKeyConstraint('entity_id', 'position', name=op.f('pk_entity_element')), + sa.PrimaryKeyConstraint("entity_id", "position", name=op.f("pk_entity_element")), ) _persist_data() # NOTE: some constraints are only created by the create_new_spine_database() function, @@ -79,28 +79,28 @@ def upgrade(): # We should avoid this in the future. entity_class_constraints, entity_constraints = _get_constraints() with op.batch_alter_table("entity", naming_convention=naming_convention) as batch_op: - for cname in ('uq_entity_idclass_id', 'uq_entity_idtype_idclass_id'): + for cname in ("uq_entity_idclass_id", "uq_entity_idtype_idclass_id"): if cname in entity_constraints: - batch_op.drop_constraint(cname, type_='unique') - batch_op.drop_constraint('fk_entity_type_id_entity_type', type_='foreignkey') - batch_op.drop_column('type_id') + batch_op.drop_constraint(cname, type_="unique") + batch_op.drop_constraint("fk_entity_type_id_entity_type", type_="foreignkey") + batch_op.drop_column("type_id") with op.batch_alter_table("entity_class", naming_convention=naming_convention) as batch_op: - for cname in ('uq_entity_class_idtype_id', 'uq_entity_class_type_idname'): + for cname in ("uq_entity_class_idtype_id", "uq_entity_class_type_idname"): if cname in entity_class_constraints: - batch_op.drop_constraint(cname, type_='unique') - batch_op.drop_constraint('fk_entity_class_type_id_entity_class_type', type_='foreignkey') - batch_op.drop_constraint('fk_entity_class_commit_id_commit', type_='foreignkey') - batch_op.drop_column('commit_id') - batch_op.drop_column('type_id') - op.drop_table('object_class') - op.drop_table('entity_class_type') + batch_op.drop_constraint(cname, type_="unique") + batch_op.drop_constraint("fk_entity_class_type_id_entity_class_type", type_="foreignkey") + batch_op.drop_constraint("fk_entity_class_commit_id_commit", type_="foreignkey") + batch_op.drop_column("commit_id") + batch_op.drop_column("type_id") + op.drop_table("object_class") + op.drop_table("entity_class_type") # op.drop_table('next_id') - op.drop_table('object') - op.drop_table('relationship_entity_class') - op.drop_table('relationship') - op.drop_table('entity_type') - op.drop_table('relationship_class') - op.drop_table('relationship_entity') + op.drop_table("object") + op.drop_table("relationship_entity_class") + op.drop_table("relationship") + op.drop_table("entity_type") + op.drop_table("relationship_class") + op.drop_table("relationship_entity") def _get_constraints(): diff --git a/spinedb_api/alembic/versions/738d494a08ac_fix_foreign_key_constraints_in_object_.py b/spinedb_api/alembic/versions/738d494a08ac_fix_foreign_key_constraints_in_object_.py index 6933e94f..a0dce871 100644 --- a/spinedb_api/alembic/versions/738d494a08ac_fix_foreign_key_constraints_in_object_.py +++ b/spinedb_api/alembic/versions/738d494a08ac_fix_foreign_key_constraints_in_object_.py @@ -10,30 +10,30 @@ # revision identifiers, used by Alembic. -revision = '738d494a08ac' -down_revision = '1e4997105288' +revision = "738d494a08ac" +down_revision = "1e4997105288" branch_labels = None depends_on = None def upgrade(): with op.batch_alter_table("object", naming_convention=naming_convention) as batch_op: - batch_op.drop_constraint('fk_object_entity_id_entity', type_='foreignkey') + batch_op.drop_constraint("fk_object_entity_id_entity", type_="foreignkey") batch_op.create_foreign_key( - op.f('fk_object_entity_id_entity'), - 'entity', - ['entity_id', 'type_id'], - ['id', 'type_id'], + op.f("fk_object_entity_id_entity"), + "entity", + ["entity_id", "type_id"], + ["id", "type_id"], onupdate="CASCADE", ondelete="CASCADE", ) with op.batch_alter_table("relationship", naming_convention=naming_convention) as batch_op: - batch_op.drop_constraint('fk_relationship_entity_id_entity', type_='foreignkey') + batch_op.drop_constraint("fk_relationship_entity_id_entity", type_="foreignkey") batch_op.create_foreign_key( - op.f('fk_relationship_entity_id_entity'), - 'entity', - ['entity_id', 'type_id'], - ['id', 'type_id'], + op.f("fk_relationship_entity_id_entity"), + "entity", + ["entity_id", "type_id"], + ["id", "type_id"], onupdate="CASCADE", ondelete="CASCADE", ) diff --git a/spinedb_api/alembic/versions/7d0b467f2f4e_fix_foreign_key_constraints_in_entity_.py b/spinedb_api/alembic/versions/7d0b467f2f4e_fix_foreign_key_constraints_in_entity_.py index c5dfdfc8..50379da7 100644 --- a/spinedb_api/alembic/versions/7d0b467f2f4e_fix_foreign_key_constraints_in_entity_.py +++ b/spinedb_api/alembic/versions/7d0b467f2f4e_fix_foreign_key_constraints_in_entity_.py @@ -10,29 +10,29 @@ # revision identifiers, used by Alembic. -revision = '7d0b467f2f4e' -down_revision = 'fd542cebf699' +revision = "7d0b467f2f4e" +down_revision = "fd542cebf699" branch_labels = None depends_on = None def upgrade(): with op.batch_alter_table("object_class", naming_convention=naming_convention) as batch_op: - batch_op.drop_constraint('fk_object_class_entity_class_id_entity_class', type_='foreignkey') + batch_op.drop_constraint("fk_object_class_entity_class_id_entity_class", type_="foreignkey") batch_op.create_foreign_key( - op.f('fk_object_class_entity_class_id_entity_class'), - 'entity_class', - ['entity_class_id', 'type_id'], - ['id', 'type_id'], + op.f("fk_object_class_entity_class_id_entity_class"), + "entity_class", + ["entity_class_id", "type_id"], + ["id", "type_id"], ondelete="CASCADE", ) with op.batch_alter_table("relationship_class", naming_convention=naming_convention) as batch_op: - batch_op.drop_constraint('fk_relationship_class_entity_class_id_entity_class', type_='foreignkey') + batch_op.drop_constraint("fk_relationship_class_entity_class_id_entity_class", type_="foreignkey") batch_op.create_foreign_key( - op.f('fk_relationship_class_entity_class_id_entity_class'), - 'entity_class', - ['entity_class_id', 'type_id'], - ['id', 'type_id'], + op.f("fk_relationship_class_entity_class_id_entity_class"), + "entity_class", + ["entity_class_id", "type_id"], + ["id", "type_id"], ondelete="CASCADE", ) diff --git a/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py b/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py index 3cbd6ede..bf3538e3 100644 --- a/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py +++ b/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py @@ -12,8 +12,8 @@ from spinedb_api.compatibility import convert_tool_feature_method_to_active_by_default # revision identifiers, used by Alembic. -revision = '8b0eff478bcb' -down_revision = '5385f063bef2' +revision = "8b0eff478bcb" +down_revision = "5385f063bef2" branch_labels = None depends_on = None diff --git a/spinedb_api/alembic/versions/989fccf80441_replace_values_with_reference_to_list_.py b/spinedb_api/alembic/versions/989fccf80441_replace_values_with_reference_to_list_.py index 2183c42f..920437ff 100644 --- a/spinedb_api/alembic/versions/989fccf80441_replace_values_with_reference_to_list_.py +++ b/spinedb_api/alembic/versions/989fccf80441_replace_values_with_reference_to_list_.py @@ -15,8 +15,8 @@ # revision identifiers, used by Alembic. -revision = '989fccf80441' -down_revision = '0c7d199ae915' +revision = "989fccf80441" +down_revision = "0c7d199ae915" branch_labels = None depends_on = None diff --git a/spinedb_api/alembic/versions/9da58d2def22_create_entity_group_table.py b/spinedb_api/alembic/versions/9da58d2def22_create_entity_group_table.py index ef172628..da5bfe79 100644 --- a/spinedb_api/alembic/versions/9da58d2def22_create_entity_group_table.py +++ b/spinedb_api/alembic/versions/9da58d2def22_create_entity_group_table.py @@ -10,8 +10,8 @@ # revision identifiers, used by Alembic. -revision = '9da58d2def22' -down_revision = '070a0eb89e88' +revision = "9da58d2def22" +down_revision = "070a0eb89e88" branch_labels = None depends_on = None diff --git a/spinedb_api/alembic/versions/ce9faa82ed59_create_entity_alternative_table.py b/spinedb_api/alembic/versions/ce9faa82ed59_create_entity_alternative_table.py index ccc873e2..53816e04 100644 --- a/spinedb_api/alembic/versions/ce9faa82ed59_create_entity_alternative_table.py +++ b/spinedb_api/alembic/versions/ce9faa82ed59_create_entity_alternative_table.py @@ -11,40 +11,40 @@ # revision identifiers, used by Alembic. -revision = 'ce9faa82ed59' -down_revision = '6b7c994c1c61' +revision = "ce9faa82ed59" +down_revision = "6b7c994c1c61" branch_labels = None depends_on = None def upgrade(): op.create_table( - 'entity_alternative', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('entity_id', sa.Integer(), nullable=False), - sa.Column('alternative_id', sa.Integer(), nullable=False), - sa.Column('active', sa.Boolean(name='active'), server_default=sa.text('1'), nullable=False), - sa.Column('commit_id', sa.Integer(), nullable=True), + "entity_alternative", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("entity_id", sa.Integer(), nullable=False), + sa.Column("alternative_id", sa.Integer(), nullable=False), + sa.Column("active", sa.Boolean(name="active"), server_default=sa.text("1"), nullable=False), + sa.Column("commit_id", sa.Integer(), nullable=True), sa.ForeignKeyConstraint( - ['alternative_id'], - ['alternative.id'], - name=op.f('fk_entity_alternative_alternative_id_alternative'), - onupdate='CASCADE', - ondelete='CASCADE', + ["alternative_id"], + ["alternative.id"], + name=op.f("fk_entity_alternative_alternative_id_alternative"), + onupdate="CASCADE", + ondelete="CASCADE", ), - sa.ForeignKeyConstraint(['commit_id'], ['commit.id'], name=op.f('fk_entity_alternative_commit_id_commit')), + sa.ForeignKeyConstraint(["commit_id"], ["commit.id"], name=op.f("fk_entity_alternative_commit_id_commit")), sa.ForeignKeyConstraint( - ['entity_id'], - ['entity.id'], - name=op.f('fk_entity_alternative_entity_id_entity'), - onupdate='CASCADE', - ondelete='CASCADE', + ["entity_id"], + ["entity.id"], + name=op.f("fk_entity_alternative_entity_id_entity"), + onupdate="CASCADE", + ondelete="CASCADE", ), - sa.PrimaryKeyConstraint('id', name=op.f('pk_entity_alternative')), - sa.UniqueConstraint('entity_id', 'alternative_id', name=op.f('uq_entity_alternative_entity_idalternative_id')), + sa.PrimaryKeyConstraint("id", name=op.f("pk_entity_alternative")), + sa.UniqueConstraint("entity_id", "alternative_id", name=op.f("uq_entity_alternative_entity_idalternative_id")), ) try: - op.drop_table('next_id') + op.drop_table("next_id") except sa.exc.OperationalError: pass convert_tool_feature_method_to_entity_alternative(op.get_bind(), use_existing_tool_feature_method=True, apply=True) diff --git a/spinedb_api/alembic/versions/defbda3bf2b5_add_tool_feature_tables.py b/spinedb_api/alembic/versions/defbda3bf2b5_add_tool_feature_tables.py index 8c1034fb..3fc1241b 100644 --- a/spinedb_api/alembic/versions/defbda3bf2b5_add_tool_feature_tables.py +++ b/spinedb_api/alembic/versions/defbda3bf2b5_add_tool_feature_tables.py @@ -10,8 +10,8 @@ # revision identifiers, used by Alembic. -revision = 'defbda3bf2b5' -down_revision = '39e860a11b05' +revision = "defbda3bf2b5" +down_revision = "39e860a11b05" branch_labels = None depends_on = None diff --git a/spinedb_api/alembic/versions/fbb540efbf15_add_support_for_mysql.py b/spinedb_api/alembic/versions/fbb540efbf15_add_support_for_mysql.py index 4032d62e..5d209cc9 100644 --- a/spinedb_api/alembic/versions/fbb540efbf15_add_support_for_mysql.py +++ b/spinedb_api/alembic/versions/fbb540efbf15_add_support_for_mysql.py @@ -11,8 +11,8 @@ # revision identifiers, used by Alembic. -revision = 'fbb540efbf15' -down_revision = '1892adebc00f' +revision = "fbb540efbf15" +down_revision = "1892adebc00f" branch_labels = None depends_on = None @@ -20,95 +20,95 @@ def upgrade(): # 1. Drop referential actions in foreign keys with columns also having a check contraint with op.batch_alter_table("object_class", naming_convention=naming_convention) as batch_op: - batch_op.drop_constraint('fk_object_class_entity_class_id_entity_class', type_='foreignkey') + batch_op.drop_constraint("fk_object_class_entity_class_id_entity_class", type_="foreignkey") batch_op.create_foreign_key( - op.f('fk_object_class_entity_class_id_entity_class'), - 'entity_class', - ['entity_class_id', 'type_id'], - ['id', 'type_id'], + op.f("fk_object_class_entity_class_id_entity_class"), + "entity_class", + ["entity_class_id", "type_id"], + ["id", "type_id"], ) with op.batch_alter_table("relationship_class", naming_convention=naming_convention) as batch_op: - batch_op.drop_constraint('fk_relationship_class_entity_class_id_entity_class', type_='foreignkey') + batch_op.drop_constraint("fk_relationship_class_entity_class_id_entity_class", type_="foreignkey") batch_op.create_foreign_key( - op.f('fk_relationship_class_entity_class_id_entity_class'), - 'entity_class', - ['entity_class_id', 'type_id'], - ['id', 'type_id'], + op.f("fk_relationship_class_entity_class_id_entity_class"), + "entity_class", + ["entity_class_id", "type_id"], + ["id", "type_id"], ) with op.batch_alter_table("relationship_entity_class", naming_convention=naming_convention) as batch_op: - batch_op.drop_constraint('fk_relationship_entity_class_member_class_id_entity_class', type_='foreignkey') + batch_op.drop_constraint("fk_relationship_entity_class_member_class_id_entity_class", type_="foreignkey") batch_op.create_foreign_key( - op.f('fk_relationship_entity_class_member_class_id_entity_class'), - 'entity_class', - ['member_class_id', 'member_class_type_id'], - ['id', 'type_id'], + op.f("fk_relationship_entity_class_member_class_id_entity_class"), + "entity_class", + ["member_class_id", "member_class_type_id"], + ["id", "type_id"], ) with op.batch_alter_table("object", naming_convention=naming_convention) as batch_op: - batch_op.drop_constraint('fk_object_entity_id_entity', type_='foreignkey') + batch_op.drop_constraint("fk_object_entity_id_entity", type_="foreignkey") batch_op.create_foreign_key( - op.f('fk_object_entity_id_entity'), 'entity', ['entity_id', 'type_id'], ['id', 'type_id'] + op.f("fk_object_entity_id_entity"), "entity", ["entity_id", "type_id"], ["id", "type_id"] ) with op.batch_alter_table("relationship", naming_convention=naming_convention) as batch_op: - batch_op.drop_constraint('fk_relationship_entity_id_entity', type_='foreignkey') + batch_op.drop_constraint("fk_relationship_entity_id_entity", type_="foreignkey") batch_op.create_foreign_key( - op.f('fk_relationship_entity_id_entity'), 'entity', ['entity_id', 'type_id'], ['id', 'type_id'] + op.f("fk_relationship_entity_id_entity"), "entity", ["entity_id", "type_id"], ["id", "type_id"] ) # 2. Add new unique constraints required to make some foreign keys work with op.batch_alter_table("relationship_entity_class", naming_convention=naming_convention) as batch_op: batch_op.create_unique_constraint( - 'uq_relationship_entity_class', ['entity_class_id', 'dimension', 'member_class_id'] + "uq_relationship_entity_class", ["entity_class_id", "dimension", "member_class_id"] ) with op.batch_alter_table("relationship", naming_convention=naming_convention) as batch_op: batch_op.create_unique_constraint( - op.f('uq_relationship_entity_identity_class_id'), ['entity_id', 'entity_class_id'] + op.f("uq_relationship_entity_identity_class_id"), ["entity_id", "entity_class_id"] ) with op.batch_alter_table("parameter_definition", naming_convention=naming_convention) as batch_op: batch_op.create_unique_constraint( - op.f('uq_parameter_definition_idparameter_value_list_id'), ['id', 'parameter_value_list_id'] + op.f("uq_parameter_definition_idparameter_value_list_id"), ["id", "parameter_value_list_id"] ) with op.batch_alter_table("feature", naming_convention=naming_convention) as batch_op: batch_op.create_unique_constraint( - op.f('uq_feature_idparameter_value_list_id'), ['id', 'parameter_value_list_id'] + op.f("uq_feature_idparameter_value_list_id"), ["id", "parameter_value_list_id"] ) with op.batch_alter_table("tool_feature", naming_convention=naming_convention) as batch_op: batch_op.create_unique_constraint( - op.f('uq_tool_feature_idparameter_value_list_id'), ['id', 'parameter_value_list_id'] + op.f("uq_tool_feature_idparameter_value_list_id"), ["id", "parameter_value_list_id"] ) # 3. Rename constraints having too long name with op.batch_alter_table("parameter_definition_tag", naming_convention=naming_convention) as batch_op: batch_op.create_unique_constraint( - 'uq_parameter_definition_tag', ['parameter_definition_id', 'parameter_tag_id'] + "uq_parameter_definition_tag", ["parameter_definition_id", "parameter_tag_id"] ) - batch_op.drop_constraint('uq_parameter_definition_tag_parameter_definition_idparameter_tag_id', type_='unique') + batch_op.drop_constraint("uq_parameter_definition_tag_parameter_definition_idparameter_tag_id", type_="unique") with op.batch_alter_table("parameter_value", naming_convention=naming_convention) as batch_op: batch_op.create_unique_constraint( - 'uq_parameter_value', ['parameter_definition_id', 'entity_id', 'alternative_id'] + "uq_parameter_value", ["parameter_definition_id", "entity_id", "alternative_id"] ) - batch_op.drop_constraint('uq_parameter_value_parameter_definition_identity_idalternative_id', type_='unique') + batch_op.drop_constraint("uq_parameter_value_parameter_definition_identity_idalternative_id", type_="unique") # 4. Extend length of fields holding parameter values with op.batch_alter_table("parameter_definition") as batch_op: - batch_op.alter_column('default_value', type_=Text(LONGTEXT_LENGTH)) + batch_op.alter_column("default_value", type_=Text(LONGTEXT_LENGTH)) with op.batch_alter_table("parameter_value") as batch_op: - batch_op.alter_column('value', type_=Text(LONGTEXT_LENGTH)) + batch_op.alter_column("value", type_=Text(LONGTEXT_LENGTH)) with op.batch_alter_table("parameter_value_list") as batch_op: - batch_op.alter_column('value', type_=Text(LONGTEXT_LENGTH)) + batch_op.alter_column("value", type_=Text(LONGTEXT_LENGTH)) # 5. Extend length of description fields with op.batch_alter_table("alternative") as batch_op: - batch_op.alter_column('description', type_=Text()) + batch_op.alter_column("description", type_=Text()) with op.batch_alter_table("scenario") as batch_op: - batch_op.alter_column('description', type_=Text()) + batch_op.alter_column("description", type_=Text()) with op.batch_alter_table("entity_class") as batch_op: - batch_op.alter_column('description', type_=Text()) + batch_op.alter_column("description", type_=Text()) with op.batch_alter_table("entity") as batch_op: - batch_op.alter_column('description', type_=Text()) + batch_op.alter_column("description", type_=Text()) with op.batch_alter_table("parameter_definition") as batch_op: - batch_op.alter_column('description', type_=Text()) + batch_op.alter_column("description", type_=Text()) with op.batch_alter_table("parameter_tag") as batch_op: - batch_op.alter_column('description', type_=Text()) + batch_op.alter_column("description", type_=Text()) with op.batch_alter_table("tool") as batch_op: - batch_op.alter_column('description', type_=Text()) + batch_op.alter_column("description", type_=Text()) with op.batch_alter_table("feature") as batch_op: - batch_op.alter_column('description', type_=Text()) + batch_op.alter_column("description", type_=Text()) def downgrade(): diff --git a/spinedb_api/alembic/versions/fd542cebf699_drop_on_update_clauses_from_object_and_.py b/spinedb_api/alembic/versions/fd542cebf699_drop_on_update_clauses_from_object_and_.py index 4a5a21f9..593587bd 100644 --- a/spinedb_api/alembic/versions/fd542cebf699_drop_on_update_clauses_from_object_and_.py +++ b/spinedb_api/alembic/versions/fd542cebf699_drop_on_update_clauses_from_object_and_.py @@ -10,29 +10,29 @@ # revision identifiers, used by Alembic. -revision = 'fd542cebf699' -down_revision = '738d494a08ac' +revision = "fd542cebf699" +down_revision = "738d494a08ac" branch_labels = None depends_on = None def upgrade(): with op.batch_alter_table("object", naming_convention=naming_convention) as batch_op: - batch_op.drop_constraint('fk_object_entity_id_entity', type_='foreignkey') + batch_op.drop_constraint("fk_object_entity_id_entity", type_="foreignkey") batch_op.create_foreign_key( - op.f('fk_object_entity_id_entity'), - 'entity', - ['entity_id', 'type_id'], - ['id', 'type_id'], + op.f("fk_object_entity_id_entity"), + "entity", + ["entity_id", "type_id"], + ["id", "type_id"], ondelete="CASCADE", ) with op.batch_alter_table("relationship", naming_convention=naming_convention) as batch_op: - batch_op.drop_constraint('fk_relationship_entity_id_entity', type_='foreignkey') + batch_op.drop_constraint("fk_relationship_entity_id_entity", type_="foreignkey") batch_op.create_foreign_key( - op.f('fk_relationship_entity_id_entity'), - 'entity', - ['entity_id', 'type_id'], - ['id', 'type_id'], + op.f("fk_relationship_entity_id_entity"), + "entity", + ["entity_id", "type_id"], + ["id", "type_id"], ondelete="CASCADE", ) diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 11e9d03a..8cb17ae7 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -171,7 +171,7 @@ def __init__( ) # NOTE: The NullPool is needed to receive the close event (or any events), for some reason self.engine = create_engine("sqlite://", poolclass=NullPool) if self._memory else self._original_engine - listen(self.engine, 'close', self._receive_engine_close) + listen(self.engine, "close", self._receive_engine_close) if self._memory: copy_database_bind(self.engine, self._original_engine) self._metadata = MetaData(self.engine) @@ -278,7 +278,7 @@ def get_upgrade_db_prompt_data(url, create=False): @staticmethod def create_engine(sa_url, create=False, upgrade=False, backup_url="", sqlite_timeout=1800): if sa_url.drivername == "sqlite": - connect_args = {'timeout': sqlite_timeout} + connect_args = {"timeout": sqlite_timeout} else: connect_args = {} try: @@ -894,7 +894,7 @@ def _uq_fields(factory): def _kwargs(fields): def type_(f_dict): - return f_dict['type'].__name__ + (', optional' if f_dict.get('optional', False) else '') + return f_dict["type"].__name__ + (", optional" if f_dict.get("optional", False) else "") return f"\n{padding}".join( [f"{f_name} ({type_(f_dict)}): {f_dict['value']}" for f_name, f_dict in fields.items()] diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index 796c9b8f..f2e92baa 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -131,7 +131,7 @@ def compile_DOUBLE_mysql_sqlite(element, compiler, **kw): class group_concat(FunctionElement): type = String() - name = 'group_concat' + name = "group_concat" def _parse_group_concat_clauses(clauses): diff --git a/spinedb_api/import_mapping/import_mapping.py b/spinedb_api/import_mapping/import_mapping.py index c521a55e..af20d677 100644 --- a/spinedb_api/import_mapping/import_mapping.py +++ b/spinedb_api/import_mapping/import_mapping.py @@ -149,7 +149,7 @@ def check_for_invalid_column_refs(self, header, table_name): if error: return error if isinstance(self.position, int) and self.position >= len(header) > 0: - msg = f"Column ref {self.position + 1} is out of range for the source table \"{table_name}\"" + msg = f'Column ref {self.position + 1} is out of range for the source table "{table_name}"' return msg return "" @@ -206,7 +206,7 @@ def _polish_for_import(self, table_name, source_header, column_count): msg = f"'{self.value}' is not a valid index in header '{source_header}'" raise InvalidMappingComponent(msg) if isinstance(self.position, int) and self.position >= column_count > 0: - msg = f"Column ref {self.position + 1} is out of range for the source table \"{table_name}\"" + msg = f'Column ref {self.position + 1} is out of range for the source table "{table_name}"' raise InvalidMappingComponent(msg) def _polish_for_preview(self, source_header): diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index 3fc447ec..03e4b7f3 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -38,14 +38,14 @@ def item_factory(item_type): }.get(item_type, MappedItemBase) -_ENTITY_BYNAME_VALUE = 'A tuple with the entity name as single element if the entity is zero-dimensional, or the element names if the entity is multi-dimensional.' +_ENTITY_BYNAME_VALUE = "A tuple with the entity name as single element if the entity is zero-dimensional, or the element names if the entity is multi-dimensional." class CommitItem(MappedItemBase): fields = { - 'comment': {'type': str, 'value': 'A comment describing the commit.'}, - 'date': {'type': str, 'value': 'Date and time of the commit in ISO 8601 format.'}, - 'user': {'type': str, 'value': 'Username of the committer.'}, + "comment": {"type": str, "value": "A comment describing the commit."}, + "date": {"type": str, "value": "Date and time of the commit in ISO 8601 format."}, + "user": {"type": str, "value": "Username of the committer."}, } _unique_keys = (("date",),) is_protected = True @@ -56,20 +56,20 @@ def commit(self, commit_id): class EntityClassItem(MappedItemBase): fields = { - 'name': {'type': str, 'value': 'The class name.'}, - 'dimension_name_list': { - 'type': tuple, - 'value': 'The dimension names for a multi-dimensional class.', - 'optional': True, + "name": {"type": str, "value": "The class name."}, + "dimension_name_list": { + "type": tuple, + "value": "The dimension names for a multi-dimensional class.", + "optional": True, }, - 'description': {'type': str, 'value': 'The class description.', 'optional': True}, - 'display_icon': { - 'type': int, - 'value': 'An integer representing an icon within your application.', - 'optional': True, + "description": {"type": str, "value": "The class description.", "optional": True}, + "display_icon": { + "type": int, + "value": "An integer representing an icon within your application.", + "optional": True, }, - 'display_order': {'type': int, 'value': 'Not in use at the moment.', 'optional': True}, - 'hidden': {'type': int, 'value': 'Not in use at the moment.', 'optional': True}, + "display_order": {"type": int, "value": "Not in use at the moment.", "optional": True}, + "hidden": {"type": int, "value": "Not in use at the moment.", "optional": True}, "active_by_default": { "type": bool, "value": "Default activity for the entity alternatives of the class.", @@ -121,15 +121,15 @@ def commit(self, _commit_id): class EntityItem(MappedItemBase): fields = { - 'entity_class_name': {'type': str, 'value': 'The entity class name.'}, - 'name': {'type': str, 'value': 'The entity name.'}, - 'element_name_list': {'type': tuple, 'value': 'The element names if the entity is multi-dimensional.'}, - 'entity_byname': { - 'type': tuple, - 'value': 'A tuple with the entity name as single element if the entity is zero-dimensional,' - 'or the element names if it is multi-dimensional.', + "entity_class_name": {"type": str, "value": "The entity class name."}, + "name": {"type": str, "value": "The entity name."}, + "element_name_list": {"type": tuple, "value": "The element names if the entity is multi-dimensional."}, + "entity_byname": { + "type": tuple, + "value": "A tuple with the entity name as single element if the entity is zero-dimensional," + "or the element names if it is multi-dimensional.", }, - 'description': {'type': str, 'value': 'The entity description.', 'optional': True}, + "description": {"type": str, "value": "The entity description.", "optional": True}, } _defaults = {"description": None} @@ -263,9 +263,9 @@ def polish(self): class EntityGroupItem(MappedItemBase): fields = { - 'entity_class_name': {'type': str, 'value': 'The entity class name.'}, - 'group_name': {'type': str, 'value': 'The group entity name.'}, - 'member_name': {'type': str, 'value': 'The member entity name.'}, + "entity_class_name": {"type": str, "value": "The entity class name."}, + "group_name": {"type": str, "value": "The group entity name."}, + "member_name": {"type": str, "value": "The member entity name."}, } _unique_keys = (("entity_class_name", "group_name", "member_name"),) _references = { @@ -303,16 +303,16 @@ def commit(self, _commit_id): class EntityAlternativeItem(MappedItemBase): fields = { - 'entity_class_name': {'type': str, 'value': 'The entity class name.'}, - 'entity_byname': { - 'type': tuple, - 'value': _ENTITY_BYNAME_VALUE, + "entity_class_name": {"type": str, "value": "The entity class name."}, + "entity_byname": { + "type": tuple, + "value": _ENTITY_BYNAME_VALUE, }, - 'alternative_name': {'type': str, 'value': 'The alternative name.'}, - 'active': { - 'type': bool, - 'value': 'Whether the entity is active in the alternative - defaults to True.', - 'optional': True, + "alternative_name": {"type": str, "value": "The alternative name."}, + "active": { + "type": bool, + "value": "Whether the entity is active in the alternative - defaults to True.", + "optional": True, }, } _defaults = {"active": True} @@ -452,16 +452,16 @@ def polish(self): class ParameterDefinitionItem(ParameterItemBase): fields = { - 'entity_class_name': {'type': str, 'value': 'The entity class name.'}, - 'name': {'type': str, 'value': 'The parameter name.'}, - 'default_value': {'type': bytes, 'value': 'The default value.', 'optional': True}, - 'default_type': {'type': str, 'value': 'The default value type.', 'optional': True}, - 'parameter_value_list_name': { - 'type': str, - 'value': 'The parameter value list name if any.', - 'optional': True, + "entity_class_name": {"type": str, "value": "The entity class name."}, + "name": {"type": str, "value": "The parameter name."}, + "default_value": {"type": bytes, "value": "The default value.", "optional": True}, + "default_type": {"type": str, "value": "The default value type.", "optional": True}, + "parameter_value_list_name": { + "type": str, + "value": "The parameter value list name if any.", + "optional": True, }, - 'description': {'type': str, 'value': 'The parameter description.', 'optional': True}, + "description": {"type": str, "value": "The parameter description.", "optional": True}, } _defaults = {"description": None, "default_value": None, "default_type": None, "parameter_value_list_id": None} _unique_keys = (("entity_class_name", "name"),) @@ -526,15 +526,15 @@ def _value_not_in_list_error(self, parsed_value, list_name): class ParameterValueItem(ParameterItemBase): fields = { - 'entity_class_name': {'type': str, 'value': 'The entity class name.'}, - 'parameter_definition_name': {'type': str, 'value': 'The parameter name.'}, - 'entity_byname': { - 'type': tuple, - 'value': _ENTITY_BYNAME_VALUE, + "entity_class_name": {"type": str, "value": "The entity class name."}, + "parameter_definition_name": {"type": str, "value": "The parameter name."}, + "entity_byname": { + "type": tuple, + "value": _ENTITY_BYNAME_VALUE, }, - 'value': {'type': bytes, 'value': 'The value.'}, - 'type': {'type': str, 'value': 'The value type.', 'optional': True}, - 'alternative_name': {'type': str, 'value': "The alternative name - defaults to 'Base'.", 'optional': True}, + "value": {"type": bytes, "value": "The value."}, + "type": {"type": str, "value": "The value type.", "optional": True}, + "alternative_name": {"type": str, "value": "The alternative name - defaults to 'Base'.", "optional": True}, } _unique_keys = (("entity_class_name", "parameter_definition_name", "entity_byname", "alternative_name"),) _references = { @@ -596,16 +596,16 @@ def _value_not_in_list_error(self, parsed_value, list_name): class ParameterValueListItem(MappedItemBase): - fields = {'name': {'type': str, 'value': 'The parameter value list name.'}} + fields = {"name": {"type": str, "value": "The parameter value list name."}} _unique_keys = (("name",),) class ListValueItem(ParsedValueBase): fields = { - 'parameter_value_list_name': {'type': str, 'value': 'The parameter value list name.'}, - 'value': {'type': bytes, 'value': 'The value.'}, - 'type': {'type': str, 'value': 'The value type.', 'optional': True}, - 'index': {'type': int, 'value': 'The value index.', 'optional': True}, + "parameter_value_list_name": {"type": str, "value": "The parameter value list name."}, + "value": {"type": bytes, "value": "The value."}, + "type": {"type": str, "value": "The value type.", "optional": True}, + "index": {"type": int, "value": "The value index.", "optional": True}, } _unique_keys = (("parameter_value_list_name", "value_and_type"), ("parameter_value_list_name", "index")) _references = {"parameter_value_list_id": ("parameter_value_list", "id")} @@ -629,8 +629,8 @@ def __getitem__(self, key): class AlternativeItem(MappedItemBase): fields = { - 'name': {'type': str, 'value': 'The alternative name.'}, - 'description': {'type': str, 'value': 'The alternative description.', 'optional': True}, + "name": {"type": str, "value": "The alternative name."}, + "description": {"type": str, "value": "The alternative description.", "optional": True}, } _defaults = {"description": None} _unique_keys = (("name",),) @@ -638,9 +638,9 @@ class AlternativeItem(MappedItemBase): class ScenarioItem(MappedItemBase): fields = { - 'name': {'type': str, 'value': 'The scenario name.'}, - 'description': {'type': str, 'value': 'The scenario description.', 'optional': True}, - 'active': {'type': bool, 'value': 'Not in use at the moment.', 'optional': True}, + "name": {"type": str, "value": "The scenario name."}, + "description": {"type": str, "value": "The scenario description.", "optional": True}, + "active": {"type": bool, "value": "Not in use at the moment.", "optional": True}, } _defaults = {"active": False, "description": None} _unique_keys = (("name",),) @@ -665,9 +665,9 @@ def __getitem__(self, key): class ScenarioAlternativeItem(MappedItemBase): fields = { - 'scenario_name': {'type': str, 'value': 'The scenario name.'}, - 'alternative_name': {'type': str, 'value': 'The alternative name.'}, - 'rank': {'type': int, 'value': 'The rank - higher has precedence.'}, + "scenario_name": {"type": str, "value": "The scenario name."}, + "alternative_name": {"type": str, "value": "The alternative name."}, + "rank": {"type": int, "value": "The rank - higher has precedence."}, } _unique_keys = (("scenario_name", "alternative_name"), ("scenario_name", "rank")) _references = {"scenario_id": ("scenario", "id"), "alternative_id": ("alternative", "id")} @@ -693,18 +693,18 @@ def __getitem__(self, key): class MetadataItem(MappedItemBase): fields = { - 'name': {'type': str, 'value': 'The metadata entry name.'}, - 'value': {'type': str, 'value': 'The metadata entry value.'}, + "name": {"type": str, "value": "The metadata entry name."}, + "value": {"type": str, "value": "The metadata entry value."}, } _unique_keys = (("name", "value"),) class EntityMetadataItem(MappedItemBase): fields = { - 'entity_class_name': {'type': str, 'value': 'The entity class name.'}, - 'entity_byname': {'type': tuple, 'value': _ENTITY_BYNAME_VALUE}, - 'metadata_name': {'type': str, 'value': 'The metadata entry name.'}, - 'metadata_value': {'type': str, 'value': 'The metadata entry value.'}, + "entity_class_name": {"type": str, "value": "The entity class name."}, + "entity_byname": {"type": tuple, "value": _ENTITY_BYNAME_VALUE}, + "metadata_name": {"type": str, "value": "The metadata entry name."}, + "metadata_value": {"type": str, "value": "The metadata entry value."}, } _unique_keys = (("entity_class_name", "entity_byname", "metadata_name", "metadata_value"),) _references = { @@ -732,15 +732,15 @@ class EntityMetadataItem(MappedItemBase): class ParameterValueMetadataItem(MappedItemBase): fields = { - 'entity_class_name': {'type': str, 'value': 'The entity class name.'}, - 'parameter_definition_name': {'type': str, 'value': 'The parameter name.'}, - 'entity_byname': { - 'type': tuple, - 'value': _ENTITY_BYNAME_VALUE, + "entity_class_name": {"type": str, "value": "The entity class name."}, + "parameter_definition_name": {"type": str, "value": "The parameter name."}, + "entity_byname": { + "type": tuple, + "value": _ENTITY_BYNAME_VALUE, }, - 'alternative_name': {'type': str, 'value': 'The alternative name.'}, - 'metadata_name': {'type': str, 'value': 'The metadata entry name.'}, - 'metadata_value': {'type': str, 'value': 'The metadata entry value.'}, + "alternative_name": {"type": str, "value": "The alternative name."}, + "metadata_name": {"type": str, "value": "The metadata entry name."}, + "metadata_value": {"type": str, "value": "The metadata entry value."}, } _unique_keys = ( ( @@ -779,8 +779,8 @@ class ParameterValueMetadataItem(MappedItemBase): class SuperclassSubclassItem(MappedItemBase): fields = { - 'superclass_name': {'type': str, 'value': 'The superclass name.'}, - 'subclass_name': {'type': str, 'value': 'The subclass name.'}, + "superclass_name": {"type": str, "value": "The superclass name."}, + "subclass_name": {"type": str, "value": "The subclass name."}, } _unique_keys = (("subclass_name",),) _references = {"superclass_id": ("entity_class", "id"), "subclass_id": ("entity_class", "id")} diff --git a/spinedb_api/parameter_value.py b/spinedb_api/parameter_value.py index 73971fbc..6f22264b 100644 --- a/spinedb_api/parameter_value.py +++ b/spinedb_api/parameter_value.py @@ -679,7 +679,7 @@ def _array_from_database(value_dict): try: data = [value_type(x) for x in value_dict["data"]] except (TypeError, ParameterValueFormatError) as error: - raise ParameterValueFormatError(f'Failed to read values for Array: {error}') + raise ParameterValueFormatError(f"Failed to read values for Array: {error}") else: index_name = value_dict.get("index_name", Array.DEFAULT_INDEX_NAME) return Array(data, value_type, index_name) @@ -877,11 +877,11 @@ def __array_finalize__(self, obj): if obj is None: return # pylint: disable=attribute-defined-outside-init - self.position_lookup = getattr(obj, 'position_lookup', {}) + self.position_lookup = getattr(obj, "position_lookup", {}) def __setitem__(self, position, index): old_index = self.__getitem__(position) - self.position_lookup[index] = self.position_lookup.pop(old_index, '') + self.position_lookup[index] = self.position_lookup.pop(old_index, "") super().__setitem__(position, index) def __eq__(self, other): @@ -1019,7 +1019,7 @@ def merge(self, other): if not isinstance(other, type(self)): return self new_indexes = np.unique(np.concatenate((self.indexes, other.indexes))) - new_indexes.sort(kind='mergesort') + new_indexes.sort(kind="mergesort") _merge = lambda value, other: other if value is None else merge_parsed(value, other) new_values = [_merge(self.get_value(index), other.get_value(index)) for index in new_indexes] self.indexes = new_indexes diff --git a/spinedb_api/perfect_split.py b/spinedb_api/perfect_split.py index 287a4a87..4f40cab1 100644 --- a/spinedb_api/perfect_split.py +++ b/spinedb_api/perfect_split.py @@ -55,7 +55,7 @@ def perfect_split(input_urls, intersection_url, diff_urls): if intersection_data: db_map_intersection = DatabaseMapping(intersection_url) import_data(db_map_intersection, **intersection_data) - all_db_names = ', '.join(db_names.values()) + all_db_names = ", ".join(db_names.values()) db_map_intersection.commit_session(f"Add intersection of {all_db_names}") db_map_intersection.connection.close() lookup = _make_lookup(intersection_data) @@ -65,7 +65,7 @@ def perfect_split(input_urls, intersection_url, diff_urls): _add_references(diff_data, lookup) import_data(diff_db_map, **diff_data) db_name = db_names[input_url] - other_db_names = ', '.join([name for url, name in db_names.items() if url != input_url]) + other_db_names = ", ".join([name for url, name in db_names.items() if url != input_url]) diff_db_map.commit_session(f"Add differences between {db_name} and {other_db_names}") diff_db_map.close() diff --git a/spinedb_api/server_client_helpers.py b/spinedb_api/server_client_helpers.py index 1dd44e67..863d2fb7 100644 --- a/spinedb_api/server_client_helpers.py +++ b/spinedb_api/server_client_helpers.py @@ -16,15 +16,15 @@ from .temp_id import TempId # Encode decode server messages -_START_OF_TAIL = '\u001f' # Unit separator -_START_OF_ADDRESS = '\u0091' # Private Use 1 -_ADDRESS_SEP = ':' +_START_OF_TAIL = "\u001f" # Unit separator +_START_OF_ADDRESS = "\u0091" # Private Use 1 +_ADDRESS_SEP = ":" class ReceiveAllMixing: _ENCODING = "utf-8" _BUFF_SIZE = 4096 - _EOT = '\u0004' # End of transmission + _EOT = "\u0004" # End of transmission _BEOT = _EOT.encode(_ENCODING) """End of message character""" diff --git a/tests/export_mapping/test_export_mapping.py b/tests/export_mapping/test_export_mapping.py index 00a5609e..0fa6900c 100644 --- a/tests/export_mapping/test_export_mapping.py +++ b/tests/export_mapping/test_export_mapping.py @@ -822,9 +822,9 @@ def test_export_relationships(self): element1_mapping = relationship_mapping.child = ElementMapping(4) element1_mapping.child = ElementMapping(5) expected = [ - ['rc1', 'oc1', '', 'o11__', 'o11', ''], - ['rc2', 'oc2', 'oc1', 'o21__o11', 'o21', 'o11'], - ['rc2', 'oc2', 'oc1', 'o21__o12', 'o21', 'o12'], + ["rc1", "oc1", "", "o11__", "o11", ""], + ["rc2", "oc2", "oc1", "o21__o11", "o21", "o11"], + ["rc2", "oc2", "oc1", "o21__o12", "o21", "o12"], ] self.assertEqual(list(rows(relationship_class_mapping, db_map)), expected) db_map.close() diff --git a/tests/filters/test_alternative_filter.py b/tests/filters/test_alternative_filter.py index 11e4b799..2ba9ec15 100644 --- a/tests/filters/test_alternative_filter.py +++ b/tests/filters/test_alternative_filter.py @@ -170,5 +170,5 @@ def test_quoted_alternative_names(self): self.assertEqual(config, {"type": "alternative_filter", "alternatives": ["alt:er:na:ti:ve", "alternative2"]}) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/filters/test_execution_filter.py b/tests/filters/test_execution_filter.py index c43a894e..6f5b706a 100644 --- a/tests/filters/test_execution_filter.py +++ b/tests/filters/test_execution_filter.py @@ -31,5 +31,5 @@ def test_import_alternative_after_applying_execution_filter(self): self.assertEqual(scenarios, {"low_on_steam", "wasting_my_time"}) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/filters/test_scenario_filter.py b/tests/filters/test_scenario_filter.py index 3912b14c..741e8bc8 100644 --- a/tests/filters/test_scenario_filter.py +++ b/tests/filters/test_scenario_filter.py @@ -552,5 +552,5 @@ def _build_data_with_single_scenario(db_map, commit=True): db_map.commit_session("Add test data.") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/filters/test_tool_filter.py b/tests/filters/test_tool_filter.py index 4e90e46d..c471ab99 100644 --- a/tests/filters/test_tool_filter.py +++ b/tests/filters/test_tool_filter.py @@ -217,5 +217,5 @@ def test_object_activity_control_filter(self): self.assertTrue("node1" not in pval_object_names) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/import_mapping/test_generator.py b/tests/import_mapping/test_generator.py index 589210ec..8bcb1e20 100644 --- a/tests/import_mapping/test_generator.py +++ b/tests/import_mapping/test_generator.py @@ -70,11 +70,11 @@ def test_returns_appropriate_error_if_last_row_is_empty(self): self.assertEqual( mapped_data, { - 'alternatives': {'Base'}, - 'entity_classes': [('Object',)], - 'parameter_values': [['Object', 'data', 'Parameter', Map(["T1", "T2"], [5.0, 99.0]), 'Base']], - 'parameter_definitions': [('Object', 'Parameter')], - 'entities': [('Object', 'data')], + "alternatives": {"Base"}, + "entity_classes": [("Object",)], + "parameter_values": [["Object", "data", "Parameter", Map(["T1", "T2"], [5.0, 99.0]), "Base"]], + "parameter_definitions": [("Object", "Parameter")], + "entities": [("Object", "data")], }, ) @@ -102,11 +102,11 @@ def test_convert_functions_get_expanded_over_last_defined_column_in_pivoted_data self.assertEqual( mapped_data, { - 'alternatives': {'Base'}, - 'entity_classes': [('Object',)], - 'parameter_values': [['Object', 'data', 'Parameter', Map(["T1", "T2"], [5.0, 99.0]), 'Base']], - 'parameter_definitions': [('Object', 'Parameter')], - 'entities': [('Object', 'data')], + "alternatives": {"Base"}, + "entity_classes": [("Object",)], + "parameter_values": [["Object", "data", "Parameter", Map(["T1", "T2"], [5.0, 99.0]), "Base"]], + "parameter_definitions": [("Object", "Parameter")], + "entities": [("Object", "data")], }, ) @@ -134,10 +134,10 @@ def test_read_start_row_skips_rows_in_pivoted_data(self): self.assertEqual( mapped_data, { - 'entity_classes': [('klass',)], - 'parameter_values': [['klass', 'kloss', 'Parameter_2', Map(["T1", "T2"], [2.3, 23.0])]], - 'parameter_definitions': [('klass', 'Parameter_2')], - 'entities': [('klass', 'kloss')], + "entity_classes": [("klass",)], + "parameter_values": [["klass", "kloss", "Parameter_2", Map(["T1", "T2"], [2.3, 23.0])]], + "parameter_definitions": [("klass", "Parameter_2")], + "entities": [("klass", "kloss")], }, ) @@ -188,7 +188,7 @@ def test_map_without_values_is_ignored_and_not_interpreted_as_null(self): mapped_data, { "alternatives": {"base"}, - 'entity_classes': [("o",)], + "entity_classes": [("o",)], "parameter_definitions": [("o", "parameter_name")], "parameter_values": [], "entities": [("o", "o1")], @@ -224,7 +224,7 @@ def test_import_object_works_with_multiple_relationship_object_imports(self): mapped_data, { "alternatives": {"base"}, - 'entity_classes': [("o",), ("q",), ("o_to_q", ("o", "q"))], + "entity_classes": [("o",), ("q",), ("o_to_q", ("o", "q"))], "entities": [ ("o", "o1"), ("q", "q1"), @@ -268,7 +268,7 @@ def test_default_convert_function_in_column_convert_functions(self): self.assertEqual( mapped_data, { - 'entity_classes': [("klass",)], + "entity_classes": [("klass",)], "parameter_values": [["klass", "kloss", "Parameter_2", Map(["T1", "T2"], [2.3, 23.0])]], "parameter_definitions": [("klass", "Parameter_2")], "entities": [("klass", "kloss")], @@ -296,7 +296,7 @@ def test_identity_function_is_used_as_convert_function_when_no_convert_functions self.assertEqual( mapped_data, { - 'entity_classes': [("klass",)], + "entity_classes": [("klass",)], "parameter_values": [["klass", "kloss", "Parameter_2", Map(["T1", "T2"], ["2.3", "23.0"])]], "parameter_definitions": [("klass", "Parameter_2")], "entities": [("klass", "kloss")], @@ -326,7 +326,7 @@ def test_last_convert_function_gets_used_as_default_convert_function_when_no_def self.assertEqual( mapped_data, { - 'entity_classes': [("klass",)], + "entity_classes": [("klass",)], "parameter_values": [["klass", "kloss", "Parameter_2", Map(["T1", "T2"], [2.3, 23.0])]], "parameter_definitions": [("klass", "Parameter_2")], "entities": [("klass", "kloss")], @@ -359,7 +359,7 @@ def test_array_parameters_get_imported_correctly_when_objects_are_in_header(self mapped_data, { "alternatives": {"Base"}, - 'entity_classes': [("class",)], + "entity_classes": [("class",)], "parameter_values": [ ["class", "object_1", "param", Array([-1.1, 1.1]), "Base"], ["class", "object_2", "param", Array([2.3, -2.3]), "Base"], @@ -395,7 +395,7 @@ def test_arrays_get_imported_correctly_when_objects_are_in_header_and_alternativ mapped_data, { "alternatives": {"Base"}, - 'entity_classes': [("Gadget",)], + "entity_classes": [("Gadget",)], "parameter_values": [ ["Gadget", "object_1", "data", Array([-1.1, 1.1]), "Base"], ["Gadget", "object_2", "data", Array([2.3, -2.3]), "Base"], @@ -443,5 +443,5 @@ def test_header_position_is_ignored_in_last_mapping_if_other_mappings_are_in_hea ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/import_mapping/test_import_mapping.py b/tests/import_mapping/test_import_mapping.py index d0f119d8..78f8d323 100644 --- a/tests/import_mapping/test_import_mapping.py +++ b/tests/import_mapping/test_import_mapping.py @@ -57,9 +57,9 @@ def test_convert_functions_float(self): param_def_mapping.flatten()[-1].position = 1 mapped_data, _ = get_mapped_data(data, [mapping], column_convert_fns=column_convert_fns) expected = { - 'entity_classes': [('a',)], - 'entities': [('a', 'obj')], - 'parameter_definitions': [('a', 'param', 1.2)], + "entity_classes": [("a",)], + "entities": [("a", "obj")], + "parameter_definitions": [("a", "param", 1.2)], } self.assertEqual(mapped_data, expected) @@ -76,9 +76,9 @@ def test_convert_functions_str(self): param_def_mapping.flatten()[-1].position = 1 mapped_data, _ = get_mapped_data(data, [mapping], column_convert_fns=column_convert_fns) expected = { - 'entity_classes': [('a',)], - 'entities': [('a', 'obj')], - 'parameter_definitions': [('a', 'param', '1111.2222')], + "entity_classes": [("a",)], + "entities": [("a", "obj")], + "parameter_definitions": [("a", "param", "1111.2222")], } self.assertEqual(mapped_data, expected) @@ -95,9 +95,9 @@ def test_convert_functions_bool(self): param_def_mapping.flatten()[-1].position = 1 mapped_data, _ = get_mapped_data(data, [mapping], column_convert_fns=column_convert_fns) expected = { - 'entity_classes': [('a',)], - 'entities': [('a', 'obj')], - 'parameter_definitions': [('a', 'param', False)], + "entity_classes": [("a",)], + "entities": [("a", "obj")], + "parameter_definitions": [("a", "param", False)], } self.assertEqual(mapped_data, expected) @@ -193,42 +193,42 @@ def test_object_class_mapping(self): mapping = import_mapping_from_dict({"map_type": "ObjectClass"}) d = mapping_to_dict(mapping) types = [m["map_type"] for m in d] - expected = ['EntityClass', 'Entity', 'EntityMetadata'] + expected = ["EntityClass", "Entity", "EntityMetadata"] self.assertEqual(types, expected) def test_relationship_class_mapping(self): mapping = import_mapping_from_dict({"map_type": "RelationshipClass"}) d = mapping_to_dict(mapping) types = [m["map_type"] for m in d] - expected = ['EntityClass', 'Dimension', 'Entity', 'Element', 'EntityMetadata'] + expected = ["EntityClass", "Dimension", "Entity", "Element", "EntityMetadata"] self.assertEqual(types, expected) def test_object_group_mapping(self): mapping = import_mapping_from_dict({"map_type": "ObjectGroup"}) d = mapping_to_dict(mapping) types = [m["map_type"] for m in d] - expected = ['EntityClass', 'Entity', 'EntityGroup'] + expected = ["EntityClass", "Entity", "EntityGroup"] self.assertEqual(types, expected) def test_alternative_mapping(self): mapping = import_mapping_from_dict({"map_type": "Alternative"}) d = mapping_to_dict(mapping) types = [m["map_type"] for m in d] - expected = ['Alternative'] + expected = ["Alternative"] self.assertEqual(types, expected) def test_scenario_mapping(self): mapping = import_mapping_from_dict({"map_type": "Scenario"}) d = mapping_to_dict(mapping) types = [m["map_type"] for m in d] - expected = ['Scenario', 'ScenarioActiveFlag'] + expected = ["Scenario", "ScenarioActiveFlag"] self.assertEqual(types, expected) def test_scenario_alternative_mapping(self): mapping = import_mapping_from_dict({"map_type": "ScenarioAlternative"}) d = mapping_to_dict(mapping) types = [m["map_type"] for m in d] - expected = ['Scenario', 'ScenarioAlternative', 'ScenarioBeforeAlternative'] + expected = ["Scenario", "ScenarioAlternative", "ScenarioBeforeAlternative"] self.assertEqual(types, expected) def test_tool_mapping(self): @@ -247,7 +247,7 @@ def test_parameter_value_list_mapping(self): mapping = import_mapping_from_dict({"map_type": "ParameterValueList"}) d = mapping_to_dict(mapping) types = [m["map_type"] for m in d] - expected = ['ParameterValueList', 'ParameterValueListValue'] + expected = ["ParameterValueList", "ParameterValueListValue"] self.assertEqual(types, expected) @@ -262,13 +262,13 @@ def test_ObjectClass_to_dict_from_dict(self): mapping = import_mapping_from_dict(mapping) out = mapping_to_dict(mapping) expected = [ - {'map_type': 'EntityClass', 'position': 0}, - {'map_type': 'Entity', 'position': 1}, - {'map_type': 'EntityMetadata', 'position': 'hidden'}, - {'map_type': 'ParameterDefinition', 'position': 2}, - {'map_type': 'Alternative', 'position': 'hidden'}, - {'map_type': 'ParameterValueMetadata', 'position': 'hidden'}, - {'map_type': 'ParameterValue', 'position': 3}, + {"map_type": "EntityClass", "position": 0}, + {"map_type": "Entity", "position": 1}, + {"map_type": "EntityMetadata", "position": "hidden"}, + {"map_type": "ParameterDefinition", "position": 2}, + {"map_type": "Alternative", "position": "hidden"}, + {"map_type": "ParameterValueMetadata", "position": "hidden"}, + {"map_type": "ParameterValue", "position": 3}, ] self.assertEqual(out, expected) @@ -277,9 +277,9 @@ def test_ObjectClass_object_from_dict_to_dict(self): mapping = import_mapping_from_dict(mapping) out = mapping_to_dict(mapping) expected = [ - {'map_type': 'EntityClass', 'position': 0}, - {'map_type': 'Entity', 'position': 1}, - {'map_type': 'EntityMetadata', 'position': 'hidden'}, + {"map_type": "EntityClass", "position": 0}, + {"map_type": "Entity", "position": 1}, + {"map_type": "EntityMetadata", "position": "hidden"}, ] self.assertEqual(out, expected) @@ -288,9 +288,9 @@ def test_ObjectClass_object_from_dict_to_dict2(self): mapping = import_mapping_from_dict(mapping) out = mapping_to_dict(mapping) expected = [ - {'map_type': 'EntityClass', 'position': 'hidden', 'value': 'cls'}, - {'map_type': 'Entity', 'position': 'hidden', 'value': 'obj'}, - {'map_type': 'EntityMetadata', 'position': 'hidden'}, + {"map_type": "EntityClass", "position": "hidden", "value": "cls"}, + {"map_type": "Entity", "position": "hidden", "value": "obj"}, + {"map_type": "EntityMetadata", "position": "hidden"}, ] self.assertEqual(out, expected) @@ -305,17 +305,17 @@ def test_RelationshipClassMapping_from_dict_to_dict(self): mapping = import_mapping_from_dict(mapping) out = mapping_to_dict(mapping) expected = [ - {'map_type': 'EntityClass', 'position': 'hidden', 'value': 'unit__node'}, - {'map_type': 'Dimension', 'position': 0}, - {'map_type': 'Dimension', 'position': 1}, - {'map_type': 'Entity', 'position': 'hidden'}, - {'map_type': 'Element', 'position': 0}, - {'map_type': 'Element', 'position': 1}, - {'map_type': 'EntityMetadata', 'position': 'hidden'}, - {'map_type': 'ParameterDefinition', 'position': 'hidden', 'value': 'pname'}, - {'map_type': 'Alternative', 'position': 'hidden'}, - {'map_type': 'ParameterValueMetadata', 'position': 'hidden'}, - {'map_type': 'ParameterValue', 'position': 2}, + {"map_type": "EntityClass", "position": "hidden", "value": "unit__node"}, + {"map_type": "Dimension", "position": 0}, + {"map_type": "Dimension", "position": 1}, + {"map_type": "Entity", "position": "hidden"}, + {"map_type": "Element", "position": 0}, + {"map_type": "Element", "position": 1}, + {"map_type": "EntityMetadata", "position": "hidden"}, + {"map_type": "ParameterDefinition", "position": "hidden", "value": "pname"}, + {"map_type": "Alternative", "position": "hidden"}, + {"map_type": "ParameterValueMetadata", "position": "hidden"}, + {"map_type": "ParameterValue", "position": 2}, ] self.assertEqual(out, expected) @@ -329,13 +329,13 @@ def test_RelationshipClassMapping_from_dict_to_dict2(self): mapping = import_mapping_from_dict(mapping) out = mapping_to_dict(mapping) expected = [ - {'map_type': 'EntityClass', 'position': 'hidden', 'value': 'unit__node'}, - {'map_type': 'Dimension', 'position': 'hidden', 'value': 'cls'}, - {'map_type': 'Dimension', 'position': 0}, - {'map_type': 'Entity', 'position': 'hidden'}, - {'map_type': 'Element', 'position': 'hidden', 'value': 'obj'}, - {'map_type': 'Element', 'position': 0}, - {'map_type': 'EntityMetadata', 'position': 'hidden'}, + {"map_type": "EntityClass", "position": "hidden", "value": "unit__node"}, + {"map_type": "Dimension", "position": "hidden", "value": "cls"}, + {"map_type": "Dimension", "position": 0}, + {"map_type": "Entity", "position": "hidden"}, + {"map_type": "Element", "position": "hidden", "value": "obj"}, + {"map_type": "Element", "position": 0}, + {"map_type": "EntityMetadata", "position": "hidden"}, ] self.assertEqual(out, expected) @@ -354,18 +354,18 @@ def test_RelationshipClassMapping_from_dict_to_dict3(self): mapping = import_mapping_from_dict(mapping) out = mapping_to_dict(mapping) expected = [ - {'map_type': 'EntityClass', 'position': 'hidden', 'value': 'unit__node'}, - {'map_type': 'Dimension', 'position': 'hidden'}, - {'map_type': 'Entity', 'position': 'hidden'}, - {'map_type': 'Element', 'position': 'hidden'}, - {'map_type': 'EntityMetadata', 'position': 'hidden'}, - {'map_type': 'ParameterDefinition', 'position': 'hidden', 'value': 'pname'}, - {'map_type': 'Alternative', 'position': 'hidden'}, - {'map_type': 'ParameterValueMetadata', 'position': 'hidden'}, - {'map_type': 'ParameterValueType', 'position': 'hidden', 'value': 'array'}, - {'map_type': 'IndexName', 'position': 'hidden'}, - {'map_type': 'ParameterValueIndex', 'position': 'hidden', 'value': 'dim'}, - {'map_type': 'ExpandedValue', 'position': 2}, + {"map_type": "EntityClass", "position": "hidden", "value": "unit__node"}, + {"map_type": "Dimension", "position": "hidden"}, + {"map_type": "Entity", "position": "hidden"}, + {"map_type": "Element", "position": "hidden"}, + {"map_type": "EntityMetadata", "position": "hidden"}, + {"map_type": "ParameterDefinition", "position": "hidden", "value": "pname"}, + {"map_type": "Alternative", "position": "hidden"}, + {"map_type": "ParameterValueMetadata", "position": "hidden"}, + {"map_type": "ParameterValueType", "position": "hidden", "value": "array"}, + {"map_type": "IndexName", "position": "hidden"}, + {"map_type": "ParameterValueIndex", "position": "hidden", "value": "dim"}, + {"map_type": "ExpandedValue", "position": 2}, ] self.assertEqual(out, expected) @@ -386,9 +386,9 @@ def test_ObjectGroupMapping_to_dict_from_dict(self): mapping = import_mapping_from_dict(mapping) out = mapping_to_dict(mapping) expected = [ - {'map_type': 'EntityClass', 'position': 0}, - {'map_type': 'Entity', 'position': 1}, - {'map_type': 'EntityGroup', 'position': 2}, + {"map_type": "EntityClass", "position": 0}, + {"map_type": "Entity", "position": 1}, + {"map_type": "EntityGroup", "position": 2}, ] self.assertEqual(out, expected) @@ -396,7 +396,7 @@ def test_Alternative_to_dict_from_dict(self): mapping = {"map_type": "Alternative", "name": 0} mapping = import_mapping_from_dict(mapping) out = mapping_to_dict(mapping) - expected = [{'map_type': 'Alternative', 'position': 0}] + expected = [{"map_type": "Alternative", "position": 0}] self.assertEqual(out, expected) def test_Scenario_to_dict_from_dict(self): @@ -404,8 +404,8 @@ def test_Scenario_to_dict_from_dict(self): mapping = import_mapping_from_dict(mapping) out = mapping_to_dict(mapping) expected = [ - {'map_type': 'Scenario', 'position': 0}, - {'map_type': 'ScenarioActiveFlag', 'position': 'hidden', 'value': 'false'}, + {"map_type": "Scenario", "position": 0}, + {"map_type": "ScenarioActiveFlag", "position": "hidden", "value": "false"}, ] self.assertEqual(out, expected) @@ -419,9 +419,9 @@ def test_ScenarioAlternative_to_dict_from_dict(self): mapping = import_mapping_from_dict(mapping) out = mapping_to_dict(mapping) expected = [ - {'map_type': 'Scenario', 'position': 0}, - {'map_type': 'ScenarioAlternative', 'position': 1}, - {'map_type': 'ScenarioBeforeAlternative', 'position': 2}, + {"map_type": "Scenario", "position": 0}, + {"map_type": "ScenarioAlternative", "position": 1}, + {"map_type": "ScenarioBeforeAlternative", "position": 2}, ] self.assertEqual(out, expected) @@ -461,10 +461,10 @@ def test_MapValueMapping_from_dict_to_dict(self): parameter_mapping = parameter_value_mapping_from_dict(mapping_dict) out = mapping_to_dict(parameter_mapping) expected = [ - {'map_type': 'ParameterValueType', 'position': 'hidden', 'value': 'map', 'compress': True}, - {'map_type': 'IndexName', 'position': 'hidden'}, - {'map_type': 'ParameterValueIndex', 'position': 'fifth column'}, - {'map_type': 'ExpandedValue', 'position': -24}, + {"map_type": "ParameterValueType", "position": "hidden", "value": "map", "compress": True}, + {"map_type": "IndexName", "position": "hidden"}, + {"map_type": "ParameterValueIndex", "position": "fifth column"}, + {"map_type": "ExpandedValue", "position": -24}, ] self.assertEqual(out, expected) @@ -479,14 +479,14 @@ def test_TimeSeriesValueMapping_from_dict_to_dict(self): out = mapping_to_dict(parameter_mapping) expected = [ { - 'map_type': 'ParameterValueType', - 'position': 'hidden', - 'value': 'time_series', - 'options': {'repeat': True, 'ignore_year': False, 'fixed_resolution': False}, + "map_type": "ParameterValueType", + "position": "hidden", + "value": "time_series", + "options": {"repeat": True, "ignore_year": False, "fixed_resolution": False}, }, - {'map_type': 'IndexName', 'position': 'hidden'}, - {'map_type': 'ParameterValueIndex', 'position': 'fifth column'}, - {'map_type': 'ExpandedValue', 'position': -24}, + {"map_type": "IndexName", "position": "hidden"}, + {"map_type": "ParameterValueIndex", "position": "fifth column"}, + {"map_type": "ExpandedValue", "position": -24}, ] self.assertEqual(out, expected) diff --git a/tests/spine_io/exporters/test_gdx_writer.py b/tests/spine_io/exporters/test_gdx_writer.py index 5fa97328..33cd6c39 100644 --- a/tests/spine_io/exporters/test_gdx_writer.py +++ b/tests/spine_io/exporters/test_gdx_writer.py @@ -309,5 +309,5 @@ def test_special_value_conversions(self): db_map.close() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/spine_io/exporters/test_writer.py b/tests/spine_io/exporters/test_writer.py index 3a150b9b..892278ea 100644 --- a/tests/spine_io/exporters/test_writer.py +++ b/tests/spine_io/exporters/test_writer.py @@ -87,5 +87,5 @@ def test_max_rows_with_filter(self): self.assertEqual(writer.tables, {None: [["class2", "obj6"]]}) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/spine_io/importers/test_datapackage_reader.py b/tests/spine_io/importers/test_datapackage_reader.py index 58fabb3f..3fe43486 100644 --- a/tests/spine_io/importers/test_datapackage_reader.py +++ b/tests/spine_io/importers/test_datapackage_reader.py @@ -93,5 +93,5 @@ def test_datapackage(rows): yield package_path -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/spine_io/importers/test_json_reader.py b/tests/spine_io/importers/test_json_reader.py index 58c2754b..bdd0e416 100644 --- a/tests/spine_io/importers/test_json_reader.py +++ b/tests/spine_io/importers/test_json_reader.py @@ -41,5 +41,5 @@ def test_file_iterator_works_with_empty_options(self): self.assertEqual(rows, [["a", 1], ["b", "c", 2]]) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/spine_io/importers/test_reader.py b/tests/spine_io/importers/test_reader.py index 4cb9026f..48b775e2 100644 --- a/tests/spine_io/importers/test_reader.py +++ b/tests/spine_io/importers/test_reader.py @@ -39,5 +39,5 @@ def failing_iterator(): self.assertEqual(errors, ["error in iterator"]) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/spine_io/importers/test_sqlalchemy_connector.py b/tests/spine_io/importers/test_sqlalchemy_connector.py index d4e4ff18..67705376 100644 --- a/tests/spine_io/importers/test_sqlalchemy_connector.py +++ b/tests/spine_io/importers/test_sqlalchemy_connector.py @@ -26,5 +26,5 @@ def test_connector_is_picklable(self): self.assertTrue(pickled) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 4813ec93..b659de7d 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -991,22 +991,22 @@ def tearDown(self): self._db_map.close() def create_object_classes(self): - obj_classes = ['class1', 'class2'] + obj_classes = ["class1", "class2"] import_functions.import_object_classes(self._db_map, obj_classes) return obj_classes def create_objects(self): - objects = [('class1', 'obj11'), ('class1', 'obj12'), ('class2', 'obj21')] + objects = [("class1", "obj11"), ("class1", "obj12"), ("class2", "obj21")] import_functions.import_objects(self._db_map, objects) return objects def create_relationship_classes(self): - relationship_classes = [('rel1', ['class1']), ('rel2', ['class1', 'class2'])] + relationship_classes = [("rel1", ["class1"]), ("rel2", ["class1", "class2"])] import_functions.import_relationship_classes(self._db_map, relationship_classes) return relationship_classes def create_relationships(self): - relationships = [('rel1', ['obj11']), ('rel2', ['obj11', 'obj21'])] + relationships = [("rel1", ["obj11"]), ("rel2", ["obj11", "obj21"])] import_functions.import_relationships(self._db_map, relationships) return relationships @@ -1181,8 +1181,8 @@ def test_entity_parameter_definition_sq_for_relationship_class(self): def test_entity_parameter_definition_sq_with_multiple_relationship_classes_but_single_parameter(self): self.create_object_classes() self.create_relationship_classes() - obj_parameter_definitions = [('class1', 'par1a'), ('class1', 'par1b')] - rel_parameter_definitions = [('rel1', 'rpar1a')] + obj_parameter_definitions = [("class1", "par1a"), ("class1", "par1b")] + rel_parameter_definitions = [("rel1", "rpar1a")] import_functions.import_object_parameters(self._db_map, obj_parameter_definitions) import_functions.import_relationship_parameters(self._db_map, rel_parameter_definitions) self._db_map.commit_session("test") @@ -1198,18 +1198,18 @@ def test_entity_parameter_values(self): self.create_objects() self.create_relationship_classes() self.create_relationships() - obj_parameter_definitions = [('class1', 'par1a'), ('class1', 'par1b'), ('class2', 'par2a')] - rel_parameter_definitions = [('rel1', 'rpar1a'), ('rel2', 'rpar2a')] + obj_parameter_definitions = [("class1", "par1a"), ("class1", "par1b"), ("class2", "par2a")] + rel_parameter_definitions = [("rel1", "rpar1a"), ("rel2", "rpar2a")] import_functions.import_object_parameters(self._db_map, obj_parameter_definitions) import_functions.import_relationship_parameters(self._db_map, rel_parameter_definitions) object_parameter_values = [ - ('class1', 'obj11', 'par1a', 123), - ('class1', 'obj11', 'par1b', 333), - ('class2', 'obj21', 'par2a', 'empty'), + ("class1", "obj11", "par1a", 123), + ("class1", "obj11", "par1b", 333), + ("class2", "obj21", "par2a", "empty"), ] _, errors = import_functions.import_object_parameter_values(self._db_map, object_parameter_values) self.assertFalse(errors) - relationship_parameter_values = [('rel1', ['obj11'], 'rpar1a', 1.1), ('rel2', ['obj11', 'obj21'], 'rpar2a', 42)] + relationship_parameter_values = [("rel1", ["obj11"], "rpar1a", 1.1), ("rel2", ["obj11", "obj21"], "rpar2a", 42)] _, errors = import_functions.import_relationship_parameter_values(self._db_map, relationship_parameter_values) self.assertFalse(errors) self._db_map.commit_session("test") @@ -1222,7 +1222,7 @@ def test_entity_parameter_values(self): if row.object_name: # This is an object parameter self.assertEqual(row.object_name, par_val[1]) else: # This is a relationship parameter - self.assertEqual(row.object_name_list, ','.join(par_val[1])) + self.assertEqual(row.object_name_list, ",".join(par_val[1])) self.assertEqual(row.parameter_name, par_val[2]) self.assertEqual(from_database(row.value, row.type), par_val[3]) @@ -3006,11 +3006,11 @@ def test_cascade_remove_unfetched(self): class TestDatabaseMappingConcurrent(AssertSuccessTestCase): - @unittest.skipIf(os.name == 'nt', "Needs fixing") + @unittest.skipIf(os.name == "nt", "Needs fixing") def test_concurrent_commit_threading(self): self._do_test_concurrent_commit(threading.Thread) - @unittest.skipIf(os.name == 'nt', "Needs fixing") + @unittest.skipIf(os.name == "nt", "Needs fixing") def test_concurrent_commit_multiprocessing(self): self._do_test_concurrent_commit(multiprocessing.Process) diff --git a/tests/test_check_integrity.py b/tests/test_check_integrity.py index afdd4c2d..d4ab0802 100644 --- a/tests/test_check_integrity.py +++ b/tests/test_check_integrity.py @@ -27,21 +27,21 @@ def _val_dict(val): class TestCheckIntegrity(unittest.TestCase): def setUp(self): self.data = [ - (bool, (b'"TRUE"', b'"FALSE"', b'"T"', b'"True"', b'"False"'), (b'true', b'false')), - (int, (b'32', b'3.14'), (b'42', b'-2')), + (bool, (b'"TRUE"', b'"FALSE"', b'"T"', b'"True"', b'"False"'), (b"true", b"false")), + (int, (b"32", b"3.14"), (b"42", b"-2")), (str, (b'"FOO"', b'"bar"'), (b'"foo"', b'"Bar"', b'"BAZ"')), ] self.value_type = {bool: 1, int: 2, str: 3} self.db_map = DatabaseMapping("sqlite://", create=True) - self.db_map.add_items("entity_class", {"id": 1, 'name': 'cat'}) + self.db_map.add_items("entity_class", {"id": 1, "name": "cat"}) self.db_map.add_items( "entity", - {"id": 1, 'name': 'Tom', "class_id": 1}, - {"id": 2, 'name': 'Felix', "class_id": 1}, - {"id": 3, 'name': 'Jansson', "class_id": 1}, + {"id": 1, "name": "Tom", "class_id": 1}, + {"id": 2, "name": "Felix", "class_id": 1}, + {"id": 3, "name": "Jansson", "class_id": 1}, ) self.db_map.add_items( - "parameter_value_list", {"id": 1, 'name': 'list1'}, {"id": 2, 'name': 'list2'}, {"id": 3, 'name': 'list3'} + "parameter_value_list", {"id": 1, "name": "list1"}, {"id": 2, "name": "list2"}, {"id": 3, "name": "list3"} ) self.db_map.add_items( "list_value", @@ -55,21 +55,21 @@ def setUp(self): ) self.db_map.add_items( "parameter_definition", - {"id": 1, 'name': 'par1', 'entity_class_id': 1, 'parameter_value_list_id': 1}, - {"id": 2, 'name': 'par2', 'entity_class_id': 1, 'parameter_value_list_id': 2}, - {"id": 3, 'name': 'par3', 'entity_class_id': 1, 'parameter_value_list_id': 3}, + {"id": 1, "name": "par1", "entity_class_id": 1, "parameter_value_list_id": 1}, + {"id": 2, "name": "par2", "entity_class_id": 1, "parameter_value_list_id": 2}, + {"id": 3, "name": "par3", "entity_class_id": 1, "parameter_value_list_id": 3}, ) @staticmethod def get_item(id_: int, val: bytes, entity_id: int): return { - 'id': 1, - 'parameter_definition_id': id_, - 'entity_class_id': 1, - 'entity_id': entity_id, - 'value': val, - 'type': None, - 'alternative_id': 1, + "id": 1, + "parameter_definition_id": id_, + "entity_class_id": 1, + "entity_id": entity_id, + "value": val, + "type": None, + "alternative_id": 1, } def test_parameter_values_and_default_values_with_list_references(self): @@ -81,7 +81,7 @@ def test_parameter_values_and_default_values_with_list_references(self): item = self.get_item(id_, value, 1) _, errors = self.db_map.add_items("parameter_value", item) self.assertEqual(len(errors), 1) - parsed_value = json.loads(value.decode('utf8')) + parsed_value = json.loads(value.decode("utf8")) if isinstance(parsed_value, Number): parsed_value = float(parsed_value) self.assertEqual(errors[0], f"value {parsed_value} of par{id_} for ('Tom',) is not in list{id_}") diff --git a/tests/test_db_mapping_base.py b/tests/test_db_mapping_base.py index cd67668c..7b2e7802 100644 --- a/tests/test_db_mapping_base.py +++ b/tests/test_db_mapping_base.py @@ -82,5 +82,5 @@ def test_setting_new_id_validates_it(self): self.assertTrue(item.has_valid_id) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_export_functions.py b/tests/test_export_functions.py index 2a8aaba9..289f53af 100644 --- a/tests/test_export_functions.py +++ b/tests/test_export_functions.py @@ -148,5 +148,5 @@ def test_export_data(self): self.assertEqual(exported["scenario_alternatives"], [("scenario", "alternative", None)]) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 7d3a8b7b..4ed0abf2 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -41,13 +41,13 @@ def test_multiple_dimension(self): class TestCreateNewSpineEngine(unittest.TestCase): def test_same_schema(self): - engine1 = create_new_spine_database('sqlite://') - engine2 = create_new_spine_database('sqlite://') + engine1 = create_new_spine_database("sqlite://") + engine2 = create_new_spine_database("sqlite://") self.assertTrue(compare_schemas(engine1, engine2)) def test_different_schema(self): - engine1 = create_new_spine_database('sqlite://') - engine2 = create_new_spine_database('sqlite://') + engine1 = create_new_spine_database("sqlite://") + engine2 = create_new_spine_database("sqlite://") engine2.execute("drop table entity") self.assertFalse(compare_schemas(engine1, engine2)) diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py index 7a8039ad..9f7e2a6b 100644 --- a/tests/test_import_functions.py +++ b/tests/test_import_functions.py @@ -86,9 +86,9 @@ def test_import_data_integration(self): rel_parameters = [["example_rel_class", "rel_parameter"]] # 1 item object_p_values = [["example_class", "example_object", "example_parameter", 3.14]] # 1 item rel_p_values = [["example_rel_class", ["example_object", "other_object"], "rel_parameter", 2.718]] # 1 - alternatives = [['example_alternative', 'An example']] - scenarios = [['example_scenario', True, 'An example']] - scenario_alternatives = [['example_scenario', 'example_alternative']] + alternatives = [["example_alternative", "An example"]] + scenarios = [["example_scenario", True, "An example"]] + scenario_alternatives = [["example_scenario", "example_alternative"]] num_imports, errors = import_data( db_map, @@ -807,7 +807,7 @@ def test_import_existing_object_parameter_value_on_conflict_keep(self): db_map.commit_session("test") pv = db_map.query(db_map.object_parameter_value_sq).filter_by(object_name="object1").first() value = from_database(pv.value, pv.type) - self.assertEqual(['2000-01-01T01:00:00', '2000-01-01T02:00:00'], [str(x) for x in value.indexes]) + self.assertEqual(["2000-01-01T01:00:00", "2000-01-01T02:00:00"], [str(x) for x in value.indexes]) self.assertEqual([1.0, 2.0], list(value.values)) db_map.close() @@ -824,7 +824,7 @@ def test_import_existing_object_parameter_value_on_conflict_replace(self): db_map.commit_session("test") pv = db_map.query(db_map.object_parameter_value_sq).filter_by(object_name="object1").first() value = from_database(pv.value, pv.type) - self.assertEqual(['2000-01-01T02:00:00', '2000-01-01T03:00:00'], [str(x) for x in value.indexes]) + self.assertEqual(["2000-01-01T02:00:00", "2000-01-01T03:00:00"], [str(x) for x in value.indexes]) self.assertEqual([3.0, 4.0], list(value.values)) db_map.close() @@ -842,7 +842,7 @@ def test_import_existing_object_parameter_value_on_conflict_merge(self): pv = db_map.query(db_map.object_parameter_value_sq).filter_by(object_name="object1").first() value = from_database(pv.value, pv.type) self.assertEqual( - ['2000-01-01T01:00:00', '2000-01-01T02:00:00', '2000-01-01T03:00:00'], [str(x) for x in value.indexes] + ["2000-01-01T01:00:00", "2000-01-01T02:00:00", "2000-01-01T03:00:00"], [str(x) for x in value.indexes] ) self.assertEqual([1.0, 3.0, 4.0], list(value.values)) db_map.close() @@ -868,10 +868,10 @@ def test_import_existing_object_parameter_value_on_conflict_merge_map(self): db_map.commit_session("test") pv = db_map.query(db_map.object_parameter_value_sq).filter_by(object_name="object1").first() map_ = from_database(pv.value, pv.type) - self.assertEqual(['xxx'], [str(x) for x in map_.indexes]) - ts = map_.get_value('xxx') + self.assertEqual(["xxx"], [str(x) for x in map_.indexes]) + ts = map_.get_value("xxx") self.assertEqual( - ['2000-01-01T01:00:00', '2000-01-01T02:00:00', '2000-01-01T03:00:00'], [str(x) for x in ts.indexes] + ["2000-01-01T01:00:00", "2000-01-01T02:00:00", "2000-01-01T03:00:00"], [str(x) for x in ts.indexes] ) self.assertEqual([1.0, 3.0, 4.0], list(ts.values)) db_map.close() @@ -1138,28 +1138,28 @@ def test_non_existent_relationship_parameter_value_from_value_list_fails_gracefu def test_unparse_value_imports_fields_correctly(self): with DatabaseMapping("sqlite:///", create=True) as db_map: data = { - 'entity_classes': [('A', (), None, None, False)], - 'entities': [('A', 'aa', None)], - 'parameter_definitions': [('A', 'test1', None, None, None)], - 'parameter_values': [ + "entity_classes": [("A", (), None, None, False)], + "entities": [("A", "aa", None)], + "parameter_definitions": [("A", "test1", None, None, None)], + "parameter_values": [ ( - 'A', - 'aa', - 'test1', + "A", + "aa", + "test1", { - 'type': 'time_series', - 'index': { - 'start': '2000-01-01 00:00:00', - 'resolution': '1h', - 'ignore_year': False, - 'repeat': False, + "type": "time_series", + "index": { + "start": "2000-01-01 00:00:00", + "resolution": "1h", + "ignore_year": False, + "repeat": False, }, - 'data': [0.0, 1.0, 2.0, 4.0, 8.0, 0.0], + "data": [0.0, 1.0, 2.0, 4.0, 8.0, 0.0], }, - 'Base', + "Base", ) ], - 'alternatives': [('Base', 'Base alternative')], + "alternatives": [("Base", "Base alternative")], } count, errors = import_data(db_map, **data, unparse_value=dump_db_value) @@ -1175,7 +1175,7 @@ def test_unparse_value_imports_fields_correctly(self): time_series = from_database(value.value, value.type) expected_result = TimeSeriesFixedResolution( - '2000-01-01 00:00:00', '1h', [0.0, 1.0, 2.0, 4.0, 8.0, 0.0], False, False + "2000-01-01 00:00:00", "1h", [0.0, 1.0, 2.0, 4.0, 8.0, 0.0], False, False ) self.assertEqual(time_series, expected_result) @@ -1360,16 +1360,16 @@ def test_importing_existing_scenario_alternative_does_not_alter_scenario_alterna self.assertEqual(count, 0) def test_import_scenario_alternatives_in_arbitrary_order(self): - count, errors = import_scenarios(self._db_map, [('A (1)', False, '')]) + count, errors = import_scenarios(self._db_map, [("A (1)", False, "")]) self.assertEqual(errors, []) self.assertEqual(count, 1) count, errors = import_alternatives( - self._db_map, [('Base', 'Base alternative'), ('b', ''), ('c', ''), ('d', '')] + self._db_map, [("Base", "Base alternative"), ("b", ""), ("c", ""), ("d", "")] ) self.assertEqual(errors, []) self.assertEqual(count, 3) count, errors = import_scenario_alternatives( - self._db_map, [('A (1)', 'c', 'd'), ('A (1)', 'd', None), ('A (1)', 'Base', 'b'), ('A (1)', 'b', 'c')] + self._db_map, [("A (1)", "c", "d"), ("A (1)", "d", None), ("A (1)", "Base", "b"), ("A (1)", "b", "c")] ) self.assertEqual(errors, []) self.assertEqual(count, 4) @@ -1481,13 +1481,13 @@ def test_import_metadata_with_nested_list(self): self.assertEqual(count, 2) self.assertFalse(errors) self.assertEqual(len(metadata), 2) - self.assertIn(('contributors', "{'name': 'John'}"), metadata) - self.assertIn(('contributors', "{'name': 'Charly'}"), metadata) + self.assertIn(("contributors", "{'name': 'John'}"), metadata) + self.assertIn(("contributors", "{'name': 'Charly'}"), metadata) db_map.close() def test_import_unformatted_metadata(self): db_map = create_db_map() - count, errors = import_metadata(db_map, ['not a JSON object']) + count, errors = import_metadata(db_map, ["not a JSON object"]) db_map.commit_session("test") metadata = [(x.name, x.value) for x in db_map.query(db_map.metadata_sq)] self.assertEqual(count, 1) @@ -1527,10 +1527,10 @@ def test_import_object_metadata(self): (x.entity_name, x.metadata_name, x.metadata_value) for x in db_map.query(db_map.ext_entity_metadata_sq) ] self.assertEqual(len(metadata), 4) - self.assertIn(('object1', 'co-author', 'John'), metadata) - self.assertIn(('object1', 'age', '90'), metadata) - self.assertIn(('object1', 'co-author', 'Charly'), metadata) - self.assertIn(('object1', 'age', '17'), metadata) + self.assertIn(("object1", "co-author", "John"), metadata) + self.assertIn(("object1", "age", "90"), metadata) + self.assertIn(("object1", "co-author", "Charly"), metadata) + self.assertIn(("object1", "age", "17"), metadata) db_map.close() def test_import_relationship_metadata(self): @@ -1548,10 +1548,10 @@ def test_import_relationship_metadata(self): db_map.commit_session("test") metadata = [(x.metadata_name, x.metadata_value) for x in db_map.query(db_map.ext_entity_metadata_sq)] self.assertEqual(len(metadata), 4) - self.assertIn(('co-author', 'John'), metadata) - self.assertIn(('age', '90'), metadata) - self.assertIn(('co-author', 'Charly'), metadata) - self.assertIn(('age', '17'), metadata) + self.assertIn(("co-author", "John"), metadata) + self.assertIn(("age", "90"), metadata) + self.assertIn(("co-author", "Charly"), metadata) + self.assertIn(("age", "17"), metadata) db_map.close() diff --git a/tests/test_migration.py b/tests/test_migration.py index 27c1a882..3540ede1 100644 --- a/tests/test_migration.py +++ b/tests/test_migration.py @@ -128,19 +128,19 @@ def test_upgrade_content(self): self.assertTrue(len(rel_par_defs), 1) self.assertTrue(len(obj_par_vals), 2) self.assertTrue(len(rel_par_vals), 2) - self.assertTrue('dog' in object_classes.values()) - self.assertTrue('fish' in object_classes.values()) - self.assertTrue(('dog', 'pluto') in objects.values()) - self.assertTrue(('dog', 'scooby') in objects.values()) - self.assertTrue(('fish', 'nemo') in objects.values()) - self.assertTrue(('dog__fish', 'dog,fish') in rel_clss.values()) - self.assertTrue(('dog__fish', 'pluto__nemo', 'pluto,nemo') in rels.values()) - self.assertTrue(('dog__fish', 'scooby__nemo', 'scooby,nemo') in rels.values()) - self.assertTrue(('dog', 'breed') in obj_par_defs.values()) - self.assertTrue(('fish', 'water') in obj_par_defs.values()) - self.assertTrue(('dog__fish', 'relative_speed') in rel_par_defs.values()) - self.assertTrue(('breed', 'scooby', b'"big dane"') in obj_par_vals) - self.assertTrue(('breed', 'pluto', b'"labrador"') in obj_par_vals) - self.assertTrue(('relative_speed', 'pluto__nemo', b'100') in rel_par_vals) - self.assertTrue(('relative_speed', 'scooby__nemo', b'-1') in rel_par_vals) + self.assertTrue("dog" in object_classes.values()) + self.assertTrue("fish" in object_classes.values()) + self.assertTrue(("dog", "pluto") in objects.values()) + self.assertTrue(("dog", "scooby") in objects.values()) + self.assertTrue(("fish", "nemo") in objects.values()) + self.assertTrue(("dog__fish", "dog,fish") in rel_clss.values()) + self.assertTrue(("dog__fish", "pluto__nemo", "pluto,nemo") in rels.values()) + self.assertTrue(("dog__fish", "scooby__nemo", "scooby,nemo") in rels.values()) + self.assertTrue(("dog", "breed") in obj_par_defs.values()) + self.assertTrue(("fish", "water") in obj_par_defs.values()) + self.assertTrue(("dog__fish", "relative_speed") in rel_par_defs.values()) + self.assertTrue(("breed", "scooby", b'"big dane"') in obj_par_vals) + self.assertTrue(("breed", "pluto", b'"labrador"') in obj_par_vals) + self.assertTrue(("relative_speed", "pluto__nemo", b"100") in rel_par_vals) + self.assertTrue(("relative_speed", "scooby__nemo", b"-1") in rel_par_vals) db_map.close() diff --git a/tests/test_parameter_value.py b/tests/test_parameter_value.py index deaf7cea..1f4de335 100644 --- a/tests/test_parameter_value.py +++ b/tests/test_parameter_value.py @@ -626,13 +626,13 @@ def test_from_database_Map_two_column_array_format(self): self.assertEqual(value.index_name, "x") def test_from_database_Map_nested_maps(self): - database_value = b''' + database_value = b""" { "index_type": "duration", "data":[["1 hour", {"type": "map", "index_type": "date_time", "data": {"2020-01-01T12:00": {"type":"duration", "data":"3 hours"}}}]] - }''' + }""" value = from_database(database_value, type_="map") self.assertEqual(value.indexes, [Duration("1 hour")]) nested_map = value.values[0] @@ -641,14 +641,14 @@ def test_from_database_Map_nested_maps(self): self.assertEqual(nested_map.values, [Duration("3 hours")]) def test_from_database_Map_with_TimeSeries_values(self): - database_value = b''' + database_value = b""" { "index_type": "duration", "data":[["1 hour", {"type": "time_series", "data": [["2020-01-01T12:00", -3.0], ["2020-01-02T12:00", -9.3]] } ]] - }''' + }""" value = from_database(database_value, type_="map") self.assertEqual(value.indexes, [Duration("1 hour")]) self.assertEqual( @@ -657,21 +657,21 @@ def test_from_database_Map_with_TimeSeries_values(self): ) def test_from_database_Map_with_Array_values(self): - database_value = b''' + database_value = b""" { "index_type": "duration", "data":[["1 hour", {"type": "array", "data": [-3.0, -9.3]}]] - }''' + }""" value = from_database(database_value, type_="map") self.assertEqual(value.indexes, [Duration("1 hour")]) self.assertEqual(value.values, [Array([-3.0, -9.3])]) def test_from_database_Map_with_TimePattern_values(self): - database_value = b''' + database_value = b""" { "index_type": "float", "data":[["2.3", {"type": "time_pattern", "data": {"M1-2": -9.3, "M3-12": -3.9}}]] - }''' + }""" value = from_database(database_value, type_="map") self.assertEqual(value.indexes, [2.3]) self.assertEqual(value.values, [TimePattern(["M1-2", "M3-12"], [-9.3, -3.9])]) diff --git a/tests/test_purge.py b/tests/test_purge.py index 8c08c81a..504bb241 100644 --- a/tests/test_purge.py +++ b/tests/test_purge.py @@ -103,5 +103,5 @@ def test_purge_externally(self): self.assertFalse(db_map.get_items("entity_class")) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() From 5bd4c459e5997a9523672f341061fc5ab843064b Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Thu, 25 Jan 2024 17:01:02 +0100 Subject: [PATCH 308/317] Lift Python version upper limit Re spine-tools/Spine-Toolbox#2522 --- .github/workflows/run_unit_tests.yml | 2 +- CHANGELOG.md | 1 + pyproject.toml | 2 +- spinedb_api/helpers.py | 25 +++++++++++++++++++ spinedb_api/import_mapping/import_mapping.py | 5 ++-- spinedb_api/import_mapping/type_conversion.py | 10 +++----- tests/test_helpers.py | 22 ++++++++++++++++ 7 files changed, 57 insertions(+), 10 deletions(-) diff --git a/.github/workflows/run_unit_tests.yml b/.github/workflows/run_unit_tests.yml index 22a48065..ac42522b 100644 --- a/.github/workflows/run_unit_tests.yml +++ b/.github/workflows/run_unit_tests.yml @@ -18,7 +18,7 @@ jobs: strategy: matrix: os: [ubuntu-22.04, windows-latest] - python-version: [3.8, 3.9, "3.10", 3.11] + python-version: [3.8, 3.9, "3.10", 3.11, 3.12] steps: - uses: actions/checkout@v4 with: diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ce2b6bf..d24c84cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ but those functions and methods are pending deprecation. ### Changed +- Python 3.12 is now supported. - Objects and relationships have been replaced by *entities*. Zero-dimensional entities correspond to objects while multidimensional entities to relationships. diff --git a/pyproject.toml b/pyproject.toml index f4dd65bc..3969813e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ classifiers = [ "License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)", "Operating System :: OS Independent", ] -requires-python = ">=3.8.1, <3.12" +requires-python = ">=3.8.1" dependencies = [ # v1.4 does not pass tests "sqlalchemy >=1.3, <1.4", diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index f2e92baa..bbbaeb17 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -886,3 +886,28 @@ def group_consecutive(list_of_numbers): for _k, g in groupby(enumerate(sorted(list_of_numbers)), lambda x: x[0] - x[1]): group = list(map(itemgetter(1), g)) yield group[0], group[-1] + + +_TRUTHS = {s.casefold() for s in ("yes", "true", "y", "t", "1")} +_FALSES = {s.casefold() for s in ("no", "false", "n", "f", "0")} + + +def string_to_bool(string): + """Converts string to boolean. + + Recognizes "yes", "true", "y", "t" and "1" as True, "no", "false", "n", "f" and "0" as False. + Case insensitive. + Raises Value error if value is not recognized. + + Args: + string (str): string to convert + + Returns: + bool: True or False + """ + string = string.casefold() + if string in _TRUTHS: + return True + if string in _FALSES: + return False + raise ValueError(string) diff --git a/spinedb_api/import_mapping/import_mapping.py b/spinedb_api/import_mapping/import_mapping.py index af20d677..8fc5b955 100644 --- a/spinedb_api/import_mapping/import_mapping.py +++ b/spinedb_api/import_mapping/import_mapping.py @@ -11,8 +11,9 @@ ###################################################################################################################### """ Contains import mappings for database items such as entities, entity classes and parameter values. """ -from distutils.util import strtobool from enum import auto, Enum, unique + +from spinedb_api.helpers import string_to_bool from spinedb_api.mapping import Mapping, Position, unflatten, is_pivoted from spinedb_api.exception import InvalidMappingComponent @@ -808,7 +809,7 @@ class ScenarioActiveFlagMapping(ImportMapping): def _import_row(self, source_data, state, mapped_data): scenario = state[ImportKey.SCENARIO_NAME] - active = bool(strtobool(str(source_data))) + active = string_to_bool(str(source_data)) mapped_data.setdefault("scenarios", set()).add((scenario, active)) diff --git a/spinedb_api/import_mapping/type_conversion.py b/spinedb_api/import_mapping/type_conversion.py index fea46b4b..5200ccaa 100644 --- a/spinedb_api/import_mapping/type_conversion.py +++ b/spinedb_api/import_mapping/type_conversion.py @@ -10,13 +10,11 @@ # this program. If not, see . ###################################################################################################################### -""" -Type conversion functions. - -""" +""" Type conversion functions. """ import re -from distutils.util import strtobool + +from spinedb_api.helpers import string_to_bool from spinedb_api.parameter_value import DateTime, Duration, ParameterValueFormatError @@ -81,7 +79,7 @@ class BooleanConvertSpec(ConvertSpec): RETURN_TYPE = bool def __call__(self, value): - return self.RETURN_TYPE(strtobool(str(value))) + return self.RETURN_TYPE(string_to_bool(str(value))) class IntegerSequenceDateTimeConvertSpec(ConvertSpec): diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 4ed0abf2..676c6d06 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -20,6 +20,7 @@ name_from_dimensions, name_from_elements, remove_credentials_from_url, + string_to_bool, ) @@ -75,5 +76,26 @@ def test_returns_latest_version(self): self.assertEqual(get_head_alembic_version(), "8b0eff478bcb") +class TestStringToBool(unittest.TestCase): + def test_truths(self): + self.assertTrue(string_to_bool("yes")) + self.assertTrue(string_to_bool("YES")) + self.assertTrue(string_to_bool("y")) + self.assertTrue(string_to_bool("true")) + self.assertTrue(string_to_bool("t")) + self.assertTrue(string_to_bool("1")) + + def test_falses(self): + self.assertFalse(string_to_bool("NO")) + self.assertFalse(string_to_bool("no")) + self.assertFalse(string_to_bool("n")) + self.assertFalse(string_to_bool("false")) + self.assertFalse(string_to_bool("f")) + self.assertFalse(string_to_bool("0")) + + def test_raises_value_error(self): + self.assertRaises(ValueError, string_to_bool, "no truth in this") + + if __name__ == "__main__": unittest.main() From a6fb8297efd8ad890b33256ce39eb173944143cb Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Fri, 5 Apr 2024 10:51:20 +0300 Subject: [PATCH 309/317] Make deep_copy_value() accept None --- spinedb_api/parameter_value.py | 2 +- tests/test_parameter_value.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/spinedb_api/parameter_value.py b/spinedb_api/parameter_value.py index 6f22264b..51a16ca1 100644 --- a/spinedb_api/parameter_value.py +++ b/spinedb_api/parameter_value.py @@ -1745,7 +1745,7 @@ def deep_copy_value(value): Returns: Any: deep-copied value """ - if isinstance(value, (Number, str)): + if isinstance(value, (Number, str)) or value is None: return value if isinstance(value, Array): return Array(value.values, value.value_type, value.index_name) diff --git a/tests/test_parameter_value.py b/tests/test_parameter_value.py index 1f4de335..fa75a838 100644 --- a/tests/test_parameter_value.py +++ b/tests/test_parameter_value.py @@ -1005,6 +1005,9 @@ def convert_map_to_dict(self): self.assertEqual(nested_map, {"A": {"a": -3.2, "b": -2.3}, "B": {"c": 3.2, "d": 2.3}}) def test_deep_copy_value_for_scalars(self): + x = None + copy_of_x = deep_copy_value(x) + self.assertIsNone(copy_of_x) x = 1.0 copy_of_x = deep_copy_value(x) self.assertEqual(x, copy_of_x) From b3051846526ef75021ea417aae02b8d526ba7e42 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Tue, 9 Apr 2024 13:25:03 +0300 Subject: [PATCH 310/317] Fix typo in benchmarks/README.md link --- benchmarks/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/README.md b/benchmarks/README.md index 5482382f..816ec167 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -22,5 +22,5 @@ Benchmarks from e.g. different commits/branches can be compared by python -mpyperf compare_to ``` -Check the [`pyperf` documentation]((https://pyperf.readthedocs.io/en/latest/index.html)) +Check the [`pyperf` documentation](https://pyperf.readthedocs.io/en/latest/index.html) for further things you can do with it. From 770d6db1ca9bf5fae673c6653b24584f4a37953b Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Tue, 9 Apr 2024 16:35:34 +0300 Subject: [PATCH 311/317] Fix benchmarks and add new ones Benchmarks were giving 20x smaller results due to misunderstanding on how the inner_loop parameter of pyperf works. The parameter is no longer used. Also, added two new benchmarks for from_database() in case of Map. --- benchmarks/map_from_database.py | 40 ++++++++++++++++ ...update_default_value_to_different_value.py | 7 ++- .../update_default_value_to_same_value.py | 6 +-- benchmarks/utils.py | 47 ++++++++++--------- 4 files changed, 70 insertions(+), 30 deletions(-) create mode 100644 benchmarks/map_from_database.py diff --git a/benchmarks/map_from_database.py b/benchmarks/map_from_database.py new file mode 100644 index 00000000..9768cb6b --- /dev/null +++ b/benchmarks/map_from_database.py @@ -0,0 +1,40 @@ +""" +This benchmark tests the performance of reading a Map type value from database. +""" + +import time +import pyperf +from spinedb_api import from_database, to_database +from benchmarks.utils import build_even_map, run_file_name + + +def value_from_database(loops, db_value, value_type): + duration = 0.0 + for _ in range(loops): + start = time.perf_counter() + from_database(db_value, value_type) + duration += (time.perf_counter() - start) + return duration + + +def run_benchmark(): + file_name = run_file_name() + runner = pyperf.Runner(loops=3) + runs = { + "value_from_database[Map(10, 10, 100)]": {"dimensions": (10, 10, 100)}, + "value_from_database[Map(1000)]": {"dimensions": (10000,)}, + } + for name, parameters in runs.items(): + db_value, value_type = to_database(build_even_map(parameters["dimensions"])) + benchmark = runner.bench_time_func( + name, + value_from_database, + db_value, + value_type, + ) + if benchmark is not None: + pyperf.add_runs(file_name, benchmark) + + +if __name__ == "__main__": + run_benchmark() diff --git a/benchmarks/update_default_value_to_different_value.py b/benchmarks/update_default_value_to_different_value.py index 0f27836d..ca1a6be4 100644 --- a/benchmarks/update_default_value_to_different_value.py +++ b/benchmarks/update_default_value_to_different_value.py @@ -6,7 +6,7 @@ import time import pyperf from spinedb_api import DatabaseMapping, to_database -from benchmarks.utils import build_sizeable_map, run_file_name +from benchmarks.utils import build_even_map, run_file_name def update_default_value(loops, db_map, first_db_value, first_value_type, second_db_value, second_value_type): @@ -29,13 +29,13 @@ def update_default_value(loops, db_map, first_db_value, first_value_type, second def run_benchmark(): first_value, first_type = to_database(None) - second_value, second_type = to_database(build_sizeable_map()) + second_value, second_type = to_database(build_even_map()) with DatabaseMapping("sqlite://", create=True) as db_map: db_map.add_entity_class_item(name="Object") db_map.add_parameter_definition_item( name="x", entity_class_name="Object", default_value=first_value, default_type=first_type ) - runner = pyperf.Runner(min_time=0.0001) + runner = pyperf.Runner(min_time=0.001) benchmark = runner.bench_time_func( "update_parameter_definition_item[None,Map]", update_default_value, @@ -44,7 +44,6 @@ def run_benchmark(): first_type, second_value, second_type, - inner_loops=10, ) pyperf.add_runs(run_file_name(), benchmark) diff --git a/benchmarks/update_default_value_to_same_value.py b/benchmarks/update_default_value_to_same_value.py index 13e7b634..cca9b102 100644 --- a/benchmarks/update_default_value_to_same_value.py +++ b/benchmarks/update_default_value_to_same_value.py @@ -5,7 +5,7 @@ import time import pyperf from spinedb_api import DatabaseMapping, to_database -from benchmarks.utils import build_sizeable_map, run_file_name +from benchmarks.utils import build_even_map, run_file_name def update_default_value(loops, db_map, value, value_type): @@ -24,7 +24,7 @@ def update_default_value(loops, db_map, value, value_type): def run_benchmark(): - value, value_type = to_database(build_sizeable_map()) + value, value_type = to_database(build_even_map()) with DatabaseMapping("sqlite://", create=True) as db_map: db_map.add_entity_class_item(name="Object") db_map.add_parameter_definition_item( @@ -32,7 +32,7 @@ def run_benchmark(): ) runner = pyperf.Runner() benchmark = runner.bench_time_func( - "update_parameter_definition_item[Map,Map]", update_default_value, db_map, value, value_type, inner_loops=10 + "update_parameter_definition_item[Map,Map]", update_default_value, db_map, value, value_type ) pyperf.add_runs(run_file_name(), benchmark) diff --git a/benchmarks/utils.py b/benchmarks/utils.py index c2285ba2..072f1d47 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -1,32 +1,33 @@ import datetime import math +from typing import Sequence from spinedb_api import __version__, DateTime, Map -def build_sizeable_map(): +def build_map(size: int) -> Map: start = datetime.datetime(year=2024, month=1, day=1) - root_xs = [] - root_ys = [] - i_max = 10 - j_max = 10 - k_max = 10 - total = i_max * j_max * k_max - for i in range(i_max): - root_xs.append(DateTime(start + datetime.timedelta(hours=i))) - leaf_xs = [] - leaf_ys = [] - for j in range(j_max): - leaf_xs.append(DateTime(start + datetime.timedelta(hours=j))) - xs = [] - ys = [] - for k in range(k_max): - xs.append(DateTime(start + datetime.timedelta(hours=k))) - x = float(k + k_max * j + j_max * i) / total - ys.append(math.sin(x * math.pi / 2.0) + (x * j) ** 2 + x * i) - leaf_ys.append(Map(xs, ys)) - root_ys.append(Map(leaf_xs, leaf_ys)) - return Map(root_xs, root_ys) + xs = [] + ys = [] + for i in range(size): + xs.append(DateTime(start + datetime.timedelta(hours=i))) + x = i / size + ys.append(math.sin(x * math.pi / 2.0) + x) + return Map(xs, ys) -def run_file_name(): +def build_even_map(shape: Sequence[int] = (10, 10, 10)) -> Map: + if not shape: + return Map([], [], index_type=DateTime) + if len(shape) == 1: + return build_map(shape[0]) + xs = [] + ys = [] + for i in range(shape[0]): + start = datetime.datetime(year=2024, month=1, day=1) + xs.append(DateTime(start + datetime.timedelta(hours=i))) + ys.append(build_even_map(shape[1:])) + return Map(xs, ys) + + +def run_file_name() -> str: return f"benchmark-{__version__}.json" From d54ebe9261430ccc09f99beb8d08738b4454ba7b Mon Sep 17 00:00:00 2001 From: Henrik Koski Date: Wed, 10 Apr 2024 10:21:56 +0300 Subject: [PATCH 312/317] Remove the entity_type field form Excel exports Re spine-tools/Spine-Toolbox#2601 --- spinedb_api/spine_io/exporters/excel.py | 3 --- spinedb_api/spine_io/importers/excel_reader.py | 9 +++------ 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/spinedb_api/spine_io/exporters/excel.py b/spinedb_api/spine_io/exporters/excel.py index 80b7bc68..9cc4ecd9 100644 --- a/spinedb_api/spine_io/exporters/excel.py +++ b/spinedb_api/spine_io/exporters/excel.py @@ -60,14 +60,11 @@ def _make_preamble(table_name, title_key): return {"sheet_type": "object_group", "class_name": class_name} dimension_id_list = title_key.get("dimension_id_list") if dimension_id_list is None: - entity_type = "object" entity_dim_count = 0 else: - entity_type = "relationship" entity_dim_count = len(dimension_id_list.split(",")) preamble = { "sheet_type": "entity", - "entity_type": entity_type, "class_name": class_name, "entity_dim_count": entity_dim_count, } diff --git a/spinedb_api/spine_io/importers/excel_reader.py b/spinedb_api/spine_io/importers/excel_reader.py index 0e143cf3..c98a01c5 100644 --- a/spinedb_api/spine_io/importers/excel_reader.py +++ b/spinedb_api/spine_io/importers/excel_reader.py @@ -241,19 +241,16 @@ def _get_header(ws, header_row, index_dim_count): def _create_entity_mappings(metadata, header, index_dim_count): class_name = metadata["class_name"] - entity_type = metadata["entity_type"] + entity_dim_count = int(metadata["entity_dim_count"]) map_dict = {"name": class_name} ent_alt_map_type = "row" if index_dim_count else "column" - if entity_type == "object": + if entity_dim_count == 0: map_dict["map_type"] = "ObjectClass" map_dict["objects"] = {"map_type": ent_alt_map_type, "reference": 0} - elif entity_type == "relationship": - entity_dim_count = int(metadata["entity_dim_count"]) + else: map_dict["map_type"] = "RelationshipClass" map_dict["object_classes"] = header[:entity_dim_count] map_dict["objects"] = [{"map_type": ent_alt_map_type, "reference": i} for i in range(entity_dim_count)] - else: - return None, None value_type = metadata.get("value_type") if value_type is not None: value = {"value_type": value_type} From b219ceb9b8a47c4bb2d65536277e2ff26e18d6ba Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Wed, 10 Apr 2024 12:39:49 +0300 Subject: [PATCH 313/317] Try datetime.fromisoformat() before falling back to dateutil.parser.parse() fromisoformat() makes e.g. constructing DateTime objects from ISO 8601 time stamps ~10x speedier. --- benchmarks/datetime_from_database.py | 56 ++++++++++++++++++++++++++++ benchmarks/map_from_database.py | 2 +- spinedb_api/parameter_value.py | 28 ++++++++++---- 3 files changed, 77 insertions(+), 9 deletions(-) create mode 100644 benchmarks/datetime_from_database.py diff --git a/benchmarks/datetime_from_database.py b/benchmarks/datetime_from_database.py new file mode 100644 index 00000000..3a00ed65 --- /dev/null +++ b/benchmarks/datetime_from_database.py @@ -0,0 +1,56 @@ +""" +This benchmark tests the performance of reading a DateTime value from database. +""" + +import datetime +import time +from typing import Any, Sequence, Tuple +import pyperf +from spinedb_api import DateTime, from_database, to_database +from benchmarks.utils import run_file_name + + +def build_datetimes(count: int) -> Sequence[DateTime]: + datetimes = [] + year = 2024 + month = 1 + day = 1 + hour = 0 + while len(datetimes) != count: + datetimes.append(DateTime(datetime.datetime(year, month, day, hour))) + hour += 1 + if hour == 24: + hour = 1 + day += 1 + if day == 29: + day = 1 + month += 1 + if month == 13: + month = 1 + year += 1 + return datetimes + + +def value_from_database(loops: int, db_values_and_types: Sequence[Tuple[Any, str]]) -> float: + duration = 0.0 + for _ in range(loops): + for db_value, db_type in db_values_and_types: + start = time.perf_counter() + from_database(db_value, db_type) + duration += time.perf_counter() - start + return duration + + +def run_benchmark(): + file_name = run_file_name() + runner = pyperf.Runner() + inner_loops = 100 + db_values_and_types = [to_database(x) for x in build_datetimes(inner_loops)] + benchmark = runner.bench_time_func( + "from_database[DateTime]", value_from_database, db_values_and_types, inner_loops=inner_loops + ) + pyperf.add_runs(file_name, benchmark) + + +if __name__ == "__main__": + run_benchmark() diff --git a/benchmarks/map_from_database.py b/benchmarks/map_from_database.py index 9768cb6b..d4324001 100644 --- a/benchmarks/map_from_database.py +++ b/benchmarks/map_from_database.py @@ -13,7 +13,7 @@ def value_from_database(loops, db_value, value_type): for _ in range(loops): start = time.perf_counter() from_database(db_value, value_type) - duration += (time.perf_counter() - start) + duration += time.perf_counter() - start return duration diff --git a/spinedb_api/parameter_value.py b/spinedb_api/parameter_value.py index 51a16ca1..9cecc1fc 100644 --- a/spinedb_api/parameter_value.py +++ b/spinedb_api/parameter_value.py @@ -389,9 +389,12 @@ def _break_dictionary(data): def _datetime_from_database(value): """Converts a datetime database value into a DateTime object.""" try: - stamp = dateutil.parser.parse(value) + stamp = datetime.fromisoformat(value) except ValueError: - raise ParameterValueFormatError(f'Could not parse datetime from "{value}"') + try: + stamp = dateutil.parser.parse(value) + except ValueError: + raise ParameterValueFormatError(f'Could not parse datetime from "{value}"') return DateTime(stamp) @@ -517,9 +520,12 @@ def _time_series_from_single_column(value_dict): duration = str(duration) + _TIME_SERIES_PLAIN_INDEX_UNIT relativedeltas.append(duration_to_relativedelta(duration)) try: - start = dateutil.parser.parse(start) + start = datetime.fromisoformat(start) except ValueError: - raise ParameterValueFormatError(f'Could not decode start value "{start}"') + try: + start = dateutil.parser.parse(start) + except ValueError: + raise ParameterValueFormatError(f'Could not decode start value "{start}"') values = np.array(value_dict["data"]) return TimeSeriesFixedResolution( start, relativedeltas, values, ignore_year, repeat, value_dict.get("index_name", "") @@ -744,9 +750,12 @@ def __init__(self, value=None): value = datetime(year=2000, month=1, day=1) elif isinstance(value, str): try: - value = dateutil.parser.parse(value) + value = datetime.fromisoformat(value) except ValueError: - raise ParameterValueFormatError(f'Could not parse datetime from "{value}"') + try: + value = dateutil.parser.parse(value) + except ValueError: + raise ParameterValueFormatError(f'Could not parse datetime from "{value}"') elif isinstance(value, DateTime): value = copy(value._value) elif not isinstance(value, datetime): @@ -1348,9 +1357,12 @@ def start(self, start): """ if isinstance(start, str): try: - self._start = dateutil.parser.parse(start) + self._start = datetime.fromisoformat(start) except ValueError: - raise ParameterValueFormatError(f'Cannot parse start time "{start}"') + try: + self._start = dateutil.parser.parse(start) + except ValueError: + raise ParameterValueFormatError(f'Cannot parse start time "{start}"') elif isinstance(start, np.datetime64): self._start = start.tolist() else: From 9263b66e37f86924a3d64ad0d82a4201f423d1b6 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Thu, 11 Apr 2024 12:26:46 +0300 Subject: [PATCH 314/317] Add more benchmarks --- benchmarks/datetime_from_database.py | 6 +-- benchmarks/mapped_item_getitem.py | 62 ++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 3 deletions(-) create mode 100644 benchmarks/mapped_item_getitem.py diff --git a/benchmarks/datetime_from_database.py b/benchmarks/datetime_from_database.py index 3a00ed65..5a70bdb8 100644 --- a/benchmarks/datetime_from_database.py +++ b/benchmarks/datetime_from_database.py @@ -20,7 +20,7 @@ def build_datetimes(count: int) -> Sequence[DateTime]: datetimes.append(DateTime(datetime.datetime(year, month, day, hour))) hour += 1 if hour == 24: - hour = 1 + hour = 0 day += 1 if day == 29: day = 1 @@ -43,8 +43,8 @@ def value_from_database(loops: int, db_values_and_types: Sequence[Tuple[Any, str def run_benchmark(): file_name = run_file_name() - runner = pyperf.Runner() - inner_loops = 100 + runner = pyperf.Runner(loops=10) + inner_loops = 1000 db_values_and_types = [to_database(x) for x in build_datetimes(inner_loops)] benchmark = runner.bench_time_func( "from_database[DateTime]", value_from_database, db_values_and_types, inner_loops=inner_loops diff --git a/benchmarks/mapped_item_getitem.py b/benchmarks/mapped_item_getitem.py new file mode 100644 index 00000000..84afb732 --- /dev/null +++ b/benchmarks/mapped_item_getitem.py @@ -0,0 +1,62 @@ +""" +This benchmark tests the performance of the MappedItemBase.__getitem__() method. +""" + +import pyperf +import time +from typing import Dict +from spinedb_api import DatabaseMapping +from spinedb_api.db_mapping_base import PublicItem +from benchmarks.utils import run_file_name + + +def use_subscript_operator(loops: int, items: PublicItem, field: Dict): + duration = 0.0 + for _ in range(loops): + for item in items: + start = time.perf_counter() + value = item[field] + duration += time.perf_counter() - start + return duration + + +def run_benchmark(): + runner = pyperf.Runner() + inner_loops = 1000 + object_class_names = [str(i) for i in range(inner_loops)] + relationship_class_names = [f"r{dimension}" for dimension in object_class_names] + with DatabaseMapping("sqlite://", create=True) as db_map: + object_classes = [] + for name in object_class_names: + item, error = db_map.add_entity_class_item(name=name) + assert error is None + object_classes.append(item) + relationship_classes = [] + for name, dimension in zip(relationship_class_names, object_classes): + item, error = db_map.add_entity_class_item(name, dimension_name_list=(dimension["name"],)) + assert error is None + relationship_classes.append(item) + benchmarks = [ + runner.bench_time_func( + "PublicItem subscript['name' in EntityClassItem]", + use_subscript_operator, + object_classes, + "name", + inner_loops=inner_loops, + ), + runner.bench_time_func( + "PublicItem subscript['dimension_name_list' in EntityClassItem]", + use_subscript_operator, + relationship_classes, + "dimension_name_list", + inner_loops=inner_loops, + ), + ] + file_name = run_file_name() + for benchmark in benchmarks: + if benchmark is not None: + pyperf.add_runs(file_name, benchmark) + + +if __name__ == "__main__": + run_benchmark() From 7e940a73b73a7852310ceb8fd856abb71faa2055 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Thu, 11 Apr 2024 15:58:32 +0300 Subject: [PATCH 315/317] Fix example in docstring, improve tutorial Re #376 --- docs/source/tutorial.rst | 9 +++++---- spinedb_api/parameter_value.py | 12 ++++++++---- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/docs/source/tutorial.rst b/docs/source/tutorial.rst index f4adfa56..85282d7c 100644 --- a/docs/source/tutorial.rst +++ b/docs/source/tutorial.rst @@ -25,7 +25,7 @@ To create a :class:`.DatabaseMapping`, we just pass the URL of the DB to the cla import spinedb_api as api from spinedb_api import DatabaseMapping - url = "mysql://spine_db" # The URL of an existing Spine DB + url = "mysql+pymysql://spine_db" # The URL of an existing Spine DB with DatabaseMapping(url) as db_map: # Do something with db_map @@ -55,7 +55,8 @@ We can remediate this by creating a SQLite DB (which is just a file in your syst The above will create a file called ``first.sqlite`` in your current working directoy. Note that we pass the keyword argument ``create=True`` to :class:`.DatabaseMapping` to explicitly say -that we want the DB to be created at the given URL. +that we want the DB to be created at the given URL +if it does not exists already. .. note:: @@ -151,9 +152,9 @@ Now let's retrieve our parameter value:: alternative_name="Base" ) -We use :func:`.from_database` to convert the value and type from the parameter value into our original value:: +We can use the ``"parsed_value"`` field to access our original value:: - nemo_color = api.from_database(nemo_color_item["value"], nemo_color_item["type"]) + nemo_color = nemo_color_item["parsed_value"] assert nemo_color == "mainly orange" To retrieve all the items of a given type, we use :meth:`~.DatabaseMapping.get_items`:: diff --git a/spinedb_api/parameter_value.py b/spinedb_api/parameter_value.py index 9cecc1fc..1561e20b 100644 --- a/spinedb_api/parameter_value.py +++ b/spinedb_api/parameter_value.py @@ -40,7 +40,7 @@ # Create the Python object parsed_value = TimeSeriesFixedResolution( - datetime("2023-01-01T00:00"), # start + "2023-01-01T00:00", # start "1D", # resolution [9, 8, 7, 6, 5, 4, 3, 2, 1, 0], # values ignore_year=False, @@ -60,7 +60,7 @@ ) db_map.commit_session("Tom is living one day at a time") -Similarly, to read a parameter value from the DB into a Python object:: +The value can be accessed as a Python object using the ``parsed_value`` field:: # Get the parameter_value from the DB with DatabaseMapping(url) as db_map: @@ -70,11 +70,15 @@ parameter_definition_name="number_of_lives", alternative_name="Base", ) + value = pval_item["parsed_value"] + +In the rare case where a manual conversion from a DB value to Python object is needed, +use :func:`.from_database`:: + # Obtain value and type value, type_ = pval_item["value"], pval_item["type"] - # Translate value and type to a Python object + # Translate value and type to a Python object manually parsed_value = from_database(value, type_) - """ from collections.abc import Sequence From 0ce27d41f56cc7a90a4d42e6f1bb5bfb1c16718d Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Tue, 23 Apr 2024 16:22:38 +0300 Subject: [PATCH 316/317] Bump Spine DB Server and cliend versions from 6 to 7 The interface is not compatible with 6 anymore with all the 0.8 changes. --- spinedb_api/spine_db_client.py | 2 +- spinedb_api/spine_db_server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/spinedb_api/spine_db_client.py b/spinedb_api/spine_db_client.py index 017af3e8..759c329e 100644 --- a/spinedb_api/spine_db_client.py +++ b/spinedb_api/spine_db_client.py @@ -19,7 +19,7 @@ from sqlalchemy.engine.url import URL from .server_client_helpers import ReceiveAllMixing, encode, decode -client_version = 6 +client_version = 7 class SpineDBClient(ReceiveAllMixing): diff --git a/spinedb_api/spine_db_server.py b/spinedb_api/spine_db_server.py index 5604823f..26bca284 100644 --- a/spinedb_api/spine_db_server.py +++ b/spinedb_api/spine_db_server.py @@ -117,7 +117,7 @@ def _import_entity_class(server_url, class_name): from .filters.tools import apply_filter_stack from .spine_db_client import SpineDBClient -_current_server_version = 6 +_current_server_version = 7 def get_current_server_version(): From 29e89b04beaee01cce97e08c8d356f3e614accc9 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Wed, 24 Apr 2024 09:59:41 +0300 Subject: [PATCH 317/317] Don't forcefully write benchmark results on disk It is better if you can store the result file where ever you want just by editing the benchmark script. --- benchmarks/README.md | 13 ++++++++----- benchmarks/datetime_from_database.py | 9 ++++----- benchmarks/map_from_database.py | 9 ++++----- benchmarks/mapped_item_getitem.py | 13 ++++++------- .../update_default_value_to_different_value.py | 8 ++++---- benchmarks/update_default_value_to_same_value.py | 12 +++++++----- benchmarks/utils.py | 6 +----- 7 files changed, 34 insertions(+), 36 deletions(-) diff --git a/benchmarks/README.md b/benchmarks/README.md index 816ec167..12611d72 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -2,21 +2,24 @@ This Python package contains performance benchmarks for `spinedb_api`. The benchmarks use [`pyperf`](https://pyperf.readthedocs.io/en/latest/index.html) -which can be installed by installing the optional developer dependencies: +which is installed as part of the optional developer dependencies: ```commandline -python -mpip install .[dev] +python -mpip install -e .[dev] ``` -Each Python file is an individual script -that writes the run results into a common `.json` file. +Each Python file is a self-contained script +that benchmarks some aspect of the DB API. +Benchmark results can be optionally written into a`.json` file +by modifying the script. +This may be handy for comparing different branches/commits/changes etc. The file can be inspected by ```commandline python -mpyperf show ``` -Benchmarks from e.g. different commits/branches can be compared by +Benchmark files from e.g. different commits/branches can be compared by ```commandline python -mpyperf compare_to diff --git a/benchmarks/datetime_from_database.py b/benchmarks/datetime_from_database.py index 5a70bdb8..3b34383b 100644 --- a/benchmarks/datetime_from_database.py +++ b/benchmarks/datetime_from_database.py @@ -7,7 +7,6 @@ from typing import Any, Sequence, Tuple import pyperf from spinedb_api import DateTime, from_database, to_database -from benchmarks.utils import run_file_name def build_datetimes(count: int) -> Sequence[DateTime]: @@ -41,16 +40,16 @@ def value_from_database(loops: int, db_values_and_types: Sequence[Tuple[Any, str return duration -def run_benchmark(): - file_name = run_file_name() +def run_benchmark(file_name): runner = pyperf.Runner(loops=10) inner_loops = 1000 db_values_and_types = [to_database(x) for x in build_datetimes(inner_loops)] benchmark = runner.bench_time_func( "from_database[DateTime]", value_from_database, db_values_and_types, inner_loops=inner_loops ) - pyperf.add_runs(file_name, benchmark) + if file_name: + pyperf.add_runs(file_name, benchmark) if __name__ == "__main__": - run_benchmark() + run_benchmark("") diff --git a/benchmarks/map_from_database.py b/benchmarks/map_from_database.py index d4324001..9037ac57 100644 --- a/benchmarks/map_from_database.py +++ b/benchmarks/map_from_database.py @@ -5,7 +5,7 @@ import time import pyperf from spinedb_api import from_database, to_database -from benchmarks.utils import build_even_map, run_file_name +from benchmarks.utils import build_even_map def value_from_database(loops, db_value, value_type): @@ -17,8 +17,7 @@ def value_from_database(loops, db_value, value_type): return duration -def run_benchmark(): - file_name = run_file_name() +def run_benchmark(file_name): runner = pyperf.Runner(loops=3) runs = { "value_from_database[Map(10, 10, 100)]": {"dimensions": (10, 10, 100)}, @@ -32,9 +31,9 @@ def run_benchmark(): db_value, value_type, ) - if benchmark is not None: + if file_name and benchmark is not None: pyperf.add_runs(file_name, benchmark) if __name__ == "__main__": - run_benchmark() + run_benchmark("") diff --git a/benchmarks/mapped_item_getitem.py b/benchmarks/mapped_item_getitem.py index 84afb732..09fc75c5 100644 --- a/benchmarks/mapped_item_getitem.py +++ b/benchmarks/mapped_item_getitem.py @@ -7,7 +7,6 @@ from typing import Dict from spinedb_api import DatabaseMapping from spinedb_api.db_mapping_base import PublicItem -from benchmarks.utils import run_file_name def use_subscript_operator(loops: int, items: PublicItem, field: Dict): @@ -20,7 +19,7 @@ def use_subscript_operator(loops: int, items: PublicItem, field: Dict): return duration -def run_benchmark(): +def run_benchmark(file_name): runner = pyperf.Runner() inner_loops = 1000 object_class_names = [str(i) for i in range(inner_loops)] @@ -52,11 +51,11 @@ def run_benchmark(): inner_loops=inner_loops, ), ] - file_name = run_file_name() - for benchmark in benchmarks: - if benchmark is not None: - pyperf.add_runs(file_name, benchmark) + if file_name: + for benchmark in benchmarks: + if benchmark is not None: + pyperf.add_runs(file_name, benchmark) if __name__ == "__main__": - run_benchmark() + run_benchmark("") diff --git a/benchmarks/update_default_value_to_different_value.py b/benchmarks/update_default_value_to_different_value.py index ca1a6be4..27039168 100644 --- a/benchmarks/update_default_value_to_different_value.py +++ b/benchmarks/update_default_value_to_different_value.py @@ -6,7 +6,7 @@ import time import pyperf from spinedb_api import DatabaseMapping, to_database -from benchmarks.utils import build_even_map, run_file_name +from benchmarks.utils import build_even_map def update_default_value(loops, db_map, first_db_value, first_value_type, second_db_value, second_value_type): @@ -27,7 +27,7 @@ def update_default_value(loops, db_map, first_db_value, first_value_type, second return total_time -def run_benchmark(): +def run_benchmark(file_name: str): first_value, first_type = to_database(None) second_value, second_type = to_database(build_even_map()) with DatabaseMapping("sqlite://", create=True) as db_map: @@ -45,8 +45,8 @@ def run_benchmark(): second_value, second_type, ) - pyperf.add_runs(run_file_name(), benchmark) + pyperf.add_runs(file_name, benchmark) if __name__ == "__main__": - run_benchmark() + run_benchmark("") diff --git a/benchmarks/update_default_value_to_same_value.py b/benchmarks/update_default_value_to_same_value.py index cca9b102..501cf30c 100644 --- a/benchmarks/update_default_value_to_same_value.py +++ b/benchmarks/update_default_value_to_same_value.py @@ -3,12 +3,13 @@ the default value is somewhat complex Map and the update does not change anything. """ import time +from typing import Optional import pyperf from spinedb_api import DatabaseMapping, to_database -from benchmarks.utils import build_even_map, run_file_name +from benchmarks.utils import build_even_map -def update_default_value(loops, db_map, value, value_type): +def update_default_value(loops: int, db_map: DatabaseMapping, value: bytes, value_type: Optional[str]) -> float: total_time = 0.0 for counter in range(loops): start = time.perf_counter() @@ -23,7 +24,7 @@ def update_default_value(loops, db_map, value, value_type): return total_time -def run_benchmark(): +def run_benchmark(file_name: str): value, value_type = to_database(build_even_map()) with DatabaseMapping("sqlite://", create=True) as db_map: db_map.add_entity_class_item(name="Object") @@ -34,8 +35,9 @@ def run_benchmark(): benchmark = runner.bench_time_func( "update_parameter_definition_item[Map,Map]", update_default_value, db_map, value, value_type ) - pyperf.add_runs(run_file_name(), benchmark) + if file_name: + pyperf.add_runs(file_name, benchmark) if __name__ == "__main__": - run_benchmark() + run_benchmark("") diff --git a/benchmarks/utils.py b/benchmarks/utils.py index 072f1d47..d860f9eb 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -1,7 +1,7 @@ import datetime import math from typing import Sequence -from spinedb_api import __version__, DateTime, Map +from spinedb_api import DateTime, Map def build_map(size: int) -> Map: @@ -27,7 +27,3 @@ def build_even_map(shape: Sequence[int] = (10, 10, 10)) -> Map: xs.append(DateTime(start + datetime.timedelta(hours=i))) ys.append(build_even_map(shape[1:])) return Map(xs, ys) - - -def run_file_name() -> str: - return f"benchmark-{__version__}.json"