Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate selected device indices in DeepSpeedStrategy #17952

Merged
merged 15 commits into from
Jul 4, 2023
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Automatically call `xla_model.mark_step()` after `optimizer.step()` with XLA ([#17883](https://github.com/Lightning-AI/lightning/pull/17883))


- Added validation against misconfigured device selection when using the DeepSpeed strategy ([#17952](https://github.com/Lightning-AI/lightning/pull/17952))


### Changed

- Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331))
Expand Down
1 change: 0 additions & 1 deletion src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,6 @@ def _choose_strategy(self) -> Union[Strategy, str]:
if self._num_nodes_flag > 1:
return "ddp"
if len(self._parallel_devices) <= 1:
# TODO: Change this once gpu accelerator was renamed to cuda accelerator
if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or (
isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps")
):
Expand Down
13 changes: 13 additions & 0 deletions src/lightning/fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,8 @@ def _setup_distributed(self) -> None:
f"The DeepSpeed strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`"
" is used."
)
assert self.parallel_devices is not None
_validate_device_index_selection(self.parallel_devices)
reset_seed()
self._set_world_ranks()
self._init_deepspeed_distributed()
Expand Down Expand Up @@ -831,3 +833,14 @@ def _validate_state_keys(state: Dict[str, Any]) -> None:
" values being overwritten by DeepSpeed. Consider changing the name of these keys to something else: "
+ ", ".join(colliding_keys)
)


def _validate_device_index_selection(parallel_devices: List[torch.device]) -> None:
selected_device_indices = [device.index for device in parallel_devices]
expected_device_indices = list(range(len(parallel_devices)))
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if selected_device_indices != expected_device_indices:
raise RuntimeError(
f"The selected device indices {selected_device_indices!r} don't match the local rank values of processes."
" If you need to select GPUs at a specific index, set the `CUDA_VISIBLE_DEVICES` environment variable"
f" instead. For example: `CUDA_VISIBLE_DEVICES={','.join(str(i) for i in selected_device_indices)}`."
)
4 changes: 4 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Automatically call `xla_model.mark_step()` before saving checkpoints with XLA ([#17882](https://github.com/Lightning-AI/lightning/pull/17882))


- Added validation against misconfigured device selection when using the DeepSpeed strategy ([#17952](https://github.com/Lightning-AI/lightning/pull/17952))


### Changed

- Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309))
Expand Down
4 changes: 3 additions & 1 deletion src/lightning/pytorch/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import lightning.pytorch as pl
from lightning.fabric.plugins import ClusterEnvironment
from lightning.fabric.strategies import _StrategyRegistry
from lightning.fabric.strategies.deepspeed import _DEEPSPEED_AVAILABLE
from lightning.fabric.strategies.deepspeed import _DEEPSPEED_AVAILABLE, _validate_device_index_selection
from lightning.fabric.utilities.optimizer import _optimizers_to_device
from lightning.fabric.utilities.seed import reset_seed
from lightning.fabric.utilities.types import _PATH, LRScheduler, ReduceLROnPlateau
Expand Down Expand Up @@ -325,6 +325,8 @@ def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Option
return config

def setup_distributed(self) -> None:
assert self.parallel_devices is not None
_validate_device_index_selection(self.parallel_devices)
reset_seed()
self.set_world_ranks()
self._init_deepspeed_distributed()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,6 @@ def _choose_strategy(self) -> Union[Strategy, str]:
if self._num_nodes_flag > 1:
return "ddp"
if len(self._parallel_devices) <= 1:
# TODO: Change this once gpu accelerator was renamed to cuda accelerator
if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or (
isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps")
):
Expand Down
18 changes: 17 additions & 1 deletion tests/tests_fabric/strategies/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torch.optim import Optimizer

from lightning.fabric import Fabric
from lightning.fabric.accelerators import CPUAccelerator
from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator
from lightning.fabric.strategies import DeepSpeedStrategy
from tests_fabric.helpers.runif import RunIf

Expand Down Expand Up @@ -349,3 +349,19 @@ def test_deepspeed_save_filter(tmp_path):
fabric = Fabric(devices=1, strategy="deepspeed")
with pytest.raises(TypeError, match="manages the state serialization internally"):
fabric.save(tmp_path, {}, filter={})


@RunIf(deepspeed=True)
@pytest.mark.parametrize("device_indices", [[1], [1, 0], [0, 2], [3, 2, 1]])
def test_validate_parallel_devices_indices(device_indices):
"""Test that the strategy validates that it doesn't support selecting specific devices by index.

DeepSpeed doesn't support it and needs the index to match to the local rank of the process.
"""
strategy = DeepSpeedStrategy(
accelerator=CUDAAccelerator(), parallel_devices=[torch.device("cuda", i) for i in device_indices]
)
with pytest.raises(
RuntimeError, match=escape(f"device indices {device_indices!r} don't match the local rank values of processes")
):
strategy.setup_environment()
16 changes: 0 additions & 16 deletions tests/tests_fabric/strategies/test_deepspeed_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,22 +264,6 @@ def test_deepspeed_env_variables_on_platforms(_, deepspeed_dist_mock, platform):
assert os.environ["LOCAL_RANK"] == str(strategy.local_rank)


@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
def test_deepspeed_specific_gpu_device_index():
"""Test that the DeepSpeed strategy can run on specific device indices."""

class RunFabric(BoringFabric):
def step(self, model, batch):
assert self.device.type == "cuda"
assert self.device.index == 1
assert batch.device.index == 1
assert model.device.index == 1
return super().step(model, batch)

fabric = RunFabric(accelerator="cuda", devices=[1], strategy="deepspeed")
fabric.run()


@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True, bf16_cuda=True)
def test_deepspeed_with_bfloat16_precision():
"""Test that the DeepSpeed strategy works with bfloat16 precision."""
Expand Down
63 changes: 19 additions & 44 deletions tests/tests_pytorch/strategies/test_deepspeed_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import json
import logging
import os
from re import escape
from typing import Any, Dict
from unittest import mock

Expand All @@ -26,12 +27,12 @@
from torchmetrics import Accuracy

from lightning.pytorch import LightningDataModule, LightningModule, Trainer
from lightning.pytorch.accelerators import CUDAAccelerator
from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.plugins import DeepSpeedPrecisionPlugin
from lightning.pytorch.strategies import DeepSpeedStrategy
from lightning.pytorch.strategies.deepspeed import _DEEPSPEED_AVAILABLE
from lightning.pytorch.strategies.deepspeed import _DEEPSPEED_AVAILABLE, DeepSpeedStrategy
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11
from tests_pytorch.helpers.datamodules import ClassifDataModule
Expand Down Expand Up @@ -1155,48 +1156,6 @@ def test_deepspeed_gradient_clip_by_value(tmpdir):
trainer.fit(model)


@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
def test_specific_gpu_device_id(tmpdir):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
class TestCallback(Callback):
def on_train_start(self, *_) -> None:
assert model.device.index == 1

def on_train_batch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
batch: Any,
*_,
) -> None:
assert batch.device.index == 1

def on_test_start(self, *_) -> None:
assert model.device.index == 1

def on_test_batch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
batch: Any,
*_,
) -> None:
assert batch.device.index == 1

model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
accelerator="gpu",
devices=[1],
strategy="deepspeed",
callbacks=TestCallback(),
enable_progress_bar=False,
enable_model_summary=False,
)
trainer.fit(model)
trainer.test(model)


@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
def test_deepspeed_multi_save_same_filepath(tmpdir):
"""Test that verifies that deepspeed saves only latest checkpoint in the specified path and deletes the old
Expand Down Expand Up @@ -1307,3 +1266,19 @@ def transfer_batch_to_device(self, batch, *args, **kwargs):
batch = trainer.strategy.batch_to_device(batch)
assert batch.is_cuda
assert batch.dtype is torch.float16


@RunIf(deepspeed=True)
@pytest.mark.parametrize("device_indices", [[1], [1, 0], [0, 2], [3, 2, 1]])
def test_validate_parallel_devices_indices(device_indices):
"""Test that the strategy validates that it doesn't support selecting specific devices by index.

DeepSpeed doesn't support it and needs the index to match to the local rank of the process.
"""
strategy = DeepSpeedStrategy(
accelerator=CUDAAccelerator(), parallel_devices=[torch.device("cuda", i) for i in device_indices]
)
with pytest.raises(
RuntimeError, match=escape(f"device indices {device_indices!r} don't match the local rank values of processes")
):
strategy.setup_environment()