diff --git a/spinedb_api/filters/scenario_filter.py b/spinedb_api/filters/scenario_filter.py index d48a798b..f6415473 100644 --- a/spinedb_api/filters/scenario_filter.py +++ b/spinedb_api/filters/scenario_filter.py @@ -360,6 +360,22 @@ def _make_scenario_filtered_parameter_value_sq(db_map, state): Alias: a subquery for parameter value filtered by selected scenario """ ext_entity_sq = _ext_entity_sq(db_map, state) + ext_entity_element_count_sq = ( + db_map.query( + db_map.entity_element_sq.c.entity_id, + func.count(db_map.entity_element_sq.c.element_id).label("element_count"), + ) + .group_by(db_map.entity_element_sq.c.entity_id) + .subquery() + ) + ext_entity_class_dimension_count_sq = ( + db_map.query( + db_map.entity_class_dimension_sq.c.entity_class_id, + func.count(db_map.entity_class_dimension_sq.c.dimension_id).label("dimension_count"), + ) + .group_by(db_map.entity_class_dimension_sq.c.entity_class_id) + .subquery() + ) ext_parameter_value_sq = ( db_map.query( state.original_parameter_value_sq, @@ -387,6 +403,22 @@ def _make_scenario_filtered_parameter_value_sq(db_map, state): and_(ext_entity_sq.c.active == None, ext_entity_sq.c.active_by_default == True), ), ) + .outerjoin( + ext_entity_element_count_sq, ext_entity_element_count_sq.c.entity_id == ext_parameter_value_sq.c.entity_id + ) + .outerjoin( + ext_entity_class_dimension_count_sq, + ext_entity_class_dimension_count_sq.c.entity_class_id == ext_parameter_value_sq.c.entity_class_id, + ) + .filter( + or_( + and_( + ext_entity_element_count_sq.c.element_count == None, + ext_entity_class_dimension_count_sq.c.dimension_count == None, + ), + ext_entity_element_count_sq.c.element_count == ext_entity_class_dimension_count_sq.c.dimension_count, + ) + ) .subquery() ) diff --git a/tests/filters/test_scenario_filter.py b/tests/filters/test_scenario_filter.py index 3b2d712f..e7ff5250 100644 --- a/tests/filters/test_scenario_filter.py +++ b/tests/filters/test_scenario_filter.py @@ -965,6 +965,62 @@ def test_parameter_values_for_entities_that_swim_against_active_by_default(self) self.assertEqual(len(values), 1) self.assertEqual(from_database(values[0]["value"], values[0]["type"]), -2.3) + def test_parameter_values_of_multidimensional_entity_whose_elements_have_entity_alternatives(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + self._assert_success(db_map.add_scenario_item(name="base")) + self._assert_success( + db_map.add_scenario_alternative_item(scenario_name="base", alternative_name="Base", rank=1) + ) + self._assert_success(db_map.add_entity_class_item(name="Object")) + self._assert_success(db_map.add_entity_item(name="visible", entity_class_name="Object")) + self._assert_success( + db_map.add_entity_alternative_item( + entity_class_name="Object", entity_byname=("visible",), alternative_name="Base", active=True + ) + ) + self._assert_success(db_map.add_entity_item(name="invisible", entity_class_name="Object")) + self._assert_success( + db_map.add_entity_alternative_item( + entity_class_name="Object", entity_byname=("invisible",), alternative_name="Base", active=False + ) + ) + self._assert_success(db_map.add_entity_class_item(name="Relationship", dimension_name_list=("Object",))) + self._assert_success( + db_map.add_entity_item(element_name_list=("visible",), entity_class_name="Relationship") + ) + self._assert_success( + db_map.add_entity_item(element_name_list=("invisible",), entity_class_name="Relationship") + ) + self._assert_success(db_map.add_parameter_definition_item(name="y", entity_class_name="Relationship")) + value, value_type = to_database(2.3) + self._assert_success( + db_map.add_parameter_value_item( + entity_class_name="Relationship", + entity_byname=("visible",), + parameter_definition_name="y", + alternative_name="Base", + value=value, + type=value_type, + ) + ) + value, value_type = to_database(-2.3) + self._assert_success( + db_map.add_parameter_value_item( + entity_class_name="Relationship", + entity_byname=("invisible",), + parameter_definition_name="y", + alternative_name="Base", + value=value, + type=value_type, + ) + ) + db_map.commit_session("Add test data") + config = scenario_filter_config("base") + scenario_filter_from_dict(db_map, config) + values = db_map.query(db_map.parameter_value_sq).all() + self.assertEqual(len(values), 1) + self.assertEqual(from_database(values[0].value, values[0].type), 2.3) + class TestScenarioFilterUtils(DataBuilderTestCase): def test_scenario_filter_config(self):