-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New FusedAdam (since #424 commit 8 Aug 2019) Issues #475
Comments
In terms of accuracy, updated fusedadam incoporated this upstream fix In terms of memory There could also be other change affect how memory get reused, but those should not be strict more use. With the dynamic nature of pytorch's allocator, any change on order and life cycle of tensor could cause OOM if you are very close to limit, even though you are not using more memory in total. |
I'll go for a full run and check if there is some regression, thanks for the explanation.
What is the rationale for this change ? |
We want fused optimizers to work with AMP O1, while not doing it by special casing in both optimizer and amp backend. @mcarilli speaking of optimizer not being peak, I think new fusedadam keep more memory than before after zero_grad now. |
OK, I trained a bit longer but I am pretty sure something is not so good. Old FusedAdam: New FusedAdam: Losing 1 to 2 BLEU depending on the testset. |
Hi @vince62s I have tested the new fused adam against latest upstream pytorch adamW, with this example. |
My O2 example above had already converged. |
How big is the difference on final result? It is 'minor' per upstream. pytorch/pytorch@fed5ca1 |
It seems it is not related to the "minor" fix at all. At OpenNMT-py, before the integration of the new Apex API, (this commit: OpenNMT/OpenNMT-py@aaa220b) After the above commit, Adam in FP16 was using the new API when FusedAdam was still using the FP16_Optimizer path as mentioned by @FDecaYed in a previous post. Since then we did all our training with FusedAdam. Starting Aug 7/8 with the merge of #424 FusedAdam switched to the new API. We have tested several run with the same preprocessed data, and all hyperparameters being the same. Here is what we have: Bottom line: the issue is with the new API or the way we use the new API. We really need to find out what is going on because if we want no regression we would have to go back to FP16_Optimizer. NB: I can give you some detailed numbers by email if needed. Also The dataset can be downloaded and I can give you a small script so that you can replicate. Cheers. |
Looks to me like your integration isn't right. You're passing a list with 2 models, but only receiving one on the LHS. Also, if model, optimizer = apex.amp.initialize(
model,
optimizer,
opt_level=opt.apex_opt_level,
loss_scale=loss_scale,
# keep_batchnorm_fp32=False if opt.optim == "fusedadam" else None)
keep_batchnorm_fp32=True) One benefit of the latest FusedAdam API is that it can handle a mixture of fp32 and fp16 params, so you may say |
Your suggestion does not work (we tested it as a first intention at the beg, and I retried it) it spits a cast error. Not even talking about FusedAdam, is there anywhere a regression test between old API (FP16_optimizer) and New API, with plain Adam, on a real word dataset (not just a toy set) ? I'll give a try with keep_batchnorm_fp32=True |
There are small unit tests but not an end to end test. However, we also use the new FusedAdam for full workloads and it seems fine. Does the latest version with |
Does amp api with pytorch adam work before #424 ? This question will help us determine where we look into next. Some other points:
|
To your first question: However, when we use amp the way we use it, with old fusedadam then it works, maybe because it falls back to fp16_optimizer, but at least it shows that we call amp correctly. NB: can you be more specific for the clipping code thing? (yes sjould be unrelated to our issue since I am not using max_grad_norm in my tests) |
For reference, how does accuracy look with O1? In general O1 is always best practice, and the new FusedAdam should permit it (ie if that was the only reason you used O2, it is no longer relevant, but maybe there is some other reason). |
For your way of using amp O2, I feel it is probably ok For clipping: |
Just finished the first 5000 iterations on 4 GPU with keep_batchnorm_fp32=True ==> no change. |
With O1, the model weights are left as FP32 (in other words, the model weights are the master weights). O1 is more conservative about casting to FP32 for internal functions (which is why it's recommended as an out-of-the-box approach) so it may use more memory, but hopefully not much. |
it actually starts much much better, letting it run with Adam first, will do Fusedadam just after. |
O2 was our original hypothesized recipe for mixed precision, but maintaining separate master weights is more confusing for everyone. IMO the only remaining utility of O2 is to support certain internal use cases. O1 is what I hope all public codes are able to use, and O1 is the implementation that's I'm working on for Amp upstream (pytorch/pytorch#25081). The API will be different (as requested by upstream) but more flexible and powerful. |
one thing though: |
It depends. Currently O1 in Apex relies on a lot of Python side patching logic, so if any sections of your model are CPU bound, those sections may slow down. The upstream integration of O1-style casting will be entirely on the C++ side, and be faster. |
okay an update with FusedAdam O1: Getting this: |
Since I need a working solution, I tried the following: I copied the old FusedAdam class in my optimizers.py code, and use the current master FP16_Optimizer wrapper. It works fine, and actually a bit faster. If I try to use the current master master FusedAdam in the same way (taht is to say FusedAdam without amp) then it does not work. I think I may stick to the old way until all of this gets integrated in Pytorch. But if O1 is the right way then something might be missing for FusedAdam. |
Hmm it's beginning to sound more like there's something wrong with new FusedAdam, but like I said, we have used it ourselves for several applications. Where are you implementing gradient clipping in your code, if anywhere? |
here: https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/utils/optimizers.py#L328-L339 |
Yes, with the new FusedAdam+new API (no FP16_Optimizer) I don't believe you need to special-case the clipping code for FP16 at all. You can replace if self._fp16:
if hasattr(self._optimizer, "update_master_grads"):
self._optimizer.update_master_grads()
if hasattr(self._optimizer, "clip_master_grads") and \
self._max_grad_norm > 0:
import apex
torch.nn.utils.clip_grad_norm_(
apex.amp.master_params(self), self._max_grad_norm)
for group in self._optimizer.param_groups:
group['lr'] = learning_rate
if not self._fp16 and self._max_grad_norm > 0:
clip_grad_norm_(group['params'], self._max_grad_norm) with simply for group in self._optimizer.param_groups:
group['lr'] = learning_rate
if not self._fp16 and self._max_grad_norm > 0:
clip_grad_norm_(group['params'], self._max_grad_norm) if your intention is to clip per-param-group. However, if you're not clipping at all and still seeing problems, this is likely unrelated...Are you passing any sparse params/gradients to FusedAdam? I don't think it supports sparse gradients, but I don't think the old FusedAdam did either... |
no |
RE: clipping I intended to do this: https://github.com/OpenNMT/OpenNMT-py/pull/1560/files#diff-423be3bd1890af1b892704ba31891075R358 |
In the non-legacy case you are clipping per param group though, as opposed to clipping the gradients of all the params together. |
@mcarilli Using old fusedadam thus clipping by group is also an known issue. We are working on our bert fine-tuning example to see what's the best way to address that. |
@FDecaYed I am not sure what we want to prove here.
Hope this is clearer. |
|
ok...
Anyhow, we now need to find out why AdamW makes things unstable with FP16. |
If new upstream adamw with O1 also causes the crash(and I suspect upstream adamw with o2 also possibly cause similar accuracy drop we see in new fusedadam with o2), then it does seems the algorithm change causes issue when work with fp16 I feel It’s a problem worth looking into. Let’s probably start with a repro and add back old algo to fused Adam |
@mcarilli shall we look into and track this? |
Team,
This huge PR #424 was not squashed and commits in a long period of time trigger lots of side effects.
As an example, if I use FusedAdam as of this commit 4a8c4ac then the memory footprint in GPUs enables to train with a certain batch size (eg 3072 tokens per batch).
On master, I have to reduce the batch size not to get a CUDA OOM.
(eg tokens 2944 per batch use case is OpenNMT-py)
More concerning, during training, using the same config (batch size 2944 tokens) and all other params being equal:
MASTER of Sept 5 2019
[2019-09-05 12:29:50,751 INFO] Step 100/50000; acc: 3.55; ppl: 6812.21; xent: 8.83; lr: 0.00002; 23557/28574 tok/s; 148 sec
[2019-09-05 12:31:02,317 INFO] Step 200/50000; acc: 6.69; ppl: 2112.41; xent: 7.66; lr: 0.00003; 48483/58933 tok/s; 220 sec
[2019-09-05 12:32:14,252 INFO] Step 300/50000; acc: 10.53; ppl: 597.04; xent: 6.39; lr: 0.00005; 48349/58938 tok/s; 292 sec
[2019-09-05 12:33:26,245 INFO] Step 400/50000; acc: 13.50; ppl: 339.95; xent: 5.83; lr: 0.00006; 48524/58874 tok/s; 364 sec
[2019-09-05 12:34:41,417 INFO] Step 500/50000; acc: 15.14; ppl: 235.72; xent: 5.46; lr: 0.00008; 46461/56365 tok/s; 439 sec
[2019-09-05 12:35:53,654 INFO] Step 600/50000; acc: 17.41; ppl: 176.58; xent: 5.17; lr: 0.00009; 48243/58731 tok/s; 511 sec
[2019-09-05 12:37:06,072 INFO] Step 700/50000; acc: 19.30; ppl: 139.99; xent: 4.94; lr: 0.00011; 48149/58553 tok/s; 584 sec
Commit 4a8c4ac
[2019-09-05 12:42:05,446 INFO] Step 100/50000; acc: 3.61; ppl: 6631.33; xent: 8.80; lr: 0.00002; 23708/28735 tok/s; 147 sec
[2019-09-05 12:43:16,428 INFO] Step 200/50000; acc: 7.97; ppl: 1824.98; xent: 7.51; lr: 0.00003; 48888/59853 tok/s; 218 sec
[2019-09-05 12:44:27,514 INFO] Step 300/50000; acc: 11.75; ppl: 524.17; xent: 6.26; lr: 0.00005; 48952/59220 tok/s; 289 sec
[2019-09-05 12:45:38,700 INFO] Step 400/50000; acc: 14.74; ppl: 278.04; xent: 5.63; lr: 0.00006; 49226/59553 tok/s; 360 sec
[2019-09-05 12:46:53,512 INFO] Step 500/50000; acc: 17.45; ppl: 181.13; xent: 5.20; lr: 0.00008; 46560/56793 tok/s; 435 sec
[2019-09-05 12:48:05,109 INFO] Step 600/50000; acc: 19.77; ppl: 132.15; xent: 4.88; lr: 0.00009; 48644/59191 tok/s; 507 sec
[2019-09-05 12:49:16,393 INFO] Step 700/50000; acc: 22.03; ppl: 101.56; xent: 4.62; lr: 0.00011; 48692/59330 tok/s; 578 sec
The accuracy / ppl seem much better on the old FusedAdam.
Any clue ?
@FDecaYed @mcarilli
The text was updated successfully, but these errors were encountered: