Skip to content

Commit

Permalink
Test parametrization utils for native funcol migration (pytorch#119950)
Browse files Browse the repository at this point in the history
```
Between the time we switch to the native funcol by default and the time when
we are confident that we can remove the legacy implementation, we want to
ensure that the legacy funcol remains covered by unit tests. This is to
prepare for any potential (but unlikely) reverts. The following utilities
help achieve this goal.

run_with_{native,legacy}_funcol - mark a test to run with only
{native,legacy} funcol. These decorators are for impl specific tests (e.g.
verifying generated code with FileCheck).

run_with_both_funcol_impls - parametrize a test to run with both legacy and
native funcol.

run_with_both_funcol_impls_with_arg - same as run_with_both_funcol_impls, but
passes `enable_native_funcol` to the test so impl specific checks can be
carried out.
```

This PR also marks some tests we want to cover in this fashion. More tests will be marked in subsequent PRs.

Pull Request resolved: pytorch#119950
Approved by: https://github.com/wanchaol
ghstack dependencies: pytorch#119881
  • Loading branch information
yifuwang authored and pytorchmergebot committed Feb 19, 2024
1 parent 40786ca commit 637cf4a
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 96 deletions.
92 changes: 41 additions & 51 deletions test/distributed/_tensor/test_dtensor_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import torch._dynamo
import torch._dynamo.testing
import torch.distributed as dist
import torch.distributed._functional_collectives as funcol
import torch.nn as nn
from torch._C import FileCheck
from torch._inductor.utils import run_and_get_triton_code
Expand All @@ -34,7 +33,11 @@
PrepareModuleOutput,
RowwiseParallel,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_distributed import (
run_with_both_funcol_impls,
run_with_both_funcol_impls_with_arg,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
Expand Down Expand Up @@ -81,6 +84,7 @@ def extract_graph(fx_g, _, graph_cell):
)


@instantiate_parametrized_tests
class TestDTensorCompile(torch._dynamo.test_case.TestCase):
def setUp(self):
super().setUp()
Expand All @@ -101,6 +105,7 @@ def device_type(self) -> str:
def world_size(self) -> int:
return 2

@run_with_both_funcol_impls
def test_placement_compile(self):
def fn(x):
a = 0
Expand All @@ -127,6 +132,7 @@ def fn(x):
compiled_out = compiled_fn(x)
self.assertEqual(opt_fn, compiled_out)

@run_with_both_funcol_impls
def test_device_mesh_compile(self):
def fn(x):
# test size()
Expand All @@ -147,6 +153,7 @@ def fn(x):
compiled_out = compiled_fn(mesh)
self.assertEqual(opt_fn, compiled_out)

@run_with_both_funcol_impls
def test_fakify_dtensor(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))

Expand All @@ -161,6 +168,7 @@ def fn(x):
res = opt_fn(x)
self.assertEqual(res, ref)

@run_with_both_funcol_impls
def test_dynamo_dtensor(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))

Expand All @@ -175,6 +183,7 @@ def fn(x):
res = opt_fn(x)
self.assertEqual(res, ref)

@run_with_both_funcol_impls
def test_dtensor_attribute_access_on_intermediate(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))

Expand All @@ -192,6 +201,7 @@ def fn(x):
res = opt_fn(x)
self.assertEqual(res, ref)

@run_with_both_funcol_impls
def test_dynamo_dtensor_from_local(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))

Expand Down Expand Up @@ -235,6 +245,7 @@ def from_local_kwargs_fn(x):
self.assertEqual(res, ref)
self.assertEqual(cnt.frame_count, 2)

@run_with_both_funcol_impls
def test_dynamo_dtensor_from_local_redistribute(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))

Expand Down Expand Up @@ -271,7 +282,8 @@ def redistribute_kwargs_fn(x):
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
def test_tp_compile_comm_reordering(self):
@run_with_both_funcol_impls_with_arg
def test_tp_compile_comm_reordering(self, use_native_funcol):
class FakeAttention(nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -333,28 +345,35 @@ def forward(self, input):
self.assertEqual(cnt.frame_count, 1)

code = run_and_get_triton_code(compiled_model, inp)
# Check that `buf2` is correctly waited on before first use.
# fmt: off
FileCheck() \
.check("buf1_work = dist.all_gather_into_tensor(buf1[0]") \
.check("buf2 = buf1[0]") \
.check("buf2 = _wait_tensor(buf2)") \
.check("extern_kernels.mm(buf2,") \
.run(code)
if use_native_funcol:
FileCheck().check(
"buf0 = torch.ops._c10d_functional.all_gather_into_tensor.default(primal"
).check("buf1 = torch.ops._c10d_functional.wait_tensor.default(buf0").check(
"extern_kernels.mm(buf0,"
).run(
code
)
else:
# Check that `buf2` is correctly waited on before first use.
# fmt: off
FileCheck() \
.check("buf1_work = dist.all_gather_into_tensor(buf1[0]") \
.check("buf2 = buf1[0]") \
.check("buf2 = _wait_tensor(buf2)") \
.check("extern_kernels.mm(buf2,") \
.run(code)


@instantiate_parametrized_tests
class TestDTensorCompileE2E(DTensorTestBase):
@property
def world_size(self):
return 4

@with_comms
@parametrize("is_seq_parallel", [True, False])
@parametrize("use_native_funcol", [True, False])
def test_tp_compile_fullgraph(self, is_seq_parallel, use_native_funcol):
if use_native_funcol:
funcol.enable_native_funcol()

@run_with_both_funcol_impls
def test_tp_compile_fullgraph(self, is_seq_parallel):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))

model = SimpleModel(self.device_type)
Expand Down Expand Up @@ -414,11 +433,8 @@ def test_tp_compile_fullgraph(self, is_seq_parallel, use_native_funcol):

@with_comms
@skip_if_lt_x_gpu(4)
@parametrize("use_native_funcol", [True, False])
def test_2d_fsdp_tp_compile(self, use_native_funcol):
if use_native_funcol:
funcol.enable_native_funcol()

@run_with_both_funcol_impls
def test_2d_fsdp_tp_compile(self):
data_parallel_size = 2
model = SimpleModel(self.device_type)
model_copy = copy.deepcopy(model)
Expand Down Expand Up @@ -469,11 +485,8 @@ def test_2d_fsdp_tp_compile(self, use_native_funcol):

@with_comms
@skip_if_lt_x_gpu(4)
@parametrize("use_native_funcol", [True, False])
def test_2d_fsdp_tp_ac_compile(self, use_native_funcol):
if use_native_funcol:
funcol.enable_native_funcol()

@run_with_both_funcol_impls
def test_2d_fsdp_tp_ac_compile(self):
dp_degree = 2
tp_degree = self.world_size // dp_degree
model = SimpleModel(self.device_type)
Expand Down Expand Up @@ -524,11 +537,8 @@ def test_2d_fsdp_tp_ac_compile(self, use_native_funcol):

@with_comms
@skip_if_lt_x_gpu(4)
@parametrize("use_native_funcol", [True, False])
def test_compile_dtensor_redistribute_backward(self, use_native_funcol):
if use_native_funcol:
funcol.enable_native_funcol()

@run_with_both_funcol_impls
def test_compile_dtensor_redistribute_backward(self):
mesh = DeviceMesh(device_type="cuda", mesh=torch.arange(self.world_size))

def fn(x, y):
Expand Down Expand Up @@ -558,25 +568,5 @@ def fn(x, y):
self.assertEqual(y_ref.grad, y.grad)


class TestDTensorCompileWithNativeFunCol(TestDTensorCompile):
def setUp(self) -> None:
self._prev_native_funcol_enabled = funcol.native_funcol_enabled()
funcol.enable_native_funcol()
super().setUp()

def tearDown(self) -> None:
super().tearDown()
if not self._prev_native_funcol_enabled:
funcol.disable_native_funcol()

def test_tp_compile_comm_reordering(self):
# Bypass this test for now. The native funcols have different
# IRs, so the reordering pass needs to be reworked.
pass


instantiate_parametrized_tests(TestDTensorCompileE2E)


if __name__ == "__main__":
run_tests()
Loading

0 comments on commit 637cf4a

Please sign in to comment.