Skip to content

Commit 2828cd1

Browse files
Marcio Portofacebook-github-bot
authored andcommitted
Add LayerFeaturePermutation
Summary: Add support for layer attribution via permutation by combining the existing `LayerFeatureAblation` and `FeaturePermutation` attribution classes. See this [doc](https://docs.google.com/document/d/1HwlBYKOEhguA_9OVrndjuBXE5npr7o6rVntFMF8KDtU/edit#heading=h.fuwkwbjpq8z) for design. Unit tests will be added in a follow-up diff from yucu. Differential Revision: D54551200 fbshipit-source-id: 419fd5a2ba0129aae9eb90e8af681721483e4ea2
1 parent 38daaed commit 2828cd1

File tree

4 files changed

+29
-3
lines changed

4 files changed

+29
-3
lines changed

captum/attr/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
LayerDeepLiftShap,
2222
)
2323
from captum.attr._core.layer.layer_feature_ablation import LayerFeatureAblation # noqa
24+
from captum.attr._core.layer.layer_feature_permutation import ( # noqa
25+
LayerFeaturePermutation,
26+
)
2427
from captum.attr._core.layer.layer_gradient_shap import LayerGradientShap # noqa
2528
from captum.attr._core.layer.layer_gradient_x_activation import ( # noqa
2629
LayerGradientXActivation,

captum/attr/_core/feature_permutation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,10 @@ def attribute( # type: ignore
255255
>>> attr = feature_perm.attribute(input, target=1,
256256
>>> feature_mask=feature_mask)
257257
"""
258+
# Remove baselines from kwargs if provided so we don't specify this field
259+
# twice in the FeatureAblation.attribute call below.
260+
if isinstance(kwargs, dict) and "baselines" in kwargs:
261+
del kwargs["baselines"]
258262
return FeatureAblation.attribute.__wrapped__(
259263
self,
260264
inputs,

captum/attr/_core/layer/layer_feature_ablation.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,10 +288,10 @@ def forward_hook(module, inp, out=None):
288288
else inputs + layer_eval_len
289289
)
290290

291-
ablator = FeatureAblation(layer_forward_func)
291+
attributor = self.attributor(layer_forward_func)
292292

293-
layer_attribs = ablator.attribute.__wrapped__(
294-
ablator, # self
293+
layer_attribs = attributor.attribute.__wrapped__(
294+
attributor, # self
295295
layer_eval,
296296
baselines=layer_baselines,
297297
additional_forward_args=all_inputs,
@@ -300,3 +300,7 @@ def forward_hook(module, inp, out=None):
300300
)
301301
_attr = _format_output(len(layer_attribs) > 1, layer_attribs)
302302
return _attr
303+
304+
@property
305+
def attributor(self):
306+
return FeatureAblation
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#!/usr/bin/env python3
2+
from captum.attr._core.feature_permutation import FeaturePermutation
3+
from captum.attr._core.layer.layer_feature_ablation import LayerFeatureAblation
4+
5+
6+
class LayerFeaturePermutation(LayerFeatureAblation):
7+
r"""
8+
A perturbation based approach to computing layer attribution similar to
9+
LayerFeatureAblation, but using FeaturePermutation under the hood instead
10+
of FeatureAblation.
11+
"""
12+
13+
@property
14+
def attributor(self):
15+
return FeaturePermutation

0 commit comments

Comments
 (0)