-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
tl;dr
does the approach in the code snippet below look ok, or is there a better alternative for automatically skipping few "bad" samples in the data that cause inf/nan gradients/loss? (is it a good practice altogether?)
details
sometimes, there is a small percentage (but annoyingly large in absolute value) of "dirty" samples in the data that cause the loss to be nan, although the neural-network architecture itself is fine and stable in terms of numerical stability.
one approach is to automatically stop training (use terminate_on_nan
) and then somehow isolate all these samples and remove them from the data permanently. but..
sometimes we simply want to automatically skip these samples as if they never existed (perhaps with a warning), and continue training.
I couldn't find any documentation about how to do that, nor anyone who asked this question. so i decided to ask and offer a solution I found, for others that might need it as well.
in the end, i came up with the following approach - override on_after_backwards
method in my lightning-module with the following code:
code
def on_after_backward(self) -> None:
valid_gradients = True
for name, param in self.named_parameters():
if param.grad is not None:
valid_gradients = not (torch.isnan(param.grad).any() or torch.isinf(param.grad).any())
if not valid_gradients:
break
if not valid_gradients:
log.warning(f'detected inf or nan values in gradients. not updating model parameters')
self.zero_grad()
pros
- this code successfully identifies nan/inf gradients, and skips parameter update by zeroing gradients for the specific batch
- support multi-gpu (at least ddp which I tested). when done this way, detecting inf/nan gradients (instead of inf/nan loss), we avoid a potential cases of losing synchronization between different processes, because typically one of the processes would generate an inf loss, while the others won't. if we stop only one process from doing a backwards pass, we lose synchronization, and would stumble into a never-ending processes that wait for nothing. training stalls. when checking gradients, it is after all gradients in all processes have been affected by the bad inf loss. so we have synchronization.
cons
- can't catch bad samples that way.. need to work harder..
- might not be future proof
- clutters lightning module code (it is essentially architecture agnostic, boiler-plate code)
- perhaps there is a better way..
final question
is it worth having such functionality integrated into lightning as a simple command-line-switch/parameter?