Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support early stopping during training inside of the early stopping callback #7033

Closed
ananthsub opened this issue Apr 15, 2021 · 0 comments · Fixed by #6944 or #7069
Closed

Support early stopping during training inside of the early stopping callback #7033

ananthsub opened this issue Apr 15, 2021 · 0 comments · Fixed by #6944 or #7069
Labels
feature Is an improvement or enhancement help wanted Open to be worked on

Comments

@ananthsub
Copy link
Contributor

ananthsub commented Apr 15, 2021

🚀 Feature

The ability to early stop during training, as controlled by the callback instead of the training loop.

Motivation

This simplifies the training loop from manually running these callback hooks if the validation epoch won't be run: https://github.com/PyTorchLightning/pytorch-lightning/blob/5bd3cd5f712b65d38812b27cf957261bb06b83c5/pytorch_lightning/trainer/training_loop.py#L152-L159

Similar to flexibility with checkpointing callbacks, this can eventually enable users to specify separate early stop criteria for both training and validation.

Pitch

Add an check_on_train_epoch_end flag to the callback constructor. See #6944 for a sketch

This flag controls whether we check during training or validation. Because the monitor metric may be in training but not validation, or vice versa, this flag makes the check across these two hooks mutually exclusive.

    def on_train_epoch_end(self, trainer, pl_module, outputs) -> None:
        if not self._on_train_epoch_end or self._should_skip_check(trainer):
            return
        self._run_early_stopping_check(trainer)

    def on_validation_end(self, trainer, pl_module):
        if self._on_train_epoch_end or self._should_skip_check(trainer):
            return
        self._run_early_stopping_check(trainer)

For parity with existing behavior, by default this flag will be False by default. With this feature enabled, users can specify their callback for early stopping during training as such:

stop = EarlyStopping(monitor='abc', min_delta=0.1, patience=0, check_on_train_epoch_end=True)

Users could then create multiple such callbacks:

train_stop = EarlyStopping(monitor='abc', min_delta=0.1, patience=0, check_on_train_epoch_end=True)
val_stop = EarlyStopping(monitor='val_loss', min_delta=0.5, patience=3)
trainer = Trainer(...., callbacks=[train_stop, val_stop], ...)
trainer.fit(...)

Alternatives

Keep as is

Additional context

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on
Projects
None yet
1 participant