From 9952cbee680de4d17ca983cd4cddab5ff845bb49 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 15 Dec 2022 10:29:53 +0000 Subject: [PATCH] Add function to remove checkpoint to allow override for extended classes --- src/pytorch_lightning/callbacks/model_checkpoint.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/callbacks/model_checkpoint.py b/src/pytorch_lightning/callbacks/model_checkpoint.py index 0a7b400bb9f05..d7227c78f4e4a 100644 --- a/src/pytorch_lightning/callbacks/model_checkpoint.py +++ b/src/pytorch_lightning/callbacks/model_checkpoint.py @@ -649,7 +649,7 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[ previous, self.last_model_path = self.last_model_path, filepath self._save_checkpoint(trainer, filepath) if previous and previous != filepath: - trainer.strategy.remove_checkpoint(previous) + self._remove_checkpoint(trainer, previous) def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None: assert self.monitor @@ -668,7 +668,7 @@ def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidate previous, self.best_model_path = self.best_model_path, filepath self._save_checkpoint(trainer, filepath) if self.save_top_k == 1 and previous and previous != filepath: - trainer.strategy.remove_checkpoint(previous) + self._remove_checkpoint(trainer, previous) def _update_best_and_save( self, current: Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor] @@ -710,7 +710,7 @@ def _update_best_and_save( self._save_checkpoint(trainer, filepath) if del_filepath is not None and filepath != del_filepath: - trainer.strategy.remove_checkpoint(del_filepath) + self._remove_checkpoint(trainer, del_filepath) def to_yaml(self, filepath: Optional[_PATH] = None) -> None: """Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML @@ -727,3 +727,7 @@ def file_exists(self, filepath: _PATH, trainer: "pl.Trainer") -> bool: state to diverge between ranks.""" exists = self._fs.exists(filepath) return trainer.strategy.broadcast(exists) + + def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: + """Calls the strategy to remove the checkpoint file.""" + trainer.strategy.remove_checkpoint(filepath)