From 74f3a03d256b4f096d71cd088eb19896413a3d08 Mon Sep 17 00:00:00 2001 From: lgvaz Date: Wed, 20 May 2020 16:38:54 -0300 Subject: [PATCH 1/5] saves model every epoch --- pytorch_lightning/callbacks/model_checkpoint.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index a65855e49cda5..dcce9d23d9054 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -43,6 +43,7 @@ class ModelCheckpoint(Callback): monitor: quantity to monitor. verbose: verbosity mode. Default: ``False``. + save_last: always saves the model at the end of the epoch. Default: ``False``. save_top_k: if `save_top_k == k`, the best k models according to the quantity monitored will be saved. @@ -83,7 +84,7 @@ class ModelCheckpoint(Callback): """ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', verbose: bool = False, - save_top_k: int = 1, save_weights_only: bool = False, + save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False, mode: str = 'auto', period: int = 1, prefix: str = ''): super().__init__() if save_top_k > 0 and filepath is not None and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0: @@ -103,6 +104,7 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve else: self.dirpath, self.filename = os.path.split(filepath) os.makedirs(self.dirpath, exist_ok=True) + self.save_last = save_last self.save_top_k = save_top_k self.save_weights_only = save_weights_only self.period = period @@ -217,6 +219,10 @@ def on_validation_end(self, trainer, pl_module): self.epoch_last_check = epoch + if self.save_last: + filepath = os.path.join(self.dirpath, self.prefix + 'last.ckpt') + self._save_model(filepath) + filepath = self.format_checkpoint_name(epoch, metrics) version_cnt = 0 while os.path.isfile(filepath): From b37201ad6eb16b415647876eb6e2b8c0db7bdf36 Mon Sep 17 00:00:00 2001 From: lgvaz Date: Wed, 20 May 2020 17:02:56 -0300 Subject: [PATCH 2/5] implement test for save_last --- tests/trainer/test_trainer.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 1c2c169191564..7d04e1eac6da0 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -255,19 +255,21 @@ def test_dp_output_reduce(): assert reduced['b']['c'] == out['b']['c'] -@pytest.mark.parametrize(["save_top_k", "file_prefix", "expected_files"], [ - pytest.param(-1, '', {'epoch=4.ckpt', 'epoch=3.ckpt', 'epoch=2.ckpt', 'epoch=1.ckpt', 'epoch=0.ckpt'}, +@pytest.mark.parametrize(["save_top_k", "save_last", "file_prefix", "expected_files"], [ + pytest.param(-1, False, '', {'epoch=4.ckpt', 'epoch=3.ckpt', 'epoch=2.ckpt', 'epoch=1.ckpt', 'epoch=0.ckpt'}, id="CASE K=-1 (all)"), - pytest.param(1, 'test_prefix_', {'test_prefix_epoch=4.ckpt'}, + pytest.param(1, False, 'test_prefix_', {'test_prefix_epoch=4.ckpt'}, id="CASE K=1 (2.5, epoch 4)"), - pytest.param(2, '', {'epoch=4.ckpt', 'epoch=2.ckpt'}, + pytest.param(2, False, '', {'epoch=4.ckpt', 'epoch=2.ckpt'}, id="CASE K=2 (2.5 epoch 4, 2.8 epoch 2)"), - pytest.param(4, '', {'epoch=1.ckpt', 'epoch=4.ckpt', 'epoch=3.ckpt', 'epoch=2.ckpt'}, + pytest.param(4, False, '', {'epoch=1.ckpt', 'epoch=4.ckpt', 'epoch=3.ckpt', 'epoch=2.ckpt'}, id="CASE K=4 (save all 4 base)"), - pytest.param(3, '', {'epoch=2.ckpt', 'epoch=3.ckpt', 'epoch=4.ckpt'}, + pytest.param(3, False, '', {'epoch=2.ckpt', 'epoch=3.ckpt', 'epoch=4.ckpt'}, + id="CASE K=3 (save the 2nd, 3rd, 4th model)"), + pytest.param(1, True, '', {'epoch=4.ckpt', 'last.ckpt'}, id="CASE K=3 (save the 2nd, 3rd, 4th model)"), ]) -def test_model_checkpoint_options(tmpdir, save_top_k, file_prefix, expected_files): +def test_model_checkpoint_options(tmpdir, save_top_k, save_last, file_prefix, expected_files): """Test ModelCheckpoint options.""" def mock_save_function(filepath, *args): @@ -276,7 +278,8 @@ def mock_save_function(filepath, *args): # simulated losses losses = [10, 9, 2.8, 5, 2.5] - checkpoint_callback = ModelCheckpoint(tmpdir, save_top_k=save_top_k, prefix=file_prefix, verbose=1) + checkpoint_callback = ModelCheckpoint(tmpdir, save_top_k=save_top_k, save_last=save_last, + prefix=file_prefix, verbose=1) checkpoint_callback.save_function = mock_save_function trainer = Trainer() From a7c52a18a9a065fe32acd14eb9e18bfdef774bf9 Mon Sep 17 00:00:00 2001 From: Lucas Vazquez Date: Thu, 21 May 2020 14:37:11 -0300 Subject: [PATCH 3/5] Update CHANGELOG.md --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ff2c124eafd63..00eaf2d20e116 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723)). +- Added option `save_last` to save the model at the end of every epoch in `ModelCheckpoint` (https://github.com/PyTorchLightning/pytorch-lightning/pull/1908) + ### Changed - Allow user to select individual TPU core to train on ([#1729](https://github.com/PyTorchLightning/pytorch-lightning/pull/1729)) From 6442a85240609b08590ea752e08cf24bb78f482c Mon Sep 17 00:00:00 2001 From: Lucas Vazquez Date: Thu, 21 May 2020 14:38:05 -0300 Subject: [PATCH 4/5] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 00eaf2d20e116..2b24478ae8854 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723)). -- Added option `save_last` to save the model at the end of every epoch in `ModelCheckpoint` (https://github.com/PyTorchLightning/pytorch-lightning/pull/1908) +- Added option `save_last` to save the model at the end of every epoch in `ModelCheckpoint` [(#1908)](https://github.com/PyTorchLightning/pytorch-lightning/pull/1908) ### Changed From 6803c2eef19aa816d127facb3dcb180edd57a79d Mon Sep 17 00:00:00 2001 From: Lucas Vazquez Date: Fri, 22 May 2020 00:44:33 -0300 Subject: [PATCH 5/5] changes test description Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> --- tests/trainer/test_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 7d04e1eac6da0..91bda756202d9 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -267,7 +267,7 @@ def test_dp_output_reduce(): pytest.param(3, False, '', {'epoch=2.ckpt', 'epoch=3.ckpt', 'epoch=4.ckpt'}, id="CASE K=3 (save the 2nd, 3rd, 4th model)"), pytest.param(1, True, '', {'epoch=4.ckpt', 'last.ckpt'}, - id="CASE K=3 (save the 2nd, 3rd, 4th model)"), + id="CASE K=1 (save the 4th model and the last model)"), ]) def test_model_checkpoint_options(tmpdir, save_top_k, save_last, file_prefix, expected_files): """Test ModelCheckpoint options."""