Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
fix a bug when using fp16 training & gradient clipping (#5426)
Browse files Browse the repository at this point in the history
* fix a bug when using fp16 training & gradient clipping

* fix format

* format

Co-authored-by: yk x <xyk1021355229@gmail.com>
Co-authored-by: epwalsh <epwalsh10@gmail.com>
  • Loading branch information
3 people authored Oct 7, 2021
1 parent a63e28c commit 17ef1aa
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion allennlp/training/gradient_descent_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch.cuda import amp
from torch.nn.utils import clip_grad_norm_
import torch.distributed as dist
from torch.cuda.amp.grad_scaler import OptState

from allennlp.common.checks import ConfigurationError, check_for_gpu
from allennlp.common import util as common_util, Tqdm, Lazy
Expand Down Expand Up @@ -349,6 +350,23 @@ def _pytorch_model(self):
return self.model
return self._ddp_wrapped_model.model

def clip_gradient(self):
"""
Performs gradient clipping.
If the model is in mixed precision training, we would first unscale the gradient.
"""
if self._grad_clipping is not None:
# 1. We have to unscale the gradient before clipping
if self._scaler is not None:
optimizer_state = self._scaler._per_optimizer_states[id(self.optimizer)]
# 2. The `unscale_` shouldn't be performed more than once per optimizer per step call,
# so we only perform `unscale_` if it has not already been called.
if optimizer_state["stage"] is not OptState.UNSCALED:
self._scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_value_(
[p for p in self.model.parameters() if p.grad is not None], self._grad_clipping
)

def rescale_gradients(self) -> Optional[float]:
"""
Performs gradient rescaling. Is a no-op if gradient rescaling is not enabled.
Expand Down Expand Up @@ -518,6 +536,7 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
train_loss += batch_loss

batch_grad_norm = self.rescale_gradients()
self.clip_gradient()

if self._learning_rate_scheduler:
self._learning_rate_scheduler.step_batch(self._total_batches_completed + 1)
Expand Down Expand Up @@ -756,7 +775,6 @@ def train(self) -> Dict[str, Any]:
callback.on_end(self, metrics=metrics, epoch=epoch, is_primary=self._primary)

def _try_train(self) -> Tuple[Dict[str, Any], int]:
training_util.enable_gradient_clipping(self.model, self._grad_clipping)

logger.info("Beginning training.")

Expand Down

0 comments on commit 17ef1aa

Please sign in to comment.