Skip to content

EarlyStopping Integration with LearningRateMonitor #17449

@CrohnEngineer

Description

@CrohnEngineer

Description & Motivation

Description

It would be nice to have the EarlyStopping callback access the learning rate values logged by a LearningRateMonitor.

Motivation

Using the ReduceLROnPlateau scheduler, it is useful to stop the training when the learning rate reaches a minimum value.
While LearningRateMonitor logs the learning rate to a logger, the EarlyStopping callback cannot access the values of this metric.
Inserting the following callback EarlyStopping(monitor='lr', mode='min', stopping_threshold=min_lr_value) fails the training with a

RuntimeError: Early stopping conditioned on metric `lr` which is not available. Pass in or modify your `EarlyStopping` callback to use any of the following: `train_loss`, `val_loss`.

Pitch

Allow the EarlyStopping callback to access learning rate values logged by the LearningRateMonitor. This enables users to stop training when the learning rate reaches a minimum value when using the ReduceLROnPlateau scheduler. Currently, attempting to do so results in a RuntimeError.

Alternatives

As explained in this discussion, the most obvious alternative would be for the users to log by themselves the learning rate during the validation step.

Additional context

A small training script to replicate the error based on the docs Lightning in 15 minutes, if it might be useful.

# --- Libraries import
import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning.pytorch as pl
import torch
import torch.utils.data as data
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger


# --- Module definition

# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))


# define the LightningModule
class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        val_loss = nn.functional.mse_loss(x_hat, x)
        self.log("val_loss", val_loss)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


# --- Init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)

# --- Setup data
train_set = MNIST(os.getcwd(), download=True, transform=ToTensor(), train=True)

# use 20% of training data for validation
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size

# split the train set into two
seed = torch.Generator().manual_seed(42)
train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)

# instantiate dataloaders
train_loader = data.DataLoader(train_set)
valid_loader = data.DataLoader(valid_set)

# --- Trainer
logger = TensorBoardLogger(save_dir='experiments/logs')
trainer = pl.Trainer(accelerator='auto', logger=logger, max_epochs=4,
                     callbacks=[LearningRateMonitor(logging_interval='step'),
                                EarlyStopping(monitor='val_loss', mode='min', patience=2),
                                EarlyStopping(monitor='lr', mode='min', stopping_threshold=1e-5,
                                              check_on_train_epoch_end=True),
                                ],
                    )

# --- Fit
trainer.fit(model=autoencoder, train_dataloaders=train_loader, val_dataloaders=valid_loader)


cc @Borda @carmocca @Blaizzy @awaelchli

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions