Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 12 additions & 15 deletions captum/insights/attr_vis/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import (
Any,
Callable,
cast,
Dict,
Iterable,
List,
Expand Down Expand Up @@ -150,8 +149,8 @@ def __init__(
r"""
Args:

models (torch.nn.module): One or more PyTorch modules (models) for attribution
visualization.
models (torch.nn.module): One or more PyTorch modules (models) for
attribution visualization.
classes (list of string): List of strings corresponding to the names of
classes for classification.
features (list of BaseFeature): List of BaseFeatures, which correspond
Expand Down Expand Up @@ -399,27 +398,25 @@ def _calculate_vis_output(
# *an input contains multiple features that represent it
# e.g. all the pixels that describe an image is an input

attrs_per_input_feature = (
self.attribution_calculation.calculate_attribution(
baselines,
transformed_inputs,
additional_forward_args,
target,
self._config.attribution_method,
self._config.attribution_arguments,
model,
)
attrs_per_feature = self.attribution_calculation.calculate_attribution(
baselines,
transformed_inputs,
additional_forward_args,
target,
self._config.attribution_method,
self._config.attribution_arguments,
model,
)

net_contrib = self.attribution_calculation.calculate_net_contrib(
attrs_per_input_feature
attrs_per_feature
)

# the features per input given
features_per_input = [
feature.visualize(attr, data, contrib)
for feature, attr, data, contrib in zip(
self.features, attrs_per_input_feature, inputs, net_contrib
self.features, attrs_per_feature, inputs, net_contrib
)
]

Expand Down
14 changes: 7 additions & 7 deletions captum/insights/attr_vis/attribution_calculation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ def calculate_predicted_scores(
) -> Tuple[
List[OutputScore], Optional[List[Tuple[Tensor, ...]]], Tuple[Tensor, ...]
]:
# Check to see if these inputs already have caches baselines and transformed inputs
hashableInputs = tuple(inputs)
if hashableInputs in self.baseline_cache:
baselines_group = self.baseline_cache[hashableInputs]
transformed_inputs = self.transformed_input_cache[hashableInputs]
# Check if inputs have cached baselines and transformed inputs
hashable_inputs = tuple(inputs)
if hashable_inputs in self.baseline_cache:
baselines_group = self.baseline_cache[hashable_inputs]
transformed_inputs = self.transformed_input_cache[hashable_inputs]
else:
# Initialize baselines
baseline_transforms_len = 1 # todo support multiple baselines
Expand All @@ -79,8 +79,8 @@ def calculate_predicted_scores(

baselines = cast(List[List[Optional[Tensor]]], baselines)
baselines_group = [tuple(b) for b in baselines]
self.baseline_cache[hashableInputs] = baselines_group
self.transformed_input_cache[hashableInputs] = transformed_inputs
self.baseline_cache[hashable_inputs] = baselines_group
self.transformed_input_cache[hashable_inputs] = transformed_inputs

outputs = _run_forward(
model,
Expand Down