diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index aa673abd1f085..70a413680a931 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1091,10 +1091,10 @@ def __clip_gradients(self): torch.nn.utils.clip_grad_norm_(model.parameters(), self.gradient_clip) def __print_nan_grads(self): - if self.print_nan_grads: - model = self.__get_model() - for param in model.parameters(): - print(param.grad.float().sum()) + model = self.__get_model() + for param in model.parameters(): + if torch.isnan(param.grad.float()).any(): + print(param, param.grad) def __run_tng_batch(self, data_batch, batch_nb): if data_batch is None: @@ -1137,7 +1137,8 @@ def __run_tng_batch(self, data_batch, batch_nb): model_ref.on_after_backward() # nan grads - self.__print_nan_grads() + if self.print_nan_grads: + self.__print_nan_grads() # track total loss for logging (avoid mem leaks) self.batch_loss_value += loss.item()