@@ -82,7 +82,7 @@ def attribute(
8282 additional_forward_args : Any = None ,
8383 attribute_to_layer_input : bool = False ,
8484 relu_attributions : bool = False ,
85- split_channels : bool = False ,
85+ attr_dim_summation : bool = True ,
8686 ) -> Union [Tensor , Tuple [Tensor , ...]]:
8787 r"""
8888 Args:
@@ -150,10 +150,10 @@ def attribute(
150150 otherwise, by default, both positive and negative
151151 attributions are returned.
152152 Default: False
153- split_channels (bool, optional): Indicates whether to
154- keep attributions split per channel.
155- The default (False ) means to sum per channels .
156- Default: False
153+ attr_dim_summation (bool, optional): Indicates whether to
154+ sum attributions along dimension 1 (usually channel) .
155+ The default (True ) means to sum along dimension 1 .
156+ Default: True
157157
158158 Returns:
159159 *Tensor* or *tuple[Tensor, ...]* of **attributions**:
@@ -213,16 +213,17 @@ def attribute(
213213 for layer_grad in layer_gradients
214214 )
215215
216- if split_channels :
216+ if attr_dim_summation :
217217 scaled_acts = tuple (
218- summed_grad * layer_eval
218+ torch . sum ( summed_grad * layer_eval , dim = 1 , keepdim = True )
219219 for summed_grad , layer_eval in zip (summed_grads , layer_evals )
220220 )
221221 else :
222222 scaled_acts = tuple (
223- torch . sum ( summed_grad * layer_eval , dim = 1 , keepdim = True )
223+ summed_grad * layer_eval
224224 for summed_grad , layer_eval in zip (summed_grads , layer_evals )
225225 )
226+
226227 if relu_attributions :
227228 scaled_acts = tuple (F .relu (scaled_act ) for scaled_act in scaled_acts )
228229 return _format_output (len (scaled_acts ) > 1 , scaled_acts )
0 commit comments