Skip to content

Commit 945c582

Browse files
aobo-yfacebook-github-bot
authored andcommitted
export DataLoaderAttribution in attr (#1162)
Summary: Pull Request resolved: #1162 rename `DataloaderAttribution` to `DataLoaderAttribution` as it's `DataLoader` in pytorch export through `captum.attr` Reviewed By: vivekmig Differential Revision: D46992878 fbshipit-source-id: 2500f64183f329128e57fd0bcb67959242be8801
1 parent c4617ca commit 945c582

File tree

3 files changed

+19
-17
lines changed

3 files changed

+19
-17
lines changed

captum/attr/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/usr/bin/env python3
2+
from captum.attr._core.dataloader_attr import DataLoaderAttribution # noqa
23
from captum.attr._core.deep_lift import DeepLift, DeepLiftShap # noqa
34
from captum.attr._core.feature_ablation import FeatureAblation # noqa
45
from captum.attr._core.feature_permutation import FeaturePermutation # noqa
@@ -86,6 +87,7 @@
8687
"NeuronAttribution",
8788
"LayerAttribution",
8889
"IntegratedGradients",
90+
"DataLoaderAttribution",
8991
"DeepLift",
9092
"DeepLiftShap",
9193
"InputXGradient",

captum/attr/_core/dataloader_attr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
_run_forward,
1414
)
1515
from captum._utils.typing import BaselineType
16-
from captum.attr import FeatureAblation
16+
from captum.attr._core.feature_ablation import FeatureAblation
1717
from captum.attr._utils.attribution import Attribution
1818
from torch import Tensor
1919

@@ -140,7 +140,7 @@ def _convert_output_shape(
140140
return tuple(attr)
141141

142142

143-
class DataloaderAttribution(Attribution):
143+
class DataLoaderAttribution(Attribution):
144144
r"""
145145
Decorate a perturbation-based attribution algorthm to make it work with dataloaders.
146146
The decorated instance will calculate attribution in the

tests/attr/test_dataloader_attr.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch
77

8-
from captum.attr._core.dataloader_attr import DataloaderAttribution, InputRole
8+
from captum.attr._core.dataloader_attr import DataLoaderAttribution, InputRole
99
from captum.attr._core.feature_ablation import FeatureAblation
1010
from parameterized import parameterized
1111
from tests.helpers.basic import (
@@ -75,13 +75,13 @@ class Test(BaseTest):
7575
)
7676
def test_dl_attr(self, forward) -> None:
7777
fa = FeatureAblation(forward)
78-
dl_fa = DataloaderAttribution(fa)
78+
dl_fa = DataLoaderAttribution(fa)
7979

8080
dataloader = DataLoader(mock_dataset, batch_size=2)
8181

8282
dl_attributions = dl_fa.attribute(dataloader)
8383

84-
# default reduce of DataloaderAttribution works the same as concat all batches
84+
# default reduce of DataLoaderAttribution works the same as concat all batches
8585
attr_list = []
8686
for batch in dataloader:
8787
batch_attr = fa.attribute(tuple(batch))
@@ -109,13 +109,13 @@ def test_dl_attr_with_mask(self, forward) -> None:
109109
)
110110

111111
fa = FeatureAblation(forward)
112-
dl_fa = DataloaderAttribution(fa)
112+
dl_fa = DataLoaderAttribution(fa)
113113

114114
dataloader = DataLoader(mock_dataset, batch_size=2)
115115

116116
dl_attributions = dl_fa.attribute(dataloader, feature_mask=masks)
117117

118-
# default reduce of DataloaderAttribution works the same as concat all batches
118+
# default reduce of DataLoaderAttribution works the same as concat all batches
119119
attr_list = []
120120
for batch in dataloader:
121121
batch_attr = fa.attribute(tuple(batch), feature_mask=masks)
@@ -141,13 +141,13 @@ def test_dl_attr_with_baseline(self, forward) -> None:
141141
)
142142

143143
fa = FeatureAblation(forward)
144-
dl_fa = DataloaderAttribution(fa)
144+
dl_fa = DataLoaderAttribution(fa)
145145

146146
dataloader = DataLoader(mock_dataset, batch_size=2)
147147

148148
dl_attributions = dl_fa.attribute(dataloader, baselines=baselines)
149149

150-
# default reduce of DataloaderAttribution works the same as concat all batches
150+
# default reduce of DataLoaderAttribution works the same as concat all batches
151151
attr_list = []
152152
for batch in dataloader:
153153
batch_attr = fa.attribute(tuple(batch), baselines=baselines)
@@ -188,7 +188,7 @@ def to_metric(accum):
188188
)
189189

190190
fa = FeatureAblation(forward)
191-
dl_fa = DataloaderAttribution(fa)
191+
dl_fa = DataLoaderAttribution(fa)
192192

193193
batch_size = 2
194194
dataloader = DataLoader(mock_dataset, batch_size=batch_size)
@@ -243,7 +243,7 @@ def forward(*forward_inputs):
243243
return sum_forward(*forward_inputs)
244244

245245
fa = FeatureAblation(forward)
246-
dl_fa = DataloaderAttribution(fa)
246+
dl_fa = DataLoaderAttribution(fa)
247247

248248
batch_size = 2
249249
dataloader = DataLoader(mock_dataset, batch_size=batch_size)
@@ -257,7 +257,7 @@ def forward(*forward_inputs):
257257
# only inputs needs
258258
self.assertEqual(len(dl_attributions), n_attr_inputs)
259259

260-
# default reduce of DataloaderAttribution works the same as concat all batches
260+
# default reduce of DataLoaderAttribution works the same as concat all batches
261261
attr_list = []
262262
for batch in dataloader:
263263
attr_inputs = tuple(
@@ -283,7 +283,7 @@ def forward(*forward_inputs):
283283
def test_dl_attr_not_return_input_shape(self) -> None:
284284
forward = sum_forward
285285
fa = FeatureAblation(forward)
286-
dl_fa = DataloaderAttribution(fa)
286+
dl_fa = DataLoaderAttribution(fa)
287287

288288
dataloader = DataLoader(mock_dataset, batch_size=2)
289289

@@ -295,7 +295,7 @@ def test_dl_attr_not_return_input_shape(self) -> None:
295295
dl_attribution = cast(Tensor, dl_attribution)
296296
self.assertEqual(dl_attribution.shape, expected_attr_shape)
297297

298-
# default reduce of DataloaderAttribution works the same as concat all batches
298+
# default reduce of DataLoaderAttribution works the same as concat all batches
299299
attr_list = []
300300
for batch in dataloader:
301301
batch_attr = fa.attribute(tuple(batch))
@@ -321,7 +321,7 @@ def test_dl_attr_with_mask_not_return_input_shape(self) -> None:
321321
)
322322

323323
fa = FeatureAblation(forward)
324-
dl_fa = DataloaderAttribution(fa)
324+
dl_fa = DataLoaderAttribution(fa)
325325

326326
dataloader = DataLoader(mock_dataset, batch_size=2)
327327

@@ -340,7 +340,7 @@ def test_dl_attr_with_perturb_per_pass(self, perturb_per_pass) -> None:
340340
forward = sum_forward
341341

342342
fa = FeatureAblation(forward)
343-
dl_fa = DataloaderAttribution(fa)
343+
dl_fa = DataLoaderAttribution(fa)
344344

345345
mock_dl_iter = Mock(wraps=DataLoader.__iter__)
346346

@@ -360,7 +360,7 @@ def test_dl_attr_with_perturb_per_pass(self, perturb_per_pass) -> None:
360360
math.ceil(n_features / perturb_per_pass) + n_iter_overhead,
361361
)
362362

363-
# default reduce of DataloaderAttribution works the same as concat all batches
363+
# default reduce of DataLoaderAttribution works the same as concat all batches
364364
attr_list = []
365365
for batch in dataloader:
366366
batch_attr = fa.attribute(tuple(batch))

0 commit comments

Comments
 (0)