Skip to content

Commit

Permalink
API: consistently refer to all mean/std calculations as "aggregation" (
Browse files Browse the repository at this point in the history
  • Loading branch information
j-ittner authored Mar 24, 2021
1 parent d0a9ee2 commit 204234e
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 85 deletions.
82 changes: 43 additions & 39 deletions src/facet/inspection/_inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,14 @@ class LearnerInspector(
specified in the underlying training sample.
"""

#: constant for "mean" aggregation method, to be passed as arg ``aggregation``
#: to :class:`.LearnerInspector` methods that implement it
AGG_MEAN = "mean"

#: constant for "std" aggregation method, to be passed as arg ``aggregation``
#: to :class:`.LearnerInspector` methods that implement it
AGG_STD = "std"

#: Name for feature importance series or column.
COL_IMPORTANCE = "importance"

Expand Down Expand Up @@ -394,7 +402,7 @@ def features_(self) -> List[str]:
return self.crossfit_.pipeline.feature_names_out_.to_list()

def shap_values(
self, consolidate: Optional[str] = "mean"
self, aggregation: Optional[str] = AGG_MEAN
) -> Union[pd.DataFrame, List[pd.DataFrame]]:
"""
Calculate the SHAP values for all observations and features.
Expand All @@ -405,31 +413,31 @@ def shap_values(
By default, one SHAP value is returned for each observation and feature; this
value is calculated as the mean SHAP value across all crossfits.
The ``consolidate`` argument can be used to disable or change the consolidation
The ``aggregation`` argument can be used to disable or change the aggregation
of SHAP values:
- passing ``consolidate=None`` will disable SHAP value consolidation,
- passing ``aggregation=None`` will disable SHAP value aggregation,
generating one row for every crossfit and observation (identified by
a hierarchical index with two levels)
- passing ``consolidate="mean"`` (the default) will calculate the mean SHAP
- passing ``aggregation="mean"`` (the default) will calculate the mean SHAP
values across all crossfits
- passing ``consolidate="std"`` will calculate the standard deviation of SHAP
- passing ``aggregation="std"`` will calculate the standard deviation of SHAP
values across all crossfits, as the basis for determining the uncertainty
of SHAP calculations
:param consolidate: consolidate SHAP values across splits;
:param aggregation: aggregation SHAP values across splits;
permissible values are ``"mean"`` (calculate the mean), ``"std"``
(calculate the standard deviation), or ``None`` to prevent consolidation
(calculate the standard deviation), or ``None`` to prevent aggregation
(default: ``"mean"``)
:return: a data frame with SHAP values
"""
self._ensure_fitted()
return self.__split_multi_output_df(
self._shap_calculator.get_shap_values(consolidate=consolidate)
self._shap_calculator.get_shap_values(aggregation=aggregation)
)

def shap_interaction_values(
self, consolidate: Optional[str] = "mean"
self, aggregation: Optional[str] = AGG_MEAN
) -> Union[pd.DataFrame, List[pd.DataFrame]]:
"""
Calculate the SHAP interaction values for all observations and pairs of
Expand All @@ -443,28 +451,28 @@ def shap_interaction_values(
feature pairing; this value is calculated as the mean SHAP interaction value
across all crossfits.
The ``consolidate`` argument can be used to disable or change the consolidation
The ``aggregation`` argument can be used to disable or change the aggregation
of SHAP interaction values:
- passing ``consolidate=None`` will disable SHAP interaction value
consolidation, generating one row for every crossfit, observation and
- passing ``aggregation=None`` will disable SHAP interaction value
aggregation, generating one row for every crossfit, observation and
feature (identified by a hierarchical index with three levels)
- passing ``consolidate="mean"`` (the default) will calculate the mean SHAP
- passing ``aggregation="mean"`` (the default) will calculate the mean SHAP
interaction values across all crossfits
- passing ``consolidate="std"`` will calculate the standard deviation of SHAP
- passing ``aggregation="std"`` will calculate the standard deviation of SHAP
interaction values across all crossfits, as the basis for determining the
uncertainty of SHAP calculations
:param consolidate: consolidate SHAP interaction values across splits;
:param aggregation: aggregate SHAP interaction values across splits;
permissible values are ``"mean"`` (calculate the mean), ``"std"``
(calculate the standard deviation), or ``None`` to prevent consolidation
(calculate the standard deviation), or ``None`` to prevent aggregation
(default: ``"mean"``)
:return: a data frame with SHAP interaction values
"""
self._ensure_fitted()
return self.__split_multi_output_df(
self.__shap_interaction_values_calculator.get_shap_interaction_values(
consolidate=consolidate
aggregation=aggregation
)
)

Expand Down Expand Up @@ -493,7 +501,7 @@ def feature_importance(
raise ValueError(f'arg method="{method}" must be one of {methods}')

shap_matrix: pd.DataFrame = self._shap_calculator.get_shap_values(
consolidate="mean"
aggregation="mean"
)
weight: Optional[pd.Series] = self.sample_.weight

Expand Down Expand Up @@ -530,7 +538,7 @@ def feature_synergy_matrix(
*,
absolute: bool = False,
symmetrical: bool = False,
std: bool = False,
aggregation: Optional[str] = AGG_MEAN,
clustered: bool = True,
) -> Union[pd.DataFrame, List[pd.DataFrame]]:
"""
Expand Down Expand Up @@ -558,9 +566,8 @@ def feature_synergy_matrix(
mutual synergy; if ``False``, return an asymmetrical matrix quantifying
unilateral synergy of the features represented by rows with the
features represented by columns (default: ``False``)
:param std: if ``True``, return standard deviations instead of (mean) values;
return ``None`` if only a single matrix had been calculated and
thus the standard deviation is not known
:param aggregation: if ``mean``, return mean values across all models in the
crossfit; additional aggregation methods will be added in future releases
:param clustered: if ``True``, reorder the rows and columns of the matrix
such that synergy between adjacent rows and columns is maximised; if
``False``, keep rows and columns in the original features order
Expand All @@ -574,9 +581,7 @@ def feature_synergy_matrix(
return self.__feature_affinity_matrix(
affinity_matrices=(
explainer.to_frames(
explainer.synergy(
symmetrical=symmetrical, absolute=absolute, std=std
)
explainer.synergy(symmetrical=symmetrical, absolute=absolute)
)
),
affinity_symmetrical=explainer.synergy(
Expand All @@ -590,7 +595,7 @@ def feature_redundancy_matrix(
*,
absolute: bool = False,
symmetrical: bool = False,
std: bool = False,
aggregation: Optional[str] = AGG_MEAN,
clustered: bool = True,
) -> Union[pd.DataFrame, List[pd.DataFrame]]:
"""
Expand Down Expand Up @@ -618,9 +623,8 @@ def feature_redundancy_matrix(
mutual redundancy; if ``False``, return an asymmetrical matrix quantifying
unilateral redundancy of the features represented by rows with the
features represented by columns (default: ``False``)
:param std: if ``True``, return standard deviations instead of (mean) values;
return ``None`` if only a single matrix had been calculated and
thus the standard deviation is not known
:param aggregation: if ``mean``, return mean values across all models in the
crossfit; additional aggregation methods will be added in future releases
:param clustered: if ``True``, reorder the rows and columns of the matrix
such that redundancy between adjacent rows and columns is maximised; if
``False``, keep rows and columns in the original features order
Expand All @@ -634,9 +638,7 @@ def feature_redundancy_matrix(
return self.__feature_affinity_matrix(
affinity_matrices=(
explainer.to_frames(
explainer.redundancy(
symmetrical=symmetrical, absolute=absolute, std=std
)
explainer.redundancy(symmetrical=symmetrical, absolute=absolute)
)
),
affinity_symmetrical=explainer.redundancy(
Expand All @@ -650,7 +652,7 @@ def feature_association_matrix(
*,
absolute: bool = False,
symmetrical: bool = False,
std: bool = False,
aggregation: Optional[str] = AGG_MEAN,
clustered: bool = True,
) -> Union[pd.DataFrame, List[pd.DataFrame]]:
"""
Expand Down Expand Up @@ -681,9 +683,8 @@ def feature_association_matrix(
with the features represented by columns;
if ``True``, return a symmetrical matrix quantifying mutual association
(default: ``False``)
:param std: if ``True``, return standard deviations instead of (mean) values;
return ``None`` if only a single matrix had been calculated and
thus the standard deviation is not known
:param aggregation: if ``mean``, return mean values across all models in the
crossfit; additional aggregation methods will be added in future releases
:param clustered: if ``True``, reorder the rows and columns of the matrix
such that association between adjacent rows and columns is maximised; if
``False``, keep rows and columns in the original features order
Expand All @@ -693,12 +694,15 @@ def feature_association_matrix(
"""
self._ensure_fitted()

if aggregation != LearnerInspector.AGG_MEAN:
raise ValueError(f"unknown aggregation method: aggregation={aggregation}")

global_explainer = self._shap_global_explainer
return self.__feature_affinity_matrix(
affinity_matrices=(
global_explainer.to_frames(
global_explainer.association(
absolute=absolute, symmetrical=symmetrical, std=std
absolute=absolute, symmetrical=symmetrical
)
)
),
Expand Down Expand Up @@ -828,7 +832,7 @@ def feature_interaction_matrix(self) -> Union[pd.DataFrame, List[pd.DataFrame]]:
# (n_observations, n_outputs, n_features, n_features)
# where the innermost feature x feature arrays are symmetrical
im_matrix_per_observation_and_output = (
self.shap_interaction_values(consolidate=None)
self.shap_interaction_values(aggregation=None)
.values.reshape((-1, n_features, n_outputs, n_features))
.swapaxes(1, 2)
)
Expand Down Expand Up @@ -902,7 +906,7 @@ def shap_plot_data(self) -> ShapPlotData:
"""

shap_values: Union[pd.DataFrame, List[pd.DataFrame]] = self.shap_values(
consolidate="mean"
aggregation="mean"
)

output_names: List[str] = self.output_names_
Expand Down
56 changes: 32 additions & 24 deletions src/facet/inspection/_shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ class ShapCalculator(
in a data frame.
"""

#: constant for "mean" aggregation method, to be passed as arg ``aggregation``
#: to :class:`.ShapCalculator` methods that implement it
AGG_MEAN = "mean"

#: constant for "std" aggregation method, to be passed as arg ``aggregation``
#: to :class:`.ShapCalculator` methods that implement it
AGG_STD = "std"

#: name of index level indicating the split ID
IDX_SPLIT = "split"

Expand Down Expand Up @@ -159,23 +167,23 @@ def fit(
return self

@abstractmethod
def get_shap_values(self, consolidate: Optional[str]) -> pd.DataFrame:
def get_shap_values(self, aggregation: Optional[str]) -> pd.DataFrame:
"""
The resulting consolidated shap values as a data frame,
The resulting aggregated shap values as a data frame,
aggregated to averaged SHAP contributions per feature and observation.
:param consolidate: consolidation method, or ``None`` for no consolidation
:param aggregation: aggregation method, or ``None`` for no aggregation
:return: SHAP contribution values with shape
(n_observations, n_outputs * n_features)
"""

@abstractmethod
def get_shap_interaction_values(self, consolidate: Optional[str]) -> pd.DataFrame:
def get_shap_interaction_values(self, aggregation: Optional[str]) -> pd.DataFrame:
"""
The resulting consolidated shap interaction values as a data frame,
The resulting aggregated shap interaction values as a data frame,
aggregated to averaged SHAP interaction values per observation.
:param consolidate: consolidation method, or ``None`` for no consolidation
:param aggregation: aggregation method, or ``None`` for no aggregation
:return: SHAP contribution values with shape
(n_observations * n_features, n_outputs * n_features)
:raise TypeError: this SHAP calculator does not support interaction values
Expand Down Expand Up @@ -281,7 +289,7 @@ def _concatenate_splits(
pass

@staticmethod
def _consolidate_splits(
def _aggregate_splits(
shap_all_splits_df: pd.DataFrame, method: Optional[str]
) -> pd.DataFrame:
# Group SHAP values by observation ID, aggregate SHAP values using mean or std,
Expand All @@ -298,14 +306,14 @@ def _consolidate_splits(

level = 1 if n_levels == 2 else tuple(range(1, n_levels))

if method == "mean":
shap_consolidated = shap_all_splits_df.mean(level=level)
elif method == "std":
shap_consolidated = shap_all_splits_df.std(level=level)
if method == ShapCalculator.AGG_MEAN:
shap_aggregated = shap_all_splits_df.mean(level=level)
elif method == ShapCalculator.AGG_STD:
shap_aggregated = shap_all_splits_df.std(level=level)
else:
raise ValueError(f"unknown consolidation method: {method}")
raise ValueError(f"unknown aggregation method: {method}")

return shap_consolidated
return shap_aggregated

@staticmethod
@abstractmethod
Expand Down Expand Up @@ -411,18 +419,18 @@ class ShapValuesCalculator(
Base class for calculating SHAP contribution values.
"""

def get_shap_values(self, consolidate: Optional[str]) -> pd.DataFrame:
def get_shap_values(self, aggregation: Optional[str]) -> pd.DataFrame:
"""[see superclass]"""
self._ensure_fitted()
return ShapCalculator._consolidate_splits(
shap_all_splits_df=self.shap_, method=consolidate
return ShapCalculator._aggregate_splits(
shap_all_splits_df=self.shap_, method=aggregation
)

def get_shap_interaction_values(self, consolidate: Optional[str]) -> pd.DataFrame:
def get_shap_interaction_values(self, aggregation: Optional[str]) -> pd.DataFrame:
"""
Not implemented.
:param consolidate: (ignored)
:param aggregation: (ignored)
:return: (never returns)
:raise TypeError: always raises this - SHAP interaction values are not supported
"""
Expand Down Expand Up @@ -487,18 +495,18 @@ class ShapInteractionValuesCalculator(
Base class for calculating SHAP interaction values.
"""

def get_shap_values(self, consolidate: Optional[str]) -> pd.DataFrame:
def get_shap_values(self, aggregation: Optional[str]) -> pd.DataFrame:
"""[see superclass]"""
self._ensure_fitted()
return ShapCalculator._consolidate_splits(
shap_all_splits_df=self.shap_.sum(level=(0, 1)), method=consolidate
return ShapCalculator._aggregate_splits(
shap_all_splits_df=self.shap_.sum(level=(0, 1)), method=aggregation
)

def get_shap_interaction_values(self, consolidate: Optional[str]) -> pd.DataFrame:
def get_shap_interaction_values(self, aggregation: Optional[str]) -> pd.DataFrame:
"""[see superclass]"""
self._ensure_fitted()
return ShapCalculator._consolidate_splits(
shap_all_splits_df=self.shap_, method=consolidate
return ShapCalculator._aggregate_splits(
shap_all_splits_df=self.shap_, method=aggregation
)

def get_diagonals(self) -> pd.DataFrame:
Expand Down
10 changes: 5 additions & 5 deletions src/facet/inspection/_shap_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(self) -> None:
self.association_rel_asymmetric_: Optional[np.ndarray] = None

def association(
self, absolute: bool, symmetrical: bool, std: bool
self, absolute: bool, symmetrical: bool, std: bool = False
) -> Optional[np.ndarray]:
"""[see superclass]"""
if absolute:
Expand All @@ -94,7 +94,7 @@ def _fit(self, shap_calculator: ShapCalculator) -> None:
# basic definitions
#

shap_values: pd.DataFrame = shap_calculator.get_shap_values(consolidate="mean")
shap_values: pd.DataFrame = shap_calculator.get_shap_values(aggregation="mean")
n_outputs: int = len(shap_calculator.output_names_)
n_features: int = len(shap_calculator.feature_index_)
n_observations: int = len(shap_values)
Expand Down Expand Up @@ -266,7 +266,7 @@ def __init__(self, min_direct_synergy: Optional[float] = None) -> None:
"""

def synergy(
self, symmetrical: bool, absolute: bool, std: bool
self, symmetrical: bool, absolute: bool, std: bool = False
) -> Optional[np.ndarray]:
"""[see superclass]"""
if absolute:
Expand All @@ -278,7 +278,7 @@ def synergy(
return self.synergy_rel_ if symmetrical else self.synergy_rel_asymmetric_

def redundancy(
self, symmetrical: bool, absolute: bool, std: bool
self, symmetrical: bool, absolute: bool, std: bool = False
) -> Optional[np.ndarray]:
"""[see superclass]"""
if absolute:
Expand All @@ -297,7 +297,7 @@ def _fit(self, shap_calculator: ShapInteractionValuesCalculator) -> None:
# basic definitions
#
shap_values: pd.DataFrame = shap_calculator.get_shap_interaction_values(
consolidate="mean"
aggregation="mean"
)
features: pd.Index = shap_calculator.feature_index_
outputs: List[str] = shap_calculator.output_names_
Expand Down
Loading

0 comments on commit 204234e

Please sign in to comment.