Skip to content

Commit

Permalink
unit8co#1545 implement _query_explainability_result() helper to avoid…
Browse files Browse the repository at this point in the history
… code duplication
  • Loading branch information
Rijk van der Meulen committed Mar 12, 2023
1 parent 55e9941 commit 7176e11
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 29 deletions.
69 changes: 41 additions & 28 deletions darts/explainability/explainability_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""

from abc import ABC
from typing import Dict, Optional, Sequence, Union
from typing import Any, Dict, Optional, Sequence, Union

import shap
from numpy import integer
Expand Down Expand Up @@ -55,14 +55,39 @@ def get_explanation(
The component for which to return the explanation. Does not
need to be specified for univariate series.
"""
return self._query_explainability_result(
self.explained_forecasts, horizon, component
)

def _query_explainability_result(
self,
attr: Union[
Dict[integer, Dict[str, Any]], Sequence[Dict[integer, Dict[str, Any]]]
],
horizon: int,
component: Optional[str] = None,
) -> Any:
"""
Helper that extracts and returns the explainability result attribute for a specified horizon and component from
the input attribute.
Parameters
----------
attr
An explainability result attribute from which to extract the content for a certain horizon and component.
horizon
The horizon for which to return the content of the attribute.
component
The component for which to return the content of the attribute. Does not
need to be specified for univariate series.
"""
self._validate_input_for_querying_explainability_result(horizon, component)
if isinstance(self.explained_forecasts, list):
return [
self.explained_forecasts[i][horizon][component]
for i in range(len(self.explained_forecasts))
]
if component is None:
component = self.available_components[0]
if isinstance(attr, list):
return [attr[i][horizon][component] for i in range(len(attr))]
else:
return self.explained_forecasts[horizon][component]
return attr[horizon][component]

def _validate_input_for_querying_explainability_result(
self, horizon: int, component: Optional[str] = None
Expand Down Expand Up @@ -123,11 +148,9 @@ def __init__(
Dict[integer, Dict[str, TimeSeries]],
Sequence[Dict[integer, Dict[str, TimeSeries]]],
],
shap_explanation_object: Optional[
Union[
Dict[integer, Dict[str, shap.Explanation]],
Sequence[Dict[integer, Dict[str, shap.Explanation]]],
]
shap_explanation_object: Union[
Dict[integer, Dict[str, shap.Explanation]],
Sequence[Dict[integer, Dict[str, shap.Explanation]]],
],
):
super().__init__(explained_forecasts)
Expand All @@ -149,14 +172,9 @@ def get_feature_values(
The component for which to return the feature values. Does not
need to be specified for univariate series.
"""
self._validate_input_for_querying_explainability_result(horizon, component)
if isinstance(self.feature_values, list):
return [
self.feature_values[i][horizon][component]
for i in range(len(self.feature_values))
]
else:
return self.feature_values[horizon][component]
return self._query_explainability_result(
self.feature_values, horizon, component
)

def get_shap_explanation_object(
self, horizon: int, component: Optional[str] = None
Expand All @@ -172,11 +190,6 @@ def get_shap_explanation_object(
The component for which to return the `shap.Explanation` object. Does not
need to be specified for univariate series.
"""
self._validate_input_for_querying_explainability_result(horizon, component)
if isinstance(self.shap_explanation_object, list):
return [
self.shap_explanation_object[i][horizon][component]
for i in range(len(self.shap_explanation_object))
]
else:
return self.shap_explanation_object[horizon][component]
return self._query_explainability_result(
self.shap_explanation_object, horizon, component
)
2 changes: 1 addition & 1 deletion darts/tests/explainability/test_shap_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ def test_feature_values_align_with_raw_output_shap(self):
).data

assert_array_equal(feature_values.values(), comparison)
self.assertTrue(
self.assertEqual(
feature_values.values().shape,
explanation_results.get_explanation(horizon=1, component="price")
.values()
Expand Down

0 comments on commit 7176e11

Please sign in to comment.