Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Continue Jeremy's early stopping PR #1504 #2391

Merged
merged 155 commits into from
Jun 29, 2020
Merged
Changes from 1 commit
Commits
Show all changes
155 commits
Select commit Hold shift + click to select a range
d5472c0
add state_dict for early stopping
Apr 16, 2020
339506a
move best attr after monitor_op defined
Apr 16, 2020
a5a1198
improve early stopping and model checkpoint callbacks
Apr 25, 2020
a07be00
fix formatting
Apr 25, 2020
47fdd16
fix attr init order
Apr 25, 2020
535314b
clean up setting of default_root_dir attr
Apr 25, 2020
ed980ea
logger needs default root dir set first
Apr 25, 2020
98e240b
reorg trainer init
Apr 25, 2020
a316e02
remove direct references to checkpoint callback
Apr 26, 2020
048a5f3
more fixes
Apr 26, 2020
980ac2b
more bugfixes
Apr 26, 2020
c80c0e7
run callbacks at epoch end
Apr 28, 2020
84d1d54
update tests to use on epoch end
Apr 29, 2020
35f193c
PR cleanup
Apr 29, 2020
f32d22b
address failing tests
Apr 30, 2020
78d2545
refactor for homogeneity
May 1, 2020
f236d3e
fix merge conflict
May 5, 2020
9972af9
separate tests
May 22, 2020
fc7c6e8
tests for early stopping bug regressions
May 23, 2020
016da66
small fixes
May 23, 2020
5a0028c
revert model checkpoint change
May 23, 2020
d6087af
typo fix
May 23, 2020
1386365
fix tests
May 24, 2020
2200c81
update train loop
May 25, 2020
be80c8e
cannot pass an int as default_save_path
May 25, 2020
994e25b
refactor log message
May 26, 2020
fa488eb
fix test case
May 27, 2020
1a56edd
appease the linter
May 27, 2020
9c83128
fix some doctests
May 29, 2020
42b39c5
move config to callback
May 30, 2020
4e414d6
fixes from rebase
Jun 1, 2020
52d60e8
fixes from rebase
Jun 1, 2020
f084371
chlog
Borda Jun 1, 2020
24f4dfe
docs
Borda Jun 1, 2020
2a0a9c2
reformat
Borda Jun 1, 2020
949d3e6
formatting
Borda Jun 1, 2020
668b2ca
fix
Borda Jun 1, 2020
02914cf
fix
Borda Jun 1, 2020
a3243d5
Merge branch 'master' into bugfix/early-stopping-state
Jun 16, 2020
4837abe
fixes from rebase
Jun 16, 2020
a8a39d5
add new test for patience
Jun 16, 2020
c49c231
Merge branch 'bugfix/early-stopping-state' of https://github.com/jere…
Jun 16, 2020
053ce18
Update pytorch_lightning/callbacks/model_checkpoint.py
jeremyjordan Jun 16, 2020
5902d82
Update pytorch_lightning/callbacks/model_checkpoint.py
jeremyjordan Jun 16, 2020
83a754d
Update tests/callbacks/test_early_stopping.py
jeremyjordan Jun 16, 2020
1cf8a47
Merge branch 'master' into bugfix/early-stopping-state
Jun 17, 2020
fa669ef
fix formatting
Jun 17, 2020
33f6e2d
remove enable_early_stop attribute
Jun 17, 2020
4e31335
add state_dict for early stopping
Apr 16, 2020
9763351
move best attr after monitor_op defined
Apr 16, 2020
0f4fc5f
improve early stopping and model checkpoint callbacks
Apr 25, 2020
1a24c81
fix formatting
Apr 25, 2020
bd0d23a
fix attr init order
Apr 25, 2020
fb8c858
clean up setting of default_root_dir attr
Apr 25, 2020
892bff3
logger needs default root dir set first
Apr 25, 2020
368787d
reorg trainer init
Apr 25, 2020
191f3e8
remove direct references to checkpoint callback
Apr 26, 2020
b0a0b22
more fixes
Apr 26, 2020
3fe257e
more bugfixes
Apr 26, 2020
1649e9c
run callbacks at epoch end
Apr 28, 2020
4a3146a
update tests to use on epoch end
Apr 29, 2020
460fcef
PR cleanup
Apr 29, 2020
9f51575
address failing tests
Apr 30, 2020
e47d251
refactor for homogeneity
May 1, 2020
78f7efb
fix merge conflict
May 5, 2020
a4c72cc
separate tests
May 22, 2020
d81c90c
tests for early stopping bug regressions
May 23, 2020
fc616f2
small fixes
May 23, 2020
78a092d
revert model checkpoint change
May 23, 2020
8da8f64
typo fix
May 23, 2020
02694cf
fix tests
May 24, 2020
5692f5f
update train loop
May 25, 2020
47c2c74
fix test case
May 27, 2020
6aee109
appease the linter
May 27, 2020
84a2da7
fix some doctests
May 29, 2020
6bc50df
move config to callback
May 30, 2020
9b44672
fixes from rebase
Jun 1, 2020
e8d7c37
fixes from rebase
Jun 1, 2020
d1d7aa2
chlog
Borda Jun 1, 2020
6aee4d2
docs
Borda Jun 1, 2020
8daadc1
reformat
Borda Jun 1, 2020
8a623f0
formatting
Borda Jun 1, 2020
b20b1c1
fix
Borda Jun 1, 2020
a52983f
fix
Borda Jun 1, 2020
5e4e710
fixes from rebase
Jun 16, 2020
4ea2d99
add new test for patience
Jun 16, 2020
3c4d31e
Update pytorch_lightning/callbacks/model_checkpoint.py
jeremyjordan Jun 16, 2020
d893b18
Update pytorch_lightning/callbacks/model_checkpoint.py
jeremyjordan Jun 16, 2020
d650b74
Update tests/callbacks/test_early_stopping.py
jeremyjordan Jun 16, 2020
4eb8905
fix formatting
Jun 17, 2020
9f34584
remove enable_early_stop attribute
Jun 17, 2020
7e6b93e
Merge branch 'bugfix/early-stopping-state' of https://github.com/jere…
Jun 17, 2020
5beb38f
fix test with new epoch indexing
Jun 17, 2020
2a2250c
Merge branch 'master' into bugfix/early-stopping-state
Jun 17, 2020
827f573
Merge branch 'master' into bugfix/early-stopping-state
awaelchli Jun 22, 2020
1a39e1d
fix progress bar totals
awaelchli Jun 22, 2020
c5330ed
fix off by one error (see #2289) epoch starts at 0 now
awaelchli Jun 22, 2020
c86d08c
added missing imports
awaelchli Jun 22, 2020
776bc64
fix hpc_save folderpath
Jun 23, 2020
3b9dbde
fix formatting
Jun 23, 2020
3767d94
Merge branch 'master' into bugfix/early-stopping-state
Borda Jun 23, 2020
47a02a1
fix tests
Jun 24, 2020
2f8c62c
Merge branch 'bugfix/early-stopping-state' of https://github.com/jere…
Jun 24, 2020
780b0f2
small fixes from a rebase
Jun 24, 2020
a46ab9a
fix
Borda Jun 24, 2020
50174ae
tmpdir
Borda Jun 24, 2020
5180ce0
tmpdir
Borda Jun 24, 2020
655a60c
tmpdir
Borda Jun 24, 2020
a19989d
wandb
Borda Jun 24, 2020
8b6f4cc
Merge branch 'bugfix/early-stopping-state' of https://github.com/jere…
Jun 25, 2020
5b3d639
Merge branch 'master' into bugfix/early-stopping-state
Jun 25, 2020
ed4d66c
Merge branch 'master' into bugfix/early-stopping-jeremy
awaelchli Jun 27, 2020
7e08c6e
fix merge conflict
awaelchli Jun 27, 2020
16f1448
add back evaluation after training
awaelchli Jun 27, 2020
04f20a5
test_resume_early_stopping_from_checkpoint TODO
awaelchli Jun 27, 2020
6e48dd8
Merge branch 'master' into bugfix/early-stopping-jeremy
awaelchli Jun 27, 2020
86bd66f
undo the horovod check
awaelchli Jun 27, 2020
f34ac7d
update changelog
awaelchli Jun 27, 2020
02ccd19
remove a duplicate test from merge error
awaelchli Jun 27, 2020
a525299
try fix dp_resume test
awaelchli Jun 27, 2020
ed1e0c6
Merge branch 'master' into bugfix/early-stopping-jeremy
awaelchli Jun 27, 2020
651fb09
add the logger fix from master
awaelchli Jun 27, 2020
335a2e5
try remove default_root_dir
awaelchli Jun 28, 2020
ce39095
Merge branch 'master' into bugfix/early-stopping-jeremy
awaelchli Jun 28, 2020
aa7fb92
try mocking numpy
awaelchli Jun 28, 2020
978bed0
try import numpy in docs test
awaelchli Jun 28, 2020
a06970a
fix wandb test
awaelchli Jun 28, 2020
6a2acf3
pep 8 fix
awaelchli Jun 28, 2020
594795a
skip if no amp
awaelchli Jun 28, 2020
b6c99b4
dont mock when doctesting
awaelchli Jun 28, 2020
45c1cbf
install extra
awaelchli Jun 28, 2020
2ea95ec
Merge branch 'docs/dont-mock-when-doctesting' into bugfix/early-stopp…
awaelchli Jun 28, 2020
4e694f1
fix the resume ES test
awaelchli Jun 28, 2020
c20cc4f
Merge branch 'master' into bugfix/early-stopping-state
awaelchli Jun 28, 2020
5f72cec
undo conf.py changes
awaelchli Jun 28, 2020
ae75fa4
revert remove comet pickle from test
awaelchli Jun 28, 2020
2463b41
Update CHANGELOG.md
williamFalcon Jun 28, 2020
1e822ad
Update weights_loading.rst
williamFalcon Jun 28, 2020
e4d450e
Update weights_loading.rst
williamFalcon Jun 28, 2020
d840226
Update weights_loading.rst
williamFalcon Jun 28, 2020
21d6c8c
renamed flag
williamFalcon Jun 28, 2020
7170205
renamed flag
williamFalcon Jun 28, 2020
37304c6
revert the None check in logger experiment name/version
awaelchli Jun 28, 2020
b23d1fa
Merge branch 'master' into bugfix/early-stopping-jeremy
awaelchli Jun 28, 2020
c16cf77
add the old comments
awaelchli Jun 28, 2020
88454f0
_experiment
awaelchli Jun 28, 2020
d3edf9c
test chckpointing on DDP
Jun 28, 2020
0b3d402
skip the ddp test on windows
awaelchli Jun 28, 2020
190e761
cloudpickle
Borda Jun 28, 2020
97b04c4
Merge branch 'master' into bugfix/early-stopping-state
Borda Jun 28, 2020
c62150f
renamed flag
williamFalcon Jun 28, 2020
137e38f
renamed flag
williamFalcon Jun 28, 2020
d15cb70
parentheses for clarity
Jun 28, 2020
18cc130
apply suggestion max epochs
awaelchli Jun 28, 2020
cea1c0d
Merge branch 'master' into bugfix/early-stopping-state
williamFalcon Jun 29, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
separate tests
Jeremy Jordan committed Jun 1, 2020
commit 9972af9f00f409c6b347f003b263f8c684203d07
83 changes: 0 additions & 83 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,6 @@
import tests.base.utils as tutils
from pytorch_lightning import Callback
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from tests.base import EvalModelTemplate

@@ -200,85 +199,3 @@ def on_test_end(self, trainer, pl_module):
assert not test_callback.on_validation_end_called
assert not test_callback.on_validation_batch_end_called
assert not test_callback.on_validation_batch_start_called


def test_early_stopping_no_val_step(tmpdir):
"""Test that early stopping callback falls back to training metrics when no validation defined."""

class CurrentModel(EvalModelTemplate):
def training_step(self, *args, **kwargs):
output = super().training_step(*args, **kwargs)
output.update({'my_train_metric': output['loss']}) # could be anything else
return output

model = CurrentModel()
model.validation_step = None
model.val_dataloader = None

stopping = EarlyStopping(monitor='my_train_metric', min_delta=0.1)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=stopping,
overfit_pct=0.20,
max_epochs=5,
)
result = trainer.fit(model)

assert result == 1, 'training failed to complete'
assert trainer.current_epoch < trainer.max_epochs


def test_pickling(tmpdir):
import pickle
early_stopping = EarlyStopping()
ckpt = ModelCheckpoint(tmpdir)

early_stopping_pickled = pickle.dumps(early_stopping)
ckpt_pickled = pickle.dumps(ckpt)

early_stopping_loaded = pickle.loads(early_stopping_pickled)
ckpt_loaded = pickle.loads(ckpt_pickled)

assert vars(early_stopping) == vars(early_stopping_loaded)
assert vars(ckpt) == vars(ckpt_loaded)


@pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2])
def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
""" Test that None in checkpoint callback is valid and that chkp_path is set correctly """
tutils.reset_seed()
model = EvalModelTemplate()

checkpoint = ModelCheckpoint(filepath=None, save_top_k=save_top_k)

trainer = Trainer(default_root_dir=tmpdir,
checkpoint_callback=checkpoint,
overfit_pct=0.20,
max_epochs=5
)
trainer.fit(model)

# These should be different if the dirpath has be overridden
assert trainer.ckpt_path != trainer.default_root_dir


@pytest.mark.parametrize(
'logger_version,expected',
[(None, 'version_0'), (1, 'version_1'), ('awesome', 'awesome')],
)
def test_model_checkpoint_path(tmpdir, logger_version, expected):
"""Test that "version_" prefix is only added when logger's version is an integer"""
tutils.reset_seed()
model = EvalModelTemplate()
logger = TensorBoardLogger(str(tmpdir), version=logger_version)

trainer = Trainer(
default_root_dir=tmpdir,
overfit_pct=0.2,
max_epochs=5,
logger=logger
)
trainer.fit(model)

ckpt_version = Path(trainer.ckpt_path).parent.name
assert ckpt_version == expected
44 changes: 44 additions & 0 deletions tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest

import tests.base.utils as tutils
from pytorch_lightning import Callback
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from tests.base import EvalModelTemplate
from pathlib import Path


# TODO remove this test
def test_early_stopping_no_val_step(tmpdir):
"""Test that early stopping callback falls back to training metrics when no validation defined."""

class CurrentModel(EvalModelTemplate):
def training_step(self, *args, **kwargs):
output = super().training_step(*args, **kwargs)
output.update({'my_train_metric': output['loss']}) # could be anything else
return output

model = CurrentModel()
model.validation_step = None
model.val_dataloader = None

stopping = EarlyStopping(monitor='my_train_metric', min_delta=0.1)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=stopping,
overfit_pct=0.20,
max_epochs=5,
)
result = trainer.fit(model)

assert result == 1, 'training failed to complete'
assert trainer.current_epoch < trainer.max_epochs


def test_pickling(tmpdir):
import pickle
early_stopping = EarlyStopping()
early_stopping_pickled = pickle.dumps(early_stopping)
early_stopping_loaded = pickle.loads(early_stopping_pickled)
Borda marked this conversation as resolved.
Show resolved Hide resolved
assert vars(early_stopping) == vars(early_stopping_loaded)
78 changes: 78 additions & 0 deletions tests/callbacks/test_learning_rate_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import tests.base.utils as tutils
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.callbacks import LearningRateLogger
from tests.base import EvalModelTemplate


def test_lr_logger_single_lr(tmpdir):
""" Test that learning rates are extracted and logged for single lr scheduler"""
tutils.reset_seed()

model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__single_scheduler

lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=5,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
results = trainer.fit(model)

assert results == 1
assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of lr schedulers'
assert all([k in ['lr-Adam'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'


def test_lr_logger_multi_lrs(tmpdir):
""" Test that learning rates are extracted and logged for multi lr schedulers """
tutils.reset_seed()

model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__multiple_schedulers

lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
results = trainer.fit(model)

assert results == 1
assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of lr schedulers'
assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'


def test_lr_logger_param_groups(tmpdir):
""" Test that learning rates are extracted and logged for single lr scheduler"""
tutils.reset_seed()

model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__param_groups

lr_logger = LearningRateLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=5,
val_percent_check=0.1,
train_percent_check=0.5,
callbacks=[lr_logger]
)
results = trainer.fit(model)

assert lr_logger.lrs, 'No learning rates logged'
assert len(lr_logger.lrs) == 2 * len(trainer.lr_schedulers), \
'Number of learning rates logged does not match number of param groups'
assert all([k in ['lr-Adam/pg1', 'lr-Adam/pg2'] for k in lr_logger.lrs.keys()]), \
'Names of learning rates not set correctly'
58 changes: 58 additions & 0 deletions tests/callbacks/test_model_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pytest

import tests.base.utils as tutils
from pytorch_lightning import Callback
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from tests.base import EvalModelTemplate
from pathlib import Path


@pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2])
def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
""" Test that None in checkpoint callback is valid and that chkp_path is set correctly """
tutils.reset_seed()
model = EvalModelTemplate()

checkpoint = ModelCheckpoint(filepath=None, save_top_k=save_top_k)

trainer = Trainer(default_root_dir=tmpdir,
checkpoint_callback=checkpoint,
overfit_pct=0.20,
max_epochs=5
)
trainer.fit(model)

# These should be different if the dirpath has be overridden
assert trainer.ckpt_path != trainer.default_root_dir
awaelchli marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize(
'logger_version,expected',
[(None, 'version_0'), (1, 'version_1'), ('awesome', 'awesome')],
)
def test_model_checkpoint_path(tmpdir, logger_version, expected):
"""Test that "version_" prefix is only added when logger's version is an integer"""
tutils.reset_seed()
model = EvalModelTemplate()
logger = TensorBoardLogger(str(tmpdir), version=logger_version)

trainer = Trainer(
default_root_dir=tmpdir,
overfit_pct=0.2,
max_epochs=5,
logger=logger
)
trainer.fit(model)

ckpt_version = Path(trainer.ckpt_path).parent.name
assert ckpt_version == expected


def test_pickling(tmpdir):
import pickle
ckpt = ModelCheckpoint(tmpdir)
ckpt_pickled = pickle.dumps(ckpt)
ckpt_loaded = pickle.loads(ckpt_pickled)
Borda marked this conversation as resolved.
Show resolved Hide resolved
assert vars(ckpt) == vars(ckpt_loaded)