diff --git a/captum/testing/helpers/basic_models.py b/captum/testing/helpers/basic_models.py index 77584594a9..2833bad6bc 100644 --- a/captum/testing/helpers/basic_models.py +++ b/captum/testing/helpers/basic_models.py @@ -2,7 +2,7 @@ # pyre-strict -from typing import no_type_check, Optional, Tuple, Union +from typing import Dict, no_type_check, Optional, Tuple, Union import torch import torch.nn as nn @@ -418,7 +418,7 @@ def forward(self, input1, input2, input3=None): return self.linear2(self.relu(self.linear1(embeddings))).sum(1) -class GradientUnsupportedLayerOutput(nn.Module): +class PassThroughLayerOutput(nn.Module): """ This layer is used to test the case where the model returns a layer that is not supported by the gradient computation. @@ -428,10 +428,8 @@ def __init__(self) -> None: super().__init__() @no_type_check - def forward( - self, unsupported_layer_output: PassThroughOutputType - ) -> PassThroughOutputType: - return unsupported_layer_output + def forward(self, output: PassThroughOutputType) -> PassThroughOutputType: + return output class BasicModel_GradientLayerAttribution(nn.Module): @@ -456,7 +454,7 @@ def __init__( self.relu = nn.ReLU(inplace=inplace) self.relu_alt = nn.ReLU(inplace=False) - self.unsupportedLayer = GradientUnsupportedLayerOutput() + self.unsupported_layer = PassThroughLayerOutput() self.linear2 = nn.Linear(4, 2) self.linear2.weight = nn.Parameter(torch.ones(2, 4)) @@ -466,15 +464,19 @@ def __init__( self.linear3.weight = nn.Parameter(torch.ones(2, 4)) self.linear3.bias = nn.Parameter(torch.tensor([-1.0, 1.0])) + self.int_layer = PassThroughLayerOutput() # sample layer with an int ouput + @no_type_check - def forward(self, x: Tensor, add_input: Optional[Tensor] = None) -> Tensor: + def forward( + self, x: Tensor, add_input: Optional[Tensor] = None + ) -> Dict[str, Tensor]: input = x if add_input is None else x + add_input lin0_out = self.linear0(input) lin1_out = self.linear1(lin0_out) lin1_out_alt = self.linear1_alt(lin0_out) if self.unsupported_layer_output is not None: - self.unsupportedLayer(self.unsupported_layer_output) + self.unsupported_layer(self.unsupported_layer_output) # unsupportedLayer is unused in the forward func. self.relu_alt( lin1_out_alt @@ -483,9 +485,17 @@ def forward(self, x: Tensor, add_input: Optional[Tensor] = None) -> Tensor: relu_out = self.relu(lin1_out) lin2_out = self.linear2(relu_out) - lin3_out = self.linear3(lin1_out_alt).to(torch.int64) + lin3_out = self.linear3(lin1_out_alt) + int_output = self.int_layer(lin3_out.to(torch.int64)) + + output_tensors = torch.cat((lin2_out, int_output), dim=1) - return torch.cat((lin2_out, lin3_out), dim=1) + # we return a dictionary of tensors as an output to test the case + # where an output accessor is required + return { + "task {}".format(i + 1): output_tensors[:, i] + for i in range(output_tensors.shape[1]) + } class MultiRelu(nn.Module):