From 3f4790bd27196d8cdd926ce1db928714f4172d0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 4 Jul 2023 11:58:38 -0700 Subject: [PATCH] Validate selected device indices in `DeepSpeedStrategy` (#17952) --- src/lightning/fabric/CHANGELOG.md | 3 + src/lightning/fabric/connector.py | 1 - src/lightning/fabric/strategies/deepspeed.py | 13 ++++ src/lightning/pytorch/CHANGELOG.md | 4 ++ src/lightning/pytorch/strategies/deepspeed.py | 4 +- .../connectors/accelerator_connector.py | 1 - .../tests_fabric/strategies/test_deepspeed.py | 18 +++++- .../strategies/test_deepspeed_integration.py | 22 ------- .../strategies/test_deepspeed_strategy.py | 63 ++++++------------- 9 files changed, 59 insertions(+), 70 deletions(-) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index b715248cf583f..213626513ef8a 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -80,6 +80,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Automatically call `xla_model.mark_step()` after `optimizer.step()` with XLA ([#17883](https://github.com/Lightning-AI/lightning/pull/17883)) +- Added validation against misconfigured device selection when using the DeepSpeed strategy ([#17952](https://github.com/Lightning-AI/lightning/pull/17952)) + + ### Changed - Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331)) diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index f99418185712f..d007ff26231cf 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -377,7 +377,6 @@ def _choose_strategy(self) -> Union[Strategy, str]: if self._num_nodes_flag > 1: return "ddp" if len(self._parallel_devices) <= 1: - # TODO: Change this once gpu accelerator was renamed to cuda accelerator if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or ( isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps") ): diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 66b4c3ca6efa4..75d8ed1546f4f 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -598,6 +598,8 @@ def _setup_distributed(self) -> None: f"The DeepSpeed strategy is only supported on CUDA GPUs but `{self.accelerator.__class__.__name__}`" " is used." ) + assert self.parallel_devices is not None + _validate_device_index_selection(self.parallel_devices) reset_seed() self._set_world_ranks() self._init_deepspeed_distributed() @@ -831,3 +833,14 @@ def _validate_state_keys(state: Dict[str, Any]) -> None: " values being overwritten by DeepSpeed. Consider changing the name of these keys to something else: " + ", ".join(colliding_keys) ) + + +def _validate_device_index_selection(parallel_devices: List[torch.device]) -> None: + selected_device_indices = [device.index for device in parallel_devices] + expected_device_indices = list(range(len(parallel_devices))) + if selected_device_indices != expected_device_indices: + raise RuntimeError( + f"The selected device indices {selected_device_indices!r} don't match the local rank values of processes." + " If you need to select GPUs at a specific index, set the `CUDA_VISIBLE_DEVICES` environment variable" + f" instead. For example: `CUDA_VISIBLE_DEVICES={','.join(str(i) for i in selected_device_indices)}`." + ) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index fefe9f738533a..d956f5215b6be 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -59,6 +59,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Automatically call `xla_model.mark_step()` before saving checkpoints with XLA ([#17882](https://github.com/Lightning-AI/lightning/pull/17882)) + +- Added validation against misconfigured device selection when using the DeepSpeed strategy ([#17952](https://github.com/Lightning-AI/lightning/pull/17952)) + + ### Changed - Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309)) diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index 4a8a30a18fe27..516e02ff79eeb 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -30,7 +30,7 @@ import lightning.pytorch as pl from lightning.fabric.plugins import ClusterEnvironment from lightning.fabric.strategies import _StrategyRegistry -from lightning.fabric.strategies.deepspeed import _DEEPSPEED_AVAILABLE +from lightning.fabric.strategies.deepspeed import _DEEPSPEED_AVAILABLE, _validate_device_index_selection from lightning.fabric.utilities.optimizer import _optimizers_to_device from lightning.fabric.utilities.seed import reset_seed from lightning.fabric.utilities.types import _PATH, LRScheduler, ReduceLROnPlateau @@ -325,6 +325,8 @@ def _load_config(self, config: Optional[Union[_PATH, Dict[str, Any]]]) -> Option return config def setup_distributed(self) -> None: + assert self.parallel_devices is not None + _validate_device_index_selection(self.parallel_devices) reset_seed() self.set_world_ranks() self._init_deepspeed_distributed() diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 42f46cd75047d..7e21f6216c1e3 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -444,7 +444,6 @@ def _choose_strategy(self) -> Union[Strategy, str]: if self._num_nodes_flag > 1: return "ddp" if len(self._parallel_devices) <= 1: - # TODO: Change this once gpu accelerator was renamed to cuda accelerator if isinstance(self._accelerator_flag, (CUDAAccelerator, MPSAccelerator)) or ( isinstance(self._accelerator_flag, str) and self._accelerator_flag in ("cuda", "gpu", "mps") ): diff --git a/tests/tests_fabric/strategies/test_deepspeed.py b/tests/tests_fabric/strategies/test_deepspeed.py index 6a542ad65d73c..c0c14ad429a70 100644 --- a/tests/tests_fabric/strategies/test_deepspeed.py +++ b/tests/tests_fabric/strategies/test_deepspeed.py @@ -21,7 +21,7 @@ import torch from torch.optim import Optimizer -from lightning.fabric.accelerators import CPUAccelerator +from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator from lightning.fabric.strategies import DeepSpeedStrategy from tests_fabric.helpers.runif import RunIf @@ -348,3 +348,19 @@ def test_deepspeed_save_filter(tmp_path): strategy = DeepSpeedStrategy() with pytest.raises(TypeError, match="manages the state serialization internally"): strategy.save_checkpoint(path=tmp_path, state={}, filter={}) + + +@RunIf(deepspeed=True) +@pytest.mark.parametrize("device_indices", [[1], [1, 0], [0, 2], [3, 2, 1]]) +def test_validate_parallel_devices_indices(device_indices): + """Test that the strategy validates that it doesn't support selecting specific devices by index. + + DeepSpeed doesn't support it and needs the index to match to the local rank of the process. + """ + strategy = DeepSpeedStrategy( + accelerator=CUDAAccelerator(), parallel_devices=[torch.device("cuda", i) for i in device_indices] + ) + with pytest.raises( + RuntimeError, match=escape(f"device indices {device_indices!r} don't match the local rank values of processes") + ): + strategy.setup_environment() diff --git a/tests/tests_fabric/strategies/test_deepspeed_integration.py b/tests/tests_fabric/strategies/test_deepspeed_integration.py index 2a1cbc0167356..fa6bca967314c 100644 --- a/tests/tests_fabric/strategies/test_deepspeed_integration.py +++ b/tests/tests_fabric/strategies/test_deepspeed_integration.py @@ -264,28 +264,6 @@ def test_deepspeed_env_variables_on_platforms(_, deepspeed_dist_mock, platform): assert os.environ["LOCAL_RANK"] == str(strategy.local_rank) -@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) -def test_deepspeed_specific_gpu_device_index(): - """Test that the DeepSpeed strategy can run on specific device indices.""" - fabric = Fabric(accelerator="cuda", devices=[1], strategy="deepspeed") - fabric.launch() - assert fabric.device.type == "cuda" - assert fabric.device.index == 1 - - model = nn.Linear(32, 2) - optimizer = torch.optim.Adam(model.parameters(), lr=0.1) - model, optimizer = fabric.setup(model, optimizer) - assert model.device.index == 1 - - batch = torch.rand(2, 32, device=fabric.device) - assert batch.device.index == 1 - - loss = model(batch).sum() - fabric.backward(loss) - optimizer.step() - optimizer.zero_grad() - - @RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True, bf16_cuda=True) def test_deepspeed_with_bfloat16_precision(): """Test that the DeepSpeed strategy works with bfloat16 precision.""" diff --git a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py index 569f9fcafe81e..6aaf93034957c 100644 --- a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py +++ b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py @@ -15,6 +15,7 @@ import json import logging import os +from re import escape from typing import Any, Dict from unittest import mock @@ -26,12 +27,12 @@ from torchmetrics import Accuracy from lightning.pytorch import LightningDataModule, LightningModule, Trainer +from lightning.pytorch.accelerators import CUDAAccelerator from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.plugins import DeepSpeedPrecisionPlugin -from lightning.pytorch.strategies import DeepSpeedStrategy -from lightning.pytorch.strategies.deepspeed import _DEEPSPEED_AVAILABLE +from lightning.pytorch.strategies.deepspeed import _DEEPSPEED_AVAILABLE, DeepSpeedStrategy from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11 from tests_pytorch.helpers.datamodules import ClassifDataModule @@ -1155,48 +1156,6 @@ def test_deepspeed_gradient_clip_by_value(tmpdir): trainer.fit(model) -@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) -def test_specific_gpu_device_id(tmpdir): - class TestCallback(Callback): - def on_train_start(self, *_) -> None: - assert model.device.index == 1 - - def on_train_batch_start( - self, - trainer: Trainer, - pl_module: LightningModule, - batch: Any, - *_, - ) -> None: - assert batch.device.index == 1 - - def on_test_start(self, *_) -> None: - assert model.device.index == 1 - - def on_test_batch_start( - self, - trainer: Trainer, - pl_module: LightningModule, - batch: Any, - *_, - ) -> None: - assert batch.device.index == 1 - - model = BoringModel() - trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=True, - accelerator="gpu", - devices=[1], - strategy="deepspeed", - callbacks=TestCallback(), - enable_progress_bar=False, - enable_model_summary=False, - ) - trainer.fit(model) - trainer.test(model) - - @RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) def test_deepspeed_multi_save_same_filepath(tmpdir): """Test that verifies that deepspeed saves only latest checkpoint in the specified path and deletes the old @@ -1307,3 +1266,19 @@ def transfer_batch_to_device(self, batch, *args, **kwargs): batch = trainer.strategy.batch_to_device(batch) assert batch.is_cuda assert batch.dtype is torch.float16 + + +@RunIf(deepspeed=True) +@pytest.mark.parametrize("device_indices", [[1], [1, 0], [0, 2], [3, 2, 1]]) +def test_validate_parallel_devices_indices(device_indices): + """Test that the strategy validates that it doesn't support selecting specific devices by index. + + DeepSpeed doesn't support it and needs the index to match to the local rank of the process. + """ + strategy = DeepSpeedStrategy( + accelerator=CUDAAccelerator(), parallel_devices=[torch.device("cuda", i) for i in device_indices] + ) + with pytest.raises( + RuntimeError, match=escape(f"device indices {device_indices!r} don't match the local rank values of processes") + ): + strategy.setup_environment()