Skip to content

Commit

Permalink
Port fabric tests
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Jul 13, 2023
1 parent 3b78097 commit c875e55
Show file tree
Hide file tree
Showing 8 changed files with 165 additions and 8 deletions.
2 changes: 1 addition & 1 deletion tests/tests_fabric/plugins/precision/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_selected_dtype(precision, expected_dtype):
("16-true", torch.float16),
],
)
def test_module_init_context(precision, expected_dtype):
def test_init_context(precision, expected_dtype):
plugin = DeepSpeedPrecision(precision=precision)
with plugin.init_context():
model = torch.nn.Linear(2, 2)
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_fabric/plugins/precision/test_half.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_selected_dtype(precision, expected_dtype):
("16-true", torch.half),
],
)
def test_module_init_context(precision, expected_dtype):
def test_init_context(precision, expected_dtype):
plugin = HalfPrecision(precision=precision)
with plugin.init_context():
model = torch.nn.Linear(2, 2)
Expand Down
8 changes: 8 additions & 0 deletions tests/tests_pytorch/plugins/test_double_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,11 @@ def test_double_precision_pickle():
plugin = DoublePrecisionPlugin()
model, _, __ = plugin.connect(model, MagicMock(), MagicMock())
pickle.dumps(model)


def test_init_context():
plugin = DoublePrecisionPlugin()
with plugin.init_context():
model = torch.nn.Linear(2, 2)
assert torch.get_default_dtype() == torch.double
assert model.weight.dtype == torch.double
41 changes: 41 additions & 0 deletions tests/tests_pytorch/strategies/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock

import pytest
import torch

from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.pytorch import Trainer
from lightning.pytorch.plugins import DoublePrecisionPlugin, PrecisionPlugin
from lightning.pytorch.strategies import SingleDeviceStrategy
from tests_pytorch.helpers.datamodules import ClassifDataModule
from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.helpers.simple_models import ClassificationModel
Expand Down Expand Up @@ -47,3 +52,39 @@ def test_evaluate(tmpdir, trainer_kwargs):
# make sure weights didn't change
new_weights = model.layer_0.weight.clone().detach().cpu()
torch.testing.assert_close(old_weights, new_weights)


@RunIf(min_torch="1.13")
@pytest.mark.parametrize(
"device",
[
"cpu",
pytest.param("cuda:0", marks=RunIf(min_cuda_gpus=1)),
pytest.param("mps:0", marks=RunIf(mps=True)),
],
)
@pytest.mark.parametrize(
("precision", "dtype"),
[
(PrecisionPlugin(), torch.float32),
pytest.param(DoublePrecisionPlugin(), torch.float64, marks=RunIf(mps=False)),
],
)
@pytest.mark.parametrize("empty_init", [None, True, False])
def test_module_init_context(device, precision, dtype, empty_init, monkeypatch):
"""Test that the module under the init-module-context gets moved to the right device and dtype."""
init_mock = Mock()
monkeypatch.setattr(torch.Tensor, "uniform_", init_mock)

device = torch.device(device)
strategy = SingleDeviceStrategy(device=device, precision_plugin=precision) # surrogate class to test base class
with strategy.tensor_init_context(empty_init=empty_init):
module = torch.nn.Linear(2, 2)

expected_device = device if _TORCH_GREATER_EQUAL_2_0 else torch.device("cpu")
assert module.weight.device == module.bias.device == expected_device
assert module.weight.dtype == module.bias.dtype == dtype
if not empty_init:
init_mock.assert_called()
else:
init_mock.assert_not_called()
27 changes: 27 additions & 0 deletions tests/tests_pytorch/strategies/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import mock

import pytest
import torch
from torch.nn.parallel.distributed import DistributedDataParallel

import lightning.pytorch as pl
from lightning.fabric.plugins.environments import LightningEnvironment
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.pytorch import seed_everything, Trainer
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.plugins import DoublePrecisionPlugin, PrecisionPlugin
from lightning.pytorch.strategies import DDPStrategy
from tests_pytorch.helpers.datamodules import ClassifDataModule
from tests_pytorch.helpers.runif import RunIf
Expand Down Expand Up @@ -158,3 +161,27 @@ def root_device(self):
def test_ddp_kwargs_from_registry(strategy_name, expected_ddp_kwargs, mps_count_0):
trainer = Trainer(strategy=strategy_name)
assert trainer.strategy._ddp_kwargs == expected_ddp_kwargs


@RunIf(min_cuda_gpus=2)
@pytest.mark.parametrize(
("precision", "expected_dtype"),
[
(PrecisionPlugin(), torch.float32),
(DoublePrecisionPlugin(), torch.float64),
],
)
@mock.patch.dict(os.environ, {"LOCAL_RANK": "1"})
def test_tensor_init_context(precision, expected_dtype):
"""Test that the module under the init-context gets moved to the right device and dtype."""
parallel_devices = [torch.device("cuda", 0), torch.device("cuda", 1)]
expected_device = parallel_devices[1] if _TORCH_GREATER_EQUAL_2_0 else torch.device("cpu")

strategy = DDPStrategy(
parallel_devices=parallel_devices, precision=precision, cluster_environment=LightningEnvironment()
)
assert strategy.local_rank == 1
with strategy.tensor_init_context():
module = torch.nn.Linear(2, 2)
assert module.weight.device == module.bias.device == expected_device
assert module.weight.dtype == module.bias.dtype == expected_dtype
31 changes: 31 additions & 0 deletions tests/tests_pytorch/strategies/test_deepspeed_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from re import escape
from typing import Any, Dict
from unittest import mock
from unittest.mock import ANY

import pytest
import torch
Expand Down Expand Up @@ -1282,3 +1283,33 @@ def test_validate_parallel_devices_indices(device_indices):
RuntimeError, match=escape(f"device indices {device_indices!r} don't match the local rank values of processes")
):
strategy.setup_environment()


@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True, bf16_cuda=True)
@pytest.mark.parametrize("empty_init", [None, True])
def test_deepspeed_init_module_with_stage_3(empty_init):
"""Tests how `.init_module()` behaves with ZeRO stage 3."""
trainer = Trainer(
accelerator="cuda", devices=2, strategy="deepspeed_stage_3", precision="bf16-mixed", fast_dev_run=1
)
model = ModelParallelBoringModel()
with mock.patch("deepspeed.zero.Init") as zero_init_mock:
trainer.fit(model)

zero_init_mock.assert_called_once_with(
remote_device="cpu", pin_memory=True, config_dict_or_path=ANY, dtype=torch.bfloat16
)


@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True, bf16_cuda=True)
@pytest.mark.parametrize("stage", [1, 2])
def test_deepspeed_init_module_with_stages_1_2(stage):
"""Tests how `.init_module()` behaves with ZeRO stages 1 and 2."""
strategy = DeepSpeedStrategy(stage=stage)
trainer = Trainer(accelerator="cuda", devices=2, strategy=strategy, precision="bf16-mixed", fast_dev_run=1)
model = ModelParallelBoringModel()
with mock.patch("deepspeed.zero.Init") as zero_init_mock:
trainer.fit(model)

zero_init_mock.assert_not_called()
assert model.layer.weight.dtype == torch.bfloat16
42 changes: 36 additions & 6 deletions tests/tests_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,9 @@

if _TORCH_GREATER_EQUAL_1_12:
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, wrap
else:
size_based_auto_wrap_policy = object

from torch.distributed.fsdp.wrap import always_wrap_policy, size_based_auto_wrap_policy, wrap
if _TORCH_GREATER_EQUAL_2_0:
from torch.distributed.fsdp.wrap import _FSDPPolicy
else:
_FSDPPolicy = object


class TestFSDPModel(BoringModel):
Expand Down Expand Up @@ -627,3 +622,38 @@ def test_fsdp_strategy_load_optimizer_states(tmpdir, wrap_min_params):
torch.testing.assert_close(optimizer_state_dict, restored_optimizer_state_dict, atol=0, rtol=0)

trainer.strategy.barrier()


@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.parametrize(
("precision", "expected_dtype"),
[
("32-true", torch.float32),
],
)
def test_configure_model(precision, expected_dtype):
"""Test that the module under configure_model gets moved to the right device and dtype."""
trainer = Trainer(
accelerator="cuda",
devices=2,
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
precision=precision,
fast_dev_run=1,
)

class MyModel(BoringModel):
def configure_model(self):
self.layer = torch.nn.Linear(32, 2)
# The model is on the CPU until `.setup()``
# TODO: Support initialization on meta device
expected_device = torch.device("cpu")
assert self.model.weight.device == expected_device
assert self.model.weight.dtype == expected_dtype

def on_fit_start(self):
# Parameters get sharded in `.setup()` and moved to the target device
assert self.model.weight.device == torch.device("cuda", self.local_rank)
assert self.model.weight.dtype == expected_dtype

model = MyModel()
trainer.fit(model)
20 changes: 20 additions & 0 deletions tests/tests_pytorch/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from torch.optim import SGD
from torch.utils.data import DataLoader, IterableDataset

import lightning.fabric
import tests_pytorch.helpers.utils as tutils
from lightning.fabric.utilities.cloud_io import _load as pl_load
from lightning.fabric.utilities.seed import seed_everything
Expand All @@ -54,6 +55,7 @@
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
from lightning.pytorch.utilities.warnings import PossibleUserWarning
from tests_pytorch.conftest import mock_cuda_count, mock_mps_count
from tests_pytorch.helpers.datamodules import ClassifDataModule
from tests_pytorch.helpers.runif import RunIf
Expand Down Expand Up @@ -2044,3 +2046,21 @@ def on_fit_start(self):
):
trainer.fit(ExceptionModel())
on_exception_mock.assert_called_once_with(exception)


def test_init_module_context(monkeypatch):
"""Test that the strategy returns the context manager for initializing the module."""
trainer = Trainer(accelerator="cpu", devices=1)
strategy = SingleDeviceStrategy(device=torch.device("cuda"))
strategy.tensor_init_context = Mock(wraps=strategy.tensor_init_context)
trainer._accelerator_connector.strategy = strategy
with trainer.init_module():
pass
strategy.tensor_init_context.assert_called_once_with(empty_init=None)
strategy.tensor_init_context.reset_mock()

# Pretend we are using PyTorch < 2.0
monkeypatch.setattr(lightning.pytorch.trainer.trainer, "_TORCH_GREATER_EQUAL_2_0", False)
with pytest.warns(PossibleUserWarning, match="can't place .* on the device"), trainer.init_module():
pass
strategy.tensor_init_context.assert_called_once()

0 comments on commit c875e55

Please sign in to comment.