diff --git a/CHANGELOG.md b/CHANGELOG.md index c90af8b9c97cd..fd56399bcd943 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,7 +51,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added Fault Tolerant Training to `DataFetcher` ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891)) * Replaced old prefetch iterator with new `DataFetcher` in training loop ([#8953](https://github.com/PyTorchLightning/pytorch-lightning/pull/8953)) -- Added `CheckpointIO` to expose checkpoint IO from training type plugin ([#8743](https://github.com/PyTorchLightning/pytorch-lightning/pull/8743)) +- Checkpoint saving & loading extensibility: + * Added `CheckpointIO` to expose checkpoint IO from training type plugin ([#8743](https://github.com/PyTorchLightning/pytorch-lightning/pull/8743)) + * Refactored CheckpointConnector to offload validating logic to the checkpoitn IO plugin ([#9045](https://github.com/PyTorchLightning/pytorch-lightning/pull/9045)) - Added DeepSpeed Stage 1 support ([#8974](https://github.com/PyTorchLightning/pytorch-lightning/pull/8974)) diff --git a/pytorch_lightning/plugins/io/torch_plugin.py b/pytorch_lightning/plugins/io/torch_plugin.py index e95f3d3b226f7..2aa66e65cc30b 100644 --- a/pytorch_lightning/plugins/io/torch_plugin.py +++ b/pytorch_lightning/plugins/io/torch_plugin.py @@ -16,7 +16,7 @@ import pytorch_lightning as pl from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities.cloud_io import atomic_save +from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.types import _PATH @@ -51,5 +51,14 @@ def load_checkpoint( locations. Returns: The loaded checkpoint. + + Raises: + FileNotFoundError: If ``path`` is not found by the ``fsspec`` filesystem """ + + # Try to read the checkpoint at `path`. If not exist, do not restore checkpoint. + fs = get_filesystem(path) + if not fs.exists(path): + raise FileNotFoundError(f"Checkpoint at {path} not found. Aborting training.") + return pl_load(path, map_location=map_location) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 31fdfde234462..94fb868d1c646 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -697,12 +697,12 @@ def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None: checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys} self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint) - def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Optional[Dict[str, Any]]: + def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Optional[Dict[str, Any]]: if self.load_full_weights and self.zero_stage_3: # Broadcast to ensure we load from the rank 0 checkpoint # This doesn't have to be the case when using deepspeed sharded checkpointing checkpoint_path = self.broadcast(checkpoint_path) - return super().load_checkpoint_file(checkpoint_path) + return super().load_checkpoint(checkpoint_path) # Rely on deepspeed to load the checkpoint and necessary information from pytorch_lightning.trainer.states import TrainerFn @@ -730,7 +730,7 @@ def lightning_restore_optimizer_and_schedulers(self) -> bool: return False def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: - # override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint_file()` + # override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint()` if self.load_full_weights and self.zero_stage_3: self.model_to_device() self._restore_zero_state(checkpoint) @@ -782,7 +782,7 @@ def load(module: torch.nn.Module, prefix=""): load(self.lightning_module, prefix="") def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: - # override to do nothing, deepspeed engine already loaded the states in `load_checkpoint_file()` + # override to do nothing, deepspeed engine already loaded the states in `load_checkpoint()` pass def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int: diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index cdc3bd47ab966..eed6de63f60b4 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -154,7 +154,8 @@ def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: """ return self._results - def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: + def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: + torch.cuda.empty_cache() return self.checkpoint_io.load_checkpoint(checkpoint_path) def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 5b967fed810e2..2c13b67697f01 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -60,16 +60,8 @@ def resume_start(self) -> None: if not checkpoint_path: return - # clear cache before restore - torch.cuda.empty_cache() - - # Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint. - fs = get_filesystem(checkpoint_path) - if not fs.exists(checkpoint_path): - raise FileNotFoundError(f"Checkpoint at {checkpoint_path} not found. Aborting training.") - - rank_zero_info(f"Restoring states from the checkpoint file at {checkpoint_path}") - self._loaded_checkpoint = self.trainer.training_type_plugin.load_checkpoint_file(checkpoint_path) + rank_zero_info(f"Restoring states from the checkpoint path at {checkpoint_path}") + self._loaded_checkpoint = self.trainer.training_type_plugin.load_checkpoint(checkpoint_path) def resume_end(self) -> None: """Signal the connector that all states have resumed and memory for the checkpoint object can be released.""" @@ -152,7 +144,7 @@ def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) -> """Restore only the model weights.""" checkpoint = self._loaded_checkpoint if checkpoint_path is not None: - checkpoint = self.trainer.training_type_plugin.load_checkpoint_file(checkpoint_path) + checkpoint = self.trainer.training_type_plugin.load_checkpoint(checkpoint_path) self.trainer.lightning_module.on_load_checkpoint(checkpoint) self.trainer.training_type_plugin.load_model_state_dict(checkpoint) diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index 98899360592f5..7280afea02f76 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -82,9 +82,9 @@ def pre_dispatch(self) -> None: def restore_checkpoint_after_pre_dispatch(self) -> bool: return restore_after_pre_dispatch - def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: + def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: assert self.predispatched_called == restore_after_pre_dispatch - return super().load_checkpoint_file(checkpoint_path) + return super().load_checkpoint(checkpoint_path) model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) diff --git a/tests/plugins/test_checkpoint_io_plugin.py b/tests/plugins/test_checkpoint_io_plugin.py index ef43b8b14b146..810127a03f361 100644 --- a/tests/plugins/test_checkpoint_io_plugin.py +++ b/tests/plugins/test_checkpoint_io_plugin.py @@ -27,15 +27,10 @@ class CustomCheckpointIO(CheckpointIO): - save_checkpoint_called: bool = False - load_checkpoint_file_called: bool = False - def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: - self.save_checkpoint_called = True torch.save(checkpoint, path) def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> Dict[str, Any]: - self.load_checkpoint_file_called = True return torch.load(path)