Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fairscale integration tests for Lite #14921

Merged
merged 16 commits into from
Sep 29, 2022
1 change: 1 addition & 0 deletions src/lightning_lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
has_iterable_dataset,
)
from lightning_lite.utilities.distributed import DistributedSamplerWrapper
from lightning_lite.utilities.optimizer import optimizers_to_device
from lightning_lite.utilities.seed import seed_everything
from lightning_lite.utilities.warnings import PossibleUserWarning
from lightning_lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
Expand Down
25 changes: 12 additions & 13 deletions src/lightning_lite/strategies/fairscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,10 @@ def __init__(
cluster_environment=cluster_environment,
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
process_group_backen=process_group_backend,
carmocca marked this conversation as resolved.
Show resolved Hide resolved
process_group_backend=process_group_backend,
timeout=timeout,
**kwargs,
)
super().__init__()
if "reduce_buffer_size" not in self._ddp_kwargs:
# For multi-node training, enabling bucketing will improve performance.
self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0
Expand All @@ -78,6 +77,11 @@ def setup_module_and_optimizers(
and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`.
"""
optimizers = _reinit_optimizers_with_oss(optimizers, self.precision_plugin, self.num_nodes)
for optimizer in optimizers:
# This forces buckets to be rebuilt on the first forward pass
# We are not sure why this is needed, but it prevents an error resulting from buckets having a different
# device than the params
optimizer._clear_cache()
model = ShardedDataParallel(module, sharded_optimizer=optimizers, **self._ddp_kwargs)
return model, optimizers

Expand Down Expand Up @@ -131,11 +135,10 @@ def __init__(
cluster_environment=cluster_environment,
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
process_group_backen=process_group_backend,
process_group_backend=process_group_backend,
timeout=timeout,
**kwargs,
)
super().__init__()
if "reduce_buffer_size" not in self._ddp_kwargs:
# For multi-node training, enabling bucketing will improve performance.
self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0
Expand All @@ -150,6 +153,11 @@ def setup_module_and_optimizers(
and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`.
"""
optimizers = _reinit_optimizers_with_oss(optimizers, self.precision_plugin, self.num_nodes)
for optimizer in optimizers:
# This forces buckets to be rebuilt on the first forward pass
# We are not sure why this is needed, but it prevents an error resulting from buckets having a different
# device than the params
optimizer._clear_cache()
model = ShardedDataParallel(module, sharded_optimizer=optimizers, **self._ddp_kwargs)
return model, optimizers

Expand Down Expand Up @@ -180,15 +188,6 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
description=cls.__class__.__name__,
)

def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"]:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
carmocca marked this conversation as resolved.
Show resolved Hide resolved
for x, optimizer in enumerate(optimizers):
if not isinstance(optimizer, OSS):
optim_class = type(optimizer)
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
optimizers[x] = zero_optimizer
del optimizer
return optimizers


def _reinit_optimizers_with_oss(
optimizers: List[Optimizer], precision_plugin: Precision, num_nodes: int
Expand Down
7 changes: 7 additions & 0 deletions src/lightning_lite/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,13 @@ def get_optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:

Allows for syncing/collating optimizer state from processes in custom plugins.
"""
if hasattr(optimizer, "consolidate_state_dict"):
# there are optimizers like Fairscale's OSS or PyTorch's ZeroRedundancyOptimizer that shard their
# states, and to avoid OOM we consolidate the full state on rank 0 only
optimizer.consolidate_state_dict()
return optimizer.state_dict() if self.is_global_zero else {}

# for optimizers that are not sharded, we return the state dict on all ranks
return optimizer.state_dict()

def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
Expand Down
11 changes: 10 additions & 1 deletion tests/tests_lite/helpers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class BoringLite(LightningLite):
def get_model(self) -> Module:
return nn.Linear(32, 2)

def get_optimizer(self, module: Module) -> Optimizer:
return torch.optim.SGD(module.parameters(), lr=0.1)

def get_dataloader(self) -> DataLoader:
return DataLoader(RandomDataset(32, 64))

Expand All @@ -52,12 +55,18 @@ def after_optimizer_step(self, model: Module, optimizer: Optimizer) -> None:

def run(self) -> None:
model = self.get_model()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer = self.get_optimizer(model)
dataloader = self.get_dataloader()

model, optimizer = self.setup(model, optimizer)
dataloader = self.setup_dataloaders(dataloader)

self.model = model
self.optimizer = optimizer
self.dataloader = dataloader

model.train()

data_iter = iter(dataloader)
batch = next(data_iter)
loss = self.step(model, batch)
Expand Down
47 changes: 46 additions & 1 deletion tests/tests_lite/strategies/test_fairscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
# limitations under the License.
from unittest import mock

import pytest
import torch.nn as nn
import torch.optim
from tests_lite.helpers.runif import RunIf

from lightning_lite.strategies import DDPShardedStrategy
from lightning_lite.strategies.fairscale import ShardedDataParallel
from lightning_lite.strategies.fairscale import DDPSpawnShardedStrategy, ShardedDataParallel


@RunIf(fairscale=True)
Expand All @@ -26,3 +29,45 @@ def test_block_backward_sync():
with strategy.block_backward_sync(model):
pass
model.no_sync.assert_called_once()


@RunIf(fairscale=True)
@mock.patch("lightning_lite.strategies.fairscale._reinit_optimizers_with_oss", autospec=True)
@pytest.mark.parametrize("cls", [DDPShardedStrategy, DDPSpawnShardedStrategy])
def test_fairscale_custom_kwargs(_, cls):
"""Test that if custom kwargs are passed, they are set correctly."""
strategy = cls(reduce_fp16=True)
assert strategy._ddp_kwargs["reduce_fp16"] is True

model = nn.Linear(3, 3)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

with mock.patch("lightning_lite.strategies.fairscale.ShardedDataParallel", autospec=True) as mock_sharded:
strategy.setup_module_and_optimizers(model, [optimizer])
args, kwargs = mock_sharded.call_args
assert kwargs["reduce_fp16"] is True


@RunIf(fairscale=True)
@mock.patch("lightning_lite.strategies.fairscale._reinit_optimizers_with_oss", autospec=True)
@pytest.mark.parametrize("kwargs, expected_buffer_size", [(dict(), 0), (dict(reduce_buffer_size=128), 128)])
@pytest.mark.parametrize("num_nodes", [1, 2])
def test_fairscale_custom_kwargs_reduce_buffer_size(_, kwargs, expected_buffer_size, num_nodes):
"""Test that ``reduce_buffer_size`` is correctly set based on provided kwargs."""
strategy = DDPShardedStrategy(**kwargs)
strategy.num_nodes = num_nodes

model = nn.Linear(3, 3)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

with mock.patch("lightning_lite.strategies.fairscale.ShardedDataParallel", autospec=True) as mock_sharded:
strategy.setup_module_and_optimizers(model, [optimizer])

args, kwargs = mock_sharded.call_args
assert "reduce_buffer_size" in kwargs

if num_nodes > 1 and len(kwargs) == 0:
# If user has not specified a buffer size, and we're using multiple nodes, check to see if default is set
assert kwargs["reduce_buffer_size"] == DDPShardedStrategy._REDUCE_BUFFER_SIZE_DEFAULT
else:
assert kwargs["reduce_buffer_size"] == expected_buffer_size
83 changes: 83 additions & 0 deletions tests/tests_lite/strategies/test_fairscale_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import pytest
import torch
from tests_lite.helpers.models import BoringLite
from tests_lite.helpers.runif import RunIf


class ShardedSaveAndLoad(BoringLite):
def get_optimizer(self, module):
optimizer = super().get_optimizer(module)
if self.with_fairscale_oss:
from fairscale.optim import OSS

optimizer = OSS(params=optimizer.param_groups, optim=type(optimizer), **optimizer.defaults)
return optimizer

def run(self, tmpdir, with_fairscale_oss=False):
self.with_fairscale_oss = with_fairscale_oss

super().run()

from fairscale.nn import ShardedDataParallel
from fairscale.optim import OSS

# the model and optimizer is wrapped correctly
assert isinstance(self.model._forward_module, ShardedDataParallel)
assert isinstance(self.optimizer.optimizer, OSS)

self.model.cpu()

checkpoint_path = os.path.join(tmpdir, "checkpoint.ckpt")
checkpoint = {"model": self.model.state_dict(), "optimizer": self.optimizer.state_dict()}
self.save(checkpoint, checkpoint_path)

self.barrier()

loaded_checkpoint = self.load(checkpoint_path)
new_model = self.get_model()
new_model.load_state_dict(loaded_checkpoint["model"])

# Assert model parameters are identical after loading
for trained_param, loaded_param in zip(self.model.parameters(), new_model.parameters()):
assert torch.equal(trained_param, loaded_param)


@RunIf(fairscale=True)
@pytest.mark.parametrize("accelerator", ["cpu", pytest.param("cuda", marks=RunIf(min_cuda_gpus=2))])
@pytest.mark.parametrize("strategy", (pytest.param("ddp_sharded", marks=RunIf(standalone=True)), "ddp_sharded_spawn"))
@pytest.mark.parametrize("with_fairscale_oss", (True, False))
def test_fairscale_multi_process_checkpoint_state_consolidation(with_fairscale_oss, strategy, accelerator, tmpdir):
"""Test that the sharded optimizer states get consolidated when saving the checkpoint, and that the loaded
weights is identical to the saved one."""
lite = ShardedSaveAndLoad(strategy=strategy, accelerator=accelerator, devices=2)
lite.run(tmpdir, with_fairscale_oss=with_fairscale_oss)


@pytest.mark.parametrize(
"strategy, expected_find_unused_parameters",
[
("ddp_sharded", None),
("ddp_sharded_find_unused_parameters_false", False),
("ddp_sharded_spawn", None),
("ddp_sharded_spawn_find_unused_parameters_false", False),
],
)
def test_fairscale_find_unused_parameters_from_registry(strategy, expected_find_unused_parameters):
lite = BoringLite(strategy=strategy)
if expected_find_unused_parameters is not None:
assert lite._strategy._ddp_kwargs["find_unused_parameters"] is False
15 changes: 14 additions & 1 deletion tests/tests_lite/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def test_ipython_compatible_strategy_ddp_fork(monkeypatch):
)
@pytest.mark.parametrize("devices", [1, 2])
@mock.patch("lightning_lite.accelerators.cuda.num_cuda_devices", return_value=2)
def test_accelerator_choice_multi_node_gpu(_, strategy, strategy_class, devices):
def test_strategy_choice_multi_node_gpu(_, strategy, strategy_class, devices):
connector = _Connector(num_nodes=2, accelerator="gpu", strategy=strategy, devices=devices)
assert isinstance(connector.strategy, strategy_class)

Expand Down Expand Up @@ -376,6 +376,19 @@ def test_strategy_choice_gpu_str(strategy, strategy_class):
assert isinstance(connector.strategy, strategy_class)


@RunIf(fairscale=True)
@pytest.mark.parametrize(
"strategy,expected_strategy", [("ddp_sharded", DDPShardedStrategy), ("ddp_sharded_spawn", DDPSpawnShardedStrategy)]
)
@pytest.mark.parametrize(
"precision,expected_precision", [(16, NativeMixedPrecision), (32, Precision), ("bf16", NativeMixedPrecision)]
)
def test_strategy_choice_sharded(strategy, expected_strategy, precision, expected_precision):
connector = _Connector(strategy=strategy, devices=1, precision=precision)
assert isinstance(connector.strategy, expected_strategy)
assert isinstance(connector.precision_plugin, expected_precision)


@RunIf(min_cuda_gpus=2)
@pytest.mark.parametrize("strategy_class", [DDPSpawnStrategy, DDPStrategy])
def test_strategy_choice_gpu_instance(strategy_class):
Expand Down
Loading