Skip to content

Commit

Permalink
New fabric parity tests (#16899)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>

(cherry picked from commit 3e04353)
  • Loading branch information
awaelchli authored and Borda committed Mar 30, 2023
1 parent a665920 commit 581662b
Show file tree
Hide file tree
Showing 8 changed files with 467 additions and 226 deletions.
1 change: 1 addition & 0 deletions requirements/fabric/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ coverage==6.5.0
codecov==2.1.12
pytest==7.2.0
pytest-cov==4.0.0
pytest-rerunfailures==10.3
pre-commit==2.20.0
click==8.1.3
tensorboardX>=2.2, <=2.5.1 # min version is set by torch.onnx missing attribute
1 change: 1 addition & 0 deletions tests/tests_fabric/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def teardown_process_group():
def reset_deterministic_algorithm():
"""Ensures that torch determinism settings are reset before the next test runs."""
yield
os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None)
torch.use_deterministic_algorithms(False)


Expand Down
Empty file.
83 changes: 83 additions & 0 deletions tests/tests_fabric/parity/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Callable

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Optimizer
from torch.utils.data import DataLoader, TensorDataset


class ParityModel(ABC, nn.Module):
"""Defines the interface for a model in a Fabric-PyTorch parity test."""

# Benchmarking parameters that should be model-specific
batch_size = 1
num_steps = 1

@abstractmethod
def get_optimizer(self, *args, **kwargs) -> Optimizer:
pass

@abstractmethod
def get_dataloader(self, *args, **kwargs) -> DataLoader:
pass

@abstractmethod
def get_loss_function(self) -> Callable:
pass


class ConvNet(ParityModel):
batch_size = 4
num_steps = 1000

def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

def get_optimizer(self):
return torch.optim.SGD(self.parameters(), lr=0.0001)

def get_dataloader(self):
# multiply * 8 just in case world size is larger than 1
dataset_size = self.num_steps * self.batch_size * 8
inputs = torch.rand(dataset_size, 3, 32, 32)
labels = torch.randint(0, 10, (dataset_size,))
dataset = TensorDataset(inputs, labels)
dataloader = DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=2,
)
return dataloader

def get_loss_function(self):
return F.cross_entropy
169 changes: 169 additions & 0 deletions tests/tests_fabric/parity/test_parity_ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
from copy import deepcopy

import pytest
import torch
import torch.distributed
import torch.nn.functional
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from lightning_fabric.fabric import Fabric
from tests_fabric.helpers.runif import RunIf
from tests_fabric.parity.models import ConvNet
from tests_fabric.parity.utils import (
cuda_reset,
is_cuda_memory_close,
is_state_dict_equal,
is_timing_close,
make_deterministic,
)


def train_torch_ddp(
rank,
world_size,
device=torch.device("cpu"),
backend="nccl",
):
make_deterministic()
memory_stats = {}

os.environ["LOCAL_RANK"] = str(rank)
torch.distributed.init_process_group(backend, rank=rank, world_size=world_size)

model = ConvNet().to(device)
initial_state_dict = deepcopy(model.state_dict())

ddp_model = DistributedDataParallel(model, device_ids=([rank] if device.type == "cuda" else None))

dataloader = model.get_dataloader()
sampler = DistributedSampler(dataloader.dataset, rank=rank, num_replicas=world_size, drop_last=False, shuffle=False)
dataloader = DataLoader(dataloader.dataset, sampler=sampler, batch_size=model.batch_size)
optimizer = model.get_optimizer()
loss_fn = model.get_loss_function()

memory_stats["start"] = torch.cuda.memory_stats()

ddp_model.train()
iteration_timings = []
iterator = iter(dataloader)
for _ in range(model.num_steps):
t0 = time.perf_counter()

inputs, labels = next(iterator)
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = ddp_model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()

t1 = time.perf_counter()
iteration_timings.append(t1 - t0)

memory_stats["end"] = torch.cuda.memory_stats()

# check that the model has changed
assert not is_state_dict_equal(initial_state_dict, ddp_model.module.state_dict())

return ddp_model.module.state_dict(), torch.tensor(iteration_timings), memory_stats


def train_fabric_ddp(fabric):
make_deterministic()
memory_stats = {}

model = ConvNet()
initial_state_dict = deepcopy(model.state_dict())

optimizer = model.get_optimizer()
model, optimizer = fabric.setup(model, optimizer)

dataloader = model.get_dataloader()
dataloader = fabric.setup_dataloaders(dataloader)
loss_fn = model.get_loss_function()

memory_stats["start"] = torch.cuda.memory_stats()

model.train()
iteration_timings = []
iterator = iter(dataloader)
for _ in range(model.num_steps):
t0 = time.perf_counter()

inputs, labels = next(iterator)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, labels)
fabric.backward(loss)
optimizer.step()

t1 = time.perf_counter()
iteration_timings.append(t1 - t0)

memory_stats["end"] = torch.cuda.memory_stats()

# check that the model has changed
assert not is_state_dict_equal(initial_state_dict, model.state_dict())

return model.state_dict(), torch.tensor(iteration_timings), memory_stats


@pytest.mark.flaky(reruns=3)
@RunIf(standalone=True)
@pytest.mark.usefixtures("reset_deterministic_algorithm", "reset_cudnn_benchmark")
@pytest.mark.parametrize(
"accelerator, devices, tolerance",
[
("cpu", 2, 0.01),
pytest.param("cuda", 2, 0.005, marks=RunIf(min_cuda_gpus=2)),
],
)
def test_parity_ddp(accelerator, devices, tolerance):
cuda_reset()

# Launch processes with Fabric and re-use them for the PyTorch training for convenience
fabric = Fabric(accelerator=accelerator, strategy="ddp", devices=devices)
fabric.launch()

# Train with Fabric
state_dict_fabric, timings_fabric, memory_fabric = train_fabric_ddp(fabric)

fabric.barrier()
cuda_reset()
torch.distributed.destroy_process_group()

# Train with raw PyTorch
state_dict_torch, timings_torch, memory_torch = train_torch_ddp(
rank=fabric.global_rank,
world_size=fabric.world_size,
device=fabric.device,
backend=fabric.strategy._process_group_backend,
)

# Compare the final weights
assert all(fabric.all_gather(is_state_dict_equal(state_dict_torch, state_dict_fabric)))

# Compare the time per iteration
assert all(fabric.all_gather(is_timing_close(timings_torch, timings_fabric, rtol=tolerance, atol=tolerance)))

# Compare memory usage
if accelerator == "cuda":
assert all(fabric.all_gather(is_cuda_memory_close(memory_torch["start"], memory_fabric["start"])))
assert all(fabric.all_gather(is_cuda_memory_close(memory_torch["end"], memory_fabric["end"])))
Loading

0 comments on commit 581662b

Please sign in to comment.