From 87237063b4ef1e328f28dcffb0b1815b11d264a6 Mon Sep 17 00:00:00 2001 From: Sarah Tran Date: Tue, 25 Mar 2025 11:59:19 -0700 Subject: [PATCH 1/2] Support multiple perturbations per eval when masking across tensors (#1530) Summary: This was supported in the old path (when constructing ablated inputs over each input tensor individually) to improve compute efficiency by optionally passing in multiple perturbed inputs to the model fwd function. Reviewed By: craymichael Differential Revision: D71435704 --- captum/attr/_core/feature_ablation.py | 208 ++++++++++++++++------- captum/attr/_core/feature_permutation.py | 46 +++-- tests/attr/test_feature_ablation.py | 57 +++++++ 3 files changed, 240 insertions(+), 71 deletions(-) diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index d7f2570c9b..c6a47417e4 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -353,10 +353,12 @@ def attribute( formatted_feature_mask, attr_progress, flattened_initial_eval, + initial_eval, n_outputs, total_attrib, weights, attrib_type, + perturbations_per_eval, **kwargs, ) else: @@ -470,10 +472,12 @@ def _attribute_with_cross_tensor_feature_masks( formatted_feature_mask: Tuple[Tensor, ...], attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]], flattened_initial_eval: Tensor, + initial_eval: Tensor, n_outputs: int, total_attrib: List[Tensor], weights: List[Tensor], attrib_type: dtype, + perturbations_per_eval: int, **kwargs: Any, ) -> Tuple[List[Tensor], List[Tensor]]: feature_idx_to_tensor_idx: Dict[int, List[int]] = {} @@ -482,17 +486,78 @@ def _attribute_with_cross_tensor_feature_masks( if feature_idx.item() not in feature_idx_to_tensor_idx: feature_idx_to_tensor_idx[feature_idx.item()] = [] feature_idx_to_tensor_idx[feature_idx.item()].append(i) + all_feature_idxs = list(feature_idx_to_tensor_idx.keys()) + + additional_args_repeated: object + if perturbations_per_eval > 1: + # Repeat features and additional args for batch size. + all_features_repeated = tuple( + torch.cat([formatted_inputs[j]] * perturbations_per_eval, dim=0) + for j in range(len(formatted_inputs)) + ) + additional_args_repeated = ( + _expand_additional_forward_args( + formatted_additional_forward_args, perturbations_per_eval + ) + if formatted_additional_forward_args is not None + else None + ) + target_repeated = _expand_target(target, perturbations_per_eval) + else: + all_features_repeated = formatted_inputs + additional_args_repeated = formatted_additional_forward_args + target_repeated = target + num_examples = formatted_inputs[0].shape[0] + + current_additional_args: object + if isinstance(baselines, tuple): + reshaped = False + reshaped_baselines: list[Union[Tensor, int, float]] = [] + for baseline in baselines: + if isinstance(baseline, Tensor): + reshaped = True + reshaped_baselines.append( + baseline.reshape((1,) + tuple(baseline.shape)) + ) + else: + reshaped_baselines.append(baseline) + baselines = tuple(reshaped_baselines) if reshaped else baselines + for i in range(0, len(all_feature_idxs), perturbations_per_eval): + current_feature_idxs = all_feature_idxs[i : i + perturbations_per_eval] + current_num_ablated_features = min( + perturbations_per_eval, len(current_feature_idxs) + ) + + # Store appropriate inputs and additional args based on batch size. + if current_num_ablated_features != perturbations_per_eval: + current_additional_args = ( + _expand_additional_forward_args( + formatted_additional_forward_args, current_num_ablated_features + ) + if formatted_additional_forward_args is not None + else None + ) + current_target = _expand_target(target, current_num_ablated_features) + expanded_inputs = tuple( + feature_repeated[0 : current_num_ablated_features * num_examples] + for feature_repeated in all_features_repeated + ) + else: + current_additional_args = additional_args_repeated + current_target = target_repeated + expanded_inputs = all_features_repeated + + current_inputs, current_masks = ( + self._construct_ablated_input_across_tensors( + expanded_inputs, + formatted_feature_mask, + baselines, + current_feature_idxs, + feature_idx_to_tensor_idx, + current_num_ablated_features, + ) + ) - for ( - current_inputs, - current_mask, - ) in self._ablation_generator( - formatted_inputs, - baselines, - formatted_feature_mask, - feature_idx_to_tensor_idx, - **kwargs, - ): # modified_eval has (n_feature_perturbed * n_outputs) elements # shape: # agg mode: (*initial_eval.shape) @@ -501,8 +566,8 @@ def _attribute_with_cross_tensor_feature_masks( modified_eval = _run_forward( self.forward_func, current_inputs, - target, - formatted_additional_forward_args, + current_target, + current_additional_args, ) if attr_progress is not None: @@ -515,75 +580,65 @@ def _attribute_with_cross_tensor_feature_masks( total_attrib, weights = self._process_ablated_out_full( modified_eval, - current_mask, + current_masks, flattened_initial_eval, - formatted_inputs, + initial_eval, + current_inputs, n_outputs, + num_examples, total_attrib, weights, attrib_type, + perturbations_per_eval, ) return total_attrib, weights - def _ablation_generator( - self, - inputs: Tuple[Tensor, ...], - baselines: BaselineType, - input_mask: Tuple[Tensor, ...], - feature_idx_to_tensor_idx: Dict[int, List[int]], - **kwargs: Any, - ) -> Generator[ - Tuple[ - Tuple[Tensor, ...], - Tuple[Optional[Tensor], ...], - ], - None, - None, - ]: - if isinstance(baselines, torch.Tensor): - baselines = baselines.reshape((1,) + tuple(baselines.shape)) - - # Process one feature per time, rather than processing every input tensor - for feature_idx in feature_idx_to_tensor_idx.keys(): - ablated_inputs, current_masks = ( - self._construct_ablated_input_across_tensors( - inputs, - input_mask, - baselines, - feature_idx, - feature_idx_to_tensor_idx[feature_idx], - ) - ) - yield ablated_inputs, current_masks - def _construct_ablated_input_across_tensors( self, inputs: Tuple[Tensor, ...], input_mask: Tuple[Tensor, ...], baselines: BaselineType, - feature_idx: int, - tensor_idxs: List[int], + feature_idxs: List[int], + feature_idx_to_tensor_idx: Dict[int, List[int]], + current_num_ablated_features: int, ) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]: - ablated_inputs = [] current_masks: List[Optional[Tensor]] = [] + tensor_idxs = { + tensor_idx + for sublist in ( + feature_idx_to_tensor_idx[feature_idx] for feature_idx in feature_idxs + ) + for tensor_idx in sublist + } + for i, input_tensor in enumerate(inputs): if i not in tensor_idxs: ablated_inputs.append(input_tensor) current_masks.append(None) continue - tensor_mask = (input_mask[i] == feature_idx).to(input_tensor.device).long() + tensor_mask = [] + ablated_input = input_tensor.clone() baseline = baselines[i] if isinstance(baselines, tuple) else baselines - if isinstance(baseline, torch.Tensor): - baseline = baseline.reshape( - (1,) * (input_tensor.dim() - baseline.dim()) + tuple(baseline.shape) + for j, feature_idx in enumerate(feature_idxs): + original_input_size = ( + input_tensor.shape[0] // current_num_ablated_features ) - assert baseline is not None, "baseline must be provided" - ablated_input = ( - input_tensor * (1 - tensor_mask).to(input_tensor.dtype) - ) + (baseline * tensor_mask.to(input_tensor.dtype)) + start_idx = j * original_input_size + end_idx = (j + 1) * original_input_size + + mask = (input_mask[i] == feature_idx).to(input_tensor.device).long() + if mask.ndim == 0: + mask = mask.reshape((1,) * input_tensor.dim()) + tensor_mask.append(mask) + + assert baseline is not None, "baseline must be provided" + ablated_input[start_idx:end_idx] = input_tensor[start_idx:end_idx] * ( + 1 - mask + ) + (baseline * mask.to(input_tensor.dtype)) + current_masks.append(torch.stack(tensor_mask, dim=0)) ablated_inputs.append(ablated_input) - current_masks.append(tensor_mask) + return tuple(ablated_inputs), tuple(current_masks) def _initial_eval_to_processed_initial_eval_fut( @@ -784,7 +839,7 @@ def _attribute_progress_setup( formatted_inputs, feature_mask, **kwargs ) total_forwards = ( - int(sum(feature_counts)) + math.ceil(int(sum(feature_counts)) / perturbations_per_eval) if enable_cross_tensor_attribution else sum( math.ceil(count / perturbations_per_eval) for count in feature_counts @@ -1194,13 +1249,46 @@ def _process_ablated_out_full( modified_eval: Tensor, current_mask: Tuple[Optional[Tensor], ...], flattened_initial_eval: Tensor, + initial_eval: Tensor, inputs: TensorOrTupleOfTensorsGeneric, n_outputs: int, + num_examples: int, total_attrib: List[Tensor], weights: List[Tensor], attrib_type: dtype, + perturbations_per_eval: int, ) -> Tuple[List[Tensor], List[Tensor]]: modified_eval = self._parse_forward_out(modified_eval) + # if perturbations_per_eval > 1, the output shape must grow with + # input and not be aggregated + current_batch_size = inputs[0].shape[0] + + # number of perturbation, which is not the same as + # perturbations_per_eval when not enough features to perturb + n_perturb = current_batch_size / num_examples + if perturbations_per_eval > 1 and not self._is_output_shape_valid: + + current_output_shape = modified_eval.shape + + # use initial_eval as the forward of perturbations_per_eval = 1 + initial_output_shape = initial_eval.shape + + assert ( + # check if the output is not a scalar + current_output_shape + and initial_output_shape + # check if the output grow in same ratio, i.e., not agg + and current_output_shape[0] == n_perturb * initial_output_shape[0] + ), ( + "When perturbations_per_eval > 1, forward_func's output " + "should be a tensor whose 1st dim grow with the input " + f"batch size: when input batch size is {num_examples}, " + f"the output shape is {initial_output_shape}; " + f"when input batch size is {current_batch_size}, " + f"the output shape is {current_output_shape}" + ) + + self._is_output_shape_valid = True # reshape the leading dim for n_feature_perturbed # flatten each feature's eval outputs into 1D of (n_outputs) @@ -1209,9 +1297,6 @@ def _process_ablated_out_full( eval_diff = flattened_initial_eval - modified_eval eval_diff_shape = eval_diff.shape - # append the shape of one input example - # to make it broadcastable to mask - if self.use_weights: for weight, mask in zip(weights, current_mask): if mask is not None: @@ -1224,6 +1309,7 @@ def _process_ablated_out_full( ) eval_diff = eval_diff.to(total_attrib[i].device) total_attrib[i] += (eval_diff * mask.to(attrib_type)).sum(dim=0) + return total_attrib, weights def _fut_tuple_to_accumulate_fut_list( diff --git a/captum/attr/_core/feature_permutation.py b/captum/attr/_core/feature_permutation.py index 3657c00fc2..0d64f1d8b0 100644 --- a/captum/attr/_core/feature_permutation.py +++ b/captum/attr/_core/feature_permutation.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # pyre-strict -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric @@ -391,15 +391,41 @@ def _construct_ablated_input_across_tensors( inputs: Tuple[Tensor, ...], input_mask: Tuple[Tensor, ...], baselines: BaselineType, - feature_idx: int, - tensor_idxs: List[int], + feature_idxs: List[int], + feature_idx_to_tensor_idx: Dict[int, List[int]], + current_num_ablated_features: int, ) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]: current_masks: List[Optional[Tensor]] = [] - for i, mask in enumerate(input_mask): - if i in tensor_idxs: - current_masks.append((mask == feature_idx).to(inputs[0].device)) - else: + tensor_idxs = { + tensor_idx + for sublist in ( + feature_idx_to_tensor_idx[feature_idx] for feature_idx in feature_idxs + ) + for tensor_idx in sublist + } + permuted_inputs = [] + for i, input_tensor in enumerate(inputs): + if i not in tensor_idxs: current_masks.append(None) - feature_masks = tuple(current_masks) - permuted_outputs = self.perm_func_cross_tensor(inputs, feature_masks) - return permuted_outputs, feature_masks + permuted_inputs.append(input_tensor) + continue + tensor_mask = [] + permuted_input = input_tensor.clone() + for j, feature_idx in enumerate(feature_idxs): + original_input_size = ( + input_tensor.shape[0] // current_num_ablated_features + ) + start_idx = j * original_input_size + end_idx = (j + 1) * original_input_size + + mask = (input_mask[i] == feature_idx).to(input_tensor.device).bool() + if mask.ndim == 0: + mask = mask.reshape((1,) * input_tensor.dim()) + tensor_mask.append(mask) + permuted_input[start_idx:end_idx] = self.perm_func( + input_tensor[start_idx:end_idx], mask + ) + current_masks.append(torch.stack(tensor_mask, dim=0)) + permuted_inputs.append(permuted_input) + + return tuple(permuted_inputs), tuple(current_masks) diff --git a/tests/attr/test_feature_ablation.py b/tests/attr/test_feature_ablation.py index c8f9802d6a..5c3101ad01 100644 --- a/tests/attr/test_feature_ablation.py +++ b/tests/attr/test_feature_ablation.py @@ -164,6 +164,19 @@ def test_multi_sample_ablation_with_mask(self) -> None: perturbations_per_eval=(1, 2, 3), ) + def test_multi_sample_ablation_with_mask_weighted(self) -> None: + ablation_algo = FeatureAblation(BasicModel_MultiLayer()) + ablation_algo.use_weights = True + inp = torch.tensor([[2.0, 10.0, 3.0], [20.0, 50.0, 30.0]]) + mask = torch.tensor([[0, 0, 1], [1, 1, 0]]) + self._ablation_test_assert( + ablation_algo, + inp, + [[41.0, 41.0, 12.0], [280.0, 280.0, 120.0]], + feature_mask=mask, + perturbations_per_eval=(1, 2, 3), + ) + def test_multi_input_ablation_with_mask(self) -> None: ablation_algo = FeatureAblation(BasicModel_MultiLayer_MultiInput()) inp1 = torch.tensor([[23.0, 100.0, 0.0], [20.0, 50.0, 30.0]]) @@ -207,6 +220,50 @@ def test_multi_input_ablation_with_mask(self) -> None: perturbations_per_eval=(1, 2, 3), ) + def test_multi_input_ablation_with_mask_weighted(self) -> None: + ablation_algo = FeatureAblation(BasicModel_MultiLayer_MultiInput()) + ablation_algo.use_weights = True + inp1 = torch.tensor([[23.0, 100.0, 0.0], [20.0, 50.0, 30.0]]) + inp2 = torch.tensor([[20.0, 50.0, 30.0], [0.0, 100.0, 0.0]]) + inp3 = torch.tensor([[0.0, 100.0, 10.0], [2.0, 10.0, 3.0]]) + mask1 = torch.tensor([[1, 1, 1], [0, 1, 0]]) + mask2 = torch.tensor([[3, 4, 2]]) + mask3 = torch.tensor([[5, 6, 7], [5, 5, 5]]) + expected = ( + [[492.0, 492.0, 492.0], [200.0, 200.0, 200.0]], + [[80.0, 200.0, 120.0], [0.0, 400.0, 0.0]], + [[0.0, 400.0, 40.0], [60.0, 60.0, 60.0]], + ) + self._ablation_test_assert( + ablation_algo, + (inp1, inp2, inp3), + expected, + additional_input=(1,), + feature_mask=(mask1, mask2, mask3), + ) + self._ablation_test_assert( + ablation_algo, + (inp1, inp2), + expected[0:1], + additional_input=(inp3, 1), + feature_mask=(mask1, mask2), + perturbations_per_eval=(1, 2, 3), + ) + expected_with_baseline = ( + [[468.0, 468.0, 468.0], [184.0, 192.0, 184.0]], + [[68.0, 188.0, 108.0], [-12.0, 388.0, -12.0]], + [[-16.0, 384.0, 24.0], [12.0, 12.0, 12.0]], + ) + self._ablation_test_assert( + ablation_algo, + (inp1, inp2, inp3), + expected_with_baseline, + additional_input=(1,), + feature_mask=(mask1, mask2, mask3), + baselines=(2, 3.0, 4), + perturbations_per_eval=(1, 2, 3), + ) + def test_multi_input_ablation_with_mask_dupe_feature_idx(self) -> None: ablation_algo = FeatureAblation(BasicModel_MultiLayer_MultiInput()) inp1 = torch.tensor([[23.0, 100.0, 0.0], [20.0, 50.0, 30.0]]) From 8269ba55d3098d68c914b4be0aea2202d91e70b4 Mon Sep 17 00:00:00 2001 From: Sarah Tran Date: Tue, 25 Mar 2025 11:59:19 -0700 Subject: [PATCH 2/2] Adjust indices in LayerAttributor mask for individual neurons (#1531) Summary: With `enable_cross_tensor_attribution=True` for `FeatureAblation`/`FeaturePermutation`, ids/indices in the masks are now "global" Reviewed By: cyrjano Differential Revision: D71778355 --- captum/_utils/common.py | 5 +++-- captum/testing/helpers/basic_models.py | 10 ++++++++++ tests/attr/neuron/test_neuron_ablation.py | 4 ++-- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/captum/_utils/common.py b/captum/_utils/common.py index 054cd61b1e..9ebcdd6f2e 100644 --- a/captum/_utils/common.py +++ b/captum/_utils/common.py @@ -194,17 +194,18 @@ def _is_mask_valid(mask: Tensor, inp: Tensor) -> bool: def _format_feature_mask( feature_mask: Union[None, Tensor, Tuple[Tensor, ...]], inputs: Tuple[Tensor, ...], + start_idx: int = 0, ) -> Tuple[Tensor, ...]: """ Format a feature mask into a tuple of tensors. The `inputs` should be correctly formatted first If `feature_mask` is None, assign each non-batch dimension with a consecutive - integer from 0. + integer from `start_idx`. If `feature_mask` is a tensor, wrap it in a tuple. """ if feature_mask is None: formatted_mask = [] - current_num_features = 0 + current_num_features = start_idx for inp in inputs: # the following can handle empty tensor where numel is 0 # empty tensor will be added to the feature mask diff --git a/captum/testing/helpers/basic_models.py b/captum/testing/helpers/basic_models.py index 2833bad6bc..73bbad038a 100644 --- a/captum/testing/helpers/basic_models.py +++ b/captum/testing/helpers/basic_models.py @@ -662,6 +662,16 @@ def forward(self, x1: Tensor, x2: Tensor, x3: Tensor, scale: int): return self.model(scale * (x1 + x2 + x3)) +class BasicModel_MultiLayer_TupleInput(nn.Module): + def __init__(self) -> None: + super().__init__() + self.model = BasicModel_MultiLayer() + + @no_type_check + def forward(self, x: Tuple[Tensor, Tensor, Tensor]) -> Tensor: + return self.model(x[0] + x[1] + x[2]) + + class BasicModel_MultiLayer_MultiInput_with_Future(nn.Module): def __init__(self) -> None: super().__init__() diff --git a/tests/attr/neuron/test_neuron_ablation.py b/tests/attr/neuron/test_neuron_ablation.py index b316400feb..02b92cb65b 100644 --- a/tests/attr/neuron/test_neuron_ablation.py +++ b/tests/attr/neuron/test_neuron_ablation.py @@ -83,8 +83,8 @@ def test_multi_input_ablation_with_mask(self) -> None: inp2 = torch.tensor([[20.0, 50.0, 30.0], [0.0, 100.0, 0.0]]) inp3 = torch.tensor([[0.0, 100.0, 10.0], [2.0, 10.0, 3.0]]) mask1 = torch.tensor([[1, 1, 1], [0, 1, 0]]) - mask2 = torch.tensor([[0, 1, 2]]) - mask3 = torch.tensor([[0, 1, 2], [0, 0, 0]]) + mask2 = torch.tensor([[3, 4, 2]]) + mask3 = torch.tensor([[5, 6, 7], [5, 5, 5]]) expected = ( [[492.0, 492.0, 492.0], [200.0, 200.0, 200.0]], [[80.0, 200.0, 120.0], [0.0, 400.0, 0.0]],