Skip to content

LRP TypeError: Module type <class 'torch.nn.modules.upsampling.Upsample'> is not supported.No default rule defined. #712

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
bugsuse opened this issue Jul 8, 2021 · 12 comments

Comments

@bugsuse
Copy link

bugsuse commented Jul 8, 2021

I want to use LRP to explain Unet model, but raise TypeError as follow,

TypeError: Module type <class 'torch.nn.modules.upsampling.Upsample'> is not supported.No default rule defined.

In addition, raise same error when using nn.ConvTranspose2d to upsample.

Any help is appreciated!

@NarineK
Copy link
Contributor

NarineK commented Jul 9, 2021

@bugsuse, since nn.ConvTranspose2d is not a very common activation we don't have default rules for it. Have you tried to explicitly attach a rule to it similar to this tutorial ?
#668

@bugsuse
Copy link
Author

bugsuse commented Jul 9, 2021

@NarineK Thanks a lot! I tested it following the tutorial 668. I need to use nn.Upsample or nn.ConvTanspose2d to upsample output to the original size, because I want to use unet model for semantic segmentation task and explain it.

How should I add rules to support nn.Upsample or nn.ConvTranspose2d or skip these unsupported layers when using LRP in captum to calculate attributes?

@NarineK
Copy link
Contributor

NarineK commented Jul 9, 2021

@bugsuse for upsampling of ConvTranspose2d you might want to start w/Epsilon rule. Have you tried to set the rules with something like this ?

from captum.attr._utils.lrp_rules import EpsilonRule
model.layer.rule =  EpsilonRule()

@bugsuse
Copy link
Author

bugsuse commented Jul 10, 2021

Hey, @NarineK I'm new to captum and XAI. I don't clear that how to add w/Epsilon rule to nn.ConvTranspose2d? Is there a similar example in captum? I'd like to learn about and try to add? Thanks so much!

@bugsuse
Copy link
Author

bugsuse commented Jul 11, 2021

Hey @NarineK, I have tried add nn.ConvTranspose2d : EpsilonRule to SUPPORTED_LAYERS_WITH_RULES or SUPPORTED_NON_LINEAR_LAYERS into lrp.py, but I get AssertionError due to target dimension. The error information is below,

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-20-90353dc30936> in <module>
      1 lrp = LRP(model)
----> 2 attribution = lrp.attribute(vifea.cuda(), target=torch.unsqueeze(vtar, dim=1).cuda(), verbose=True)
      3 attribution = attribution.squeeze().permute(1, 2, 0).detach().numpy()

/data3/works/ai/phd/github/captum/captum/log/__init__.py in wrapper(*args, **kwargs)
     33             @wraps(func)
     34             def wrapper(*args, **kwargs):
---> 35                 return func(*args, **kwargs)
     36 
     37             return wrapper

/data3/works/ai/phd/github/captum/captum/attr/_core/lrp.py in attribute(self, inputs, target, additional_forward_args, return_convergence_delta, verbose)
    191         try:
    192             # 1. Forward pass: Change weights of layers according to selected rules.
--> 193             output = self._compute_output_and_change_weights(
    194                 inputs, target, additional_forward_args
    195             )

/data3/works/ai/phd/github/captum/captum/attr/_core/lrp.py in _compute_output_and_change_weights(self, inputs, target, additional_forward_args)
    341         try:
    342             self._register_weight_hooks()
--> 343             output = _run_forward(self.model, inputs, target, additional_forward_args)
    344         finally:
    345             self._remove_forward_hooks()

/data3/works/ai/phd/github/captum/captum/_utils/common.py in _run_forward(forward_func, inputs, target, additional_forward_args)
    453         else inputs
    454     )
--> 455     return _select_targets(output, target)
    456 
    457 

/data3/works/ai/phd/github/captum/captum/_utils/common.py in _select_targets(output, target)
    472             return torch.gather(output, 1, target.reshape(len(output), 1))
    473         else:
--> 474             raise AssertionError(
    475                 "Tensor target dimension %r is not valid. %r"
    476                 % (target.shape, output.shape)

AssertionError: Tensor target dimension torch.Size([16, 1, 1, 64, 64]) is not valid. torch.Size([16, 1, 1, 64, 64])

The code is below,

lrp = LRP(model) # Unet model for semantic segmentation
attribution = lrp.attribute(vifea.cuda(), target=torch.unsqueeze(vtar, dim=1).cuda(), verbose=True)
attribution = attribution.squeeze().permute(1, 2, 0).detach().numpy()

Is it because LRP of captum does not currently support semantic segmentation models?

@NarineK
Copy link
Contributor

NarineK commented Jul 12, 2021

@bugsuse, instead of tensor, do you mind representing target as a list of tuples similar to this ?
https://github.com/pytorch/captum/blob/master/tests/attr/helpers/test_config.py#L199

Remember that first dimension corresponds to the example index. Target for each example has to have length 4:
s.a. [(0, 0, 0, 0), (0, 0, 24, 24)... ]

@bugsuse
Copy link
Author

bugsuse commented Jul 13, 2021

@NarineK I tried to instead of tensor using the codes below, but raised a new error,

lrp = LRP(model)
attribution = lrp.attribute(vifea[0:1].cuda(), target=[(0, 0, 0)], verbose=True)
attribution = attribution.squeeze().permute(1, 2, 0).detach().numpy()

The shape of vifea and unet model output is (16, 24, 64, 64) and (16, 1, 64, 64).

The information of error is below,

Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b57472e8be0> on layer Conv2d(24, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b5747309190> on layer BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b5747309d60> on layer Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b5747232580> on layer BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b57473571c0> on layer Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b5747357a90> on layer Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b5747357dc0> on layer BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b5747357c10> on layer Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b5747357730> on layer BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b5747357d30> on layer Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b5747327df0> on layer Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b5747327fd0> on layer BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b5747327f10> on layer Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b5747327580> on layer BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b5747327e20> on layer Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b5747327f70> on layer Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b5747327ee0> on layer BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b5747327fa0> on layer Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b574735a940> on layer BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b574735afa0> on layer Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b574735aac0> on layer Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b574735a2e0> on layer BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b574735a790> on layer Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b574735abb0> on layer BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b574735ad00> on layer Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b574735a970> on layer Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b574735a6d0> on layer BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b574735a820> on layer Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b574735a6a0> on layer BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b574735ae20> on layer Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b574735aa90> on layer BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b574735a9a0> on layer Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b574735a7f0> on layer BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b574735a520> on layer Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b574735a310> on layer BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b574735a430> on layer Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b574735a070> on layer BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b574735a4f0> on layer Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b574735a580> on layer BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b574735a7c0> on layer Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b574735a910> on layer BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b57472426a0> on layer Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b57472421f0> on layer BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b574711c6d0> on layer Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b574711ce80> on layer BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b574711c7f0> on layer Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b574711c100> on layer BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b574711cf70> on layer Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b574711cc40> on layer BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Applied <captum.attr._utils.lrp_rules.EpsilonRule object at 0x2b574711cd90> on layer Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1))
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-107-b865fd048fdd> in <module>
      1 lrp = LRP(model)
----> 2 attribution = lrp.attribute(vifea[0:1].cuda(), target=[(0, 0, 0)], verbose=True)
      3 attribution = attribution.squeeze().permute(1, 2, 0).detach().numpy()

/data3/yangl/works/ai/phd/github/captum/captum/log/__init__.py in wrapper(*args, **kwargs)
     33             @wraps(func)
     34             def wrapper(*args, **kwargs):
---> 35                 return func(*args, **kwargs)
     36 
     37             return wrapper

/data3/yangl/works/ai/phd/github/captum/captum/attr/_core/lrp.py in attribute(self, inputs, target, additional_forward_args, return_convergence_delta, verbose)
    197             # propagation and execute back-propagation.
    198             self._register_forward_hooks()
--> 199             normalized_relevances = self.gradient_func(
    200                 self._forward_fn_wrapper, inputs, target, additional_forward_args
    201             )

/data3/yangl/works/ai/phd/github/captum/captum/_utils/gradient.py in compute_gradients(forward_fn, inputs, target_ind, additional_forward_args)
    117         # torch.unbind(forward_out) is a list of scalar tensor tuples and
    118         # contains batch_size * #steps elements
--> 119         grads = torch.autograd.grad(torch.unbind(outputs), inputs)
    120     return grads
    121 

~/tools/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/autograd/__init__.py in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused)
    200         retain_graph = create_graph
    201 
--> 202     return Variable._execution_engine.run_backward(
    203         outputs, grad_outputs_, retain_graph, create_graph,
    204         inputs, allow_unused)

RuntimeError: hook 'backward_hook_activation' has changed the size of value

@NarineK
Copy link
Contributor

NarineK commented Jul 14, 2021

@bugsuse, this is interesting. Maybe some activations got reused. Will you be able to share a collab notebooks so that we can debug it ?
cc: @nanohanno

@bugsuse
Copy link
Author

bugsuse commented Jul 15, 2021

@NarineK see in colab notebook, the pretrained weights and test data is here. Please let me know if you need more infomation.

@nanohanno
Copy link
Contributor

I have not been able to run the code yet to check the behaviour. Has it maybe something to do with the additional output added in 084d755 ?

BTW, I just saw that register_backward_hook(hook) is deprecated in the current PyTorch version and is exchanged by register_full_backward_hook(hook).

https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=backward_hook#torch.nn.Module.register_backward_hook

@NarineK
Copy link
Contributor

NarineK commented Jul 23, 2021

@nanohanno, register_full_backward_hook isn't fully backward compatible with register_backward_hook . We are waiting for the new version to fix backward compatibility issue.

We were expecting additional arguments to be needed for dropouts. We need to debug this case. @vivekmig and @nanohanno do you have time to debug this issue ?

@survivebycoding
Copy link

@NarineK

model.layer.rule = EpsilonRule()

I tried to use this but my layer name is : layer_0/kernel:0
when i am using this name i am getting syntax error. I tried 'layer_0/kernel:0' and "layer_0/kernel:0" as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants