diff --git a/captum/metrics/_core/infidelity.py b/captum/metrics/_core/infidelity.py index ce7087fe24..23e2cbc78f 100644 --- a/captum/metrics/_core/infidelity.py +++ b/captum/metrics/_core/infidelity.py @@ -120,6 +120,7 @@ def infidelity( target: TargetType = None, n_perturb_samples: int = 10, max_examples_per_batch: int = None, + normalize: bool = False, ) -> Tensor: r""" Explanation infidelity represents the expected mean-squared error @@ -347,6 +348,20 @@ def infidelity( `input batch size * n_perturb_samples`. Default: None + normalize (bool, optional): Normalize the dot product of the input + perturbation and the attribution so the infidelity value is invariant + to constant scaling of the attribution values. The normalization factor + beta is defined as the ratio of two mean values: + $$ \beta = \frac{ + \mathbb{E}_{I \sim \mu_I} [ I^T \Phi(f, x) (f(x) - f(x - I)) ] + }{ + \mathbb{E}_{I \sim \mu_I} [ (I^T \Phi(f, x))^2 ] + } $$. + Please refer the original paper for the meaning of the symbols. Same + normalization can be found in the paper's official implementation + https://github.com/chihkuanyeh/saliency_evaluation + + Default: False Returns: infidelities (tensor): A tensor of scalar infidelity scores per @@ -439,7 +454,9 @@ def _validate_inputs_and_perturbations( is: {}""" ).format(perturb[0].shape, input_perturbed[0].shape) - def _next_infidelity(current_n_perturb_samples: int) -> Tensor: + def _next_infidelity_tensors( + current_n_perturb_samples: int, + ) -> Union[Tuple[Tensor], Tuple[Tensor, Tensor, Tensor]]: perturbations, inputs_perturbed = _generate_perturbations( current_n_perturb_samples ) @@ -474,11 +491,12 @@ def _next_infidelity(current_n_perturb_samples: int) -> Tensor: inputs_fwd = torch.repeat_interleave( inputs_fwd, current_n_perturb_samples, dim=0 ) - inputs_minus_perturb = inputs_fwd - inputs_perturbed_fwd + perturbed_fwd_diffs = inputs_fwd - inputs_perturbed_fwd attributions_expanded = tuple( torch.repeat_interleave(attribution, current_n_perturb_samples, dim=0) for attribution in attributions ) + attributions_times_perturb = tuple( (attribution_expanded * perturbation).view(attribution_expanded.size(0), -1) for attribution_expanded, perturbation in zip( @@ -486,19 +504,31 @@ def _next_infidelity(current_n_perturb_samples: int) -> Tensor: ) ) - attribution_times_perturb_sums = sum( - [ - torch.sum(attribution_times_perturb, dim=1) - for attribution_times_perturb in attributions_times_perturb - ] + attr_times_perturb_sums = sum( + torch.sum(attribution_times_perturb, dim=1) + for attribution_times_perturb in attributions_times_perturb ) + attr_times_perturb_sums = cast(Tensor, attr_times_perturb_sums) - return torch.sum( - torch.pow( - attribution_times_perturb_sums - inputs_minus_perturb.view(-1), 2 - ).view(bsz, -1), - dim=1, - ) + # reshape as Tensor(bsz, current_n_perturb_samples) + attr_times_perturb_sums = attr_times_perturb_sums.view(bsz, -1) + perturbed_fwd_diffs = perturbed_fwd_diffs.view(bsz, -1) + + if normalize: + # in order to normalize, we have to aggregate the following tensors + # to calculate MSE in its polynomial expansion: + # (a-b)^2 = a^2 - 2ab + b^2 + return ( + attr_times_perturb_sums.pow(2).sum(-1), + (attr_times_perturb_sums * perturbed_fwd_diffs).sum(-1), + perturbed_fwd_diffs.pow(2).sum(-1), + ) + else: + # returns (a-b)^2 if no need to normalize + return ((attr_times_perturb_sums - perturbed_fwd_diffs).pow(2).sum(-1),) + + def _sum_infidelity_tensors(agg_tensors, tensors): + return tuple(agg_t + t for agg_t, t in zip(agg_tensors, tensors)) # perform argument formattings inputs = _format_input(inputs) # type: ignore @@ -522,10 +552,32 @@ def _next_infidelity(current_n_perturb_samples: int) -> Tensor: bsz = inputs[0].size(0) with torch.no_grad(): - metrics_sum = _divide_and_aggregate_metrics( + # if not normalize, directly return aggrgated MSE ((a-b)^2,) + # else return aggregated MSE's polynomial expansion tensors (a^2, ab, b^2) + agg_tensors = _divide_and_aggregate_metrics( cast(Tuple[Tensor, ...], inputs), n_perturb_samples, - _next_infidelity, + _next_infidelity_tensors, + agg_func=_sum_infidelity_tensors, max_examples_per_batch=max_examples_per_batch, ) - return metrics_sum * 1 / n_perturb_samples + + if normalize: + beta_num = agg_tensors[1] + beta_denorm = agg_tensors[0] + + beta = safe_div( + beta_num, + beta_denorm, + torch.tensor(1.0, dtype=beta_denorm.dtype, device=beta_denorm.device), + ) + + infidelity_values = ( + beta ** 2 * agg_tensors[0] - 2 * beta * agg_tensors[1] + agg_tensors[2] + ) + else: + infidelity_values = agg_tensors[0] + + infidelity_values /= n_perturb_samples + + return infidelity_values diff --git a/captum/metrics/_utils/batching.py b/captum/metrics/_utils/batching.py index 0f441ea5c5..ee3b38f58e 100644 --- a/captum/metrics/_utils/batching.py +++ b/captum/metrics/_utils/batching.py @@ -16,7 +16,7 @@ def _divide_and_aggregate_metrics( ) -> Tensor: r""" This function is used to slice large number of samples `n_perturb_samples` per - input example into smaller pieces, computing the metics for each small piece and + input example into smaller pieces, computing the metrics for each small piece and aggregating the results across all `n_perturb_samples` per example. The function returns overall aggregated metric per sample. The size of each slice is determined by the `max_examples_per_batch` input parameter. diff --git a/tests/metrics/test_infidelity.py b/tests/metrics/test_infidelity.py index 63009edd2a..4f8a85b647 100644 --- a/tests/metrics/test_infidelity.py +++ b/tests/metrics/test_infidelity.py @@ -294,6 +294,29 @@ def perturbed_func3(inputs, baselines): assertTensorAlmostEqual(self, infid, delta * delta) assertTensorAlmostEqual(self, infid, infid2) + def test_basic_infidelity_multiple_with_normalize(self) -> None: + input1 = torch.tensor([3.0] * 3) + input2 = torch.tensor([1.0] * 3) + inputs = (input1, input2) + expected = torch.zeros(3) + + model = BasicModel2() + ig = IntegratedGradients(model) + attrs = ig.attribute(inputs) + scaled_attrs = tuple(attr * 100 for attr in attrs) + + infid = self.infidelity_assert(model, attrs, inputs, expected, normalize=True) + scaled_infid = self.infidelity_assert( + model, + scaled_attrs, + inputs, + expected, + normalize=True, + ) + + # scaling attr should not change normalized infidelity + assertTensorAlmostEqual(self, infid, scaled_infid) + def test_sensitivity_n_ig(self) -> None: model = BasicModel_MultiLayer() ig = IntegratedGradients(model) @@ -384,6 +407,7 @@ def basic_model_assert( max_batch_size: int = None, perturb_func: Callable = _local_perturb_func, multiply_by_inputs: bool = False, + normalize: bool = False, ) -> Tensor: ig = IntegratedGradients(model) if multiply_by_inputs: @@ -404,6 +428,7 @@ def basic_model_assert( n_perturb_samples=n_perturb_samples, max_batch_size=max_batch_size, perturb_func=perturb_func, + normalize=normalize, ) def basic_model_global_assert( @@ -417,6 +442,7 @@ def basic_model_global_assert( n_perturb_samples: int = 10, max_batch_size: int = None, perturb_func: Callable = _global_perturb_func1, + normalize: bool = False, ) -> Tensor: attrs = attr_algo.attribute( inputs, additional_forward_args=additional_args, target=target @@ -431,6 +457,7 @@ def basic_model_global_assert( target=target, n_perturb_samples=n_perturb_samples, max_batch_size=max_batch_size, + normalize=normalize, ) return infid @@ -447,6 +474,7 @@ def infidelity_assert( max_batch_size: int = None, multi_input: bool = True, perturb_func: Callable = _local_perturb_func, + normalize: bool = False, **kwargs: Any ) -> Tensor: infid = infidelity( @@ -459,6 +487,7 @@ def infidelity_assert( baselines=baselines, n_perturb_samples=n_perturb_samples, max_examples_per_batch=max_batch_size, + normalize=normalize, ) assertTensorAlmostEqual(self, infid, expected, 0.05) return infid