33from typing import (
44 Any ,
55 Callable ,
6+ cast ,
67 Dict ,
78 Iterable ,
89 List ,
@@ -76,7 +77,7 @@ def _get_context():
7677
7778
7879VisualizationOutput = namedtuple (
79- "VisualizationOutput" , "feature_outputs actual predicted active_index"
80+ "VisualizationOutput" , "feature_outputs actual predicted active_index model_index "
8081)
8182Contribution = namedtuple ("Contribution" , "name percent" )
8283SampleCache = namedtuple ("SampleCache" , "inputs additional_forward_args label" )
@@ -149,11 +150,8 @@ def __init__(
149150 r"""
150151 Args:
151152
152- models (torch.nn.module): PyTorch module (model ) for attribution
153+ models (torch.nn.module): One or more PyTorch modules (models ) for attribution
153154 visualization.
154- We plan to support visualizing and comparing multiple models
155- in the future, but currently this supports only a single
156- model.
157155 classes (list of string): List of strings corresponding to the names of
158156 classes for classification.
159157 features (list of BaseFeature): List of BaseFeatures, which correspond
@@ -195,6 +193,7 @@ class scores.
195193 self .classes = classes
196194 self .features = features
197195 self .dataset = dataset
196+ self .models = models
198197 self .attribution_calculation = AttributionCalculation (
199198 models , classes , features , score_func , use_label_for_attr
200199 )
@@ -203,13 +202,24 @@ class scores.
203202 self ._dataset_iter = iter (dataset )
204203
205204 def _calculate_attribution_from_cache (
206- self , index : int , target : Optional [Tensor ]
205+ self , input_index : int , model_index : int , target : Optional [Tensor ]
207206 ) -> Optional [VisualizationOutput ]:
208- c = self ._outputs [index ][1 ]
209- return self ._calculate_vis_output (
210- c .inputs , c .additional_forward_args , c .label , torch .tensor (target )
207+ c = self ._outputs [input_index ][1 ]
208+ result = self ._calculate_vis_output (
209+ c .inputs ,
210+ c .additional_forward_args ,
211+ c .label ,
212+ torch .tensor (target ),
213+ model_index ,
211214 )
212215
216+ # Type checking doesn't allow indexing for an optional value, so we must
217+ # manually check it and then cast it
218+ if not result :
219+ return None
220+ else :
221+ return result [0 ]
222+
213223 def _update_config (self , settings ):
214224 self ._config = FilterConfig (
215225 attribution_method = settings ["attribution_method" ],
@@ -344,67 +354,97 @@ def _should_keep_prediction(
344354 return True
345355
346356 def _calculate_vis_output (
347- self , inputs , additional_forward_args , label , target = None
348- ) -> Optional [VisualizationOutput ]:
349- actual_label_output = None
350- if label is not None and len (label ) > 0 :
351- label_index = int (label [0 ])
352- actual_label_output = OutputScore (
353- score = 100 , index = label_index , label = self .classes [label_index ]
354- )
355-
356- (
357- predicted_scores ,
358- baselines ,
359- transformed_inputs ,
360- ) = self .attribution_calculation .calculate_predicted_scores (
361- inputs , additional_forward_args
357+ self ,
358+ inputs ,
359+ additional_forward_args ,
360+ label ,
361+ target = None ,
362+ single_model_index = None ,
363+ ) -> Optional [List [VisualizationOutput ]]:
364+ # Use all models, unless the user wants to render data for a particular one
365+ models_used = (
366+ [self .models [single_model_index ]]
367+ if single_model_index is not None
368+ else self .models
362369 )
370+ results = []
371+ for model_index , model in enumerate (models_used ):
372+ # Get list of model visualizations for each input
373+ actual_label_output = None
374+ if label is not None and len (label ) > 0 :
375+ label_index = int (label [0 ])
376+ actual_label_output = OutputScore (
377+ score = 100 , index = label_index , label = self .classes [label_index ]
378+ )
379+
380+ (
381+ predicted_scores ,
382+ baselines ,
383+ transformed_inputs ,
384+ ) = self .attribution_calculation .calculate_predicted_scores (
385+ inputs , additional_forward_args , model
386+ )
363387
364- # Filter based on UI configuration
365- if actual_label_output is None or not self ._should_keep_prediction (
366- predicted_scores , actual_label_output
367- ):
368- return None
369-
370- if target is None :
371- target = predicted_scores [0 ].index if len (predicted_scores ) > 0 else None
372-
373- # attributions are given per input*
374- # inputs given to the model are described via `self.features`
375- #
376- # *an input contains multiple features that represent it
377- # e.g. all the pixels that describe an image is an input
378-
379- attrs_per_input_feature = self .attribution_calculation .calculate_attribution (
380- baselines ,
381- transformed_inputs ,
382- additional_forward_args ,
383- target ,
384- self ._config .attribution_method ,
385- self ._config .attribution_arguments ,
386- )
388+ # Filter based on UI configuration
389+ if actual_label_output is None or not self ._should_keep_prediction (
390+ predicted_scores , actual_label_output
391+ ):
392+ continue
393+
394+ if target is None :
395+ target = (
396+ predicted_scores [0 ].index if len (predicted_scores ) > 0 else None
397+ )
398+
399+ # attributions are given per input*
400+ # inputs given to the model are described via `self.features`
401+ #
402+ # *an input contains multiple features that represent it
403+ # e.g. all the pixels that describe an image is an input
404+
405+ attrs_per_input_feature = (
406+ self .attribution_calculation .calculate_attribution (
407+ baselines ,
408+ transformed_inputs ,
409+ additional_forward_args ,
410+ target ,
411+ self ._config .attribution_method ,
412+ self ._config .attribution_arguments ,
413+ model ,
414+ )
415+ )
387416
388- net_contrib = self .attribution_calculation .calculate_net_contrib (
389- attrs_per_input_feature
390- )
417+ net_contrib = self .attribution_calculation .calculate_net_contrib (
418+ attrs_per_input_feature
419+ )
391420
392- # the features per input given
393- features_per_input = [
394- feature .visualize (attr , data , contrib )
395- for feature , attr , data , contrib in zip (
396- self .features , attrs_per_input_feature , inputs , net_contrib
421+ # the features per input given
422+ features_per_input = [
423+ feature .visualize (attr , data , contrib )
424+ for feature , attr , data , contrib in zip (
425+ self .features , attrs_per_input_feature , inputs , net_contrib
426+ )
427+ ]
428+
429+ results .append (
430+ VisualizationOutput (
431+ feature_outputs = features_per_input ,
432+ actual = actual_label_output ,
433+ predicted = predicted_scores ,
434+ active_index = target
435+ if target is not None
436+ else actual_label_output .index ,
437+ # Even if we only iterated over one model, the index should be fixed
438+ # to show the index the model would have had in the list
439+ model_index = single_model_index
440+ if single_model_index is not None
441+ else model_index ,
442+ )
397443 )
398- ]
399444
400- return VisualizationOutput (
401- feature_outputs = features_per_input ,
402- actual = actual_label_output ,
403- predicted = predicted_scores ,
404- active_index = target if target is not None else actual_label_output .index ,
405- )
445+ return results if results else None
406446
407- def _get_outputs (self ) -> List [Tuple [VisualizationOutput , SampleCache ]]:
447+ def _get_outputs (self ) -> List [Tuple [List [ VisualizationOutput ] , SampleCache ]]:
408448 batch_data = next (self ._dataset_iter )
409449 vis_outputs = []
410450
0 commit comments