77import torch
88import torch .nn as nn
99import torch .nn .functional as F
10+ from captum ._utils .typing import PassThroughOutputType
1011from torch import Tensor
1112from torch .futures import Future
1213
@@ -417,6 +418,76 @@ def forward(self, input1, input2, input3=None):
417418 return self .linear2 (self .relu (self .linear1 (embeddings ))).sum (1 )
418419
419420
421+ class GradientUnsupportedLayerOutput (nn .Module ):
422+ """
423+ This layer is used to test the case where the model returns a layer that
424+ is not supported by the gradient computation.
425+ """
426+
427+ def __init__ (self ) -> None :
428+ super ().__init__ ()
429+
430+ @no_type_check
431+ def forward (
432+ self , unsupported_layer_output : PassThroughOutputType
433+ ) -> PassThroughOutputType :
434+ return unsupported_layer_output
435+
436+
437+ class BasicModel_GradientLayerAttribution (nn .Module ):
438+ def __init__ (
439+ self ,
440+ inplace : bool = False ,
441+ unsupported_layer_output : PassThroughOutputType = None ,
442+ ) -> None :
443+ super ().__init__ ()
444+ # Linear 0 is simply identity transform
445+ self .unsupported_layer_output = unsupported_layer_output
446+ self .linear0 = nn .Linear (3 , 3 )
447+ self .linear0 .weight = nn .Parameter (torch .eye (3 ))
448+ self .linear0 .bias = nn .Parameter (torch .zeros (3 ))
449+ self .linear1 = nn .Linear (3 , 4 )
450+ self .linear1 .weight = nn .Parameter (torch .ones (4 , 3 ))
451+ self .linear1 .bias = nn .Parameter (torch .tensor ([- 10.0 , 1.0 , 1.0 , 1.0 ]))
452+
453+ self .linear1_alt = nn .Linear (3 , 4 )
454+ self .linear1_alt .weight = nn .Parameter (torch .ones (4 , 3 ))
455+ self .linear1_alt .bias = nn .Parameter (torch .tensor ([- 10.0 , 1.0 , 1.0 , 1.0 ]))
456+
457+ self .relu = nn .ReLU (inplace = inplace )
458+ self .relu_alt = nn .ReLU (inplace = False )
459+ self .unsupportedLayer = GradientUnsupportedLayerOutput ()
460+
461+ self .linear2 = nn .Linear (4 , 2 )
462+ self .linear2 .weight = nn .Parameter (torch .ones (2 , 4 ))
463+ self .linear2 .bias = nn .Parameter (torch .tensor ([- 1.0 , 1.0 ]))
464+
465+ self .linear3 = nn .Linear (4 , 2 )
466+ self .linear3 .weight = nn .Parameter (torch .ones (2 , 4 ))
467+ self .linear3 .bias = nn .Parameter (torch .tensor ([- 1.0 , 1.0 ]))
468+
469+ @no_type_check
470+ def forward (self , x : Tensor , add_input : Optional [Tensor ] = None ) -> Tensor :
471+ input = x if add_input is None else x + add_input
472+ lin0_out = self .linear0 (input )
473+ lin1_out = self .linear1 (lin0_out )
474+ lin1_out_alt = self .linear1_alt (lin0_out )
475+
476+ if self .unsupported_layer_output is not None :
477+ self .unsupportedLayer (self .unsupported_layer_output )
478+ # unsupportedLayer is unused in the forward func.
479+ self .relu_alt (
480+ lin1_out_alt
481+ ) # relu_alt's output is supported but it's unused in the forward func.
482+
483+ relu_out = self .relu (lin1_out )
484+ lin2_out = self .linear2 (relu_out )
485+
486+ lin3_out = self .linear3 (lin1_out_alt ).to (torch .int64 )
487+
488+ return torch .cat ((lin2_out , lin3_out ), dim = 1 )
489+
490+
420491class MultiRelu (nn .Module ):
421492 def __init__ (self , inplace : bool = False ) -> None :
422493 super ().__init__ ()
@@ -429,7 +500,11 @@ def forward(self, arg1: Tensor, arg2: Tensor) -> Tuple[Tensor, Tensor]:
429500
430501
431502class BasicModel_MultiLayer (nn .Module ):
432- def __init__ (self , inplace : bool = False , multi_input_module : bool = False ) -> None :
503+ def __init__ (
504+ self ,
505+ inplace : bool = False ,
506+ multi_input_module : bool = False ,
507+ ) -> None :
433508 super ().__init__ ()
434509 # Linear 0 is simply identity transform
435510 self .multi_input_module = multi_input_module
@@ -461,6 +536,7 @@ def forward(
461536 input = x if add_input is None else x + add_input
462537 lin0_out = self .linear0 (input )
463538 lin1_out = self .linear1 (lin0_out )
539+
464540 if self .multi_input_module :
465541 relu_out1 , relu_out2 = self .multi_relu (lin1_out , self .linear1_alt (input ))
466542 relu_out = relu_out1 + relu_out2
0 commit comments