Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add possibility to switch between APEX and AMP in Trainer #9137

Merged
merged 4 commits into from
Dec 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 31 additions & 26 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@
from .utils import logging


_use_native_amp = False
_use_apex = False
_is_native_amp_available = False

DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
Expand All @@ -110,16 +109,10 @@

if is_apex_available():
from apex import amp
_use_apex = True
else:
_use_native_amp = True
_is_native_amp_available = True
from torch.cuda.amp import autocast

if version.parse(torch.__version__) < version.parse("1.2"):
_use_ddp_no_sync = False
else:
_use_ddp_no_sync = True

if is_datasets_available():
import datasets

Expand Down Expand Up @@ -292,13 +285,30 @@ def __init__(
if isinstance(eval_dataset, datasets.Dataset):
self._remove_unused_columns(self.eval_dataset, description="evaluation")

# Mixed precision setup
self.use_apex = False
self.use_amp = False
if args.fp16:
if args.fp16_backend == "auto":
backend = "amp" if _is_native_amp_available else "apex"
else:
backend = args.fp16_backend

if backend == "amp":
self.use_amp = True
self.scaler = torch.cuda.amp.GradScaler()
else:
if not is_apex_available():
raise ImportError(
"Using FP16 with APEX but APEX is not installed, please refer to https://www.github.com/nvidia/apex."
)
self.use_apex = True

self.state = TrainerState()
self.control = TrainerControl()
# Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the
# state at each call to self.log.
self._total_flos = None
if self.args.fp16 and _use_native_amp:
self.scaler = torch.cuda.amp.GradScaler()
self.hp_search_backend = None
self.use_tune_checkpoints = False
default_label_names = (
Expand Down Expand Up @@ -625,9 +635,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D

# Mixed precision training with apex (torch < 1.6)
model = self.model
if self.args.fp16 and _use_apex:
if not is_apex_available():
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
if self.use_apex:
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)

# Multi-gpu training (should be after apex fp16 initialization)
Expand Down Expand Up @@ -756,11 +764,8 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
if (step + 1) % self.args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control)

if (
((step + 1) % self.args.gradient_accumulation_steps != 0)
and self.args.local_rank != -1
and _use_ddp_no_sync
):
if ((step + 1) % self.args.gradient_accumulation_steps != 0) and self.args.local_rank != -1:
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
with model.no_sync():
tr_loss += self.training_step(model, inputs)
else:
Expand All @@ -772,17 +777,17 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
steps_in_epoch <= self.args.gradient_accumulation_steps
and (step + 1) == steps_in_epoch
):
if self.args.fp16 and _use_native_amp:
if self.use_amp:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
elif self.args.fp16 and _use_apex:
elif self.use_apex:
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)

if is_torch_tpu_available():
xm.optimizer_step(self.optimizer)
elif self.args.fp16 and _use_native_amp:
elif self.use_amp:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
Expand Down Expand Up @@ -1089,7 +1094,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
model.train()
inputs = self._prepare_inputs(inputs)

if self.args.fp16 and _use_native_amp:
if self.use_amp:
with autocast():
loss = self.compute_loss(model, inputs)
else:
Expand All @@ -1101,9 +1106,9 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
if self.args.gradient_accumulation_steps > 1:
loss = loss / self.args.gradient_accumulation_steps

if self.args.fp16 and _use_native_amp:
if self.use_amp:
self.scaler.scale(loss).backward()
elif self.args.fp16 and _use_apex:
elif self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
Expand Down Expand Up @@ -1498,7 +1503,7 @@ def prediction_step(
ignore_keys = []

with torch.no_grad():
if self.args.fp16 and _use_native_amp:
if self.use_amp:
with autocast():
outputs = model(**inputs)
else:
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ class TrainingArguments:
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping
step can take a long time) but will not yield the same results as the interrupted training would have.
fp16_backend (:obj:`str`, `optional`, defaults to :obj:`"auto"`):
The backend to use for mixed precision training. Must be one of :obj:`"auto"`, :obj:`"amp"` or
:obj:`"apex"`. :obj:`"auto"` will use AMP or APEX depending on the PyTorch version detected, while the
other choices will force the requested backend.
"""

output_dir: str = field(
Expand Down Expand Up @@ -378,6 +382,10 @@ class TrainingArguments:
"help": "When resuming training, whether or not to skip the first epochs and batches to get to the same training data."
},
)
fp16_backend: str = field(
default="auto",
metadata={"help": "The backend to be used for mixed precision. Should be one of 'auto', 'amp' or 'apex'."},
)

def __post_init__(self):
if self.disable_tqdm is None:
Expand Down
54 changes: 29 additions & 25 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,34 +798,38 @@ def test_num_train_epochs_in_training(self):

def test_early_stopping_callback(self):
# early stopping stops training before num_training_epochs
trainer = get_regression_trainer(
num_train_epochs=20,
gradient_accumulation_steps=1,
per_device_train_batch_size=16,
load_best_model_at_end=True,
evaluation_strategy=EvaluationStrategy.EPOCH,
compute_metrics=AlmostAccuracy(),
metric_for_best_model="accuracy",
)
trainer.add_callback(EarlyStoppingCallback(1, 0.0001))
train_output = trainer.train()
self.assertLess(train_output.global_step, 20 * 64 / 16)
with tempfile.TemporaryDirectory() as tmp_dir:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test was saving things in a regression folder and adding lots of unwanted files. Moving it to a temp folder.

trainer = get_regression_trainer(
output_dir=tmp_dir,
num_train_epochs=20,
gradient_accumulation_steps=1,
per_device_train_batch_size=16,
load_best_model_at_end=True,
evaluation_strategy=EvaluationStrategy.EPOCH,
compute_metrics=AlmostAccuracy(),
metric_for_best_model="accuracy",
)
trainer.add_callback(EarlyStoppingCallback(1, 0.0001))
train_output = trainer.train()
self.assertLess(train_output.global_step, 20 * 64 / 16)

# Invalid inputs to trainer with early stopping callback result in assertion error
trainer = get_regression_trainer(
num_train_epochs=20,
gradient_accumulation_steps=1,
per_device_train_batch_size=16,
evaluation_strategy=EvaluationStrategy.EPOCH,
compute_metrics=AlmostAccuracy(),
metric_for_best_model="accuracy",
)
trainer.add_callback(EarlyStoppingCallback(1))
self.assertEqual(trainer.state.global_step, 0)
try:
trainer.train()
except AssertionError:
with tempfile.TemporaryDirectory() as tmp_dir:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

trainer = get_regression_trainer(
output_dir=tmp_dir,
num_train_epochs=20,
gradient_accumulation_steps=1,
per_device_train_batch_size=16,
evaluation_strategy=EvaluationStrategy.EPOCH,
compute_metrics=AlmostAccuracy(),
metric_for_best_model="accuracy",
)
trainer.add_callback(EarlyStoppingCallback(1))
self.assertEqual(trainer.state.global_step, 0)
try:
trainer.train()
except AssertionError:
self.assertEqual(trainer.state.global_step, 0)

def test_flos_extraction(self):
trainer = get_regression_trainer(learning_rate=0.1)
Expand Down