diff --git a/src/lightning_lite/strategies/fairscale.py b/src/lightning_lite/strategies/fairscale.py index 7c39f94e66969..86da43ec41ec5 100644 --- a/src/lightning_lite/strategies/fairscale.py +++ b/src/lightning_lite/strategies/fairscale.py @@ -59,11 +59,10 @@ def __init__( cluster_environment=cluster_environment, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin, - process_group_backen=process_group_backend, + process_group_backend=process_group_backend, timeout=timeout, **kwargs, ) - super().__init__() if "reduce_buffer_size" not in self._ddp_kwargs: # For multi-node training, enabling bucketing will improve performance. self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0 @@ -78,6 +77,11 @@ def setup_module_and_optimizers( and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`. """ optimizers = _reinit_optimizers_with_oss(optimizers, self.precision_plugin, self.num_nodes) + for optimizer in optimizers: + # This forces buckets to be rebuilt on the first forward pass + # We are not sure why this is needed, but it prevents an error resulting from buckets having a different + # device than the params + optimizer._clear_cache() model = ShardedDataParallel(module, sharded_optimizer=optimizers, **self._ddp_kwargs) return model, optimizers @@ -131,11 +135,10 @@ def __init__( cluster_environment=cluster_environment, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin, - process_group_backen=process_group_backend, + process_group_backend=process_group_backend, timeout=timeout, **kwargs, ) - super().__init__() if "reduce_buffer_size" not in self._ddp_kwargs: # For multi-node training, enabling bucketing will improve performance. self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0 @@ -150,6 +153,11 @@ def setup_module_and_optimizers( and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`. """ optimizers = _reinit_optimizers_with_oss(optimizers, self.precision_plugin, self.num_nodes) + for optimizer in optimizers: + # This forces buckets to be rebuilt on the first forward pass + # We are not sure why this is needed, but it prevents an error resulting from buckets having a different + # device than the params + optimizer._clear_cache() model = ShardedDataParallel(module, sharded_optimizer=optimizers, **self._ddp_kwargs) return model, optimizers @@ -180,15 +188,6 @@ def register_strategies(cls, strategy_registry: Dict) -> None: description=cls.__class__.__name__, ) - def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"]: - for x, optimizer in enumerate(optimizers): - if not isinstance(optimizer, OSS): - optim_class = type(optimizer) - zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) - optimizers[x] = zero_optimizer - del optimizer - return optimizers - def _reinit_optimizers_with_oss( optimizers: List[Optimizer], precision_plugin: Precision, num_nodes: int diff --git a/src/lightning_lite/strategies/strategy.py b/src/lightning_lite/strategies/strategy.py index abff9515dda0a..bd87501a3bc1a 100644 --- a/src/lightning_lite/strategies/strategy.py +++ b/src/lightning_lite/strategies/strategy.py @@ -248,6 +248,13 @@ def get_optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: Allows for syncing/collating optimizer state from processes in custom plugins. """ + if hasattr(optimizer, "consolidate_state_dict"): + # there are optimizers like Fairscale's OSS or PyTorch's ZeroRedundancyOptimizer that shard their + # states, and to avoid OOM we consolidate the full state on rank 0 only + optimizer.consolidate_state_dict() + return optimizer.state_dict() if self.is_global_zero else {} + + # for optimizers that are not sharded, we return the state dict on all ranks return optimizer.state_dict() def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: diff --git a/src/pytorch_lightning/strategies/sharded.py b/src/pytorch_lightning/strategies/sharded.py index 8b9ccdc20462b..cdf6853935c3a 100644 --- a/src/pytorch_lightning/strategies/sharded.py +++ b/src/pytorch_lightning/strategies/sharded.py @@ -19,8 +19,7 @@ from torch.optim import Optimizer import pytorch_lightning as pl -from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE -from lightning_lite.utilities.enums import PrecisionType +from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE, _reinit_optimizers_with_oss from lightning_lite.utilities.optimizer import optimizers_to_device from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase @@ -104,24 +103,8 @@ def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]: assert self.lightning_module is not None if self.model is not None and self.lightning_module.trainer.state.fn != TrainerFn.FITTING: return optimizers - - return self._reinit_optimizers_with_oss(optimizers) - - def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"]: - for x, optimizer in enumerate(optimizers): - if isinstance(optimizer, LightningOptimizer): - optimizer = optimizer._optimizer - if not isinstance(optimizer, OSS): - optim_class = type(optimizer) - zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) - is_fp16 = self.precision_plugin.precision in (PrecisionType.MIXED, PrecisionType.HALF) - # For multi-node training, compressing the model shards in fp16 before broadcasting - # improves performance. When using PyTorch AMP, it will not degrade - # the model performance. - zero_optimizer.broadcast_fp16 = is_fp16 and self.num_nodes > 1 - optimizers[x] = zero_optimizer - del optimizer - return optimizers + optimizers = [o._optimizer if isinstance(o, LightningOptimizer) else o for o in optimizers] + return _reinit_optimizers_with_oss(optimizers, self.precision_plugin, self.num_nodes) def pre_backward(self, closure_loss: Tensor) -> None: pass diff --git a/src/pytorch_lightning/strategies/sharded_spawn.py b/src/pytorch_lightning/strategies/sharded_spawn.py index 934cf680de0f4..7266f6c91bcee 100644 --- a/src/pytorch_lightning/strategies/sharded_spawn.py +++ b/src/pytorch_lightning/strategies/sharded_spawn.py @@ -19,8 +19,9 @@ from torch.optim import Optimizer import pytorch_lightning as pl -from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE +from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE, _reinit_optimizers_with_oss from lightning_lite.utilities.optimizer import optimizers_to_device +from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy from pytorch_lightning.trainer.states import TrainerFn @@ -68,21 +69,12 @@ def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer] model = ShardedDataParallel(model, sharded_optimizer=optimizers, **self._ddp_kwargs) return model, optimizers - def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"]: - for x, optimizer in enumerate(optimizers): - if not isinstance(optimizer, OSS): - optim_class = type(optimizer) - zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) - optimizers[x] = zero_optimizer - del optimizer - return optimizers - def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]: assert self.lightning_module if self.model is not None and self.lightning_module.trainer.state.fn != TrainerFn.FITTING: return optimizers - - return self._reinit_optimizers_with_oss(optimizers) + optimizers = [o._optimizer if isinstance(o, LightningOptimizer) else o for o in optimizers] + return _reinit_optimizers_with_oss(optimizers, self.precision_plugin, self.num_nodes) @contextmanager def block_backward_sync(self) -> Generator: diff --git a/tests/tests_lite/helpers/models.py b/tests/tests_lite/helpers/models.py index 5eef87866abe5..abe40de22767a 100644 --- a/tests/tests_lite/helpers/models.py +++ b/tests/tests_lite/helpers/models.py @@ -36,6 +36,9 @@ class BoringLite(LightningLite): def get_model(self) -> Module: return nn.Linear(32, 2) + def get_optimizer(self, module: Module) -> Optimizer: + return torch.optim.SGD(module.parameters(), lr=0.1) + def get_dataloader(self) -> DataLoader: return DataLoader(RandomDataset(32, 64)) @@ -52,12 +55,18 @@ def after_optimizer_step(self, model: Module, optimizer: Optimizer) -> None: def run(self) -> None: model = self.get_model() - optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + optimizer = self.get_optimizer(model) dataloader = self.get_dataloader() model, optimizer = self.setup(model, optimizer) dataloader = self.setup_dataloaders(dataloader) + self.model = model + self.optimizer = optimizer + self.dataloader = dataloader + + model.train() + data_iter = iter(dataloader) batch = next(data_iter) loss = self.step(model, batch) diff --git a/tests/tests_lite/strategies/test_fairscale.py b/tests/tests_lite/strategies/test_fairscale.py index 77ee5cb3d0f53..b5eb5ba67ffb1 100644 --- a/tests/tests_lite/strategies/test_fairscale.py +++ b/tests/tests_lite/strategies/test_fairscale.py @@ -13,10 +13,13 @@ # limitations under the License. from unittest import mock +import pytest +import torch.nn as nn +import torch.optim from tests_lite.helpers.runif import RunIf from lightning_lite.strategies import DDPShardedStrategy -from lightning_lite.strategies.fairscale import ShardedDataParallel +from lightning_lite.strategies.fairscale import DDPSpawnShardedStrategy, ShardedDataParallel @RunIf(fairscale=True) @@ -26,3 +29,45 @@ def test_block_backward_sync(): with strategy.block_backward_sync(model): pass model.no_sync.assert_called_once() + + +@RunIf(fairscale=True) +@mock.patch("lightning_lite.strategies.fairscale._reinit_optimizers_with_oss", autospec=True) +@pytest.mark.parametrize("cls", [DDPShardedStrategy, DDPSpawnShardedStrategy]) +def test_fairscale_custom_kwargs(_, cls): + """Test that if custom kwargs are passed, they are set correctly.""" + strategy = cls(reduce_fp16=True) + assert strategy._ddp_kwargs["reduce_fp16"] is True + + model = nn.Linear(3, 3) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + + with mock.patch("lightning_lite.strategies.fairscale.ShardedDataParallel", autospec=True) as mock_sharded: + strategy.setup_module_and_optimizers(model, [optimizer]) + args, kwargs = mock_sharded.call_args + assert kwargs["reduce_fp16"] is True + + +@RunIf(fairscale=True) +@mock.patch("lightning_lite.strategies.fairscale._reinit_optimizers_with_oss", autospec=True) +@pytest.mark.parametrize("kwargs, expected_buffer_size", [(dict(), 0), (dict(reduce_buffer_size=128), 128)]) +@pytest.mark.parametrize("num_nodes", [1, 2]) +def test_fairscale_custom_kwargs_reduce_buffer_size(_, kwargs, expected_buffer_size, num_nodes): + """Test that ``reduce_buffer_size`` is correctly set based on provided kwargs.""" + strategy = DDPShardedStrategy(**kwargs) + strategy.num_nodes = num_nodes + + model = nn.Linear(3, 3) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + + with mock.patch("lightning_lite.strategies.fairscale.ShardedDataParallel", autospec=True) as mock_sharded: + strategy.setup_module_and_optimizers(model, [optimizer]) + + args, kwargs = mock_sharded.call_args + assert "reduce_buffer_size" in kwargs + + if num_nodes > 1 and len(kwargs) == 0: + # If user has not specified a buffer size, and we're using multiple nodes, check to see if default is set + assert kwargs["reduce_buffer_size"] == DDPShardedStrategy._REDUCE_BUFFER_SIZE_DEFAULT + else: + assert kwargs["reduce_buffer_size"] == expected_buffer_size diff --git a/tests/tests_lite/strategies/test_fairscale_integration.py b/tests/tests_lite/strategies/test_fairscale_integration.py new file mode 100644 index 0000000000000..484eddc070785 --- /dev/null +++ b/tests/tests_lite/strategies/test_fairscale_integration.py @@ -0,0 +1,83 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import pytest +import torch +from tests_lite.helpers.models import BoringLite +from tests_lite.helpers.runif import RunIf + + +class ShardedSaveAndLoad(BoringLite): + def get_optimizer(self, module): + optimizer = super().get_optimizer(module) + if self.with_fairscale_oss: + from fairscale.optim import OSS + + optimizer = OSS(params=optimizer.param_groups, optim=type(optimizer), **optimizer.defaults) + return optimizer + + def run(self, tmpdir, with_fairscale_oss=False): + self.with_fairscale_oss = with_fairscale_oss + + super().run() + + from fairscale.nn import ShardedDataParallel + from fairscale.optim import OSS + + # the model and optimizer is wrapped correctly + assert isinstance(self.model._forward_module, ShardedDataParallel) + assert isinstance(self.optimizer.optimizer, OSS) + + self.model.cpu() + + checkpoint_path = os.path.join(tmpdir, "checkpoint.ckpt") + checkpoint = {"model": self.model.state_dict(), "optimizer": self.optimizer.state_dict()} + self.save(checkpoint, checkpoint_path) + + self.barrier() + + loaded_checkpoint = self.load(checkpoint_path) + new_model = self.get_model() + new_model.load_state_dict(loaded_checkpoint["model"]) + + # Assert model parameters are identical after loading + for trained_param, loaded_param in zip(self.model.parameters(), new_model.parameters()): + assert torch.equal(trained_param, loaded_param) + + +@RunIf(fairscale=True) +@pytest.mark.parametrize("accelerator", ["cpu", pytest.param("cuda", marks=RunIf(min_cuda_gpus=2))]) +@pytest.mark.parametrize("strategy", (pytest.param("ddp_sharded", marks=RunIf(standalone=True)), "ddp_sharded_spawn")) +@pytest.mark.parametrize("with_fairscale_oss", (True, False)) +def test_fairscale_multi_process_checkpoint_state_consolidation(with_fairscale_oss, strategy, accelerator, tmpdir): + """Test that the sharded optimizer states get consolidated when saving the checkpoint, and that the loaded + weights is identical to the saved one.""" + lite = ShardedSaveAndLoad(strategy=strategy, accelerator=accelerator, devices=2) + lite.run(tmpdir, with_fairscale_oss=with_fairscale_oss) + + +@pytest.mark.parametrize( + "strategy, expected_find_unused_parameters", + [ + ("ddp_sharded", None), + ("ddp_sharded_find_unused_parameters_false", False), + ("ddp_sharded_spawn", None), + ("ddp_sharded_spawn_find_unused_parameters_false", False), + ], +) +def test_fairscale_find_unused_parameters_from_registry(strategy, expected_find_unused_parameters): + lite = BoringLite(strategy=strategy) + if expected_find_unused_parameters is not None: + assert lite._strategy._ddp_kwargs["find_unused_parameters"] is False diff --git a/tests/tests_lite/test_connector.py b/tests/tests_lite/test_connector.py index 3f2ff0b55c8c9..73a22fc473c3f 100644 --- a/tests/tests_lite/test_connector.py +++ b/tests/tests_lite/test_connector.py @@ -252,7 +252,7 @@ def test_ipython_compatible_strategy_ddp_fork(monkeypatch): ) @pytest.mark.parametrize("devices", [1, 2]) @mock.patch("lightning_lite.accelerators.cuda.num_cuda_devices", return_value=2) -def test_accelerator_choice_multi_node_gpu(_, strategy, strategy_class, devices): +def test_strategy_choice_multi_node_gpu(_, strategy, strategy_class, devices): connector = _Connector(num_nodes=2, accelerator="gpu", strategy=strategy, devices=devices) assert isinstance(connector.strategy, strategy_class) @@ -376,6 +376,19 @@ def test_strategy_choice_gpu_str(strategy, strategy_class): assert isinstance(connector.strategy, strategy_class) +@RunIf(fairscale=True) +@pytest.mark.parametrize( + "strategy,expected_strategy", [("ddp_sharded", DDPShardedStrategy), ("ddp_sharded_spawn", DDPSpawnShardedStrategy)] +) +@pytest.mark.parametrize( + "precision,expected_precision", [(16, NativeMixedPrecision), (32, Precision), ("bf16", NativeMixedPrecision)] +) +def test_strategy_choice_sharded(strategy, expected_strategy, precision, expected_precision): + connector = _Connector(strategy=strategy, devices=1, precision=precision) + assert isinstance(connector.strategy, expected_strategy) + assert isinstance(connector.precision_plugin, expected_precision) + + @RunIf(min_cuda_gpus=2) @pytest.mark.parametrize("strategy_class", [DDPSpawnStrategy, DDPStrategy]) def test_strategy_choice_gpu_instance(strategy_class): diff --git a/tests/tests_pytorch/strategies/test_sharded_strategy.py b/tests/tests_pytorch/strategies/test_sharded_strategy.py index a2b3775eb6708..34436e47ba2f4 100644 --- a/tests/tests_pytorch/strategies/test_sharded_strategy.py +++ b/tests/tests_pytorch/strategies/test_sharded_strategy.py @@ -8,6 +8,7 @@ from lightning_lite.strategies.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.demos.boring_classes import BoringModel +from pytorch_lightning.plugins import NativeMixedPrecisionPlugin from pytorch_lightning.strategies import DDPShardedStrategy, DDPSpawnShardedStrategy from pytorch_lightning.trainer.states import TrainerFn from tests_pytorch.helpers.runif import RunIf @@ -56,6 +57,7 @@ def test_ddp_choice_sharded_amp(strategy, expected): """Test to ensure that plugin native amp plugin is correctly chosen when using sharded.""" trainer = Trainer(fast_dev_run=True, accelerator="gpu", devices=1, precision=16, strategy=strategy) assert isinstance(trainer.strategy, expected) + assert isinstance(trainer.precision_plugin, NativeMixedPrecisionPlugin) @RunIf(fairscale=True)