Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce CheckpointIO Plugin #8743

Merged
merged 40 commits into from
Aug 13, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
a93e452
poc API
Jul 19, 2021
72f4dfd
Merge branch 'master' into feat/ckpt_plugin
Aug 5, 2021
e7d2b66
Fix up the API, unsure on connection
Aug 5, 2021
b41e794
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2021
9161980
Example API
Aug 5, 2021
7aa4e8c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2021
dffe088
Update all constructors
Aug 5, 2021
cacf0e5
Move towards having the checkpoint plugin not require the plugin, and…
Aug 6, 2021
028ac38
Remove import
Aug 6, 2021
99c7a46
Fix tests
Aug 6, 2021
3adc486
Change name
Aug 6, 2021
b7d5b55
Cleanups
Aug 9, 2021
0a0a068
Fixes/Cleanups
Aug 9, 2021
97fb2a2
Use property
Aug 9, 2021
402156e
Fixes to signature
Aug 9, 2021
5310a7f
Merge branch 'master' into feat/ckpt_plugin
Aug 9, 2021
d7f567a
Add warning for TPU plugins that they do not support custom checkpoin…
Aug 9, 2021
fcc24b4
Cleanup API, introduce storage options
Aug 10, 2021
4421276
Update signature to be more general
Aug 10, 2021
d84cce1
Address feedback, add test for support check
Aug 11, 2021
38c22a2
Merge branch 'master' into feat/ckpt_plugin
Aug 11, 2021
b7f37ee
Add CHANGELOG.md
Aug 11, 2021
49086cc
fix tests
Aug 11, 2021
936f65a
change name
Aug 11, 2021
049a676
Fix mypy
Aug 11, 2021
1ff0912
Reviews
Aug 12, 2021
1841d3b
Add ability to pass checkpoint plugin through the trainer
Aug 12, 2021
b909dfe
Add constraints
Aug 12, 2021
50b11b5
Match signature to see if mypy works
Aug 12, 2021
9e16e34
Address review points
Aug 13, 2021
642e6fa
Revert changes to typing
Aug 13, 2021
5c9e973
Add docs/doc strings and API
Aug 13, 2021
6361c87
Address feedback
Aug 13, 2021
c921dbb
Update pytorch_lightning/plugins/training_type/training_type_plugin.py
Aug 13, 2021
fd82276
Address reviews
Aug 13, 2021
2fc3558
Update typing
Aug 13, 2021
21783f6
Refactor name
Aug 13, 2021
3b8c3f5
Clear up signature of function; checkpoint_plugin -> checkpoint_io
Aug 13, 2021
9cfe98f
Slightly cleaner
Aug 13, 2021
8f234e0
Address reviews
Aug 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
60 changes: 60 additions & 0 deletions pytorch_lightning/plugins/checkpoint/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from abc import ABC
from pathlib import Path
from typing import Any, Dict, Mapping, Union

from torch.nn import Module

from pytorch_lightning import LightningModule


class CheckpointPlugin(ABC):
def __init__(self):
self._training_type_plugin = None

@property
def training_type_plugin(self) -> "TrainingTypePlugin":
return self._training_type_plugin

@training_type_plugin.setter
def training_type_plugin(self, plugin) -> None:
self._training_type_plugin = plugin
ananthsub marked this conversation as resolved.
Show resolved Hide resolved

@property
def lightning_module(self) -> LightningModule:
return self.training_type_plugin.lightning_module

@property
def model(self) -> Module:
return self.training_type_plugin.model

def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
"""Save model/training states as a checkpoint file through state-dump and file-write.

Args:
checkpoint: dict containing model and trainer state
filepath: write-target file's path
"""

def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
"""
Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages.
Args:
checkpoint_path: Path to checkpoint

Returns: The loaded checkpoint.
"""
pass
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
"""
Given the loaded checkpoint file, loads the state dict into the model.
Args:
checkpoint: The loaded checkpoint file.
"""

def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
"""
Given the loaded checkpoint file, loads the optimizer state dicts into optimizers.
Args:
checkpoint: The loaded checkpoint file.
"""
123 changes: 123 additions & 0 deletions pytorch_lightning/plugins/checkpoint/deepspeed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from pathlib import Path
from typing import Any, Dict, Mapping, Optional, Union

import torch
from torch import Tensor

from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.warnings import WarningCache

warning_cache = WarningCache()
if _DEEPSPEED_AVAILABLE:
import deepspeed


class DeepSpeedCheckpointPlugin(CheckpointPlugin):
def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.

Args:
checkpoint: The checkpoint state dictionary
filepath: write-target file's path
"""
if (
self.training_type_plugin.zero_stage_3
and self.training_type_plugin._multi_device
and self.training_type_plugin.is_global_zero
):
warning_cache.warn(
"When saving the DeepSpeed Stage 3 checkpoint, "
"each worker will save a shard of the checkpoint within a directory. "
"If a single file is required after training, "
"see https://pytorch-lightning.readthedocs.io/en/latest/advanced/advanced_gpu.html#"
"deepspeed-zero-stage-3-single-file for instructions."
)
# Use deepspeed's internal checkpointing function to handle partitioned weights across processes
# dump states as a checkpoint dictionary object
_exclude_keys = ["state_dict", "optimizer_states", "lr_schedulers"]
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
self.model.save_checkpoint(filepath, client_state=checkpoint)

def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Optional[Dict[str, Any]]:
if self.training_type_plugin.load_full_weights and self.training_type_plugin.zero_stage_3:
# Broadcast to ensure we load from the rank 0 checkpoint
# This doesn't have to be the case when using deepspeed sharded checkpointing
checkpoint_path = self.training_type_plugin.broadcast(checkpoint_path)
return super().load_checkpoint_file(checkpoint_path)

# Rely on deepspeed to load the checkpoint and necessary information
from pytorch_lightning.trainer.states import TrainerFn

is_fitting = self.lightning_module.trainer.state.fn == TrainerFn.FITTING
_, client_state = self.model.load_checkpoint(
checkpoint_path, load_optimizer_states=is_fitting, load_lr_scheduler_states=is_fitting
)
if client_state is None:
raise MisconfigurationException(
"DeepSpeed was unable to load the checkpoint. Ensure you passed in a DeepSpeed compatible checkpoint "
"or a single checkpoint file with `Trainer(plugins=DeepSpeedPlugin(load_full_weights=True))`."
)
return client_state

def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
# override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint_file()`
if self.training_type_plugin.load_full_weights and self.training_type_plugin.zero_stage_3:
self.training_type_plugin.model_to_device()
self._restore_zero_state(checkpoint)

def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
# override to do nothing, deepspeed engine already loaded the states in `load_checkpoint_file()`
pass

def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]:
"""Returns model state."""
model = self.lightning_module
return model.state_dict()

def _restore_zero_state(self, ckpt: Mapping[str, Any]) -> None:
"""
Overrides the normal load_state_dict behaviour in PyTorch to ensure
we gather parameters that may be sharded across processes before loading
the state dictionary when using ZeRO stage 3.
This is then automatically synced across processes.

Args:
ckpt: The ckpt file.
"""

def load(module: torch.nn.Module, prefix=""):

missing_keys = []
unexpected_keys = []
error_msgs = []
state_dict = ckpt["state_dict"]

# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata

local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
if self.training_type_plugin.is_global_zero:
module._load_from_state_dict(
state_dict=state_dict,
prefix=prefix,
local_metadata=local_metadata,
strict=True,
missing_keys=missing_keys,
unexpected_keys=unexpected_keys,
error_msgs=error_msgs,
)

for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")

load(self.lightning_module, prefix="")
39 changes: 39 additions & 0 deletions pytorch_lightning/plugins/checkpoint/torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from pathlib import Path
from typing import Any, Dict, Mapping, Union

from torch import Tensor

import pytorch_lightning as pl
from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load


class TorchCheckpointPlugin(CheckpointPlugin):
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
# dump states as a checkpoint dictionary object
try:
# write the checkpoint dictionary on the file
atomic_save(checkpoint, filepath)
except AttributeError as err:
key = pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
checkpoint.pop(key, None)
rank_zero_warn(f"Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}")
atomic_save(checkpoint, filepath)
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
return pl_load(checkpoint_path, map_location=(lambda storage, loc: storage))
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
self.lightning_module.load_state_dict(checkpoint["state_dict"])

def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
optimizer_states = checkpoint["optimizer_states"]
for optimizer, opt_state in zip(self.lightning_module.trainer.accelerator.optimizers, optimizer_states):
optimizer.load_state_dict(opt_state)

def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]:
"""Returns model state."""
model = self.lightning_module
return model.state_dict()
10 changes: 8 additions & 2 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities import (
Expand Down Expand Up @@ -75,14 +76,19 @@ def __init__(
self,
parallel_devices: Optional[List[torch.device]] = None,
num_nodes: Optional[int] = None,
cluster_environment: ClusterEnvironment = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_plugin: Optional[CheckpointPlugin] = None,
sync_batchnorm: Optional[bool] = None,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
**kwargs: Union[Any, Dict[str, Any]],
) -> None:
super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment)
super().__init__(
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
checkpoint_plugin=checkpoint_plugin,
)
self.interactive_ddp_procs = []
if num_nodes is not None:
rank_zero_deprecation(
Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.checkpoint.checkpoint import CheckpointPlugin
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.trainer.states import TrainerFn
Expand Down Expand Up @@ -63,13 +64,18 @@ def __init__(
parallel_devices: Optional[List[torch.device]] = None,
num_nodes: Optional[int] = None,
cluster_environment: ClusterEnvironment = None,
checkpoint_plugin: Optional[CheckpointPlugin] = None,
sync_batchnorm: Optional[bool] = None,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
**kwargs: Any,
):
super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment)
super().__init__(
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
checkpoint_plugin=checkpoint_plugin,
)
if num_nodes is not None:
rank_zero_deprecation(
"Argument `num_nodes` in `DDPSpawnPlugin` is deprecated in v1.4, and will be removed in v1.6. "
Expand Down
Loading