Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce CheckpointIO Plugin #8743

Merged
merged 40 commits into from
Aug 13, 2021
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
a93e452
poc API
Jul 19, 2021
72f4dfd
Merge branch 'master' into feat/ckpt_plugin
Aug 5, 2021
e7d2b66
Fix up the API, unsure on connection
Aug 5, 2021
b41e794
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2021
9161980
Example API
Aug 5, 2021
7aa4e8c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2021
dffe088
Update all constructors
Aug 5, 2021
cacf0e5
Move towards having the checkpoint plugin not require the plugin, and…
Aug 6, 2021
028ac38
Remove import
Aug 6, 2021
99c7a46
Fix tests
Aug 6, 2021
3adc486
Change name
Aug 6, 2021
b7d5b55
Cleanups
Aug 9, 2021
0a0a068
Fixes/Cleanups
Aug 9, 2021
97fb2a2
Use property
Aug 9, 2021
402156e
Fixes to signature
Aug 9, 2021
5310a7f
Merge branch 'master' into feat/ckpt_plugin
Aug 9, 2021
d7f567a
Add warning for TPU plugins that they do not support custom checkpoin…
Aug 9, 2021
fcc24b4
Cleanup API, introduce storage options
Aug 10, 2021
4421276
Update signature to be more general
Aug 10, 2021
d84cce1
Address feedback, add test for support check
Aug 11, 2021
38c22a2
Merge branch 'master' into feat/ckpt_plugin
Aug 11, 2021
b7f37ee
Add CHANGELOG.md
Aug 11, 2021
49086cc
fix tests
Aug 11, 2021
936f65a
change name
Aug 11, 2021
049a676
Fix mypy
Aug 11, 2021
1ff0912
Reviews
Aug 12, 2021
1841d3b
Add ability to pass checkpoint plugin through the trainer
Aug 12, 2021
b909dfe
Add constraints
Aug 12, 2021
50b11b5
Match signature to see if mypy works
Aug 12, 2021
9e16e34
Address review points
Aug 13, 2021
642e6fa
Revert changes to typing
Aug 13, 2021
5c9e973
Add docs/doc strings and API
Aug 13, 2021
6361c87
Address feedback
Aug 13, 2021
c921dbb
Update pytorch_lightning/plugins/training_type/training_type_plugin.py
Aug 13, 2021
fd82276
Address reviews
Aug 13, 2021
2fc3558
Update typing
Aug 13, 2021
21783f6
Refactor name
Aug 13, 2021
3b8c3f5
Clear up signature of function; checkpoint_plugin -> checkpoint_io
Aug 13, 2021
9cfe98f
Slightly cleaner
Aug 13, 2021
8f234e0
Address reviews
Aug 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added `FastForwardSampler` and `CaptureIterableDataset` injection to data loading utilities ([#8366](https://github.com/PyTorchLightning/pytorch-lightning/pull/8366))


- Added `CheckpointIO` to expose checkpoint IO from training type plugin ([#8743](https://github.com/PyTorchLightning/pytorch-lightning/pull/8743))


### Changed

- Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770))
Expand Down
52 changes: 52 additions & 0 deletions docs/source/advanced/checkpoint_io.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
Custom Checkpointing IO
=======================

.. warning:: The Checkpoint IO API is experimental and subject to change.

Lightning supports modifying the checkpointing save/load functionality through the ``CheckpointIO``. This encapsulates the save/load logic
that is managed by the ``TrainingTypePlugin``.

``CheckpointIO`` can be extended to include your custom save/load functionality to and from a path, with the object being passed to either a `Trainer`` object or a``TrainingTypePlugin`` as shown below.

.. code-block:: python

from pathlib import Path
from typing import Any, Dict, Optional, Union

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.plugins import CheckpointIO, SingleDevicePlugin


class CustomCheckpointPlugin(CheckpointIO):
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
def save_checkpoint(
self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[Any] = None
) -> None:
...

def load_checkpoint(self, path: Union[str, Path], storage_options: Optional[Any] = None) -> Dict[str, Any]:
...


checkpoint_plugin = CustomCheckpointPlugin()

# Pass into the Trainer object
model = MyModel()
trainer = Trainer(
plugins=[checkpoint_plugin],
callbacks=ModelCheckpoint(save_last=True),
)
trainer.fit(model)

# pass into TrainingTypePlugin
model = MyModel()
device = torch.device("cpu")
trainer = Trainer(
plugins=SingleDevicePlugin(device, checkpoint_plugin=checkpoint_plugin),
callbacks=ModelCheckpoint(save_last=True),
)
trainer.fit(model)

.. note::

Some ``TrainingTypePlugins`` do not support custom ``CheckpointIO`` as as checkpointing logic is not modifiable.
12 changes: 12 additions & 0 deletions docs/source/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,18 @@ Cluster Environments
KubeflowEnvironment
SLURMEnvironment

Checkpoint IO Plugins
^^^^^^^^^^^^^^^^^^^^^

.. currentmodule:: pytorch_lightning.plugins.io

.. autosummary::
:toctree: api
:nosignatures:
:template: classtemplate.rst

CheckpointIO
TorchCheckpointIO

Profiler API
------------
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ PyTorch Lightning Documentation
advanced/multi_gpu
advanced/advanced_gpu
common/weights_loading
advanced/checkpoint_io
common/optimizers
advanced/profiler
advanced/sequences
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from pytorch_lightning.plugins.base_plugin import Plugin
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
from pytorch_lightning.plugins.plugins_registry import ( # noqa: F401
call_training_type_register_plugins,
TrainingTypePluginsRegistry,
Expand Down Expand Up @@ -29,6 +31,8 @@
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin

__all__ = [
"CheckpointIO",
"TorchCheckpointIO",
"ApexMixedPrecisionPlugin",
"DataParallelPlugin",
"DDP2Plugin",
Expand Down
15 changes: 15 additions & 0 deletions pytorch_lightning/plugins/io/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO # noqa: F401
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO # noqa: F401
57 changes: 57 additions & 0 deletions pytorch_lightning/plugins/io/checkpoint_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional

from pytorch_lightning.utilities.types import _PATH


class CheckpointIO(ABC):
"""
Interface to save/load checkpoints as they are saved through the ``TrainingTypePlugin``.

Typically most plugins either use the Torch based IO Plugin; ``TorchCheckpointIO`` but may
require particular handling depending on the plugin.

In addition, you can pass a custom ``CheckpointIO`` by extending this class and passing it
to the Trainer, i.e ``Trainer(plugins=[MyCustomCheckpointIO()])``.

.. note::

For some plugins, it is not possible to use a custom checkpoint plugin as checkpointing logic is not
modifiable.

"""

@abstractmethod
def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.

Args:
checkpoint: dict containing model and trainer state
path: write-target path
storage_options: Optional parameters when saving the model/training states.
"""

@abstractmethod
def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> Dict[str, Any]:
"""
Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages.

Args:
path: Path to checkpoint
storage_options: Optional parameters when loading the model/training states.

Returns: The loaded checkpoint.
"""
55 changes: 55 additions & 0 deletions pytorch_lightning/plugins/io/torch_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, Optional

import pytorch_lightning as pl
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.types import _PATH


class TorchCheckpointIO(CheckpointIO):
"""
CheckpointIO that utilizes :func:`torch.save` and :func:`torch.load`
to save and load checkpoints respectively, common for most use cases.
"""

def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
try:
# write the checkpoint dictionary on the file
atomic_save(checkpoint, path)
except AttributeError as err:
# todo (sean): is this try catch necessary still?
# https://github.com/PyTorchLightning/pytorch-lightning/pull/431
key = pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY
checkpoint.pop(key, None)
rank_zero_warn(f"Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}")
atomic_save(checkpoint, path)

def load_checkpoint(
self, path: _PATH, map_location: Optional[Callable] = lambda storage, loc: storage
) -> Dict[str, Any]:
"""
Loads checkpoint using torch.load, with additional handling for fsspec remote loading of files.
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

Args:
path: Path to checkpoint
map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
locations.

Returns: The loaded checkpoint.
"""
return pl_load(path, map_location=map_location)
10 changes: 8 additions & 2 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities import (
_HYDRA_AVAILABLE,
Expand Down Expand Up @@ -75,14 +76,19 @@ def __init__(
self,
parallel_devices: Optional[List[torch.device]] = None,
num_nodes: Optional[int] = None,
cluster_environment: ClusterEnvironment = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_plugin: Optional[CheckpointIO] = None,
sync_batchnorm: Optional[bool] = None,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
**kwargs: Union[Any, Dict[str, Any]],
) -> None:
super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment)
super().__init__(
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
checkpoint_plugin=checkpoint_plugin,
)
self.interactive_ddp_procs = []
if num_nodes is not None:
rank_zero_deprecation(
Expand Down
10 changes: 8 additions & 2 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import (
Expand Down Expand Up @@ -62,14 +63,19 @@ def __init__(
self,
parallel_devices: Optional[List[torch.device]] = None,
num_nodes: Optional[int] = None,
cluster_environment: ClusterEnvironment = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_plugin: Optional[CheckpointIO] = None,
sync_batchnorm: Optional[bool] = None,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
**kwargs: Any,
):
super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment)
super().__init__(
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
checkpoint_plugin=checkpoint_plugin,
)
if num_nodes is not None:
rank_zero_deprecation(
"Argument `num_nodes` in `DDPSpawnPlugin` is deprecated in v1.4, and will be removed in v1.6. "
Expand Down
15 changes: 14 additions & 1 deletion pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pytorch_lightning.callbacks import GradientAccumulationScheduler
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.trainer.states import TrainerFn
Expand Down Expand Up @@ -274,8 +275,11 @@ def __init__(
pin_memory = cpu_offload_use_pin_memory

super().__init__(
parallel_devices=parallel_devices, num_nodes=num_nodes, cluster_environment=cluster_environment
parallel_devices=parallel_devices,
num_nodes=num_nodes,
cluster_environment=cluster_environment,
)

self.config = self._load_config(config)
if self.config is None:
# User has not overridden config, set defaults
Expand Down Expand Up @@ -679,6 +683,7 @@ def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None:
filepath: write-target file's path
"""
if self.zero_stage_3 and self._multi_device and self.is_global_zero:
# todo (sean): Add link to docs once docs are merged.
warning_cache.warn(
"When saving the DeepSpeed Stage 3 checkpoint, "
"each worker will save a shard of the checkpoint within a directory. "
Expand Down Expand Up @@ -818,3 +823,11 @@ def register_plugins(cls, plugin_registry: Dict) -> None:
offload_params_device="nvme",
offload_optimizer_device="nvme",
)

@property
def checkpoint_plugin(self) -> CheckpointIO:
return self._checkpoint_plugin

@checkpoint_plugin.setter
def checkpoint_plugin(self, plugin: CheckpointIO) -> None:
raise MisconfigurationException("DeepSpeed currently does not support custom checkpoint plugins.")
11 changes: 9 additions & 2 deletions pytorch_lightning/plugins/training_type/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torch.nn import DataParallel

from pytorch_lightning.overrides.data_parallel import LightningParallelModule
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand All @@ -29,8 +30,14 @@ class DataParallelPlugin(ParallelPlugin):
device and each gets a split of the data.
"""

def __init__(self, parallel_devices: Optional[List[torch.device]]):
super().__init__(parallel_devices=parallel_devices, cluster_environment=None)
def __init__(
self,
parallel_devices: Optional[List[torch.device]],
checkpoint_plugin: Optional[CheckpointIO] = None,
):
super().__init__(
parallel_devices=parallel_devices, cluster_environment=None, checkpoint_plugin=checkpoint_plugin
)

@property
def global_rank(self) -> int:
Expand Down
Loading