diff --git a/torchft/process_group.py b/torchft/process_group.py index 540633b3..b38d2914 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -57,6 +57,7 @@ AllreduceOptions, BroadcastOptions, ReduceOp, + ReduceScatterOptions, Work, ) from torch.futures import Future @@ -159,6 +160,20 @@ def broadcast_one(self, tensor: torch.Tensor, root: int) -> Work: opts.rootRank = root return self.broadcast([tensor], opts) + # pyre-fixme[14]: inconsistent override + def reduce_scatter( + self, + output_tensors: List[torch.Tensor], + input_tensors: List[List[torch.Tensor]], + opts: ReduceScatterOptions, + ) -> Work: + """ + Reduces, then scatters a list of tensors to all processes in a group. + + See torch.distributed.reduce_scatter for more details. + """ + raise NotImplementedError("not implemented") + def size(self) -> int: raise NotImplementedError("not implemented") @@ -267,6 +282,14 @@ def allgather( def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work: return self.parent.broadcast(tensor_list, opts) + def reduce_scatter( + self, + output_tensors: List[torch.Tensor], + input_tensors: List[List[torch.Tensor]], + opts: object, + ) -> Work: + return self.parent.reduce_scatter(output_tensors, input_tensors, opts) + def size(self) -> int: return self.parent.size() @@ -295,6 +318,25 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro def getBackendName(self) -> str: return "torchft-gloo" + # pyre-fixme[14,15]: inconsistent override + def reduce_scatter( + self, + output_tensors: List[torch.Tensor], + input_tensors: List[List[torch.Tensor]], + opts: ReduceScatterOptions, + ) -> None: + """ + This function is a placeholder for the reduce_scatter operation in the + ProcessGroupGloo class. However, this operation is not supported by the + Gloo backend, and thus, calling this function will raise a + RuntimeError. + + Raises: + RuntimeError: Always raised since reduce_scatter is not + supported by ProcessGroupGloo. + """ + raise RuntimeError("ProcessGroupGloo does not support reduce_scatter.") + class ProcessGroupNCCL(ProcessGroupWrapper): """ @@ -354,11 +396,6 @@ def __init__(self, rank: int, world: int) -> None: def configure(self, store_addr: str, rank: int, world_size: int) -> None: self.configure_count += 1 - def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work: - res = _DummyWork(tensor_list) - self._work.append(res) - return res - def allgather( self, output_tensors: List[List[torch.Tensor]], @@ -377,6 +414,24 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: self._work.append(res) return res + def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work: + res = _DummyWork(tensor_list) + self._work.append(res) + return res + + def reduce_scatter( + self, + output_tensors: List[torch.Tensor], + input_tensors: List[List[torch.Tensor]], + opts: object, + ) -> Work: + for o, i in zip(output_tensors, input_tensors[0]): + o.copy_(i) + + res = _DummyWork(output_tensors) + self._work.append(res) + return res + def size(self) -> int: return self._world @@ -970,6 +1025,25 @@ def broadcast( return self._run_func("broadcast", tensor_list, opts) + def reduce_scatter( + self, + output_tensors: List[torch.Tensor], + input_tensors: List[List[torch.Tensor]], + opts: ReduceScatterOptions, + ) -> Work: + assert isinstance(output_tensors, list), "input must be list" + assert isinstance(input_tensors, list), "input must be list" + + for tensor in output_tensors: + if not tensor.is_shared(): + tensor.share_memory_() + + for tensor_list in input_tensors: + for tensor in tensor_list: + if not tensor.is_shared(): + tensor.share_memory_() + return self._run_func("reduce_scatter", output_tensors, input_tensors, opts) + def size(self) -> int: return self._world_size @@ -992,7 +1066,15 @@ def safe_args(cls, args: T) -> T: return tuple(cls.safe_args(arg) for arg in args) elif isinstance(args, list): return [cls.safe_args(arg) for arg in args] - elif isinstance(args, (AllreduceOptions, AllgatherOptions, BroadcastOptions)): + elif isinstance( + args, + ( + AllreduceOptions, + AllgatherOptions, + BroadcastOptions, + ReduceScatterOptions, + ), + ): return cls.from_torch(args) else: return args @@ -1038,6 +1120,25 @@ def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGrou def getBackendName(self) -> str: return "torchft-baby-gloo" + # pyre-fixme[15]: inconsistent override + def reduce_scatter( + self, + output_tensors: List[torch.Tensor], + input_tensors: List[List[torch.Tensor]], + opts: ReduceScatterOptions, + ) -> None: + """ + This function is a placeholder for the reduce_scatter operation in the + ProcessGroupGloo class. However, this operation is not supported by the + Gloo backend, and thus, calling this function will raise a + RuntimeError. + + Raises: + RuntimeError: Always raised since reduce_scatter is not + supported by ProcessGroupGloo. + """ + raise RuntimeError("ProcessGroupBabyGloo does not support reduce_scatter.") + class ProcessGroupBabyNCCL(ProcessGroupBaby): """ diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index e49f5a47..fb67457d 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -23,6 +23,7 @@ AllreduceOptions, BroadcastOptions, ReduceOp, + ReduceScatterOptions, _resolve_process_group, ) from torch.distributed import ( @@ -94,18 +95,28 @@ def check_tensors(arg: Any) -> None: # pyre-ignore[2] ("allgather", (output_tensors, [input_tensor], AllgatherOptions())), ("broadcast", (tensor_list, BroadcastOptions())), ("broadcast_one", (input_tensor, 0)), + ( + "reduce_scatter", + (output_tensors[0], [[input_tensor]], ReduceScatterOptions()), + ), ] works: Dict[str, dist._Work] = {} - for coll_str, args in collectives: - coll = getattr(pg, coll_str) - work = coll(*args) - works[coll_str] = work - work.wait() - fut = work.get_future() - fut.wait() - # Check that all tensor arguments have the expected shapes and dtypes - check_tensors(args) + for coll_str, args in collectives: + try: + coll = getattr(pg, coll_str) + work = coll(*args) + works[coll_str] = work + work.wait() + fut = work.get_future() + fut.wait() + # Check that all tensor arguments have the expected shapes and dtypes + check_tensors(args) + except RuntimeError as e: + if f"does not support {coll_str}" in str(e): + # Skip collectives that are not supported by the backend. + continue + raise e print(works) return works @@ -306,7 +317,7 @@ def test_baby_nccl_2gpu(self) -> None: store_addr: str = f"localhost:{store.port}/prefix" - def run(rank: int) -> Tuple[torch.Tensor, Work]: + def run(rank: int) -> Tuple[ProcessGroupBabyNCCL, torch.Tensor, Work]: a = ProcessGroupBabyNCCL( timeout=timedelta(seconds=10.0), ) @@ -318,19 +329,29 @@ def run(rank: int) -> Tuple[torch.Tensor, Work]: at = torch.tensor([rank + 1], device="cuda") a_work = a.allreduce([at], ReduceOp.SUM) - return at, a_work + return a, at, a_work with ThreadPoolExecutor(max_workers=2) as executor: a_fut = executor.submit(run, 0) b_fut = executor.submit(run, 1) - at, a_work = a_fut.result() - bt, b_work = b_fut.result() - - a_work.wait() - b_work.get_future().wait() + a, at, a_work = a_fut.result() + b, bt, b_work = b_fut.result() - torch.testing.assert_close(at.cpu(), bt.cpu()) + try: + a_work.wait() + b_work.get_future().wait() + torch.testing.assert_close(at.cpu(), bt.cpu()) + finally: + # cleanup - first ensure that babywork is deleted before shutting down PGs + # note futures must be deleted as they hold references to babywork + del a_fut + del b_fut + del a_work + del b_work + gc.collect() + b.shutdown() + a.shutdown() def test_device_mesh(self) -> None: os.environ["MASTER_ADDR"] = "localhost"