Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow ModelCheckpoint monitor to be None #3633

Merged
merged 2 commits into from
Sep 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added error when AUROC metric is used for multiclass problems ([#3350](https://github.com/PyTorchLightning/pytorch-lightning/pull/3350))

- Allow `ModelCheckpoint` monitor to be `None`, meaning it will always save ([3630](https://github.com/PyTorchLightning/pytorch-lightning/pull/3630))
Borda marked this conversation as resolved.
Show resolved Hide resolved

### Changed

- Changed `LearningRateLogger` to `LearningRateMonitor` ([#3251](https://github.com/PyTorchLightning/pytorch-lightning/pull/3251))
Expand Down
18 changes: 9 additions & 9 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

class ModelCheckpoint(Callback):
r"""
Save the model after every epoch if it improves.
Save the model after every epoch by monitoring a quantity.

After training finishes, use :attr:`best_model_path` to retrieve the path to the
best checkpoint file and :attr:`best_model_score` to retrieve its score.
Expand All @@ -63,7 +63,7 @@ class ModelCheckpoint(Callback):
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.weights_save_path` arguments,
and if the Trainer uses a logger, the path will also contain logger name and version.

monitor: quantity to monitor.
monitor: quantity to monitor. If None, a checkpoint will be saved every epoch.
verbose: verbosity mode. Default: ``False``.
save_last: always saves the model at the end of the epoch. Default: ``False``.
save_top_k: if ``save_top_k == k``,
Expand Down Expand Up @@ -120,7 +120,7 @@ class ModelCheckpoint(Callback):
def __init__(
self,
filepath: Optional[str] = None,
monitor: str = "checkpoint_on",
monitor: Optional[str] = "checkpoint_on",
verbose: bool = False,
save_last: bool = False,
save_top_k: int = 1,
Expand Down Expand Up @@ -174,7 +174,7 @@ def __init__(
"min": (torch_inf, "min"),
"max": (-torch_inf, "max"),
"auto": (-torch_inf, "max")
if "acc" in self.monitor or self.monitor.startswith("fmeasure")
if monitor is not None and ("acc" in monitor or monitor.startswith("fmeasure"))
carmocca marked this conversation as resolved.
Show resolved Hide resolved
else (torch_inf, "min"),
}

Expand Down Expand Up @@ -337,11 +337,11 @@ def on_validation_end(self, trainer, pl_module):
epoch = trainer.current_epoch

# validate metric
if not self._is_valid_monitor_key(metrics):
keys = list(metrics.keys())
m = f"""
ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics ({keys}),
"did you call result.log(f'{self.monitor}', tensor)?"""
if not (self.monitor is None or self._is_valid_monitor_key(metrics)):
m = (
f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:"
f" {list(metrics.keys())}. HINT: Did you call result.log('{self.monitor}', tensor)?"
)
raise MisconfigurationException(m)

if (
Expand Down
24 changes: 24 additions & 0 deletions tests/callbacks/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,30 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
assert w0.eq(w1).all()


def test_model_checkpoint_none_monitor(tmpdir):
model = EvalModelTemplate()
epochs = 2
checkpoint_callback = ModelCheckpoint(filepath=tmpdir, monitor=None, save_top_k=-1)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
checkpoint_callback=checkpoint_callback,
max_epochs=epochs,
)
trainer.fit(model)

# these should not be set if monitor is None
assert checkpoint_callback.best_model_path == ''
assert checkpoint_callback.best_model_score == 0
assert checkpoint_callback.best_k_models == {}
assert checkpoint_callback.kth_best_model_path == ''

# check that the correct ckpts were created
expected = ['lightning_logs']
expected.extend(f'epoch={e}.ckpt' for e in range(epochs))
assert set(os.listdir(tmpdir)) == set(expected)


def test_ckpt_metric_names(tmpdir):
model = EvalModelTemplate()

Expand Down