-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Description & Motivation
Description
It would be nice to have the EarlyStopping callback access the learning rate values logged by a LearningRateMonitor.
Motivation
Using the ReduceLROnPlateau scheduler, it is useful to stop the training when the learning rate reaches a minimum value.
While LearningRateMonitor logs the learning rate to a logger, the EarlyStopping callback cannot access the values of this metric.
Inserting the following callback EarlyStopping(monitor='lr', mode='min', stopping_threshold=min_lr_value) fails the training with a
RuntimeError: Early stopping conditioned on metric `lr` which is not available. Pass in or modify your `EarlyStopping` callback to use any of the following: `train_loss`, `val_loss`.
Pitch
Allow the EarlyStopping callback to access learning rate values logged by the LearningRateMonitor. This enables users to stop training when the learning rate reaches a minimum value when using the ReduceLROnPlateau scheduler. Currently, attempting to do so results in a RuntimeError.
Alternatives
As explained in this discussion, the most obvious alternative would be for the users to log by themselves the learning rate during the validation step.
Additional context
A small training script to replicate the error based on the docs Lightning in 15 minutes, if it might be useful.
# --- Libraries import
import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning.pytorch as pl
import torch
import torch.utils.data as data
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
# --- Module definition
# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
# define the LightningModule
class LitAutoEncoder(pl.LightningModule):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def training_step(self, batch, batch_idx):
# training_step defines the train loop.
# it is independent of forward
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = nn.functional.mse_loss(x_hat, x)
# Logging to TensorBoard (if installed) by default
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
# this is the validation loop
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
val_loss = nn.functional.mse_loss(x_hat, x)
self.log("val_loss", val_loss)
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=1e-3)
return optimizer
# --- Init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)
# --- Setup data
train_set = MNIST(os.getcwd(), download=True, transform=ToTensor(), train=True)
# use 20% of training data for validation
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size
# split the train set into two
seed = torch.Generator().manual_seed(42)
train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)
# instantiate dataloaders
train_loader = data.DataLoader(train_set)
valid_loader = data.DataLoader(valid_set)
# --- Trainer
logger = TensorBoardLogger(save_dir='experiments/logs')
trainer = pl.Trainer(accelerator='auto', logger=logger, max_epochs=4,
callbacks=[LearningRateMonitor(logging_interval='step'),
EarlyStopping(monitor='val_loss', mode='min', patience=2),
EarlyStopping(monitor='lr', mode='min', stopping_threshold=1e-5,
check_on_train_epoch_end=True),
],
)
# --- Fit
trainer.fit(model=autoencoder, train_dataloaders=train_loader, val_dataloaders=valid_loader)