-
Notifications
You must be signed in to change notification settings - Fork 512
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LayerDeepLift fails when used on a MaxPooling layer? #382
Comments
Hi @Holt59, thank you for the question. That's interesting! I'll debug it. |
@Holt59, I've been debugging this issue. There seem to be some inconsistencies in the backward pass. In the meanwhile, as a workaround, if you want to attribute to the inputs of the MaxPool2D layer it will work. By default we attribute to the outputs of the layer.
|
Actually what you were doing will be equivalent to:
as a workaround |
@Holt59 , did that workaround work for you? |
The workaround seems to work but I cannot use it in my code base like this since I am trying to compute attributions for multiple layers (and I don't know the following layer). But that's not a big issue, I'm not particularly interested in the MaxPool layers, I can leave them out. |
@Holt59, this PR #390 will fix the problem with MaxPool. To give more context, this problem happened because in the More details about the issue can be found here: https://pytorch.org/docs/stable/nn.html#torch.nn.Module.register_backward_hook Another point that I wanted to bring up is: In VGG the modules might get reused (you might want to check that). We want to make sure that this isn't happening for the layer algorithms and DeepLift. |
…n issue for MaxPool (#390) Summary: Related to the issue: #382 asserting `grad_inputs` and `inputs` to have the same shape. More description about the workaround and why the issue happens can be found in the description of the assert. The error occurs when we attribute to the outputs of the layer because of the input or output tensor returned in the forward hook. Added `forward_hook_with_return_excl_modules` that contains the list of modules for which we don't want to have a return in the forward_hook. This is only used in DeepLift and can be used for any algorithm that attributes to maxpool and at the same time has a backward hook set on it. Added test cases for layer and neuron use cases. Pull Request resolved: #390 Reviewed By: edward-io Differential Revision: D22197030 Pulled By: NarineK fbshipit-source-id: e6cf712103900190f46c5c1e9051519f3eaa933f
…n issue for MaxPool (pytorch#390) Summary: Related to the issue: pytorch#382 asserting `grad_inputs` and `inputs` to have the same shape. More description about the workaround and why the issue happens can be found in the description of the assert. The error occurs when we attribute to the outputs of the layer because of the input or output tensor returned in the forward hook. Added `forward_hook_with_return_excl_modules` that contains the list of modules for which we don't want to have a return in the forward_hook. This is only used in DeepLift and can be used for any algorithm that attributes to maxpool and at the same time has a backward hook set on it. Added test cases for layer and neuron use cases. Pull Request resolved: pytorch#390 Reviewed By: edward-io Differential Revision: D22197030 Pulled By: NarineK fbshipit-source-id: e6cf712103900190f46c5c1e9051519f3eaa933f
This got fixed through: #390 |
…n issue for MaxPool (pytorch#390) Summary: Related to the issue: pytorch#382 asserting `grad_inputs` and `inputs` to have the same shape. More description about the workaround and why the issue happens can be found in the description of the assert. The error occurs when we attribute to the outputs of the layer because of the input or output tensor returned in the forward hook. Added `forward_hook_with_return_excl_modules` that contains the list of modules for which we don't want to have a return in the forward_hook. This is only used in DeepLift and can be used for any algorithm that attributes to maxpool and at the same time has a backward hook set on it. Added test cases for layer and neuron use cases. Pull Request resolved: pytorch#390 Reviewed By: edward-io Differential Revision: D22197030 Pulled By: NarineK fbshipit-source-id: e6cf712103900190f46c5c1e9051519f3eaa933f
I am trying to use
LayerDeepLift
on multiple layers of a VGG16 model fromtorchvision.models
. It works for all layers exceptMaxPooling2D
layers.The following (layer
23
is aMaxPool2d
layer):Raises the following:
It works on all layers except the
MaxPool2d
layers ofvgg16.features
(it works with the average pooling layer).I am not sure if this is a restriction of DeepLift or an error in the implementation?
Also, when the error occurs, the model seems to be left in some weird state as re-using it leads to
IndexError: tuple index out of range
(even with a brand newcaptum.attr.LayerDeepLift
instance).The text was updated successfully, but these errors were encountered: