From 40b9b2d5892de8fa41ac2b6a9431d3e44d03adb9 Mon Sep 17 00:00:00 2001 From: shayaharon Date: Wed, 25 Jan 2023 13:49:43 +0200 Subject: [PATCH] unscaling and clipping moved to step --- src/super_gradients/training/sg_trainer/sg_trainer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index 31912a9842..349f31693e 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -546,16 +546,17 @@ def _backward_step(self, loss: torch.Tensor, epoch: int, batch_idx: int, context self.scaler.scale(loss).backward() self.phase_callback_handler.on_train_batch_backward_end(context) - # APPLY GRADIENT CLIPPING IF REQUIRED - if self.training_params.clip_grad_norm: - torch.nn.utils.clip_grad_norm_(self.net.parameters(), self.training_params.clip_grad_norm) - # ACCUMULATE GRADIENT FOR X BATCHES BEFORE OPTIMIZING integrated_batches_num = batch_idx + len(self.train_loader) * epoch + 1 if integrated_batches_num % self.batch_accumulate == 0: self.phase_callback_handler.on_train_batch_gradient_step_start(context) + # APPLY GRADIENT CLIPPING IF REQUIRED + if self.training_params.clip_grad_norm: + self.scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_(self.net.parameters(), self.training_params.clip_grad_norm) + # SCALER IS ENABLED ONLY IF self.training_params.mixed_precision=True self.scaler.step(self.optimizer) self.scaler.update()