Skip to content

Commit 68d4f24

Browse files
edward-iofacebook-github-bot
authored andcommitted
fix flake8 issues (#554)
Summary: Pull Request resolved: #554 Reviewed By: NarineK Differential Revision: D25418107 Pulled By: edward-io fbshipit-source-id: 960ed22c5f6845ac9fedff8793196660d6fd5529
1 parent e4520f3 commit 68d4f24

File tree

2 files changed

+19
-22
lines changed

2 files changed

+19
-22
lines changed

captum/insights/attr_vis/app.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import (
44
Any,
55
Callable,
6-
cast,
76
Dict,
87
Iterable,
98
List,
@@ -150,8 +149,8 @@ def __init__(
150149
r"""
151150
Args:
152151
153-
models (torch.nn.module): One or more PyTorch modules (models) for attribution
154-
visualization.
152+
models (torch.nn.module): One or more PyTorch modules (models) for
153+
attribution visualization.
155154
classes (list of string): List of strings corresponding to the names of
156155
classes for classification.
157156
features (list of BaseFeature): List of BaseFeatures, which correspond
@@ -399,27 +398,25 @@ def _calculate_vis_output(
399398
# *an input contains multiple features that represent it
400399
# e.g. all the pixels that describe an image is an input
401400

402-
attrs_per_input_feature = (
403-
self.attribution_calculation.calculate_attribution(
404-
baselines,
405-
transformed_inputs,
406-
additional_forward_args,
407-
target,
408-
self._config.attribution_method,
409-
self._config.attribution_arguments,
410-
model,
411-
)
401+
attrs_per_feature = self.attribution_calculation.calculate_attribution(
402+
baselines,
403+
transformed_inputs,
404+
additional_forward_args,
405+
target,
406+
self._config.attribution_method,
407+
self._config.attribution_arguments,
408+
model,
412409
)
413410

414411
net_contrib = self.attribution_calculation.calculate_net_contrib(
415-
attrs_per_input_feature
412+
attrs_per_feature
416413
)
417414

418415
# the features per input given
419416
features_per_input = [
420417
feature.visualize(attr, data, contrib)
421418
for feature, attr, data, contrib in zip(
422-
self.features, attrs_per_input_feature, inputs, net_contrib
419+
self.features, attrs_per_feature, inputs, net_contrib
423420
)
424421
]
425422

captum/insights/attr_vis/attribution_calculation.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,11 @@ def calculate_predicted_scores(
4949
) -> Tuple[
5050
List[OutputScore], Optional[List[Tuple[Tensor, ...]]], Tuple[Tensor, ...]
5151
]:
52-
# Check to see if these inputs already have caches baselines and transformed inputs
53-
hashableInputs = tuple(inputs)
54-
if hashableInputs in self.baseline_cache:
55-
baselines_group = self.baseline_cache[hashableInputs]
56-
transformed_inputs = self.transformed_input_cache[hashableInputs]
52+
# Check if inputs have cached baselines and transformed inputs
53+
hashable_inputs = tuple(inputs)
54+
if hashable_inputs in self.baseline_cache:
55+
baselines_group = self.baseline_cache[hashable_inputs]
56+
transformed_inputs = self.transformed_input_cache[hashable_inputs]
5757
else:
5858
# Initialize baselines
5959
baseline_transforms_len = 1 # todo support multiple baselines
@@ -79,8 +79,8 @@ def calculate_predicted_scores(
7979

8080
baselines = cast(List[List[Optional[Tensor]]], baselines)
8181
baselines_group = [tuple(b) for b in baselines]
82-
self.baseline_cache[hashableInputs] = baselines_group
83-
self.transformed_input_cache[hashableInputs] = transformed_inputs
82+
self.baseline_cache[hashable_inputs] = baselines_group
83+
self.transformed_input_cache[hashable_inputs] = transformed_inputs
8484

8585
outputs = _run_forward(
8686
model,

0 commit comments

Comments
 (0)