From 323f9c6d9399492071c40e9e03fd218230875f78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Sun, 27 Sep 2020 18:16:21 +0200 Subject: [PATCH 1/5] Fix ModelCheckpoint period --- CHANGELOG.md | 2 + .../callbacks/model_checkpoint.py | 41 +++++++------------ tests/callbacks/test_model_checkpoint.py | 25 ++++++++++- 3 files changed, 40 insertions(+), 28 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 10022219857ee..874b68f7aa47f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -77,6 +77,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed determinism in `DDPSpawnBackend` when using `seed_everything` in main process ([#3335](https://github.com/PyTorchLightning/pytorch-lightning/pull/3335)) +- Fixed `ModelCheckpoint` `period` to actually save every `period` epochs ([3630](https://github.com/PyTorchLightning/pytorch-lightning/pull/3630)) + ## [0.9.0] - YYYY-MM-DD ### Added diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index f85e11e0f92d3..46099b0fb502c 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -181,27 +181,21 @@ def save_checkpoint(self, trainer, pl_module): """ Performs the main logic around saving a checkpoint """ - # only run on main process - if trainer.global_rank != 0: - return - - # no models are saved - if self.save_top_k == 0: - return - - # don't save anything during sanity check - if trainer.running_sanity_check: - return + epoch = trainer.current_epoch - # skip this epoch - if self._should_skip_epoch(trainer): + if ( + trainer.global_rank != 0 # only run on main process + or self.save_top_k == 0 # no models are saved + or self.period < 1 # no models are saved + or (epoch + 1) % self.period # skip epoch + or trainer.running_sanity_check # don't save anything during sanity check + or self.epoch_last_check == epoch # already saved + ): return self._add_backward_monitor_support(trainer) self._validate_monitor_key(trainer) - epoch = trainer.current_epoch - # track epoch when ckpt was last checked self.epoch_last_check = trainer.current_epoch @@ -278,7 +272,7 @@ def __init_monitor_mode(self, monitor, mode): if mode not in mode_dict: rank_zero_warn( - f"ModelCheckpoint mode {mode} is unknown, " f"fallback to auto mode.", + f"ModelCheckpoint mode {mode} is unknown, fallback to auto mode", RuntimeWarning, ) mode = "auto" @@ -290,7 +284,6 @@ def _del_model(self, filepath: str): self._fs.rm(filepath) def _save_model(self, filepath: str, trainer, pl_module): - # in debugging, track when we save checkpoints trainer.dev_debugger.track_checkpointing_history(filepath) @@ -317,9 +310,7 @@ def check_monitor_top_k(self, current) -> bool: current = torch.tensor(current) monitor_op = {"min": torch.lt, "max": torch.gt}[self.mode] - - val = monitor_op(current, self.best_k_models[self.kth_best_model_path]) - return val + return monitor_op(current, self.best_k_models[self.kth_best_model_path]).item() @classmethod def _format_checkpoint_name( @@ -443,10 +434,6 @@ def _validate_monitor_key(self, trainer): ) raise MisconfigurationException(m) - def _should_skip_epoch(self, trainer): - epoch = trainer.current_epoch - return (self.epoch_last_check is not None) and (epoch - self.epoch_last_check) < self.period - def _get_metric_interpolated_filepath_name(self, epoch, ckpt_name_metrics): filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics) version_cnt = 0 @@ -496,8 +483,10 @@ def _save_top_k_checkpoints(self, metrics, trainer, pl_module, epoch, filepath): if current is None: m = f"Can save best model only with {self.monitor} available, skipping." if self.monitor == 'checkpoint_on': - m = 'No checkpoint_on found. Hint: Did you set it in EvalResult(checkpoint_on=tensor) or ' \ - 'TrainResult(checkpoint_on=tensor)?' + m = ( + 'No checkpoint_on found. HINT: Did you set it in ' + 'EvalResult(checkpoint_on=tensor) or TrainResult(checkpoint_on=tensor)?', + ) rank_zero_warn(m, RuntimeWarning) elif self.check_monitor_top_k(current): self._do_check_save(filepath, current, epoch, trainer, pl_module) diff --git a/tests/callbacks/test_model_checkpoint.py b/tests/callbacks/test_model_checkpoint.py index 7139d33ed46cc..3b5a3d1139a80 100644 --- a/tests/callbacks/test_model_checkpoint.py +++ b/tests/callbacks/test_model_checkpoint.py @@ -261,6 +261,7 @@ def test_model_checkpoint_none_monitor(tmpdir): early_stop_callback=False, checkpoint_callback=checkpoint_callback, max_epochs=epochs, + logger=False, ) trainer.fit(model) @@ -272,8 +273,28 @@ def test_model_checkpoint_none_monitor(tmpdir): assert checkpoint_callback.kth_best_model_path == '' # check that the correct ckpts were created - expected = ['lightning_logs'] - expected.extend(f'epoch={e}.ckpt' for e in range(epochs)) + expected = [f'epoch={e}.ckpt' for e in range(epochs)] + assert set(os.listdir(tmpdir)) == set(expected) + + +@pytest.mark.parametrize("period", list(range(4))) +def test_model_checkpoint_period(tmpdir, period): + model = EvalModelTemplate() + epochs = 5 + checkpoint_callback = ModelCheckpoint(filepath=tmpdir, save_top_k=-1, period=period) + trainer = Trainer( + default_root_dir=tmpdir, + early_stop_callback=False, + checkpoint_callback=checkpoint_callback, + max_epochs=epochs, + limit_train_batches=0.1, + limit_val_batches=0.1, + logger=False, + ) + trainer.fit(model) + + # check that the correct ckpts were created + expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) From a13b7c5d6292188ed9b6ed2e7b923a47188f0906 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Sun, 27 Sep 2020 18:47:00 +0200 Subject: [PATCH 2/5] Remove comma --- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 46099b0fb502c..a7eb33cafff41 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -485,7 +485,7 @@ def _save_top_k_checkpoints(self, metrics, trainer, pl_module, epoch, filepath): if self.monitor == 'checkpoint_on': m = ( 'No checkpoint_on found. HINT: Did you set it in ' - 'EvalResult(checkpoint_on=tensor) or TrainResult(checkpoint_on=tensor)?', + 'EvalResult(checkpoint_on=tensor) or TrainResult(checkpoint_on=tensor)?' ) rank_zero_warn(m, RuntimeWarning) elif self.check_monitor_top_k(current): From 90d132314cc275da785f7b390467805e3bfa7382 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 28 Sep 2020 02:36:06 +0200 Subject: [PATCH 3/5] Minor changes --- tests/callbacks/test_model_checkpoint.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/callbacks/test_model_checkpoint.py b/tests/callbacks/test_model_checkpoint.py index 3b5a3d1139a80..cffdf6f0b80b0 100644 --- a/tests/callbacks/test_model_checkpoint.py +++ b/tests/callbacks/test_model_checkpoint.py @@ -181,13 +181,13 @@ def test_model_checkpoint_save_last(tmpdir): early_stop_callback=False, checkpoint_callback=model_checkpoint, max_epochs=epochs, + logger=False, ) trainer.fit(model) last_filename = model_checkpoint._format_checkpoint_name(ModelCheckpoint.CHECKPOINT_NAME_LAST, epochs - 1, {}) last_filename = last_filename + '.ckpt' assert str(tmpdir / last_filename) == model_checkpoint.last_model_path - assert set(os.listdir(tmpdir)) == \ - set([f'epoch={i}.ckpt' for i in range(epochs)] + [last_filename, 'lightning_logs']) + assert set(os.listdir(tmpdir)) == set([f'epoch={i}.ckpt' for i in range(epochs)] + [last_filename]) ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last' @@ -317,7 +317,7 @@ def test_model_checkpoint_topk_zero(tmpdir): assert checkpoint_callback.best_k_models == {} assert checkpoint_callback.kth_best_model_path == '' # check that no ckpts were created - assert len(set(os.listdir(tmpdir))) == 0 + assert len(os.listdir(tmpdir)) == 0 def test_ckpt_metric_names(tmpdir): From 00d9e77b81e7dc9a6cb90c676f744bce23514cb7 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 28 Sep 2020 10:37:50 +0200 Subject: [PATCH 4/5] skip check --- .pyrightconfig.json | 1 + 1 file changed, 1 insertion(+) diff --git a/.pyrightconfig.json b/.pyrightconfig.json index 41c90bbc3573e..005a39956c87d 100644 --- a/.pyrightconfig.json +++ b/.pyrightconfig.json @@ -33,6 +33,7 @@ "pytorch_lightning/trainer/lr_scheduler_connector.py", "pytorch_lightning/trainer/training_loop_temp.py", "pytorch_lightning/trainer/connectors/checkpoint_connector.py", + "pytorch_lightning/trainer/connectors/data_connector.py", "pytorch_lightning/tuner", "pytorch_lightning/plugins" ], From 34086ba77f0fab4b85c7788d2f682aaf7448f097 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 29 Sep 2020 15:03:39 +0200 Subject: [PATCH 5/5] Revert "skip check" Already pushed to master This reverts commit 00d9e77b81e7dc9a6cb90c676f744bce23514cb7. --- .pyrightconfig.json | 1 - 1 file changed, 1 deletion(-) diff --git a/.pyrightconfig.json b/.pyrightconfig.json index 005a39956c87d..41c90bbc3573e 100644 --- a/.pyrightconfig.json +++ b/.pyrightconfig.json @@ -33,7 +33,6 @@ "pytorch_lightning/trainer/lr_scheduler_connector.py", "pytorch_lightning/trainer/training_loop_temp.py", "pytorch_lightning/trainer/connectors/checkpoint_connector.py", - "pytorch_lightning/trainer/connectors/data_connector.py", "pytorch_lightning/tuner", "pytorch_lightning/plugins" ],