@@ -76,7 +76,7 @@ def _get_context():
7676
7777
7878VisualizationOutput = namedtuple (
79- "VisualizationOutput" , "feature_outputs actual predicted active_index"
79+ "VisualizationOutput" , "feature_outputs actual predicted active_index model_index "
8080)
8181Contribution = namedtuple ("Contribution" , "name percent" )
8282SampleCache = namedtuple ("SampleCache" , "inputs additional_forward_args label" )
@@ -149,11 +149,8 @@ def __init__(
149149 r"""
150150 Args:
151151
152- models (torch.nn.module): PyTorch module (model ) for attribution
152+ models (torch.nn.module): One or more PyTorch modules (models ) for attribution
153153 visualization.
154- We plan to support visualizing and comparing multiple models
155- in the future, but currently this supports only a single
156- model.
157154 classes (list of string): List of strings corresponding to the names of
158155 classes for classification.
159156 features (list of BaseFeature): List of BaseFeatures, which correspond
@@ -195,6 +192,7 @@ class scores.
195192 self .classes = classes
196193 self .features = features
197194 self .dataset = dataset
195+ self .models = models
198196 self .attribution_calculation = AttributionCalculation (
199197 models , classes , features , score_func , use_label_for_attr
200198 )
@@ -203,12 +201,16 @@ class scores.
203201 self ._dataset_iter = iter (dataset )
204202
205203 def _calculate_attribution_from_cache (
206- self , index : int , target : Optional [Tensor ]
204+ self , input_index : int , model_index : int , target : Optional [Tensor ]
207205 ) -> Optional [VisualizationOutput ]:
208- c = self ._outputs [index ][1 ]
206+ c = self ._outputs [input_index ][1 ]
209207 return self ._calculate_vis_output (
210- c .inputs , c .additional_forward_args , c .label , torch .tensor (target )
211- )
208+ c .inputs ,
209+ c .additional_forward_args ,
210+ c .label ,
211+ torch .tensor (target ),
212+ model_index ,
213+ )[0 ]
212214
213215 def _update_config (self , settings ):
214216 self ._config = FilterConfig (
@@ -344,67 +346,97 @@ def _should_keep_prediction(
344346 return True
345347
346348 def _calculate_vis_output (
347- self , inputs , additional_forward_args , label , target = None
348- ) -> Optional [VisualizationOutput ]:
349- actual_label_output = None
350- if label is not None and len (label ) > 0 :
351- label_index = int (label [0 ])
352- actual_label_output = OutputScore (
353- score = 100 , index = label_index , label = self .classes [label_index ]
354- )
355-
356- (
357- predicted_scores ,
358- baselines ,
359- transformed_inputs ,
360- ) = self .attribution_calculation .calculate_predicted_scores (
361- inputs , additional_forward_args
349+ self ,
350+ inputs ,
351+ additional_forward_args ,
352+ label ,
353+ target = None ,
354+ single_model_index = None ,
355+ ) -> Optional [List [VisualizationOutput ]]:
356+ # Use all models, unless the user wants to render data for a particular one
357+ models_used = (
358+ [self .models [single_model_index ]]
359+ if single_model_index is not None
360+ else self .models
362361 )
362+ results = []
363+ for model_index , model in enumerate (models_used ):
364+ # Get list of model visualizations for each input
365+ actual_label_output = None
366+ if label is not None and len (label ) > 0 :
367+ label_index = int (label [0 ])
368+ actual_label_output = OutputScore (
369+ score = 100 , index = label_index , label = self .classes [label_index ]
370+ )
371+
372+ (
373+ predicted_scores ,
374+ baselines ,
375+ transformed_inputs ,
376+ ) = self .attribution_calculation .calculate_predicted_scores (
377+ inputs , additional_forward_args , model
378+ )
363379
364- # Filter based on UI configuration
365- if actual_label_output is None or not self ._should_keep_prediction (
366- predicted_scores , actual_label_output
367- ):
368- return None
369-
370- if target is None :
371- target = predicted_scores [0 ].index if len (predicted_scores ) > 0 else None
372-
373- # attributions are given per input*
374- # inputs given to the model are described via `self.features`
375- #
376- # *an input contains multiple features that represent it
377- # e.g. all the pixels that describe an image is an input
378-
379- attrs_per_input_feature = self .attribution_calculation .calculate_attribution (
380- baselines ,
381- transformed_inputs ,
382- additional_forward_args ,
383- target ,
384- self ._config .attribution_method ,
385- self ._config .attribution_arguments ,
386- )
380+ # Filter based on UI configuration
381+ if actual_label_output is None or not self ._should_keep_prediction (
382+ predicted_scores , actual_label_output
383+ ):
384+ continue
385+
386+ if target is None :
387+ target = (
388+ predicted_scores [0 ].index if len (predicted_scores ) > 0 else None
389+ )
390+
391+ # attributions are given per input*
392+ # inputs given to the model are described via `self.features`
393+ #
394+ # *an input contains multiple features that represent it
395+ # e.g. all the pixels that describe an image is an input
396+
397+ attrs_per_input_feature = (
398+ self .attribution_calculation .calculate_attribution (
399+ baselines ,
400+ transformed_inputs ,
401+ additional_forward_args ,
402+ target ,
403+ self ._config .attribution_method ,
404+ self ._config .attribution_arguments ,
405+ model ,
406+ )
407+ )
387408
388- net_contrib = self .attribution_calculation .calculate_net_contrib (
389- attrs_per_input_feature
390- )
409+ net_contrib = self .attribution_calculation .calculate_net_contrib (
410+ attrs_per_input_feature
411+ )
391412
392- # the features per input given
393- features_per_input = [
394- feature .visualize (attr , data , contrib )
395- for feature , attr , data , contrib in zip (
396- self .features , attrs_per_input_feature , inputs , net_contrib
413+ # the features per input given
414+ features_per_input = [
415+ feature .visualize (attr , data , contrib )
416+ for feature , attr , data , contrib in zip (
417+ self .features , attrs_per_input_feature , inputs , net_contrib
418+ )
419+ ]
420+
421+ results .append (
422+ VisualizationOutput (
423+ feature_outputs = features_per_input ,
424+ actual = actual_label_output ,
425+ predicted = predicted_scores ,
426+ active_index = target
427+ if target is not None
428+ else actual_label_output .index ,
429+ # Even if we only iterated over one model, the index should be fixed
430+ # to show the index the model would have had in the list
431+ model_index = single_model_index
432+ if single_model_index is not None
433+ else model_index ,
434+ )
397435 )
398- ]
399436
400- return VisualizationOutput (
401- feature_outputs = features_per_input ,
402- actual = actual_label_output ,
403- predicted = predicted_scores ,
404- active_index = target if target is not None else actual_label_output .index ,
405- )
437+ return results if results else None
406438
407- def _get_outputs (self ) -> List [Tuple [VisualizationOutput , SampleCache ]]:
439+ def _get_outputs (self ) -> List [Tuple [List [ VisualizationOutput ] , SampleCache ]]:
408440 batch_data = next (self ._dataset_iter )
409441 vis_outputs = []
410442
0 commit comments