Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added Fault Tolerant Training to `DataFetcher` ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891))
* Replaced old prefetch iterator with new `DataFetcher` in training loop ([#8953](https://github.com/PyTorchLightning/pytorch-lightning/pull/8953))

- Added `CheckpointIO` to expose checkpoint IO from training type plugin ([#8743](https://github.com/PyTorchLightning/pytorch-lightning/pull/8743))
- Checkpoint saving & loading extensibility:
* Added `CheckpointIO` to expose checkpoint IO from training type plugin ([#8743](https://github.com/PyTorchLightning/pytorch-lightning/pull/8743))
* Refactored CheckpointConnector to offload validating logic to the checkpoitn IO plugin ([#9045](https://github.com/PyTorchLightning/pytorch-lightning/pull/9045))


- Added DeepSpeed Stage 1 support ([#8974](https://github.com/PyTorchLightning/pytorch-lightning/pull/8974))
Expand Down
11 changes: 10 additions & 1 deletion pytorch_lightning/plugins/io/torch_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import pytorch_lightning as pl
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.types import _PATH

Expand Down Expand Up @@ -51,5 +51,14 @@ def load_checkpoint(
locations.

Returns: The loaded checkpoint.

Raises:
FileNotFoundError: If ``path`` is not found by the ``fsspec`` filesystem
"""

# Try to read the checkpoint at `path`. If not exist, do not restore checkpoint.
fs = get_filesystem(path)
if not fs.exists(path):
raise FileNotFoundError(f"Checkpoint at {path} not found. Aborting training.")

return pl_load(path, map_location=map_location)
8 changes: 4 additions & 4 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,12 +697,12 @@ def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None:
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint)

def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Optional[Dict[str, Any]]:
def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Optional[Dict[str, Any]]:
if self.load_full_weights and self.zero_stage_3:
# Broadcast to ensure we load from the rank 0 checkpoint
# This doesn't have to be the case when using deepspeed sharded checkpointing
checkpoint_path = self.broadcast(checkpoint_path)
return super().load_checkpoint_file(checkpoint_path)
return super().load_checkpoint(checkpoint_path)

# Rely on deepspeed to load the checkpoint and necessary information
from pytorch_lightning.trainer.states import TrainerFn
Expand Down Expand Up @@ -730,7 +730,7 @@ def lightning_restore_optimizer_and_schedulers(self) -> bool:
return False

def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
# override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint_file()`
# override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint()`
if self.load_full_weights and self.zero_stage_3:
self.model_to_device()
self._restore_zero_state(checkpoint)
Expand Down Expand Up @@ -782,7 +782,7 @@ def load(module: torch.nn.Module, prefix=""):
load(self.lightning_module, prefix="")

def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
# override to do nothing, deepspeed engine already loaded the states in `load_checkpoint_file()`
# override to do nothing, deepspeed engine already loaded the states in `load_checkpoint()`
pass

def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
"""
return self._results

def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
torch.cuda.empty_cache()
return self.checkpoint_io.load_checkpoint(checkpoint_path)

def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
Expand Down
14 changes: 3 additions & 11 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,8 @@ def resume_start(self) -> None:
if not checkpoint_path:
return

# clear cache before restore
torch.cuda.empty_cache()

# Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint.
fs = get_filesystem(checkpoint_path)
if not fs.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint at {checkpoint_path} not found. Aborting training.")

rank_zero_info(f"Restoring states from the checkpoint file at {checkpoint_path}")
self._loaded_checkpoint = self.trainer.training_type_plugin.load_checkpoint_file(checkpoint_path)
rank_zero_info(f"Restoring states from the checkpoint path at {checkpoint_path}")
self._loaded_checkpoint = self.trainer.training_type_plugin.load_checkpoint(checkpoint_path)

def resume_end(self) -> None:
"""Signal the connector that all states have resumed and memory for the checkpoint object can be released."""
Expand Down Expand Up @@ -152,7 +144,7 @@ def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) ->
"""Restore only the model weights."""
checkpoint = self._loaded_checkpoint
if checkpoint_path is not None:
checkpoint = self.trainer.training_type_plugin.load_checkpoint_file(checkpoint_path)
checkpoint = self.trainer.training_type_plugin.load_checkpoint(checkpoint_path)

self.trainer.lightning_module.on_load_checkpoint(checkpoint)
self.trainer.training_type_plugin.load_model_state_dict(checkpoint)
Expand Down
4 changes: 2 additions & 2 deletions tests/accelerators/test_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ def pre_dispatch(self) -> None:
def restore_checkpoint_after_pre_dispatch(self) -> bool:
return restore_after_pre_dispatch

def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
assert self.predispatched_called == restore_after_pre_dispatch
return super().load_checkpoint_file(checkpoint_path)
return super().load_checkpoint(checkpoint_path)

model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
Expand Down
5 changes: 0 additions & 5 deletions tests/plugins/test_checkpoint_io_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,10 @@


class CustomCheckpointIO(CheckpointIO):
save_checkpoint_called: bool = False
load_checkpoint_file_called: bool = False

def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
self.save_checkpoint_called = True
torch.save(checkpoint, path)

def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> Dict[str, Any]:
self.load_checkpoint_file_called = True
return torch.load(path)


Expand Down