diff --git a/torchft/process_group.py b/torchft/process_group.py index f6ac2d4..e44ac11 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -611,6 +611,114 @@ def reduce_scatter_tensor_coalesced( ) +class _ParallelWork(Work): + def __init__(self, works: List[Work]) -> None: + super().__init__() + self._works = works + + def wait(self, timeout: Optional[timedelta] = None) -> bool: + for work in self._works: + if timeout is not None: + work.wait(timeout=timeout) + else: + work.wait() + return True + + def get_future(self) -> torch.futures.Future[object]: + futures = [work.get_future() for work in self._works] + return torch.futures.collect_all(futures) + + +class ParallelProcessGroup(ProcessGroupWrapper): + def __init__( + self, + base: ProcessGroupWrapper, + timeout: timedelta = timedelta(seconds=60), + count: int = 10, + ) -> None: + super().__init__(timeout=timeout) + + self._base = base + self._count = count + self._pgs = [] + + self._create_pg = base._create_pg + + def configure(self, store_addr: str, rank: int, world_size: int) -> None: + # abort if already initialized + self.abort() + + self._pgs = [] + + for i in range(self._count): + store = create_store_client( + f"{store_addr}/parallel{i}", timeout=self._timeout + ) + + self._pgs.append(self._create_pg(store, rank, world_size)) + + self._pg = self._pgs[0] + + def getBackendName(self) -> str: + return f"{self._base.getBackendName()}-parallel" + + def _split_tensors(self, tensors: List[torch.Tensor]) -> List[List[torch.Tensor]]: + if not isinstance(tensors, (list, tuple)): + tensors = [tensors] + + tensor_lists = [[] for _ in range(self._count)] + for t in tensors: + chunks = torch.tensor_split(t.view(-1), self._count, dim=0) + for i, chunk in enumerate(chunks): + tensor_lists[i].append(chunk) + + return tensor_lists + + def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: + tensor_lists = self._split_tensors(tensors) + + with self._run_context(): + works = [] + for i in range(self._count): + works.append( + self._pgs[i].allreduce(tensor_lists[i], self._opts_hook(opts)) + ) + + return self._wrap_work(_ParallelWork(works), opts) + + def reduce(self, tensors: List[torch.Tensor], dst: int, opts: object) -> Work: + tensor_lists = self._split_tensors(tensors) + + with self._run_context(): + works = [] + for i in range(self._count): + works.append( + self._pgs[i].reduce(tensor_lists[i], dst, self._opts_hook(opts)) + ) + + return self._wrap_work(_ParallelWork(works), opts) + + def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work: + tensor_lists = self._split_tensors(tensors) + + with self._run_context(): + works = [] + for i in range(self._count): + works.append(self._pgs[i].send(tensor_lists[i], dst_rank, tag)) + + return self._wrap_work(_ParallelWork(works), None) + + def recv(self, tensors: List[torch.Tensor], src_rank: int, tag: int) -> Work: + tensor_lists = self._split_tensors(tensors) + + with self._run_context(): + works = [] + for i in range(self._count): + works.append(self._pgs[i].recv(tensor_lists[i], src_rank, tag)) + + return self._wrap_work(_ParallelWork(works), None) + + class _WorkCUDATimeout(Work): def __init__(self, pg: ProcessGroup, work: Work, timeout: timedelta) -> None: super().__init__() diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index 4c3455d..e31041a 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -40,6 +40,7 @@ from torchft.process_group import ( ErrorSwallowingProcessGroupWrapper, ManagedProcessGroup, + ParallelProcessGroup, ProcessGroup, ProcessGroupBabyGloo, ProcessGroupBabyNCCL, @@ -690,6 +691,29 @@ def test_baby_gloo_apis(self) -> None: with self.assertRaisesRegex(OSError, "handle is closed"): a.allreduce([t], AllreduceOptions()).wait() + def test_parallel_gloo_apis(self) -> None: + dummy_init_pg() + + store = TCPStore( + host_name="localhost", port=0, is_master=True, wait_for_workers=False + ) + + store_addr = f"localhost:{store.port}/prefix" + + a = ParallelProcessGroup( + base=ProcessGroupGloo(), + count=4, + ) + a.configure(store_addr, 0, 1) + a.register("test_parallel_gloo_apis") + + _test_pg( + a, + skip=("reduce_scatter_tensor_coalesced"), + ) + + a.unregister() + # pyre-fixme[56]: Pyre was not able to infer the type of argument @skipUnless(torch.cuda.is_available(), "needs CUDA") def test_baby_nccl_apis(self) -> None: