diff --git a/captum/_utils/common.py b/captum/_utils/common.py index 054cd61b1e..9ebcdd6f2e 100644 --- a/captum/_utils/common.py +++ b/captum/_utils/common.py @@ -194,17 +194,18 @@ def _is_mask_valid(mask: Tensor, inp: Tensor) -> bool: def _format_feature_mask( feature_mask: Union[None, Tensor, Tuple[Tensor, ...]], inputs: Tuple[Tensor, ...], + start_idx: int = 0, ) -> Tuple[Tensor, ...]: """ Format a feature mask into a tuple of tensors. The `inputs` should be correctly formatted first If `feature_mask` is None, assign each non-batch dimension with a consecutive - integer from 0. + integer from `start_idx`. If `feature_mask` is a tensor, wrap it in a tuple. """ if feature_mask is None: formatted_mask = [] - current_num_features = 0 + current_num_features = start_idx for inp in inputs: # the following can handle empty tensor where numel is 0 # empty tensor will be added to the feature mask diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index d7f2570c9b..c6a47417e4 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -353,10 +353,12 @@ def attribute( formatted_feature_mask, attr_progress, flattened_initial_eval, + initial_eval, n_outputs, total_attrib, weights, attrib_type, + perturbations_per_eval, **kwargs, ) else: @@ -470,10 +472,12 @@ def _attribute_with_cross_tensor_feature_masks( formatted_feature_mask: Tuple[Tensor, ...], attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]], flattened_initial_eval: Tensor, + initial_eval: Tensor, n_outputs: int, total_attrib: List[Tensor], weights: List[Tensor], attrib_type: dtype, + perturbations_per_eval: int, **kwargs: Any, ) -> Tuple[List[Tensor], List[Tensor]]: feature_idx_to_tensor_idx: Dict[int, List[int]] = {} @@ -482,17 +486,78 @@ def _attribute_with_cross_tensor_feature_masks( if feature_idx.item() not in feature_idx_to_tensor_idx: feature_idx_to_tensor_idx[feature_idx.item()] = [] feature_idx_to_tensor_idx[feature_idx.item()].append(i) + all_feature_idxs = list(feature_idx_to_tensor_idx.keys()) + + additional_args_repeated: object + if perturbations_per_eval > 1: + # Repeat features and additional args for batch size. + all_features_repeated = tuple( + torch.cat([formatted_inputs[j]] * perturbations_per_eval, dim=0) + for j in range(len(formatted_inputs)) + ) + additional_args_repeated = ( + _expand_additional_forward_args( + formatted_additional_forward_args, perturbations_per_eval + ) + if formatted_additional_forward_args is not None + else None + ) + target_repeated = _expand_target(target, perturbations_per_eval) + else: + all_features_repeated = formatted_inputs + additional_args_repeated = formatted_additional_forward_args + target_repeated = target + num_examples = formatted_inputs[0].shape[0] + + current_additional_args: object + if isinstance(baselines, tuple): + reshaped = False + reshaped_baselines: list[Union[Tensor, int, float]] = [] + for baseline in baselines: + if isinstance(baseline, Tensor): + reshaped = True + reshaped_baselines.append( + baseline.reshape((1,) + tuple(baseline.shape)) + ) + else: + reshaped_baselines.append(baseline) + baselines = tuple(reshaped_baselines) if reshaped else baselines + for i in range(0, len(all_feature_idxs), perturbations_per_eval): + current_feature_idxs = all_feature_idxs[i : i + perturbations_per_eval] + current_num_ablated_features = min( + perturbations_per_eval, len(current_feature_idxs) + ) + + # Store appropriate inputs and additional args based on batch size. + if current_num_ablated_features != perturbations_per_eval: + current_additional_args = ( + _expand_additional_forward_args( + formatted_additional_forward_args, current_num_ablated_features + ) + if formatted_additional_forward_args is not None + else None + ) + current_target = _expand_target(target, current_num_ablated_features) + expanded_inputs = tuple( + feature_repeated[0 : current_num_ablated_features * num_examples] + for feature_repeated in all_features_repeated + ) + else: + current_additional_args = additional_args_repeated + current_target = target_repeated + expanded_inputs = all_features_repeated + + current_inputs, current_masks = ( + self._construct_ablated_input_across_tensors( + expanded_inputs, + formatted_feature_mask, + baselines, + current_feature_idxs, + feature_idx_to_tensor_idx, + current_num_ablated_features, + ) + ) - for ( - current_inputs, - current_mask, - ) in self._ablation_generator( - formatted_inputs, - baselines, - formatted_feature_mask, - feature_idx_to_tensor_idx, - **kwargs, - ): # modified_eval has (n_feature_perturbed * n_outputs) elements # shape: # agg mode: (*initial_eval.shape) @@ -501,8 +566,8 @@ def _attribute_with_cross_tensor_feature_masks( modified_eval = _run_forward( self.forward_func, current_inputs, - target, - formatted_additional_forward_args, + current_target, + current_additional_args, ) if attr_progress is not None: @@ -515,75 +580,65 @@ def _attribute_with_cross_tensor_feature_masks( total_attrib, weights = self._process_ablated_out_full( modified_eval, - current_mask, + current_masks, flattened_initial_eval, - formatted_inputs, + initial_eval, + current_inputs, n_outputs, + num_examples, total_attrib, weights, attrib_type, + perturbations_per_eval, ) return total_attrib, weights - def _ablation_generator( - self, - inputs: Tuple[Tensor, ...], - baselines: BaselineType, - input_mask: Tuple[Tensor, ...], - feature_idx_to_tensor_idx: Dict[int, List[int]], - **kwargs: Any, - ) -> Generator[ - Tuple[ - Tuple[Tensor, ...], - Tuple[Optional[Tensor], ...], - ], - None, - None, - ]: - if isinstance(baselines, torch.Tensor): - baselines = baselines.reshape((1,) + tuple(baselines.shape)) - - # Process one feature per time, rather than processing every input tensor - for feature_idx in feature_idx_to_tensor_idx.keys(): - ablated_inputs, current_masks = ( - self._construct_ablated_input_across_tensors( - inputs, - input_mask, - baselines, - feature_idx, - feature_idx_to_tensor_idx[feature_idx], - ) - ) - yield ablated_inputs, current_masks - def _construct_ablated_input_across_tensors( self, inputs: Tuple[Tensor, ...], input_mask: Tuple[Tensor, ...], baselines: BaselineType, - feature_idx: int, - tensor_idxs: List[int], + feature_idxs: List[int], + feature_idx_to_tensor_idx: Dict[int, List[int]], + current_num_ablated_features: int, ) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]: - ablated_inputs = [] current_masks: List[Optional[Tensor]] = [] + tensor_idxs = { + tensor_idx + for sublist in ( + feature_idx_to_tensor_idx[feature_idx] for feature_idx in feature_idxs + ) + for tensor_idx in sublist + } + for i, input_tensor in enumerate(inputs): if i not in tensor_idxs: ablated_inputs.append(input_tensor) current_masks.append(None) continue - tensor_mask = (input_mask[i] == feature_idx).to(input_tensor.device).long() + tensor_mask = [] + ablated_input = input_tensor.clone() baseline = baselines[i] if isinstance(baselines, tuple) else baselines - if isinstance(baseline, torch.Tensor): - baseline = baseline.reshape( - (1,) * (input_tensor.dim() - baseline.dim()) + tuple(baseline.shape) + for j, feature_idx in enumerate(feature_idxs): + original_input_size = ( + input_tensor.shape[0] // current_num_ablated_features ) - assert baseline is not None, "baseline must be provided" - ablated_input = ( - input_tensor * (1 - tensor_mask).to(input_tensor.dtype) - ) + (baseline * tensor_mask.to(input_tensor.dtype)) + start_idx = j * original_input_size + end_idx = (j + 1) * original_input_size + + mask = (input_mask[i] == feature_idx).to(input_tensor.device).long() + if mask.ndim == 0: + mask = mask.reshape((1,) * input_tensor.dim()) + tensor_mask.append(mask) + + assert baseline is not None, "baseline must be provided" + ablated_input[start_idx:end_idx] = input_tensor[start_idx:end_idx] * ( + 1 - mask + ) + (baseline * mask.to(input_tensor.dtype)) + current_masks.append(torch.stack(tensor_mask, dim=0)) ablated_inputs.append(ablated_input) - current_masks.append(tensor_mask) + return tuple(ablated_inputs), tuple(current_masks) def _initial_eval_to_processed_initial_eval_fut( @@ -784,7 +839,7 @@ def _attribute_progress_setup( formatted_inputs, feature_mask, **kwargs ) total_forwards = ( - int(sum(feature_counts)) + math.ceil(int(sum(feature_counts)) / perturbations_per_eval) if enable_cross_tensor_attribution else sum( math.ceil(count / perturbations_per_eval) for count in feature_counts @@ -1194,13 +1249,46 @@ def _process_ablated_out_full( modified_eval: Tensor, current_mask: Tuple[Optional[Tensor], ...], flattened_initial_eval: Tensor, + initial_eval: Tensor, inputs: TensorOrTupleOfTensorsGeneric, n_outputs: int, + num_examples: int, total_attrib: List[Tensor], weights: List[Tensor], attrib_type: dtype, + perturbations_per_eval: int, ) -> Tuple[List[Tensor], List[Tensor]]: modified_eval = self._parse_forward_out(modified_eval) + # if perturbations_per_eval > 1, the output shape must grow with + # input and not be aggregated + current_batch_size = inputs[0].shape[0] + + # number of perturbation, which is not the same as + # perturbations_per_eval when not enough features to perturb + n_perturb = current_batch_size / num_examples + if perturbations_per_eval > 1 and not self._is_output_shape_valid: + + current_output_shape = modified_eval.shape + + # use initial_eval as the forward of perturbations_per_eval = 1 + initial_output_shape = initial_eval.shape + + assert ( + # check if the output is not a scalar + current_output_shape + and initial_output_shape + # check if the output grow in same ratio, i.e., not agg + and current_output_shape[0] == n_perturb * initial_output_shape[0] + ), ( + "When perturbations_per_eval > 1, forward_func's output " + "should be a tensor whose 1st dim grow with the input " + f"batch size: when input batch size is {num_examples}, " + f"the output shape is {initial_output_shape}; " + f"when input batch size is {current_batch_size}, " + f"the output shape is {current_output_shape}" + ) + + self._is_output_shape_valid = True # reshape the leading dim for n_feature_perturbed # flatten each feature's eval outputs into 1D of (n_outputs) @@ -1209,9 +1297,6 @@ def _process_ablated_out_full( eval_diff = flattened_initial_eval - modified_eval eval_diff_shape = eval_diff.shape - # append the shape of one input example - # to make it broadcastable to mask - if self.use_weights: for weight, mask in zip(weights, current_mask): if mask is not None: @@ -1224,6 +1309,7 @@ def _process_ablated_out_full( ) eval_diff = eval_diff.to(total_attrib[i].device) total_attrib[i] += (eval_diff * mask.to(attrib_type)).sum(dim=0) + return total_attrib, weights def _fut_tuple_to_accumulate_fut_list( diff --git a/captum/attr/_core/feature_permutation.py b/captum/attr/_core/feature_permutation.py index 3657c00fc2..0d64f1d8b0 100644 --- a/captum/attr/_core/feature_permutation.py +++ b/captum/attr/_core/feature_permutation.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # pyre-strict -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric @@ -391,15 +391,41 @@ def _construct_ablated_input_across_tensors( inputs: Tuple[Tensor, ...], input_mask: Tuple[Tensor, ...], baselines: BaselineType, - feature_idx: int, - tensor_idxs: List[int], + feature_idxs: List[int], + feature_idx_to_tensor_idx: Dict[int, List[int]], + current_num_ablated_features: int, ) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]: current_masks: List[Optional[Tensor]] = [] - for i, mask in enumerate(input_mask): - if i in tensor_idxs: - current_masks.append((mask == feature_idx).to(inputs[0].device)) - else: + tensor_idxs = { + tensor_idx + for sublist in ( + feature_idx_to_tensor_idx[feature_idx] for feature_idx in feature_idxs + ) + for tensor_idx in sublist + } + permuted_inputs = [] + for i, input_tensor in enumerate(inputs): + if i not in tensor_idxs: current_masks.append(None) - feature_masks = tuple(current_masks) - permuted_outputs = self.perm_func_cross_tensor(inputs, feature_masks) - return permuted_outputs, feature_masks + permuted_inputs.append(input_tensor) + continue + tensor_mask = [] + permuted_input = input_tensor.clone() + for j, feature_idx in enumerate(feature_idxs): + original_input_size = ( + input_tensor.shape[0] // current_num_ablated_features + ) + start_idx = j * original_input_size + end_idx = (j + 1) * original_input_size + + mask = (input_mask[i] == feature_idx).to(input_tensor.device).bool() + if mask.ndim == 0: + mask = mask.reshape((1,) * input_tensor.dim()) + tensor_mask.append(mask) + permuted_input[start_idx:end_idx] = self.perm_func( + input_tensor[start_idx:end_idx], mask + ) + current_masks.append(torch.stack(tensor_mask, dim=0)) + permuted_inputs.append(permuted_input) + + return tuple(permuted_inputs), tuple(current_masks) diff --git a/captum/testing/helpers/basic_models.py b/captum/testing/helpers/basic_models.py index 2833bad6bc..73bbad038a 100644 --- a/captum/testing/helpers/basic_models.py +++ b/captum/testing/helpers/basic_models.py @@ -662,6 +662,16 @@ def forward(self, x1: Tensor, x2: Tensor, x3: Tensor, scale: int): return self.model(scale * (x1 + x2 + x3)) +class BasicModel_MultiLayer_TupleInput(nn.Module): + def __init__(self) -> None: + super().__init__() + self.model = BasicModel_MultiLayer() + + @no_type_check + def forward(self, x: Tuple[Tensor, Tensor, Tensor]) -> Tensor: + return self.model(x[0] + x[1] + x[2]) + + class BasicModel_MultiLayer_MultiInput_with_Future(nn.Module): def __init__(self) -> None: super().__init__() diff --git a/tests/attr/neuron/test_neuron_ablation.py b/tests/attr/neuron/test_neuron_ablation.py index b316400feb..02b92cb65b 100644 --- a/tests/attr/neuron/test_neuron_ablation.py +++ b/tests/attr/neuron/test_neuron_ablation.py @@ -83,8 +83,8 @@ def test_multi_input_ablation_with_mask(self) -> None: inp2 = torch.tensor([[20.0, 50.0, 30.0], [0.0, 100.0, 0.0]]) inp3 = torch.tensor([[0.0, 100.0, 10.0], [2.0, 10.0, 3.0]]) mask1 = torch.tensor([[1, 1, 1], [0, 1, 0]]) - mask2 = torch.tensor([[0, 1, 2]]) - mask3 = torch.tensor([[0, 1, 2], [0, 0, 0]]) + mask2 = torch.tensor([[3, 4, 2]]) + mask3 = torch.tensor([[5, 6, 7], [5, 5, 5]]) expected = ( [[492.0, 492.0, 492.0], [200.0, 200.0, 200.0]], [[80.0, 200.0, 120.0], [0.0, 400.0, 0.0]], diff --git a/tests/attr/test_feature_ablation.py b/tests/attr/test_feature_ablation.py index c8f9802d6a..5c3101ad01 100644 --- a/tests/attr/test_feature_ablation.py +++ b/tests/attr/test_feature_ablation.py @@ -164,6 +164,19 @@ def test_multi_sample_ablation_with_mask(self) -> None: perturbations_per_eval=(1, 2, 3), ) + def test_multi_sample_ablation_with_mask_weighted(self) -> None: + ablation_algo = FeatureAblation(BasicModel_MultiLayer()) + ablation_algo.use_weights = True + inp = torch.tensor([[2.0, 10.0, 3.0], [20.0, 50.0, 30.0]]) + mask = torch.tensor([[0, 0, 1], [1, 1, 0]]) + self._ablation_test_assert( + ablation_algo, + inp, + [[41.0, 41.0, 12.0], [280.0, 280.0, 120.0]], + feature_mask=mask, + perturbations_per_eval=(1, 2, 3), + ) + def test_multi_input_ablation_with_mask(self) -> None: ablation_algo = FeatureAblation(BasicModel_MultiLayer_MultiInput()) inp1 = torch.tensor([[23.0, 100.0, 0.0], [20.0, 50.0, 30.0]]) @@ -207,6 +220,50 @@ def test_multi_input_ablation_with_mask(self) -> None: perturbations_per_eval=(1, 2, 3), ) + def test_multi_input_ablation_with_mask_weighted(self) -> None: + ablation_algo = FeatureAblation(BasicModel_MultiLayer_MultiInput()) + ablation_algo.use_weights = True + inp1 = torch.tensor([[23.0, 100.0, 0.0], [20.0, 50.0, 30.0]]) + inp2 = torch.tensor([[20.0, 50.0, 30.0], [0.0, 100.0, 0.0]]) + inp3 = torch.tensor([[0.0, 100.0, 10.0], [2.0, 10.0, 3.0]]) + mask1 = torch.tensor([[1, 1, 1], [0, 1, 0]]) + mask2 = torch.tensor([[3, 4, 2]]) + mask3 = torch.tensor([[5, 6, 7], [5, 5, 5]]) + expected = ( + [[492.0, 492.0, 492.0], [200.0, 200.0, 200.0]], + [[80.0, 200.0, 120.0], [0.0, 400.0, 0.0]], + [[0.0, 400.0, 40.0], [60.0, 60.0, 60.0]], + ) + self._ablation_test_assert( + ablation_algo, + (inp1, inp2, inp3), + expected, + additional_input=(1,), + feature_mask=(mask1, mask2, mask3), + ) + self._ablation_test_assert( + ablation_algo, + (inp1, inp2), + expected[0:1], + additional_input=(inp3, 1), + feature_mask=(mask1, mask2), + perturbations_per_eval=(1, 2, 3), + ) + expected_with_baseline = ( + [[468.0, 468.0, 468.0], [184.0, 192.0, 184.0]], + [[68.0, 188.0, 108.0], [-12.0, 388.0, -12.0]], + [[-16.0, 384.0, 24.0], [12.0, 12.0, 12.0]], + ) + self._ablation_test_assert( + ablation_algo, + (inp1, inp2, inp3), + expected_with_baseline, + additional_input=(1,), + feature_mask=(mask1, mask2, mask3), + baselines=(2, 3.0, 4), + perturbations_per_eval=(1, 2, 3), + ) + def test_multi_input_ablation_with_mask_dupe_feature_idx(self) -> None: ablation_algo = FeatureAblation(BasicModel_MultiLayer_MultiInput()) inp1 = torch.tensor([[23.0, 100.0, 0.0], [20.0, 50.0, 30.0]])