Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions dockers/base-cuda/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ RUN \
pip install --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./apex && \
rm -rf apex

RUN \
# install DeepSpeed
pip install deepspeed>=0.3.14

RUN \
# Show what we have
pip --version && \
Expand Down
33 changes: 28 additions & 5 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(
cpu_checkpointing: bool = False,
contiguous_memory_optimization: bool = False,
synchronize_checkpoint_boundary: bool = False,
save_full_weights: bool = True,
) -> None:
"""

Expand Down Expand Up @@ -177,11 +178,16 @@ def __init__(
Not supported by all models

synchronize_checkpoint_boundary: Insert :func:`torch.cuda.synchronize` at each checkpoint boundary.

save_full_weights: Gathers weights across all processes before saving to disk
when using ZeRO Stage 3. This allows a single weight file to contain the entire model,
rather than individual sharded weight files.
Disable to save sharded states individually. (Default: True)
"""
if not _DEEPSPEED_AVAILABLE:
raise MisconfigurationException(
"To use the DeepSpeed plugin, you must have DeepSpeed installed."
" pip install deepspeed mpi4py"
" pip install deepspeed"
)
super().__init__(
parallel_devices=parallel_devices, num_nodes=num_nodes, cluster_environment=cluster_environment
Expand All @@ -205,11 +211,13 @@ def __init__(
allgather_partitions=allgather_partitions,
reduce_scatter=reduce_scatter,
allgather_bucket_size=allgather_bucket_size,
reduce_bucket_size=reduce_bucket_size
reduce_bucket_size=reduce_bucket_size,
)
self._config_initialized = False
deepspeed.utils.logging.logger.setLevel(logging_level)

self.save_full_weights = save_full_weights

# default FP16 parameters.
self.loss_scale = loss_scale
self.initial_scale_power = initial_scale_power
Expand Down Expand Up @@ -472,17 +480,27 @@ def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.

Args:
checkpoint: The checkpoint state dictionary
filepath: write-target file's path
weights_only: saving model weights only
"""
if self.world_size > 1 and self.zero_stage_3:
if self.save_full_weights:
# todo: expose this as general function in deepspeed
state_dict = self.deepspeed_engine._zero3_consolidated_fp16_state_dict()
if self.is_global_zero:
# State dict keys will include reference to wrapper LightningDeepSpeedModule
# Delete `module` prefix before saving.
state_dict = {k.partition('module.')[2]: state_dict[k] for k in state_dict.keys()}
checkpoint['state_dict'] = state_dict
return super().save_checkpoint(checkpoint, filepath)
return

# Use deepspeed's internal checkpointing function to handle partitioned weights across processes
# dump states as a checkpoint dictionary object
save_dir = self._filepath_to_dir(filepath)
_exclude_keys = ['state_dict', 'optimizer_states', 'lr_schedulers']
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
self.deepspeed_engine.save_checkpoint(save_dir, client_state=checkpoint)

else:
super().save_checkpoint(checkpoint, filepath)

Expand All @@ -491,7 +509,8 @@ def restore_model_state_from_ckpt_path(
ckpt_path: str,
map_location: Callable = lambda storage, loc: storage,
) -> Tuple[Dict, bool]:
if self.world_size > 1:
if not self.save_full_weights and self.world_size > 1:
# Rely on deepspeed to load the checkpoint and necessary information
from pytorch_lightning.trainer.states import TrainerState
stage_is_fit = self.lightning_module.trainer.state == TrainerState.FITTING
save_dir = self._filepath_to_dir(ckpt_path)
Expand All @@ -511,6 +530,10 @@ def restore_model_state_from_ckpt_path(
# hook: give user access to checkpoint if needed.
self.lightning_module.on_load_checkpoint(client_state)
return client_state, False

# Broadcast to ensure we load from the rank 0 checkpoint
# This doesn't have to be the case when using deepspeed sharded checkpointing
ckpt_path = self.broadcast(ckpt_path)
return super().restore_model_state_from_ckpt_path(ckpt_path, map_location=map_location)

def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int:
Expand Down
1 change: 0 additions & 1 deletion requirements/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,3 @@ hydra-core>=1.0
# todo: when switch to standard package stream, drop `fairscale` from hard mocked docs libs
https://github.com/PyTorchLightning/fairscale/archive/pl_1.2.0.zip
jsonargparse[signatures]>=3.3.1
deepspeed>=0.3.13
37 changes: 26 additions & 11 deletions tests/plugins/test_deepspeed_plugin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import os
from typing import Any
from typing import Any, Dict

import pytest
import torch
Expand Down Expand Up @@ -28,6 +28,9 @@ def __init__(self):
def configure_sharded_model(self) -> None:
self.linear = torch.nn.Linear(32, 2)

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self.configure_sharded_model()


def test_deepspeed_lightning_module(tmpdir):
"""
Expand Down Expand Up @@ -456,23 +459,17 @@ def test_deepspeed_multigpu_stage_3(tmpdir, deepspeed_config):
trainer.fit(model)
trainer.test(model)

# todo (tchaton) Currently load_from_checkpoint is not support for zero-v3
# _assert_save_model_is_equal(model, tmpdir, trainer)
_assert_save_model_is_equal(model, tmpdir, trainer, cls=ModelParallelBoringModel)


@RunIf(min_gpus=2, deepspeed=True, special=True)
def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir):
"""
Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint,
and see convergence.
"""
def run_checkpoint_test(tmpdir, save_full_weights):
seed_everything(42)
model = ModelParallelClassificationModel()
dm = ClassifDataModule()
ck = ModelCheckpoint(monitor="val_acc", mode="max", save_last=True, save_top_k=-1)
trainer = Trainer(
max_epochs=10,
plugins=[DeepSpeedPlugin(stage=3)],
plugins=[DeepSpeedPlugin(stage=3, save_full_weights=save_full_weights)],
default_root_dir=tmpdir,
gpus=2,
precision=16,
Expand All @@ -490,7 +487,7 @@ def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir):

trainer = Trainer(
max_epochs=10,
plugins=[DeepSpeedPlugin(stage=3)],
plugins=[DeepSpeedPlugin(stage=3, save_full_weights=save_full_weights)],
default_root_dir=tmpdir,
gpus=2,
precision=16,
Expand All @@ -506,6 +503,24 @@ def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir):
assert results[-1] > 0.7


@RunIf(min_gpus=2, deepspeed=True, special=True)
def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir):
"""
Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint,
and see convergence.
"""
run_checkpoint_test(tmpdir, save_full_weights=False)


@RunIf(min_gpus=2, deepspeed=True, special=True)
def test_deepspeed_multigpu_stage_3_checkpointing_full_weights(tmpdir):
"""
Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint,
where we save the full weights to one file.
"""
run_checkpoint_test(tmpdir, save_full_weights=True)


@RunIf(min_gpus=2, deepspeed=True, special=True)
@pytest.mark.parametrize('cpu_offload', [True, False])
def test_deepspeed_multigpu_stage_2_accumulated_grad_batches(tmpdir, cpu_offload):
Expand Down