diff --git a/requirements/fabric/test.txt b/requirements/fabric/test.txt index 26395bbdb958f..09bf112d03fcb 100644 --- a/requirements/fabric/test.txt +++ b/requirements/fabric/test.txt @@ -2,6 +2,7 @@ coverage==6.5.0 codecov==2.1.12 pytest==7.2.0 pytest-cov==4.0.0 +pytest-rerunfailures==10.3 pre-commit==2.20.0 click==8.1.3 tensorboardX>=2.2, <=2.5.1 # min version is set by torch.onnx missing attribute diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py index 56d87ed960f1e..bafc5a8e84fc4 100644 --- a/tests/tests_fabric/conftest.py +++ b/tests/tests_fabric/conftest.py @@ -72,9 +72,17 @@ def teardown_process_group(): def reset_deterministic_algorithm(): """Ensures that torch determinism settings are reset before the next test runs.""" yield + os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None) torch.use_deterministic_algorithms(False) +@pytest.fixture +def reset_cudnn_benchmark(): + """Ensures that the `torch.backends.cudnn.benchmark` setting gets reset before the next test runs.""" + yield + torch.backends.cudnn.benchmark = False + + def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> None: monkeypatch.setattr(lightning.fabric.accelerators.tpu, "_XLA_AVAILABLE", value) monkeypatch.setattr(lightning.fabric.plugins.environments.xla, "_XLA_AVAILABLE", value) diff --git a/tests/tests_fabric/parity/__init__.py b/tests/tests_fabric/parity/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tests_fabric/parity/models.py b/tests/tests_fabric/parity/models.py new file mode 100644 index 0000000000000..16616b255acc2 --- /dev/null +++ b/tests/tests_fabric/parity/models.py @@ -0,0 +1,83 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod +from typing import Callable + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.optim import Optimizer +from torch.utils.data import DataLoader, TensorDataset + + +class ParityModel(ABC, nn.Module): + """Defines the interface for a model in a Fabric-PyTorch parity test.""" + + # Benchmarking parameters that should be model-specific + batch_size = 1 + num_steps = 1 + + @abstractmethod + def get_optimizer(self, *args, **kwargs) -> Optimizer: + pass + + @abstractmethod + def get_dataloader(self, *args, **kwargs) -> DataLoader: + pass + + @abstractmethod + def get_loss_function(self) -> Callable: + pass + + +class ConvNet(ParityModel): + batch_size = 4 + num_steps = 1000 + + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = torch.flatten(x, 1) # flatten all dimensions except batch + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + def get_optimizer(self): + return torch.optim.SGD(self.parameters(), lr=0.0001) + + def get_dataloader(self): + # multiply * 8 just in case world size is larger than 1 + dataset_size = self.num_steps * self.batch_size * 8 + inputs = torch.rand(dataset_size, 3, 32, 32) + labels = torch.randint(0, 10, (dataset_size,)) + dataset = TensorDataset(inputs, labels) + dataloader = DataLoader( + dataset, + batch_size=self.batch_size, + num_workers=2, + ) + return dataloader + + def get_loss_function(self): + return F.cross_entropy diff --git a/tests/tests_fabric/parity/test_parity_ddp.py b/tests/tests_fabric/parity/test_parity_ddp.py new file mode 100644 index 0000000000000..73933742ca069 --- /dev/null +++ b/tests/tests_fabric/parity/test_parity_ddp.py @@ -0,0 +1,169 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import time +from copy import deepcopy + +import pytest +import torch +import torch.distributed +import torch.nn.functional +from torch.nn.parallel.distributed import DistributedDataParallel +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from lightning.fabric.fabric import Fabric +from tests_fabric.helpers.runif import RunIf +from tests_fabric.parity.models import ConvNet +from tests_fabric.parity.utils import ( + cuda_reset, + is_cuda_memory_close, + is_state_dict_equal, + is_timing_close, + make_deterministic, +) + + +def train_torch_ddp( + rank, + world_size, + device=torch.device("cpu"), + backend="nccl", +): + make_deterministic() + memory_stats = {} + + os.environ["LOCAL_RANK"] = str(rank) + torch.distributed.init_process_group(backend, rank=rank, world_size=world_size) + + model = ConvNet().to(device) + initial_state_dict = deepcopy(model.state_dict()) + + ddp_model = DistributedDataParallel(model, device_ids=([rank] if device.type == "cuda" else None)) + + dataloader = model.get_dataloader() + sampler = DistributedSampler(dataloader.dataset, rank=rank, num_replicas=world_size, drop_last=False, shuffle=False) + dataloader = DataLoader(dataloader.dataset, sampler=sampler, batch_size=model.batch_size) + optimizer = model.get_optimizer() + loss_fn = model.get_loss_function() + + memory_stats["start"] = torch.cuda.memory_stats() + + ddp_model.train() + iteration_timings = [] + iterator = iter(dataloader) + for _ in range(model.num_steps): + t0 = time.perf_counter() + + inputs, labels = next(iterator) + inputs, labels = inputs.to(device), labels.to(device) + optimizer.zero_grad() + outputs = ddp_model(inputs) + loss = loss_fn(outputs, labels) + loss.backward() + optimizer.step() + + t1 = time.perf_counter() + iteration_timings.append(t1 - t0) + + memory_stats["end"] = torch.cuda.memory_stats() + + # check that the model has changed + assert not is_state_dict_equal(initial_state_dict, ddp_model.module.state_dict()) + + return ddp_model.module.state_dict(), torch.tensor(iteration_timings), memory_stats + + +def train_fabric_ddp(fabric): + make_deterministic() + memory_stats = {} + + model = ConvNet() + initial_state_dict = deepcopy(model.state_dict()) + + optimizer = model.get_optimizer() + model, optimizer = fabric.setup(model, optimizer) + + dataloader = model.get_dataloader() + dataloader = fabric.setup_dataloaders(dataloader) + loss_fn = model.get_loss_function() + + memory_stats["start"] = torch.cuda.memory_stats() + + model.train() + iteration_timings = [] + iterator = iter(dataloader) + for _ in range(model.num_steps): + t0 = time.perf_counter() + + inputs, labels = next(iterator) + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_fn(outputs, labels) + fabric.backward(loss) + optimizer.step() + + t1 = time.perf_counter() + iteration_timings.append(t1 - t0) + + memory_stats["end"] = torch.cuda.memory_stats() + + # check that the model has changed + assert not is_state_dict_equal(initial_state_dict, model.state_dict()) + + return model.state_dict(), torch.tensor(iteration_timings), memory_stats + + +@pytest.mark.flaky(reruns=3) +@RunIf(standalone=True) +@pytest.mark.usefixtures("reset_deterministic_algorithm", "reset_cudnn_benchmark") +@pytest.mark.parametrize( + "accelerator, devices, tolerance", + [ + ("cpu", 2, 0.01), + pytest.param("cuda", 2, 0.005, marks=RunIf(min_cuda_gpus=2)), + ], +) +def test_parity_ddp(accelerator, devices, tolerance): + cuda_reset() + + # Launch processes with Fabric and re-use them for the PyTorch training for convenience + fabric = Fabric(accelerator=accelerator, strategy="ddp", devices=devices) + fabric.launch() + + # Train with Fabric + state_dict_fabric, timings_fabric, memory_fabric = train_fabric_ddp(fabric) + + fabric.barrier() + cuda_reset() + torch.distributed.destroy_process_group() + + # Train with raw PyTorch + state_dict_torch, timings_torch, memory_torch = train_torch_ddp( + rank=fabric.global_rank, + world_size=fabric.world_size, + device=fabric.device, + backend=fabric.strategy._process_group_backend, + ) + + # Compare the final weights + assert all(fabric.all_gather(is_state_dict_equal(state_dict_torch, state_dict_fabric))) + + # Compare the time per iteration + assert all(fabric.all_gather(is_timing_close(timings_torch, timings_fabric, rtol=tolerance, atol=tolerance))) + + # Compare memory usage + if accelerator == "cuda": + assert all(fabric.all_gather(is_cuda_memory_close(memory_torch["start"], memory_fabric["start"]))) + assert all(fabric.all_gather(is_cuda_memory_close(memory_torch["end"], memory_fabric["end"]))) diff --git a/tests/tests_fabric/parity/test_parity_simple.py b/tests/tests_fabric/parity/test_parity_simple.py new file mode 100644 index 0000000000000..1e2d0ac6d52dd --- /dev/null +++ b/tests/tests_fabric/parity/test_parity_simple.py @@ -0,0 +1,153 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time +from copy import deepcopy +from typing import Callable + +import pytest +import torch +import torch.distributed +import torch.nn.functional + +from lightning.fabric.fabric import Fabric +from tests_fabric.helpers.runif import RunIf +from tests_fabric.parity.models import ConvNet +from tests_fabric.parity.utils import ( + cuda_reset, + get_model_input_dtype, + is_cuda_memory_close, + is_state_dict_equal, + is_timing_close, + make_deterministic, +) + + +def train_torch( + move_to_device: Callable, + precision_context, + input_dtype=torch.float32, +): + make_deterministic() + memory_stats = {} + + model = ConvNet() + model = move_to_device(model) + dataloader = model.get_dataloader() + optimizer = model.get_optimizer() + loss_fn = model.get_loss_function() + + memory_stats["start"] = torch.cuda.memory_stats() + + model.train() + iteration_timings = [] + iterator = iter(dataloader) + for _ in range(model.num_steps): + t0 = time.perf_counter() + + inputs, labels = next(iterator) + inputs, labels = move_to_device(inputs), move_to_device(labels) + optimizer.zero_grad() + with precision_context(): + outputs = model(inputs.to(input_dtype)) + loss = loss_fn(outputs.float(), labels) + loss.backward() + optimizer.step() + + t1 = time.perf_counter() + iteration_timings.append(t1 - t0) + + memory_stats["end"] = torch.cuda.memory_stats() + + return model.state_dict(), torch.tensor(iteration_timings), memory_stats + + +def train_fabric(fabric): + make_deterministic() + memory_stats = {} + + model = ConvNet() + initial_state_dict = deepcopy(model.state_dict()) + + optimizer = model.get_optimizer() + model, optimizer = fabric.setup(model, optimizer) + + dataloader = model.get_dataloader() + dataloader = fabric.setup_dataloaders(dataloader) + loss_fn = model.get_loss_function() + + memory_stats["start"] = torch.cuda.memory_stats() + + model.train() + iteration_timings = [] + iterator = iter(dataloader) + for _ in range(model.num_steps): + t0 = time.perf_counter() + + inputs, labels = next(iterator) + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_fn(outputs, labels) + fabric.backward(loss) + optimizer.step() + + t1 = time.perf_counter() + iteration_timings.append(t1 - t0) + + memory_stats["end"] = torch.cuda.memory_stats() + + # check that the model has changed + assert not is_state_dict_equal(initial_state_dict, model.state_dict()) + + return model.state_dict(), torch.tensor(iteration_timings), memory_stats + + +@pytest.mark.flaky(reruns=3) +@pytest.mark.usefixtures("reset_deterministic_algorithm", "reset_cudnn_benchmark") +@pytest.mark.parametrize( + "precision, accelerator", + [ + (32, "cpu"), + pytest.param(32, "cuda", marks=RunIf(min_cuda_gpus=1)), + # pytest.param(16, "cuda", marks=RunIf(min_cuda_gpus=1)), # TODO: requires GradScaler + pytest.param("bf16", "cpu", marks=RunIf(skip_windows=True)), + pytest.param("bf16", "cuda", marks=RunIf(min_cuda_gpus=1, bf16_cuda=True)), + pytest.param(32, "mps", marks=RunIf(mps=True)), + ], +) +def test_parity_single_device(precision, accelerator): + input_dtype = get_model_input_dtype(precision) + + cuda_reset() + + # Train with Fabric + fabric = Fabric(precision=precision, accelerator=accelerator, devices=1) + state_dict_fabric, timings_fabric, memory_fabric = train_fabric(fabric) + + cuda_reset() + + # Train with raw PyTorch + state_dict_torch, timings_torch, memory_torch = train_torch( + fabric.to_device, precision_context=fabric.autocast, input_dtype=input_dtype + ) + + # Compare the final weights + assert is_state_dict_equal(state_dict_torch, state_dict_fabric) + + # Compare the time per iteration + assert is_timing_close(timings_torch, timings_fabric, rtol=1e-3, atol=1e-3) + + # Compare memory usage + if accelerator == "cuda": + assert is_cuda_memory_close(memory_torch["start"], memory_fabric["start"]) + assert is_cuda_memory_close(memory_torch["end"], memory_fabric["end"]) diff --git a/tests/tests_fabric/parity/utils.py b/tests/tests_fabric/parity/utils.py new file mode 100644 index 0000000000000..0248c036f76f0 --- /dev/null +++ b/tests/tests_fabric/parity/utils.py @@ -0,0 +1,60 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import torch + +from lightning.fabric.accelerators.cuda import _clear_cuda_memory +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12 + + +def is_state_dict_equal(state0, state1): + eq_fn = torch.equal if _TORCH_GREATER_EQUAL_1_12 else torch.allclose + return all(eq_fn(w0.cpu(), w1.cpu()) for w0, w1 in zip(state0.values(), state1.values())) + + +def is_timing_close(timings_torch, timings_fabric, rtol=1e-3, atol=1e-3): + # Drop measurements of the first iterations, as they may be slower than others + # The median is more robust to outliers than the mean + # Given relative and absolute tolerances, we want to satisfy: |torch – fabric| < RTOL * torch + ATOL + return bool(torch.isclose(torch.median(timings_torch[3:]), torch.median(timings_fabric[3:]), rtol=rtol, atol=atol)) + + +def is_cuda_memory_close(memory_stats_torch, memory_stats_fabric): + # We require Fabric's peak memory usage to be smaller or equal to that of PyTorch + return memory_stats_torch["allocated_bytes.all.peak"] >= memory_stats_fabric["allocated_bytes.all.peak"] + + +def make_deterministic(): + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + torch.use_deterministic_algorithms(True) + torch.backends.cudnn.benchmark = False + torch.manual_seed(1) + torch.cuda.manual_seed(1) + + +def get_model_input_dtype(precision): + if precision in ("16-mixed", "16", 16): + return torch.float16 + elif precision in ("bf16-mixed", "bf16"): + return torch.bfloat16 + elif precision in ("64-true", "64", 64): + return torch.double + return torch.float32 + + +def cuda_reset(): + if torch.cuda.is_available(): + _clear_cuda_memory() + torch.cuda.reset_peak_memory_stats() diff --git a/tests/tests_fabric/test_parity.py b/tests/tests_fabric/test_parity.py deleted file mode 100644 index 87eb7094a251d..0000000000000 --- a/tests/tests_fabric/test_parity.py +++ /dev/null @@ -1,226 +0,0 @@ -# Copyright The Lightning AI team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -from contextlib import contextmanager -from copy import deepcopy -from functools import partial -from typing import Callable, Generator - -import pytest -import torch -import torch.distributed -import torch.multiprocessing as mp -import torch.nn.functional -from lightning_utilities.core.apply_func import apply_to_collection -from torch import nn, Tensor -from torch.nn.parallel.distributed import DistributedDataParallel -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler - -from lightning.fabric.fabric import Fabric -from lightning.fabric.plugins.environments.lightning import find_free_network_port -from lightning.fabric.strategies.ddp import DDPStrategy -from lightning.fabric.utilities.apply_func import move_data_to_device -from lightning.fabric.utilities.cloud_io import _atomic_save -from tests_fabric.helpers.models import RandomDataset -from tests_fabric.helpers.runif import RunIf - - -class BoringModel(nn.Module): - def __init__(self): - super().__init__() - self.layer = torch.nn.Linear(32, 2, bias=False) - - def forward(self, x): - x = self.layer(x) - return torch.nn.functional.mse_loss(x, torch.ones_like(x)) - - -def configure_optimizers(module: nn.Module): - return torch.optim.SGD(module.parameters(), lr=0.0001) - - -def main( - move_to_device: Callable, - model: nn.Module, - train_dataloader: DataLoader, - num_epochs: int = 10, -): - model = move_to_device(model) - optimizer = configure_optimizers(model) - - for _ in range(num_epochs): - model.train() - for batch in train_dataloader: - batch = move_to_device(batch) - optimizer.zero_grad() - loss = model(batch) - loss.backward() - optimizer.step() - - return model.state_dict() - - -class FabricRunner(Fabric): - def run(self, model: nn.Module, train_dataloader: DataLoader, num_epochs: int = 10, tmpdir: str = None): - optimizer = configure_optimizers(model) - model, optimizer = self.setup(model, optimizer) - train_dataloader = self.setup_dataloaders(train_dataloader) - - model.train() - for _ in range(num_epochs): - for batch in train_dataloader: - batch = self.to_device(batch) - optimizer.zero_grad() - loss = model(batch) - self.backward(loss) - optimizer.step() - - if isinstance(self._strategy, DDPStrategy) and tmpdir and self.global_rank == 0: - checkpoint_path = os.path.join(tmpdir, "model.pt") - _atomic_save(model.state_dict(), checkpoint_path) - return checkpoint_path - - -@contextmanager -def precision_context(precision, accelerator) -> Generator[None, None, None]: - if precision == 32: - yield - return - if accelerator == "gpu": - with torch.cuda.amp.autocast(): - yield - elif accelerator == "cpu": - with torch.cpu.amp.autocast(): - yield - - -@pytest.mark.parametrize( - "precision, accelerator", - [ - (32, "cpu"), - pytest.param(32, "gpu", marks=RunIf(min_cuda_gpus=1)), - pytest.param(16, "gpu", marks=RunIf(min_cuda_gpus=1)), - pytest.param("bf16", "gpu", marks=RunIf(min_cuda_gpus=1, bf16_cuda=True)), - pytest.param(32, "mps", marks=RunIf(mps=True)), - ], -) -def test_boring_fabric_model_single_device(precision, accelerator): - Fabric.seed_everything(42) - train_dataloader = DataLoader(RandomDataset(32, 8)) - model = BoringModel() - num_epochs = 1 - state_dict = deepcopy(model.state_dict()) - - fabric = FabricRunner(precision=precision, accelerator=accelerator) - fabric.run(model, train_dataloader, num_epochs=num_epochs) - fabric_state_dict = model.state_dict() - - with precision_context(precision, accelerator): - model.load_state_dict(state_dict) - pure_state_dict = main(fabric.to_device, model, train_dataloader, num_epochs=num_epochs) - - state_dict = apply_to_collection(state_dict, Tensor, fabric.to_device) - for w_pure, w_fabric in zip(state_dict.values(), fabric_state_dict.values()): - # TODO: This should be torch.equal, but MPS does not yet support this operation (torch 1.12) - assert not torch.allclose(w_pure, w_fabric) - - for w_pure, w_fabric in zip(pure_state_dict.values(), fabric_state_dict.values()): - # TODO: This should be torch.equal, but MPS does not yet support this operation (torch 1.12) - assert torch.allclose(w_pure, w_fabric) - - -def run(rank, model, train_dataloader, num_epochs, precision, accelerator, tmpdir): - os.environ["LOCAL_RANK"] = str(rank) - if torch.distributed.is_available() and not torch.distributed.is_initialized(): - torch.distributed.init_process_group("gloo", rank=rank, world_size=2) - - to_device = partial(move_data_to_device, device=torch.device("cuda", rank)) - model = DistributedDataParallel( - to_device(model), - device_ids=[rank], - ) - train_dataloader = DataLoader( - train_dataloader.dataset, - sampler=DistributedSampler(train_dataloader.dataset, rank=rank, num_replicas=2, seed=42, drop_last=False), - ) - with precision_context(precision, accelerator): - main(to_device, model, train_dataloader, num_epochs=num_epochs) - - if rank == 0: - _atomic_save(model.state_dict(), os.path.join(tmpdir, "model_spawn.pt")) - - -@pytest.mark.skip(reason="Skipping as it takes 80 seconds.") -@RunIf(min_cuda_gpus=2) -@pytest.mark.parametrize( - "precision, strategy, devices, accelerator", - [ - (32, "ddp_spawn", 2, "gpu"), - ], -) -def test_boring_fabric_model_ddp_spawn(precision, strategy, devices, accelerator, tmpdir): - Fabric.seed_everything(42) - train_dataloader = DataLoader(RandomDataset(32, 8)) - model = BoringModel() - num_epochs = 1 - state_dict = deepcopy(model.state_dict()) - - fabric = FabricRunner(precision=precision, strategy=strategy, devices=devices, accelerator=accelerator) - checkpoint_path = fabric.run(model, train_dataloader, num_epochs=num_epochs, tmpdir=tmpdir) - spawn_model_state_dict = torch.load(checkpoint_path) - - for w_pure, w_fabric in zip(state_dict.values(), spawn_model_state_dict.values()): - assert not torch.equal(w_pure.cpu(), w_fabric.cpu()) - - model.load_state_dict(state_dict) - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = str(find_free_network_port()) - mp.spawn(run, args=(model, train_dataloader, num_epochs, precision, accelerator, tmpdir), nprocs=2) - spawn_pure_model_state_dict = torch.load(os.path.join(tmpdir, "model_spawn.pt")) - - for w_pure, w_fabric in zip(spawn_pure_model_state_dict.values(), spawn_model_state_dict.values()): - assert torch.equal(w_pure.cpu(), w_fabric.cpu()) - - -@RunIf(min_cuda_gpus=2, standalone=True) -@pytest.mark.parametrize( - "precision, strategy, devices, accelerator", - [ - (32, "ddp", 2, "gpu"), - ], -) -def test_boring_fabric_model_ddp(precision, strategy, devices, accelerator, tmpdir): - Fabric.seed_everything(42) - train_dataloader = DataLoader(RandomDataset(32, 4), shuffle=True) - model = BoringModel() - num_epochs = 1 - state_dict = deepcopy(model.state_dict()) - - fabric = FabricRunner(precision=precision, strategy=strategy, devices=devices, accelerator=accelerator) - fabric.run(model, train_dataloader, num_epochs=num_epochs, tmpdir=tmpdir) - - fabric_model_state_dict = model.state_dict() - - for w_pure, w_fabric in zip(state_dict.values(), fabric_model_state_dict.values()): - assert not torch.allclose(w_pure.cpu(), w_fabric.cpu()) - - Fabric.seed_everything(42) - train_dataloader = DataLoader(RandomDataset(32, 4), shuffle=True) - model = BoringModel() - run(fabric.global_rank, model, train_dataloader, num_epochs, precision, accelerator, tmpdir) - pure_model_state_dict = model.state_dict() - - for w_pure, w_fabric in zip(pure_model_state_dict.values(), fabric_model_state_dict.values()): - torch.testing.assert_close(w_pure.cpu(), w_fabric.cpu())