Skip to content

how to properly skip samples that cause inf/nan gradients/loss  #4956

@levhaikin

Description

@levhaikin

tl;dr

does the approach in the code snippet below look ok, or is there a better alternative for automatically skipping few "bad" samples in the data that cause inf/nan gradients/loss? (is it a good practice altogether?)

details

sometimes, there is a small percentage (but annoyingly large in absolute value) of "dirty" samples in the data that cause the loss to be nan, although the neural-network architecture itself is fine and stable in terms of numerical stability.
one approach is to automatically stop training (use terminate_on_nan) and then somehow isolate all these samples and remove them from the data permanently. but..
sometimes we simply want to automatically skip these samples as if they never existed (perhaps with a warning), and continue training.
I couldn't find any documentation about how to do that, nor anyone who asked this question. so i decided to ask and offer a solution I found, for others that might need it as well.
in the end, i came up with the following approach - override on_after_backwards method in my lightning-module with the following code:

code

    def on_after_backward(self) -> None:
        valid_gradients = True
        for name, param in self.named_parameters():
            if param.grad is not None:
                valid_gradients = not (torch.isnan(param.grad).any() or torch.isinf(param.grad).any())
                if not valid_gradients:
                    break

        if not valid_gradients:
            log.warning(f'detected inf or nan values in gradients. not updating model parameters')
            self.zero_grad()

pros

  • this code successfully identifies nan/inf gradients, and skips parameter update by zeroing gradients for the specific batch
  • support multi-gpu (at least ddp which I tested). when done this way, detecting inf/nan gradients (instead of inf/nan loss), we avoid a potential cases of losing synchronization between different processes, because typically one of the processes would generate an inf loss, while the others won't. if we stop only one process from doing a backwards pass, we lose synchronization, and would stumble into a never-ending processes that wait for nothing. training stalls. when checking gradients, it is after all gradients in all processes have been affected by the bad inf loss. so we have synchronization.

cons

  • can't catch bad samples that way.. need to work harder..
  • might not be future proof
  • clutters lightning module code (it is essentially architecture agnostic, boiler-plate code)
  • perhaps there is a better way..

final question

is it worth having such functionality integrated into lightning as a simple command-line-switch/parameter?

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancementquestionFurther information is requestedwon't fixThis will not be worked on

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions