-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
revlib can not work in torch amp. #7
Comments
dear authors, how to solve it, thanks in advance |
Unfortunately, I don't know. I don't use |
dear author, I guess the error happens in the RevResNet backward pass where feature map dtype (float16) is not match to the conv weights dtype (float32) and this will not happen in forward pass, because the forward pass is warpped in torch.cuda.amp.autocast() context, where the conv weights dtype will automatically convert to the half. |
Do you have a minimal example to reproduce the error? |
Due to an important deadline recently, I'll try to give you a reply as soon as possible. Thanks for your help. |
when loss.backward() warpped in torch.cuda.amp.autocast() context, this error is not reported. |
Please share a minimal script to reproduce this error. I'll be able to take it from there. The best next action would be a PR with a unit test for torch amp. |
Really thanks for your RevLib. I hope to contribute mysellf once I have enough time. |
dear authors,
when using revlib in torch amp, it reports error as follow:
Traceback (most recent call last):
File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/revlib/core.py", line 130, in backward
mod_out = take_0th_tensor(new_mod.wrapped_module(y0, *ctx.args, **ctx.kwargs))
File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 613, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 597, in _conv_forward
return F.conv3d(
RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same
The text was updated successfully, but these errors were encountered: