From 637cf4a3f2cfdd364005681636ca885bdc4d5887 Mon Sep 17 00:00:00 2001 From: Yifu Wang Date: Sun, 18 Feb 2024 11:51:44 -0800 Subject: [PATCH] Test parametrization utils for native funcol migration (#119950) ``` Between the time we switch to the native funcol by default and the time when we are confident that we can remove the legacy implementation, we want to ensure that the legacy funcol remains covered by unit tests. This is to prepare for any potential (but unlikely) reverts. The following utilities help achieve this goal. run_with_{native,legacy}_funcol - mark a test to run with only {native,legacy} funcol. These decorators are for impl specific tests (e.g. verifying generated code with FileCheck). run_with_both_funcol_impls - parametrize a test to run with both legacy and native funcol. run_with_both_funcol_impls_with_arg - same as run_with_both_funcol_impls, but passes `enable_native_funcol` to the test so impl specific checks can be carried out. ``` This PR also marks some tests we want to cover in this fashion. More tests will be marked in subsequent PRs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/119950 Approved by: https://github.com/wanchaol ghstack dependencies: #119881 --- .../_tensor/test_dtensor_compile.py | 92 +++++++++---------- .../test_c10d_functional_native.py | 25 ++++- test/distributed/test_device_mesh.py | 63 ++++++++----- torch/distributed/_functional_collectives.py | 2 - .../_functional_collectives_impl.py | 15 --- torch/testing/_internal/common_distributed.py | 62 +++++++++++++ .../distributed/_tensor/common_dtensor.py | 1 + 7 files changed, 164 insertions(+), 96 deletions(-) diff --git a/test/distributed/_tensor/test_dtensor_compile.py b/test/distributed/_tensor/test_dtensor_compile.py index fcb9ff568a74c4..918c5bb865d2d5 100644 --- a/test/distributed/_tensor/test_dtensor_compile.py +++ b/test/distributed/_tensor/test_dtensor_compile.py @@ -10,7 +10,6 @@ import torch._dynamo import torch._dynamo.testing import torch.distributed as dist -import torch.distributed._functional_collectives as funcol import torch.nn as nn from torch._C import FileCheck from torch._inductor.utils import run_and_get_triton_code @@ -34,7 +33,11 @@ PrepareModuleOutput, RowwiseParallel, ) -from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +from torch.testing._internal.common_distributed import ( + run_with_both_funcol_impls, + run_with_both_funcol_impls_with_arg, + skip_if_lt_x_gpu, +) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -81,6 +84,7 @@ def extract_graph(fx_g, _, graph_cell): ) +@instantiate_parametrized_tests class TestDTensorCompile(torch._dynamo.test_case.TestCase): def setUp(self): super().setUp() @@ -101,6 +105,7 @@ def device_type(self) -> str: def world_size(self) -> int: return 2 + @run_with_both_funcol_impls def test_placement_compile(self): def fn(x): a = 0 @@ -127,6 +132,7 @@ def fn(x): compiled_out = compiled_fn(x) self.assertEqual(opt_fn, compiled_out) + @run_with_both_funcol_impls def test_device_mesh_compile(self): def fn(x): # test size() @@ -147,6 +153,7 @@ def fn(x): compiled_out = compiled_fn(mesh) self.assertEqual(opt_fn, compiled_out) + @run_with_both_funcol_impls def test_fakify_dtensor(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -161,6 +168,7 @@ def fn(x): res = opt_fn(x) self.assertEqual(res, ref) + @run_with_both_funcol_impls def test_dynamo_dtensor(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -175,6 +183,7 @@ def fn(x): res = opt_fn(x) self.assertEqual(res, ref) + @run_with_both_funcol_impls def test_dtensor_attribute_access_on_intermediate(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -192,6 +201,7 @@ def fn(x): res = opt_fn(x) self.assertEqual(res, ref) + @run_with_both_funcol_impls def test_dynamo_dtensor_from_local(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -235,6 +245,7 @@ def from_local_kwargs_fn(x): self.assertEqual(res, ref) self.assertEqual(cnt.frame_count, 2) + @run_with_both_funcol_impls def test_dynamo_dtensor_from_local_redistribute(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -271,7 +282,8 @@ def redistribute_kwargs_fn(x): # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor @patch.object(torch._inductor.config, "compile_threads", 1) @patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True) - def test_tp_compile_comm_reordering(self): + @run_with_both_funcol_impls_with_arg + def test_tp_compile_comm_reordering(self, use_native_funcol): class FakeAttention(nn.Module): def __init__(self): super().__init__() @@ -333,16 +345,26 @@ def forward(self, input): self.assertEqual(cnt.frame_count, 1) code = run_and_get_triton_code(compiled_model, inp) - # Check that `buf2` is correctly waited on before first use. - # fmt: off - FileCheck() \ - .check("buf1_work = dist.all_gather_into_tensor(buf1[0]") \ - .check("buf2 = buf1[0]") \ - .check("buf2 = _wait_tensor(buf2)") \ - .check("extern_kernels.mm(buf2,") \ - .run(code) + if use_native_funcol: + FileCheck().check( + "buf0 = torch.ops._c10d_functional.all_gather_into_tensor.default(primal" + ).check("buf1 = torch.ops._c10d_functional.wait_tensor.default(buf0").check( + "extern_kernels.mm(buf0," + ).run( + code + ) + else: + # Check that `buf2` is correctly waited on before first use. + # fmt: off + FileCheck() \ + .check("buf1_work = dist.all_gather_into_tensor(buf1[0]") \ + .check("buf2 = buf1[0]") \ + .check("buf2 = _wait_tensor(buf2)") \ + .check("extern_kernels.mm(buf2,") \ + .run(code) +@instantiate_parametrized_tests class TestDTensorCompileE2E(DTensorTestBase): @property def world_size(self): @@ -350,11 +372,8 @@ def world_size(self): @with_comms @parametrize("is_seq_parallel", [True, False]) - @parametrize("use_native_funcol", [True, False]) - def test_tp_compile_fullgraph(self, is_seq_parallel, use_native_funcol): - if use_native_funcol: - funcol.enable_native_funcol() - + @run_with_both_funcol_impls + def test_tp_compile_fullgraph(self, is_seq_parallel): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) model = SimpleModel(self.device_type) @@ -414,11 +433,8 @@ def test_tp_compile_fullgraph(self, is_seq_parallel, use_native_funcol): @with_comms @skip_if_lt_x_gpu(4) - @parametrize("use_native_funcol", [True, False]) - def test_2d_fsdp_tp_compile(self, use_native_funcol): - if use_native_funcol: - funcol.enable_native_funcol() - + @run_with_both_funcol_impls + def test_2d_fsdp_tp_compile(self): data_parallel_size = 2 model = SimpleModel(self.device_type) model_copy = copy.deepcopy(model) @@ -469,11 +485,8 @@ def test_2d_fsdp_tp_compile(self, use_native_funcol): @with_comms @skip_if_lt_x_gpu(4) - @parametrize("use_native_funcol", [True, False]) - def test_2d_fsdp_tp_ac_compile(self, use_native_funcol): - if use_native_funcol: - funcol.enable_native_funcol() - + @run_with_both_funcol_impls + def test_2d_fsdp_tp_ac_compile(self): dp_degree = 2 tp_degree = self.world_size // dp_degree model = SimpleModel(self.device_type) @@ -524,11 +537,8 @@ def test_2d_fsdp_tp_ac_compile(self, use_native_funcol): @with_comms @skip_if_lt_x_gpu(4) - @parametrize("use_native_funcol", [True, False]) - def test_compile_dtensor_redistribute_backward(self, use_native_funcol): - if use_native_funcol: - funcol.enable_native_funcol() - + @run_with_both_funcol_impls + def test_compile_dtensor_redistribute_backward(self): mesh = DeviceMesh(device_type="cuda", mesh=torch.arange(self.world_size)) def fn(x, y): @@ -558,25 +568,5 @@ def fn(x, y): self.assertEqual(y_ref.grad, y.grad) -class TestDTensorCompileWithNativeFunCol(TestDTensorCompile): - def setUp(self) -> None: - self._prev_native_funcol_enabled = funcol.native_funcol_enabled() - funcol.enable_native_funcol() - super().setUp() - - def tearDown(self) -> None: - super().tearDown() - if not self._prev_native_funcol_enabled: - funcol.disable_native_funcol() - - def test_tp_compile_comm_reordering(self): - # Bypass this test for now. The native funcols have different - # IRs, so the reordering pass needs to be reworked. - pass - - -instantiate_parametrized_tests(TestDTensorCompileE2E) - - if __name__ == "__main__": run_tests() diff --git a/test/distributed/test_c10d_functional_native.py b/test/distributed/test_c10d_functional_native.py index a2af8fba5f6577..d52dbf50171fd1 100644 --- a/test/distributed/test_c10d_functional_native.py +++ b/test/distributed/test_c10d_functional_native.py @@ -20,9 +20,9 @@ from torch.testing._internal.common_distributed import ( MultiProcessTestCase, requires_nccl, + run_with_native_funcol, skip_if_lt_x_gpu, ) - from torch.testing._internal.common_utils import ( # type: ignore[attr-defined] run_tests, TestCase, @@ -57,7 +57,6 @@ def load_test_module(name): class C10DFunctionalNativeTest(MultiProcessTestCase): def setUp(self) -> None: super().setUp() - funcol.enable_native_funcol() self._spawn_processes() @property @@ -88,6 +87,7 @@ def _init_process_group(self) -> None: torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD) @skip_if_lt_x_gpu(2) + @run_with_native_funcol def test_all_reduce_single(self) -> None: self._init_process_group() @@ -114,6 +114,7 @@ def test_all_reduce_single(self) -> None: assert output.completed @skip_if_lt_x_gpu(2) + @run_with_native_funcol def test_all_reduce_single_(self) -> None: self._init_process_group() @@ -129,6 +130,7 @@ def test_all_reduce_single_(self) -> None: assert output.eq(expect).all() @skip_if_lt_x_gpu(2) + @run_with_native_funcol def test_all_reduce_coalesced(self) -> None: self._init_process_group() @@ -158,6 +160,7 @@ def test_all_reduce_coalesced(self) -> None: assert output.completed @skip_if_lt_x_gpu(2) + @run_with_native_funcol def test_all_reduce_coalesced_(self) -> None: self._init_process_group() @@ -176,6 +179,7 @@ def test_all_reduce_coalesced_(self) -> None: assert output.eq(sum(self.ranks) / self.world_size * i).all() @skip_if_lt_x_gpu(2) + @run_with_native_funcol def test_all_gather_into_tensor_single(self) -> None: self._init_process_group() @@ -207,6 +211,7 @@ def test_all_gather_into_tensor_single(self) -> None: assert output.completed @skip_if_lt_x_gpu(2) + @run_with_native_funcol def test_all_gather_into_tensor_coalesced(self) -> None: self._init_process_group() @@ -243,6 +248,7 @@ def test_all_gather_into_tensor_coalesced(self) -> None: assert output.completed @skip_if_lt_x_gpu(2) + @run_with_native_funcol def test_reduce_scatter_tensor_single(self) -> None: self._init_process_group() @@ -269,6 +275,7 @@ def test_reduce_scatter_tensor_single(self) -> None: assert output.completed @skip_if_lt_x_gpu(2) + @run_with_native_funcol def test_reduce_scatter_tensor_coalesced(self) -> None: self._init_process_group() @@ -296,6 +303,7 @@ def test_reduce_scatter_tensor_coalesced(self) -> None: assert output.completed @skip_if_lt_x_gpu(2) + @run_with_native_funcol def test_all_to_all_single(self) -> None: self._init_process_group() torch.cuda.set_device(self.device) @@ -331,6 +339,7 @@ def test_all_to_all_single(self) -> None: assert output.completed @skip_if_lt_x_gpu(2) + @run_with_native_funcol def test_broadcast(self) -> None: self._init_process_group() @@ -357,6 +366,7 @@ def test_broadcast(self) -> None: assert output.completed @skip_if_lt_x_gpu(2) + @run_with_native_funcol def test_unwaited(self) -> None: # Verify that the process can terminate gracefully # even with unwaited tensors @@ -387,7 +397,6 @@ def setUp(self): rank=self.rank, store=store, ) - funcol.enable_native_funcol() def tearDown(self): funcol.disable_native_funcol() @@ -395,6 +404,7 @@ def tearDown(self): @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() + @run_with_native_funcol def test_inductor_all_reduce_single(self): def func(arg: torch.Tensor) -> torch.Tensor: buf0 = arg + 42 @@ -431,6 +441,7 @@ def func(arg: torch.Tensor) -> torch.Tensor: @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() + @run_with_native_funcol def test_inductor_all_reduce_coalesced(self): def func(args: List[torch.Tensor]) -> torch.Tensor: bufs = [arg + 42 for arg in args] @@ -476,6 +487,7 @@ def func(args: List[torch.Tensor]) -> torch.Tensor: @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() + @run_with_native_funcol def test_inductor_inplace_op_on_view(self): def func(arg: torch.Tensor) -> torch.Tensor: buf0 = (arg + 10)[:2] @@ -503,6 +515,7 @@ def func(arg: torch.Tensor) -> torch.Tensor: @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() + @run_with_native_funcol def test_inductor_reuse_buffer_after_inplace_collective(self): def func(arg: torch.Tensor) -> torch.Tensor: # Expect allocation @@ -537,6 +550,7 @@ def func(arg: torch.Tensor) -> torch.Tensor: @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() + @run_with_native_funcol def test_inductor_all_gather_into_tensor_single(self): def func(arg: torch.Tensor) -> torch.Tensor: ag0 = funcol.all_gather_tensor(arg, 0, "0") @@ -563,6 +577,7 @@ def func(arg: torch.Tensor) -> torch.Tensor: @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() + @run_with_native_funcol def test_inductor_all_gather_into_tensor_coalesced(self): def func(args: List[torch.Tensor]) -> torch.Tensor: ag0 = funcol.all_gather_into_tensor_coalesced(args, "0") @@ -597,6 +612,7 @@ def func(args: List[torch.Tensor]) -> torch.Tensor: @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() + @run_with_native_funcol def test_inductor_reduce_scatter_tensor_single(self): def func(arg: torch.Tensor) -> torch.Tensor: rs0 = funcol.reduce_scatter_tensor(arg, "avg", 0, "0") @@ -623,6 +639,7 @@ def func(arg: torch.Tensor) -> torch.Tensor: @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() + @run_with_native_funcol def test_inductor_reduce_scatter_tensor_coalesced(self): def func(args: List[torch.Tensor]) -> torch.Tensor: rs0 = funcol.reduce_scatter_tensor_coalesced( @@ -659,6 +676,7 @@ def func(args: List[torch.Tensor]) -> torch.Tensor: @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() + @run_with_native_funcol def test_inductor_all_to_all_single(self): def _tolist_with_constrain_as_size(tensor): lst = tensor.tolist() @@ -707,6 +725,7 @@ def func( @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() + @run_with_native_funcol def test_inductor_broadcast(self): def func(arg: torch.Tensor) -> torch.Tensor: buf0 = arg + 42 diff --git a/test/distributed/test_device_mesh.py b/test/distributed/test_device_mesh.py index 9e1f12b36b6817..28f8a44bf6f69c 100644 --- a/test/distributed/test_device_mesh.py +++ b/test/distributed/test_device_mesh.py @@ -21,7 +21,11 @@ is_nccl_available, ProcessGroup, ) -from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_distributed import run_with_both_funcol_impls +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + run_tests, +) from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, with_comms, @@ -48,11 +52,13 @@ def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0): os.environ["RANK"] = f"{rank}" +@instantiate_parametrized_tests class DeviceMeshTest(DTensorTestBase): @property def world_size(self): return 4 + @run_with_both_funcol_impls def test_init_process_group(self): device_type = _get_device_type(self.world_size) mesh_tensor = torch.arange(4).reshape(2, 2) @@ -63,6 +69,7 @@ def test_init_process_group(self): self.destroy_pg() @with_comms + @run_with_both_funcol_impls def test_get_group(self): mesh_shape = (2, self.world_size // 2) mesh_2d = init_device_mesh( @@ -83,6 +90,7 @@ def test_get_group(self): self.assertEqual(mesh_2d.get_group("tp"), tp_mesh.get_group()) @with_comms + @run_with_both_funcol_impls def test_get_local_rank_raises_exception(self): mesh_shape = (2, self.world_size // 2) mesh_2d = init_device_mesh( @@ -96,6 +104,7 @@ def test_get_local_rank_raises_exception(self): local_rank = mesh_2d.get_local_rank() @with_comms + @run_with_both_funcol_impls def test_get_local_rank(self): mesh_shape = (2, self.world_size // 2) mesh_2d = init_device_mesh( @@ -110,6 +119,7 @@ def test_get_local_rank(self): self.assertEqual(tp_mesh.get_local_rank(), mesh_2d.get_local_rank("tp")) @with_comms + @run_with_both_funcol_impls def test_device_mesh_2d(self): mesh_tensor = torch.arange(4).reshape(2, 2) # construct a cuda device mesh @@ -134,6 +144,7 @@ def test_device_mesh_2d(self): ) self.assertEqual(global_ranks, current_rank_expected_group_ranks) + @run_with_both_funcol_impls def test_fake_pg_device_mesh(self): fake_store = FakeStore() init_process_group("fake", store=fake_store, rank=0, world_size=self.world_size) @@ -153,6 +164,7 @@ def world_size(self): return 8 @with_comms + @run_with_both_funcol_impls def test_device_mesh_nd(self): # construct a cuda device mesh mesh_tensor = torch.arange(8).reshape(2, 2, 2) @@ -176,6 +188,7 @@ def test_device_mesh_nd(self): self.assertEqual(global_ranks, ranks.tolist()) @with_comms + @run_with_both_funcol_impls def test_device_mesh_hash(self): mesh_tensor_2d = torch.arange(8).reshape(4, 2) mesh = DeviceMesh(self.device_type, mesh_tensor_2d) @@ -193,6 +206,7 @@ def world_size(self): return 8 @with_comms + @run_with_both_funcol_impls def test_init_device_mesh(self): mesh_shape = (2, 4) ref_mesh = DeviceMesh(self.device_type, torch.arange(8).view(mesh_shape)) @@ -210,6 +224,7 @@ def test_init_device_mesh(self): self.assertEqual(mesh_2d, ref_mesh) @with_comms + @run_with_both_funcol_impls def test_raises_duplicate_mesh_dim_names(self): with self.assertRaisesRegex( RuntimeError, @@ -222,6 +237,7 @@ def test_raises_duplicate_mesh_dim_names(self): ) @with_comms + @run_with_both_funcol_impls def test_raises_mesh_shape_mesh_dim_names_mismatch(self): with self.assertRaisesRegex( RuntimeError, @@ -234,18 +250,21 @@ def test_raises_mesh_shape_mesh_dim_names_mismatch(self): ) +@instantiate_parametrized_tests class TestDeviceMeshGetItem(DTensorTestBase): @property def world_size(self): return 8 @with_comms + @run_with_both_funcol_impls def test_raises_no_mesh_dim_found(self): with self.assertRaisesRegex(KeyError, "No `mesh_dim_names` found."): mesh = init_device_mesh(self.device_type, (2, 4)) child_mesh = mesh["DP"] @with_comms + @run_with_both_funcol_impls def test_raises_invalid_mesh_dim_name(self): child_mesh_dim_name = "PP" with self.assertRaisesRegex( @@ -258,6 +277,7 @@ def test_raises_invalid_mesh_dim_name(self): child_mesh = mesh[child_mesh_dim_name] @with_comms + @run_with_both_funcol_impls def test_get_item(self): mesh_shape = (2, 4) mesh_dim_names = ("DP", "TP") @@ -281,6 +301,7 @@ def test_get_item(self): self.assertEqual(mesh_2d["DP"].mesh, pg_ranks_by_dim_name["DP"][dp_group_idx]) @with_comms + @run_with_both_funcol_impls def test_get_item_1d(self): mesh = init_device_mesh(self.device_type, (8,), mesh_dim_names=("dp",)) # Make sure slicing out 1D mesh from a 1D mesh works. @@ -292,8 +313,10 @@ def test_get_item_1d(self): dp_mesh = mesh["dim0"] +@instantiate_parametrized_tests class TestMeshEnv(DTensorTestBase): @with_comms + @run_with_both_funcol_impls def test_get_parent_mesh(self): mesh_shape = (2, self.world_size // 2) mesh_dim_names = ("DP", "TP") @@ -313,6 +336,7 @@ def test_get_parent_mesh(self): self.assertEqual(_mesh_resources.get_parent_mesh(mesh_1_3), None) @with_comms + @run_with_both_funcol_impls def test_get_parent_mesh_dim_exist(self): mesh_shape = (2, self.world_size // 2) mesh_dim_names = ("DP", "TP") @@ -324,6 +348,7 @@ def test_get_parent_mesh_dim_exist(self): self.assertEqual(_mesh_resources.get_parent_mesh_dim(mesh_2d["TP"]), 1) @with_comms + @run_with_both_funcol_impls def test_get_parent_mesh_dim_not_exist(self): mesh_shape = (self.world_size,) mesh = init_device_mesh(self.device_type, mesh_shape) @@ -331,6 +356,7 @@ def test_get_parent_mesh_dim_not_exist(self): self.assertEqual(_mesh_resources.get_parent_mesh_dim(mesh), None) @with_comms + @run_with_both_funcol_impls def test_get_mesh_dim_by_name(self): mesh_shape = (2, self.world_size // 2) mesh_dim_names = ("DP", "TP") @@ -342,12 +368,14 @@ def test_get_mesh_dim_by_name(self): self.assertEqual(_mesh_resources.get_mesh_dim_by_name(mesh_2d, "TP"), 1) +@instantiate_parametrized_tests class DeviceMeshCollectiveTest(DTensorTestBase): @property def world_size(self): return 8 @with_comms + @run_with_both_funcol_impls def test_broadcast_1d(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank @@ -355,6 +383,7 @@ def test_broadcast_1d(self): self.assertEqual(local_tensor, torch.zeros(3, 3)) @with_comms + @run_with_both_funcol_impls def test_scatter_1d(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) scatter_tensor_shape = [3, 3, 3] @@ -373,6 +402,7 @@ def test_scatter_1d(self): self.assertEqual(recv_tensor, splitted_list[mesh.get_rank()]) @with_comms + @run_with_both_funcol_impls def test_scatter_uneven(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) my_rank = device_mesh.get_rank() @@ -418,6 +448,7 @@ def test_scatter_uneven(self): self.assertEqual(scattered_tensor, tensor_splitted_list[my_rank]) @with_comms + @run_with_both_funcol_impls def test_all_gather_uneven(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) my_rank = device_mesh.get_rank() @@ -454,6 +485,7 @@ def test_all_gather_uneven(self): self.assertEqual(all_gathered_tensor, tensor_to_split) @with_comms + @run_with_both_funcol_impls def test_reduce_scatter_contiguous(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) my_rank = device_mesh.get_rank() @@ -496,6 +528,7 @@ def test_reduce_scatter_contiguous(self): self.assertEqual(new_tensor_local, expected_tensor) @with_comms + @run_with_both_funcol_impls def test_reduce_scatter_uneven(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) my_rank = device_mesh.get_rank() @@ -558,6 +591,7 @@ def test_reduce_scatter_uneven(self): ) @with_comms + @run_with_both_funcol_impls def test_broadcast_nd(self): mesh_tensor = torch.arange(8).reshape(2, 2, 2) mesh = DeviceMesh(self.device_type, mesh_tensor) @@ -576,6 +610,7 @@ def test_broadcast_nd(self): self.assertEqual(cloned_local_tensor, torch.ones(3, 3) * res_num) @with_comms + @run_with_both_funcol_impls def test_scatter_nd(self): mesh_tensor = torch.arange(8).reshape(2, 2, 2) mesh = DeviceMesh(self.device_type, mesh_tensor) @@ -598,6 +633,7 @@ def test_scatter_nd(self): self.assertEqual(received_tensor, torch.ones(3, 3) * self.rank) @with_comms + @run_with_both_funcol_impls def test_all_to_all_1d(self): # transpose on a 2D tensor distributed over N nodes: mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -625,6 +661,7 @@ def test_all_to_all_1d(self): self.assertEqual(output_tensor, expected_tensor) @with_comms + @run_with_both_funcol_impls def test_all_to_all_nd(self): mesh_tensor = torch.arange(8).reshape(2, 2, 2) mesh = DeviceMesh(self.device_type, mesh_tensor) @@ -662,29 +699,5 @@ def test_all_to_all_nd(self): self.assertEqual(output_tensor, expected_tensor) -class DeviceMeshTestWithNativeFunCol(DeviceMeshTest): - def setUp(self) -> None: - self._prev_native_funcol_enabled = funcol.native_funcol_enabled() - funcol.enable_native_funcol() - super().setUp() - - def tearDown(self) -> None: - super().tearDown() - if not self._prev_native_funcol_enabled: - funcol.disable_native_funcol() - - -class DeviceMeshCollectiveTestWithNativeFunCol(DeviceMeshCollectiveTest): - def setUp(self) -> None: - self._prev_native_funcol_enabled = funcol.native_funcol_enabled() - funcol.enable_native_funcol() - super().setUp() - - def tearDown(self) -> None: - super().tearDown() - if not self._prev_native_funcol_enabled: - funcol.disable_native_funcol() - - if __name__ == "__main__": run_tests() diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index c3f523ed49e8e8..841f41a02e6224 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -12,8 +12,6 @@ from . import _functional_collectives_impl as fun_col_impl from ._functional_collectives_impl import ( # noqa: F401 _register_tensor_wrapper, - disable_native_funcol, - enable_native_funcol, native_funcol_enabled, ) diff --git a/torch/distributed/_functional_collectives_impl.py b/torch/distributed/_functional_collectives_impl.py index 6c54cd415568e0..f598723730759b 100644 --- a/torch/distributed/_functional_collectives_impl.py +++ b/torch/distributed/_functional_collectives_impl.py @@ -27,25 +27,10 @@ _use_native_funcol = "_USE_NATIVE_C10D_FUNCTIONAL" in os.environ -# These are for testing purposes only and will be removed after we fully -# migrate to native funcol. def native_funcol_enabled(): return _use_native_funcol -def enable_native_funcol(): - global _use_native_funcol - os.environ["_USE_NATIVE_C10D_FUNCTIONAL"] = "1" - _use_native_funcol = True - - -def disable_native_funcol(): - global _use_native_funcol - if "_USE_NATIVE_C10D_FUNCTIONAL" in os.environ: - del os.environ["_USE_NATIVE_C10D_FUNCTIONAL"] - _use_native_funcol = False - - logger = logging.getLogger(__name__) data_ptr_to_work: Dict[int, "_WaitRegistration"] = dict() diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 100693c76cab8f..2ceae282e27331 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -38,6 +38,10 @@ TEST_WITH_TSAN, TestCase, ) +from torch.testing._internal.common_utils import ( + parametrize, + subtest, +) from torch.testing._internal.distributed.multi_threaded_pg import ( _install_threaded_pg, _uninstall_threaded_pg, @@ -1256,3 +1260,61 @@ def _run(cls, rank: int, test_name: str, file_name: str, parent_pipe) -> None: self.rank = rank self.file_name = file_name self.run_test(test_name, parent_pipe) + + +# NOTE [test parametrization utils for native funcol migration] +# +# Between the time we switch to the native funcol by default and the time when +# we are confident that we can remove the legacy implementation, we want to +# ensure that the legacy funcol remains covered by unit tests. This is to +# prepare for any potential (but unlikely) reverts. The following utilities +# help achieve this goal. +# +# run_with_{native,legacy}_funcol - mark a test to run with only +# {native,legacy} funcol. These decorators are for impl specific tests (e.g. +# verifying generated code with FileCheck). +# +# run_with_both_funcol_impls - parametrize a test to run with both legacy and +# native funcol. +# +# run_with_both_funcol_impls_with_arg - same as run_with_both_funcol_impls, but +# passes `enable_native_funcol` to the test so impl specific checks can be +# carried out. +def with_native_funcol(use_native_funcol: bool, remove_arg: bool): + import torch.distributed._functional_collectives_impl as funcol_impl + + def decorator(fn): + def inner(*args, **kwargs): + if remove_arg: + del kwargs["use_native_funcol"] + prev = funcol_impl._use_native_funcol + funcol_impl._use_native_funcol = use_native_funcol + try: + return fn(*args, **kwargs) + finally: + funcol_impl._use_native_funcol = prev + + return inner + + return decorator + + +run_with_native_funcol = with_native_funcol(True, remove_arg=False) +run_with_legacy_funcol = with_native_funcol(False, remove_arg=False) + + +run_with_both_funcol_impls = parametrize( + "use_native_funcol", + [ + subtest(True, decorators=[with_native_funcol(True, remove_arg=True)]), + subtest(False, decorators=[with_native_funcol(False, remove_arg=True)]), + ] +) + +run_with_both_funcol_impls_with_arg = parametrize( + "use_native_funcol", + [ + subtest(True, decorators=[with_native_funcol(True, remove_arg=False)]), + subtest(False, decorators=[with_native_funcol(False, remove_arg=False)]), + ] +) diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index 3386cf797f62c5..e3380d6a43231f 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -31,6 +31,7 @@ skip_if_lt_x_gpu, ) + from torch.distributed._tensor import ( DeviceMesh, Shard,