33from typing import (
44 Any ,
55 Callable ,
6+ cast ,
67 Dict ,
78 Iterable ,
89 List ,
@@ -76,7 +77,7 @@ def _get_context():
7677
7778
7879VisualizationOutput = namedtuple (
79- "VisualizationOutput" , "feature_outputs actual predicted active_index"
80+ "VisualizationOutput" , "feature_outputs actual predicted active_index model_index "
8081)
8182Contribution = namedtuple ("Contribution" , "name percent" )
8283SampleCache = namedtuple ("SampleCache" , "inputs additional_forward_args label" )
@@ -149,11 +150,8 @@ def __init__(
149150 r"""
150151 Args:
151152
152- models (torch.nn.module): PyTorch module (model ) for attribution
153+ models (torch.nn.module): One or more PyTorch modules (models ) for attribution
153154 visualization.
154- We plan to support visualizing and comparing multiple models
155- in the future, but currently this supports only a single
156- model.
157155 classes (list of string): List of strings corresponding to the names of
158156 classes for classification.
159157 features (list of BaseFeature): List of BaseFeatures, which correspond
@@ -195,6 +193,7 @@ class scores.
195193 self .classes = classes
196194 self .features = features
197195 self .dataset = dataset
196+ self .models = models
198197 self .attribution_calculation = AttributionCalculation (
199198 models , classes , features , score_func , use_label_for_attr
200199 )
@@ -203,13 +202,21 @@ class scores.
203202 self ._dataset_iter = iter (dataset )
204203
205204 def _calculate_attribution_from_cache (
206- self , index : int , target : Optional [Tensor ]
205+ self , input_index : int , model_index : int , target : Optional [Tensor ]
207206 ) -> Optional [VisualizationOutput ]:
208- c = self ._outputs [index ][1 ]
209- return self ._calculate_vis_output (
210- c .inputs , c .additional_forward_args , c .label , torch .tensor (target )
207+ c = self ._outputs [input_index ][1 ]
208+ result = self ._calculate_vis_output (
209+ c .inputs ,
210+ c .additional_forward_args ,
211+ c .label ,
212+ torch .tensor (target ),
213+ model_index ,
211214 )
212215
216+ if not result :
217+ return None
218+ return result [0 ]
219+
213220 def _update_config (self , settings ):
214221 self ._config = FilterConfig (
215222 attribution_method = settings ["attribution_method" ],
@@ -344,67 +351,97 @@ def _should_keep_prediction(
344351 return True
345352
346353 def _calculate_vis_output (
347- self , inputs , additional_forward_args , label , target = None
348- ) -> Optional [VisualizationOutput ]:
349- actual_label_output = None
350- if label is not None and len (label ) > 0 :
351- label_index = int (label [0 ])
352- actual_label_output = OutputScore (
353- score = 100 , index = label_index , label = self .classes [label_index ]
354- )
355-
356- (
357- predicted_scores ,
358- baselines ,
359- transformed_inputs ,
360- ) = self .attribution_calculation .calculate_predicted_scores (
361- inputs , additional_forward_args
354+ self ,
355+ inputs ,
356+ additional_forward_args ,
357+ label ,
358+ target = None ,
359+ single_model_index = None ,
360+ ) -> Optional [List [VisualizationOutput ]]:
361+ # Use all models, unless the user wants to render data for a particular one
362+ models_used = (
363+ [self .models [single_model_index ]]
364+ if single_model_index is not None
365+ else self .models
362366 )
367+ results = []
368+ for model_index , model in enumerate (models_used ):
369+ # Get list of model visualizations for each input
370+ actual_label_output = None
371+ if label is not None and len (label ) > 0 :
372+ label_index = int (label [0 ])
373+ actual_label_output = OutputScore (
374+ score = 100 , index = label_index , label = self .classes [label_index ]
375+ )
376+
377+ (
378+ predicted_scores ,
379+ baselines ,
380+ transformed_inputs ,
381+ ) = self .attribution_calculation .calculate_predicted_scores (
382+ inputs , additional_forward_args , model
383+ )
363384
364- # Filter based on UI configuration
365- if actual_label_output is None or not self ._should_keep_prediction (
366- predicted_scores , actual_label_output
367- ):
368- return None
369-
370- if target is None :
371- target = predicted_scores [0 ].index if len (predicted_scores ) > 0 else None
372-
373- # attributions are given per input*
374- # inputs given to the model are described via `self.features`
375- #
376- # *an input contains multiple features that represent it
377- # e.g. all the pixels that describe an image is an input
378-
379- attrs_per_input_feature = self .attribution_calculation .calculate_attribution (
380- baselines ,
381- transformed_inputs ,
382- additional_forward_args ,
383- target ,
384- self ._config .attribution_method ,
385- self ._config .attribution_arguments ,
386- )
385+ # Filter based on UI configuration
386+ if actual_label_output is None or not self ._should_keep_prediction (
387+ predicted_scores , actual_label_output
388+ ):
389+ continue
390+
391+ if target is None :
392+ target = (
393+ predicted_scores [0 ].index if len (predicted_scores ) > 0 else None
394+ )
395+
396+ # attributions are given per input*
397+ # inputs given to the model are described via `self.features`
398+ #
399+ # *an input contains multiple features that represent it
400+ # e.g. all the pixels that describe an image is an input
401+
402+ attrs_per_input_feature = (
403+ self .attribution_calculation .calculate_attribution (
404+ baselines ,
405+ transformed_inputs ,
406+ additional_forward_args ,
407+ target ,
408+ self ._config .attribution_method ,
409+ self ._config .attribution_arguments ,
410+ model ,
411+ )
412+ )
387413
388- net_contrib = self .attribution_calculation .calculate_net_contrib (
389- attrs_per_input_feature
390- )
414+ net_contrib = self .attribution_calculation .calculate_net_contrib (
415+ attrs_per_input_feature
416+ )
391417
392- # the features per input given
393- features_per_input = [
394- feature .visualize (attr , data , contrib )
395- for feature , attr , data , contrib in zip (
396- self .features , attrs_per_input_feature , inputs , net_contrib
418+ # the features per input given
419+ features_per_input = [
420+ feature .visualize (attr , data , contrib )
421+ for feature , attr , data , contrib in zip (
422+ self .features , attrs_per_input_feature , inputs , net_contrib
423+ )
424+ ]
425+
426+ results .append (
427+ VisualizationOutput (
428+ feature_outputs = features_per_input ,
429+ actual = actual_label_output ,
430+ predicted = predicted_scores ,
431+ active_index = target
432+ if target is not None
433+ else actual_label_output .index ,
434+ # Even if we only iterated over one model, the index should be fixed
435+ # to show the index the model would have had in the list
436+ model_index = single_model_index
437+ if single_model_index is not None
438+ else model_index ,
439+ )
397440 )
398- ]
399441
400- return VisualizationOutput (
401- feature_outputs = features_per_input ,
402- actual = actual_label_output ,
403- predicted = predicted_scores ,
404- active_index = target if target is not None else actual_label_output .index ,
405- )
442+ return results if results else None
406443
407- def _get_outputs (self ) -> List [Tuple [VisualizationOutput , SampleCache ]]:
444+ def _get_outputs (self ) -> List [Tuple [List [ VisualizationOutput ] , SampleCache ]]:
408445 batch_data = next (self ._dataset_iter )
409446 vis_outputs = []
410447
0 commit comments