33# pyre-strict 
44
55import  math 
6- from  typing  import  Any , Callable , cast , Generator , List , Optional , Tuple , TypeVar , Union 
6+ from  typing  import  (
7+     Any ,
8+     Callable ,
9+     cast ,
10+     Dict ,
11+     Generator ,
12+     List ,
13+     Optional ,
14+     Tuple ,
15+     TypeVar ,
16+     Union ,
17+ )
718
819import  torch 
920from  captum ._utils .common  import  (
@@ -465,13 +476,21 @@ def _attribute_with_cross_tensor_feature_masks(
465476        attrib_type : dtype ,
466477        ** kwargs : Any ,
467478    ) ->  Tuple [List [Tensor ], List [Tensor ]]:
479+         feature_idx_to_tensor_idx : Dict [int , List [int ]] =  {}
480+         for  i , mask  in  enumerate (formatted_feature_mask ):
481+             for  feature_idx  in  torch .unique (mask ):
482+                 if  feature_idx .item () not  in feature_idx_to_tensor_idx :
483+                     feature_idx_to_tensor_idx [feature_idx .item ()] =  []
484+                 feature_idx_to_tensor_idx [feature_idx .item ()].append (i )
485+ 
468486        for  (
469487            current_inputs ,
470488            current_mask ,
471489        ) in  self ._ablation_generator (
472490            formatted_inputs ,
473491            baselines ,
474492            formatted_feature_mask ,
493+             feature_idx_to_tensor_idx ,
475494            ** kwargs ,
476495        ):
477496            # modified_eval has (n_feature_perturbed * n_outputs) elements 
@@ -511,27 +530,28 @@ def _ablation_generator(
511530        inputs : Tuple [Tensor , ...],
512531        baselines : BaselineType ,
513532        input_mask : Tuple [Tensor , ...],
533+         feature_idx_to_tensor_idx : Dict [int , List [int ]],
514534        ** kwargs : Any ,
515535    ) ->  Generator [
516536        Tuple [
517537            Tuple [Tensor , ...],
518-             Tuple [Tensor , ...],
538+             Tuple [Optional [ Tensor ] , ...],
519539        ],
520540        None ,
521541        None ,
522542    ]:
523-         unique_feature_ids  =  torch .unique (
524-             torch .cat ([mask .flatten () for  mask  in  input_mask ])
525-         ).tolist ()
526- 
527543        if  isinstance (baselines , torch .Tensor ):
528544            baselines  =  baselines .reshape ((1 ,) +  tuple (baselines .shape ))
529545
530546        # Process one feature per time, rather than processing every input tensor 
531-         for  feature_idx  in  unique_feature_ids :
547+         for  feature_idx  in  feature_idx_to_tensor_idx . keys () :
532548            ablated_inputs , current_masks  =  (
533549                self ._construct_ablated_input_across_tensors (
534-                     inputs , input_mask , baselines , feature_idx 
550+                     inputs ,
551+                     input_mask ,
552+                     baselines ,
553+                     feature_idx ,
554+                     feature_idx_to_tensor_idx [feature_idx ],
535555                )
536556            )
537557            yield  ablated_inputs , current_masks 
@@ -542,18 +562,17 @@ def _construct_ablated_input_across_tensors(
542562        input_mask : Tuple [Tensor , ...],
543563        baselines : BaselineType ,
544564        feature_idx : int ,
545-     ) ->  Tuple [Tuple [Tensor , ...], Tuple [Tensor , ...]]:
565+         tensor_idxs : List [int ],
566+     ) ->  Tuple [Tuple [Tensor , ...], Tuple [Optional [Tensor ], ...]]:
546567
547568        ablated_inputs  =  []
548-         current_masks  =  []
569+         current_masks :  List [ Optional [ Tensor ]]  =  []
549570        for  i , input_tensor  in  enumerate (inputs ):
550-             mask  =  input_mask [i ]
551-             tensor_mask  =  mask  ==  feature_idx 
552-             if  not  tensor_mask .any ():
571+             if  i  not  in tensor_idxs :
553572                ablated_inputs .append (input_tensor )
554-                 current_masks .append (torch . zeros_like ( tensor_mask ) )
573+                 current_masks .append (None )
555574                continue 
556-             tensor_mask  =  tensor_mask .to (input_tensor .device ).long ()
575+             tensor_mask  =  ( input_mask [ i ]  ==   feature_idx ) .to (input_tensor .device ).long ()
557576            baseline  =  baselines [i ] if  isinstance (baselines , tuple ) else  baselines 
558577            if  isinstance (baseline , torch .Tensor ):
559578                baseline  =  baseline .reshape (
@@ -1173,7 +1192,7 @@ def _process_ablated_out(
11731192    def  _process_ablated_out_full (
11741193        self ,
11751194        modified_eval : Tensor ,
1176-         current_mask : Tuple [Tensor , ...],
1195+         current_mask : Tuple [Optional [ Tensor ] , ...],
11771196        flattened_initial_eval : Tensor ,
11781197        inputs : TensorOrTupleOfTensorsGeneric ,
11791198        n_outputs : int ,
@@ -1195,9 +1214,10 @@ def _process_ablated_out_full(
11951214
11961215        if  self .use_weights :
11971216            for  weight , mask  in  zip (weights , current_mask ):
1198-                 weight  +=  mask .float ().sum (dim = 0 )
1217+                 if  mask  is  not None :
1218+                     weight  +=  mask .float ().sum (dim = 0 )
11991219        for  i , mask  in  enumerate (current_mask ):
1200-             if  inputs [i ].numel () ==  0 :
1220+             if  mask   is   None   or   inputs [i ].numel () ==  0 :
12011221                continue 
12021222            eval_diff  =  eval_diff .reshape (
12031223                eval_diff_shape  +  (inputs [i ].dim () -  1 ) *  (1 ,)
0 commit comments