Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LearningRateFinder defining max validation batches for entire training loop #17412

Closed
blainehoak opened this issue Apr 18, 2023 · 1 comment · Fixed by #17636
Closed

LearningRateFinder defining max validation batches for entire training loop #17412

blainehoak opened this issue Apr 18, 2023 · 1 comment · Fixed by #17636
Labels
bug Something isn't working help wanted Open to be worked on tuner ver: 2.0.x
Milestone

Comments

@blainehoak
Copy link

blainehoak commented Apr 18, 2023

Bug description

When the LearningRateFinder callback is used, the num_training_steps parameter that is passed on init (default: 100) ends up defining how many validation batches to run during the entire length of training. Meaning that if num_training_steps in the learning rate finder is less than the total number of batches in your validation set, then all validation loops while training will only see a subset of the validation data.

What version are you seeing the problem on?

2.0+

How to reproduce the bug

This code will fail because trainer.num_val_batches[0] = 5.

import os

import torch
from torch.utils.data import DataLoader, Dataset

from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import LearningRateFinder


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self, lr=0.1):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.lr = lr

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=self.lr)


def run():
    train_data = DataLoader(RandomDataset(32, 100), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 100), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
        callbacks=[LearningRateFinder(num_training_steps=5)],
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    assert trainer.num_val_batches[0] == 50
    trainer.validate(model, dataloaders=val_data)


if __name__ == "__main__":
    run()

Using the same base code from above but removing the LearningRateFinder, this code passes.

    trainer = Trainer(
        default_root_dir=os.getcwd(),
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    assert trainer.num_val_batches[0] == 50
    trainer.validate(model, dataloaders=val_data)

However, num_val_batches does get updated once .validate() is called. Putting the LearningRateFinder back in but moving the assert statement, this code passes:

    trainer = Trainer(
        default_root_dir=os.getcwd(),
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
        callbacks=[LearningRateFinder(num_training_steps=5)],
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.validate(model, dataloaders=val_data)
    assert trainer.num_val_batches[0] == 50

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0): 2.0.1.post0
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0): 2.0.0
#- Python version (e.g., 3.9): 3.10.9
#- OS (e.g., Linux): Darwin
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source): pip
#- Running environment of LightningApp (e.g. local, cloud):

More info

I did some more digging into why this might be happening and it looks like the problem is likely coming from the fact that trainer.fit_loop.epoch_loop.val_loop.setup_data() is getting called for the first time while the learning rate finder is running, so trainer.fit_loop.epoch_loop.val_loop._max_batches gets set according to the parameters that the learning rate finder has passed in.

Even though the learning rate finder restores the parameters that the trainer initially set once it is done, the setup_data() method never runs a full setup again, so the _max_batches attribute never gets updated again.

One solution to fix this might be to redo the data setup once the learning rate finder has completed, like how setup is redone when .validate() is called

@blainehoak blainehoak added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Apr 18, 2023
@awaelchli awaelchli added tuner help wanted Open to be worked on and removed needs triage Waiting to be triaged by maintainers labels Apr 23, 2023
@awaelchli awaelchli added this to the 2.0.x milestone Apr 23, 2023
@awaelchli
Copy link
Contributor

@blainehoak Thanks for reporting. Help on this would be appreciated :) You are right, the finder is probably not resetting all variables correctly.

@Borda Borda changed the title LearningRateFinder defining max validation batches for entire training loop LearningRateFinder defining max validation batches for entire training loop May 3, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on tuner ver: 2.0.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants