Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable None model checkpoint default #3669

Merged
merged 22 commits into from
Sep 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
34ec8ce
enable None model checkpoint default
williamFalcon Sep 26, 2020
430c656
enable None model checkpoint default
williamFalcon Sep 26, 2020
919e506
enable None model checkpoint default
williamFalcon Sep 26, 2020
a0a354e
enable None model checkpoint default
williamFalcon Sep 26, 2020
fd9513f
enable None model checkpoint default
williamFalcon Sep 26, 2020
6db6315
enable None model checkpoint default
williamFalcon Sep 26, 2020
219fdc4
enable None model checkpoint default
williamFalcon Sep 26, 2020
aec27c0
enable None model checkpoint default
williamFalcon Sep 26, 2020
237846f
enable None model checkpoint default
williamFalcon Sep 26, 2020
ccf6d70
enable None model checkpoint default
williamFalcon Sep 26, 2020
94eb9bb
enable None model checkpoint default
williamFalcon Sep 27, 2020
cc88049
enable None model checkpoint default
williamFalcon Sep 27, 2020
eeb367d
enable None model checkpoint default
williamFalcon Sep 27, 2020
4391f53
enable None model checkpoint default
williamFalcon Sep 27, 2020
96fca9b
enable None model checkpoint default
williamFalcon Sep 27, 2020
397cbe9
enable None model checkpoint default
williamFalcon Sep 27, 2020
4142ab8
enable None model checkpoint default
williamFalcon Sep 27, 2020
a2e7722
enable None model checkpoint default
williamFalcon Sep 27, 2020
038471e
enable None model checkpoint default
williamFalcon Sep 27, 2020
6f47bbd
enable None model checkpoint default
williamFalcon Sep 27, 2020
dcf53d2
enable None model checkpoint default
williamFalcon Sep 27, 2020
38fa422
enable None model checkpoint default
williamFalcon Sep 27, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 167 additions & 102 deletions pytorch_lightning/callbacks/model_checkpoint.py

Large diffs are not rendered by default.

11 changes: 0 additions & 11 deletions pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,17 +149,6 @@ def process_dict_result(self, output, train=False):
progress_bar_metrics = recursive_detach(progress_bar_metrics)
log_metrics = recursive_detach(log_metrics)

# replace loss with checkpoint_on
if 'loss' in callback_metrics:
callback_metrics['checkpoint_on'] = callback_metrics['loss']
callback_metrics['early_stop_on'] = callback_metrics['loss']
del callback_metrics['loss']

if 'val_loss' in callback_metrics:
callback_metrics['checkpoint_on'] = callback_metrics['val_loss']
callback_metrics['early_stop_on'] = callback_metrics['val_loss']
del callback_metrics['val_loss']

return loss, progress_bar_metrics, log_metrics, callback_metrics, hiddens

def reduce_distributed_output(self, output, num_gpus):
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,6 @@ def run_sanity_check(self, ref_model):
self.on_sanity_check_end()
self.running_sanity_check = False

@trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Borda let's please get rid of these... makes readability impossible (and IDE inspection)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, we can remove wrappers, just keep the state assignment..

def test(
self,
model: Optional[LightningModule] = None,
Expand Down
23 changes: 23 additions & 0 deletions tests/base/model_valid_epoch_ends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,29 @@ class ValidationEpochEndVariations(ABC):
"""
Houses all variations of validation_epoch_end steps
"""
def validation_epoch_end_no_monitor(self, outputs):
"""
Called at the end of validation to aggregate outputs

Args:
outputs: list of individual outputs of each validation step
"""
# if returned a scalar from validation_step, outputs is a list of tensor scalars
# we return just the average in this case (if we want)
def _mean(res, key):
# recursive mean for multilevel dicts
return torch.stack([x[key] if isinstance(x, dict) else _mean(x, key) for x in res]).mean()

val_acc_mean = _mean(outputs, 'val_acc')

# alternate between tensor and scalar
if self.current_epoch % 2 == 0:
val_acc_mean = val_acc_mean.item()

metrics_dict = {'val_acc': val_acc_mean}
results = {'progress_bar': metrics_dict, 'log': metrics_dict}
return results


def validation_epoch_end(self, outputs):
"""
Expand Down
24 changes: 24 additions & 0 deletions tests/base/model_valid_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,30 @@ def validation_step(self, batch, batch_idx, *args, **kwargs):
})
return output

def validation_step_no_monitor(self, batch, batch_idx, *args, **kwargs):
"""
Lightning calls this inside the validation loop
:param batch:
:return:
"""
self.validation_step_called = True
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x)

loss_val = self.loss(y, y_hat)

# acc
labels_hat = torch.argmax(y_hat, dim=1)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
val_acc = torch.tensor(val_acc).type_as(x)

output = OrderedDict({
'val_acc': val_acc,
'test_dic': {'val_loss_a': loss_val}
})
return output

def validation_step_result_obj(self, batch, batch_idx, *args, **kwargs):
x, y = batch
x = x.view(x.size(0), -1)
Expand Down
78 changes: 57 additions & 21 deletions tests/callbacks/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from tests.base import EvalModelTemplate
from pytorch_lightning.utilities.exceptions import MisconfigurationException


@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2])
Expand All @@ -21,7 +22,7 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
tutils.reset_seed()
model = EvalModelTemplate()

checkpoint = ModelCheckpoint(filepath=None, save_top_k=save_top_k)
checkpoint = ModelCheckpoint(monitor='val_loss', filepath=None, save_top_k=save_top_k)

trainer = Trainer(
default_root_dir=tmpdir,
Expand Down Expand Up @@ -97,7 +98,7 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir):
"""Test to ensure that the model callback saves the checkpoints only once in distributed mode."""
model = EvalModelTemplate()
num_epochs = 4
model_checkpoint = ModelCheckpointTestInvocations(
model_checkpoint = ModelCheckpointTestInvocations(monitor='val_loss',
expected_count=num_epochs, save_top_k=-1
)
trainer = Trainer(
Expand Down Expand Up @@ -131,23 +132,24 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir):
assert ckpt_name == 'test@epoch=3,acc=0.03000'
ModelCheckpoint.CHECKPOINT_JOIN_CHAR = char_org
# no filepath set
ckpt_name = ModelCheckpoint(filepath=None).format_checkpoint_name(3, {})
ckpt_name = ModelCheckpoint(monitor='val_loss', filepath=None).format_checkpoint_name(3, {})
assert ckpt_name == 'epoch=3.ckpt'
ckpt_name = ModelCheckpoint(filepath='').format_checkpoint_name(5, {})
ckpt_name = ModelCheckpoint(monitor='val_loss', filepath='').format_checkpoint_name(5, {})
assert ckpt_name == 'epoch=5.ckpt'
# CWD
ckpt_name = ModelCheckpoint(filepath='.').format_checkpoint_name(3, {})
ckpt_name = ModelCheckpoint(monitor='val_loss', filepath='.').format_checkpoint_name(3, {})
assert Path(ckpt_name) == Path('.') / 'epoch=3.ckpt'
# dir does not exist so it is used as filename
filepath = tmpdir / 'dir'
ckpt_name = ModelCheckpoint(filepath=filepath, prefix='test').format_checkpoint_name(3, {})
ckpt_name = ModelCheckpoint(monitor='val_loss', filepath=filepath, prefix='test').format_checkpoint_name(3, {})
assert ckpt_name == tmpdir / 'test-dir.ckpt'
# now, dir exists
os.mkdir(filepath)
ckpt_name = ModelCheckpoint(filepath=filepath, prefix='test').format_checkpoint_name(3, {})
ckpt_name = ModelCheckpoint(monitor='val_loss', filepath=filepath, prefix='test').format_checkpoint_name(3, {})
assert ckpt_name == filepath / 'test-epoch=3.ckpt'
# with ver
ckpt_name = ModelCheckpoint(filepath=tmpdir / 'name', prefix='test').format_checkpoint_name(3, {}, ver=3)
ckpt_name = ModelCheckpoint(monitor='val_loss',
filepath=tmpdir / 'name', prefix='test').format_checkpoint_name(3, {}, ver=3)
assert ckpt_name == tmpdir / 'test-name-v3.ckpt'


Expand All @@ -156,7 +158,7 @@ def test_model_checkpoint_save_last(tmpdir):
model = EvalModelTemplate()
epochs = 3
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}'
model_checkpoint = ModelCheckpoint(filepath=tmpdir, save_top_k=-1, save_last=True)
model_checkpoint = ModelCheckpoint(monitor='val_loss', filepath=tmpdir, save_top_k=-1, save_last=True)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
Expand All @@ -167,18 +169,27 @@ def test_model_checkpoint_save_last(tmpdir):
last_filename = model_checkpoint._format_checkpoint_name(ModelCheckpoint.CHECKPOINT_NAME_LAST, epochs - 1, {})
last_filename = last_filename + '.ckpt'
assert str(tmpdir / last_filename) == model_checkpoint.last_model_path
assert set(os.listdir(tmpdir)) == set(
[f'epoch={i}.ckpt' for i in range(epochs)] + [last_filename, 'lightning_logs']
)
assert set(os.listdir(tmpdir)) == \
set([f'epoch={i}.ckpt' for i in range(epochs)] + [last_filename, 'lightning_logs'])
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last'


def test_none_monitor_top_k(tmpdir):
"""
Make sure that when saving top k of anything (if it's not 1), then monitor cannot be none
"""
seed_everything(100)
num_epochs = 3
with pytest.raises(MisconfigurationException, match=r'To save checkpoints for a top_k metric.*'):
ModelCheckpoint(filepath=tmpdir, save_top_k=num_epochs, save_last=True)


def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
"""Tests that the save_last checkpoint contains the latest information."""
seed_everything(100)
model = EvalModelTemplate()
num_epochs = 3
model_checkpoint = ModelCheckpoint(filepath=tmpdir, save_top_k=num_epochs, save_last=True)
model_checkpoint = ModelCheckpoint(monitor='val_loss', filepath=tmpdir, save_top_k=num_epochs, save_last=True)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
Expand All @@ -193,10 +204,6 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
ckpt_last_epoch = torch.load(path_last_epoch)
ckpt_last = torch.load(model_checkpoint.last_model_path)
assert all(ckpt_last_epoch[k] == ckpt_last[k] for k in ("epoch", "global_step"))
assert all(
ckpt_last["callbacks"][type(model_checkpoint)][k] == ckpt_last_epoch["callbacks"][type(model_checkpoint)][k]
for k in ("best_model_score", "best_model_path")
)

# it is easier to load the model objects than to iterate over the raw dict of tensors
model_last_epoch = EvalModelTemplate.load_from_checkpoint(path_last_epoch)
Expand All @@ -208,7 +215,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
def test_model_checkpoint_none_monitor(tmpdir):
model = EvalModelTemplate()
epochs = 2
checkpoint_callback = ModelCheckpoint(filepath=tmpdir, monitor=None, save_top_k=-1)
checkpoint_callback = ModelCheckpoint(monitor='val_loss', filepath=tmpdir, save_top_k=-1)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
Expand Down Expand Up @@ -240,7 +247,7 @@ def test_ckpt_metric_names(tmpdir):
progress_bar_refresh_rate=0,
limit_train_batches=0.01,
limit_val_batches=0.01,
checkpoint_callback=ModelCheckpoint(filepath=tmpdir + "/{val_loss:.2f}"),
checkpoint_callback=ModelCheckpoint(monitor='val_loss', filepath=tmpdir + "/{val_loss:.2f}"),
)

trainer.fit(model)
Expand All @@ -253,6 +260,35 @@ def test_ckpt_metric_names(tmpdir):
assert len(val) > 3


def test_default_checkpoint_behavior(tmpdir):
seed_everything(1234)

os.environ['PL_DEV_DEBUG'] = '1'
model = EvalModelTemplate()
model.validation_step = model.validation_step_no_monitor
model.validation_epoch_end = model.validation_epoch_end_no_monitor

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=3,
progress_bar_refresh_rate=0,
limit_train_batches=5,
limit_val_batches=5,
)

trainer.fit(model)
results = trainer.test()

assert len(results) == 1
assert results[0]['test_acc'] >= 0.80
assert len(trainer.dev_debugger.checkpoint_callback_history) == 3

# make sure the checkpoint we saved has the metric in the name
ckpts = os.listdir(os.path.join(tmpdir, 'lightning_logs', 'version_0', 'checkpoints'))
assert len(ckpts) == 1
assert ckpts[0] == 'epoch=2.ckpt'


def test_ckpt_metric_names_results(tmpdir):
model = EvalModelTemplate()
model.training_step = model.training_step_result_obj
Expand All @@ -271,7 +307,7 @@ def test_ckpt_metric_names_results(tmpdir):
progress_bar_refresh_rate=0,
limit_train_batches=0.01,
limit_val_batches=0.01,
checkpoint_callback=ModelCheckpoint(filepath=tmpdir + "/{val_loss:.2f}"),
checkpoint_callback=ModelCheckpoint(monitor='val_loss', filepath=tmpdir + "/{val_loss:.2f}"),
)

trainer.fit(model)
Expand All @@ -294,7 +330,7 @@ def test_model_checkpoint_save_last_warning(tmpdir, caplog, max_epochs, should_v
model.validation_step = None
trainer = Trainer(
default_root_dir=tmpdir,
checkpoint_callback=ModelCheckpoint(filepath=tmpdir, save_top_k=0, save_last=save_last),
checkpoint_callback=ModelCheckpoint(monitor='val_loss', filepath=tmpdir, save_top_k=0, save_last=save_last),
max_epochs=max_epochs,
)
trainer.fit(model)
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def test_train_loop_only(tmpdir):
# fit model
result = trainer.fit(model, dm)
assert result == 1
assert trainer.logger_connector.callback_metrics['checkpoint_on'] < 0.6
assert trainer.logger_connector.callback_metrics['loss'] < 0.6


def test_train_val_loop_only(tmpdir):
Expand All @@ -213,7 +213,7 @@ def test_train_val_loop_only(tmpdir):
# fit model
result = trainer.fit(model, dm)
assert result == 1
assert trainer.logger_connector.callback_metrics['checkpoint_on'] < 0.6
assert trainer.logger_connector.callback_metrics['loss'] < 0.6


def test_test_loop_only(tmpdir):
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def test_load_model_from_checkpoint(tmpdir, model_template):
max_epochs=2,
limit_train_batches=0.4,
limit_val_batches=0.2,
checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1),
checkpoint_callback=ModelCheckpoint(tmpdir, monitor='val_loss', save_top_k=-1),
default_root_dir=tmpdir,
)

Expand Down
16 changes: 8 additions & 8 deletions tests/trainer/test_eval_loop_dict_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_validation_step_dict_return(tmpdir):
# eval_results are output of _evaluate
callback_metrics, eval_results = trainer.run_evaluation(test_mode=False)
assert len(callback_metrics) == 2
assert len(callback_metrics[0]) == 6
assert len(callback_metrics[0]) == 5
assert len(eval_results) == 2
assert eval_results[0]['log']['log_acc1'] == 12
assert eval_results[1]['log']['log_acc1'] == 13
Expand All @@ -136,7 +136,7 @@ def test_validation_step_dict_return(tmpdir):
assert k in eval_results[1]

# ensure all the keys ended up as candidates for callbacks
assert len(trainer.logger_connector.callback_metrics) in [8, 9]
assert len(trainer.logger_connector.callback_metrics) in [9, 10]

# make sure correct steps were called
assert model.validation_step_called
Expand Down Expand Up @@ -198,7 +198,7 @@ def test_val_step_step_end(tmpdir):
# eval_results are output of _evaluate
callback_metrics, eval_results = trainer.run_evaluation(test_mode=False)
assert len(callback_metrics) == 2
assert len(callback_metrics[0]) == 7
assert len(callback_metrics[0]) == 6

callback_metrics = callback_metrics[0]
assert callback_metrics['val_step_end'] == 1802
Expand All @@ -211,7 +211,7 @@ def test_val_step_step_end(tmpdir):
assert k in eval_results[1]

# ensure all the keys ended up as candidates for callbacks
assert len(trainer.logger_connector.callback_metrics) in [9, 10]
assert len(trainer.logger_connector.callback_metrics) in [10, 11]

# make sure correct steps were called
assert model.validation_step_called
Expand Down Expand Up @@ -243,7 +243,7 @@ def test_no_val_step_end(tmpdir):
# eval_results are output of _evaluate
callback_metrics, eval_results = trainer.run_evaluation(test_mode=False)
assert len(callback_metrics) == 1
assert len(callback_metrics[0]) == 7
assert len(callback_metrics[0]) == 6
assert len(eval_results) == 1

eval_results = eval_results[0]
Expand All @@ -254,7 +254,7 @@ def test_no_val_step_end(tmpdir):
assert k in eval_results

# ensure all the keys ended up as candidates for callbacks
assert len(trainer.logger_connector.callback_metrics) in [9, 10]
assert len(trainer.logger_connector.callback_metrics) in [10, 11]

# make sure correct steps were called
assert model.validation_step_called
Expand Down Expand Up @@ -286,7 +286,7 @@ def test_full_val_loop(tmpdir):
# eval_results are output of _evaluate
callback_metrics, eval_results = trainer.run_evaluation(test_mode=False)
assert len(callback_metrics) == 1
assert len(callback_metrics[0]) == 8
assert len(callback_metrics[0]) == 7
assert len(eval_results) == 1

eval_results = eval_results[0]
Expand All @@ -297,7 +297,7 @@ def test_full_val_loop(tmpdir):
assert k in eval_results

# ensure all the keys ended up as candidates for callbacks
assert len(trainer.logger_connector.callback_metrics) in [10, 11]
assert len(trainer.logger_connector.callback_metrics) in [11, 12]

# make sure correct steps were called
assert model.validation_step_called
Expand Down
6 changes: 3 additions & 3 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def mock_save_function(filepath, *args):
# simulated losses
losses = [10, 9, 2.8, 5, 2.5]

checkpoint_callback = ModelCheckpoint(tmpdir, save_top_k=save_top_k, save_last=save_last,
checkpoint_callback = ModelCheckpoint(tmpdir, monitor='checkpoint_on', save_top_k=save_top_k, save_last=save_last,
prefix=file_prefix, verbose=1)
checkpoint_callback.save_function = mock_save_function
trainer = Trainer()
Expand Down Expand Up @@ -507,7 +507,7 @@ def increment_on_load_checkpoint(self, _):
max_epochs=2,
limit_train_batches=0.65,
limit_val_batches=1,
checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1),
checkpoint_callback=ModelCheckpoint(tmpdir, monitor='val_loss', save_top_k=-1),
default_root_dir=tmpdir,
early_stop_callback=False,
val_check_interval=1.,
Expand Down Expand Up @@ -664,7 +664,7 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k):
max_epochs=2,
progress_bar_refresh_rate=0,
default_root_dir=tmpdir,
checkpoint_callback=ModelCheckpoint(save_top_k=save_top_k),
checkpoint_callback=ModelCheckpoint(monitor='val_loss', save_top_k=save_top_k),
)
trainer.fit(model)
if ckpt_path == 'best':
Expand Down