Skip to content

Commit e4520f3

Browse files
Reubendfacebook-github-bot
authored andcommitted
Add ability to compare multiple models in Captum Insights (#551)
Summary: This PR adds the ability to compare multiple models in Captum Insights. ![Screenshot of model comparison](https://user-images.githubusercontent.com/13208038/101406612-869ed600-388e-11eb-9520-62797a9ae3db.png) In order to test this, I went through two scenarios. First, I made sure there are no regressions to single model workflows like this: 1. Start the Insights example with `python3 -m captum.insights.example` 2. Ensure that the original functionality is still working and there are no changes, other than the visual changes for column headers Then, I tested comparing multiple models by duplicating the existing example one: 1. Go to `example.py` 2. Duplicate the example model, by changing `models=[model]` to `models=[model, model, model]` 3. Check to make sure that it renders properly, and that selecting different target classes works to properly update the data for each visualization Pull Request resolved: #551 Reviewed By: edward-io Differential Revision: D25379744 Pulled By: Reubend fbshipit-source-id: 4999c1ef0f18b8f735cd47a890cef413a7c6548e
1 parent 03f89a5 commit e4520f3

23 files changed

+538
-348
lines changed

captum/insights/attr_vis/app.py

Lines changed: 100 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import (
44
Any,
55
Callable,
6+
cast,
67
Dict,
78
Iterable,
89
List,
@@ -76,7 +77,7 @@ def _get_context():
7677

7778

7879
VisualizationOutput = namedtuple(
79-
"VisualizationOutput", "feature_outputs actual predicted active_index"
80+
"VisualizationOutput", "feature_outputs actual predicted active_index model_index"
8081
)
8182
Contribution = namedtuple("Contribution", "name percent")
8283
SampleCache = namedtuple("SampleCache", "inputs additional_forward_args label")
@@ -149,11 +150,8 @@ def __init__(
149150
r"""
150151
Args:
151152
152-
models (torch.nn.module): PyTorch module (model) for attribution
153+
models (torch.nn.module): One or more PyTorch modules (models) for attribution
153154
visualization.
154-
We plan to support visualizing and comparing multiple models
155-
in the future, but currently this supports only a single
156-
model.
157155
classes (list of string): List of strings corresponding to the names of
158156
classes for classification.
159157
features (list of BaseFeature): List of BaseFeatures, which correspond
@@ -195,6 +193,7 @@ class scores.
195193
self.classes = classes
196194
self.features = features
197195
self.dataset = dataset
196+
self.models = models
198197
self.attribution_calculation = AttributionCalculation(
199198
models, classes, features, score_func, use_label_for_attr
200199
)
@@ -203,13 +202,21 @@ class scores.
203202
self._dataset_iter = iter(dataset)
204203

205204
def _calculate_attribution_from_cache(
206-
self, index: int, target: Optional[Tensor]
205+
self, input_index: int, model_index: int, target: Optional[Tensor]
207206
) -> Optional[VisualizationOutput]:
208-
c = self._outputs[index][1]
209-
return self._calculate_vis_output(
210-
c.inputs, c.additional_forward_args, c.label, torch.tensor(target)
207+
c = self._outputs[input_index][1]
208+
result = self._calculate_vis_output(
209+
c.inputs,
210+
c.additional_forward_args,
211+
c.label,
212+
torch.tensor(target),
213+
model_index,
211214
)
212215

216+
if not result:
217+
return None
218+
return result[0]
219+
213220
def _update_config(self, settings):
214221
self._config = FilterConfig(
215222
attribution_method=settings["attribution_method"],
@@ -344,67 +351,97 @@ def _should_keep_prediction(
344351
return True
345352

346353
def _calculate_vis_output(
347-
self, inputs, additional_forward_args, label, target=None
348-
) -> Optional[VisualizationOutput]:
349-
actual_label_output = None
350-
if label is not None and len(label) > 0:
351-
label_index = int(label[0])
352-
actual_label_output = OutputScore(
353-
score=100, index=label_index, label=self.classes[label_index]
354-
)
355-
356-
(
357-
predicted_scores,
358-
baselines,
359-
transformed_inputs,
360-
) = self.attribution_calculation.calculate_predicted_scores(
361-
inputs, additional_forward_args
354+
self,
355+
inputs,
356+
additional_forward_args,
357+
label,
358+
target=None,
359+
single_model_index=None,
360+
) -> Optional[List[VisualizationOutput]]:
361+
# Use all models, unless the user wants to render data for a particular one
362+
models_used = (
363+
[self.models[single_model_index]]
364+
if single_model_index is not None
365+
else self.models
362366
)
367+
results = []
368+
for model_index, model in enumerate(models_used):
369+
# Get list of model visualizations for each input
370+
actual_label_output = None
371+
if label is not None and len(label) > 0:
372+
label_index = int(label[0])
373+
actual_label_output = OutputScore(
374+
score=100, index=label_index, label=self.classes[label_index]
375+
)
376+
377+
(
378+
predicted_scores,
379+
baselines,
380+
transformed_inputs,
381+
) = self.attribution_calculation.calculate_predicted_scores(
382+
inputs, additional_forward_args, model
383+
)
363384

364-
# Filter based on UI configuration
365-
if actual_label_output is None or not self._should_keep_prediction(
366-
predicted_scores, actual_label_output
367-
):
368-
return None
369-
370-
if target is None:
371-
target = predicted_scores[0].index if len(predicted_scores) > 0 else None
372-
373-
# attributions are given per input*
374-
# inputs given to the model are described via `self.features`
375-
#
376-
# *an input contains multiple features that represent it
377-
# e.g. all the pixels that describe an image is an input
378-
379-
attrs_per_input_feature = self.attribution_calculation.calculate_attribution(
380-
baselines,
381-
transformed_inputs,
382-
additional_forward_args,
383-
target,
384-
self._config.attribution_method,
385-
self._config.attribution_arguments,
386-
)
385+
# Filter based on UI configuration
386+
if actual_label_output is None or not self._should_keep_prediction(
387+
predicted_scores, actual_label_output
388+
):
389+
continue
390+
391+
if target is None:
392+
target = (
393+
predicted_scores[0].index if len(predicted_scores) > 0 else None
394+
)
395+
396+
# attributions are given per input*
397+
# inputs given to the model are described via `self.features`
398+
#
399+
# *an input contains multiple features that represent it
400+
# e.g. all the pixels that describe an image is an input
401+
402+
attrs_per_input_feature = (
403+
self.attribution_calculation.calculate_attribution(
404+
baselines,
405+
transformed_inputs,
406+
additional_forward_args,
407+
target,
408+
self._config.attribution_method,
409+
self._config.attribution_arguments,
410+
model,
411+
)
412+
)
387413

388-
net_contrib = self.attribution_calculation.calculate_net_contrib(
389-
attrs_per_input_feature
390-
)
414+
net_contrib = self.attribution_calculation.calculate_net_contrib(
415+
attrs_per_input_feature
416+
)
391417

392-
# the features per input given
393-
features_per_input = [
394-
feature.visualize(attr, data, contrib)
395-
for feature, attr, data, contrib in zip(
396-
self.features, attrs_per_input_feature, inputs, net_contrib
418+
# the features per input given
419+
features_per_input = [
420+
feature.visualize(attr, data, contrib)
421+
for feature, attr, data, contrib in zip(
422+
self.features, attrs_per_input_feature, inputs, net_contrib
423+
)
424+
]
425+
426+
results.append(
427+
VisualizationOutput(
428+
feature_outputs=features_per_input,
429+
actual=actual_label_output,
430+
predicted=predicted_scores,
431+
active_index=target
432+
if target is not None
433+
else actual_label_output.index,
434+
# Even if we only iterated over one model, the index should be fixed
435+
# to show the index the model would have had in the list
436+
model_index=single_model_index
437+
if single_model_index is not None
438+
else model_index,
439+
)
397440
)
398-
]
399441

400-
return VisualizationOutput(
401-
feature_outputs=features_per_input,
402-
actual=actual_label_output,
403-
predicted=predicted_scores,
404-
active_index=target if target is not None else actual_label_output.index,
405-
)
442+
return results if results else None
406443

407-
def _get_outputs(self) -> List[Tuple[VisualizationOutput, SampleCache]]:
444+
def _get_outputs(self) -> List[Tuple[List[VisualizationOutput], SampleCache]]:
408445
batch_data = next(self._dataset_iter)
409446
vis_outputs = []
410447

captum/insights/attr_vis/attribution_calculation.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -41,42 +41,49 @@ def __init__(
4141
self.features = features
4242
self.score_func = score_func
4343
self.use_label_for_attr = use_label_for_attr
44+
self.baseline_cache: dict = {}
45+
self.transformed_input_cache: dict = {}
4446

4547
def calculate_predicted_scores(
46-
self, inputs, additional_forward_args
48+
self, inputs, additional_forward_args, model
4749
) -> Tuple[
4850
List[OutputScore], Optional[List[Tuple[Tensor, ...]]], Tuple[Tensor, ...]
4951
]:
50-
net = self.models[0] # TODO process multiple models
51-
52-
# initialize baselines
53-
baseline_transforms_len = 1 # todo support multiple baselines
54-
baselines: List[List[Optional[Tensor]]] = [
55-
[None] * len(self.features) for _ in range(baseline_transforms_len)
56-
]
57-
transformed_inputs = list(inputs)
58-
59-
for feature_i, feature in enumerate(self.features):
60-
transformed_inputs[feature_i] = self._transform(
61-
feature.input_transforms, transformed_inputs[feature_i], True
62-
)
63-
for baseline_i in range(baseline_transforms_len):
64-
if baseline_i > len(feature.baseline_transforms) - 1:
65-
baselines[baseline_i][feature_i] = torch.zeros_like(
66-
transformed_inputs[feature_i]
67-
)
68-
else:
69-
baselines[baseline_i][feature_i] = self._transform(
70-
[feature.baseline_transforms[baseline_i]],
71-
transformed_inputs[feature_i],
72-
True,
73-
)
74-
75-
baselines = cast(List[List[Tensor]], baselines)
76-
baselines_group = [tuple(b) for b in baselines]
52+
# Check to see if these inputs already have caches baselines and transformed inputs
53+
hashableInputs = tuple(inputs)
54+
if hashableInputs in self.baseline_cache:
55+
baselines_group = self.baseline_cache[hashableInputs]
56+
transformed_inputs = self.transformed_input_cache[hashableInputs]
57+
else:
58+
# Initialize baselines
59+
baseline_transforms_len = 1 # todo support multiple baselines
60+
baselines: List[List[Optional[Tensor]]] = [
61+
[None] * len(self.features) for _ in range(baseline_transforms_len)
62+
]
63+
transformed_inputs = list(inputs)
64+
for feature_i, feature in enumerate(self.features):
65+
transformed_inputs[feature_i] = self._transform(
66+
feature.input_transforms, transformed_inputs[feature_i], True
67+
)
68+
for baseline_i in range(baseline_transforms_len):
69+
if baseline_i > len(feature.baseline_transforms) - 1:
70+
baselines[baseline_i][feature_i] = torch.zeros_like(
71+
transformed_inputs[feature_i]
72+
)
73+
else:
74+
baselines[baseline_i][feature_i] = self._transform(
75+
[feature.baseline_transforms[baseline_i]],
76+
transformed_inputs[feature_i],
77+
True,
78+
)
79+
80+
baselines = cast(List[List[Optional[Tensor]]], baselines)
81+
baselines_group = [tuple(b) for b in baselines]
82+
self.baseline_cache[hashableInputs] = baselines_group
83+
self.transformed_input_cache[hashableInputs] = transformed_inputs
7784

7885
outputs = _run_forward(
79-
net,
86+
model,
8087
tuple(transformed_inputs),
8188
additional_forward_args=additional_forward_args,
8289
)
@@ -105,10 +112,10 @@ def calculate_attribution(
105112
label: Optional[Union[Tensor]],
106113
attribution_method_name: str,
107114
attribution_arguments: Dict,
115+
model: Module,
108116
) -> Tuple[Tensor, ...]:
109-
net = self.models[0]
110117
attribution_cls = ATTRIBUTION_NAMES_TO_METHODS[attribution_method_name]
111-
attribution_method = attribution_cls(net)
118+
attribution_method = attribution_cls(model)
112119
param_config = ATTRIBUTION_METHOD_CONFIG[attribution_method_name]
113120
if param_config.post_process:
114121
for k, v in attribution_arguments.items():

captum/insights/attr_vis/frontend/src/App.module.css

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@
5858
padding: 12px 8px;
5959
}
6060

61-
.filter-panel__column__title,
62-
.panel__column__title {
61+
.filter-panel__column__title {
6362
font-weight: bold;
6463
color: #1c1e21;
6564
padding-bottom: 12px;
@@ -164,12 +163,19 @@
164163
padding: 24px;
165164
background: white;
166165
border-radius: 8px;
167-
display: flex;
168166
box-shadow: 0px 3px 6px 0px rgba(0, 0, 0, 0.18);
169167
transition: opacity 0.2s; /* for loading */
170168
overflow-y: scroll;
171169
}
172170

171+
.panel__column__title {
172+
font-weight: 700;
173+
border-bottom: 2px solid #c1c1c1;
174+
color: #1c1e21;
175+
padding-bottom: 2px;
176+
margin-bottom: 15px;
177+
}
178+
173179
.panel--loading {
174180
opacity: 0.5;
175181
pointer-events: none; /* disables all interactions inside panel */
@@ -346,3 +352,25 @@
346352
transform: rotate(360deg);
347353
}
348354
}
355+
356+
.visualization-container {
357+
display: flex;
358+
}
359+
360+
.model-number {
361+
display: block;
362+
height: 2em;
363+
font-size: 16px;
364+
font-weight: 800;
365+
}
366+
367+
.model-number-spacer {
368+
display: block;
369+
height: 2em;
370+
}
371+
372+
.model-separator {
373+
width: 100%;
374+
border-bottom: 2px solid #c1c1c1;
375+
margin: 10px 0px;
376+
}

0 commit comments

Comments
 (0)