Skip to content

Commit

Permalink
Add test for torch.compile() with Fabric.setup() (#16977)
Browse files Browse the repository at this point in the history
(cherry picked from commit b6c693d)
  • Loading branch information
awaelchli authored and Borda committed Mar 30, 2023
1 parent 0b9017f commit 92d1c9f
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 38 deletions.
77 changes: 45 additions & 32 deletions tests/tests_fabric/test_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,6 @@
from tests_fabric.helpers.runif import RunIf


class EmptyFabric(Fabric):
def run(self):
pass


class BoringModel(nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -81,7 +76,7 @@ def run(self, *args, **kwargs):
@pytest.mark.parametrize("setup_method", ["setup", "setup_module"])
def test_setup_module(ddp_mock, setup_method):
"""Test that the setup method lets the strategy wrap the model, but keeps a reference to the original model."""
fabric = EmptyFabric(accelerator="cpu", strategy="ddp", devices=2)
fabric = Fabric(accelerator="cpu", strategy="ddp", devices=2)
model = nn.Linear(1, 2)
setup_method = getattr(fabric, setup_method)
fabric_model = setup_method(model)
Expand All @@ -91,6 +86,24 @@ def test_setup_module(ddp_mock, setup_method):
assert fabric_model.forward != model.forward


@RunIf(min_torch="2.0.0", skip_windows=True)
@pytest.mark.parametrize("setup_method", ["setup", "setup_module"])
def test_setup_compiled_module(setup_method):
"""Test that an `OptimizedModule` can be passed to the setup method."""
from torch._dynamo.eval_frame import OptimizedModule

fabric = Fabric()
model = nn.Linear(1, 2)
compiled_model = torch.compile(model)
assert isinstance(compiled_model, OptimizedModule)
setup_method = getattr(fabric, setup_method)
fabric_model = setup_method(compiled_model)

assert fabric_model.module == compiled_model
# Attributes get passed through
assert fabric_model.weight is model.weight


@pytest.mark.parametrize(
"accelerator, initial_device, target_device",
[
Expand All @@ -111,7 +124,7 @@ def test_setup_module_move_to_device(setup_method, move_to_device, accelerator,
target_device = torch.device(target_device)
expected_device = target_device if move_to_device else initial_device

fabric = EmptyFabric(accelerator=accelerator, devices=1)
fabric = Fabric(accelerator=accelerator, devices=1)
model = nn.Linear(1, 2)
model.to(initial_device)
setup_method = getattr(fabric, setup_method)
Expand All @@ -134,7 +147,7 @@ def test_setup_module_parameters_on_different_devices(setup_method, move_to_devi
device0 = torch.device("cpu")
device1 = torch.device("cuda", 0)

fabric = EmptyFabric(accelerator="cuda", devices=1)
fabric = Fabric(accelerator="cuda", devices=1)

module0 = nn.Linear(1, 2).to(device0)
module1 = nn.Linear(1, 2).to(device1)
Expand All @@ -157,7 +170,7 @@ def test_setup_module_parameters_on_different_devices(setup_method, move_to_devi

def test_setup_module_and_optimizers():
"""Test that `setup()` can handle no optimizers, one optimizer, or multiple optimizers."""
fabric = EmptyFabric()
fabric = Fabric()
model = nn.Linear(1, 2)
optimizer0 = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer1 = torch.optim.Adam(model.parameters(), lr=0.1)
Expand Down Expand Up @@ -186,7 +199,7 @@ def test_setup_module_and_optimizers():

def test_setup_optimizers():
"""Test that `setup_optimizers()` can handle one or more optimizers."""
fabric = EmptyFabric()
fabric = Fabric()
model = nn.Linear(1, 2)
optimizer0 = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer1 = torch.optim.Adam(model.parameters(), lr=0.1)
Expand All @@ -206,7 +219,7 @@ def test_setup_optimizers():

def test_setup_twice_fails():
"""Test that calling `setup` with a model or optimizer that is already wrapped fails."""
fabric = EmptyFabric()
fabric = Fabric()
model = nn.Linear(1, 2)
optimizer = torch.optim.Adam(model.parameters())

Expand All @@ -221,7 +234,7 @@ def test_setup_twice_fails():

def test_setup_module_twice_fails():
"""Test that calling `setup_module` with a model that is already wrapped fails."""
fabric = EmptyFabric()
fabric = Fabric()
model = nn.Linear(1, 2)

fabric_model = fabric.setup_module(model)
Expand All @@ -231,7 +244,7 @@ def test_setup_module_twice_fails():

def test_setup_optimizers_twice_fails():
"""Test that calling `setup_module` with a model that is already wrapped fails."""
fabric = EmptyFabric()
fabric = Fabric()
model = nn.Linear(1, 2)
optimizer = torch.optim.Adam(model.parameters())

Expand All @@ -244,7 +257,7 @@ def test_setup_optimizers_twice_fails():
def test_setup_optimizers_not_supported(strategy_cls):
"""Test that `setup_optimizers` validates the strategy supports setting up model and optimizers
independently."""
fabric = EmptyFabric()
fabric = Fabric()
model = nn.Linear(1, 2)
optimizer = torch.optim.Adam(model.parameters())
fabric._strategy = Mock(spec=strategy_cls)
Expand All @@ -254,7 +267,7 @@ def test_setup_optimizers_not_supported(strategy_cls):

def test_setup_tracks_num_models():
"""Test that setup() tracks how many times it has setup a model."""
fabric = EmptyFabric()
fabric = Fabric()
model = nn.Linear(1, 2)
optimizer = torch.optim.Adam(model.parameters())

Expand All @@ -271,7 +284,7 @@ def test_setup_tracks_num_models():

def test_setup_dataloaders_unsupported_input():
"""Test that the setup_dataloaders method fails when provided with non-DataLoader objects."""
fabric = EmptyFabric()
fabric = Fabric()
with pytest.raises(ValueError, match="`setup_dataloaders` requires at least one dataloader"):
fabric.setup_dataloaders()
with pytest.raises(TypeError, match="Only PyTorch DataLoader are currently supported"):
Expand All @@ -280,7 +293,7 @@ def test_setup_dataloaders_unsupported_input():

def test_setup_dataloaders_return_type():
"""Test that the setup method returns the dataloaders wrapped as FabricDataLoader and in the right order."""
fabric = EmptyFabric()
fabric = Fabric()

# single dataloader
fabric_dataloader = fabric.setup_dataloaders(DataLoader(range(2)))
Expand Down Expand Up @@ -315,7 +328,7 @@ def run(self):
def test_setup_dataloaders_raises_for_unknown_custom_args():
"""Test that an error raises when custom dataloaders with unknown arguments are created from outside Fabric's
run method."""
fabric = EmptyFabric()
fabric = Fabric()

class CustomDataLoader(DataLoader):
def __init__(self, new_arg, *args, **kwargs):
Expand All @@ -335,7 +348,7 @@ def __init__(self, new_arg, *args, **kwargs):

def test_setup_dataloaders_twice_fails():
"""Test that calling setup_dataloaders with a dataloader that is already wrapped fails."""
fabric = EmptyFabric()
fabric = Fabric()
dataloader = DataLoader(range(2))
fabric_dataloader = fabric.setup_dataloaders(dataloader)

Expand All @@ -350,12 +363,12 @@ def test_setup_dataloaders_twice_fails():
)
def test_setup_dataloaders_move_to_device(fabric_device_mock):
"""Test that the setup configures FabricDataLoader to move the data to the device automatically."""
fabric = EmptyFabric()
fabric = Fabric()
fabric_dataloaders = fabric.setup_dataloaders(DataLoader(Mock()), DataLoader(Mock()), move_to_device=False)
assert all(dl.device is None for dl in fabric_dataloaders)
fabric_device_mock.assert_not_called()

fabric = EmptyFabric()
fabric = Fabric()
fabric_dataloaders = fabric.setup_dataloaders(DataLoader(Mock()), DataLoader(Mock()), move_to_device=True)
assert all(dl.device == torch.device("cuda", 1) for dl in fabric_dataloaders)
fabric_device_mock.assert_called()
Expand All @@ -367,7 +380,7 @@ def test_setup_dataloaders_distributed_sampler_not_needed():
dataloader = DataLoader(Mock(), sampler=custom_sampler)

# keep the custom sampler when not needed to replace
fabric = EmptyFabric()
fabric = Fabric()
fabric_dataloader = fabric.setup_dataloaders(dataloader, replace_sampler=True)
assert fabric_dataloader.sampler is custom_sampler

Expand Down Expand Up @@ -440,9 +453,9 @@ def fetch_epoch(loader):
@mock.patch.dict(os.environ, {}, clear=True)
def test_seed_everything():
"""Test that seed everything is static and sets the worker init function on the dataloader."""
EmptyFabric.seed_everything(3)
Fabric.seed_everything(3)

fabric = EmptyFabric()
fabric = Fabric()
fabric_dataloader = fabric.setup_dataloaders(DataLoader(Mock()))

assert fabric_dataloader.worker_init_fn.func is pl_worker_init_function
Expand All @@ -466,7 +479,7 @@ def test_setup_dataloaders_replace_custom_sampler(strategy):
dataloader = DataLoader(Mock(), sampler=custom_sampler)

# explicitly asking to replace when a custom sampler is already configured raises an exception
fabric = EmptyFabric(accelerator="cpu", strategy=strategy, devices=2)
fabric = Fabric(accelerator="cpu", strategy=strategy, devices=2)
if hasattr(fabric.strategy, "distributed_sampler_kwargs"):
with pytest.raises(TypeError, match="You seem to have configured a sampler in your DataLoader"):
fabric.setup_dataloaders(dataloader, replace_sampler=True)
Expand All @@ -489,7 +502,7 @@ def test_setup_dataloaders_replace_custom_sampler(strategy):
@pytest.mark.parametrize("shuffle", [True, False])
def test_setup_dataloaders_replace_standard_sampler(shuffle, strategy):
"""Test that Fabric replaces the default samplers with DistributedSampler automatically."""
fabric = EmptyFabric(accelerator="cpu", strategy=strategy, devices=2)
fabric = Fabric(accelerator="cpu", strategy=strategy, devices=2)
is_distributed = hasattr(fabric.strategy, "distributed_sampler_kwargs")
fabric_dataloader = fabric.setup_dataloaders(DataLoader(range(3), shuffle=shuffle))
assert not is_distributed or isinstance(fabric_dataloader.sampler, DistributedSampler)
Expand Down Expand Up @@ -535,7 +548,7 @@ def run(self):

def test_rank_properties():
"""Test that the rank properties are determined by the strategy."""
fabric = EmptyFabric()
fabric = Fabric()
fabric._strategy = Mock(spec=Strategy)
fabric._strategy.world_size = 1000
assert fabric.world_size == 1000
Expand All @@ -549,7 +562,7 @@ def test_rank_properties():

def test_backward():
"""Test that backward() calls into the precision plugin."""
fabric = EmptyFabric()
fabric = Fabric()
fabric._precision = Mock(spec=Precision)
loss = Mock()
fabric.backward(loss, "arg", keyword="kwarg")
Expand All @@ -559,7 +572,7 @@ def test_backward():
@RunIf(deepspeed=True, mps=False)
def test_backward_model_input_required():
"""Test that when using deepspeed and multiple models, backward() requires the model as input."""
fabric = EmptyFabric(strategy="deepspeed")
fabric = Fabric(strategy="deepspeed")

model0 = nn.Linear(1, 2)
model1 = nn.Linear(1, 2)
Expand All @@ -580,7 +593,7 @@ def test_backward_model_input_required():

def test_autocast():
"""Test that the Fabric autocast context manager lets the precision plugin handle casting."""
fabric = EmptyFabric()
fabric = Fabric()
fabric._precision.forward_context = MagicMock()

fabric._precision.forward_context().__enter__.assert_not_called()
Expand All @@ -591,7 +604,7 @@ def test_autocast():

def test_no_backward_sync():
"""Test that `Fabric.no_backward_sync()` validates the strategy and model is compatible."""
fabric = EmptyFabric()
fabric = Fabric()
model = nn.Linear(3, 3)
with pytest.raises(TypeError, match="You need to set up the model first"):
with fabric.no_backward_sync(model):
Expand Down Expand Up @@ -678,7 +691,7 @@ def run(self):
def test_module_sharding_context():
"""Test that the sharding context manager gets applied when the strategy supports it and is a no-op
otherwise."""
fabric = EmptyFabric()
fabric = Fabric()
fabric._strategy = MagicMock(spec=DDPStrategy, module_sharded_context=Mock())
with fabric.sharded_model():
pass
Expand Down
7 changes: 1 addition & 6 deletions tests/tests_fabric/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,6 @@
from tests_fabric.helpers.runif import RunIf


class EmptyFabric(Fabric):
def run(self):
pass


def test_fabric_module_wraps():
"""Test that the wrapped module is accessible via the property."""
module = Mock()
Expand Down Expand Up @@ -137,7 +132,7 @@ def __init__(self):
)
def test_fabric_module_forward_conversion(precision, input_type, expected_type, accelerator, device_str):
"""Test that the FabricModule performs autocasting on the input tensors and during forward()."""
fabric = EmptyFabric(precision=precision, accelerator=accelerator, devices=1)
fabric = Fabric(precision=precision, accelerator=accelerator, devices=1)
device = torch.device(device_str)

def check_autocast(forward_input):
Expand Down

0 comments on commit 92d1c9f

Please sign in to comment.