Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Two bugs in AdamW #14539

Closed
manuelciosici opened this issue Nov 26, 2021 · 16 comments · Fixed by #14744
Closed

Two bugs in AdamW #14539

manuelciosici opened this issue Nov 26, 2021 · 16 comments · Fixed by #14744

Comments

@manuelciosici
Copy link
Contributor

Environment info

  • transformers version: 4.13.0.dev0
  • Platform: Linux-3.10.0-1160.45.1.el7.x86_64-x86_64-with-glibc2.17
  • Python version: 3.9.7
  • PyTorch version (GPU?): 1.10.0+cu113 (True)
  • Tensorflow version (GPU?): 2.7.0 (False)
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help

@thomwolf and @stas00 should be able to help based on git blame

Information

There are two bugs in the implementation of AdamW.

Here's the current code https://github.com/manuelciosici/transformers/blob/04683c0659aacf31a1e1df8aa2e6cf7b447a6f12/src/transformers/optimization.py#L324-L371

Weight decay bug

Look at lines 369-370. The weight decay is multiplied with p.data which no longer corresponds to theta_{t-1} since p.data was modified in line 369. Below is a picture of Algorithm 2 from the original Adamw paper that shows on line 12 that the weight decay should be multiplied with the previous step's parameters (i.e., theta_{t-1}).

Screen Shot 2021-11-26 at 09 19 33

From what I can tell, this is a regression since the original AdamW implementation in transformers applied weight decay properly. Here's the commit that introduces the bug ec07cf5#diff-40c6163602943c11431f1ec360299a7646bb436c691a646b9f54b2284f556ce0

For confirmation that weight decay is currently buggy, see the original AdamW implementation, where, on line 74, the weight decay is multiplied with the old parameters as opposed to the new parameters that are calculated on line 71.

Denominator computation bug

The second bug appears in the computation of the denominator corresponding to line 10 in Algorithm 2 above. In the current code (see link in the Information section), on line 351, the denominator excludes the division by math.sqrt(bias_correction2). On line 357, division by math.sqrt(bias_correction2) appears, but, by this time, eps has already been added to denom, making the division not equivalent to line 10 in Algorithm 10.

From what I can tell, this bug was also introduced as part of commit ec07cf5#diff-40c6163602943c11431f1ec360299a7646bb436c691a646b9f54b2284f556ce0. The previous line update = next_m / (next_v.sqrt() + group['e']) was correct.

For confirmation that the denominator is not properly calculated, see the original AdamW implementation, where, on line 64 the denominator is computed.

To reproduce

Steps to reproduce the behavior:

  1. Checkout the branch at https://github.com/manuelciosici/transformers/tree/reveal_broken_adamw:
  2. Run the unit tests in tests/test_optimization.py
  3. Tests test_compare_adamw_no_weight_decay and test_compare_adamw_with_weight_decay should fail (see the attached failed_tests.txt)

Expected behavior

The two implementations of AdamW should match their parameter updates.

Proposed fix

Checkout the branch at https://github.com/manuelciosici/transformers/tree/fix_adamw . It contains both the unit tests above and a fix for both bugs mentioned above.

I can make a PR once we agree on the two bugs and the fix.

@stas00
Copy link
Contributor

stas00 commented Nov 30, 2021

Thank you for submitting this bug report and the investigation, @manuelciosici

Look at lines 369-370. The weight decay is multiplied with p.data which no longer corresponds to theta_{t-1} since p.data was modified in line 369.

You must have meant line 359 in the sentence above.

Your investigation looks correct on both accounts, @manuelciosici. I was able to follow through your helpful notes.

I suspect the denominator buglet was an optimization since epsilon is tiny and it's there only to avoid a division by zero. The missing part of the denominator is eps*(sqrt(bias_correction2)-1). Since you can choose a slightly different epsilon w/o breaking the algorithm then I believe this missing part is practically irrelevant. Please correct me if I'm wrong. If the current code remains unchanged we should definitely add a comment that eps1+eps_2 = eps3 is still an eps.

The decay part applied to t instead of t-1 does appear to be significant.

Since I wasn't the one involved in writing this code (I only did a small adjustment) I will let @thomwolf and perhaps @LysandreJik and @sgugger to confirm.

p.s. I did see references where the choice of epsilon was important.

@sgugger
Copy link
Collaborator

sgugger commented Nov 30, 2021

I was not the one who made the adjustments, which may have been made on purpose for some reason.

I don't think the current behavior should be changed (even if different from the original paper) as it might break all reported results on all our examples, and this implementation of AdamW has worked quite well on all our tasks. Furthermore, PyTorch now has an implementation of AdamW, so people should use that one for a "bug-free" version.

@stas00
Copy link
Contributor

stas00 commented Dec 2, 2021

@manuelciosici, if you could indulge my curiosity - what was the impetus for checking the AdamW implementation?

I'm just trying to understand the actual impact of this different implementation on the training stability/convergence/etc.

Thank you.

@manuelciosici
Copy link
Contributor Author

@stas00 I was reading it as a reference implementation while trying to understand deepspeed's CPU AdamW implementation.

One thing to note is that magnitude of both bugs is a function of AdamW's hyper-parameters (i.e., it is influenced by learning rate, epsilon, and weight decay). For example, for prompt tuning where learning rates can be as high as 0.3, the effect of buggy weight decay will be more pronounced.

@sgugger I understand the concerns that fixing the optimizer will lead to misalignment with existing examples and documentation. However, ignoring the bugs is not good either. Since opening the issue, found that I was not the first one to discover the weight decay issue. I expect that, if the code stays as is, the two bugs will be rediscovered periodically.

An alternative to ignoring the bugs would be for transformers to deprecate its AdamW implementation with a removal target of, say transformers>=5.0.0 (or 6.0.0 if a longer sunset is necessary) and add a comment in the AdamW implementation explaining the two bugs. This way, current examples & documentation can continue to work as expected, while users migrate to torch's AdamW. How does this sound?

@sgugger
Copy link
Collaborator

sgugger commented Dec 2, 2021

Yes, I agree with your last suggestion @manuelciosici and I think this is the right way to go. Deprecation with a removal of v5.0.0 sounds about right, and then the Trainer can have an additional TrainingArguments that one can use to already use the right implementation of AdamW from PyTorch instead of our class.

Are you interested of making a PR for this @manuelciosici ?

@stas00
Copy link
Contributor

stas00 commented Dec 2, 2021

Additionally to @sgugger's notes: the updated AdamW API should include a new arg like disable_deprecate_warning=False - so that by default the deprecation is printed but the user should be able to shut it off if they want to continue using this version.

then the Trainer can have an additional TrainingArguments that one can use to already use the right implementation of AdamW from PyTorch instead of our class.

The question is whether we switch HF Trainer to use torch's implementation by default or not.

Also, if we are rewriting the optimizer API, perhaps we can add a generic --optim flag which could support various optimizers. I'm proposing that since I'm going to suggest shortly for HF Trainer to support BNB https://github.com/facebookresearch/bitsandbytes which saves 3/4 of optim memory and so far tested to work great.

So we can have:

  • --optim adamw_hf
  • --optim adamw_torch
  • --optim adamw_bnb
  • --optim some_other

@manuelciosici
Copy link
Contributor Author

@sgugger Yes. I can to write a PR deprecating AdamW, including @stas00 's suggestions.

@stas00 BNB sounds exciting. How should we split the work into PRs? I can also help with BNB. I think that could be fun.

@stas00
Copy link
Contributor

stas00 commented Dec 2, 2021

We don't need to worry about BNB here, I was just suggesting to add a generic --optim HF Trainer arg, rather than for example --use-torch-adamw, which opens up opportunities for new optimizers to be supported.

Adding BNB to transformers is a bit intricate since it calls for an embedding layernorm which we currently don't have. I will open an issue where we can discuss the details. That additional layernorm proved to be an essential for stability of gpt-104B training we are working on at BigScience.

@sgugger
Copy link
Collaborator

sgugger commented Dec 3, 2021

The plan is not to add any new optimizer to the Transformers library. It is a library of models, not optimizers, and no one in the team has the bandwidth to support other optimizers. We are deprecating the few we have. Adding support for optimizers implemented in other libraries is completely fine however.

Adding an --optim argument is fine, though the default of the learning rate might not be suitable for any optimizer added, so we might have to be careful with the options accepted.

The question is whether we switch HF Trainer to use torch's implementation by default or not.

Given the fact is is breaking, the Trainer should stay with the current optimizer for now, and we can either switch in v5
or when someone has checked all examples and seen comparable results, whichever comes first.

@stas00
Copy link
Contributor

stas00 commented Dec 3, 2021

Apologies for not being clear. I was not proposing to add a new optimizer, but to add integration for a new optimizer. i.e. we will not need to support it. It's just that it's not just importing it, but requires some tweaks on our side. I will make a separate issue about it.

Given the fact is is breaking, the Trainer should stay with the current optimizer for now

OK, so the default remains the current version.

Here is the updated spec then:

So with HF Trainer:

  1. default is current --optim adamw_hf, but prints deprecation warning which includes info on how to enable torch's version
  2. --optim adamw_torch - switched to torch.AdamW

With AdamW class itself

  1. the default is to print deprecation warning, unless no_deprecation_warning=True is passed.

Sylvain, please confirm that this is the correct spec before @manuelciosici starts on it. Thank you.

@sgugger
Copy link
Collaborator

sgugger commented Dec 3, 2021

Thanks for the summary @stas00, this looks great to me!

@manuelciosici
Copy link
Contributor Author

@stas00 Thank you. I work on this during the weekend.

@stas00
Copy link
Contributor

stas00 commented Dec 9, 2021

The NVIDIA engineers have been profiling a few things and torch's AdamW is faster than ours (apparently apex's is even faster), so I will add this to the performance docs once I'm able to benchmark this when your PR is ready, @manuelciosici

#14708

@stas00
Copy link
Contributor

stas00 commented Dec 9, 2021

It appears that apex.optimizers.FusedAdam is even faster. So we can plug that one in as well.

@joostgrunwald
Copy link

joostgrunwald commented Mar 16, 2022

This implementation of Adamw, Although slower, seems to give me better performance then the pytorch one in terms of acc and F1. I'm not sure if I'm the only one with this result but if this is the case for multiple persons, deprecating it could be a shame.

@stas00
Copy link
Contributor

stas00 commented Mar 16, 2022

The key to understand is that it's not implementing AdamW, but a slightly different algorithm.

Users expect exact algorithm implementation out of the box and if it's not exact it should be named differently.

Perhaps AdamWHF?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants