From fd9631c4c30de212491f428790ab04f741b83af2 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 22 Feb 2021 18:11:49 +0100 Subject: [PATCH 01/10] Resolve bug --- .../callbacks/model_checkpoint.py | 117 ++++++++++-------- .../plugins/training_type/rpc_sequential.py | 2 +- tests/checkpointing/test_model_checkpoint.py | 18 ++- 3 files changed, 80 insertions(+), 57 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 461c211baab12..0be36a3f3d693 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -254,28 +254,39 @@ def save_checkpoint(self, trainer, pl_module): # callback supports multiple simultaneous modes # here we call each mode sequentially - # Mode 1: save all checkpoints OR only the top k - if self.save_top_k: - self._save_top_k_checkpoints(trainer, pl_module, monitor_candidates) - # Mode 2: save the last checkpoint - self._save_last_checkpoint(trainer, pl_module, monitor_candidates) + # Mode 1: save the top k checkpoints + if self.monitor is not None and self.save_top_k != 0: + self._save_top_k_checkpoint(trainer, pl_module, monitor_candidates) + + # Mode 2: save monitor=None checkpoints + if self.monitor is None and self.save_top_k in (None, -1): + self._save_none_monitor_checkpoint(trainer, pl_module, monitor_candidates) + + # Mode 3: save last checkpoints + if self.save_last: + self._save_last_checkpoint(trainer, pl_module, monitor_candidates) def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') if self.monitor is None: # None: save last epoch, -1: save all epochs, 0: nothing is saved - if self.save_top_k not in [None, -1, 0]: + if self.save_top_k not in (None, -1, 0): raise MisconfigurationException( f'ModelCheckpoint(save_top_k={self.save_top_k}, monitor=None) is not a valid' ' configuration. No quantity for top_k to track.' ) if self.save_last: rank_zero_warn( - 'ModelCheckpoint(save_last=True, monitor=None) is a redundant configuration.' + 'ModelCheckpoint(save_last=True, save_top_k=None, monitor=None) is a redundant configuration.' ' You can save the last checkpoint with ModelCheckpoint(save_top_k=None, monitor=None).' ) + if self.save_top_k == -1 and self.save_last: + rank_zero_info( + 'ModelCheckpoint(save_last=True, save_top_k=-1, monitor=None)' + ' will duplicate the last checkpoint saved.' + ) def __init_ckpt_dir(self, dirpath, filename, save_top_k): @@ -321,8 +332,16 @@ def _del_model(self, filepath: str): self._fs.rm(filepath) log.debug(f"Removed checkpoint: {filepath}") - def _save_model(self, filepath: str, trainer, pl_module): - # Todo: required argument `pl_module` is not used + def _save_model(self, trainer, pl_module, filepath: str): + if trainer.training_type_plugin.rpc_enabled: + # RPCPlugin manages saving all model states + # TODO: the rpc pluging should wrap trainer.save_checkpoint + # instead of us having to do it here manually + trainer.training_type_plugin.rpc_save_model(self._do_save, filepath, trainer, pl_module) + else: + self._do_save(trainer, filepath) + + def _do_save(self, trainer, filepath: str): # in debugging, track when we save checkpoints trainer.dev_debugger.track_checkpointing_history(filepath) @@ -336,7 +355,7 @@ def _save_model(self, filepath: str, trainer, pl_module): else: raise ValueError(".save_function() not set") - def check_monitor_top_k(self, current) -> bool: + def check_monitor_top_k(self, current: torch.Tensor) -> bool: if current is None: return False @@ -511,48 +530,27 @@ def _monitor_candidates(self, trainer): monitor_candidates.update(step=trainer.global_step, epoch=trainer.current_epoch) return monitor_candidates - def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics): - should_save_last = self.monitor is None or self.save_last - if not should_save_last: - return + def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics: dict): + filepath = self._format_checkpoint_name( + self.CHECKPOINT_NAME_LAST, + trainer.current_epoch, + trainer.global_step, + ckpt_name_metrics, + prefix=self.prefix + ) + filepath = os.path.join(self.dirpath, f"{filepath}{self.FILE_EXTENSION}") - # when user ALSO asked for the 'last.ckpt' change the name - if self.save_last: - last_filepath = self._format_checkpoint_name( - self.CHECKPOINT_NAME_LAST, - trainer.current_epoch, - trainer.global_step, - ckpt_name_metrics, - prefix=self.prefix - ) - last_filepath = os.path.join(self.dirpath, f"{last_filepath}{self.FILE_EXTENSION}") - else: - last_filepath = self._get_metric_interpolated_filepath_name( - ckpt_name_metrics, - trainer.current_epoch, - trainer.global_step, - trainer, - ) + self._save_model(trainer, pl_module, filepath) - if trainer.training_type_plugin.rpc_enabled: - # RPCPlugin manages saving all model states - trainer.training_type_plugin.rpc_save_model(self._save_model, last_filepath, trainer, pl_module) - else: - self._save_model(last_filepath, trainer, pl_module) - if ( - self.last_model_path and self.last_model_path != last_filepath - and (self.save_top_k != -1 or self.save_last) and trainer.is_global_zero - ): + if self.last_model_path and self.last_model_path != filepath and trainer.is_global_zero: self._del_model(self.last_model_path) - self.last_model_path = last_filepath - if self.monitor is None: - self.best_model_path = self.last_model_path + self.last_model_path = filepath - def _save_top_k_checkpoints(self, trainer, pl_module, metrics): - current = metrics.get(self.monitor) - epoch = metrics.get("epoch") - step = metrics.get("step") + def _save_top_k_checkpoint(self, trainer, pl_module, ckpt_name_metrics): + current = ckpt_name_metrics.get(self.monitor) + epoch = ckpt_name_metrics.get("epoch") + step = ckpt_name_metrics.get("step") # when `val_loss` is being logged and no ModelCheckpoint is being provided # `val_loss` will be selected for monitor and need to be reduced to @@ -563,10 +561,29 @@ def _save_top_k_checkpoints(self, trainer, pl_module, metrics): current = trainer.training_type_plugin.reduce(current, reduce_op="mean") if self.check_monitor_top_k(current): - self._update_best_and_save(current, epoch, step, trainer, pl_module, metrics) - elif self.monitor is not None and self.verbose: + self._update_best_and_save(current, epoch, step, trainer, pl_module, ckpt_name_metrics) + elif self.verbose: rank_zero_info(f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}") + def _save_none_monitor_checkpoint(self, trainer, pl_module, ckpt_name_metrics): + filepath = self._get_metric_interpolated_filepath_name( + ckpt_name_metrics, + trainer.current_epoch, + trainer.global_step, + trainer, + ) + self._save_model(trainer, pl_module, filepath) + + if ( + self.save_top_k is None + and self.best_model_path + and self.best_model_path != filepath + and trainer.is_global_zero + ): + self._del_model(self.best_model_path) + + self.best_model_path = filepath + def _is_valid_monitor_key(self, metrics): return self.monitor in metrics or len(metrics) == 0 @@ -605,7 +622,7 @@ def _update_best_and_save( f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f}" f' (best {self.best_model_score:0.5f}), saving model to "{filepath}" as top {k}' ) - self._save_model(filepath, trainer, pl_module) + self._save_model(trainer, pl_module, filepath) if del_filepath is not None and filepath != del_filepath: self._del_model(del_filepath) diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index 3878aa9db3ea4..67b9608dbc5fe 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -275,7 +275,7 @@ def rpc_save_model(self, save_model_fn, last_filepath, trainer, pl_module) -> No save_layers_on_all_rank_zero_workers, {"gpus_per_model": self.gpus_per_model}, include_self=True ) pl_module.sequential_module = load_sequential_from_saved_layers(self.gpus_per_model) - save_model_fn(last_filepath, trainer, pl_module) + save_model_fn(last_filepath, trainer) pl_module.sequential_module = current_layers def worker_optimizer_step(self, model: LightningModule, opt_idx: int, *args, **kwargs) -> None: diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index bd4a02536c5c3..a615c8ec56118 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -17,6 +17,7 @@ import platform import re from argparse import Namespace +from logging import INFO from pathlib import Path from unittest import mock from unittest.mock import Mock @@ -379,20 +380,20 @@ def test_none_monitor_top_k(tmpdir): def test_none_monitor_save_last(tmpdir): """ Test that a warning appears for save_last=True with monitor=None. """ - with pytest.warns(UserWarning, match=r'ModelCheckpoint\(save_last=True, monitor=None\) is a redundant.*'): + with pytest.warns(UserWarning, match=r'ModelCheckpoint.*is a redundant.*'): ModelCheckpoint(dirpath=tmpdir, save_last=True) # These should not fail ModelCheckpoint(dirpath=tmpdir, save_last=None) ModelCheckpoint(dirpath=tmpdir, save_last=False) -def test_model_checkpoint_none_monitor(tmpdir): +def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog): """ Test that it is possible to save all checkpoints when monitor=None. """ seed_everything() model = LogInTwoMethods() epochs = 2 - checkpoint_callback = ModelCheckpoint(monitor=None, dirpath=tmpdir, save_top_k=-1) + checkpoint_callback = ModelCheckpoint(monitor=None, dirpath=tmpdir, save_top_k=-1, save_last=True) trainer = Trainer( default_root_dir=tmpdir, callbacks=[checkpoint_callback], @@ -401,17 +402,22 @@ def test_model_checkpoint_none_monitor(tmpdir): max_epochs=epochs, logger=False, ) - trainer.fit(model) + + with caplog.at_level(INFO): + trainer.fit(model) + assert "will duplicate the last checkpoint saved" in caplog.text # these should not be set if monitor is None assert checkpoint_callback.monitor is None - assert checkpoint_callback.best_model_path == checkpoint_callback.last_model_path == tmpdir / 'epoch=1-step=19.ckpt' + assert checkpoint_callback.best_model_path == tmpdir / 'epoch=1-step=19.ckpt' + assert checkpoint_callback.last_model_path == tmpdir / 'last.ckpt' assert checkpoint_callback.best_model_score is None assert checkpoint_callback.best_k_models == {} assert checkpoint_callback.kth_best_model_path == '' # check that the correct ckpts were created expected = [f'epoch={i}-step={j}.ckpt' for i, j in zip(range(epochs), [9, 19])] + expected.append('last.ckpt') assert set(os.listdir(tmpdir)) == set(expected) @@ -971,7 +977,7 @@ def test_ckpt_version_after_rerun_same_trainer(tmpdir): # check best_k_models state assert {Path(f).name for f in mc.best_k_models.keys()} == expected # check created ckpts - assert set(sorted(os.listdir(tmpdir))) == expected + assert set(os.listdir(tmpdir)) == expected def test_model_checkpoint_mode_options(): From 0c060881b6e9cfa91f4a748b5d81082ed01a04e0 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 22 Feb 2021 18:12:34 +0100 Subject: [PATCH 02/10] ckpt_name_metrics -> monitor candidates --- .../callbacks/model_checkpoint.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 0be36a3f3d693..fe8e9f2cfc4c8 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -510,17 +510,17 @@ def _validate_monitor_key(self, trainer): def _get_metric_interpolated_filepath_name( self, - ckpt_name_metrics: Dict[str, Any], + monitor_candidates: Dict[str, Any], epoch: int, step: int, trainer, del_filepath: Optional[str] = None, ) -> str: - filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics) + filepath = self.format_checkpoint_name(epoch, step, monitor_candidates) version_cnt = self.STARTING_VERSION while self.file_exists(filepath, trainer) and filepath != del_filepath: - filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics, ver=version_cnt) + filepath = self.format_checkpoint_name(epoch, step, monitor_candidates, ver=version_cnt) version_cnt += 1 return filepath @@ -530,12 +530,12 @@ def _monitor_candidates(self, trainer): monitor_candidates.update(step=trainer.global_step, epoch=trainer.current_epoch) return monitor_candidates - def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics: dict): + def _save_last_checkpoint(self, trainer, pl_module, monitor_candidates: Dict[str, Any]): filepath = self._format_checkpoint_name( self.CHECKPOINT_NAME_LAST, trainer.current_epoch, trainer.global_step, - ckpt_name_metrics, + monitor_candidates, prefix=self.prefix ) filepath = os.path.join(self.dirpath, f"{filepath}{self.FILE_EXTENSION}") @@ -547,10 +547,10 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics: dict): self.last_model_path = filepath - def _save_top_k_checkpoint(self, trainer, pl_module, ckpt_name_metrics): - current = ckpt_name_metrics.get(self.monitor) - epoch = ckpt_name_metrics.get("epoch") - step = ckpt_name_metrics.get("step") + def _save_top_k_checkpoint(self, trainer, pl_module, monitor_candidates: Dict[str, Any]): + current = monitor_candidates.get(self.monitor) + epoch = monitor_candidates.get("epoch") + step = monitor_candidates.get("step") # when `val_loss` is being logged and no ModelCheckpoint is being provided # `val_loss` will be selected for monitor and need to be reduced to @@ -561,13 +561,13 @@ def _save_top_k_checkpoint(self, trainer, pl_module, ckpt_name_metrics): current = trainer.training_type_plugin.reduce(current, reduce_op="mean") if self.check_monitor_top_k(current): - self._update_best_and_save(current, epoch, step, trainer, pl_module, ckpt_name_metrics) + self._update_best_and_save(current, epoch, step, trainer, pl_module, monitor_candidates) elif self.verbose: rank_zero_info(f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}") - def _save_none_monitor_checkpoint(self, trainer, pl_module, ckpt_name_metrics): + def _save_none_monitor_checkpoint(self, trainer, pl_module, monitor_candidates: Dict[str, Any]): filepath = self._get_metric_interpolated_filepath_name( - ckpt_name_metrics, + monitor_candidates, trainer.current_epoch, trainer.global_step, trainer, @@ -588,7 +588,7 @@ def _is_valid_monitor_key(self, metrics): return self.monitor in metrics or len(metrics) == 0 def _update_best_and_save( - self, current: torch.Tensor, epoch: int, step: int, trainer, pl_module, ckpt_name_metrics + self, current: torch.Tensor, epoch: int, step: int, trainer, pl_module, monitor_candidates: Dict[str, Any] ): k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k @@ -601,7 +601,7 @@ def _update_best_and_save( if isinstance(current, torch.Tensor) and torch.isnan(current): current = torch.tensor(float('inf' if self.mode == "min" else '-inf')) - filepath = self._get_metric_interpolated_filepath_name(ckpt_name_metrics, epoch, step, trainer, del_filepath) + filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, epoch, step, trainer, del_filepath) # save the current score self.current_score = current From fb8f6117c60908c8d0f67cded7d43df5f7947b5a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 22 Feb 2021 18:16:42 +0100 Subject: [PATCH 03/10] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8dfc9a1c021ac..9b20e3704aa75 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Do not print top-k verbose log with `ModelCheckpoint(monitor=None)` ([#6109](https://github.com/PyTorchLightning/pytorch-lightning/pull/6109)) +- Fixed `ModelCheckpoint(monitor=None, save_last=True)` not saving checkpoints ([#6136](https://github.com/PyTorchLightning/pytorch-lightning/pull/6136)) + + - Fixed error message for AMP + CPU incompatibility ([#6107](https://github.com/PyTorchLightning/pytorch-lightning/pull/6107)) From 7366a13ae86e9983e5e86e4b1ae405f1a8a3f9d7 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 23 Feb 2021 03:14:05 +0100 Subject: [PATCH 04/10] Resolve bug. BC support --- .../callbacks/model_checkpoint.py | 36 +++++++++++-------- .../plugins/training_type/rpc.py | 9 +++-- .../plugins/training_type/rpc_sequential.py | 12 +++---- tests/deprecated_api/test_remove_1-5.py | 25 +++++++++++++ tests/plugins/test_rpc_plugin.py | 2 +- 5 files changed, 57 insertions(+), 27 deletions(-) create mode 100644 tests/deprecated_api/test_remove_1-5.py diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index fe8e9f2cfc4c8..74d7aad301427 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -209,7 +209,7 @@ def on_validation_end(self, trainer, pl_module): """ checkpoints can be saved at the end of the val loop """ - self.save_checkpoint(trainer, pl_module) + self.save_checkpoint(trainer) def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]: return { @@ -224,12 +224,18 @@ def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]): self.best_model_score = checkpointed_state["best_model_score"] self.best_model_path = checkpointed_state["best_model_path"] - def save_checkpoint(self, trainer, pl_module): + def save_checkpoint(self, trainer, unused: Optional = None): """ Performs the main logic around saving a checkpoint. This method runs on all ranks, it is the responsibility of `self.save_function` to handle correct behaviour in distributed training, i.e., saving only on rank 0. """ + if unused is not None: + rank_zero_warn( + "`ModelCheckpoint.save_checkpoint` signature has changed in v1.3. The `pl_module` parameter" + " has been removed. Support for the old signature will be removed in v1.5", DeprecationWarning + ) + epoch = trainer.current_epoch global_step = trainer.global_step @@ -257,15 +263,15 @@ def save_checkpoint(self, trainer, pl_module): # Mode 1: save the top k checkpoints if self.monitor is not None and self.save_top_k != 0: - self._save_top_k_checkpoint(trainer, pl_module, monitor_candidates) + self._save_top_k_checkpoint(trainer, monitor_candidates) # Mode 2: save monitor=None checkpoints if self.monitor is None and self.save_top_k in (None, -1): - self._save_none_monitor_checkpoint(trainer, pl_module, monitor_candidates) + self._save_none_monitor_checkpoint(trainer, monitor_candidates) # Mode 3: save last checkpoints if self.save_last: - self._save_last_checkpoint(trainer, pl_module, monitor_candidates) + self._save_last_checkpoint(trainer, monitor_candidates) def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: @@ -332,12 +338,12 @@ def _del_model(self, filepath: str): self._fs.rm(filepath) log.debug(f"Removed checkpoint: {filepath}") - def _save_model(self, trainer, pl_module, filepath: str): + def _save_model(self, trainer, filepath: str): if trainer.training_type_plugin.rpc_enabled: # RPCPlugin manages saving all model states # TODO: the rpc pluging should wrap trainer.save_checkpoint # instead of us having to do it here manually - trainer.training_type_plugin.rpc_save_model(self._do_save, filepath, trainer, pl_module) + trainer.training_type_plugin.rpc_save_model(trainer, self._do_save, filepath) else: self._do_save(trainer, filepath) @@ -530,7 +536,7 @@ def _monitor_candidates(self, trainer): monitor_candidates.update(step=trainer.global_step, epoch=trainer.current_epoch) return monitor_candidates - def _save_last_checkpoint(self, trainer, pl_module, monitor_candidates: Dict[str, Any]): + def _save_last_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]): filepath = self._format_checkpoint_name( self.CHECKPOINT_NAME_LAST, trainer.current_epoch, @@ -540,14 +546,14 @@ def _save_last_checkpoint(self, trainer, pl_module, monitor_candidates: Dict[str ) filepath = os.path.join(self.dirpath, f"{filepath}{self.FILE_EXTENSION}") - self._save_model(trainer, pl_module, filepath) + self._save_model(trainer, filepath) if self.last_model_path and self.last_model_path != filepath and trainer.is_global_zero: self._del_model(self.last_model_path) self.last_model_path = filepath - def _save_top_k_checkpoint(self, trainer, pl_module, monitor_candidates: Dict[str, Any]): + def _save_top_k_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]): current = monitor_candidates.get(self.monitor) epoch = monitor_candidates.get("epoch") step = monitor_candidates.get("step") @@ -561,18 +567,18 @@ def _save_top_k_checkpoint(self, trainer, pl_module, monitor_candidates: Dict[st current = trainer.training_type_plugin.reduce(current, reduce_op="mean") if self.check_monitor_top_k(current): - self._update_best_and_save(current, epoch, step, trainer, pl_module, monitor_candidates) + self._update_best_and_save(current, epoch, step, trainer, monitor_candidates) elif self.verbose: rank_zero_info(f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}") - def _save_none_monitor_checkpoint(self, trainer, pl_module, monitor_candidates: Dict[str, Any]): + def _save_none_monitor_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]): filepath = self._get_metric_interpolated_filepath_name( monitor_candidates, trainer.current_epoch, trainer.global_step, trainer, ) - self._save_model(trainer, pl_module, filepath) + self._save_model(trainer, filepath) if ( self.save_top_k is None @@ -588,7 +594,7 @@ def _is_valid_monitor_key(self, metrics): return self.monitor in metrics or len(metrics) == 0 def _update_best_and_save( - self, current: torch.Tensor, epoch: int, step: int, trainer, pl_module, monitor_candidates: Dict[str, Any] + self, current: torch.Tensor, epoch: int, step: int, trainer, monitor_candidates: Dict[str, Any] ): k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k @@ -622,7 +628,7 @@ def _update_best_and_save( f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f}" f' (best {self.best_model_score:0.5f}), saving model to "{filepath}" as top {k}' ) - self._save_model(trainer, pl_module, filepath) + self._save_model(trainer, filepath) if del_filepath is not None and filepath != del_filepath: self._del_model(del_filepath) diff --git a/pytorch_lightning/plugins/training_type/rpc.py b/pytorch_lightning/plugins/training_type/rpc.py index 3c016f3cb8e25..faf528d76b768 100644 --- a/pytorch_lightning/plugins/training_type/rpc.py +++ b/pytorch_lightning/plugins/training_type/rpc.py @@ -13,7 +13,7 @@ # limitations under the License. import os from contextlib import suppress -from typing import List, Optional +from typing import List, Optional, Callable import torch @@ -63,16 +63,15 @@ def init_rpc_connection(self, global_rank: int, world_size: int) -> None: rpc._set_rpc_timeout(self.rpc_timeout_sec) self._is_rpc_initialized = True - def rpc_save_model(self, save_model_fn, last_filepath, trainer, pl_module) -> None: + def rpc_save_model(self, trainer, save_model_fn: Callable, filepath: str) -> None: """ Override to save model to disk. This is required as the main process will be required to handle aggregating model states from RPC processes. Args: - save_model_fn: The saving function to save final model. - last_filepath: The filepath to save the model to. trainer: The trainer object. - pl_module: The LightningModule. + save_model_fn: The saving function to save final model. + filepath: The filepath to save the model to. """ raise NotImplementedError diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index 67b9608dbc5fe..09959addd296e 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -13,7 +13,7 @@ # limitations under the License import logging import os -from typing import List, Optional +from typing import List, Optional, Callable import torch import torch.distributed as torch_distrib @@ -266,17 +266,17 @@ def configure_ddp(self): self._model.require_backward_grad_sync = False @rank_zero_only - def rpc_save_model(self, save_model_fn, last_filepath, trainer, pl_module) -> None: + def rpc_save_model(self, trainer, save_model_fn: Callable, filepath: str) -> None: model = self.lightning_module if not hasattr(model.sequential_module, "foreach_worker"): return - current_layers = pl_module.sequential_module + current_layers = model.sequential_module model.sequential_module.foreach_worker( save_layers_on_all_rank_zero_workers, {"gpus_per_model": self.gpus_per_model}, include_self=True ) - pl_module.sequential_module = load_sequential_from_saved_layers(self.gpus_per_model) - save_model_fn(last_filepath, trainer) - pl_module.sequential_module = current_layers + model.sequential_module = load_sequential_from_saved_layers(self.gpus_per_model) + save_model_fn(trainer, filepath) + model.sequential_module = current_layers def worker_optimizer_step(self, model: LightningModule, opt_idx: int, *args, **kwargs) -> None: model.sequential_module.foreach_worker( diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py new file mode 100644 index 0000000000000..0900133c90bb7 --- /dev/null +++ b/tests/deprecated_api/test_remove_1-5.py @@ -0,0 +1,25 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test deprecated functionality which will be removed in v1.5.0""" +import pytest + +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint + + +def test_v1_5_0_model_checkpoint_save_checkpoint(): + model_ckpt = ModelCheckpoint() + model_ckpt.save_function = lambda *_, **__: None + with pytest.deprecated_call(match="ModelCheckpoint.save_checkpoint` signature has changed"): + model_ckpt.save_checkpoint(Trainer(), object()) \ No newline at end of file diff --git a/tests/plugins/test_rpc_plugin.py b/tests/plugins/test_rpc_plugin.py index 2c074e6c3afda..a783051c3f6a4 100644 --- a/tests/plugins/test_rpc_plugin.py +++ b/tests/plugins/test_rpc_plugin.py @@ -58,7 +58,7 @@ def __init__(self, **kwargs): self.rpc_save_model_count = 0 self.worker_optimizer_step_count = 0 - def rpc_save_model(self, save_model_fn, last_filepath, trainer, pl_module) -> None: + def rpc_save_model(self, *_) -> None: self.rpc_save_model_count += 1 def barrier(self, name: Optional[str] = None) -> None: From 8a8afb470db3232ba462b9eb9044ddd9369dd3af Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 23 Feb 2021 03:15:44 +0100 Subject: [PATCH 05/10] EOF --- tests/deprecated_api/test_remove_1-5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 0900133c90bb7..2816fc7b9d52a 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -22,4 +22,4 @@ def test_v1_5_0_model_checkpoint_save_checkpoint(): model_ckpt = ModelCheckpoint() model_ckpt.save_function = lambda *_, **__: None with pytest.deprecated_call(match="ModelCheckpoint.save_checkpoint` signature has changed"): - model_ckpt.save_checkpoint(Trainer(), object()) \ No newline at end of file + model_ckpt.save_checkpoint(Trainer(), object()) From 5269594bcfe136872859b114cc99c3a8502613b1 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 23 Feb 2021 03:25:01 +0100 Subject: [PATCH 06/10] Move ifs inside functions --- .../callbacks/model_checkpoint.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 74d7aad301427..a1f716d14513b 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -260,18 +260,12 @@ def save_checkpoint(self, trainer, unused: Optional = None): # callback supports multiple simultaneous modes # here we call each mode sequentially - # Mode 1: save the top k checkpoints - if self.monitor is not None and self.save_top_k != 0: - self._save_top_k_checkpoint(trainer, monitor_candidates) - + self._save_top_k_checkpoint(trainer, monitor_candidates) # Mode 2: save monitor=None checkpoints - if self.monitor is None and self.save_top_k in (None, -1): - self._save_none_monitor_checkpoint(trainer, monitor_candidates) - + self._save_none_monitor_checkpoint(trainer, monitor_candidates) # Mode 3: save last checkpoints - if self.save_last: - self._save_last_checkpoint(trainer, monitor_candidates) + self._save_last_checkpoint(trainer, monitor_candidates) def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: @@ -537,6 +531,9 @@ def _monitor_candidates(self, trainer): return monitor_candidates def _save_last_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]): + if not self.save_last: + return + filepath = self._format_checkpoint_name( self.CHECKPOINT_NAME_LAST, trainer.current_epoch, @@ -554,6 +551,9 @@ def _save_last_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]): self.last_model_path = filepath def _save_top_k_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]): + if self.monitor is None or self.save_top_k == 0: + return + current = monitor_candidates.get(self.monitor) epoch = monitor_candidates.get("epoch") step = monitor_candidates.get("step") @@ -572,6 +572,9 @@ def _save_top_k_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]): rank_zero_info(f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}") def _save_none_monitor_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]): + if self.monitor is not None: + return + filepath = self._get_metric_interpolated_filepath_name( monitor_candidates, trainer.current_epoch, From 5717fda9c740b375e9006bd7c1365c004346d9e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 23 Feb 2021 13:48:22 +0100 Subject: [PATCH 07/10] Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: ananthsub --- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index a1f716d14513b..9f21b7ca5bbeb 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -335,7 +335,7 @@ def _del_model(self, filepath: str): def _save_model(self, trainer, filepath: str): if trainer.training_type_plugin.rpc_enabled: # RPCPlugin manages saving all model states - # TODO: the rpc pluging should wrap trainer.save_checkpoint + # TODO: the rpc plugin should wrap trainer.save_checkpoint # instead of us having to do it here manually trainer.training_type_plugin.rpc_save_model(trainer, self._do_save, filepath) else: From 2ae7200a5b859ce6c3d2388014a0588b4aef44b3 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 25 Feb 2021 21:32:22 +0100 Subject: [PATCH 08/10] Fix merge --- tests/deprecated_api/test_remove_1-5.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 2d04df8190ee9..831d80a2d916a 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -14,12 +14,11 @@ """Test deprecated functionality which will be removed in v1.5.0""" import pytest -from pytorch_lightning import Trainer, Callback -from tests.helpers import BoringModel -from tests.helpers.utils import no_warning_callimport pytest - +from pytorch_lightning import Callback from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint +from tests.helpers import BoringModel +from tests.helpers.utils import no_warning_call def test_v1_5_0_model_checkpoint_save_checkpoint(): From b1ee23be0dfcc0fa5eaefc4470f34e4e5a2ef2be Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 26 Feb 2021 15:59:35 +0100 Subject: [PATCH 09/10] Fix bug with top_k=0 and save_last=True --- CHANGELOG.md | 3 +++ pytorch_lightning/callbacks/model_checkpoint.py | 3 +-- tests/checkpointing/test_model_checkpoint.py | 7 ++++--- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bad3fa2202881..150b6401159dc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,6 +54,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `ModelCheckpoint(monitor=None, save_last=True)` not saving checkpoints ([#6136](https://github.com/PyTorchLightning/pytorch-lightning/pull/6136)) +- Fixed `ModelCheckpoint(save_top_k=0, save_last=True)` not saving the `last` checkpoint ([#6136](https://github.com/PyTorchLightning/pytorch-lightning/pull/6136)) + + - Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115)) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 8377cca8aadbe..876dbe11f8430 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -221,7 +221,6 @@ def save_checkpoint(self, trainer, unused: Optional = None): if ( trainer.fast_dev_run # disable checkpointing with fast_dev_run - or self.save_top_k == 0 # no models are saved or self.period < 1 # no models are saved or (epoch + 1) % self.period # skip epoch or trainer.running_sanity_check # don't save anything during sanity check @@ -541,7 +540,7 @@ def _save_top_k_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]): rank_zero_info(f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}") def _save_none_monitor_checkpoint(self, trainer, monitor_candidates: Dict[str, Any]): - if self.monitor is not None: + if self.monitor is not None or self.save_top_k == 0: return filepath = self._get_metric_interpolated_filepath_name( diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 72813fc79c168..8d105639a703a 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -568,7 +568,7 @@ def test_model_checkpoint_period(tmpdir, period): def test_model_checkpoint_topk_zero(tmpdir): """ Test that no checkpoints are saved when save_top_k=0. """ model = LogInTwoMethods() - checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_top_k=0) + checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_top_k=0, save_last=True) trainer = Trainer( default_root_dir=tmpdir, callbacks=[checkpoint_callback], @@ -582,8 +582,9 @@ def test_model_checkpoint_topk_zero(tmpdir): assert checkpoint_callback.best_model_score is None assert checkpoint_callback.best_k_models == {} assert checkpoint_callback.kth_best_model_path == '' - # check that no ckpts were created - assert len(os.listdir(tmpdir)) == 0 + # check that only the last ckpt was created + assert os.listdir(tmpdir) == ['last.ckpt'] + assert checkpoint_callback.last_model_path == tmpdir / 'last.ckpt' def test_model_checkpoint_topk_all(tmpdir): From 9564ba1be269114c8fd3698d41a4dc66f3116607 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 1 Mar 2021 14:33:14 +0100 Subject: [PATCH 10/10] Formatting --- tests/deprecated_api/test_remove_1-5.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index ae25d098f8f8a..384c809e20a45 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -19,12 +19,10 @@ from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import WandbLogger - from tests.helpers import BoringModel from tests.helpers.utils import no_warning_call - def test_v1_5_0_model_checkpoint_save_checkpoint(): model_ckpt = ModelCheckpoint() model_ckpt.save_function = lambda *_, **__: None