From d076a5465afa7b93bd5f3cdf60b09827b661c529 Mon Sep 17 00:00:00 2001 From: Allen Wang <9057208+allenwang28@users.noreply.github.com> Date: Thu, 6 Feb 2025 12:39:01 -0800 Subject: [PATCH 1/7] initial commit for reduce_scatter --- torchft/process_group.py | 66 ++++++++++++++++++++++++++++++++--- torchft/process_group_test.py | 2 ++ 2 files changed, 63 insertions(+), 5 deletions(-) diff --git a/torchft/process_group.py b/torchft/process_group.py index 4790352e..a5de50e9 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -58,6 +58,7 @@ AllreduceOptions, BroadcastOptions, ReduceOp, + ReduceScatterOptions, Work, ) from torch.futures import Future @@ -180,6 +181,20 @@ def broadcast_one(self, tensor: torch.Tensor, root: int) -> Work: opts.rootRank = root return self.broadcast([tensor], opts) + # pyre-fixme[14]: inconsistent override + def reduce_scatter( + self, + output_tensors: List[torch.Tensor], + input_tensors: List[torch.Tensor], + opts: ReduceScatterOptions, + ) -> Work: + """ + Reduces, then scatters a list of tensors to all processes in a group. + + See torch.distributed.reduce_scatter for more details. + """ + raise NotImplementedError("not implemented") + def size(self) -> int: raise NotImplementedError("not implemented") @@ -288,6 +303,14 @@ def allgather( def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work: return self.parent.broadcast(tensor_list, opts) + def reduce_scatter( + self, + output_tensors: List[torch.Tensor], + input_tensors: List[torch.Tensor], + opts: object, + ) -> Work: + return self.parent.reduce_scatter(output_tensors, input_tensors, opts) + def size(self) -> int: return self.parent.size() @@ -375,11 +398,6 @@ def __init__(self, rank: int, world: int) -> None: def configure(self, store_addr: str, rank: int, world_size: int) -> None: self.configure_count += 1 - def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work: - res = _DummyWork(tensor_list) - self._work.append(res) - return res - def allgather( self, output_tensors: List[List[torch.Tensor]], @@ -398,6 +416,24 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: self._work.append(res) return res + def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work: + res = _DummyWork(tensor_list) + self._work.append(res) + return res + + def reduce_scatter( + self, + output_tensors: List[torch.Tensor], + input_tensors: List[torch.Tensor], + opts: object, + ) -> Work: + for o, i in zip(output_tensors, input_tensors): + o.copy_(i) + + res = _DummyWork(output_tensors) + self._work.append(res) + return res + def size(self) -> int: return self._world @@ -960,6 +996,26 @@ def broadcast( return self._run_func("broadcast", tensor_list, opts) + def reduce_scatter( + self, + output_tensors: List[torch.Tensor], + input_tensors: List[List[torch.Tensor]], + opts: ReduceScatterOptions, + ) -> Work: + assert isinstance(output_tensors, list), "input must be list" + assert isinstance(input_tensors, list), "input must be list" + + for tensor in output_tensors: + if not tensor.is_shared(): + tensor.share_memory_() + + for tensor_list in input_tensors: + for tensor in tensor_list: + if not tensor.is_shared(): + tensor.share_memory_() + + return self._run_func("reduce_scatter", output_tensors, input_tensors, opts) + def size(self) -> int: return self._world_size diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index f7656259..ee2f04e8 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -23,6 +23,7 @@ AllreduceOptions, BroadcastOptions, ReduceOp, + ReduceScatterOptions, _resolve_process_group, ) from torch.distributed import ( @@ -94,6 +95,7 @@ def check_tensors(arg: Any) -> None: # pyre-ignore[2] ("allgather", (output_tensors, [input_tensor], AllgatherOptions())), ("broadcast", (tensor_list, BroadcastOptions())), ("broadcast_one", (input_tensor, 0)), + ("reduce_scatter", (output_tensors, [input_tensor], ReduceScatterOptions())), ] works: Dict[str, dist._Work] = {} for coll_str, args in collectives: From a42549318a279ca89c0c83f07ad86034ba6d3516 Mon Sep 17 00:00:00 2001 From: Allen Wang <9057208+allenwang28@users.noreply.github.com> Date: Thu, 6 Feb 2025 13:22:01 -0800 Subject: [PATCH 2/7] fixes reduce_scatter function signature, refactors test and adds reduce_scatter test --- torchft/process_group.py | 9 ++++---- torchft/process_group_test.py | 42 ++++++++++++++++++++++++++++++++++- 2 files changed, 45 insertions(+), 6 deletions(-) diff --git a/torchft/process_group.py b/torchft/process_group.py index a5de50e9..f883f546 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -185,7 +185,7 @@ def broadcast_one(self, tensor: torch.Tensor, root: int) -> Work: def reduce_scatter( self, output_tensors: List[torch.Tensor], - input_tensors: List[torch.Tensor], + input_tensors: List[List[torch.Tensor]], opts: ReduceScatterOptions, ) -> Work: """ @@ -306,7 +306,7 @@ def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work: def reduce_scatter( self, output_tensors: List[torch.Tensor], - input_tensors: List[torch.Tensor], + input_tensors: List[List[torch.Tensor]], opts: object, ) -> Work: return self.parent.reduce_scatter(output_tensors, input_tensors, opts) @@ -424,10 +424,10 @@ def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work: def reduce_scatter( self, output_tensors: List[torch.Tensor], - input_tensors: List[torch.Tensor], + input_tensors: List[List[torch.Tensor]], opts: object, ) -> Work: - for o, i in zip(output_tensors, input_tensors): + for o, i in zip(output_tensors, input_tensors[0]): o.copy_(i) res = _DummyWork(output_tensors) @@ -1013,7 +1013,6 @@ def reduce_scatter( for tensor in tensor_list: if not tensor.is_shared(): tensor.share_memory_() - return self._run_func("reduce_scatter", output_tensors, input_tensors, opts) def size(self) -> int: diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index ee2f04e8..a49f3352 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -61,6 +61,31 @@ def dummy_init_pg() -> None: ) +def _should_run_collective(collective_str: str, backend_str: str, device: str) -> bool: + """Verify if the collective is supported by the backend and device. + + See https://pytorch.org/docs/stable/distributed.html#backends for the + supported collectives / backends / devices matrix. + + """ + if "nccl" in backend_str.lower(): + # all collectives are supported for NCCL/CUDA but none on CPU. + return device == "cuda" + elif "gloo" in backend_str.lower(): + if device == "cuda": + # GLOO/GPU only supports broadcast and all_reduce. + if collective_str in ["broadcast", "all_reduce"]: + return True + return False + else: # cpu + if collective_str in ["reduce_scatter", "all_to_all"]: + return False + return True + else: + # Non defined backends (e.g. ErrorSwallowing) should continue to work. + return True + + def _test_pg( pg: ProcessGroup, example_tensor: torch.Tensor = torch.randn((2, 3), dtype=torch.float32), @@ -95,10 +120,25 @@ def check_tensors(arg: Any) -> None: # pyre-ignore[2] ("allgather", (output_tensors, [input_tensor], AllgatherOptions())), ("broadcast", (tensor_list, BroadcastOptions())), ("broadcast_one", (input_tensor, 0)), - ("reduce_scatter", (output_tensors, [input_tensor], ReduceScatterOptions())), + ( + "reduce_scatter", + (output_tensors[0], [[input_tensor]], ReduceScatterOptions()), + ), ] works: Dict[str, dist._Work] = {} + + try: + backend_str = pg.getBackendName() + device = example_tensor.device + if type(device) is torch.device: + device = device.type + except NotImplementedError as e: + backend_str = "" + device = "" + for coll_str, args in collectives: + if not _should_run_collective(coll_str, backend_str=backend_str, device=device): + continue coll = getattr(pg, coll_str) work = coll(*args) works[coll_str] = work From 51904141ace08d2af49e373a88a69ebfe88021c3 Mon Sep 17 00:00:00 2001 From: Allen Wang <9057208+allenwang28@users.noreply.github.com> Date: Thu, 6 Feb 2025 14:46:06 -0800 Subject: [PATCH 3/7] fixes test --- torchft/process_group.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/torchft/process_group.py b/torchft/process_group.py index f883f546..096e5706 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -1037,7 +1037,15 @@ def safe_args(cls, args: T) -> T: return tuple(cls.safe_args(arg) for arg in args) elif isinstance(args, list): return [cls.safe_args(arg) for arg in args] - elif isinstance(args, (AllreduceOptions, AllgatherOptions, BroadcastOptions)): + elif isinstance( + args, + ( + AllreduceOptions, + AllgatherOptions, + BroadcastOptions, + ReduceScatterOptions, + ), + ): return cls.from_torch(args) else: return args From 45fac866c3699baa14113d816e536c666dfb2595 Mon Sep 17 00:00:00 2001 From: Allen Wang <9057208+allenwang28@users.noreply.github.com> Date: Mon, 10 Feb 2025 13:15:12 -0800 Subject: [PATCH 4/7] adds explicit NotImplementedError to reduce_scatter in gloo, simplify the test suite --- torchft/process_group.py | 19 +++++++++++ torchft/process_group_test.py | 59 +++++++++-------------------------- 2 files changed, 33 insertions(+), 45 deletions(-) diff --git a/torchft/process_group.py b/torchft/process_group.py index 096e5706..8d066cb5 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -1091,6 +1091,25 @@ def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGrou def getBackendName(self) -> str: return "torchft-baby-gloo" + # pyre-fixme[15]: inconsistent override + def reduce_scatter( + self, + output_tensors: List[torch.Tensor], + input_tensors: List[List[torch.Tensor]], + opts: ReduceScatterOptions, + ) -> None: + """ + This function is a placeholder for the reduce_scatter operation in the + ProcessGroupGloo class. However, this operation is not supported by the + Gloo backend, and thus, calling this function will raise a + NotImplementedError. + + Raises: + NotImplementedError: Always raised since reduce_scatter is not + supported by ProcessGroupGloo. + """ + raise NotImplementedError("ProcessGroupGloo does not support reduce_scatter.") + class ProcessGroupBabyNCCL(ProcessGroupBaby): """ diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index a49f3352..f4ec36ea 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -61,31 +61,6 @@ def dummy_init_pg() -> None: ) -def _should_run_collective(collective_str: str, backend_str: str, device: str) -> bool: - """Verify if the collective is supported by the backend and device. - - See https://pytorch.org/docs/stable/distributed.html#backends for the - supported collectives / backends / devices matrix. - - """ - if "nccl" in backend_str.lower(): - # all collectives are supported for NCCL/CUDA but none on CPU. - return device == "cuda" - elif "gloo" in backend_str.lower(): - if device == "cuda": - # GLOO/GPU only supports broadcast and all_reduce. - if collective_str in ["broadcast", "all_reduce"]: - return True - return False - else: # cpu - if collective_str in ["reduce_scatter", "all_to_all"]: - return False - return True - else: - # Non defined backends (e.g. ErrorSwallowing) should continue to work. - return True - - def _test_pg( pg: ProcessGroup, example_tensor: torch.Tensor = torch.randn((2, 3), dtype=torch.float32), @@ -127,27 +102,21 @@ def check_tensors(arg: Any) -> None: # pyre-ignore[2] ] works: Dict[str, dist._Work] = {} - try: - backend_str = pg.getBackendName() - device = example_tensor.device - if type(device) is torch.device: - device = device.type - except NotImplementedError as e: - backend_str = "" - device = "" - for coll_str, args in collectives: - if not _should_run_collective(coll_str, backend_str=backend_str, device=device): - continue - coll = getattr(pg, coll_str) - work = coll(*args) - works[coll_str] = work - work.wait() - fut = work.get_future() - fut.wait() - - # Check that all tensor arguments have the expected shapes and dtypes - check_tensors(args) + try: + coll = getattr(pg, coll_str) + work = coll(*args) + works[coll_str] = work + work.wait() + fut = work.get_future() + fut.wait() + # Check that all tensor arguments have the expected shapes and dtypes + check_tensors(args) + except RuntimeError as e: + if f"does not support {coll_str}" in str(e): + # Skip collectives that are not supported by the backend. + continue + raise e print(works) return works From dc448ec613caa341830f1db5888d61254077bb70 Mon Sep 17 00:00:00 2001 From: Allen Wang <9057208+allenwang28@users.noreply.github.com> Date: Mon, 10 Feb 2025 13:19:50 -0800 Subject: [PATCH 5/7] fix tests after merge --- torchft/process_group_test.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index 8a2741f2..fb67457d 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -317,7 +317,7 @@ def test_baby_nccl_2gpu(self) -> None: store_addr: str = f"localhost:{store.port}/prefix" - def run(rank: int) -> Tuple[torch.Tensor, Work]: + def run(rank: int) -> Tuple[ProcessGroupBabyNCCL, torch.Tensor, Work]: a = ProcessGroupBabyNCCL( timeout=timedelta(seconds=10.0), ) @@ -329,19 +329,29 @@ def run(rank: int) -> Tuple[torch.Tensor, Work]: at = torch.tensor([rank + 1], device="cuda") a_work = a.allreduce([at], ReduceOp.SUM) - return at, a_work + return a, at, a_work with ThreadPoolExecutor(max_workers=2) as executor: a_fut = executor.submit(run, 0) b_fut = executor.submit(run, 1) - at, a_work = a_fut.result() - bt, b_work = b_fut.result() - - a_work.wait() - b_work.get_future().wait() + a, at, a_work = a_fut.result() + b, bt, b_work = b_fut.result() - torch.testing.assert_close(at.cpu(), bt.cpu()) + try: + a_work.wait() + b_work.get_future().wait() + torch.testing.assert_close(at.cpu(), bt.cpu()) + finally: + # cleanup - first ensure that babywork is deleted before shutting down PGs + # note futures must be deleted as they hold references to babywork + del a_fut + del b_fut + del a_work + del b_work + gc.collect() + b.shutdown() + a.shutdown() def test_device_mesh(self) -> None: os.environ["MASTER_ADDR"] = "localhost" From f8d2ac595b39390af6edc58c4a2d9605ae986553 Mon Sep 17 00:00:00 2001 From: Allen Wang <9057208+allenwang28@users.noreply.github.com> Date: Mon, 10 Feb 2025 13:50:57 -0800 Subject: [PATCH 6/7] add explicit error for ProcessGroupGloo --- torchft/process_group.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/torchft/process_group.py b/torchft/process_group.py index d0af61ba..4da6403f 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -318,6 +318,25 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro def getBackendName(self) -> str: return "torchft-gloo" + # pyre-fixme[14,15]: inconsistent override + def reduce_scatter( + self, + output_tensors: List[torch.Tensor], + input_tensors: List[List[torch.Tensor]], + opts: ReduceScatterOptions, + ) -> None: + """ + This function is a placeholder for the reduce_scatter operation in the + ProcessGroupGloo class. However, this operation is not supported by the + Gloo backend, and thus, calling this function will raise a + NotImplementedError. + + Raises: + NotImplementedError: Always raised since reduce_scatter is not + supported by ProcessGroupGloo. + """ + raise NotImplementedError("ProcessGroupGloo does not support reduce_scatter.") + class ProcessGroupNCCL(ProcessGroupWrapper): """ @@ -1118,7 +1137,9 @@ def reduce_scatter( NotImplementedError: Always raised since reduce_scatter is not supported by ProcessGroupGloo. """ - raise NotImplementedError("ProcessGroupGloo does not support reduce_scatter.") + raise NotImplementedError( + "ProcessGroupBabyGloo does not support reduce_scatter." + ) class ProcessGroupBabyNCCL(ProcessGroupBaby): From 7aaf7db5b842690af28062a36f7514ed29028d68 Mon Sep 17 00:00:00 2001 From: Allen Wang <9057208+allenwang28@users.noreply.github.com> Date: Mon, 10 Feb 2025 13:55:48 -0800 Subject: [PATCH 7/7] notimplementederror->runtimeerror --- torchft/process_group.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/torchft/process_group.py b/torchft/process_group.py index 4da6403f..b38d2914 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -329,13 +329,13 @@ def reduce_scatter( This function is a placeholder for the reduce_scatter operation in the ProcessGroupGloo class. However, this operation is not supported by the Gloo backend, and thus, calling this function will raise a - NotImplementedError. + RuntimeError. Raises: - NotImplementedError: Always raised since reduce_scatter is not + RuntimeError: Always raised since reduce_scatter is not supported by ProcessGroupGloo. """ - raise NotImplementedError("ProcessGroupGloo does not support reduce_scatter.") + raise RuntimeError("ProcessGroupGloo does not support reduce_scatter.") class ProcessGroupNCCL(ProcessGroupWrapper): @@ -1131,15 +1131,13 @@ def reduce_scatter( This function is a placeholder for the reduce_scatter operation in the ProcessGroupGloo class. However, this operation is not supported by the Gloo backend, and thus, calling this function will raise a - NotImplementedError. + RuntimeError. Raises: - NotImplementedError: Always raised since reduce_scatter is not + RuntimeError: Always raised since reduce_scatter is not supported by ProcessGroupGloo. """ - raise NotImplementedError( - "ProcessGroupBabyGloo does not support reduce_scatter." - ) + raise RuntimeError("ProcessGroupBabyGloo does not support reduce_scatter.") class ProcessGroupBabyNCCL(ProcessGroupBaby):