Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecates AdamW and adds --optim #14744

Merged
merged 55 commits into from
Jan 13, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
32392af
Add AdamW deprecation warning
manuelciosici Dec 10, 2021
c637f37
Add --optim to Trainer
manuelciosici Dec 10, 2021
b4d0b6d
Update src/transformers/optimization.py
manuelciosici Dec 13, 2021
bcc2408
Update src/transformers/optimization.py
manuelciosici Dec 13, 2021
460eff4
Update src/transformers/optimization.py
manuelciosici Dec 13, 2021
6dc78a6
Update src/transformers/optimization.py
manuelciosici Dec 13, 2021
68dd581
Update src/transformers/training_args.py
manuelciosici Dec 13, 2021
9560350
Update src/transformers/training_args.py
manuelciosici Dec 13, 2021
01f1c7b
Update src/transformers/training_args.py
stas00 Dec 27, 2021
0c79a5f
Merge remote-tracking branch 'origin/master' into deprecate_adamw
stas00 Dec 27, 2021
7ec094f
fix style
stas00 Dec 27, 2021
1c9cccf
fix
stas00 Dec 29, 2021
9807d35
Regroup adamws together
manuelciosici Dec 30, 2021
7a063ab
Change --adafactor to --optim adafactor
manuelciosici Dec 30, 2021
d599a38
Use Enum for optimizer values
manuelciosici Dec 30, 2021
1f9210c
fixup! Change --adafactor to --optim adafactor
manuelciosici Dec 30, 2021
a80b39e
fixup! Change --adafactor to --optim adafactor
manuelciosici Dec 30, 2021
fdf40b2
fixup! Change --adafactor to --optim adafactor
manuelciosici Dec 30, 2021
d5dc69a
Merge branch 'master' into deprecate_adamw
manuelciosici Dec 30, 2021
0acba0c
fixup! Use Enum for optimizer values
manuelciosici Dec 30, 2021
2b7d9dd
Improved documentation for --adafactor
manuelciosici Dec 31, 2021
7c3139a
Add mention of no_deprecation_warning
manuelciosici Dec 31, 2021
234f7d1
Rename OptimizerOptions to OptimizerNames
manuelciosici Dec 31, 2021
1786d42
Use choices for --optim
manuelciosici Dec 31, 2021
210ed37
Move optimizer selection code to a function and add a unit test
manuelciosici Dec 31, 2021
7e62da9
Change optimizer names
manuelciosici Dec 31, 2021
0e7f955
Rename method
manuelciosici Jan 1, 2022
12a9e37
Rename method
manuelciosici Jan 1, 2022
c5853b0
Remove TODO comment
manuelciosici Jan 1, 2022
d59aa52
Rename variable
manuelciosici Jan 1, 2022
e7ffd71
Rename variable
manuelciosici Jan 1, 2022
b64fc03
Rename function
manuelciosici Jan 1, 2022
c5b5443
Rename variable
manuelciosici Jan 1, 2022
91aff78
Parameterize the tests for supported optimizers
manuelciosici Jan 1, 2022
f3505db
Refactor
manuelciosici Jan 1, 2022
91c35f2
Attempt to make tests pass on CircleCI
manuelciosici Jan 1, 2022
bcd8a0d
Add a test with apex
manuelciosici Jan 2, 2022
f8cb39c
rework to add apex to parameterized; add actual train test
stas00 Jan 2, 2022
98f0f2f
fix import when torch is not available
stas00 Jan 2, 2022
eba41bd
fix optim_test_params when torch is not available
stas00 Jan 2, 2022
aaee305
fix optim_test_params when torch is not available
stas00 Jan 2, 2022
071198c
re-org
stas00 Jan 2, 2022
182dac8
small re-org
stas00 Jan 2, 2022
2b46361
fix test_fused_adam_no_apex
stas00 Jan 2, 2022
470a1d7
Update src/transformers/training_args.py
manuelciosici Jan 12, 2022
cb85474
Update src/transformers/training_args.py
manuelciosici Jan 12, 2022
b2675f8
Update src/transformers/training_args.py
manuelciosici Jan 12, 2022
1e8acec
Remove .value from OptimizerNames
manuelciosici Jan 12, 2022
b32a194
Rename optimizer strings s|--adam_|--adamw_|
manuelciosici Jan 12, 2022
b839e80
Also rename Enum options
manuelciosici Jan 12, 2022
e73249c
small fix
stas00 Jan 12, 2022
7ac8dc0
Fix instantiation of OptimizerNames. Remove redundant test
manuelciosici Jan 12, 2022
a2363cd
Use ExplicitEnum instead of Enum
manuelciosici Jan 12, 2022
ea02877
Add unit test with string optimizer
manuelciosici Jan 12, 2022
ec92011
Change optimizer default to string value
manuelciosici Jan 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/transformers/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""PyTorch optimization for BERT model."""

import math
import warnings
from typing import Callable, Iterable, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -287,6 +288,8 @@ class AdamW(Optimizer):
Decoupled weight decay to apply.
correct_bias (:obj:`bool`, `optional`, defaults to `True`):
Whether or not to correct bias in Adam (for instance, in Bert TF repository they use :obj:`False`).
no_deprecation_warning (:obj:`bool`, `optional`, defaults to :obj:`False`):
A flag used to disable the deprecation warning (set to :obj:`True` to disable the warning).
"""

def __init__(
Expand All @@ -297,7 +300,14 @@ def __init__(
eps: float = 1e-6,
weight_decay: float = 0.0,
correct_bias: bool = True,
no_deprecation_warning: bool = False,
):
if not no_deprecation_warning:
warnings.warn(
"This implementation of AdamW is deprecated and will be removed in a future version. Use the"
"PyTorch implementation torch.optim.AdamW instead.",
manuelciosici marked this conversation as resolved.
Show resolved Hide resolved
FutureWarning,
)
require_version("torch>=1.5.0") # add_ with alpha
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
Expand Down
46 changes: 36 additions & 10 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
from .optimization import Adafactor, AdamW, get_scheduler
from .optimization import Adafactor, get_scheduler
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import (
CallbackHandler,
Expand Down Expand Up @@ -818,17 +818,43 @@ def create_optimizer(self):
"weight_decay": 0.0,
},
]
optimizer_cls = Adafactor if self.args.adafactor else AdamW
if self.args.adafactor:
if self.args.adafactor and self.args.optim not in {"adamw_hf", "adafactor"}:
Copy link
Contributor

@stas00 stas00 Dec 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this logic is unnecessary complicated, IMHO

  1. deprecate --adafactor in favor of --optim adafactor
  2. if --adafactor is passed set --optim adafactor

now you no longer need to ever consider self.args.adafactor other than to deprecate it.

raise ValueError(f"You passed the --adafactor flag and optimizer {self.args.optim}.")

optimizer_kwargs = {"lr": self.args.learning_rate}

adam_kwargs = {
"betas": (self.args.adam_beta1, self.args.adam_beta2),
"eps": self.args.adam_epsilon,
}

# TODO the following code is a good candidate for PEP 622 once Python 3.10 becomes the
# minimum required version. See, https://www.python.org/dev/peps/pep-0622/
if (self.args.adafactor and self.args.optim == "adamw_hf") or self.args.optim == "adafactor":
stas00 marked this conversation as resolved.
Show resolved Hide resolved
optimizer_cls = Adafactor
optimizer_kwargs = {"scale_parameter": False, "relative_step": False}
else:
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
elif self.args.optim == "adamw_hf":
from .optimization import AdamW

optimizer_cls = AdamW
optimizer_kwargs = {
"betas": (self.args.adam_beta1, self.args.adam_beta2),
"eps": self.args.adam_epsilon,
}
optimizer_kwargs["lr"] = self.args.learning_rate
optimizer_kwargs.update(adam_kwargs)
elif self.args.optim == "adamw_torch":
from torch.optim import AdamW

optimizer_kwargs.update(adam_kwargs)
manuelciosici marked this conversation as resolved.
Show resolved Hide resolved
elif self.args.optim == "apex_fused_adam":
try:
from apex.optimizers import FusedAdam

optimizer_cls = FusedAdam
optimizer_kwargs.update(adam_kwargs)
except ImportError:
raise ValueError(
"Trainer attempted to instantiate apex.optimizers.FusedAdam but apex is not installed!"
)
else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {self.args.optim}")

if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer = OSS(
params=optimizer_grouped_parameters,
Expand Down
17 changes: 15 additions & 2 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,11 @@ class TrainingArguments:
- :obj:`"tpu_metrics_debug"`: print debug metrics on TPU

The options should be separated by whitespaces.
optim (:obj:`str`, `optional`, defaults to :obj:`adamw_hf`):
The optimizer to use: adamw_hf, adamw_torch, adafactor, or apex_fused_adam.
adafactor (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use the :class:`~transformers.Adafactor` optimizer instead of
:class:`~transformers.AdamW`.
This argument is deprecated. Use ``--optim adafactor`` instead. Whether or not to use the
:class:`~transformers.Adafactor` optimizer instead of :class:`~transformers.AdamW`.
group_by_length (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to group together samples of roughly the same length in the training dataset (to minimize
padding applied and be more efficient). Only useful if applying dynamic padding.
Expand Down Expand Up @@ -646,6 +648,10 @@ class TrainingArguments:
label_smoothing_factor: float = field(
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
)
optim: str = field(
manuelciosici marked this conversation as resolved.
Show resolved Hide resolved
default="adamw_hf",
metadata={"help": "The optimizer to use: adamw_hf, adamw_torch, adafactor, or apex_fused_adam."},
manuelciosici marked this conversation as resolved.
Show resolved Hide resolved
)
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
group_by_length: bool = field(
default=False,
Expand Down Expand Up @@ -807,6 +813,13 @@ def __post_init__(self):
)
if not (self.sharded_ddp == "" or not self.sharded_ddp):
raise ValueError("sharded_ddp is not supported with bf16")

if self.adafactor:
warnings.warn(
stas00 marked this conversation as resolved.
Show resolved Hide resolved
"`adafactor` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `--optim adafactor` instead",
FutureWarning,
)
stas00 marked this conversation as resolved.
Show resolved Hide resolved

stas00 marked this conversation as resolved.
Show resolved Hide resolved
if (
is_torch_available()
and self.device.type != "cuda"
Expand Down