33# pyre-strict
44
55import math
6- from typing import Any , Callable , cast , Generator , List , Optional , Tuple , TypeVar , Union
6+ from typing import (
7+ Any ,
8+ Callable ,
9+ cast ,
10+ Dict ,
11+ Generator ,
12+ List ,
13+ Optional ,
14+ Tuple ,
15+ TypeVar ,
16+ Union ,
17+ )
718
819import torch
920from captum ._utils .common import (
@@ -465,13 +476,21 @@ def _attribute_with_cross_tensor_feature_masks(
465476 attrib_type : dtype ,
466477 ** kwargs : Any ,
467478 ) -> Tuple [List [Tensor ], List [Tensor ]]:
479+ feature_idx_to_tensor_idx : Dict [int , List [int ]] = {}
480+ for i , mask in enumerate (formatted_feature_mask ):
481+ for feature_idx in torch .unique (mask ):
482+ if feature_idx .item () not in feature_idx_to_tensor_idx :
483+ feature_idx_to_tensor_idx [feature_idx .item ()] = []
484+ feature_idx_to_tensor_idx [feature_idx .item ()].append (i )
485+
468486 for (
469487 current_inputs ,
470488 current_mask ,
471489 ) in self ._ablation_generator (
472490 formatted_inputs ,
473491 baselines ,
474492 formatted_feature_mask ,
493+ feature_idx_to_tensor_idx ,
475494 ** kwargs ,
476495 ):
477496 # modified_eval has (n_feature_perturbed * n_outputs) elements
@@ -511,27 +530,28 @@ def _ablation_generator(
511530 inputs : Tuple [Tensor , ...],
512531 baselines : BaselineType ,
513532 input_mask : Tuple [Tensor , ...],
533+ feature_idx_to_tensor_idx : Dict [int , List [int ]],
514534 ** kwargs : Any ,
515535 ) -> Generator [
516536 Tuple [
517537 Tuple [Tensor , ...],
518- Tuple [Tensor , ...],
538+ Tuple [Optional [ Tensor ] , ...],
519539 ],
520540 None ,
521541 None ,
522542 ]:
523- unique_feature_ids = torch .unique (
524- torch .cat ([mask .flatten () for mask in input_mask ])
525- ).tolist ()
526-
527543 if isinstance (baselines , torch .Tensor ):
528544 baselines = baselines .reshape ((1 ,) + tuple (baselines .shape ))
529545
530546 # Process one feature per time, rather than processing every input tensor
531- for feature_idx in unique_feature_ids :
547+ for feature_idx in feature_idx_to_tensor_idx . keys () :
532548 ablated_inputs , current_masks = (
533549 self ._construct_ablated_input_across_tensors (
534- inputs , input_mask , baselines , feature_idx
550+ inputs ,
551+ input_mask ,
552+ baselines ,
553+ feature_idx ,
554+ feature_idx_to_tensor_idx [feature_idx ],
535555 )
536556 )
537557 yield ablated_inputs , current_masks
@@ -542,18 +562,17 @@ def _construct_ablated_input_across_tensors(
542562 input_mask : Tuple [Tensor , ...],
543563 baselines : BaselineType ,
544564 feature_idx : int ,
545- ) -> Tuple [Tuple [Tensor , ...], Tuple [Tensor , ...]]:
565+ tensor_idxs : List [int ],
566+ ) -> Tuple [Tuple [Tensor , ...], Tuple [Optional [Tensor ], ...]]:
546567
547568 ablated_inputs = []
548- current_masks = []
569+ current_masks : List [ Optional [ Tensor ]] = []
549570 for i , input_tensor in enumerate (inputs ):
550- mask = input_mask [i ]
551- tensor_mask = mask == feature_idx
552- if not tensor_mask .any ():
571+ if i not in tensor_idxs :
553572 ablated_inputs .append (input_tensor )
554- current_masks .append (torch . zeros_like ( tensor_mask ) )
573+ current_masks .append (None )
555574 continue
556- tensor_mask = tensor_mask .to (input_tensor .device ).long ()
575+ tensor_mask = ( input_mask [ i ] == feature_idx ) .to (input_tensor .device ).long ()
557576 baseline = baselines [i ] if isinstance (baselines , tuple ) else baselines
558577 if isinstance (baseline , torch .Tensor ):
559578 baseline = baseline .reshape (
@@ -1173,7 +1192,7 @@ def _process_ablated_out(
11731192 def _process_ablated_out_full (
11741193 self ,
11751194 modified_eval : Tensor ,
1176- current_mask : Tuple [Tensor , ...],
1195+ current_mask : Tuple [Optional [ Tensor ] , ...],
11771196 flattened_initial_eval : Tensor ,
11781197 inputs : TensorOrTupleOfTensorsGeneric ,
11791198 n_outputs : int ,
@@ -1195,9 +1214,10 @@ def _process_ablated_out_full(
11951214
11961215 if self .use_weights :
11971216 for weight , mask in zip (weights , current_mask ):
1198- weight += mask .float ().sum (dim = 0 )
1217+ if mask is not None :
1218+ weight += mask .float ().sum (dim = 0 )
11991219 for i , mask in enumerate (current_mask ):
1200- if inputs [i ].numel () == 0 :
1220+ if mask is None or inputs [i ].numel () == 0 :
12011221 continue
12021222 eval_diff = eval_diff .reshape (
12031223 eval_diff_shape + (inputs [i ].dim () - 1 ) * (1 ,)
0 commit comments