Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check early stopping metric in the beginning of the training #542

Merged

Conversation

kuynzereb
Copy link
Contributor

As was discussed in #524 we should check the availability of the metric required by the early stopping in the very beginning of the training. To do so I force validation sanity check if early stopping is enabled and check the obtained callback metrics.

Also this PR set default early_stopping_callback to False. It was slightly discussed in #524 and #536. I suggest to finally decide whether it should be turned on or off. I think it is better to turn it off because:

  1. By default the user doesn't expect that his training will be interrupted by some callbacks.
  2. Default early stopping will work only if the user has defined val_loss. Thus if the user doesn't know about early stopping then it will be pure coincidence, in some cases it will work and in some cases it will not. And if the user knows about the early stopping then he can easily turn it on himself.

The changes can be tested with the following code:

from time import sleep
import torch
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl


class DummyDataset(Dataset):
    def __init__(self, n):
        super().__init__()
        self.n = n
    def __len__(self):
        return self.n
    def __getitem__(self, idx):
        return torch.rand(10)


sleep_time = 0.25


class CoolSystem(pl.LightningModule):
    def __init__(self):
        super(CoolSystem, self).__init__()
        self.layer = torch.nn.Linear(10, 10)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_nb):
        sleep(sleep_time)
        return {'loss': torch.mean(self.forward(batch) ** 2)}

    def validation_step(self, batch, batch_nb):
        sleep(sleep_time)
        return {}

    def validation_end(self, outputs):
        return {'_': torch.tensor(0)}

    def test_step(self, batch, batch_nb):
        sleep(sleep_time)
        return {}

    def test_end(self, outputs):
        return {}

    def configure_optimizers(self):
        return [torch.optim.Adam(self.layer.parameters())]

    @pl.data_loader
    def train_dataloader(self):
        return DataLoader(DummyDataset(10), batch_size=1)

    @pl.data_loader
    def val_dataloader(self):
        return DataLoader(DummyDataset(5), batch_size=1)

    @pl.data_loader
    def test_dataloader(self):
        return DataLoader(DummyDataset(5), batch_size=1)


model = CoolSystem()
trainer = pl.Trainer(weights_summary=None, checkpoint_callback=False, min_nb_epochs=1,
                     early_stop_callback=True)
try:
    trainer.fit(model)
except RuntimeError as e:
    print(e)

print()
print()

model = CoolSystem()
trainer = pl.Trainer(weights_summary=None, checkpoint_callback=False, min_nb_epochs=1,
                     nb_sanity_val_steps=0, early_stop_callback=True)
try:
    trainer.fit(model)
except RuntimeError as e:
    print(e)

print()
print()

#If early_stop_callback=False and nb_sanity_val_steps=0 there is no validation sanity check.
model = CoolSystem()
trainer = pl.Trainer(weights_summary=None, checkpoint_callback=False, min_nb_epochs=1,
                     nb_sanity_val_steps=0, early_stop_callback=False, max_nb_epochs=1)
try:
    trainer.fit(model)
except RuntimeError as e:
    print(e)

print()
print()

def validation_end_correct(self, outputs):
    return {'val_loss': torch.tensor(0)}

CoolSystem.validation_end = validation_end_correct
model = CoolSystem()
trainer = pl.Trainer(weights_summary=None, checkpoint_callback=False, min_nb_epochs=1,
                     nb_sanity_val_steps=5, early_stop_callback=True)
try:
    trainer.fit(model)
except RuntimeError as e:
    print(e)

print()
print()

model = CoolSystem()
trainer = pl.Trainer(weights_summary=None, checkpoint_callback=False, min_nb_epochs=1,
                     nb_sanity_val_steps=0, early_stop_callback=True)
try:
    trainer.fit(model)
except RuntimeError as e:
    print(e)

Copy link
Member

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure about some sections...

RuntimeWarning)
stop_training = True

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then you should return True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly. Return True was before and it caused the interruption of the training if the required metric was not found. And now it just gives a warning and training just proceeds as though without early stopping. The point is that the callback should not stop the training if it can't find the metrics.

Actually, in the current implementation this branch is not reachable because we check for the availability of the metric in the trainer initialization. But my idea was that if we decide to set early_stopping to True by default, then it can be used to give a warning but not to stop the training.

You can also look at #524 for better understanding.

@@ -53,7 +53,7 @@ class Trainer(TrainerIOMixin,
def __init__(self,
logger=True,
checkpoint_callback=True,
early_stop_callback=True,
early_stop_callback=False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why changing the default config?

Copy link
Contributor Author

@kuynzereb kuynzereb Nov 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I have written in the description we should discuss whether early stopping should be turned on or off by the default. I think that it is better to be turned off. Again, please look at #524

Copy link
Contributor

@williamFalcon williamFalcon Nov 26, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Early stopping should default to True - this is the most common use case.
This is a best practice for research with patience of 3.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that it is conditioned on the val_loss and it will work only if the user guess right the name. For this reason it seems to me that the better option is when the user deliberately enables early stopping. It is also very easy (just set True).

But if you insist I can suggest the following: we enable early stopping by default, but if there is no val_loss we will just warn the user that early stopping will not work and training will proceed with disabled early stopping. I just don't like that there will be some warnings when you run the trainer with the default settings :)

@@ -185,6 +185,8 @@ def __init__(self,
# creates a default one if none passed in
self.early_stop_callback = None
self.configure_early_stopping(early_stop_callback, logger)
if self.enable_early_stop:
self.nb_sanity_val_steps = max(1, self.nb_sanity_val_steps)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe max(1, nb_sanity_val_steps) since earlier you have

if self.fast_dev_run:
    self.nb_sanity_val_steps = 1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But exactly by that reason it should be max(1, self.nb_sanity_val_steps) :)

We just take the previously defined final self.nb_sanity_val_steps and set it to 1 if it is less than 1.

If we made as you have suggested then self.nb_sanity_val_steps will be equal to the user defined value in fast dev run mode, but it should be 1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not do this. People need to have the option of turning sanity_val_check off

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But how then we will check that early stopping will work correctly? (Note that we force this check only if early stopping is turned on.)

Copy link
Contributor

@williamFalcon williamFalcon Nov 30, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand what you're saying, but restricting EVERYONE to force sanity check will certainly block some esoteric research or production cases, so we can't do this.

But I think this is on the user at this point. If they turned off sanity check then it's on them at that point and are willingly exposing themselves to these kinds of issues... but for people who keep it on, then we use what you suggest.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@williamFalcon
Copy link
Contributor

@kuynzereb ok, a few things:

  1. early stopping needs to default to True.
  2. cannot remove the option of turning off sanity_val_check.

I think it's good to check that the metric exists in the beginning of training, which we can do at the first instance of a val result:
either in sanity check or after the first validation is returned.

@kuynzereb
Copy link
Contributor Author

Hmmm, okay, I think I got it. Our position is: If the user turn off validation sanity check then he is responsible for any failures which can arise after very first validation loop.

@williamFalcon
Copy link
Contributor

exactly

@@ -113,6 +113,13 @@ def run_training_epoch(self):
if self.fast_dev_run or should_check_val:
self.run_evaluation(test=self.testing)

if (self.enable_early_stop and
self.callback_metrics.get(self.early_stop_callback.monitor) is None):
raise RuntimeError(f"Early stopping was configured to monitor "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@kuynzereb
Copy link
Contributor Author

I guess the tests fail because early stopping is turned on by default but some test models don't specify correct validation_end. So we maybe should turn off early stopping when testing.

@kuynzereb
Copy link
Contributor Author

I have fixed the tests. While I was doing it I have discovered some more problems:

  1. Current implementation will not work it the user doesn't define validation_step at all. Moreover, in such a case the total number of batches in progress bar is wrong now. The problem is that, if there is no validation_step we run run_evaluation() as usual but this function just does nothing and raises no error if it can't find validation_step(). I think we should do such a check in pretrain_routine(), so the trainer will be aware that there will be no validation computation.
  2. I think that when the user set val_percent_check=0 he expects that it will effectively disable validation. But now because of
self.nb_val_batches = int(self.nb_val_batches * self.val_percent_check)
self.nb_val_batches = max(1, self.nb_val_batches)

it will not happen and there will be validation runs.

@williamFalcon
Copy link
Contributor

@kuynzereb awesome! mind fixing the conflicts?

@kuynzereb
Copy link
Contributor Author

Yep, I will try to fix the conflicts ASAP.

@kuynzereb
Copy link
Contributor Author

Well, I have quite rethought all the thing, so @Borda and @williamFalcon please look at this.

I have added check_metrics method and strict parameter to EarlyStoppingCallback. Basically, check_metrics looks for the monitor in given metrics and raises the error if strict=True and makes a warning if strict=False and verbose=True. If there is no monitor in metrics and strict=False and verbose=False, then EarlyStoppingCallback just have no effect.

It allows us just to call self.early_stop_callback.check_metrics() after any validation loop to check for monitor metric. strict and verbose parameters allows us to control the behavior if no monitor is found.

Finally, now I have set early_stop_callback=None by default, which creates default EarlyStoppingCallback with strict=False and verbose=False. So, if there is val_loss defined by the user, we will have early stopping by default. Otherwise, training will go as if there is no early stopping at all.

And if you set early_stop_callback=True it will create default EarlyStoppingCallback with strict=True and verbose=True.

@Borda
Copy link
Member

Borda commented Jan 14, 2020

well, let me see the latest changes...

@awaelchli
Copy link
Contributor

If early_stop_callback=None I would expect it to be turned off even if the user wants val_loss or any other metric to be returned and logged.

@kuynzereb
Copy link
Contributor Author

Yeah, it is indeed somewhat ambiguous. But at the same time it is a good practice to use None for the default values, as far as I know.

The problem is that:

  1. @williamFalcon wants early stopping to be turned on by default.
  2. If early stopping set to True then it crashes the training if it can't find the appropriate metric.
  3. We should not crash the training when the user use default parameters (see default EarlyStopping callback should not fail on missing val_loss data #524).

So, one possible solution is to handle the default early stopping in a special way. And the only way I know to distinguish between user set and default value is set it to None. But yes, then it starts to be ambiguous...

Another solution may be not to crash training at all, even it early stopping was deliberately turned on, just give warnings.

@williamFalcon
Copy link
Contributor

@neggert @Borda @kuynzereb we good with this change?

@williamFalcon
Copy link
Contributor

@kuynzereb could you rebase master to make sure you have the latest test fixes on GPU?

@kuynzereb
Copy link
Contributor Author

I have merged master and updated the docs

@Borda
Copy link
Member

Borda commented Jan 22, 2020

I would say that rebase is a better option but merge master should be also fine...

@williamFalcon williamFalcon merged commit 50881c0 into Lightning-AI:master Jan 23, 2020
@kuynzereb kuynzereb deleted the early_stopping_callback_fix branch January 24, 2020 08:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants