Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix grad clipping stage + upscale gradients before clipping #653

Merged
merged 2 commits into from
Jan 25, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,16 +546,17 @@ def _backward_step(self, loss: torch.Tensor, epoch: int, batch_idx: int, context
self.scaler.scale(loss).backward()
self.phase_callback_handler.on_train_batch_backward_end(context)

# APPLY GRADIENT CLIPPING IF REQUIRED
if self.training_params.clip_grad_norm:
torch.nn.utils.clip_grad_norm_(self.net.parameters(), self.training_params.clip_grad_norm)

# ACCUMULATE GRADIENT FOR X BATCHES BEFORE OPTIMIZING
integrated_batches_num = batch_idx + len(self.train_loader) * epoch + 1

if integrated_batches_num % self.batch_accumulate == 0:
self.phase_callback_handler.on_train_batch_gradient_step_start(context)

# APPLY GRADIENT CLIPPING IF REQUIRED
if self.training_params.clip_grad_norm:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.net.parameters(), self.training_params.clip_grad_norm)

# SCALER IS ENABLED ONLY IF self.training_params.mixed_precision=True
self.scaler.step(self.optimizer)
self.scaler.update()
Expand Down