Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test for torch.compile() with Fabric.setup() #16977

Merged
merged 3 commits into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 45 additions & 32 deletions tests/tests_fabric/test_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,6 @@
from tests_fabric.helpers.runif import RunIf


class EmptyFabric(Fabric):
def run(self):
pass


class BoringModel(nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -81,7 +76,7 @@ def run(self, *args, **kwargs):
@pytest.mark.parametrize("setup_method", ["setup", "setup_module"])
def test_setup_module(ddp_mock, setup_method):
"""Test that the setup method lets the strategy wrap the model, but keeps a reference to the original model."""
fabric = EmptyFabric(accelerator="cpu", strategy="ddp", devices=2)
fabric = Fabric(accelerator="cpu", strategy="ddp", devices=2)
model = nn.Linear(1, 2)
setup_method = getattr(fabric, setup_method)
fabric_model = setup_method(model)
Expand All @@ -91,6 +86,24 @@ def test_setup_module(ddp_mock, setup_method):
assert fabric_model.forward != model.forward


@RunIf(min_torch="2.0.0")
@pytest.mark.parametrize("setup_method", ["setup", "setup_module"])
def test_setup_compiled_module(setup_method):
"""Test that an `OptimizedModule` can be passed to the setup method."""
from torch._dynamo.eval_frame import OptimizedModule

fabric = Fabric()
model = nn.Linear(1, 2)
compiled_model = torch.compile(model)
assert isinstance(compiled_model, OptimizedModule)
setup_method = getattr(fabric, setup_method)
fabric_model = setup_method(compiled_model)

assert fabric_model.module == compiled_model
# Attributes get passed through
assert fabric_model.weight is model.weight


@pytest.mark.parametrize(
"accelerator, initial_device, target_device",
[
Expand All @@ -111,7 +124,7 @@ def test_setup_module_move_to_device(setup_method, move_to_device, accelerator,
target_device = torch.device(target_device)
expected_device = target_device if move_to_device else initial_device

fabric = EmptyFabric(accelerator=accelerator, devices=1)
fabric = Fabric(accelerator=accelerator, devices=1)
model = nn.Linear(1, 2)
model.to(initial_device)
setup_method = getattr(fabric, setup_method)
Expand All @@ -134,7 +147,7 @@ def test_setup_module_parameters_on_different_devices(setup_method, move_to_devi
device0 = torch.device("cpu")
device1 = torch.device("cuda", 0)

fabric = EmptyFabric(accelerator="cuda", devices=1)
fabric = Fabric(accelerator="cuda", devices=1)

module0 = nn.Linear(1, 2).to(device0)
module1 = nn.Linear(1, 2).to(device1)
Expand All @@ -157,7 +170,7 @@ def test_setup_module_parameters_on_different_devices(setup_method, move_to_devi

def test_setup_module_and_optimizers():
"""Test that `setup()` can handle no optimizers, one optimizer, or multiple optimizers."""
fabric = EmptyFabric()
fabric = Fabric()
model = nn.Linear(1, 2)
optimizer0 = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer1 = torch.optim.Adam(model.parameters(), lr=0.1)
Expand Down Expand Up @@ -186,7 +199,7 @@ def test_setup_module_and_optimizers():

def test_setup_optimizers():
"""Test that `setup_optimizers()` can handle one or more optimizers."""
fabric = EmptyFabric()
fabric = Fabric()
model = nn.Linear(1, 2)
optimizer0 = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer1 = torch.optim.Adam(model.parameters(), lr=0.1)
Expand All @@ -206,7 +219,7 @@ def test_setup_optimizers():

def test_setup_twice_fails():
"""Test that calling `setup` with a model or optimizer that is already wrapped fails."""
fabric = EmptyFabric()
fabric = Fabric()
model = nn.Linear(1, 2)
optimizer = torch.optim.Adam(model.parameters())

Expand All @@ -221,7 +234,7 @@ def test_setup_twice_fails():

def test_setup_module_twice_fails():
"""Test that calling `setup_module` with a model that is already wrapped fails."""
fabric = EmptyFabric()
fabric = Fabric()
model = nn.Linear(1, 2)

fabric_model = fabric.setup_module(model)
Expand All @@ -231,7 +244,7 @@ def test_setup_module_twice_fails():

def test_setup_optimizers_twice_fails():
"""Test that calling `setup_module` with a model that is already wrapped fails."""
fabric = EmptyFabric()
fabric = Fabric()
model = nn.Linear(1, 2)
optimizer = torch.optim.Adam(model.parameters())

Expand All @@ -244,7 +257,7 @@ def test_setup_optimizers_twice_fails():
def test_setup_optimizers_not_supported(strategy_cls):
"""Test that `setup_optimizers` validates the strategy supports setting up model and optimizers
independently."""
fabric = EmptyFabric()
fabric = Fabric()
model = nn.Linear(1, 2)
optimizer = torch.optim.Adam(model.parameters())
fabric._strategy = Mock(spec=strategy_cls)
Expand All @@ -254,7 +267,7 @@ def test_setup_optimizers_not_supported(strategy_cls):

def test_setup_tracks_num_models():
"""Test that setup() tracks how many times it has setup a model."""
fabric = EmptyFabric()
fabric = Fabric()
model = nn.Linear(1, 2)
optimizer = torch.optim.Adam(model.parameters())

Expand All @@ -271,7 +284,7 @@ def test_setup_tracks_num_models():

def test_setup_dataloaders_unsupported_input():
"""Test that the setup_dataloaders method fails when provided with non-DataLoader objects."""
fabric = EmptyFabric()
fabric = Fabric()
with pytest.raises(ValueError, match="`setup_dataloaders` requires at least one dataloader"):
fabric.setup_dataloaders()
with pytest.raises(TypeError, match="Only PyTorch DataLoader are currently supported"):
Expand All @@ -280,7 +293,7 @@ def test_setup_dataloaders_unsupported_input():

def test_setup_dataloaders_return_type():
"""Test that the setup method returns the dataloaders wrapped as FabricDataLoader and in the right order."""
fabric = EmptyFabric()
fabric = Fabric()

# single dataloader
fabric_dataloader = fabric.setup_dataloaders(DataLoader(range(2)))
Expand Down Expand Up @@ -315,7 +328,7 @@ def run(self):
def test_setup_dataloaders_raises_for_unknown_custom_args():
"""Test that an error raises when custom dataloaders with unknown arguments are created from outside Fabric's
run method."""
fabric = EmptyFabric()
fabric = Fabric()

class CustomDataLoader(DataLoader):
def __init__(self, new_arg, *args, **kwargs):
Expand All @@ -335,7 +348,7 @@ def __init__(self, new_arg, *args, **kwargs):

def test_setup_dataloaders_twice_fails():
"""Test that calling setup_dataloaders with a dataloader that is already wrapped fails."""
fabric = EmptyFabric()
fabric = Fabric()
dataloader = DataLoader(range(2))
fabric_dataloader = fabric.setup_dataloaders(dataloader)

Expand All @@ -350,12 +363,12 @@ def test_setup_dataloaders_twice_fails():
)
def test_setup_dataloaders_move_to_device(fabric_device_mock):
"""Test that the setup configures FabricDataLoader to move the data to the device automatically."""
fabric = EmptyFabric()
fabric = Fabric()
fabric_dataloaders = fabric.setup_dataloaders(DataLoader(Mock()), DataLoader(Mock()), move_to_device=False)
assert all(dl.device is None for dl in fabric_dataloaders)
fabric_device_mock.assert_not_called()

fabric = EmptyFabric()
fabric = Fabric()
fabric_dataloaders = fabric.setup_dataloaders(DataLoader(Mock()), DataLoader(Mock()), move_to_device=True)
assert all(dl.device == torch.device("cuda", 1) for dl in fabric_dataloaders)
fabric_device_mock.assert_called()
Expand All @@ -367,7 +380,7 @@ def test_setup_dataloaders_distributed_sampler_not_needed():
dataloader = DataLoader(Mock(), sampler=custom_sampler)

# keep the custom sampler when not needed to replace
fabric = EmptyFabric()
fabric = Fabric()
fabric_dataloader = fabric.setup_dataloaders(dataloader, use_distributed_sampler=True)
assert fabric_dataloader.sampler is custom_sampler

Expand Down Expand Up @@ -440,9 +453,9 @@ def fetch_epoch(loader):
@mock.patch.dict(os.environ, {}, clear=True)
def test_seed_everything():
"""Test that seed everything is static and sets the worker init function on the dataloader."""
EmptyFabric.seed_everything(3)
Fabric.seed_everything(3)

fabric = EmptyFabric()
fabric = Fabric()
fabric_dataloader = fabric.setup_dataloaders(DataLoader(Mock()))

assert fabric_dataloader.worker_init_fn.func is pl_worker_init_function
Expand All @@ -466,7 +479,7 @@ def test_setup_dataloaders_replace_custom_sampler(strategy):
dataloader = DataLoader(Mock(), sampler=custom_sampler)

# explicitly asking to replace when a custom sampler is already configured raises an exception
fabric = EmptyFabric(accelerator="cpu", strategy=strategy, devices=2)
fabric = Fabric(accelerator="cpu", strategy=strategy, devices=2)
if hasattr(fabric.strategy, "distributed_sampler_kwargs"):
with pytest.raises(TypeError, match="You seem to have configured a sampler in your DataLoader"):
fabric.setup_dataloaders(dataloader, use_distributed_sampler=True)
Expand All @@ -489,7 +502,7 @@ def test_setup_dataloaders_replace_custom_sampler(strategy):
@pytest.mark.parametrize("shuffle", [True, False])
def test_setup_dataloaders_replace_standard_sampler(shuffle, strategy):
"""Test that Fabric replaces the default samplers with DistributedSampler automatically."""
fabric = EmptyFabric(accelerator="cpu", strategy=strategy, devices=2)
fabric = Fabric(accelerator="cpu", strategy=strategy, devices=2)
is_distributed = hasattr(fabric.strategy, "distributed_sampler_kwargs")
fabric_dataloader = fabric.setup_dataloaders(DataLoader(range(3), shuffle=shuffle))
assert not is_distributed or isinstance(fabric_dataloader.sampler, DistributedSampler)
Expand Down Expand Up @@ -535,7 +548,7 @@ def run(self):

def test_rank_properties():
"""Test that the rank properties are determined by the strategy."""
fabric = EmptyFabric()
fabric = Fabric()
fabric._strategy = Mock(spec=Strategy)
fabric._strategy.world_size = 1000
assert fabric.world_size == 1000
Expand All @@ -549,7 +562,7 @@ def test_rank_properties():

def test_backward():
"""Test that backward() calls into the precision plugin."""
fabric = EmptyFabric()
fabric = Fabric()
fabric._precision = Mock(spec=Precision)
loss = Mock()
fabric.backward(loss, "arg", keyword="kwarg")
Expand All @@ -559,7 +572,7 @@ def test_backward():
@RunIf(deepspeed=True, mps=False)
def test_backward_model_input_required():
"""Test that when using deepspeed and multiple models, backward() requires the model as input."""
fabric = EmptyFabric(strategy="deepspeed")
fabric = Fabric(strategy="deepspeed")

model0 = nn.Linear(1, 2)
model1 = nn.Linear(1, 2)
Expand All @@ -580,7 +593,7 @@ def test_backward_model_input_required():

def test_autocast():
"""Test that the Fabric autocast context manager lets the precision plugin handle casting."""
fabric = EmptyFabric()
fabric = Fabric()
fabric._precision.forward_context = MagicMock()

fabric._precision.forward_context().__enter__.assert_not_called()
Expand All @@ -591,7 +604,7 @@ def test_autocast():

def test_no_backward_sync():
"""Test that `Fabric.no_backward_sync()` validates the strategy and model is compatible."""
fabric = EmptyFabric()
fabric = Fabric()
model = nn.Linear(3, 3)
with pytest.raises(TypeError, match="You need to set up the model first"):
with fabric.no_backward_sync(model):
Expand Down Expand Up @@ -678,7 +691,7 @@ def run(self):
def test_module_sharding_context():
"""Test that the sharding context manager gets applied when the strategy supports it and is a no-op
otherwise."""
fabric = EmptyFabric()
fabric = Fabric()
fabric._strategy = MagicMock(spec=DDPStrategy, module_sharded_context=Mock())
with fabric.sharded_model():
pass
Expand Down
7 changes: 1 addition & 6 deletions tests/tests_fabric/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,6 @@
from tests_fabric.helpers.runif import RunIf


class EmptyFabric(Fabric):
def run(self):
pass


def test_fabric_module_wraps():
"""Test that the wrapped module is accessible via the property."""
module = Mock()
Expand Down Expand Up @@ -137,7 +132,7 @@ def __init__(self):
)
def test_fabric_module_forward_conversion(precision, input_type, expected_type, accelerator, device_str):
"""Test that the FabricModule performs autocasting on the input tensors and during forward()."""
fabric = EmptyFabric(precision=precision, accelerator=accelerator, devices=1)
fabric = Fabric(precision=precision, accelerator=accelerator, devices=1)
device = torch.device(device_str)

def check_autocast(forward_input):
Expand Down