|
3 | 3 | from typing import ( |
4 | 4 | Any, |
5 | 5 | Callable, |
6 | | - cast, |
7 | 6 | Dict, |
8 | 7 | Iterable, |
9 | 8 | List, |
@@ -150,8 +149,8 @@ def __init__( |
150 | 149 | r""" |
151 | 150 | Args: |
152 | 151 |
|
153 | | - models (torch.nn.module): One or more PyTorch modules (models) for attribution |
154 | | - visualization. |
| 152 | + models (torch.nn.module): One or more PyTorch modules (models) for |
| 153 | + attribution visualization. |
155 | 154 | classes (list of string): List of strings corresponding to the names of |
156 | 155 | classes for classification. |
157 | 156 | features (list of BaseFeature): List of BaseFeatures, which correspond |
@@ -399,27 +398,25 @@ def _calculate_vis_output( |
399 | 398 | # *an input contains multiple features that represent it |
400 | 399 | # e.g. all the pixels that describe an image is an input |
401 | 400 |
|
402 | | - attrs_per_input_feature = ( |
403 | | - self.attribution_calculation.calculate_attribution( |
404 | | - baselines, |
405 | | - transformed_inputs, |
406 | | - additional_forward_args, |
407 | | - target, |
408 | | - self._config.attribution_method, |
409 | | - self._config.attribution_arguments, |
410 | | - model, |
411 | | - ) |
| 401 | + attrs_per_feature = self.attribution_calculation.calculate_attribution( |
| 402 | + baselines, |
| 403 | + transformed_inputs, |
| 404 | + additional_forward_args, |
| 405 | + target, |
| 406 | + self._config.attribution_method, |
| 407 | + self._config.attribution_arguments, |
| 408 | + model, |
412 | 409 | ) |
413 | 410 |
|
414 | 411 | net_contrib = self.attribution_calculation.calculate_net_contrib( |
415 | | - attrs_per_input_feature |
| 412 | + attrs_per_feature |
416 | 413 | ) |
417 | 414 |
|
418 | 415 | # the features per input given |
419 | 416 | features_per_input = [ |
420 | 417 | feature.visualize(attr, data, contrib) |
421 | 418 | for feature, attr, data, contrib in zip( |
422 | | - self.features, attrs_per_input_feature, inputs, net_contrib |
| 419 | + self.features, attrs_per_feature, inputs, net_contrib |
423 | 420 | ) |
424 | 421 | ] |
425 | 422 |
|
|
0 commit comments