Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes for early stopping and checkpoint callbacks #1504

Conversation

jeremyjordan
Copy link
Contributor

@jeremyjordan jeremyjordan commented Apr 16, 2020

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?
  • If you made a notable change (that affects users), did you update the CHANGELOG?

What does this PR do?

For #1464
For #1463
For #1699
For #2151
Related #1458

  • best attribute isn't being saved
  • wait attribute isn't being reloaded properly
  • wait epoch is lagging by an epoch
  • early stopping callbacks are now being called twice
  • callback throwing an exception on epochs where validation metrics aren't available (due to check_val_every_n_epoch>1)

Adds tests to prevent future regressions.

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@mergify mergify bot requested a review from a team April 16, 2020 03:52
@jeremyjordan jeremyjordan removed the request for review from a team April 16, 2020 03:54
@mergify mergify bot requested a review from a team April 16, 2020 03:55
@jeremyjordan
Copy link
Contributor Author

@PyTorchLightning/core-contributors currently, our documentation states that:

In any case, the callback will fall back to the training metrics (returned in training_step(), training_step_end()) looking for a key to monitor if validation is disabled or validation_epoch_end() is not defined.

However, this is not completely true. We only look at callback_metrics which is any key that is not loss, log, or progress. Do we want to update this to look across all values? Or correct the documentation to reflect the current reality?

@pep8speaks
Copy link

pep8speaks commented Apr 25, 2020

Hello @jeremyjordan! Thanks for updating this PR.

Line 244:71: W504 line break after binary operator
Line 245:72: W504 line break after binary operator

Comment last updated at 2020-06-28 06:34:32 UTC

@jeremyjordan jeremyjordan changed the title [WIP] fixes for early stopping callback [WIP] fixes for early stopping and checkpoint callbacks Apr 26, 2020
@jeremyjordan
Copy link
Contributor Author

@Borda any idea why some of the logger tests are failing?

@Borda Borda added the bug Something isn't working label Apr 29, 2020
@Borda Borda added this to the 0.7.6 milestone Apr 29, 2020
Copy link
Member

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems that the tests are failing on multiple places not only loggers... let's take it one by one

pytorch_lightning/callbacks/early_stopping.py Outdated Show resolved Hide resolved
pytorch_lightning/callbacks/early_stopping.py Outdated Show resolved Hide resolved
@mergify mergify bot requested a review from a team April 29, 2020 10:05
@jeremyjordan
Copy link
Contributor Author

tests are failing when:

  • logger name or version is None
  • an opaque pickle issue
  • off by one error in checkpointed global step
  • mock error in parsing args

i need to investigate the off by one error, but not sure how the other tests failing are related to the changes in this PR

i want to get these failing tests addressed, then will write more tests for the early stopping callback.

@jeremyjordan
Copy link
Contributor Author

jeremyjordan commented May 1, 2020

ok, there's one remaining failing test and i've tracked down the issue. there's a thread lock being created when you create the OfflineExperiment which is preventing to object from being pickle-able. (see #1682)

@jeremyjordan jeremyjordan changed the title [WIP] fixes for early stopping and checkpoint callbacks [blocked] fixes for early stopping and checkpoint callbacks May 5, 2020
@@ -197,7 +197,7 @@ def format_checkpoint_name(self, epoch, metrics, ver=None):
return filepath

@rank_zero_only
def on_validation_end(self, trainer, pl_module):
def on_epoch_end(self, trainer, pl_module):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! Does this change effect checkpointing in the middle of training epoch? Consider the usecase where we train on a large dataset and we want to checkpoint & early stop every X steps instead of every X epoches, for example X = 100, i.e. val_check_interval = 100.

@Borda
Copy link
Member

Borda commented May 11, 2020

@jeremyjordan is it blocked by another pr?

@williamFalcon
Copy link
Contributor

@jeremyjordan which pr is blocking this?

@jeremyjordan
Copy link
Contributor Author

the tests won't pass as is until #1682 is addressed, we'll probably want to merge #1458 and then i can have this as a follow-on PR which ensures that the EarlyStopping callback works well with Checkpointing

@Borda Borda changed the title [blocked] fixes for early stopping and checkpoint callbacks [blocked by #1458] fixes for early stopping and checkpoint callbacks May 19, 2020
@jeremyjordan jeremyjordan changed the title [blocked by #1458] fixes for early stopping and checkpoint callbacks [WIP] fixes for early stopping and checkpoint callbacks May 21, 2020
@jeremyjordan jeremyjordan changed the title [WIP] fixes for early stopping and checkpoint callbacks [blocked by 1458] fixes for early stopping and checkpoint callbacks May 24, 2020
@mergify
Copy link
Contributor

mergify bot commented Jun 23, 2020

This pull request is now in conflict... :(

@awaelchli
Copy link
Contributor

awaelchli commented Jun 23, 2020

Strange. for me it works fine, I get these timings locally when running

tests\models>py.test -v test_hooks.py --durations=10

7.57s call     tests/models/test_hooks.py::test_on_before_zero_grad_called[2]
7.11s call     tests/models/test_hooks.py::test_on_before_zero_grad_called[3]
6.83s call     tests/models/test_hooks.py::test_on_before_zero_grad_called[1]

max_steps seems to work.
EDIT: even before the merge from @Borda I'm getting these fast timings.

@jeremyjordan
Copy link
Contributor Author

@awaelchli are you running on Windows by any chance? that's the only place where tests are passing :D

@awaelchli
Copy link
Contributor

Oh right that must be it..

@mergify
Copy link
Contributor

mergify bot commented Jun 24, 2020

This pull request is now in conflict... :(

@Borda
Copy link
Member

Borda commented Jun 24, 2020

@jeremyjordan mind rebase/merge master? and how is the last test? 🐰

@mergify
Copy link
Contributor

mergify bot commented Jun 26, 2020

This pull request is now in conflict... :(

@awaelchli
Copy link
Contributor

I merged master into this and getting many failed tests 😭 don't know where to begin but i still have hope this can get merged. Will try to fix them this weekend.

Comment on lines 584 to 540
self.run_evaluation(test_mode=self.testing)
self.call_checkpoint_callback()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeremyjordan The tests fail because the evaluation loop is not getting called after the epoch. Did you intend to movie it somewhere else?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh i think i only meant to remove the call_checkpoint_callback(), this was my mistake - glad you caught that!

@mergify mergify bot requested a review from a team June 27, 2020 11:49
Comment on lines 45 to 60
if self.logger is not None:
save_dir = (getattr(self.logger, 'save_dir', None) or
getattr(self.logger, '_save_dir', None) or
self.default_root_dir)

# weights_save_path overrides anything
if self.weights_save_path is not None:
save_dir = self.weights_save_path

version = self.logger.version if isinstance(
self.logger.version, str) else f'version_{self.logger.version}'
ckpt_path = os.path.join(
save_dir,
self.logger.name,
version,
"checkpoints"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeremyjordan this code was moved to the ModelCheckpoint.on_train_start, and I understand why. However, we have the problem that the logger is already saving a meta.yaml file to the default location before the on_train_start callback is even called an the model checkpoint has the chance to update the weights_save_path.
Any idea how to decouple the checkpoint and logger ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it may be unrelated, since it also happens here #2392. not sure

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I think we should provide shared configuration in the Trainer initialization and not expect these child objects (loggers and checkpoint callbacks) to reach into each other's attributes. this probably also includes moving some attributes from logging (eg. version) up into the Trainer

@mergify mergify bot requested a review from a team June 27, 2020 22:10
@awaelchli
Copy link
Contributor

awaelchli commented Jun 28, 2020

I was able to fix all tests and merge errors.
About this todo comment:

TODO support more generic way for callbacks to persist a state_dict in a checkpoint

What about a callack method on_save_checkpoint (we already have it as model hooks)? Then the checkpoint and early stop callbacks can save their state into the checkpoint and trainer doesn't need to do that.

@Borda Borda force-pushed the bugfix/early-stopping-state branch from ae75fa4 to ba6a5ba Compare June 28, 2020 06:34
@mergify
Copy link
Contributor

mergify bot commented Jun 28, 2020

This pull request is now in conflict... :(

@Borda Borda changed the base branch from master to bugfix/early-stopping-state June 28, 2020 06:35
@Borda Borda merged commit 4cabb88 into Lightning-AI:bugfix/early-stopping-state Jun 28, 2020
@jeremyjordan
Copy link
Contributor Author

What about a callack method on_save_checkpoint (we already have it as model hooks)? Then the checkpoint and early stop callbacks can save their state into the checkpoint and trainer doesn't need to do that.

Yes, I was thinking the same thing. This callback would just return a state_dict which the Trainer could store. The only thing that I am unclear how we should handle is for other callbacks how we want to reinitialize the state. If we can expect that the same exact callbacks will be passed to the Trainer then it should be trivial. Or we could expect that you only pass in a single instance of each callback class (eg. callbacks=[CustomerLogger(), EarlyStopping(), ModelCheckpoint()] and not callbacks=[CustomerLogger(params_a), CustomerLogger(params_b), EarlyStopping(), ModelCheckpoint()] and just keep a mapping of callback class to state dicts. However, if the user passed multiple callback instances of the same class I'm not sure how we would want to handle that.

Maybe for a first iteration we can just document that for on_save_checkpoint you can only have one instance per class?

@Borda
Copy link
Member

Borda commented Jun 28, 2020

@jeremyjordan we moved the PR to #2391 as it is the repo branch and much easier to maintain by other core... :]

@jeremyjordan
Copy link
Contributor Author

@awaelchli I created #2401 for us to continue discussion on your comment

@awaelchli
Copy link
Contributor

perfect!

williamFalcon added a commit that referenced this pull request Jun 29, 2020
* add state_dict for early stopping

* move best attr after monitor_op defined

* improve early stopping and model checkpoint callbacks

* fix formatting

* fix attr init order

* clean up setting of default_root_dir attr

* logger needs default root dir set first

* reorg trainer init

* remove direct references to checkpoint callback

* more fixes

* more bugfixes

* run callbacks at epoch end

* update tests to use on epoch end

* PR cleanup

* address failing tests

* refactor for homogeneity

* fix merge conflict

* separate tests

* tests for early stopping bug regressions

* small fixes

* revert model checkpoint change

* typo fix

* fix tests

* update train loop

* cannot pass an int as default_save_path

* refactor log message

* fix test case

* appease the linter

* fix some doctests

* move config to callback

* fixes from rebase

* fixes from rebase

* chlog

* docs

* reformat

* formatting

* fix

* fix

* fixes from rebase

* add new test for patience

* Update pytorch_lightning/callbacks/model_checkpoint.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/callbacks/model_checkpoint.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update tests/callbacks/test_early_stopping.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* fix formatting

* remove enable_early_stop attribute

* add state_dict for early stopping

* move best attr after monitor_op defined

* improve early stopping and model checkpoint callbacks

* fix formatting

* fix attr init order

* clean up setting of default_root_dir attr

* logger needs default root dir set first

* reorg trainer init

* remove direct references to checkpoint callback

* more fixes

* more bugfixes

* run callbacks at epoch end

* update tests to use on epoch end

* PR cleanup

* address failing tests

* refactor for homogeneity

* fix merge conflict

* separate tests

* tests for early stopping bug regressions

* small fixes

* revert model checkpoint change

* typo fix

* fix tests

* update train loop

* fix test case

* appease the linter

* fix some doctests

* move config to callback

* fixes from rebase

* fixes from rebase

* chlog

* docs

* reformat

* formatting

* fix

* fix

* fixes from rebase

* add new test for patience

* Update pytorch_lightning/callbacks/model_checkpoint.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/callbacks/model_checkpoint.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update tests/callbacks/test_early_stopping.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* fix formatting

* remove enable_early_stop attribute

* fix test with new epoch indexing

* fix progress bar totals

* fix off by one error (see #2289) epoch starts at 0 now

* added missing imports

* fix hpc_save folderpath

* fix formatting

* fix tests

* small fixes from a rebase

* fix

* tmpdir

* tmpdir

* tmpdir

* wandb

* fix merge conflict

* add back evaluation after training

* test_resume_early_stopping_from_checkpoint TODO

* undo the horovod check

* update changelog

* remove a duplicate test from merge error

* try fix dp_resume test

* add the logger fix from master

* try remove default_root_dir

* try mocking numpy

* try import numpy in docs test

* fix wandb test

* pep 8 fix

* skip if no amp

* dont mock when doctesting

* install extra

* fix the resume ES test

* undo conf.py changes

* revert remove comet pickle from test

* Update CHANGELOG.md

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update weights_loading.rst

* Update weights_loading.rst

* Update weights_loading.rst

* renamed flag

* renamed flag

* revert the None check in logger experiment name/version

* add the old comments

* _experiment

* test chckpointing on DDP

* skip the ddp test on windows

* cloudpickle

* renamed flag

* renamed flag

* parentheses for clarity

* apply suggestion max epochs

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

Co-authored-by: Jeremy Jordan <jtjordan@ncsu.edu>
Co-authored-by: Jirka <jirka@pytorchlightning.ai>
Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: William Falcon <waf2107@columbia.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working priority: 0 High priority task
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants