Skip to content

Commit

Permalink
support for native amp (#1561)
Browse files Browse the repository at this point in the history
* adding native amp suppport

* adding native amp suppport

* adding native amp suppport

* adding native amp suppport

* autocast

* autocast

* autocast

* autocast

* autocast

* autocast

* removed comments

* removed comments

* added state saving

* added state saving

* try install amp again

* added state saving

* drop Apex reinstall

Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
3 people authored Apr 23, 2020
1 parent 41b6cbb commit 29ebe92
Show file tree
Hide file tree
Showing 11 changed files with 100 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .drone.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ steps:
- pip install pip -U
- pip --version
- nvidia-smi
# - bash ./tests/install_AMP.sh
#- bash ./tests/install_AMP.sh
- apt-get update && apt-get install -y cmake
- pip install -r requirements.txt --user -q
- pip install -r ./tests/requirements-devel.txt --user -q
Expand Down
10 changes: 8 additions & 2 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,15 @@ def backward(self, use_amp, loss, optimizer):
"""
if trainer.precision == 16:

# .backward is not special on 16-bit with TPUs
if not trainer.on_tpu:
if trainer.on_tpu:
return

if self.trainer.use_native_amp:
self.trainer.scaler.scale(loss).backward()

# TODO: remove in v0.8.0
else:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
Expand Down
15 changes: 14 additions & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,9 +1157,22 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer,
if self.trainer.use_tpu and XLA_AVAILABLE:
xm.optimizer_step(optimizer)
elif isinstance(optimizer, torch.optim.LBFGS):

# native amp + lbfgs is a no go right now
if self.use_amp and self.use_native_amp:
m = 'native PyTorch amp and lbfgs are not compatible. To request, please file' \
'a Github issue in PyTorch and tag @mcarilli'
raise MisconfigurationException(m)
optimizer.step(second_order_closure)
else:
optimizer.step()
if self.use_amp and self.use_native_amp:
self.trainer.scaler.step(optimizer)
else:
optimizer.step()

# in native 16-bit we need to update scaler after optimizer step
if self.use_amp and self.use_native_amp:
self.trainer.scaler.update()

# model hook
self.on_before_zero_grad(optimizer)
Expand Down
24 changes: 23 additions & 1 deletion pytorch_lightning/trainer/auto_mix_precision.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from abc import ABC
import torch

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import rank_zero_warn

try:
from apex import amp
Expand All @@ -15,8 +17,28 @@ class TrainerAMPMixin(ABC):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
precision: int
use_native_amp: bool

def init_amp(self, use_amp):
# TODO: remove in v 0.8.0
if self.use_native_amp:
rank_zero_warn("`amp_level` has been deprecated since v0.7.4 "
"(native amp does not require it)"
" and this argument will be removed in v0.8.0", DeprecationWarning)

# Backward compatibility, TODO: remove in v0.9.0
if use_amp is not None:
rank_zero_warn("`use_amp` has been replaced by `precision` since v0.7.0"
" and this argument will be removed in v0.9.0", DeprecationWarning)
self.precision = 16 if use_amp else 32

assert self.precision in (16, 32), 'only 32 or 16 bit precision supported'

if use_amp and self.use_native_amp:
log.info('Using 16bit precision.')
return

# TODO: remove all below for v0.8.0
if use_amp and not APEX_AVAILABLE: # pragma: no-cover
raise ModuleNotFoundError("""
You set `use_amp=True` but do not have apex installed.
Expand All @@ -31,4 +53,4 @@ def init_amp(self, use_amp):

@property
def use_amp(self) -> bool:
return self.precision == 16 and APEX_AVAILABLE
return self.precision == 16
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class TrainerDDPMixin(ABC):
amp_level: str
use_tpu: bool
default_root_dir: str
use_native_amp: bool

@property
@abstractmethod
Expand Down Expand Up @@ -350,8 +351,8 @@ def ddp_train(self, process_idx, model):

# AMP
# run through amp wrapper before going to distributed DP
if self.use_amp:
# An example
# TODO: remove in v0.8.0
if self.use_amp and not self.use_native_amp:
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
self.optimizers = optimizers

Expand Down
15 changes: 13 additions & 2 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ class TrainerDPMixin(ABC):
tpu_local_core_rank: int
tpu_global_core_rank: int
use_tpu: bool
use_native_amp: bool
data_parallel_device_ids: ...
logger: Union[LightningLoggerBase, bool]

Expand Down Expand Up @@ -481,7 +482,8 @@ def single_gpu_train(self, model):
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)

if self.use_amp:
# TODO: update for 0.8.0
if self.use_amp and not self.use_native_amp:
# An example
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
self.optimizers = optimizers
Expand Down Expand Up @@ -528,9 +530,16 @@ def dp_train(self, model):

model.cuda(self.root_gpu)

# hack forward to do autocast for the user
model_autocast_original_forward = model.forward
if self.use_amp and self.use_native_amp:
# wrap the user's forward in autocast and give it back at the end
model.forward = torch.cuda.amp.autocast()(model.forward)

# TODO: remove in v0.8.0
# check for this bug (amp + dp + !01 doesn't work)
# https://github.com/NVIDIA/apex/issues/227
if self.use_dp and self.use_amp:
if self.use_dp and self.use_amp and not self.use_native_amp:
if self.amp_level == 'O2':
raise MisconfigurationException(
f'Amp level {self.amp_level} with DataParallel is not supported.'
Expand All @@ -551,6 +560,8 @@ def dp_train(self, model):

self.run_pretrain_routine(model)

model.forward = model_autocast_original_forward

def horovod_train(self, model):
# Horovod: initialize library
hvd.init()
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,11 @@ def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_
# -----------------
# RUN EVALUATION STEP
# -----------------
output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
if self.use_amp and self.use_native_amp:
with torch.cuda.amp.autocast():
output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
else:
output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)

# on dp / ddp2 might still want to do something with the batch parts
if test_mode:
Expand Down
24 changes: 11 additions & 13 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def __init__(
print_nan_grads: bool = False, # backward compatible, todo: remove in v0.9.0
weights_summary: Optional[str] = 'full',
weights_save_path: Optional[str] = None,
amp_level: str = 'O1',
num_sanity_val_steps: int = 5,
truncated_bptt_steps: Optional[int] = None,
resume_from_checkpoint: Optional[str] = None,
Expand All @@ -124,6 +123,7 @@ def __init__(
reload_dataloaders_every_epoch: bool = False,
auto_lr_find: Union[bool, str] = False,
replace_sampler_ddp: bool = True,
amp_level: str = 'O1', # backward compatible, todo: remove in v0.8.0
default_save_path=None, # backward compatible, todo: remove in v0.8.0
gradient_clip=None, # backward compatible, todo: remove in v0.8.0
nb_gpu_nodes=None, # backward compatible, todo: remove in v0.8.0
Expand Down Expand Up @@ -487,20 +487,18 @@ def __init__(
self.determine_data_use_amount(train_percent_check, val_percent_check,
test_percent_check, overfit_pct)

# 16 bit mixed precision training using apex
# AMP init
# These are the only lines needed after v0.8.0
# we wrap the user's forward with autocast and give it back at the end of fit
self.autocast_original_forward = None
self.use_native_amp = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
if self.use_native_amp and self.precision == 16:
self.scaler = torch.cuda.amp.GradScaler()
self.precision = precision

# TODO: remove for v0.8.0
self.amp_level = amp_level
self.precision = precision

# Backward compatibility, TODO: remove in v0.9.0
if use_amp is not None:
rank_zero_warn("`use_amp` has been replaced by `precision` since v0.7.0"
" and this argument will be removed in v0.9.0", DeprecationWarning)
self.precision = 16 if use_amp else 32

assert self.precision in (16, 32), 'only 32 or 16 bit precision supported'

if self.precision == 16 and self.num_tpu_cores is None:
use_amp = True
self.init_amp(use_amp)

# Callback system
Expand Down
12 changes: 12 additions & 0 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,10 @@ def restore(self, checkpoint_path: str, on_gpu: bool):
if on_gpu:
model.cuda(self.root_gpu)

# restore amp scaling
if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint:
self.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])

# load training state (affects trainer only)
self.restore_training_state(checkpoint)

Expand Down Expand Up @@ -316,6 +320,10 @@ def dump_checkpoint(self):

checkpoint['state_dict'] = model.state_dict()

# restore native amp scaling
if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint:
checkpoint['native_amp_scaling_state'] = self.scaler.state_dict()

if hasattr(model, "hparams"):
is_namespace = isinstance(model.hparams, Namespace)
checkpoint['hparams'] = vars(model.hparams) if is_namespace else model.hparams
Expand Down Expand Up @@ -441,6 +449,10 @@ def hpc_load(self, folderpath, on_gpu):
# load the state_dict on the model automatically
model.load_state_dict(checkpoint['state_dict'])

# restore amp scaling
if self.use_amp and self.use_native_amp and 'native_amp_scaling_state' in checkpoint:
self.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])

if self.root_gpu is not None:
model.cuda(self.root_gpu)

Expand Down
11 changes: 9 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def training_step(self, batch, batch_idx):

import numpy as np
from torch.utils.data import DataLoader
import torch

from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
Expand Down Expand Up @@ -588,8 +589,12 @@ def run_training_batch(self, batch, batch_idx):
def optimizer_closure():
# forward pass
with self.profiler.profile('model_forward'):
output_dict = self.training_forward(
split_batch, batch_idx, opt_idx, self.hiddens)
if self.use_amp and self.use_native_amp:
with torch.cuda.amp.autocast():
output_dict = self.training_forward(split_batch, batch_idx,
opt_idx, self.hiddens)
else:
output_dict = self.training_forward(split_batch, batch_idx, opt_idx, self.hiddens)

# format and reduce outputs accordingly
processed_output = self.process_output(output_dict, train=True)
Expand Down Expand Up @@ -645,6 +650,8 @@ def optimizer_closure():
self.track_grad_norm)

# clip gradients
if self.use_amp and self.use_native_amp:
self.scaler.unscale_(optimizer)
self.clip_gradients()

# calls .step(), .zero_grad()
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def get_model(self):
"""Warning: this is just empty shell for code implemented in other class."""

def clip_gradients(self):

# this code is a modification of torch.nn.utils.clip_grad_norm_
# with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md
if self.gradient_clip_val > 0:
Expand Down

0 comments on commit 29ebe92

Please sign in to comment.