Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New fabric parity tests #16899

Merged
merged 98 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from 68 commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
64ba17e
experimental
awaelchli Feb 27, 2023
98cd00d
wip
awaelchli Feb 27, 2023
0974d89
wip
awaelchli Feb 27, 2023
ba307f6
wip
awaelchli Feb 27, 2023
ed453d0
wip
awaelchli Feb 28, 2023
f6273db
fix
awaelchli Feb 28, 2023
164c994
update
awaelchli Feb 28, 2023
5fd9f5c
update
awaelchli Feb 28, 2023
c2ec0d7
update
awaelchli Feb 28, 2023
52c0f3f
update
awaelchli Feb 28, 2023
49313f0
update
awaelchli Feb 28, 2023
c08e16c
update
awaelchli Feb 28, 2023
d2f6184
update
awaelchli Feb 28, 2023
0cf71fb
update
awaelchli Feb 28, 2023
c713106
update
awaelchli Feb 28, 2023
b04d381
update
awaelchli Feb 28, 2023
2b47e9c
update
awaelchli Feb 28, 2023
bdc3055
update
awaelchli Feb 28, 2023
8747031
update
awaelchli Feb 28, 2023
14bb8d9
update
awaelchli Feb 28, 2023
2445026
update
awaelchli Feb 28, 2023
da23916
update
awaelchli Feb 28, 2023
0ea1496
update
awaelchli Feb 28, 2023
ba84ba7
update
awaelchli Feb 28, 2023
caa7c03
update
awaelchli Feb 28, 2023
2b57493
update
awaelchli Feb 28, 2023
c14e2c4
update
awaelchli Feb 28, 2023
43b17e9
refactor
awaelchli Feb 28, 2023
436a5e6
debug
awaelchli Feb 28, 2023
826be20
Revert "debug"
awaelchli Feb 28, 2023
c9d5f19
Revert "refactor"
awaelchli Feb 28, 2023
ddbb113
update
awaelchli Feb 28, 2023
e8b79a5
update
awaelchli Feb 28, 2023
c1dff21
update
awaelchli Feb 28, 2023
ad369b8
update
awaelchli Feb 28, 2023
0b35ef8
update
awaelchli Feb 28, 2023
68b4888
update
awaelchli Feb 28, 2023
41e7a25
update
awaelchli Feb 28, 2023
929d604
update
awaelchli Feb 28, 2023
71c77ca
update
awaelchli Feb 28, 2023
727515f
update
awaelchli Feb 28, 2023
1771443
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 28, 2023
bc37136
delete
awaelchli Feb 28, 2023
3d8ad31
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 1, 2023
76b1676
update
awaelchli Mar 1, 2023
130880f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 1, 2023
11d5099
benchmark
awaelchli Mar 1, 2023
667b174
Merge remote-tracking branch 'origin/fabric/framework-overhead' into …
awaelchli Mar 1, 2023
20c6672
update
awaelchli Mar 1, 2023
c5fa2bc
tuning
awaelchli Mar 1, 2023
2d85b0d
run on gpu
awaelchli Mar 1, 2023
72faa64
memory
awaelchli Mar 1, 2023
0de9ba2
tolerance
awaelchli Mar 1, 2023
905c5d6
memory
awaelchli Mar 1, 2023
719088b
refactor
awaelchli Mar 1, 2023
0bbe2cd
refactor
awaelchli Mar 1, 2023
6f41053
safer check
awaelchli Mar 1, 2023
1f6e987
reset peak
awaelchli Mar 1, 2023
33d7c01
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 1, 2023
8ade03a
Merge branch 'master' into fabric/framework-overhead
awaelchli Mar 1, 2023
5462c24
empty cache
awaelchli Mar 1, 2023
d7d3739
Merge remote-tracking branch 'origin/fabric/framework-overhead' into …
awaelchli Mar 1, 2023
cbf24c1
Update tests/tests_fabric/parity/test_parity_simple.py
awaelchli Mar 1, 2023
42facb7
Update tests/tests_fabric/parity/test_parity_simple.py
awaelchli Mar 1, 2023
d6e5227
cuda
awaelchli Mar 1, 2023
80d5919
Experiment with tracking mode by @carmocca
awaelchli Mar 1, 2023
a703991
Revert "Experiment with tracking mode by @carmocca"
awaelchli Mar 1, 2023
af7a7e4
move assertions top
awaelchli Mar 1, 2023
3162108
reset cuda memory stats before test
awaelchli Mar 1, 2023
24c9f1f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 1, 2023
b59f51d
assertions across all devices
awaelchli Mar 1, 2023
94e2777
Merge remote-tracking branch 'origin/fabric/framework-overhead' into …
awaelchli Mar 1, 2023
d42ab34
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 1, 2023
726d462
slow cpu
awaelchli Mar 1, 2023
b7a82b5
add requirement
awaelchli Mar 1, 2023
1697904
tolerance
awaelchli Mar 1, 2023
0c75321
bf16 skip windows
awaelchli Mar 1, 2023
3a8b048
Merge branch 'master' into fabric/framework-overhead
awaelchli Mar 1, 2023
603ec15
Merge branch 'master' into fabric/framework-overhead
awaelchli Mar 3, 2023
e5c836a
parity on cpu
awaelchli Mar 3, 2023
ea67af7
Merge branch 'master' into fabric/framework-overhead
awaelchli Mar 3, 2023
8201fed
Merge branch 'master' into fabric/framework-overhead
Borda Mar 3, 2023
a52061f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2023
1dd94b6
Merge branch 'master' into fabric/framework-overhead
awaelchli Mar 4, 2023
5970e8c
update
awaelchli Mar 4, 2023
c7f865b
Update tests/tests_fabric/parity/test_parity_ddp.py
awaelchli Mar 6, 2023
1b82ed5
Update tests/tests_fabric/conftest.py
awaelchli Mar 6, 2023
6221b28
Update tests/tests_fabric/parity/test_parity_ddp.py
awaelchli Mar 6, 2023
f38c95d
parametrize backend
awaelchli Mar 6, 2023
c6a45c8
use equality
awaelchli Mar 6, 2023
62a28de
add barrier
awaelchli Mar 6, 2023
f2ce109
comment about reusing processes
awaelchli Mar 6, 2023
f0edeba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2023
71342dc
Merge remote-tracking branch 'origin/fabric/framework-overhead' into …
awaelchli Mar 6, 2023
14bfc37
Merge branch 'master' into fabric/framework-overhead
awaelchli Mar 6, 2023
0a42768
Merge branch 'master' into fabric/framework-overhead
awaelchli Mar 6, 2023
4964d75
use the utility to clear cuda cache
awaelchli Mar 6, 2023
be6eda7
guard
awaelchli Mar 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions tests/tests_fabric/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def restore_env_variables():
"POPLAR_ENGINE_OPTIONS", # set by IPUStrategy
"CUDA_MODULE_LOADING", # leaked since PyTorch 1.13
"CRC32C_SW_MODE", # set by tensorboardX
"CUBLAS_WORKSPACE_CONFIG", # handled by the `reset_deterministic_algorithm` fixture below
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
}
leaked_vars.difference_update(allowlist)
assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}"
Expand All @@ -72,9 +73,18 @@ def teardown_process_group():
def reset_deterministic_algorithm():
"""Ensures that torch determinism settings are reset before the next test runs."""
yield
os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None)
torch.use_deterministic_algorithms(False)


@pytest.fixture
def reset_cudnn_benchmark():
"""Ensures that the `torch.backends.cudnn.benchmark` setting gets reset before the next test runs."""
benchmark = torch.backends.cudnn.benchmark
yield
torch.backends.cudnn.benchmark = benchmark
awaelchli marked this conversation as resolved.
Show resolved Hide resolved


def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> None:
monkeypatch.setattr(lightning.fabric.accelerators.tpu, "_XLA_AVAILABLE", value)
monkeypatch.setattr(lightning.fabric.plugins.environments.xla, "_XLA_AVAILABLE", value)
Expand Down
Empty file.
83 changes: 83 additions & 0 deletions tests/tests_fabric/parity/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Callable

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Optimizer
from torch.utils.data import DataLoader, TensorDataset


class ParityModel(ABC, nn.Module):
"""Defines the interface for a model in a Fabric-PyTorch parity test."""

# Benchmarking parameters that should be model-specific
batch_size = 1
num_steps = 1

@abstractmethod
def get_optimizer(self, *args, **kwargs) -> Optimizer:
pass

@abstractmethod
def get_dataloader(self, *args, **kwargs) -> DataLoader:
pass

@abstractmethod
def get_loss_function(self) -> Callable:
pass


class ConvNet(ParityModel):
batch_size = 4
num_steps = 1000

def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

def get_optimizer(self):
return torch.optim.SGD(self.parameters(), lr=0.0001)

def get_dataloader(self):
# multiply * 8 just in case world size is larger than 1
dataset_size = self.num_steps * self.batch_size * 8
inputs = torch.rand(dataset_size, 3, 32, 32)
labels = torch.randint(0, 10, (dataset_size,))
dataset = TensorDataset(inputs, labels)
dataloader = DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=2,
)
return dataloader

def get_loss_function(self):
return F.cross_entropy
157 changes: 157 additions & 0 deletions tests/tests_fabric/parity/test_parity_ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
from copy import deepcopy

import pytest
import torch
import torch.distributed
import torch.nn.functional
from tests_fabric.helpers.runif import RunIf
from tests_fabric.parity.models import ConvNet
from tests_fabric.parity.utils import is_cuda_memory_close, is_state_dict_equal, is_timing_close, make_deterministic
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from lightning.fabric.fabric import Fabric


def train_torch_ddp(
rank,
world_size,
device=torch.device("cpu"),
):
make_deterministic()
memory_stats = {}

os.environ["LOCAL_RANK"] = str(rank)
if torch.distributed.is_available() and not torch.distributed.is_initialized():
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)
carmocca marked this conversation as resolved.
Show resolved Hide resolved

model = ConvNet().to(device)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
initial_state_dict = deepcopy(model.state_dict())

ddp_model = DistributedDataParallel(model.to(device), device_ids=([rank] if device.type == "cuda" else None))
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

dataloader = model.get_dataloader()
sampler = DistributedSampler(dataloader.dataset, rank=rank, num_replicas=world_size, drop_last=False, shuffle=False)
dataloader = DataLoader(dataloader.dataset, sampler=sampler, batch_size=model.batch_size)
optimizer = model.get_optimizer()
loss_fn = model.get_loss_function()

memory_stats["start"] = torch.cuda.memory_stats()

ddp_model.train()
iteration_timings = []
iterator = iter(dataloader)
for _ in range(model.num_steps):
t0 = time.perf_counter()

inputs, labels = next(iterator)
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = ddp_model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()

t1 = time.perf_counter()
iteration_timings.append(t1 - t0)

memory_stats["end"] = torch.cuda.memory_stats()

# check that the model has changed
assert not is_state_dict_equal(initial_state_dict, ddp_model.module.state_dict())

return ddp_model.module.state_dict(), torch.tensor(iteration_timings), memory_stats


def train_fabric_ddp(fabric):
make_deterministic()
memory_stats = {}

model = ConvNet()
initial_state_dict = deepcopy(model.state_dict())

optimizer = model.get_optimizer()
model, optimizer = fabric.setup(model, optimizer)

dataloader = model.get_dataloader()
dataloader = fabric.setup_dataloaders(dataloader)
loss_fn = model.get_loss_function()

memory_stats["start"] = torch.cuda.memory_stats()

model.train()
iteration_timings = []
iterator = iter(dataloader)
for _ in range(model.num_steps):
t0 = time.perf_counter()

inputs, labels = next(iterator)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, labels)
fabric.backward(loss)
optimizer.step()

t1 = time.perf_counter()
iteration_timings.append(t1 - t0)

memory_stats["end"] = torch.cuda.memory_stats()

# check that the model has changed
assert not is_state_dict_equal(initial_state_dict, model.state_dict())

return model.state_dict(), torch.tensor(iteration_timings), memory_stats


@RunIf(standalone=True)
@pytest.mark.usefixtures("reset_deterministic_algorithm", "reset_cudnn_benchmark")
@pytest.mark.parametrize(
"accelerator, devices",
[
("cpu", 2),
pytest.param("cuda", 2, marks=RunIf(min_cuda_gpus=2)),
],
)
def test_parity_ddp(accelerator, devices):
# Train with Fabric
fabric = Fabric(accelerator=accelerator, strategy="ddp", devices=devices)
fabric.launch()
carmocca marked this conversation as resolved.
Show resolved Hide resolved
state_dict_fabric, timings_fabric, memory_fabric = train_fabric_ddp(fabric)

if accelerator == "cuda":
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

# Train with raw PyTorch
state_dict_torch, timings_torch, memory_torch = train_torch_ddp(
rank=fabric.global_rank,
world_size=fabric.world_size,
device=fabric.device,
)

# Compare the final weights
assert is_state_dict_equal(state_dict_torch, state_dict_fabric)

# Compare the time per iteration
assert is_timing_close(timings_torch, timings_fabric, rtol=1e-3, atol=1e-3)

# Compare memory usage
if accelerator == "cuda":
assert is_cuda_memory_close(memory_torch["start"], memory_fabric["start"])
assert is_cuda_memory_close(memory_torch["end"], memory_fabric["end"])
Loading