diff --git a/CHANGELOG.md b/CHANGELOG.md index b6f4ddf74662c..c28d2ab73ef99 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added error when AUROC metric is used for multiclass problems ([#3350](https://github.com/PyTorchLightning/pytorch-lightning/pull/3350)) +- Allow `ModelCheckpoint` monitor to be `None`, meaning it will always save ([3630](https://github.com/PyTorchLightning/pytorch-lightning/pull/3630)) + ### Changed - Changed `LearningRateLogger` to `LearningRateMonitor` ([#3251](https://github.com/PyTorchLightning/pytorch-lightning/pull/3251)) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 29eb2b1fb5dd8..fb992977ff21c 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -36,7 +36,7 @@ class ModelCheckpoint(Callback): r""" - Save the model after every epoch if it improves. + Save the model after every epoch by monitoring a quantity. After training finishes, use :attr:`best_model_path` to retrieve the path to the best checkpoint file and :attr:`best_model_score` to retrieve its score. @@ -63,7 +63,7 @@ class ModelCheckpoint(Callback): :paramref:`~pytorch_lightning.trainer.trainer.Trainer.weights_save_path` arguments, and if the Trainer uses a logger, the path will also contain logger name and version. - monitor: quantity to monitor. + monitor: quantity to monitor. If None, a checkpoint will be saved every epoch. verbose: verbosity mode. Default: ``False``. save_last: always saves the model at the end of the epoch. Default: ``False``. save_top_k: if ``save_top_k == k``, @@ -120,7 +120,7 @@ class ModelCheckpoint(Callback): def __init__( self, filepath: Optional[str] = None, - monitor: str = "checkpoint_on", + monitor: Optional[str] = "checkpoint_on", verbose: bool = False, save_last: bool = False, save_top_k: int = 1, @@ -174,7 +174,7 @@ def __init__( "min": (torch_inf, "min"), "max": (-torch_inf, "max"), "auto": (-torch_inf, "max") - if "acc" in self.monitor or self.monitor.startswith("fmeasure") + if monitor is not None and ("acc" in monitor or monitor.startswith("fmeasure")) else (torch_inf, "min"), } @@ -337,11 +337,11 @@ def on_validation_end(self, trainer, pl_module): epoch = trainer.current_epoch # validate metric - if not self._is_valid_monitor_key(metrics): - keys = list(metrics.keys()) - m = f""" - ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics ({keys}), - "did you call result.log(f'{self.monitor}', tensor)?""" + if not (self.monitor is None or self._is_valid_monitor_key(metrics)): + m = ( + f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:" + f" {list(metrics.keys())}. HINT: Did you call result.log('{self.monitor}', tensor)?" + ) raise MisconfigurationException(m) if ( diff --git a/tests/callbacks/test_model_checkpoint.py b/tests/callbacks/test_model_checkpoint.py index f7db35fe045bb..6ea07772ccb00 100644 --- a/tests/callbacks/test_model_checkpoint.py +++ b/tests/callbacks/test_model_checkpoint.py @@ -204,6 +204,30 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): assert w0.eq(w1).all() +def test_model_checkpoint_none_monitor(tmpdir): + model = EvalModelTemplate() + epochs = 2 + checkpoint_callback = ModelCheckpoint(filepath=tmpdir, monitor=None, save_top_k=-1) + trainer = Trainer( + default_root_dir=tmpdir, + early_stop_callback=False, + checkpoint_callback=checkpoint_callback, + max_epochs=epochs, + ) + trainer.fit(model) + + # these should not be set if monitor is None + assert checkpoint_callback.best_model_path == '' + assert checkpoint_callback.best_model_score == 0 + assert checkpoint_callback.best_k_models == {} + assert checkpoint_callback.kth_best_model_path == '' + + # check that the correct ckpts were created + expected = ['lightning_logs'] + expected.extend(f'epoch={e}.ckpt' for e in range(epochs)) + assert set(os.listdir(tmpdir)) == set(expected) + + def test_ckpt_metric_names(tmpdir): model = EvalModelTemplate()