Skip to content

Commit 45a79cb

Browse files
sarahtranfbfacebook-github-bot
authored andcommitted
Fix total_forwards calculation in ablation/permutation for cross-tensor attribution
Summary: Previously, we could iterate over the feature masks and get the feature count from the ID range in the mask. Now mask ID/indices are global Differential Revision: D72480911
1 parent 6bcc13b commit 45a79cb

File tree

3 files changed

+31
-2
lines changed

3 files changed

+31
-2
lines changed

captum/attr/_core/feature_ablation.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@
3131
from captum._utils.progress import progress, SimpleProgress
3232
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
3333
from captum.attr._utils.attribution import PerturbationAttribution
34-
from captum.attr._utils.common import _format_input_baseline
34+
from captum.attr._utils.common import (
35+
_format_input_baseline,
36+
get_total_features_from_mask,
37+
)
3538
from captum.log import log_usage
3639
from torch import dtype, Tensor
3740
from torch.futures import collect_all, Future
@@ -894,7 +897,9 @@ def _attribute_progress_setup(
894897
formatted_inputs, feature_mask, **kwargs
895898
)
896899
total_forwards = (
897-
math.ceil(int(sum(feature_counts)) / perturbations_per_eval)
900+
math.ceil(
901+
get_total_features_from_mask(feature_mask) / perturbations_per_eval
902+
)
898903
if enable_cross_tensor_attribution
899904
else sum(
900905
math.ceil(count / perturbations_per_eval) for count in feature_counts

captum/attr/_utils/common.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,3 +390,16 @@ def _construct_default_feature_mask(
390390
total_features = current_num_features
391391
feature_mask = tuple(feature_mask)
392392
return feature_mask, total_features
393+
394+
395+
def get_total_features_from_mask(
396+
feature_mask: Tuple[Tensor, ...],
397+
) -> int:
398+
"""
399+
Return the numbers of input features based on the total unique
400+
feature IDs/indices in the feature mask.
401+
"""
402+
seen_idxs = set()
403+
for mask in feature_mask:
404+
seen_idxs |= set(torch.unique(mask).tolist())
405+
return len(seen_idxs)

tests/utils/test_common.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
parse_version,
1515
safe_div,
1616
)
17+
from captum.attr._utils.common import get_total_features_from_mask
1718
from captum.testing.helpers.basic import (
1819
assertTensorAlmostEqual,
1920
assertTensorTuplesAlmostEqual,
@@ -174,6 +175,16 @@ def test_get_max_feature_index(self) -> None:
174175

175176
assert _get_max_feature_index(mask) == 100
176177

178+
def test_mask_unique_elem(self) -> None:
179+
res = get_total_features_from_mask((torch.tensor([0, 0, 0]),))
180+
self.assertEqual(res, 1)
181+
res = get_total_features_from_mask((torch.tensor([0, 0, 4]),))
182+
self.assertEqual(res, 2)
183+
res = get_total_features_from_mask(
184+
(torch.tensor([0, 0, 4]), torch.tensor([0, 4, 5]))
185+
)
186+
self.assertEqual(res, 3)
187+
177188

178189
class TestParseVersion(BaseTest):
179190
def test_parse_version_dev(self) -> None:

0 commit comments

Comments
 (0)