Skip to content

Commit

Permalink
support normalize in metric infidelity (#639)
Browse files Browse the repository at this point in the history
Summary:
#613

Support normalizing the infidelity like the author's implementation https://github.com/chihkuanyeh/saliency_evaluation/blob/44a66e2531f30b803be3bf5b0786971b7e7f72a1/infid_sen_utils.py#L295

Pull Request resolved: #639

Reviewed By: vivekmig

Differential Revision: D27293213

Pulled By: aobo-y

fbshipit-source-id: d06c57a8b81a32e1509874f50e47950104139214
  • Loading branch information
aobo-y authored and facebook-github-bot committed Apr 19, 2021
1 parent e31bf38 commit ed4b9ab
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 17 deletions.
84 changes: 68 additions & 16 deletions captum/metrics/_core/infidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def infidelity(
target: TargetType = None,
n_perturb_samples: int = 10,
max_examples_per_batch: int = None,
normalize: bool = False,
) -> Tensor:
r"""
Explanation infidelity represents the expected mean-squared error
Expand Down Expand Up @@ -347,6 +348,20 @@ def infidelity(
`input batch size * n_perturb_samples`.
Default: None
normalize (bool, optional): Normalize the dot product of the input
perturbation and the attribution so the infidelity value is invariant
to constant scaling of the attribution values. The normalization factor
beta is defined as the ratio of two mean values:
$$ \beta = \frac{
\mathbb{E}_{I \sim \mu_I} [ I^T \Phi(f, x) (f(x) - f(x - I)) ]
}{
\mathbb{E}_{I \sim \mu_I} [ (I^T \Phi(f, x))^2 ]
} $$.
Please refer the original paper for the meaning of the symbols. Same
normalization can be found in the paper's official implementation
https://github.com/chihkuanyeh/saliency_evaluation
Default: False
Returns:
infidelities (tensor): A tensor of scalar infidelity scores per
Expand Down Expand Up @@ -439,7 +454,9 @@ def _validate_inputs_and_perturbations(
is: {}"""
).format(perturb[0].shape, input_perturbed[0].shape)

def _next_infidelity(current_n_perturb_samples: int) -> Tensor:
def _next_infidelity_tensors(
current_n_perturb_samples: int,
) -> Union[Tuple[Tensor], Tuple[Tensor, Tensor, Tensor]]:
perturbations, inputs_perturbed = _generate_perturbations(
current_n_perturb_samples
)
Expand Down Expand Up @@ -474,31 +491,44 @@ def _next_infidelity(current_n_perturb_samples: int) -> Tensor:
inputs_fwd = torch.repeat_interleave(
inputs_fwd, current_n_perturb_samples, dim=0
)
inputs_minus_perturb = inputs_fwd - inputs_perturbed_fwd
perturbed_fwd_diffs = inputs_fwd - inputs_perturbed_fwd
attributions_expanded = tuple(
torch.repeat_interleave(attribution, current_n_perturb_samples, dim=0)
for attribution in attributions
)

attributions_times_perturb = tuple(
(attribution_expanded * perturbation).view(attribution_expanded.size(0), -1)
for attribution_expanded, perturbation in zip(
attributions_expanded, perturbations
)
)

attribution_times_perturb_sums = sum(
[
torch.sum(attribution_times_perturb, dim=1)
for attribution_times_perturb in attributions_times_perturb
]
attr_times_perturb_sums = sum(
torch.sum(attribution_times_perturb, dim=1)
for attribution_times_perturb in attributions_times_perturb
)
attr_times_perturb_sums = cast(Tensor, attr_times_perturb_sums)

return torch.sum(
torch.pow(
attribution_times_perturb_sums - inputs_minus_perturb.view(-1), 2
).view(bsz, -1),
dim=1,
)
# reshape as Tensor(bsz, current_n_perturb_samples)
attr_times_perturb_sums = attr_times_perturb_sums.view(bsz, -1)
perturbed_fwd_diffs = perturbed_fwd_diffs.view(bsz, -1)

if normalize:
# in order to normalize, we have to aggregate the following tensors
# to calculate MSE in its polynomial expansion:
# (a-b)^2 = a^2 - 2ab + b^2
return (
attr_times_perturb_sums.pow(2).sum(-1),
(attr_times_perturb_sums * perturbed_fwd_diffs).sum(-1),
perturbed_fwd_diffs.pow(2).sum(-1),
)
else:
# returns (a-b)^2 if no need to normalize
return ((attr_times_perturb_sums - perturbed_fwd_diffs).pow(2).sum(-1),)

def _sum_infidelity_tensors(agg_tensors, tensors):
return tuple(agg_t + t for agg_t, t in zip(agg_tensors, tensors))

# perform argument formattings
inputs = _format_input(inputs) # type: ignore
Expand All @@ -522,10 +552,32 @@ def _next_infidelity(current_n_perturb_samples: int) -> Tensor:

bsz = inputs[0].size(0)
with torch.no_grad():
metrics_sum = _divide_and_aggregate_metrics(
# if not normalize, directly return aggrgated MSE ((a-b)^2,)
# else return aggregated MSE's polynomial expansion tensors (a^2, ab, b^2)
agg_tensors = _divide_and_aggregate_metrics(
cast(Tuple[Tensor, ...], inputs),
n_perturb_samples,
_next_infidelity,
_next_infidelity_tensors,
agg_func=_sum_infidelity_tensors,
max_examples_per_batch=max_examples_per_batch,
)
return metrics_sum * 1 / n_perturb_samples

if normalize:
beta_num = agg_tensors[1]
beta_denorm = agg_tensors[0]

beta = safe_div(
beta_num,
beta_denorm,
torch.tensor(1.0, dtype=beta_denorm.dtype, device=beta_denorm.device),
)

infidelity_values = (
beta ** 2 * agg_tensors[0] - 2 * beta * agg_tensors[1] + agg_tensors[2]
)
else:
infidelity_values = agg_tensors[0]

infidelity_values /= n_perturb_samples

return infidelity_values
2 changes: 1 addition & 1 deletion captum/metrics/_utils/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def _divide_and_aggregate_metrics(
) -> Tensor:
r"""
This function is used to slice large number of samples `n_perturb_samples` per
input example into smaller pieces, computing the metics for each small piece and
input example into smaller pieces, computing the metrics for each small piece and
aggregating the results across all `n_perturb_samples` per example. The function
returns overall aggregated metric per sample. The size of each slice is determined
by the `max_examples_per_batch` input parameter.
Expand Down
29 changes: 29 additions & 0 deletions tests/metrics/test_infidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,29 @@ def perturbed_func3(inputs, baselines):
assertTensorAlmostEqual(self, infid, delta * delta)
assertTensorAlmostEqual(self, infid, infid2)

def test_basic_infidelity_multiple_with_normalize(self) -> None:
input1 = torch.tensor([3.0] * 3)
input2 = torch.tensor([1.0] * 3)
inputs = (input1, input2)
expected = torch.zeros(3)

model = BasicModel2()
ig = IntegratedGradients(model)
attrs = ig.attribute(inputs)
scaled_attrs = tuple(attr * 100 for attr in attrs)

infid = self.infidelity_assert(model, attrs, inputs, expected, normalize=True)
scaled_infid = self.infidelity_assert(
model,
scaled_attrs,
inputs,
expected,
normalize=True,
)

# scaling attr should not change normalized infidelity
assertTensorAlmostEqual(self, infid, scaled_infid)

def test_sensitivity_n_ig(self) -> None:
model = BasicModel_MultiLayer()
ig = IntegratedGradients(model)
Expand Down Expand Up @@ -384,6 +407,7 @@ def basic_model_assert(
max_batch_size: int = None,
perturb_func: Callable = _local_perturb_func,
multiply_by_inputs: bool = False,
normalize: bool = False,
) -> Tensor:
ig = IntegratedGradients(model)
if multiply_by_inputs:
Expand All @@ -404,6 +428,7 @@ def basic_model_assert(
n_perturb_samples=n_perturb_samples,
max_batch_size=max_batch_size,
perturb_func=perturb_func,
normalize=normalize,
)

def basic_model_global_assert(
Expand All @@ -417,6 +442,7 @@ def basic_model_global_assert(
n_perturb_samples: int = 10,
max_batch_size: int = None,
perturb_func: Callable = _global_perturb_func1,
normalize: bool = False,
) -> Tensor:
attrs = attr_algo.attribute(
inputs, additional_forward_args=additional_args, target=target
Expand All @@ -431,6 +457,7 @@ def basic_model_global_assert(
target=target,
n_perturb_samples=n_perturb_samples,
max_batch_size=max_batch_size,
normalize=normalize,
)
return infid

Expand All @@ -447,6 +474,7 @@ def infidelity_assert(
max_batch_size: int = None,
multi_input: bool = True,
perturb_func: Callable = _local_perturb_func,
normalize: bool = False,
**kwargs: Any
) -> Tensor:
infid = infidelity(
Expand All @@ -459,6 +487,7 @@ def infidelity_assert(
baselines=baselines,
n_perturb_samples=n_perturb_samples,
max_examples_per_batch=max_batch_size,
normalize=normalize,
)
assertTensorAlmostEqual(self, infid, expected, 0.05)
return infid

0 comments on commit ed4b9ab

Please sign in to comment.