-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Native Amp Support #1337
Comments
I think the |
this is awesome. will definitely add! eta on the next pt release? @mcarilli does it still have the issues with saving/loading weights with the loss scaling factor? @PyTorchLightning/core-contributors anyone interested in making this change? one key consideration is saving/loading weights when amp scales the loss. |
Yes, bitwise accurate saving/restoring is supported. You just need to call your GradScaler instance's state_dict() and load_state_dict() alongside the usual model/optimizer state_dict/load_state_dict calls. |
@mcarilli any chance you'd be interested in submitting the PR? was going to add checks like: if pytorch.__version__ >= 1.6:
# new amp stuff
else:
# old amp stuff |
hmm i don't know the lightning codebase at all, aside from the interface. It would take me longer than early next week to be sure I was making the right changes in the right places. The version is a more complex string though, so I'd use something like version_ge_16 = False
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 6):
version_ge_16 = True |
not sure about the particular condition |
Happy to review in-progress PRs though. One key point is that Also that versioning condition is based on what works for us. The particular number (1.6+) is a decent criterion for native amp availability, the window of commits with You could sidestep has_native_amp = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast") Now that I mention it that's probably better. |
Native automatic mixed precision support (
torch.cuda.amp
) is finally merged:https://pytorch.org/docs/master/amp.html
https://pytorch.org/docs/master/notes/amp_examples.html
Apex Amp has many known pain points (extension builds, forward/backward compatibilty, DataParallel support, flaky checkpointing, i don’t even know if it can be hacked to handle double backward/gradient penalty, others…).
torch.cuda.amp
fixes all these, the interface is more flexible and intuitive, and the tighter integration brings more future performance optimizations into scope.If you want to talk about adding
torch.cuda.amp
to Lightning, with an eye towards it becoming the true source of mixed precision and replacing Apex, message me on Pytorch slack anytime. I pinged you there as well, but I’m not sure if you monitor it habitually.The text was updated successfully, but these errors were encountered: