Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

revlib can not work in torch amp. #7

Open
JAYatBUAA opened this issue Jul 24, 2023 · 8 comments
Open

revlib can not work in torch amp. #7

JAYatBUAA opened this issue Jul 24, 2023 · 8 comments

Comments

@JAYatBUAA
Copy link

dear authors,
when using revlib in torch amp, it reports error as follow:

Traceback (most recent call last):
File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/revlib/core.py", line 130, in backward
mod_out = take_0th_tensor(new_mod.wrapped_module(y0, *ctx.args, **ctx.kwargs))
File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 613, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 597, in _conv_forward
return F.conv3d(
RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same

@JAYatBUAA
Copy link
Author

dear authors, how to solve it, thanks in advance

@ClashLuke
Copy link
Member

Unfortunately, I don't know. I don't use torch.amp myself, and I last touched RevLib a while ago. Please let me know if you have a minimal script to reproduce the error or manage to fix it. So that you know, all pull requests are welcome.
Just so we're on the same page: torch.amp doesn't save memory but instead is only helpful for speed improvements by downcasting matrix multiplications to fp16. Is this what you're after? If you want the memory improvements, please give RevLib's intermediate casting a try.\

@JAYatBUAA
Copy link
Author

Unfortunately, I don't know. I don't use torch.amp myself, and I last touched RevLib a while ago. Please let me know if you have a minimal script to reproduce the error or manage to fix it. So that you know, all pull requests are welcome. Just so we're on the same page: torch.amp doesn't save memory but instead is only helpful for speed improvements by downcasting matrix multiplications to fp16. Is this what you're after? If you want the memory improvements, please give RevLib's intermediate casting a try.\

dear author, I guess the error happens in the RevResNet backward pass where feature map dtype (float16) is not match to the conv weights dtype (float32) and this will not happen in forward pass, because the forward pass is warpped in torch.cuda.amp.autocast() context, where the conv weights dtype will automatically convert to the half.

@ClashLuke
Copy link
Member

Do you have a minimal example to reproduce the error?

@JAYatBUAA
Copy link
Author

Do you have a minimal example to reproduce the error?

Due to an important deadline recently, I'll try to give you a reply as soon as possible. Thanks for your help.

@JAYatBUAA
Copy link
Author

when loss.backward() warpped in torch.cuda.amp.autocast() context, this error is not reported.

@ClashLuke
Copy link
Member

ClashLuke commented Aug 1, 2023

Please share a minimal script to reproduce this error. I'll be able to take it from there.

The best next action would be a PR with a unit test for torch amp.
Alternatively, RevLib is open to contributions. You're welcome to submit a PR for the fix :)

@JAYatBUAA
Copy link
Author

Really thanks for your RevLib. I hope to contribute mysellf once I have enough time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants