diff --git a/CHANGELOG.md b/CHANGELOG.md index ff2c124eafd63..2b24478ae8854 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723)). +- Added option `save_last` to save the model at the end of every epoch in `ModelCheckpoint` [(#1908)](https://github.com/PyTorchLightning/pytorch-lightning/pull/1908) + ### Changed - Allow user to select individual TPU core to train on ([#1729](https://github.com/PyTorchLightning/pytorch-lightning/pull/1729)) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index a65855e49cda5..dcce9d23d9054 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -43,6 +43,7 @@ class ModelCheckpoint(Callback): monitor: quantity to monitor. verbose: verbosity mode. Default: ``False``. + save_last: always saves the model at the end of the epoch. Default: ``False``. save_top_k: if `save_top_k == k`, the best k models according to the quantity monitored will be saved. @@ -83,7 +84,7 @@ class ModelCheckpoint(Callback): """ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', verbose: bool = False, - save_top_k: int = 1, save_weights_only: bool = False, + save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False, mode: str = 'auto', period: int = 1, prefix: str = ''): super().__init__() if save_top_k > 0 and filepath is not None and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0: @@ -103,6 +104,7 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve else: self.dirpath, self.filename = os.path.split(filepath) os.makedirs(self.dirpath, exist_ok=True) + self.save_last = save_last self.save_top_k = save_top_k self.save_weights_only = save_weights_only self.period = period @@ -217,6 +219,10 @@ def on_validation_end(self, trainer, pl_module): self.epoch_last_check = epoch + if self.save_last: + filepath = os.path.join(self.dirpath, self.prefix + 'last.ckpt') + self._save_model(filepath) + filepath = self.format_checkpoint_name(epoch, metrics) version_cnt = 0 while os.path.isfile(filepath): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 1c2c169191564..91bda756202d9 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -255,19 +255,21 @@ def test_dp_output_reduce(): assert reduced['b']['c'] == out['b']['c'] -@pytest.mark.parametrize(["save_top_k", "file_prefix", "expected_files"], [ - pytest.param(-1, '', {'epoch=4.ckpt', 'epoch=3.ckpt', 'epoch=2.ckpt', 'epoch=1.ckpt', 'epoch=0.ckpt'}, +@pytest.mark.parametrize(["save_top_k", "save_last", "file_prefix", "expected_files"], [ + pytest.param(-1, False, '', {'epoch=4.ckpt', 'epoch=3.ckpt', 'epoch=2.ckpt', 'epoch=1.ckpt', 'epoch=0.ckpt'}, id="CASE K=-1 (all)"), - pytest.param(1, 'test_prefix_', {'test_prefix_epoch=4.ckpt'}, + pytest.param(1, False, 'test_prefix_', {'test_prefix_epoch=4.ckpt'}, id="CASE K=1 (2.5, epoch 4)"), - pytest.param(2, '', {'epoch=4.ckpt', 'epoch=2.ckpt'}, + pytest.param(2, False, '', {'epoch=4.ckpt', 'epoch=2.ckpt'}, id="CASE K=2 (2.5 epoch 4, 2.8 epoch 2)"), - pytest.param(4, '', {'epoch=1.ckpt', 'epoch=4.ckpt', 'epoch=3.ckpt', 'epoch=2.ckpt'}, + pytest.param(4, False, '', {'epoch=1.ckpt', 'epoch=4.ckpt', 'epoch=3.ckpt', 'epoch=2.ckpt'}, id="CASE K=4 (save all 4 base)"), - pytest.param(3, '', {'epoch=2.ckpt', 'epoch=3.ckpt', 'epoch=4.ckpt'}, + pytest.param(3, False, '', {'epoch=2.ckpt', 'epoch=3.ckpt', 'epoch=4.ckpt'}, id="CASE K=3 (save the 2nd, 3rd, 4th model)"), + pytest.param(1, True, '', {'epoch=4.ckpt', 'last.ckpt'}, + id="CASE K=1 (save the 4th model and the last model)"), ]) -def test_model_checkpoint_options(tmpdir, save_top_k, file_prefix, expected_files): +def test_model_checkpoint_options(tmpdir, save_top_k, save_last, file_prefix, expected_files): """Test ModelCheckpoint options.""" def mock_save_function(filepath, *args): @@ -276,7 +278,8 @@ def mock_save_function(filepath, *args): # simulated losses losses = [10, 9, 2.8, 5, 2.5] - checkpoint_callback = ModelCheckpoint(tmpdir, save_top_k=save_top_k, prefix=file_prefix, verbose=1) + checkpoint_callback = ModelCheckpoint(tmpdir, save_top_k=save_top_k, save_last=save_last, + prefix=file_prefix, verbose=1) checkpoint_callback.save_function = mock_save_function trainer = Trainer()