Skip to content

Commit

Permalink
fix apex gradient clipping (#2829)
Browse files Browse the repository at this point in the history
  • Loading branch information
ruotianluo authored Aug 5, 2020
1 parent 5bbcb8d commit 6034d5e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def transfer_batch_to_tpu(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def clip_gradients(self):
def clip_gradients(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
Expand Down Expand Up @@ -817,7 +817,7 @@ def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer):
# ------------------
if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
self.scaler.unscale_(optimizer)
self.clip_gradients()
self.clip_gradients(optimizer)

# ------------------
# .STEP + ZERO_GRAD
Expand Down
15 changes: 13 additions & 2 deletions pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,17 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.callbacks import GradientAccumulationScheduler
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import is_oom_error, garbage_collection_cuda

try:
from apex import amp
except ImportError:
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True

EPSILON = 1e-6
EPSILON_FP16 = 1e-5

Expand Down Expand Up @@ -60,14 +68,17 @@ def restore(self, *args):
def fit(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

def clip_gradients(self):
def clip_gradients(self, optimizer):

# this code is a modification of torch.nn.utils.clip_grad_norm_
# with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md
if self.gradient_clip_val <= 0:
return
model = self.get_model()
parameters = model.parameters()
if self.use_amp and not NATIVE_AMP_AVALAIBLE:
parameters = amp.master_params(optimizer)
else:
parameters = model.parameters()
max_norm = float(self.gradient_clip_val)
norm_type = float(2.0)
if isinstance(parameters, torch.Tensor):
Expand Down

0 comments on commit 6034d5e

Please sign in to comment.