Skip to content

Adjust indices in LayerAttributor mask for individual neurons #1531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions captum/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,17 +194,18 @@ def _is_mask_valid(mask: Tensor, inp: Tensor) -> bool:
def _format_feature_mask(
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]],
inputs: Tuple[Tensor, ...],
start_idx: int = 0,
) -> Tuple[Tensor, ...]:
"""
Format a feature mask into a tuple of tensors.
The `inputs` should be correctly formatted first
If `feature_mask` is None, assign each non-batch dimension with a consecutive
integer from 0.
integer from `start_idx`.
If `feature_mask` is a tensor, wrap it in a tuple.
"""
if feature_mask is None:
formatted_mask = []
current_num_features = 0
current_num_features = start_idx
for inp in inputs:
# the following can handle empty tensor where numel is 0
# empty tensor will be added to the feature mask
Expand Down
208 changes: 147 additions & 61 deletions captum/attr/_core/feature_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,10 +353,12 @@ def attribute(
formatted_feature_mask,
attr_progress,
flattened_initial_eval,
initial_eval,
n_outputs,
total_attrib,
weights,
attrib_type,
perturbations_per_eval,
**kwargs,
)
else:
Expand Down Expand Up @@ -470,10 +472,12 @@ def _attribute_with_cross_tensor_feature_masks(
formatted_feature_mask: Tuple[Tensor, ...],
attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]],
flattened_initial_eval: Tensor,
initial_eval: Tensor,
n_outputs: int,
total_attrib: List[Tensor],
weights: List[Tensor],
attrib_type: dtype,
perturbations_per_eval: int,
**kwargs: Any,
) -> Tuple[List[Tensor], List[Tensor]]:
feature_idx_to_tensor_idx: Dict[int, List[int]] = {}
Expand All @@ -482,17 +486,78 @@ def _attribute_with_cross_tensor_feature_masks(
if feature_idx.item() not in feature_idx_to_tensor_idx:
feature_idx_to_tensor_idx[feature_idx.item()] = []
feature_idx_to_tensor_idx[feature_idx.item()].append(i)
all_feature_idxs = list(feature_idx_to_tensor_idx.keys())

additional_args_repeated: object
if perturbations_per_eval > 1:
# Repeat features and additional args for batch size.
all_features_repeated = tuple(
torch.cat([formatted_inputs[j]] * perturbations_per_eval, dim=0)
for j in range(len(formatted_inputs))
)
additional_args_repeated = (
_expand_additional_forward_args(
formatted_additional_forward_args, perturbations_per_eval
)
if formatted_additional_forward_args is not None
else None
)
target_repeated = _expand_target(target, perturbations_per_eval)
else:
all_features_repeated = formatted_inputs
additional_args_repeated = formatted_additional_forward_args
target_repeated = target
num_examples = formatted_inputs[0].shape[0]

current_additional_args: object
if isinstance(baselines, tuple):
reshaped = False
reshaped_baselines: list[Union[Tensor, int, float]] = []
for baseline in baselines:
if isinstance(baseline, Tensor):
reshaped = True
reshaped_baselines.append(
baseline.reshape((1,) + tuple(baseline.shape))
)
else:
reshaped_baselines.append(baseline)
baselines = tuple(reshaped_baselines) if reshaped else baselines
for i in range(0, len(all_feature_idxs), perturbations_per_eval):
current_feature_idxs = all_feature_idxs[i : i + perturbations_per_eval]
current_num_ablated_features = min(
perturbations_per_eval, len(current_feature_idxs)
)

# Store appropriate inputs and additional args based on batch size.
if current_num_ablated_features != perturbations_per_eval:
current_additional_args = (
_expand_additional_forward_args(
formatted_additional_forward_args, current_num_ablated_features
)
if formatted_additional_forward_args is not None
else None
)
current_target = _expand_target(target, current_num_ablated_features)
expanded_inputs = tuple(
feature_repeated[0 : current_num_ablated_features * num_examples]
for feature_repeated in all_features_repeated
)
else:
current_additional_args = additional_args_repeated
current_target = target_repeated
expanded_inputs = all_features_repeated

current_inputs, current_masks = (
self._construct_ablated_input_across_tensors(
expanded_inputs,
formatted_feature_mask,
baselines,
current_feature_idxs,
feature_idx_to_tensor_idx,
current_num_ablated_features,
)
)

for (
current_inputs,
current_mask,
) in self._ablation_generator(
formatted_inputs,
baselines,
formatted_feature_mask,
feature_idx_to_tensor_idx,
**kwargs,
):
# modified_eval has (n_feature_perturbed * n_outputs) elements
# shape:
# agg mode: (*initial_eval.shape)
Expand All @@ -501,8 +566,8 @@ def _attribute_with_cross_tensor_feature_masks(
modified_eval = _run_forward(
self.forward_func,
current_inputs,
target,
formatted_additional_forward_args,
current_target,
current_additional_args,
)

if attr_progress is not None:
Expand All @@ -515,75 +580,65 @@ def _attribute_with_cross_tensor_feature_masks(

total_attrib, weights = self._process_ablated_out_full(
modified_eval,
current_mask,
current_masks,
flattened_initial_eval,
formatted_inputs,
initial_eval,
current_inputs,
n_outputs,
num_examples,
total_attrib,
weights,
attrib_type,
perturbations_per_eval,
)
return total_attrib, weights

def _ablation_generator(
self,
inputs: Tuple[Tensor, ...],
baselines: BaselineType,
input_mask: Tuple[Tensor, ...],
feature_idx_to_tensor_idx: Dict[int, List[int]],
**kwargs: Any,
) -> Generator[
Tuple[
Tuple[Tensor, ...],
Tuple[Optional[Tensor], ...],
],
None,
None,
]:
if isinstance(baselines, torch.Tensor):
baselines = baselines.reshape((1,) + tuple(baselines.shape))

# Process one feature per time, rather than processing every input tensor
for feature_idx in feature_idx_to_tensor_idx.keys():
ablated_inputs, current_masks = (
self._construct_ablated_input_across_tensors(
inputs,
input_mask,
baselines,
feature_idx,
feature_idx_to_tensor_idx[feature_idx],
)
)
yield ablated_inputs, current_masks

def _construct_ablated_input_across_tensors(
self,
inputs: Tuple[Tensor, ...],
input_mask: Tuple[Tensor, ...],
baselines: BaselineType,
feature_idx: int,
tensor_idxs: List[int],
feature_idxs: List[int],
feature_idx_to_tensor_idx: Dict[int, List[int]],
current_num_ablated_features: int,
) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]:

ablated_inputs = []
current_masks: List[Optional[Tensor]] = []
tensor_idxs = {
tensor_idx
for sublist in (
feature_idx_to_tensor_idx[feature_idx] for feature_idx in feature_idxs
)
for tensor_idx in sublist
}

for i, input_tensor in enumerate(inputs):
if i not in tensor_idxs:
ablated_inputs.append(input_tensor)
current_masks.append(None)
continue
tensor_mask = (input_mask[i] == feature_idx).to(input_tensor.device).long()
tensor_mask = []
ablated_input = input_tensor.clone()
baseline = baselines[i] if isinstance(baselines, tuple) else baselines
if isinstance(baseline, torch.Tensor):
baseline = baseline.reshape(
(1,) * (input_tensor.dim() - baseline.dim()) + tuple(baseline.shape)
for j, feature_idx in enumerate(feature_idxs):
original_input_size = (
input_tensor.shape[0] // current_num_ablated_features
)
assert baseline is not None, "baseline must be provided"
ablated_input = (
input_tensor * (1 - tensor_mask).to(input_tensor.dtype)
) + (baseline * tensor_mask.to(input_tensor.dtype))
start_idx = j * original_input_size
end_idx = (j + 1) * original_input_size

mask = (input_mask[i] == feature_idx).to(input_tensor.device).long()
if mask.ndim == 0:
mask = mask.reshape((1,) * input_tensor.dim())
tensor_mask.append(mask)

assert baseline is not None, "baseline must be provided"
ablated_input[start_idx:end_idx] = input_tensor[start_idx:end_idx] * (
1 - mask
) + (baseline * mask.to(input_tensor.dtype))
current_masks.append(torch.stack(tensor_mask, dim=0))
ablated_inputs.append(ablated_input)
current_masks.append(tensor_mask)

return tuple(ablated_inputs), tuple(current_masks)

def _initial_eval_to_processed_initial_eval_fut(
Expand Down Expand Up @@ -784,7 +839,7 @@ def _attribute_progress_setup(
formatted_inputs, feature_mask, **kwargs
)
total_forwards = (
int(sum(feature_counts))
math.ceil(int(sum(feature_counts)) / perturbations_per_eval)
if enable_cross_tensor_attribution
else sum(
math.ceil(count / perturbations_per_eval) for count in feature_counts
Expand Down Expand Up @@ -1194,13 +1249,46 @@ def _process_ablated_out_full(
modified_eval: Tensor,
current_mask: Tuple[Optional[Tensor], ...],
flattened_initial_eval: Tensor,
initial_eval: Tensor,
inputs: TensorOrTupleOfTensorsGeneric,
n_outputs: int,
num_examples: int,
total_attrib: List[Tensor],
weights: List[Tensor],
attrib_type: dtype,
perturbations_per_eval: int,
) -> Tuple[List[Tensor], List[Tensor]]:
modified_eval = self._parse_forward_out(modified_eval)
# if perturbations_per_eval > 1, the output shape must grow with
# input and not be aggregated
current_batch_size = inputs[0].shape[0]

# number of perturbation, which is not the same as
# perturbations_per_eval when not enough features to perturb
n_perturb = current_batch_size / num_examples
if perturbations_per_eval > 1 and not self._is_output_shape_valid:

current_output_shape = modified_eval.shape

# use initial_eval as the forward of perturbations_per_eval = 1
initial_output_shape = initial_eval.shape

assert (
# check if the output is not a scalar
current_output_shape
and initial_output_shape
# check if the output grow in same ratio, i.e., not agg
and current_output_shape[0] == n_perturb * initial_output_shape[0]
), (
"When perturbations_per_eval > 1, forward_func's output "
"should be a tensor whose 1st dim grow with the input "
f"batch size: when input batch size is {num_examples}, "
f"the output shape is {initial_output_shape}; "
f"when input batch size is {current_batch_size}, "
f"the output shape is {current_output_shape}"
)

self._is_output_shape_valid = True

# reshape the leading dim for n_feature_perturbed
# flatten each feature's eval outputs into 1D of (n_outputs)
Expand All @@ -1209,9 +1297,6 @@ def _process_ablated_out_full(
eval_diff = flattened_initial_eval - modified_eval
eval_diff_shape = eval_diff.shape

# append the shape of one input example
# to make it broadcastable to mask

if self.use_weights:
for weight, mask in zip(weights, current_mask):
if mask is not None:
Expand All @@ -1224,6 +1309,7 @@ def _process_ablated_out_full(
)
eval_diff = eval_diff.to(total_attrib[i].device)
total_attrib[i] += (eval_diff * mask.to(attrib_type)).sum(dim=0)

return total_attrib, weights

def _fut_tuple_to_accumulate_fut_list(
Expand Down
46 changes: 36 additions & 10 deletions captum/attr/_core/feature_permutation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3

# pyre-strict
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
Expand Down Expand Up @@ -391,15 +391,41 @@ def _construct_ablated_input_across_tensors(
inputs: Tuple[Tensor, ...],
input_mask: Tuple[Tensor, ...],
baselines: BaselineType,
feature_idx: int,
tensor_idxs: List[int],
feature_idxs: List[int],
feature_idx_to_tensor_idx: Dict[int, List[int]],
current_num_ablated_features: int,
) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]:
current_masks: List[Optional[Tensor]] = []
for i, mask in enumerate(input_mask):
if i in tensor_idxs:
current_masks.append((mask == feature_idx).to(inputs[0].device))
else:
tensor_idxs = {
tensor_idx
for sublist in (
feature_idx_to_tensor_idx[feature_idx] for feature_idx in feature_idxs
)
for tensor_idx in sublist
}
permuted_inputs = []
for i, input_tensor in enumerate(inputs):
if i not in tensor_idxs:
current_masks.append(None)
feature_masks = tuple(current_masks)
permuted_outputs = self.perm_func_cross_tensor(inputs, feature_masks)
return permuted_outputs, feature_masks
permuted_inputs.append(input_tensor)
continue
tensor_mask = []
permuted_input = input_tensor.clone()
for j, feature_idx in enumerate(feature_idxs):
original_input_size = (
input_tensor.shape[0] // current_num_ablated_features
)
start_idx = j * original_input_size
end_idx = (j + 1) * original_input_size

mask = (input_mask[i] == feature_idx).to(input_tensor.device).bool()
if mask.ndim == 0:
mask = mask.reshape((1,) * input_tensor.dim())
tensor_mask.append(mask)
permuted_input[start_idx:end_idx] = self.perm_func(
input_tensor[start_idx:end_idx], mask
)
current_masks.append(torch.stack(tensor_mask, dim=0))
permuted_inputs.append(permuted_input)

return tuple(permuted_inputs), tuple(current_masks)
Loading