Skip to content

Commit

Permalink
Make print_nan_grads print grad (#208)
Browse files Browse the repository at this point in the history
This seems more useful for debugging.
  • Loading branch information
alok authored and williamFalcon committed Sep 7, 2019
1 parent 9f9d386 commit 81df225
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,10 +1091,10 @@ def __clip_gradients(self):
torch.nn.utils.clip_grad_norm_(model.parameters(), self.gradient_clip)

def __print_nan_grads(self):
if self.print_nan_grads:
model = self.__get_model()
for param in model.parameters():
print(param.grad.float().sum())
model = self.__get_model()
for param in model.parameters():
if torch.isnan(param.grad.float()).any():
print(param, param.grad)

def __run_tng_batch(self, data_batch, batch_nb):
if data_batch is None:
Expand Down Expand Up @@ -1137,7 +1137,8 @@ def __run_tng_batch(self, data_batch, batch_nb):
model_ref.on_after_backward()

# nan grads
self.__print_nan_grads()
if self.print_nan_grads:
self.__print_nan_grads()

# track total loss for logging (avoid mem leaks)
self.batch_loss_value += loss.item()
Expand Down

0 comments on commit 81df225

Please sign in to comment.