Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecates AdamW and adds --optim #14744

Merged
merged 55 commits into from
Jan 13, 2022

Conversation

manuelciosici
Copy link
Contributor

@manuelciosici manuelciosici commented Dec 13, 2021

What does this PR do?

Fixes #14539

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@stas00

Misc

@stas00

  • I added FusedAdam based on your comment in Two bugs in AdamW #14539
  • Since both --optim adafactor and --adafactor server the same purpose, I marked --adafactor as deprecated. I copy-pasted a deprecation warning that mentions transformers version 5. Let me know if the deprecation warning should say something else.
  • Let me know if I missed anything else.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for implementing this. I left some nits and there is one line missing which shows we need a test of this new feature :-)

I'm also wondering if we shouldn't create an Enum for all optimizer values that are supported?

src/transformers/optimization.py Outdated Show resolved Hide resolved
src/transformers/optimization.py Outdated Show resolved Hide resolved
src/transformers/optimization.py Outdated Show resolved Hide resolved
src/transformers/optimization.py Outdated Show resolved Hide resolved
src/transformers/trainer.py Outdated Show resolved Hide resolved
src/transformers/training_args.py Outdated Show resolved Hide resolved
src/transformers/training_args.py Outdated Show resolved Hide resolved
manuelciosici and others added 6 commits December 13, 2021 11:49
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@manuelciosici
Copy link
Contributor Author

@sgugger Thank you for the suggestions. I think an enum makes sense. I will add that and fix the missing optimizer_cls line.

Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for working on it, @manuelciosici

I made a few proposals about consistency

Please let me know if it's not clear.

(and I edited/removed my earlier comments on me thinking that the actual adafactor optimizer being deprecated, which was my reading the code incorrectly).

@@ -818,17 +818,43 @@ def create_optimizer(self):
"weight_decay": 0.0,
},
]
optimizer_cls = Adafactor if self.args.adafactor else AdamW
if self.args.adafactor:
if self.args.adafactor and self.args.optim not in {"adamw_hf", "adafactor"}:
Copy link
Contributor

@stas00 stas00 Dec 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this logic is unnecessary complicated, IMHO

  1. deprecate --adafactor in favor of --optim adafactor
  2. if --adafactor is passed set --optim adafactor

now you no longer need to ever consider self.args.adafactor other than to deprecate it.

src/transformers/trainer.py Outdated Show resolved Hide resolved
src/transformers/training_args.py Show resolved Hide resolved
src/transformers/training_args.py Outdated Show resolved Hide resolved
src/transformers/training_args.py Outdated Show resolved Hide resolved
@stas00
Copy link
Contributor

stas00 commented Dec 27, 2021

@manuelciosici. have we lost you or are you just on an extended vacation?

Let's finish this up, so we can test out various optimizers. I will sync the recent changes in doc formatting.

I was just told that apex.optimizers.FusedAdam is an even faster fused optimizer than torch's - we added it already but going to benchmark it.

@stas00 stas00 self-assigned this Dec 27, 2021
@stas00
Copy link
Contributor

stas00 commented Dec 29, 2021

update: see updated benchmarks here:

  1. RTX-3090
  2. A100

I'm working on a neat HF Trainer benchmarking tool, so here it is applied to the changes introduced by this PR:

Variation Train
samples
per
second
Diff
%
Train
loss
--optim adamw_hf 117.544 32 2.19851
--optim adamw_torch 112.951 27 2.19829
--optim adafactor 89.194 0 2.20484
--optim apex_fused_adam 126.232 42 2.19832

So torch's AdamW appears to be even slower than ours. So clearly apex's AdamW is the way to go speed-wise.

Note, that the absolute and relative results will be different on a different GPU and a different finetuning setup, but most likely the current fastest optimizer will remain fastest, etc.

Reproducibility and other info:

Datetime    : 2021-12-28 20:56:42

Software:
transformers: 4.16.0.dev0
torch       : 1.10.1
cuda        : 11.3
python      : 3.8.11

Hardware:
1 GPUs      : NVIDIA GeForce RTX 3090, 23.70GB

The benchmark command line was:

CUDA_VISIBLE_DEVICES=0 python \
/hf/transformers-trainer-benchmark/scripts/benchmark/trainer-benchmark.py \
--base-cmd \
' \
examples/pytorch/translation/run_translation.py --model_name_or_path t5-base --output_dir output_dir \
--do_train --label_smoothing 0.1 --logging_strategy no --save_strategy no --per_device_train_batch_size 16 \
--max_source_length 512 --max_target_length 512 --num_train_epochs 1 --overwrite_output_dir \
--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config "ro-en" \
--source_prefix "translate English to Romanian: "  --warmup_steps 50 \
--max_train_samples 10000 --dataloader_num_workers 2 \
' \
--target-metric-key train_samples_per_second --repeat-times 1 --variations \
'--optim adamw_hf|--optim adamw_torch|--optim adafactor|--optim apex_fused_adam' \
--table-format github --report-metric-keys train_loss

@manuelciosici
Copy link
Contributor Author

@stas00 You have not lost me. I've been on vacation and before that, I accidentally messed up my GitHub notification settings, so I didn't know you had also reviewed the code. I applied all the suggestions from you.

I also changed to an Enum for optimizer string values as @sgugger suggested.

Let me know if I should change anything else.

@manuelciosici
Copy link
Contributor Author

@stas00 Let me know if I should squash all the commits to a single commit so they don't pollute master's commit log.

@manuelciosici
Copy link
Contributor Author

@stas00 Thank you for the code cleanup, figuring out the mocking issue, and for your patience. This PR has been an educational experience (I didn't know about parameterized). I'm looking forward to figuring out what my next contribution should be. Let me know if you have any suggestions.

@stas00
Copy link
Contributor

stas00 commented Jan 3, 2022

Since clearly you're interesting in easier testing, you may find some useful tidbits in this extensive doc: https://huggingface.co/docs/transformers/testing, e.g. for different parameterization situations https://huggingface.co/docs/transformers/testing#parametrization

optimizers-wise I think the next interesting but challenging thing to add is BNB 8bit optimizer
#14819 but we are still discussing how it'd work.

The other thing to potentially experiment with is
https://www.deepspeed.ai/tutorials/onebit-adam/
but I haven't had a chance to understand it so I have no idea whether it can be used outside of Deepspeed or just with deepspeed.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for waiting for me! This looks great (nice new tests!) but I have a few last comments.

src/transformers/training_args.py Outdated Show resolved Hide resolved
Comment on lines 854 to 867
if args.optim == OptimizerNames.ADAFACTOR.value:
optimizer_cls = Adafactor
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
elif args.optim == OptimizerNames.ADAM_HF.value:
from .optimization import AdamW

optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs)
elif args.optim == OptimizerNames.ADAM_TORCH.value:
from torch.optim import AdamW

optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs)
elif args.optim == OptimizerNames.ADAM_APEX_FUSED.value:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very use of an enum. The args.optim attributed should be converted to the enum type (you just need to set it with a type OptimizerNames instead of str and it will accept both str and OptimizerNames values as well as doing the conversion in the post init) and those tests should have all the .value removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. I've updated everything.

src/transformers/training_args.py Outdated Show resolved Hide resolved
src/transformers/training_args.py Outdated Show resolved Hide resolved
manuelciosici and others added 3 commits January 12, 2022 16:26
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@stas00
Copy link
Contributor

stas00 commented Jan 12, 2022

I think the only other remaining item that I didn't hear you weigh on, @sgugger, is whether we should call these --adam_foo or --adamw_foo since the class names are AdamW (except apex).

@sgugger
Copy link
Collaborator

sgugger commented Jan 12, 2022

Oh sorry, I didn't catch that. It should adamw everywhere IMO.

@stas00
Copy link
Contributor

stas00 commented Jan 12, 2022

Oh sorry, I didn't catch that. It should adamw everywhere IMO.

Thank you for validating that, @sgugger.

@manuelciosici, so one more tiny change please s|--adam_|--adamw_|

Thank you!

@manuelciosici
Copy link
Contributor Author

I removed .value from everywhere and ensured that the tests pass. I have also changed optimizer name strings as @stas00 asked.

Let me know if I should change anything else.

@stas00
Copy link
Contributor

stas00 commented Jan 12, 2022

but the rest needs to updated to match, e.g. currently many torch tests fail with:

E           ValueError: Trainer cannot instantiate unsupported optimizer: adamw_hf

@manuelciosici
Copy link
Contributor Author

@stas00 I just saw that. I'm trying to figure out what I misunderstood.

@manuelciosici
Copy link
Contributor Author

@stas00 I surprised that fixes it. On my end, I just fixed it by adding self.optim = OptimizerNames(self.optim) in post-init. I also had to remove a now redundant unit test.

@stas00
Copy link
Contributor

stas00 commented Jan 12, 2022

It's the same as:

elif args.optim == OptimizerNames.ADAMW_HF:

I'm not sure if you can see the failing tests from CI, so I thought it'd be faster to push in the fix as I saw where it was failing.

Are you still planning to push a change? You said removing a unit test.

Let us know when you're done.

@manuelciosici
Copy link
Contributor Author

@stas00 Commit e73249c makes the unit tests pass, but it doesn't work when --optim is explicitly set on the command line. TrainingArguments does not automatically convert strings to enums. For a parallel, see

self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type)

So, calling a script on the command line (for example, examples/pytorch/language-modeling/run_clm.py) with any explicit --optim (even if only with --optim adamw_hf), throws an error from Trainer.get_optimizer_cls_and_kwargs since get_optimizer_cls_and_kwargs receives a string that does not match any of the if branches which test for Enum values.

I added another test for specifying the optimizer name as a string. Also, I removed test_optim_unsupported since, with these commits, --optim no longer accepts strings that are not in OptimizerNames.

I also changed the default value from OptimizerNames.ADAMW_HF to OptimizerNames.ADAMW_HF.value. With default=OptimizerNames.ADAMW_HF calling --help on the CLI gives:

  --optim {adamw_hf,adamw_torch,adamw_apex_fused,adafactor}
                        The optimizer to use. (default:
                        OptimizerNames.ADAMW_HF)

While default=OptimizerNames.ADAMW_HF.value gives

  --optim {adamw_hf,adamw_torch,adamw_apex_fused,adafactor}
                        The optimizer to use. (default: adamw_hf)

The first one leaks the internal object name, while the second indicates the string we want users to pass.

Finally, I make OptimizerNames inherit ExplicitEnum instead of Enum because I saw SchedulerType do the same and it seems more elegant.

class SchedulerType(ExplicitEnum):

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All looking great on my side! @stas00 I'll let you merge if you're happy as well.

Thanks again for all your work on this @manuelciosici !

@stas00 stas00 merged commit 7b83feb into huggingface:master Jan 13, 2022
@stas00
Copy link
Contributor

stas00 commented Jan 13, 2022

I second that - thank you, @manuelciosici!

@manuelciosici
Copy link
Contributor Author

@stas00 @sgugger Thank you for guiding me!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Two bugs in AdamW
3 participants