-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.cuda.amp > apex.amp #818
Comments
Can torch.cuda.amp be used only for inferences on a FP32 model? See #750 and #809 |
Yes. |
Before we dive into torch.cuda.amp, should we expect a behavior change versus this issue #475 ? |
@mcarilli what about the opt_level O1 / O2 , etc... I can't find whether that's already natively supported by |
Another question: Will this be supported by Torchscript? |
How can I migrate from apex.amp to torch.cuda.amp if I already have pre-trained model with apex wrapper? Apex-wrapped models now can load like regular PyTorch models? |
@Damiox
@blizda If you have a model state dict from any source saved on disk, you shouldn't need to do anything special to migrate to native Amp. Create a model in default precision (fp32), call After migrating to native amp, for bitwise accurate saving/restoring, include calls to saved = scaler.state_dict() and scaler.load_state_dict(saved) along side your usual state_dict/load_state_dict calls. |
@vince62s However, there may be a problem with
|
seem like the model after training using torch.cuda.amp's autocast(), its dtype is fp32, if want to deploy the model , dose it need covert to fp16 manualy? it is little bit confuse. |
@mcarilli It's clear how to switch with o1, but how I can use o2 optimization with torch.cuda.amp? |
@ysystudio Autocast does not touch the model object itself, so its dtype (param type) remains as you created it (leaving it to default FP32 is recommended). Save the trained model then deploy it in whatever format you want.
@trytolose O2 isn't a thing in |
@mcarilli |
@SeungjunNah The options available in native Amp are a better representation of what users should control. Some apex options, like opt-level O2, are unsafe for general use. If an option is present in apex amp but not present in native amp, it's probably not an important knob for the user to experiment with, therefore including it would make the API more cluttered and confusing. For example, I'm not aware of any network where setting max_loss_scale was required for convergence. If you have evidence that max_loss_scale is required, I can add it. In general, |
@mcarilli A workaround could be to recompute the loss scaling until the overflow is avoided but I didn't find a way to implement it myself. I'd appreciate if you could add max_loss_scale option to |
That's true, but N is a large value (2000 by default). After the initial few iterations where GradScaler calibrates, it settles to a steady state where step skipping should only occur once every 2000 iterations (when it attempts a higher scale value). Generally, letting GradScaler dynamically find a steady state scale value is the best approach. Skipping one out of every 2000 iterations should have a negligible effect on both convergence and performance. What you're suggesting is more like "static loss scaling": locking the scale to a user-defined value rather than letting GradScaler adjust it dynamically. This is also possible (though not recommended) with the native API without an additional max_loss_scale constructor arg: call |
Ok, skipping with 1/2000 ratio doesn't hurt practically. I wanted to see if there were ways to control the number of iterations completely, though. Thanks for the explanation! |
@mcarilli I just watched a video that says you can used FusedAdam, FusedSGD, etc. for a faster optimizer when using amp. How do we use this in native Pytorch 1.6 with amp? Ty |
@mcarilli
|
I can't find the example that test the performance in imagenet with torch.cuda.amp.
|
@mcarilli I found one case where we might need min_loss_scale. In my training with AMP, the first several iterations have NaN gradient quite often. Thus the first usable scaling value becomes 0.0325 (or something like that). Does a scaling value make sense? |
"O2" is stable for me where "O1" and native amp give me NaNs. It would be really nice if there were some way to duplicate 02 behavior using native torch.cuda.amp. I've tried casting all batch norms to 32, but that didn't do it. So I guess something else is happening under the hood. |
For a while now my main focus has been moving mixed precision functionality into Pytorch core. It was merged about a month ago:
https://pytorch.org/docs/master/amp.html
https://pytorch.org/docs/master/notes/amp_examples.html
and is now usable via master or nightly pip/conda packages. (Full features did not make the 1.5 release, unfortunately.)
torch.cuda.amp
is more flexible and intuitive, and the native integration brings more future optimizations into scope. Also,torch.cuda.amp
fixes many ofapex.amp
's known pain points. Some things native amp can handle that apex amp can't:torch.cuda.amp.autocast()
has no effect outside regions where it's enabled, so it should serve cases that formerly struggled with multiple calls toapex.amp.initialize()
(including cross-validation) without difficulty. Multiple convergence runs in the same script should each use a fresh GradScaler instance, but GradScalers are lightweight and self-contained so that's not a problem.If all you want is to try mixed precision, and you're comfortable using a recent Pytorch, you don't need Apex.
The text was updated successfully, but these errors were encountered: