-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Jirka <jirka.borovec@seznam.cz> (cherry picked from commit 3e04353)
- Loading branch information
Showing
8 changed files
with
467 additions
and
226 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Copyright The Lightning AI team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from abc import ABC, abstractmethod | ||
from typing import Callable | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.optim import Optimizer | ||
from torch.utils.data import DataLoader, TensorDataset | ||
|
||
|
||
class ParityModel(ABC, nn.Module): | ||
"""Defines the interface for a model in a Fabric-PyTorch parity test.""" | ||
|
||
# Benchmarking parameters that should be model-specific | ||
batch_size = 1 | ||
num_steps = 1 | ||
|
||
@abstractmethod | ||
def get_optimizer(self, *args, **kwargs) -> Optimizer: | ||
pass | ||
|
||
@abstractmethod | ||
def get_dataloader(self, *args, **kwargs) -> DataLoader: | ||
pass | ||
|
||
@abstractmethod | ||
def get_loss_function(self) -> Callable: | ||
pass | ||
|
||
|
||
class ConvNet(ParityModel): | ||
batch_size = 4 | ||
num_steps = 1000 | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.conv1 = nn.Conv2d(3, 6, 5) | ||
self.pool = nn.MaxPool2d(2, 2) | ||
self.conv2 = nn.Conv2d(6, 16, 5) | ||
self.fc1 = nn.Linear(16 * 5 * 5, 120) | ||
self.fc2 = nn.Linear(120, 84) | ||
self.fc3 = nn.Linear(84, 10) | ||
|
||
def forward(self, x): | ||
x = self.pool(F.relu(self.conv1(x))) | ||
x = self.pool(F.relu(self.conv2(x))) | ||
x = torch.flatten(x, 1) # flatten all dimensions except batch | ||
x = F.relu(self.fc1(x)) | ||
x = F.relu(self.fc2(x)) | ||
x = self.fc3(x) | ||
return x | ||
|
||
def get_optimizer(self): | ||
return torch.optim.SGD(self.parameters(), lr=0.0001) | ||
|
||
def get_dataloader(self): | ||
# multiply * 8 just in case world size is larger than 1 | ||
dataset_size = self.num_steps * self.batch_size * 8 | ||
inputs = torch.rand(dataset_size, 3, 32, 32) | ||
labels = torch.randint(0, 10, (dataset_size,)) | ||
dataset = TensorDataset(inputs, labels) | ||
dataloader = DataLoader( | ||
dataset, | ||
batch_size=self.batch_size, | ||
num_workers=2, | ||
) | ||
return dataloader | ||
|
||
def get_loss_function(self): | ||
return F.cross_entropy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
# Copyright The Lightning AI team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import os | ||
import time | ||
from copy import deepcopy | ||
|
||
import pytest | ||
import torch | ||
import torch.distributed | ||
import torch.nn.functional | ||
from torch.nn.parallel.distributed import DistributedDataParallel | ||
from torch.utils.data import DataLoader | ||
from torch.utils.data.distributed import DistributedSampler | ||
|
||
from lightning_fabric.fabric import Fabric | ||
from tests_fabric.helpers.runif import RunIf | ||
from tests_fabric.parity.models import ConvNet | ||
from tests_fabric.parity.utils import ( | ||
cuda_reset, | ||
is_cuda_memory_close, | ||
is_state_dict_equal, | ||
is_timing_close, | ||
make_deterministic, | ||
) | ||
|
||
|
||
def train_torch_ddp( | ||
rank, | ||
world_size, | ||
device=torch.device("cpu"), | ||
backend="nccl", | ||
): | ||
make_deterministic() | ||
memory_stats = {} | ||
|
||
os.environ["LOCAL_RANK"] = str(rank) | ||
torch.distributed.init_process_group(backend, rank=rank, world_size=world_size) | ||
|
||
model = ConvNet().to(device) | ||
initial_state_dict = deepcopy(model.state_dict()) | ||
|
||
ddp_model = DistributedDataParallel(model, device_ids=([rank] if device.type == "cuda" else None)) | ||
|
||
dataloader = model.get_dataloader() | ||
sampler = DistributedSampler(dataloader.dataset, rank=rank, num_replicas=world_size, drop_last=False, shuffle=False) | ||
dataloader = DataLoader(dataloader.dataset, sampler=sampler, batch_size=model.batch_size) | ||
optimizer = model.get_optimizer() | ||
loss_fn = model.get_loss_function() | ||
|
||
memory_stats["start"] = torch.cuda.memory_stats() | ||
|
||
ddp_model.train() | ||
iteration_timings = [] | ||
iterator = iter(dataloader) | ||
for _ in range(model.num_steps): | ||
t0 = time.perf_counter() | ||
|
||
inputs, labels = next(iterator) | ||
inputs, labels = inputs.to(device), labels.to(device) | ||
optimizer.zero_grad() | ||
outputs = ddp_model(inputs) | ||
loss = loss_fn(outputs, labels) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
t1 = time.perf_counter() | ||
iteration_timings.append(t1 - t0) | ||
|
||
memory_stats["end"] = torch.cuda.memory_stats() | ||
|
||
# check that the model has changed | ||
assert not is_state_dict_equal(initial_state_dict, ddp_model.module.state_dict()) | ||
|
||
return ddp_model.module.state_dict(), torch.tensor(iteration_timings), memory_stats | ||
|
||
|
||
def train_fabric_ddp(fabric): | ||
make_deterministic() | ||
memory_stats = {} | ||
|
||
model = ConvNet() | ||
initial_state_dict = deepcopy(model.state_dict()) | ||
|
||
optimizer = model.get_optimizer() | ||
model, optimizer = fabric.setup(model, optimizer) | ||
|
||
dataloader = model.get_dataloader() | ||
dataloader = fabric.setup_dataloaders(dataloader) | ||
loss_fn = model.get_loss_function() | ||
|
||
memory_stats["start"] = torch.cuda.memory_stats() | ||
|
||
model.train() | ||
iteration_timings = [] | ||
iterator = iter(dataloader) | ||
for _ in range(model.num_steps): | ||
t0 = time.perf_counter() | ||
|
||
inputs, labels = next(iterator) | ||
optimizer.zero_grad() | ||
outputs = model(inputs) | ||
loss = loss_fn(outputs, labels) | ||
fabric.backward(loss) | ||
optimizer.step() | ||
|
||
t1 = time.perf_counter() | ||
iteration_timings.append(t1 - t0) | ||
|
||
memory_stats["end"] = torch.cuda.memory_stats() | ||
|
||
# check that the model has changed | ||
assert not is_state_dict_equal(initial_state_dict, model.state_dict()) | ||
|
||
return model.state_dict(), torch.tensor(iteration_timings), memory_stats | ||
|
||
|
||
@pytest.mark.flaky(reruns=3) | ||
@RunIf(standalone=True) | ||
@pytest.mark.usefixtures("reset_deterministic_algorithm", "reset_cudnn_benchmark") | ||
@pytest.mark.parametrize( | ||
"accelerator, devices, tolerance", | ||
[ | ||
("cpu", 2, 0.01), | ||
pytest.param("cuda", 2, 0.005, marks=RunIf(min_cuda_gpus=2)), | ||
], | ||
) | ||
def test_parity_ddp(accelerator, devices, tolerance): | ||
cuda_reset() | ||
|
||
# Launch processes with Fabric and re-use them for the PyTorch training for convenience | ||
fabric = Fabric(accelerator=accelerator, strategy="ddp", devices=devices) | ||
fabric.launch() | ||
|
||
# Train with Fabric | ||
state_dict_fabric, timings_fabric, memory_fabric = train_fabric_ddp(fabric) | ||
|
||
fabric.barrier() | ||
cuda_reset() | ||
torch.distributed.destroy_process_group() | ||
|
||
# Train with raw PyTorch | ||
state_dict_torch, timings_torch, memory_torch = train_torch_ddp( | ||
rank=fabric.global_rank, | ||
world_size=fabric.world_size, | ||
device=fabric.device, | ||
backend=fabric.strategy._process_group_backend, | ||
) | ||
|
||
# Compare the final weights | ||
assert all(fabric.all_gather(is_state_dict_equal(state_dict_torch, state_dict_fabric))) | ||
|
||
# Compare the time per iteration | ||
assert all(fabric.all_gather(is_timing_close(timings_torch, timings_fabric, rtol=tolerance, atol=tolerance))) | ||
|
||
# Compare memory usage | ||
if accelerator == "cuda": | ||
assert all(fabric.all_gather(is_cuda_memory_close(memory_torch["start"], memory_fabric["start"]))) | ||
assert all(fabric.all_gather(is_cuda_memory_close(memory_torch["end"], memory_fabric["end"]))) |
Oops, something went wrong.