diff --git a/decoimpact/business/entities/rule_based_model.py b/decoimpact/business/entities/rule_based_model.py index c805f100..5dbad9e0 100644 --- a/decoimpact/business/entities/rule_based_model.py +++ b/decoimpact/business/entities/rule_based_model.py @@ -97,7 +97,6 @@ def initialize(self, logger: ILogger) -> None: self._output_dataset = _du.create_composed_dataset( self._input_datasets, self._make_output_variables_list(), self._mappings ) - self._rule_processor = RuleProcessor(self._rules, self._output_dataset) success = self._rule_processor.initialize(logger) @@ -109,7 +108,9 @@ def execute(self, logger: ILogger) -> None: if self._rule_processor is None: raise RuntimeError("Processor is not set, please initialize model.") - self._rule_processor.process_rules(self._output_dataset, logger) + self._output_dataset = self._rule_processor.process_rules( + self._output_dataset, logger + ) def finalize(self, logger: ILogger) -> None: """Finalizes the model""" diff --git a/decoimpact/business/entities/rule_processor.py b/decoimpact/business/entities/rule_processor.py index 7070ff49..5e608547 100644 --- a/decoimpact/business/entities/rule_processor.py +++ b/decoimpact/business/entities/rule_processor.py @@ -62,7 +62,9 @@ def initialize(self, logger: ILogger) -> bool: return success - def process_rules(self, output_dataset: _xr.Dataset, logger: ILogger) -> None: + def process_rules( + self, output_dataset: _xr.Dataset, logger: ILogger + ) -> _xr.Dataset: """Processes the rules defined in the initialize method and adds the results to the provided output_dataset. @@ -89,7 +91,16 @@ def process_rules(self, output_dataset: _xr.Dataset, logger: ILogger) -> None: rule_result.dims, rule_result.values, rule_result.attrs, + rule_result.coords, ) + for coord_key in rule_result.coords: + # the coord_key is overwritten in case we don't have the if + # statement below + if coord_key not in output_dataset.coords: + output_dataset = output_dataset.assign_coords( + {coord_key: rule_result[coord_key]} + ) + return output_dataset def _create_rule_sets( self, diff --git a/decoimpact/business/entities/rules/time_aggregation_rule.py b/decoimpact/business/entities/rules/time_aggregation_rule.py index 27ba39ba..f0d2cb3e 100644 --- a/decoimpact/business/entities/rules/time_aggregation_rule.py +++ b/decoimpact/business/entities/rules/time_aggregation_rule.py @@ -81,6 +81,12 @@ def execute(self, value_array: _xr.DataArray, logger: ILogger) -> _xr.DataArray: if value: result[result_time_dim_name].attrs[key] = value + result = result.assign_coords({ + result_time_dim_name: result[result_time_dim_name] + }) + result[result_time_dim_name].attrs['long_name'] = result_time_dim_name + result[result_time_dim_name].attrs['standard_name'] = result_time_dim_name + return result def _perform_operation(self, aggregated_values: DataArrayResample) -> _xr.DataArray: diff --git a/tests/business/entities/test_rule_processor.py b/tests/business/entities/test_rule_processor.py index e7f1abd1..3625eb53 100644 --- a/tests/business/entities/test_rule_processor.py +++ b/tests/business/entities/test_rule_processor.py @@ -18,7 +18,9 @@ IMultiArrayBasedRule, ) from decoimpact.business.entities.rules.i_rule import IRule +from decoimpact.business.entities.rules.time_aggregation_rule import TimeAggregationRule from decoimpact.crosscutting.i_logger import ILogger +from decoimpact.data.api.i_time_aggregation_rule_data import ITimeAggregationRuleData def _create_test_rules() -> List[IRule]: @@ -384,6 +386,54 @@ def test_process_rules_throws_exception_for_unsupported_rule(): assert exception_raised.args[0] == expected_message +def test_process_rules_copies_multi_coords_correctly(): + """Tests if during processing the coords are copied to the output array + and there are no duplicates.""" + + # Arrange + output_dataset = _xr.Dataset() + output_dataset["test"] = _xr.DataArray([32, 94, 9]) + + logger = Mock(ILogger) + rule = Mock(IArrayBasedRule) + rule_2 = Mock(IArrayBasedRule) + + result_array = _xr.DataArray([27, 45, 93]) + result_array = result_array.assign_coords({"test": _xr.DataArray([2, 4, 5])}) + + result_array_2 = _xr.DataArray([1, 2, 93]) + result_array_2 = result_array.assign_coords({"test": _xr.DataArray([2, 4, 5])}) + + rule.input_variable_names = ["test"] + rule.output_variable_name = "output" + rule.execute.return_value = result_array + + rule_2.input_variable_names = ["test"] + rule_2.output_variable_name = "output_2" + rule_2.execute.return_value = result_array_2 + + processor = RuleProcessor([rule, rule_2], output_dataset) + + # Act + assert processor.initialize(logger) + result_dataset = processor.process_rules(output_dataset, logger) + + # Assert + assert "test" in result_dataset.coords + # compare coords at the level of variable + result_array_coords = result_array.coords["test"] + result_output_var_coords = result_dataset.output.coords["test"] # output variable + assert (result_output_var_coords == result_array_coords).all() + + # compare coords at the level of dataset / + # check if the coordinates are correctly copied to the dataset + result_dataset_coords = result_dataset.coords["test"] + assert (result_output_var_coords == result_dataset_coords).all() + + # check if havnig an extra rule with coordinates then they are not copy pasted too + assert len(result_dataset.output.coords) == 1 + + def test_execute_rule_throws_error_for_unknown_input_variable(): """Tests that trying to execute a rule with an unknown input variable throws an error, and the error message."""