diff --git a/CMakeLists.txt b/CMakeLists.txt index 57f64e4dbea1a..75db94cba4dcf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -276,6 +276,8 @@ option(USE_NATIVE_ARCH "Use -march=native" OFF) cmake_dependent_option(USE_MPS "Use MPS for macOS build" ON "MPS_FOUND" OFF) cmake_dependent_option(USE_NCCL "Use NCCL" ON "USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF) +cmake_dependent_option(USE_XCCL "Use XCCL" ON + "USE_XPU;UNIX;NOT APPLE" OFF) cmake_dependent_option(USE_RCCL "Use RCCL" ON USE_NCCL OFF) cmake_dependent_option(USE_STATIC_NCCL "Use static NCCL" OFF "USE_NCCL" OFF) cmake_dependent_option(USE_SYSTEM_NCCL "Use system-wide NCCL" OFF "USE_NCCL" @@ -352,6 +354,8 @@ cmake_dependent_option( USE_C10D_GLOO "USE C10D GLOO" ON "USE_DISTRIBUTED;USE_GLOO" OFF) cmake_dependent_option( USE_C10D_NCCL "USE C10D NCCL" ON "USE_DISTRIBUTED;USE_NCCL" OFF) +cmake_dependent_option( + USE_C10D_XCCL "USE C10D XCCL" ON "USE_DISTRIBUTED;USE_XCCL" OFF) cmake_dependent_option( USE_C10D_MPI "USE C10D MPI" ON "USE_DISTRIBUTED;USE_MPI" OFF) cmake_dependent_option( diff --git a/build_variables.bzl b/build_variables.bzl index 9cb351a4a090f..1000459a044d9 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -704,6 +704,10 @@ libtorch_cuda_sources = libtorch_cuda_core_sources + libtorch_cuda_distributed_s "torch/csrc/cuda/nccl.cpp", ] +libtorch_xpu_distributed_extra_sources = [ + "torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp", +] + torch_cpp_srcs = [ "torch/csrc/api/src/cuda.cpp", # this just forwards stuff, no real CUDA "torch/csrc/api/src/data/datasets/mnist.cpp", diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index dbd765ab44b13..02cc44c66d62f 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1049,6 +1049,13 @@ elseif(USE_CUDA) endif() if(USE_XPU) + # if SYCL runtime and oneCCL runtime are both system installed + # then building flag USE_XPU=ON , USE_XCCL=ON and USE_C10D_XCCL=ON; + # XCCL backend will be build in libtorch_xpu; + # manually set `USE_XCCL=OFF` disable XCCL backend building. + if(USE_XCCL) + append_filelist("libtorch_xpu_distributed_extra_sources" Caffe2_XPU_SRCS) + endif() list(APPEND Caffe2_XPU_SRCS ${GENERATED_CXX_TORCH_XPU}) add_library(torch_xpu ${Caffe2_XPU_SRCS}) torch_compile_options(torch_xpu) # see cmake/public/utils.cmake @@ -1118,6 +1125,10 @@ if(USE_XPU) include_directories(SYSTEM ${ATen_XPU_INCLUDE_DIRS}) endif() + if(USE_XCCL) + target_link_libraries(torch_xpu PRIVATE torch::xccl) + target_compile_definitions(torch_xpu PRIVATE USE_XCCL) + endif() endif() if(NOT MSVC AND USE_XNNPACK) @@ -1404,6 +1415,9 @@ if(USE_DISTRIBUTED) target_compile_definitions(torch_cuda PUBLIC USE_C10D_NCCL) endif() endif() + if(USE_XPU AND USE_C10D_XCCL) + target_compile_definitions(torch_xpu PUBLIC USE_C10D_XCCL) + endif() if(USE_MPI AND USE_C10D_MPI) if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU") set_source_files_properties( diff --git a/caffe2/core/macros.h.in b/caffe2/core/macros.h.in index 2929f105b31fa..e5398a83cad94 100644 --- a/caffe2/core/macros.h.in +++ b/caffe2/core/macros.h.in @@ -45,6 +45,7 @@ {"USE_CUDNN", "${USE_CUDNN}"}, \ {"CUDNN_VERSION", "${CUDNN_VERSION}"}, \ {"USE_NCCL", "${USE_NCCL}"}, \ + {"USE_XCCL", "${USE_XCCL}"}, \ {"USE_MPI", "${USE_MPI}"}, \ {"USE_GFLAGS", "${USE_GFLAGS}"}, \ {"USE_GLOG", "${USE_GLOG}"}, \ diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index ba9625edf876a..a009033ba0aa2 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1123,6 +1123,24 @@ if(USE_CUDA) include_directories(SYSTEM ${CUB_INCLUDE_DIRS}) endif() +# ---[ XCCL +if(USE_XCCL) + if(NOT USE_XPU) + message(WARNING + "Not using XPU, so disabling USE_XCCL. Suppress this warning with " + "-DUSE_XCCL=OFF.") + caffe2_update_option(USE_XCCL OFF) + elseif(NOT CMAKE_SYSTEM_NAME STREQUAL "Linux") + message(WARNING "USE_XCCL is currently only supported under Linux.") + caffe2_update_option(USE_XCCL OFF) + else() + include(${CMAKE_CURRENT_LIST_DIR}/External/xccl.cmake) + if(NOT XCCL_FOUND) + caffe2_update_option(USE_XCCL OFF) + endif() + endif() +endif() + if(USE_DISTRIBUTED AND USE_TENSORPIPE) if(MSVC) message(WARNING "Tensorpipe cannot be used on Windows.") diff --git a/cmake/External/xccl.cmake b/cmake/External/xccl.cmake new file mode 100644 index 0000000000000..acb7cee87593e --- /dev/null +++ b/cmake/External/xccl.cmake @@ -0,0 +1,15 @@ +if(NOT __XCCL_INCLUDED) + set(__XCCL_INCLUDED TRUE) + + # XCCL_ROOT, XCCL_LIBRARY_DIR, XCCL_INCLUDE_DIR are handled by FindXCCL.cmake. + find_package(XCCL REQUIRED) + if(XCCL_FOUND) + add_library(torch::xccl INTERFACE IMPORTED) + set_property( + TARGET torch::xccl PROPERTY INTERFACE_INCLUDE_DIRECTORIES + ${XCCL_INCLUDE_DIR}) + set_property( + TARGET torch::xccl PROPERTY INTERFACE_LINK_LIBRARIES + ${XCCL_LIBRARY}) + endif() +endif() diff --git a/cmake/Modules/FindXCCL.cmake b/cmake/Modules/FindXCCL.cmake new file mode 100644 index 0000000000000..18f7ac642d54e --- /dev/null +++ b/cmake/Modules/FindXCCL.cmake @@ -0,0 +1,69 @@ +# This will define the following variables: +# XCCL_FOUND : True if the system has the XCCL library. +# XCCL_INCLUDE_DIR : Include directories needed to use XCCL. +# XCCL_LIBRARY_DIR :The path to the XCCL library. +# XCCL_LIBRARY : XCCL library fullname. + +include(FindPackageHandleStandardArgs) + +set(XCCL_ROOT "/opt/intel/oneapi/ccl/latest") +if (NOT EXISTS "${XCCL_ROOT}") + message(STATUS "Default OneCCL not found, using current environment OneAPI") + set(XCCL_ROOT $ENV{ONEAPI_ROOT}/ccl/latest) +endif() + +string(COMPARE EQUAL "${XCCL_ROOT}" "" nocclfound) +if(nocclfound) + set(XCCL_FOUND False) + set(XCCL_REASON_FAILURE "OneCCL library not found!!") + set(XCCL_NOT_FOUND_MESSAGE "${XCCL_REASON_FAILURE}") + return() +endif() + +# Find include path from binary. +find_file( + XCCL_INCLUDE_DIR + NAMES include + HINTS ${XCCL_ROOT} + NO_DEFAULT_PATH +) + +# Find include/oneapi path from include path. +find_file( + XCCL_INCLUDE_ONEAPI_DIR + NAMES oneapi + HINTS ${XCCL_ROOT}/include/ + NO_DEFAULT_PATH +) + +list(APPEND XCCL_INCLUDE_DIR ${XCCL_INCLUDE_ONEAPI_DIR}) + +# Find library directory from binary. +find_file( + XCCL_LIBRARY_DIR + NAMES lib + HINTS ${XCCL_ROOT} + NO_DEFAULT_PATH +) + +# Find XCCL library fullname. +find_library( + XCCL_LIBRARY + NAMES ccl + HINTS ${XCCL_LIBRARY_DIR} + NO_DEFAULT_PATH +) + +if((NOT XCCL_INCLUDE_DIR) OR (NOT XCCL_LIBRARY_DIR) OR (NOT XCCL_LIBRARY)) + set(XCCL_FOUND False) + set(XCCL_REASON_FAILURE "OneCCL library not found!!") + set(XCCL_NOT_FOUND_MESSAGE "${XCCL_REASON_FAILURE}") + return() +endif() + +find_package_handle_standard_args( + XCCL + FOUND_VAR XCCL_FOUND + REQUIRED_VARS XCCL_INCLUDE_DIR XCCL_LIBRARY_DIR XCCL_LIBRARY + REASON_FAILURE_MESSAGE "${XCCL_REASON_FAILURE}" +) diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index 72be8b17d8911..f3d52995a45de 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -159,6 +159,12 @@ function(caffe2_print_configuration_summary) message(STATUS " USE_SYSTEM_UCC : ${USE_SYSTEM_UCC}") endif() message(STATUS " USE_ITT : ${USE_ITT}") + message(STATUS " USE_XCCL : ${USE_XCCL}") + if(${USE_XCCL}) + message(STATUS " USE_C10D_XCCL : ${USE_C10D_XCCL}") + message(STATUS " XCCL include path : ${XCCL_INCLUDE_DIR}") + message(STATUS " XCCL library : ${XCCL_LIBRARY}") + endif() message(STATUS " USE_NCCL : ${USE_NCCL}") if(${USE_NCCL}) message(STATUS " USE_SYSTEM_NCCL : ${USE_SYSTEM_NCCL}") diff --git a/setup.py b/setup.py index 26c72625cbb72..07464d308eaec 100644 --- a/setup.py +++ b/setup.py @@ -658,6 +658,10 @@ def run(self): report("-- Building NCCL library") else: report("-- Not using NCCL") + if cmake_cache_vars["USE_XCCL"]: + report("-- Building XCCL library") + else: + report("-- Not using XCCL") if cmake_cache_vars["USE_DISTRIBUTED"]: if IS_WINDOWS: report("-- Building without distributed package") diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index 2542ecf864da6..2096ce9ed68a4 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -31,6 +31,7 @@ from torch.testing._internal.common_distributed import ( MultiProcessTestCase, skip_if_lt_x_gpu, + get_device_count, ) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -60,14 +61,15 @@ torch.backends.cuda.matmul.allow_tf32 = False -def gpus_for_rank(world_size): +def gpus_for_rank(world_size, backend): """Multigpu tests are designed to simulate the multi nodes with multi GPUs on each node. Nccl backend requires equal #GPUs in each process. On a single node, all visible GPUs are evenly divided to subsets, each process only uses a subset. """ - visible_devices = list(range(torch.cuda.device_count())) - gpus_per_process = torch.cuda.device_count() // world_size + device_count = get_device_count(backend) + visible_devices = list(range(device_count)) + gpus_per_process = device_count // world_size gpus_for_rank = [] for rank in range(world_size): gpus_for_rank.append( @@ -828,7 +830,7 @@ def update_parameters(model): def _gpu_model_with_ddp_comm_hook( self, process_group, hook=None, gradient_as_bucket_view=False, state=None ): - device_id = gpus_for_rank(self.world_size)[self.rank][0] + device_id = gpus_for_rank(self.world_size, process_group.name())[self.rank][0] gpu_model = DistributedDataParallel( ModuleForDdpCommHook().to(device_id), device_ids=[device_id], @@ -845,7 +847,7 @@ def _gpu_model_with_ddp_comm_hook( def _gpu_model_with_builtin_ddp_comm_hook( self, process_group, hook=None, gradient_as_bucket_view=False ): - device_id = gpus_for_rank(self.world_size)[self.rank][0] + device_id = gpus_for_rank(self.world_size, process_group.name())[self.rank][0] gpu_model = DistributedDataParallel( ModuleForDdpCommHook().to(device_id), device_ids=[device_id], @@ -1834,6 +1836,9 @@ def test_init_process_group_for_all_backends(self): elif backend == dist.Backend.UCC: if not dist.is_ucc_available(): continue + elif backend == dist.Backend.XCCL: + if not dist.is_xccl_available(): + continue # Multi-threaded PG is defined as a pure python class. # Its pg.name() does not going through Pybind, so its backend name # is still "threaded" instead of "custom". diff --git a/test/distributed/test_c10d_ops_xccl.py b/test/distributed/test_c10d_ops_xccl.py new file mode 100644 index 0000000000000..6a600aa595f7e --- /dev/null +++ b/test/distributed/test_c10d_ops_xccl.py @@ -0,0 +1,831 @@ +# Owner(s): ["oncall: distributed"] +# This test file contains positive tests for c10d with XCCL backend. +# During the test, it is expected that ProcessGroup will not be aborted, destroyed or incur fatal error. +# Please be mindful of this when adding tests here. +# If you need to add tests for group creation, abort or destroy, please add tests in test_c10d_xccl.py. + +# There are two ways to launch tests in this file: +# 1. Run this file directly with `python test_c10d_ops_xccl.py` +# 2. Use multi-process launcher, e.g. `torchrun --standalone --nproc-per-node 2 test_c10d_ops_xccl.py` + +import math +import os +import sys +import tempfile + +import torch +import torch.distributed as c10d + + +if not c10d.is_available() or not c10d.is_xccl_available(): + print("c10d XCCL not available, skipping tests", file=sys.stderr) + sys.exit(0) + + +import torch.distributed as dist +from torch.testing._internal.common_distributed import ( + init_multigpu_helper, + MultiProcContinousTest, + requires_xccl, +) +from torch.testing._internal.common_utils import ( + skip_but_pass_in_sandcastle_if, + skipIfRocm, + TEST_WITH_DEV_DBG_ASAN, + TEST_XPU, +) + + +if TEST_WITH_DEV_DBG_ASAN: + print( + "Skip ASAN as torch + multiprocessing spawn have known issues", file=sys.stderr + ) + sys.exit(0) + +TEST_MULTIGPU = TEST_XPU and torch.xpu.device_count() >= 2 + + +class ProcessGroupXCCLOpTest(MultiProcContinousTest): + @classmethod + def backend_str(cls) -> str: + return "xccl" + + # @classmethod + # def opts(cls): + # opts = c10d.ProcessGroupXCCL.Options() + # return opts + + @property + def rank_to_GPU(self): + # return rank to GPU map + return init_multigpu_helper(self.world_size, "xccl") + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + def test_empty_tensors(self): + pg = self.pg + local_device_idx = self.rank_to_GPU[self.rank][0] + + xs = [torch.FloatTensor([]).xpu(local_device_idx)] + pg.broadcast(xs).wait() + self.assertEqual(0, xs[0].numel()) + + pg.allreduce(xs).wait() + self.assertEqual(0, xs[0].numel()) + + pg.reduce(xs).wait() + self.assertEqual(0, xs[0].numel()) + + ys = [ + [ + torch.FloatTensor([]).xpu(local_device_idx) + for _ in range(self.world_size) + ] + ] + pg.allgather(ys, xs).wait() + for y in ys[0]: + self.assertEqual(0, y.numel()) + + ys = [torch.FloatTensor([]).xpu(local_device_idx)] + xs = [ + [ + torch.FloatTensor([]).xpu(local_device_idx) + for _ in range(self.world_size) + ] + ] + pg.reduce_scatter(ys, xs).wait() + self.assertEqual(0, ys[0].numel()) + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + def test_broadcast_ops(self): + pg = self.pg + + def broadcast(xs, rootRank, rootTensor): + opts = c10d.BroadcastOptions() + opts.rootRank = rootRank + opts.rootTensor = rootTensor + work = pg.broadcast(xs, opts) + work.wait() + return xs + + # Every rank is root once + for i in range(self.world_size): + # Run with 1 input tensor + x = torch.tensor([self.rank]).xpu(self.rank_to_GPU[self.rank][0]) + output = broadcast([x], i, 0) + self.assertEqual(torch.tensor([i]), output[0]) + + expected_tensor = torch.empty([i + 1, i + 1]).fill_(i + 1) + xs = [ + torch.empty([i + 1, i + 1]).fill_(-1).xpu(device=device_idx) + for device_idx in self.rank_to_GPU[self.rank] + ] + + # test with multiple input tensors (multiple gpu in one rank) + for j in range(len(xs)): + if self.rank == i: + xs[j] = expected_tensor.xpu(device=self.rank_to_GPU[self.rank][j]) + + broadcast(xs, i, j) + + for tensor in xs: + self.assertEqual(tensor, expected_tensor) + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + def test_allreduce_ops(self): + device_count = torch.xpu.device_count() + pg = self.pg + local_device_id = self.rank_to_GPU[self.rank][0] + + def allreduce(tensors, op): + opts = c10d.AllreduceOptions() + opts.reduceOp = op + work = pg.allreduce(tensors, opts) + work.wait() + + # Sum + tensors = [torch.tensor([self.rank + 1]).xpu(local_device_id)] + + allreduce(tensors, c10d.ReduceOp.SUM) + + ndev = self.world_size + self.assertEqual( + torch.tensor([ndev * (ndev + 1) // 2]), + tensors[0], + ) + + # Avg + tensors = [torch.tensor([self.rank + 1.0]).xpu(local_device_id)] + + allreduce(tensors, c10d.ReduceOp.AVG) + ndev = self.world_size + self.assertEqual( + torch.tensor([ndev * (ndev + 1.0) / (2.0 * ndev)]), + tensors[0], + ) + + # Product + tensors = [torch.tensor([self.rank + 1]).xpu(local_device_id)] + + allreduce(tensors, c10d.ReduceOp.PRODUCT) + self.assertEqual(torch.tensor([math.factorial(self.world_size)]), tensors[0]) + + # Min + tensors = [torch.tensor([self.rank + 1]).xpu(local_device_id)] + + allreduce(tensors, c10d.ReduceOp.MIN) + self.assertEqual(torch.tensor([1]), tensors[0]) + + # Max + tensors = [torch.tensor([self.rank + 1]).xpu(local_device_id)] + + allreduce(tensors, c10d.ReduceOp.MAX) + self.assertEqual(torch.tensor([self.world_size]), tensors[0]) + + for op, err in zip( + (c10d.ReduceOp.BAND, c10d.ReduceOp.BOR, c10d.ReduceOp.BXOR), + ("ReduceOp.BAND", "ReduceOp.BOR", "ReduceOp.BXOR"), + ): + with self.assertRaisesRegex(ValueError, "Cannot use " + err + " with XCCL"): + allreduce(tensors, op) + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + def test_alltoall_ops_with_xpufree_race(self): + pg = self.pg + opts = c10d.AllToAllOptions() + local_device = f"xpu:{self.rank_to_GPU[self.rank][0]}" + torch.xpu.set_device(local_device) + input = torch.rand(1000, 1000, device=local_device) + output = torch.rand(1000, 1000, device=local_device) + race_tensors = [] + # create some tensors to race with alltoall collective + for _ in range(10): + tmp = [] + for i in range(5): + tmp.append(torch.rand(10 ** (3 + i), device=local_device)) + race_tensors.append(tmp) + + for i in range(10): + race_tensors.pop() + work = pg.alltoall_base(output, input, [], [], opts) + # this triggers xpuFree + torch.xpu.empty_cache() + work.wait() + torch.xpu.synchronize(device=local_device) + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + def test_reduce_ops(self): + pg = self.pg + local_device_id = self.rank_to_GPU[self.rank][0] + + def reduce(xs, rootRank, rootTensor, op=None): + opts = c10d.ReduceOptions() + opts.rootRank = rootRank + opts.rootTensor = rootTensor + if op: + opts.reduceOp = op + work = pg.reduce(xs, opts) + work.wait() + + # for every root tensor + for rt in range(self.world_size): + tensors = [torch.tensor([self.rank + 1]).xpu(local_device_id)] + + reduce(tensors, rt, 0) + + if self.rank == rt: + self.assertEqual( + torch.tensor([self.world_size * (self.world_size + 1) // 2]), + tensors[0], + ) + else: + self.assertEqual( + torch.tensor([self.rank + 1]), + tensors[0], + ) + + for op, err in zip( + (c10d.ReduceOp.BAND, c10d.ReduceOp.BOR, c10d.ReduceOp.BXOR), + ("ReduceOp.BAND", "ReduceOp.BOR", "ReduceOp.BXOR"), + ): + with self.assertRaisesRegex( + ValueError, "Cannot use " + err + " with XCCL" + ): + reduce(tensors, self.rank, rt, op) + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + def test_allgather_ops(self): + pg = self.pg + local_device_ids = self.rank_to_GPU[self.rank] + + def allgather(output_ts, input_ts): + work = pg.allgather(output_ts, input_ts) + return work.wait() + + tensors = [torch.empty(2, 2).fill_(2).xpu(device=i) for i in local_device_ids] + output_tensors = [] + expected_output = [] + + output_per_gpu = ( + [torch.empty(2, 2).fill_(-1)] * len(local_device_ids) * self.world_size + ) + expected_per_gpu = ( + [torch.empty(2, 2).fill_(2)] * len(local_device_ids) * self.world_size + ) + + for gpu in local_device_ids: + output_tensors.append([t.xpu(device=gpu) for t in output_per_gpu]) + expected_output.append([t.xpu(device=gpu) for t in expected_per_gpu]) + + result = allgather(output_tensors, tensors) + + # Verification + self.assertEqual(output_tensors, expected_output) + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + def test_allgather_base_ops(self): + pg = self.pg + local_device_id = self.rank_to_GPU[self.rank][0] + + def allgather_base(output_t, input_t): + work = pg._allgather_base(output_t, input_t) + work.wait() + + # allgather_base is GPU number agnostic. + # Each rank contribute one tensor regardless of GPU counts + tensor = torch.tensor([self.rank]).xpu(local_device_id) + output_t = torch.empty((self.world_size), dtype=tensor.dtype).xpu( + local_device_id + ) + + allgather_base(output_t, tensor) + + # Verification + self.assertEqual(torch.arange(self.world_size), output_t) + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + def test_allgather_base_basics(self): + pg = self.pg + local_device_id = self.rank_to_GPU[self.rank][0] + + def allgather_base(output_t, input_t): + work = pg._allgather_base(output_t, input_t) + work.wait() + + # anticipate an error + with self.assertRaisesRegex( + ValueError, + "output tensor size must be equal to world_size times input tensor size", + ): + tensor = torch.tensor([self.rank]).xpu(local_device_id) + output_t = torch.empty((self.world_size + 1), dtype=tensor.dtype).xpu( + local_device_id + ) + # fails the check because output_t is not correctly sized + allgather_base(output_t, tensor) + + # anticipate an error + with self.assertRaisesRegex( + TypeError, "output tensor must have the same type as input tensor" + ): + tensor = torch.tensor([self.rank], dtype=torch.float).xpu(local_device_id) + output_t = torch.empty((self.world_size + 1), dtype=torch.long).xpu( + local_device_id + ) + # fails the check because the dtype is different + allgather_base(output_t, tensor) + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + def test_gather_ops(self): + pg = self.pg + local_device_ids = self.rank_to_GPU[self.rank] + num_gpus = len(local_device_ids) + + def gather(output_t, input_t, rootRank): + opts = c10d.GatherOptions() + opts.rootRank = rootRank + if rootRank == self.rank: + work = pg.gather(output_t, input_t, opts) + else: + work = pg.gather([], input_t, opts) + work.wait() + + # init input + tensors = [] + for device_id in local_device_ids: + tensors.append(torch.tensor([self.rank]).xpu(device_id)) + + # init output + output_ts = [] + for idx in range(num_gpus): + gpu_idx = local_device_ids[idx] + output_ts.append([]) + for rank in range(self.world_size): + output_ts[idx].append(torch.tensor([-1]).xpu(gpu_idx)) + + expected = [[torch.tensor([rank]) for rank in range(self.world_size)]] + for rank in range(self.world_size): + gather(output_ts, tensors, rank) + if rank == self.rank: + self.assertEqual(expected, output_ts) + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + def test_gather_stress(self): + pg = self.pg + local_device_ids = self.rank_to_GPU[self.rank] + num_gpus = len(local_device_ids) + + def gather(output_t, input_t, rootRank): + opts = c10d.GatherOptions() + opts.rootRank = rootRank + if rootRank == self.rank: + work = pg.gather(output_t, input_t, opts) + else: + work = pg.gather([], input_t, opts) + work.wait() + + stress_length = 1000 + + # init input + tensors = [] + for i in range(stress_length): + tensors.append([]) + for device_id in local_device_ids: + tensors[i].append(torch.tensor([self.rank]).xpu(device_id)) + + # init output + output_ts = [] + for i in range(stress_length): + output_ts.append([[] for _ in range(num_gpus)]) + for idx, ls in enumerate(output_ts[i]): + gpu_idx = local_device_ids[idx] + for _ in range(self.world_size): + ls.append(torch.tensor([-1]).xpu(gpu_idx)) + + expected = [[torch.tensor([rank]) for rank in range(self.world_size)]] + for i in range(stress_length): + for rank in range(self.world_size): + gather(output_ts[i], tensors[i], rank) + # Verification + if rank == self.rank: + self.assertEqual(output_ts[i], expected) + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + def test_gather_checks(self): + pg = self.pg + device_id = self.rank_to_GPU[self.rank][0] + + # init input + tensor = torch.tensor([self.rank]).xpu(device_id) + + # init output + output_ts = [] + for rank in range(self.world_size): + output_ts.append(torch.tensor([-1]).xpu(device_id)) + + with self.assertRaisesRegex(ValueError, "invalid root rank"): + opts = c10d.GatherOptions() + opts.rootRank = -1 + pg.gather([output_ts], [tensor], opts) + + with self.assertRaisesRegex(TypeError, "incompatible function arguments"): + pg.gather([output_ts], [tensor], 0) + + with self.assertRaisesRegex(ValueError, "invalid root rank"): + opts = c10d.GatherOptions() + opts.rootRank = self.world_size + pg.gather([output_ts], [tensor], opts) + + with self.assertRaisesRegex( + # throws error message from dispatcher + RuntimeError, + "There were no tensor arguments to this function", + ): + opts = c10d.GatherOptions() + opts.rootRank = 0 + pg.gather([output_ts], [], opts) + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + def test_scatter_ops(self): + pg = self.pg + local_device_ids = self.rank_to_GPU[self.rank] + num_gpus = len(local_device_ids) + + def scatter(output_t, input_t, rootRank): + opts = c10d.ScatterOptions() + opts.rootRank = rootRank + if rootRank == self.rank: + work = pg.scatter(output_t, input_t, opts) + else: + work = pg.scatter(output_t, [], opts) + work.wait() + + # init output + tensors = [] + for device_id in local_device_ids: + tensors.append(torch.tensor([-1]).xpu(device_id)) + + # init input + scatter_list = [] + for idx in range(num_gpus): + gpu_idx = local_device_ids[idx] + scatter_list.append([]) + for rank in range(self.world_size): + scatter_list[idx].append(torch.tensor([rank]).xpu(gpu_idx)) + + # test each rank to scatter + expected = [torch.tensor([self.rank])] + for rank in range(self.world_size): + scatter(tensors, scatter_list, rank) + self.assertEqual(expected, tensors) + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + def test_scatter_stress(self): + pg = self.pg + local_device_ids = self.rank_to_GPU[self.rank] + num_gpus = len(local_device_ids) + + def scatter(output_t, input_t, rootRank): + opts = c10d.ScatterOptions() + opts.rootRank = rootRank + if rootRank == self.rank: + work = pg.scatter(output_t, input_t, opts) + else: + work = pg.scatter(output_t, [], opts) + work.wait() + + stress_length = 1000 + + # init output + tensors = [] + for i in range(stress_length): + tensors.append([]) + for device_id in local_device_ids: + tensors[i].append(torch.tensor([-1]).xpu(device_id)) + + # init input + scatter_list = [] + for i in range(stress_length): + scatter_list.append([[] for _ in range(num_gpus)]) + for idx, ls in enumerate(scatter_list[i]): + gpu_idx = local_device_ids[idx] + for rank in range(self.world_size): + ls.append(torch.tensor([rank]).xpu(gpu_idx)) + + # test each rank to scatter + expected = [torch.tensor([self.rank])] + for i in range(stress_length): + for rank in range(self.world_size): + scatter(tensors[i], scatter_list[i], rank) + # Verification + self.assertEqual(tensors[i], expected) + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + def test_scatter_checks(self): + pg = self.pg + local_device_ids = self.rank_to_GPU[self.rank] + num_gpus = len(local_device_ids) + + # init output + tensors = [] + for device_id in local_device_ids: + tensors.append(torch.tensor([-1]).xpu(device_id)) + + # init input + scatter_list = [] + for idx in range(num_gpus): + gpu_idx = local_device_ids[idx] + scatter_list.append([]) + for rank in range(self.world_size): + scatter_list[idx].append(torch.tensor([rank]).xpu(gpu_idx)) + + with self.assertRaisesRegex(ValueError, "invalid root rank"): + opts = c10d.ScatterOptions() + opts.rootRank = -1 + pg.scatter(tensors, scatter_list, opts) + + with self.assertRaisesRegex(TypeError, "incompatible function arguments"): + pg.scatter(tensors, scatter_list, 0) + + with self.assertRaisesRegex(ValueError, "invalid root rank"): + opts = c10d.ScatterOptions() + opts.rootRank = self.world_size + pg.scatter(tensors, scatter_list, opts) + + with self.assertRaisesRegex( + # throws error message from dispatcher + RuntimeError, + "There were no tensor arguments to this function", + ): + opts = c10d.ScatterOptions() + opts.rootRank = 0 + pg.scatter([], scatter_list, opts) + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + def test_reduce_scatter_base_basics(self): + pg = self.pg + local_device_id = self.rank_to_GPU[self.rank][0] + + def reduce_scatter_base(output_t, input_t): + work = pg._reduce_scatter_base(output_t, input_t) + work.wait() + + # anticipate an error + with self.assertRaisesRegex( + ValueError, + "input tensor must be the same size as output size times world size", + ): + input_t = torch.tensor([self.rank]).xpu(local_device_id) + output_t = torch.empty((self.world_size + 1), dtype=input_t.dtype).xpu( + local_device_id + ) + # fails the check because output_t is not correctly sized + reduce_scatter_base(output_t, input_t) + + # anticipate an error + with self.assertRaisesRegex( + TypeError, "input tensor must be the same type as the output tensor." + ): + tensor = torch.tensor([self.rank], dtype=torch.float).xpu(local_device_id) + output_t = torch.empty((self.world_size + 1), dtype=torch.long).xpu( + local_device_id + ) + # fails the check because the dtype is different + reduce_scatter_base(output_t, tensor) + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + def test_reduce_scatter_ops(self): + pg = self.pg + local_device_ids = self.rank_to_GPU[self.rank] + num_gpus = len(local_device_ids) + + def reduce_scatter(outputs, input_lists, op): + opts = c10d.ReduceScatterOptions() + opts.reduceOp = op + work = pg.reduce_scatter(outputs, input_lists, opts) + work.wait() + + output = [torch.tensor([0]).xpu(i) for i in local_device_ids] + + # GPU/rank + # 0 [1], [2], [3], [4] + # 1 [2], [3], [4], [5] + # 2 [3], [4], [5], [6] + # 3 [4], [5], [6], [7] + + # Sum + tensor_lists = [] + input_per_gpu = [] + + for i in range(self.world_size): + input_per_gpu.append(torch.tensor([self.rank + i + 1])) + + for gpu in local_device_ids: + tensor_lists.append([t.xpu(device=gpu) for t in input_per_gpu]) + + reduce_scatter(output, tensor_lists, c10d.ReduceOp.SUM) + + for i in range(num_gpus): + expected = torch.tensor( + [ + (1 + self.world_size) * self.world_size // 2 + + self.world_size * self.rank + ] + ) + + self.assertEqual(expected, output[i]) + + # Min + reduce_scatter(output, tensor_lists, c10d.ReduceOp.MIN) + + for i in range(num_gpus): + expected = torch.tensor([self.rank + 1 + i]) + self.assertEqual(expected, output[i]) + + # Max + reduce_scatter(output, tensor_lists, c10d.ReduceOp.MAX) + + for i in range(num_gpus): + expected = torch.tensor([self.rank + self.world_size + i]) + self.assertEqual(expected, output[i]) + + # Product + reduce_scatter(output, tensor_lists, c10d.ReduceOp.PRODUCT) + + # math package don't have math.perm until python 3.8, so + # we implement a naive version here. + def perm(n, k): + prod_val = n + for val in range(n - k + 1, n): + prod_val *= val + return prod_val + + for i in range(num_gpus): + prod_val = perm(self.rank + self.world_size, self.world_size) + + expected = torch.tensor([prod_val]) + self.assertEqual(expected, output[i]) + + # Test the input params overridden scenarios, aka, when the input is + # a list and output is just one tensor. + # Sum + output_tensor = torch.empty_like(input_per_gpu[0][0]).xpu(self.rank) + input_list = [tensor[0].xpu(self.rank) for tensor in input_per_gpu] + pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.SUM).wait() + expected = torch.tensor( + (1 + self.world_size) * self.world_size // 2 + self.world_size * self.rank + ) + self.assertEqual(expected, output_tensor) + + # Min + pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.MIN).wait() + expected = torch.tensor(self.rank + 1) + self.assertEqual(expected, output_tensor) + + # Max + pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.MAX).wait() + expected = torch.tensor(self.rank + self.world_size) + self.assertEqual(expected, output_tensor) + + # Product + pg.reduce_scatter(output_tensor, input_list, c10d.ReduceOp.PRODUCT).wait() + prod_val = self.rank + 1 + for k in range(1, self.world_size): + prod_val = prod_val * (self.rank + 1 + k) + expected = torch.tensor(prod_val) + self.assertEqual(expected, output_tensor) + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + def test_reduce_scatter_base_ops(self): + pg = self.pg + local_device_id = self.rank_to_GPU[self.rank][0] + + def reduce_scatter_base(output_t, input_t): + work = pg._reduce_scatter_base(output_t, input_t) + work.wait() + + # reduce_scatter_base is GPU number agnostic. + # Each rank contribute one tensor regardless of GPU counts + output_t = torch.empty([1]).xpu(local_device_id) + tensor = torch.arange(self.world_size, dtype=output_t.dtype).xpu( + local_device_id + ) + + reduce_scatter_base(output_t, tensor) + + # Verification + self.assertEqual(output_t[0], self.rank * self.world_size) + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + def test_barrier(self): + pg = self.pg + local_device_ids = self.rank_to_GPU[self.rank] + + def allreduce(tensors): + opts = c10d.AllreduceOptions() + work = pg.allreduce(tensors, opts) + return work + + # Making the collective to operate on + # 1, 2, 3, 4, .... len(local_device_ids) GPUs + tensors_list = [[] for _ in range(len(local_device_ids))] + + for i in range(1, len(local_device_ids) + 1): + for j in range(i): + tensors_list[i - 1].append( + torch.tensor([j + 1]).xpu(local_device_ids[j]) + ) + + works = [] + for tensors in tensors_list: + work = allreduce(tensors) + works.append(work) + + # Barrier will ensure that all previous work is completed + pg.barrier().wait() + + for i in range(1, len(local_device_ids) + 1): + for j in range(i): + self.assertEqual( + torch.tensor([(j + 1) * self.world_size]), tensors_list[i - 1][j] + ) + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + def test_send_recv(self): + pg = self.pg + device = self.rank_to_GPU[self.rank][0] + + # Generate the same random tensor + torch.manual_seed(0) + send_tensor = torch.rand(10, 10, device=device) + if self.rank == 0: + dist.send(send_tensor, 1) + if self.rank == 1: + recv_tensor = torch.rand(10, 10, device=device) + dist.recv(recv_tensor, 0) + self.assertEqual(send_tensor, recv_tensor) + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + def test_send_recv_complex(self): + pg = self.pg + device = self.rank_to_GPU[self.rank][0] + + # Generate the same random tensor + torch.manual_seed(0) + send_tensor = torch.rand(10, 10, dtype=torch.cfloat, device=device) + if self.rank == 0: + dist.send(send_tensor, 1) + if self.rank == 1: + recv_tensor = torch.rand(10, 10, dtype=torch.cfloat, device=device) + dist.recv(recv_tensor, 0) + self.assertEqual(send_tensor, recv_tensor) + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") + def test_send_recv_object_list(self): + device = self.rank_to_GPU[self.rank][0] + + val = 99 if self.rank == 0 else None + object_list = [val] * self.world_size + if self.rank == 0: + dist.send_object_list(object_list, 1, device=device) + if self.rank == 1: + dist.recv_object_list(object_list, 0, device=device) + self.assertEqual(object_list[0], 99) + + +if __name__ == "__main__": + rank = int(os.getenv("RANK", -1)) + world_size = int(os.getenv("WORLD_SIZE", 2)) + + if rank != -1: + # Launched with torchrun or other multi-proc launchers. Directly run the test. + ProcessGroupXCCLOpTest.run_rank(rank, world_size) + else: + # Launched as a single process. Spawn subprocess to run the tests. + # Also need a rendezvous file for `init_process_group` purpose. + rdvz_file = tempfile.NamedTemporaryFile(delete=False).name + torch.multiprocessing.spawn( + ProcessGroupXCCLOpTest.run_rank, + nprocs=world_size, + args=(world_size, rdvz_file), + ) diff --git a/test/distributed/test_c10d_xccl.py b/test/distributed/test_c10d_xccl.py new file mode 100644 index 0000000000000..3503f6059f282 --- /dev/null +++ b/test/distributed/test_c10d_xccl.py @@ -0,0 +1,1675 @@ +# Owner(s): ["oncall: distributed"] + +import copy +import math +import os +import random +import sys +import time +from datetime import timedelta +from enum import auto, Enum +from itertools import product +from unittest import mock + +from test_c10d_common import DoubleGpuNet, gpus_for_rank, ModuleForDdpCommHook + +import torch +import torch.distributed as c10d +import torch.distributed.algorithms.ddp_comm_hooks.default_hooks as default +import torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook as powerSGD +import torch.nn.functional as F +from torch import nn +from torch.nn.parallel import DistributedDataParallel + + +if not c10d.is_available() or not c10d.is_xccl_available(): + print("c10d XCCL not available, skipping tests", file=sys.stderr) + sys.exit(0) + +import test_c10d_common + +import torch.distributed as dist +import torch.testing._internal.common_utils as common +from torch.testing._internal.common_distributed import ( + init_multigpu_helper, + MultiProcessTestCase, + requires_xccl, + skip_if_lt_x_gpu, +) +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + retry_on_connect_failures, + run_tests, + skip_but_pass_in_sandcastle_if, + TEST_XPU, + TestCase, +) + + +def simple_reduce_tests(rank, world_size): + tests = [ + ( + c10d.ReduceOp.SUM, + torch.tensor([rank + 1.0]), + torch.tensor([float(world_size * (world_size + 1) / 2)]), + ), + ( + c10d.ReduceOp.PRODUCT, + torch.tensor([rank + 1.0]), + torch.tensor([float(math.factorial(world_size))]), + ), + ( + c10d.ReduceOp.MIN, + torch.tensor([rank + 1.0]), + torch.tensor([1.0]), + ), + ( + c10d.ReduceOp.MAX, + torch.tensor([rank + 1.0]), + torch.tensor([world_size]), + ), + ] + + return tests + + +TEST_MULTIXPU = torch.xpu.device_count() > 1 + + +class RendezvousEnvTest(TestCase): + @retry_on_connect_failures + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_XPU, "No GPUs available, skipping test") + def test_common_errors(self): + vars = { + "WORLD_SIZE": "1", + "RANK": "0", + "MASTER_ADDR": "127.0.0.1", + "MASTER_PORT": str(common.find_free_port()), + } + + class Env: + def __init__(self, vars): + self.env_patcher = mock.patch.dict(os.environ, vars, clear=True) + + def __enter__(self): + self.env_patcher.start() + + def __exit__(self, type, value, traceback): + self.env_patcher.stop() + + def without(d, key): + d = d.copy() + d.pop(key) + return d + + def withouts(d, keys): + d = d.copy() + for key in keys: + d.pop(key) + return d + + with Env(without(vars, "WORLD_SIZE")): + self.assertEqual(None, os.environ.get("WORLD_SIZE")) + with self.assertRaisesRegex(ValueError, "WORLD_SIZE expected"): + gen = c10d.rendezvous("env://") + next(gen) + c10d.init_process_group(backend="xccl", world_size=1) + self.assertEqual(c10d.get_rank(), 0) + self.assertEqual(c10d.get_world_size(), 1) + c10d.destroy_process_group() + + with Env(without(vars, "RANK")): + self.assertEqual(None, os.environ.get("RANK")) + with self.assertRaisesRegex(ValueError, "RANK expected"): + gen = c10d.rendezvous("env://") + next(gen) + c10d.init_process_group(backend="xccl", rank=0) + self.assertEqual(c10d.get_rank(), 0) + self.assertEqual(c10d.get_world_size(), 1) + c10d.destroy_process_group() + + with Env(withouts(vars, ["RANK", "WORLD_SIZE"])): + self.assertEqual(None, os.environ.get("RANK")) + self.assertEqual(None, os.environ.get("WORLD_SIZE")) + c10d.init_process_group(backend="xccl", rank=0, world_size=1) + self.assertEqual(c10d.get_rank(), 0) + self.assertEqual(c10d.get_world_size(), 1) + c10d.destroy_process_group() + + with Env(vars): + c10d.init_process_group(backend="xccl") + self.assertEqual(c10d.get_rank(), 0) + self.assertEqual(c10d.get_world_size(), 1) + c10d.destroy_process_group() + + with Env(without(vars, "MASTER_ADDR")): + self.assertEqual(None, os.environ.get("MASTER_ADDR")) + with self.assertRaisesRegex(ValueError, "MASTER_ADDR expected"): + gen = c10d.rendezvous("env://") + next(gen) + + with Env(without(vars, "MASTER_PORT")): + self.assertEqual(None, os.environ.get("MASTER_PORT")) + with self.assertRaisesRegex(ValueError, "MASTER_PORT expected"): + gen = c10d.rendezvous("env://") + next(gen) + + with Env(without(vars, "WORLD_SIZE")): + self.assertEqual(None, os.environ.get("WORLD_SIZE")) + gen = c10d.rendezvous(f"env://?world_size={1}") + _, _, size = next(gen) + self.assertEqual(size, 1) + + with Env(without(vars, "RANK")): + self.assertEqual(None, os.environ.get("RANK")) + gen = c10d.rendezvous(f"env://?rank={0}") + _, rank, _ = next(gen) + self.assertEqual(rank, 0) + + with Env(withouts(vars, ["RANK", "WORLD_SIZE"])): + self.assertEqual(None, os.environ.get("RANK")) + self.assertEqual(None, os.environ.get("WORLD_SIZE")) + gen = c10d.rendezvous(f"env://?rank={0}&world_size={1}") + _, rank, size = next(gen) + self.assertEqual(rank, 0) + self.assertEqual(size, 1) + + +class TimeoutTest(test_c10d_common.AbstractTimeoutTest, TestCase): + @requires_xccl() + @retry_on_connect_failures + @skip_but_pass_in_sandcastle_if(not TEST_XPU, "No GPUs available, skipping test") + def test_default_store_timeout_nccl(self): + self._test_default_store_timeout("xccl") + + +class ProcessGroupXCCLTest(MultiProcessTestCase): + def _create_process_group_xccl( + self, timeout=timedelta(seconds=600), device_id=None + ): + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + "xccl", + world_size=self.world_size, + rank=self.rank, + store=store, + timeout=timeout, + device_id=device_id, + ) + pg = c10d.distributed_c10d._get_default_group() + return pg + + def setUp(self): + super().setUp() + self._spawn_processes() + + def tearDown(self): + super().tearDown() + try: + os.remove(self.file_name) + except OSError: + pass + + @property + def world_size(self): + return 2 + + @property + def rank_to_GPU(self): + # return rank to GPU map + return init_multigpu_helper(self.world_size, "xccl") + + @requires_xccl() + @skip_but_pass_in_sandcastle_if( + torch.xpu.device_count() < 2, "XCCL test requires 2+ GPUs" + ) + def test_close_multi_pg_unordered(self): + pg = self._create_process_group_xccl() + device = self.rank_to_GPU[self.rank][0] + t = torch.rand(10, 10, device=device) + # First allreduce to initialize default PG's communicator. + pg.allreduce(t).wait() + new_pg1 = c10d.new_group([0, 1]) + new_pg2 = c10d.new_group([0, 1]) + if self.rank == 0 or self.rank == 1: + t1 = torch.rand(10, 10, device=device) + t2 = torch.rand(10, 10, device=device) + new_pg1.allreduce(t1).wait() + new_pg2.allreduce(t2).wait() + if self.rank == 0: + dist.destroy_process_group(new_pg2) + # force destruction of pg2 first + del new_pg2 + dist.destroy_process_group(new_pg1) + del new_pg1 + if self.rank == 1: + c10d.destroy_process_group(new_pg1) + # force destruction of pg1 first + del new_pg1 + dist.destroy_process_group(new_pg2) + del new_pg2 + dist.destroy_process_group() + + @requires_xccl() + @skip_but_pass_in_sandcastle_if( + torch.xpu.device_count() < 2, "XCCL test requires 2+ GPUs" + ) + def test_file_store_check(self): + # self.file_name is created using "delete=False" + # e.g., self.file_name = tempfile.NamedTemporaryFile(delete=False).name + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + backend="xccl", rank=self.rank, world_size=self.world_size, store=store + ) + pg = dist.distributed_c10d._get_default_group() + self.assertEqual(pg.rank(), self.rank) + self.assertEqual(pg.size(), self.world_size) + # give enough time for check() to be executed multiple times + time.sleep(2) + dist.destroy_process_group() + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIXPU, "XCCL test requires 2+ GPUs") + def test_set_process_group_desc(self): + device = torch.device(f"xpu:{self.rank}") + pg_default = self._create_process_group_xccl(device_id=device) + self.assertEqual(pg_default.group_desc, "default_pg") + pg_1 = c10d.new_group([0, 1], group_desc="test_purpose") + self.assertEqual(pg_1.group_desc, "test_purpose") + pg_2 = c10d.new_group([0, 1]) + self.assertEqual(pg_2.group_desc, "undefined") + + +class DistributedDataParallelTest( + test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase +): + def setUp(self): + super().setUp() + self._spawn_processes() + + def _get_process_group(self): + store = self._get_store() + c10d.init_process_group( + "xccl", store=store, rank=self.rank, world_size=self.world_size + ) + return c10d.distributed_c10d._get_default_group() + + def _test_xccl_backend( + self, devices, device_ids, multi_device=False, gradient_as_bucket_view=False + ): + process_group = self._get_process_group() + self._test_ddp_with_process_group( + process_group, devices, device_ids, multi_device, gradient_as_bucket_view + ) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_xccl_backend_multi_device_ids_not_allowed(self): + int_devices = list(range(torch.xpu.device_count())) + devices = [torch.device("xpu:" + str(i)) for i in int_devices] + with self.assertRaisesRegex( + ValueError, "device_ids can only be None or contain a single element." + ): + self._test_xccl_backend(devices, int_devices) + + @requires_xccl() + @skip_if_lt_x_gpu(4) + def test_ddp_multi_device_module_config(self): + gpus = gpus_for_rank(self.world_size, "xccl")[self.rank] + + self.assertTrue(len(gpus) >= 2, "expecting at least 2 gpus per process") + + process_group = self._get_process_group() + + gpus = gpus[:2] + model = DoubleGpuNet(gpus) + + with self.assertRaisesRegex( + ValueError, + "DistributedDataParallel device_ids and output_device arguments only work with " + "single-device/multiple-device GPU modules or CPU modules", + ): + ddp_model = DistributedDataParallel( + model, output_device=gpus[1], process_group=process_group + ) + + with self.assertRaisesRegex( + ValueError, "device_ids can only be None or contain a single element." + ): + ddp_model = DistributedDataParallel( + model, device_ids=gpus, process_group=process_group + ) + + with self.assertRaisesRegex( + ValueError, "input module must be on the same type of devices" + ): + model.fc1 = model.fc1.cpu() + ddp_model = DistributedDataParallel(model, process_group=process_group) + + model = model.cpu() + with self.assertRaisesRegex( + ValueError, "device_ids can only be None or contain a single element." + ): + ddp_model = DistributedDataParallel( + model, device_ids=gpus, process_group=process_group + ) + + def _test_fp16(self, gradient_as_bucket_view=False): + process_group = self._get_process_group() + + gpus = gpus_for_rank(self.world_size, "xccl")[self.rank] + model = nn.Linear(1, 1, bias=False).xpu(gpus[0]).half() + nn.init.constant_(model.weight, 1) + ddp_model = DistributedDataParallel( + model, + device_ids=[gpus[0]], + process_group=process_group, + bucket_cap_mb=0.001, + gradient_as_bucket_view=gradient_as_bucket_view, + ) + + # Input 2**15, so that the gradients will overflow with a + # world_size of 2, unless we normalize the gradient by the + # world_size before the reduction + input = torch.tensor([[2**15]]).xpu(gpus[0]).half() + + # Step model + ddp_model.train() + output = ddp_model(input) + loss = output.sum() + loss.backward() + + self.assertFalse(any(torch.isinf(p.grad).any() for p in ddp_model.parameters())) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_fp16(self): + self._test_fp16() + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_fp16_grad_is_view(self): + self._test_fp16(gradient_as_bucket_view=True) + + def _test_arbitrary_forward_return_value(self, gradient_as_bucket_view=False): + """ + Note: this test can be sped up by only running it on a CPU module + once DistributedDataParallel supports them. + """ + process_group = self._get_process_group() + + class ForwardReturnValueModule(nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = nn.Linear(2, 10, bias=False) + self.fc2 = nn.Linear(10, 4, bias=False) + self.fc3 = nn.Linear(4, 4, bias=False) + self.relu = nn.ReLU() + + def forward(self, x, fn): + x = self.relu(self.fc1(x)) + x = self.relu(self.fc2(x)) + # The first softmax does NOT include fc3 in its autograd graph + # whereas the second softmax DOES. If we pass only the first + # tensor we see in the output to the reducer, it marks the + # gradient for fc3 as ready (because it doesn't show up). If + # downstream uses of this return value choose to differentiate + # against the second output tensor, it would still receive a + # gradient and a callback for this tensor, resulting in a crash. + return fn( + F.softmax(x, dim=1), + F.softmax(self.fc3(x), dim=1), + ) + + device_id = gpus_for_rank(self.world_size, "xccl")[self.rank][0] + model = DistributedDataParallel( + ForwardReturnValueModule().float().to(device_id), + device_ids=[device_id], + process_group=process_group, + gradient_as_bucket_view=gradient_as_bucket_view, + ) + + batch_size = 4 + criterion = nn.CrossEntropyLoss() + input = torch.rand([batch_size, 2], dtype=torch.float) + target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to( + device_id + ) + + # Always run "backward" to ensure the reducer is called by autograd. + # If we don't correctly capture the output tensors from the return value, + # the reducer won't see a hook for the unused parameter, and throw an error. + # The correct capture is what we're testing in this function. + def test(box, unbox): + output = model(input, fn=box) + loss = criterion(unbox(output), target) + loss.backward() + + # Test with identity return value + test( + box=lambda x, y: (x, y), + unbox=lambda obj: obj[1], + ) + + # Test with list return value + test( + box=lambda x, y: ["foo", x, "bar", y], + unbox=lambda obj: obj[3], + ) + + # Test with tuple return value + test( + box=lambda x, y: ("foo", x, "bar", y), + unbox=lambda obj: obj[3], + ) + + # Test with dict return value + test( + box=lambda x, y: {"foo": "bar", "a": x, "b": y}, + unbox=lambda obj: obj["b"], + ) + + # Test with list with dict return value + test( + box=lambda x, y: ["foo", "bar", {"a": x, "b": y}], + unbox=lambda obj: obj[2]["b"], + ) + + # Test with dict with list return value + test( + box=lambda x, y: {"foo": "bar", "list": [0, x, 1, y]}, + unbox=lambda obj: obj["list"][3], + ) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_arbitrary_forward_return_value(self): + self._test_arbitrary_forward_return_value() + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_arbitrary_forward_return_value_grad_is_view(self): + self._test_arbitrary_forward_return_value(gradient_as_bucket_view=True) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_ddp_with_lazy_parameters(self): + process_group = self._get_process_group() + with self.assertRaisesRegex( + RuntimeError, "Modules with uninitialized parameters" + ): + DistributedDataParallel( + torch.nn.LazyLinear(10), process_group=process_group + ) + + def _test_multiple_outputs_multiple_backward(self, gradient_as_bucket_view=False): + """ + Note: this test can be sped up by only running it on a CPU module + once DistributedDataParallel supports them. + """ + process_group = self._get_process_group() + + class MultipleOutputModule(nn.Module): + def __init__(self) -> None: + super().__init__() + + def define_module(): + return nn.Sequential( + nn.Linear(2, 10, bias=False), + nn.ReLU(), + nn.Linear(10, 4, bias=False), + nn.ReLU(), + ) + + self.module0 = define_module() + self.module1 = define_module() + + def forward(self, x): + return ( + F.softmax(self.module0(x), dim=1), + F.softmax(self.module1(x), dim=1), + ) + + device_id = gpus_for_rank(self.world_size, "xccl")[self.rank][0] + model = DistributedDataParallel( + MultipleOutputModule().float().to(device_id), + device_ids=[device_id], + process_group=process_group, + gradient_as_bucket_view=gradient_as_bucket_view, + ) + + batch_size = 4 + criterion = nn.CrossEntropyLoss() + input = torch.rand([batch_size, 2], dtype=torch.float) + target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to( + device_id + ) + + # Compute loss and gradients for both outputs + output1, output2 = model(input) + loss1 = criterion(output1, target) + loss1.backward() + loss2 = criterion(output2, target) + loss2.backward() + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_multiple_outputs_multiple_backward(self): + self._test_multiple_outputs_multiple_backward() + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_multiple_outputs_multiple_backward_grad_is_view(self): + self._test_multiple_outputs_multiple_backward(gradient_as_bucket_view=True) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_no_grad(self): + """ + Note: this test can be sped up by only running it on a CPU module + once DistributedDataParallel supports them. + """ + process_group = self._get_process_group() + + class NoGradModule(nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = nn.Linear(2, 10, bias=False) + self.fc2 = nn.Linear(10, 4, bias=False) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.relu(self.fc1(x)) + x = self.relu(self.fc2(x)) + return F.softmax(x, dim=1) + + device_id = gpus_for_rank(self.world_size, "xccl")[self.rank][0] + model = DistributedDataParallel( + NoGradModule().float().to(device_id), + device_ids=[device_id], + process_group=process_group, + ) + + batch_size = 4 + input = torch.rand([batch_size, 2], dtype=torch.float) + + def check_no_grads(): + for p in model.parameters(): + self.assertTrue(p.requires_grad) + self.assertIsNone(p.grad) + + # After initialization, no parameter has their gradient set. + check_no_grads() + + # Run `forward` function with torch.no_grad() + with torch.no_grad(): + output = model(input) + self.assertTrue(isinstance(output, torch.Tensor)) + + # No parameter should have their gradient set. + check_no_grads() + + def _test_accumulate_gradients_module(self, gradient_as_bucket_view=False): + # This is NOT the recommended way to implement accumulating grads, but + # we would like to make sure DDP does not mess up with the underlying + # module. + int_devices = gpus_for_rank(self.world_size, "xccl")[self.rank][:1] + devices = [torch.device("xpu:" + str(i)) for i in int_devices] + process_group = self._get_process_group() + global_batch_size = self.world_size + + model, ddp_model, input, target = self._prepare_single_device_module( + process_group, devices, devices, global_batch_size, gradient_as_bucket_view + ) + + def step_model(model, input, target): + model.train() + output = model(input) + loss = F.mse_loss(output, target.to(output.device)) + loss.backward() + + # ensure accumulate grads works with no_grad + with torch.no_grad(): + ddp_model.train() + ddp_model.module(input) + + # Check two model parameters over 4 iterations. + # Use 4 iterations because we alternate between reducing and + # not reducing and want to make sure we switch both ways. + for iteration in range(4): + step_model(model, input, target) + + if iteration % 2 == 0: + # Skip gradients sync without calling prepare_for_backward + step_model( + ddp_model.module, + input[self.rank : (self.rank + 1)], + target[self.rank : (self.rank + 1)], + ) + for i, j in zip(model.parameters(), ddp_model.parameters()): + self.assertNotEqual(i.grad, j.grad) + else: + step_model( + ddp_model, + input[self.rank : (self.rank + 1)], + target[self.rank : (self.rank + 1)], + ) + for i, j in zip(model.parameters(), ddp_model.parameters()): + self.assertEqual(i.grad, j.grad, rtol=1.3e-06, atol=5e-5) + + # Shuffle the input so that DDP input is different + torch.manual_seed(1337 + iteration) + input = input[torch.randperm(global_batch_size)] + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_failure_recovery(self): + process_group = self._get_process_group() + + # need to create a separate file for the recovered FileStore, because + # the original one will be deleted when destructing the first FileStore. + recovery_filename = self.file_name + "_recovery" + + if self.rank == 0: + # the file will be deleted by the recovered FileStore + open(recovery_filename, "w").close() + + # not necessary to run barrier here, as DDP will synchronize + + class TestModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = nn.Linear(2, 10, bias=False) + self.fc2 = nn.Linear(10, 4, bias=False) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.relu(self.fc1(x)) + x = self.relu(self.fc2(x)) + return F.softmax(x, dim=1) + + device_id = gpus_for_rank(self.world_size, "xccl")[self.rank][0] + model = TestModel().float().to(device_id) + ddp = DistributedDataParallel( + model, + device_ids=[device_id], + process_group=process_group, + ) + + batch_size = 4 + criterion = nn.CrossEntropyLoss() + input = torch.rand([batch_size, 2], dtype=torch.float) + target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to( + device_id + ) + + for _ in range(6): + output = ddp(input) + loss = criterion(output, target) + loss.backward() + + del ddp + c10d.destroy_process_group(process_group) + + store = c10d.FileStore(recovery_filename, self.world_size) + c10d.init_process_group( + "xccl", store=store, rank=self.rank, world_size=self.world_size + ) + process_group = c10d.distributed_c10d._get_default_group() + ddp = DistributedDataParallel( + model, + device_ids=[device_id], + process_group=process_group, + ) + + input = torch.rand([batch_size, 2], dtype=torch.float) + target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to( + device_id + ) + for _ in range(6): + output = ddp(input) + loss = criterion(output, target) + loss.backward() + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_pass_default_pg(self): + dist.init_process_group( + "xccl", + init_method=f"file://{self.file_name}", + world_size=self.world_size, + rank=self.rank, + ) + + default_pg = c10d.distributed_c10d._get_default_group() + dist.destroy_process_group(default_pg) + self.assertFalse(dist.is_initialized()) + + def _gpu_model_with_ddp_comm_hook( + self, + process_group, + hook=None, + gradient_as_bucket_view=False, + state=None, + static_graph=False, + ): + device_id = gpus_for_rank(self.world_size, "xccl")[self.rank][0] + gpu_model = DistributedDataParallel( + ModuleForDdpCommHook().to(device_id), + device_ids=[device_id], + process_group=process_group, + gradient_as_bucket_view=gradient_as_bucket_view, + static_graph=static_graph, + ) + + # Register a DDP communication hook if any. + if hook is not None: + gpu_model.register_comm_hook(state, hook) + + return gpu_model + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_ddp_comm_hook_future_passing_gpu_xccl(self): + """ + This unit test verifies whether the Future object is passed properly using xccl backend. + The hook callback function creates a Future object and sets a value to it. + """ + process_group = self._get_process_group() + + # Get GPU model with simple_hook registered. + gpu_model = self._gpu_model_with_ddp_comm_hook(process_group, self._simple_hook) + + # check whether the grads are equal to what simple_hook's then callback returns. + # without the comm_hook, result would be 0.25 * torch.ones(2, 2). + self._run_and_verify_hook(gpu_model, 8, 2 * torch.ones(2, 2)) + + def _test_ddp_comm_hook_allreduce_hook_xccl( + self, gradient_as_bucket_view=False, static_graph=False + ): + """ + This unit test verifies whether a DDP communication hook that just calls + allreduce gives the same result with the case of no hook registered. + Without the then callback, the future_value in reducer is no longer + a PyObject, and this unit test verifies future_value is properly checked. + """ + process_group = self._get_process_group() + + def allreduce_hook( + state: object, bucket: dist.GradBucket + ) -> torch.futures.Future[torch.Tensor]: + tensors = [bucket.buffer() / self.world_size] + return ( + process_group.allreduce(tensors) + .get_future() + .then(lambda fut: fut.value()[0]) + ) + + # Get GPU model with allreduce_hook registered. + gpu_model = self._gpu_model_with_ddp_comm_hook( + process_group, allreduce_hook, gradient_as_bucket_view, static_graph + ) + + # check whether the grads are equal to what DDP without hook would return. + self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2)) + + def _test_default_ddp_comm_hooks_xccl(self, gradient_as_bucket_view=False): + """ + This unit test verifies whether default Python DDP communication hooks ALLREDUCE, FP16_COMPRESS + and BF16_COMPRESS, can give the same result with the case of no hook registered. + """ + process_group = self._get_process_group() + + # For these default DDP comm hooks, the only state is process group. + state = process_group + hook_options = [default.allreduce_hook, default.fp16_compress_hook] + if c10d.is_xccl_available(): + hook_options.append(default.bf16_compress_hook) + for hook in hook_options: + # Get GPU model with the hook registered. + # The first arg 'process_group' is used for initializing the test environment, + # so it cannot be replaced by 'state', although they have the same value. + gpu_model = self._gpu_model_with_ddp_comm_hook( + process_group, hook, gradient_as_bucket_view, state + ) + + # check whether the grads are equal to what DDP without hook would return. + self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2)) + + def _test_fp16_compress_wrapper(self, gradient_as_bucket_view=False): + """ + This unit test verifies whether wrapping the ALLREDUCE and POWER_SGD hooks with + the FP16_WRAPPER can give the same result as when there is no hook registered. + """ + process_group = self._get_process_group() + powerSGD_state = powerSGD.PowerSGDState(process_group=process_group) + + hook_args = [ + (powerSGD.powerSGD_hook, powerSGD_state), + (default.allreduce_hook, process_group), + ] + + for hook, state in hook_args: + gpu_model = self._gpu_model_with_ddp_comm_hook( + process_group, + default.fp16_compress_wrapper(hook), + gradient_as_bucket_view, + state, + ) + + # check whether the grads are equal to what DDP without hook would return. + self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2)) + + def _test_bf16_compress_wrapper(self, gradient_as_bucket_view=False): + """ + This unit test verifies whether wrapping the ALLREDUCE and POWER_SGD hooks with + the BF16_WRAPPER can give the same result as when there is no hook registered. + """ + process_group = self._get_process_group() + powerSGD_state = powerSGD.PowerSGDState(process_group=process_group) + + hook_args = [ + (powerSGD.powerSGD_hook, powerSGD_state), + (default.allreduce_hook, process_group), + ] + + for hook, state in hook_args: + gpu_model = self._gpu_model_with_ddp_comm_hook( + process_group, + default.bf16_compress_wrapper(hook), + gradient_as_bucket_view, + state, + ) + + # check whether the grads are equal to what DDP without hook would return. + self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2)) + + def _test_powerSGD_ddp_comm_hook_xccl(self, gradient_as_bucket_view=False): + """ + This unit test verifies whether Python DDP communication hook POWER_SGD + can give the same result with the case of no hook registered. + """ + process_group = self._get_process_group() + + # Get GPU model with the hook registered. + # Test the hook with different algorithmic configs. + for use_error_feedback, warm_start, batch_tensors_with_same_shape in product( + [True, False], + [True, False], + [True, False], + ): + state = powerSGD.PowerSGDState( + process_group=process_group, + matrix_approximation_rank=1, + use_error_feedback=use_error_feedback, + warm_start=warm_start, + batch_tensors_with_same_shape=batch_tensors_with_same_shape, + ) + for hook in [powerSGD.powerSGD_hook, powerSGD.batched_powerSGD_hook]: + gpu_model = self._gpu_model_with_ddp_comm_hook( + process_group, hook, gradient_as_bucket_view, state + ) + + # check whether the grads are equal to what DDP without hook would return. + self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2)) + + def _test_builtin_ddp_comm_hooks_xccl(self, gradient_as_bucket_view=False): + """ + This unit test verifies whether built-in C++ DDP communication hooks ALLREDUCE and FP16_COMPRESS + can give the same result with the case of no hook registered. + """ + process_group = self._get_process_group() + + for comm_hook_type in [ + dist.BuiltinCommHookType.ALLREDUCE, + dist.BuiltinCommHookType.FP16_COMPRESS, + ]: + # Get GPU model with the built-in communication hook. + gpu_model = self._gpu_model_with_builtin_ddp_comm_hook( + process_group, comm_hook_type, gradient_as_bucket_view + ) + + # check whether the grads are equal to what DDP without hook would return. + self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2)) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_ddp_comm_hook_allreduce_hook_xccl(self): + self._test_ddp_comm_hook_allreduce_hook_xccl() + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_default_ddp_comm_hooks_xccl(self): + self._test_default_ddp_comm_hooks_xccl() + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_fp16_compress_wrapper_xccl(self): + self._test_fp16_compress_wrapper() + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_bf16_compress_wrapper_xccl(self): + self._test_bf16_compress_wrapper() + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_builtin_ddp_comm_hooks_xccl(self): + self._test_builtin_ddp_comm_hooks_xccl() + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_powerSGD_ddp_comm_hook_xccl(self): + self._test_powerSGD_ddp_comm_hook_xccl() + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_ddp_comm_hook_allreduce_hook_xccl_grad_is_view(self): + self._test_ddp_comm_hook_allreduce_hook_xccl(gradient_as_bucket_view=True) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_ddp_comm_hook_allreduce_hook_xccl_static_graph(self): + self._test_ddp_comm_hook_allreduce_hook_xccl(static_graph=True) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_default_ddp_comm_hooks_xccl_is_view(self): + self._test_default_ddp_comm_hooks_xccl(gradient_as_bucket_view=True) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_fp16_compress_wrapper_is_view(self): + self._test_fp16_compress_wrapper(gradient_as_bucket_view=True) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_bf16_compress_wrapper_is_view(self): + self._test_bf16_compress_wrapper(gradient_as_bucket_view=True) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_builtin_ddp_comm_hooks_xccl_grad_is_view(self): + self._test_builtin_ddp_comm_hooks_xccl(gradient_as_bucket_view=True) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_powerSGD_ddp_comm_hook_xccl_grad_is_view(self): + self._test_powerSGD_ddp_comm_hook_xccl(gradient_as_bucket_view=True) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_ddp_comm_hook_allreduce_with_then_hook_xccl(self): + """ + This unit test verifies whether a DDP communication hook that calls allreduce and then + multiplies the result by ten and divides by two gives the expected result. + """ + process_group = self._get_process_group() + + def allreduce_with_then_hook( + state: object, bucket: dist.GradBucket + ) -> torch.futures.Future[torch.Tensor]: + tensors = [bucket.buffer() / self.world_size] + fut = process_group.allreduce(tensors).get_future() + + def mult(fut): + # Multiply the result by 10. + return 10 * fut.value()[0] + + def div(fut): + # Divide the result by 2. + return 0.5 * fut.value() + + return fut.then(mult).then(div) + + # Get GPU model with allreduce_with_then_hook registered. + gpu_model = self._gpu_model_with_ddp_comm_hook( + process_group, allreduce_with_then_hook + ) + + # check whether the grads are equal to what allreduce returns multiplied by 5. + # without the comm_hook, result would be still 0.25 * torch.ones(2, 2). + self._run_and_verify_hook(gpu_model, 8, 1.25 * torch.ones(2, 2)) + + class AcceptsParam(torch.nn.Module): + def __init__(self, p, factor): + super().__init__() + self.a = p + self.f = factor + + def forward(self, input): + return input + self.a * self.f + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_ddp_weight_sharing(self): + process_group = self._get_process_group() + + size = 2048 * 2048 + dev = self.rank + world = self.world_size + + p = torch.nn.Parameter(torch.randn(size, requires_grad=True)) + + for try_set_to_none, use_bucket_view in product((False, True), (False, True)): + m = torch.nn.Sequential( + self.AcceptsParam(p, dev + 1), self.AcceptsParam(p, dev + 1) + ).xpu(dev) + + m = torch.nn.parallel.DistributedDataParallel( + m, + bucket_cap_mb=1, + gradient_as_bucket_view=use_bucket_view, + device_ids=[dev], + process_group=process_group, + ) + + for i in range(3): + m.zero_grad(set_to_none=try_set_to_none) + m(1).sum().backward() + + # Each param value is multiplied by "rank + 1" twice in forward, so the grad + # values produced by a particular rank should be 2. * (rank + 1). + # Summing these over ranks and dividing by world size gives the expected result: + analytic = torch.full_like( + p, 2.0 * (world * (world + 1.0) / 2.0) / world, device=dev + ) + for name, p in m.named_parameters(): + self.assertEqual( + p.grad, + analytic, + "mismatch at " + + name + + ".grad for " + + f"set_to_none = {try_set_to_none}, use_bucket_view = {use_bucket_view}", + ) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_ddp_packed_sequence(self): + """ + Tests that DDP with ``device_ids`` specified can run a forward and + backward pass with ``PackedSequence`` s with parity compared to a local + version of the model. + """ + store = c10d.FileStore(self.file_name, self.world_size) + process_group = dist.init_process_group( + "xccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + seqs = ["sequence_sequence", "seq", "sequence"] + vocab = [""] + sorted({ch for seq in seqs for ch in seq}) + vectorized_seqs = [[vocab.index(tok) for tok in seq] for seq in seqs] + # Set the seed to make the embedding and LSTM deterministic (even + # across ranks since DDP broadcasts parameters from rank 0) + torch.manual_seed(0) + embed = nn.Embedding(len(vocab), 4) # keep on CPU + lstm = nn.LSTM(input_size=4, hidden_size=2, batch_first=True).to(self.rank) + lstm_ddp = DistributedDataParallel( + copy.deepcopy(lstm), + device_ids=[self.rank], + process_group=process_group, + ) + for p1, p2 in zip(lstm.parameters(), lstm_ddp.module.parameters()): + self.assertEqual(p1, p2) + seq_lengths = torch.LongTensor(list(map(len, vectorized_seqs))) + seq_tensor = torch.Tensor( + torch.zeros((len(vectorized_seqs), seq_lengths.max())) + ).long() + for i, (seq, seq_len) in enumerate(zip(vectorized_seqs, seq_lengths)): + seq_tensor[i, :seq_len] = torch.LongTensor(seq) + seq_lengths, permutation_idx = seq_lengths.sort(0, descending=True) + seq_tensor = seq_tensor[permutation_idx] + embedded_seq_tensor = embed(seq_tensor) + packed_input = torch.nn.utils.rnn.pack_padded_sequence( + embedded_seq_tensor, + seq_lengths, + batch_first=True, + ) + packed_input_ddp = torch.nn.utils.rnn.pack_padded_sequence( + embedded_seq_tensor.detach().clone(), + seq_lengths, + batch_first=True, + ) + # Move the input to GPU explicitly for the local model + packed_output, (ht, ct) = lstm(packed_input.to(self.rank)) + # Let DDP move the input to GPU internally + packed_output_ddp, (ht_ddp, ct_ddp) = lstm_ddp(packed_input_ddp) + self.assertEqual(packed_output.data, packed_output_ddp.data) + self.assertEqual(ht, ht_ddp) + self.assertEqual(ct, ct_ddp) + packed_output.data.sum().backward() + packed_output_ddp.data.sum().backward() + for p1, p2 in zip(lstm.parameters(), lstm_ddp.parameters()): + self.assertEqual(p1.grad, p2.grad) + + # error: input dense tensor has to be contiguous + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_channels_last_contig(self): + process_group = self._get_process_group() + device = torch.device(f"xpu:{self.rank}") + tensor = torch.ones((2, 16, 768, 1152), dtype=torch.float32, device=device).to( + memory_format=torch.channels_last + ) + process_group.broadcast([tensor]).wait() + + +class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase): + @property + def device(self): + return f"xpu:{self.rank}" + + def setUp(self): + super().setUp() + self._spawn_processes() + + def tearDown(self): + super().tearDown() + try: + os.remove(self.file_name) + except OSError: + pass + + def _test_broadcast_coalesced(self, process_group, device, root_rank): + half = torch.float16 + + # No support for float16 for CPU tensors + if device == torch.device("cpu"): + half = torch.float32 + + target = torch.arange(60, dtype=half, device=device).chunk(5) + target += torch.arange(60, dtype=torch.float32, device=device).chunk(5) + target += torch.arange(60, dtype=half, device=device).chunk(5) + target += torch.arange(60, dtype=torch.float64, device=device).chunk(5) + target += torch.arange(60, dtype=half, device=device).chunk(5) + target += torch.arange(60, dtype=torch.float32, device=device).chunk(5) + + # The tensors to pass to broadcast are identical to the target + # only on the process that is the root of the broadcast. + if self.rank == root_rank: + tensors = [tensor.clone() for tensor in target] + else: + tensors = [torch.zeros_like(tensor) for tensor in target] + + if self.rank != root_rank: + self.assertNotEqual(tensors, target) + + c10d._broadcast_coalesced( + process_group, tensors, buffer_size=256, src=root_rank + ) + + if self.rank != root_rank: + self.assertEqual(tensors, target) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_broadcast_coalesced_xccl(self): + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + backend="xccl", store=store, rank=self.rank, world_size=self.world_size + ) + process_group = c10d.distributed_c10d._get_default_group() + device = torch.device("xpu:%d" % self.rank) + ranks = [0, 1] + for root_rank in ranks: + self._test_broadcast_coalesced(process_group, device, root_rank) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_all_reduce_coalesced_xccl(self): + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + backend="xccl", store=store, rank=self.rank, world_size=self.world_size + ) + process_group = c10d.distributed_c10d._get_default_group() + device = torch.device("xpu:%d" % self.rank) + tensors = [ + torch.full((60 + i,), self.rank + 1 + i, device=device, dtype=torch.float) + for i in range(5) + ] + torch.distributed.all_reduce_coalesced(tensors, group=process_group) + for i, t in enumerate(tensors): + self.assertEqual( + t, + torch.full_like( + t, self.world_size * (i + (self.world_size + 1.0) / 2.0) + ), + ) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_all_reduce_coalesced_manager_xccl(self): + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + backend="xccl", store=store, rank=self.rank, world_size=self.world_size + ) + process_group = c10d.distributed_c10d._get_default_group() + device = torch.device("xpu:%d" % self.rank) + tensors = [ + torch.full((60 + i,), self.rank + 1 + i, device=device, dtype=torch.float) + for i in range(5) + ] + with torch.distributed._coalescing_manager( + group=process_group, device=device, async_ops=True + ) as cm: + for tensor in tensors: + torch.distributed.all_reduce(tensor) + self.assertEqual(len(cm.works), 1) + cm.wait() + for i, t in enumerate(tensors): + self.assertEqual( + t, + torch.full_like( + t, self.world_size * (i + (self.world_size + 1.0) / 2.0) + ), + ) + + @requires_xccl() + @skip_if_lt_x_gpu(4) + def test_xccl_barrier(self): + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + backend="xccl", rank=self.rank, world_size=self.world_size, store=store + ) + + t = torch.tensor([self.rank + 1] * 10).xpu(2 * self.rank) + c10d.all_reduce(t) + expected_tensor = torch.tensor([3] * 10).xpu(2 * self.rank) + self.assertEqual(expected_tensor, t) + + # Test with new_group + pg = c10d.new_group([0, 1]) + t = torch.tensor([self.rank + 1] * 10).xpu(2 * self.rank) + pg.allreduce(t).wait() + self.assertEqual(expected_tensor, t) + + pg = c10d.new_group([0]) + if self.rank == 0: + t = torch.tensor([self.rank + 1] * 10).xpu(2 * self.rank) + expected_tensor = torch.tensor([self.rank + 1] * 10).xpu(2 * self.rank) + pg.allreduce(t).wait() + self.assertEqual(expected_tensor, t) + + pg = c10d.new_group([1]) + if self.rank == 1: + t = torch.tensor([self.rank + 1] * 10).xpu(2 * self.rank) + expected_tensor = torch.tensor([self.rank + 1] * 10).xpu(2 * self.rank) + pg.allreduce(t).wait() + self.assertEqual(expected_tensor, t) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_xccl_barrier_device_ids(self): + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + backend="xccl", rank=self.rank, world_size=self.world_size, store=store + ) + + c10d.barrier(device_ids=[self.rank]) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_xccl_barrier_device_ids_function_argument(self): + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + backend="xccl", rank=self.rank, world_size=self.world_size, store=store + ) + + with self.assertRaisesRegex(TypeError, "Invalid function argument"): + c10d.barrier(device_ids=self.rank) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_reduce_scatter_base_k(self): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + "xccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + output_tensor = torch.zeros(2, dtype=torch.int64).to(self.rank) + input_tensors = torch.arange(self.world_size * 2, dtype=torch.int64).to( + self.rank + ) + input_tensors = torch.reshape(input_tensors, (self.world_size, 2)) + dist.reduce_scatter_tensor(output_tensor, input_tensors) + self.assertEqual(output_tensor, input_tensors[self.rank] * self.world_size) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_reduce_scatter_tensor_coalesced(self): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + "xccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + output_tensors = torch.zeros(2, 2).to(self.rank) + input_tensors = [torch.ones(2, 2).to(self.rank) for _ in range(self.world_size)] + with dist._coalescing_manager(): + for i in range(self.world_size): + dist.reduce_scatter_tensor(output_tensors[i], input_tensors[i]) + self.assertEqual(output_tensors, input_tensors[self.rank] * self.world_size) + + +class SetDeviceMethod(Enum): + TORCH_XPU_SET = auto() # torch.xpu.set_device + COLLECTIVE_ARGUMENT = auto() # broadcast_object_list(device=) + + +class XCCLProcessGroupWithDispatchedCollectivesTests( + test_c10d_common.ProcessGroupWithDispatchedCollectivesTests +): + @requires_xccl() + @skip_if_lt_x_gpu(1) + def test_collectives(self): + self._test_collectives(backend="xccl") + + @requires_xccl() + @skip_if_lt_x_gpu(1) + def test_allreduce_coalesced(self): + self._test_allreduce_coalesced(backend="xccl") + + @requires_xccl() + @skip_if_lt_x_gpu(1) + def test_all_to_all_single(self): + self._test_all_to_all_single(backend="xccl") + + @requires_xccl() + @skip_if_lt_x_gpu(1) + def test_allgather_base(self): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + "xccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + device = "xpu" + tensor = torch.ones(10, 10, device=torch.device(device)) + output_tensor = torch.zeros(10, 10, device=torch.device(device)) + dist.all_gather_into_tensor(output_tensor, tensor) + self.assertEqual(output_tensor, tensor) + + +class LargeCommTest(test_c10d_common.AbstractLargeCommTest, MultiProcessTestCase): + def setUp(self): + super().setUp() + self._spawn_processes() + + def tearDown(self): + super().tearDown() + try: + os.remove(self.file_name) + except OSError: + pass + + @property + def device(self): + return self.rank + + @requires_xccl() + @skip_if_lt_x_gpu(4) + def test_new_group_local_sync(self): + self._test_new_group_local_sync(backend="xccl") + + @requires_xccl() + @skip_if_lt_x_gpu(4) + def test_new_group_local_sync_sanity_check(self): + self._test_new_group_local_sync_sanity_check(backend="xccl") + + @requires_xccl() + @skip_if_lt_x_gpu(4) + def test_new_group_local_sync_duplicated_pg(self): + self._test_new_group_local_sync_duplicate_pg(backend="xccl") + + def _init_two_pg2_subgroups(self, world_size: int = 4): + if world_size != 4: + raise NotImplementedError( + f"need world size of 4 to get 2 subgroup PGs, but got world size of {world_size}" + ) + store = c10d.FileStore(self.file_name, world_size) + c10d.init_process_group( + backend="xccl", store=store, rank=self.rank, world_size=world_size + ) + # every rank creates the same sub groups + # including unused sub groups in the current rank + a_group = c10d.new_group([0, 1]) + b_group = c10d.new_group([2, 3]) + return a_group if self.rank < 2 else b_group + + @requires_xccl() + @skip_if_lt_x_gpu(4) + def test_gather_subgroup(self): + world_size = 4 + if self.rank >= world_size: + # just easier to write the test for exactly 4 gpus, even if this test class increased to 8gpu later + return + + subgroup = self._init_two_pg2_subgroups(world_size) + device = torch.device("xpu:%d" % self.rank) + input = torch.ones((10,), device=device) * self.rank + if self.rank == 0 or self.rank == 2: + gather_list = [torch.empty_like(input) for _ in range(subgroup.size())] + torch.distributed.gather( + input, + gather_list=gather_list, + dst=self.rank, + group=subgroup, + async_op=False, + ) + for src in range(len(gather_list)): + expected = (torch.ones_like(input) * self.rank) + src + self.assertEqual(gather_list[src], expected) + else: + torch.distributed.gather( + input, + gather_list=None, + dst=self.rank - 1, + group=subgroup, + async_op=False, + ) + + @requires_xccl() + @skip_if_lt_x_gpu(4) + def test_gather_object_subgroup(self): + world_size = 4 + if self.rank >= world_size: + # just easier to write the test for exactly 4 gpus, even if this test class increased to 8gpu later + return + + subgroup = self._init_two_pg2_subgroups(world_size) + + # discrepancy #1 + # have to set device or else gather_object gets wrong device from 'current_device = _get_pg_default_device(group) + torch.xpu.set_device(self.rank) + + input = {"rank": self.rank} + if self.rank == 0 or self.rank == 2: + # discrepancy #2 + # another weird thing- what's the point of making me specify some empty objects in my list? + # empty list should be valid imo. (but it throws an error) + gather_list = [{}, {}] + torch.distributed.gather_object( + input, object_gather_list=gather_list, dst=self.rank, group=subgroup + ) + for src in range(len(gather_list)): + self.assertEqual(gather_list[src]["rank"], self.rank + src) + else: + torch.distributed.gather_object( + input, object_gather_list=None, dst=self.rank - 1, group=subgroup + ) + + @requires_xccl() + @skip_if_lt_x_gpu(4) + def test_reduce_subgroup(self): + world_size = 4 + if self.rank >= world_size: + return + subgroup = self._init_two_pg2_subgroups(world_size) + device = torch.device("xpu:%d" % self.rank) + x = torch.ones((10,), device=device) * self.rank + if self.rank == 0 or self.rank == 2: + expected = x + torch.ones((10,), device=device) * (self.rank + 1) + c10d.reduce(x, dst=self.rank, group=subgroup, async_op=False) + self.assertEqual(x, expected) + else: + c10d.reduce(x, dst=self.rank - 1, group=subgroup, async_op=False) + + # error: RuntimeError: Point-to-point communication as the first call is not supported now + @requires_xccl() + @skip_if_lt_x_gpu(4) + @parametrize("async_op", [True, False]) + def test_send_recv_subgroup(self, async_op): + world_size = 4 + if self.rank >= world_size: + return + subgroup = self._init_two_pg2_subgroups(world_size) + device = torch.device("xpu:%d" % self.rank) + if self.rank == 0 or self.rank == 2: + x = torch.empty((10,), device=device) + if async_op: + c10d.irecv(x, src=self.rank + 1, group=subgroup).wait() + else: + c10d.recv(x, src=self.rank + 1, group=subgroup) + expected = torch.ones((10,), device=device) * (self.rank + 1) + self.assertEqual(x, expected) + else: + x = torch.ones((10,), device=device) * self.rank + if async_op: + c10d.isend(x, dst=self.rank - 1, group=subgroup).wait() + else: + c10d.send(x, dst=self.rank - 1, group=subgroup) + + @requires_xccl() + @skip_if_lt_x_gpu(4) + def test_broadcast_subgroup(self): + world_size = 4 + if self.rank >= world_size: + return + subgroup = self._init_two_pg2_subgroups(world_size) + device = torch.device("xpu:%d" % self.rank) + if self.rank == 0 or self.rank == 2: + x = torch.empty((10,), device=device) + c10d.broadcast(x, src=self.rank + 1, group=subgroup) + expected = torch.ones((10,), device=device) * (self.rank + 1) + self.assertEqual(x, expected) + else: + x = torch.ones((10,), device=device) * self.rank + c10d.broadcast(x, src=self.rank, group=subgroup) + + @requires_xccl() + @skip_if_lt_x_gpu(4) + @parametrize( + "set_device", + [SetDeviceMethod.TORCH_XPU_SET, SetDeviceMethod.COLLECTIVE_ARGUMENT], + ) + def test_send_recv_object_list_subgroup(self, set_device: SetDeviceMethod): + world_size = 4 + if self.rank >= world_size: + return + subgroup = self._init_two_pg2_subgroups(world_size) + if set_device == SetDeviceMethod.TORCH_XPU_SET: + torch.xpu.set_device(self.rank) + device = None + else: + device = torch.device("xpu:%d" % self.rank) + if self.rank == 0 or self.rank == 2: + x = [{}] + c10d.recv_object_list(x, src=self.rank + 1, group=subgroup, device=device) + expected = [{"rank": self.rank + 1}] + self.assertEqual(x, expected) + else: + x = [{"rank": self.rank}] + c10d.send_object_list(x, dst=self.rank - 1, group=subgroup, device=device) + + @requires_xccl() + @skip_if_lt_x_gpu(4) + @parametrize( + "set_device", + [SetDeviceMethod.TORCH_XPU_SET, SetDeviceMethod.COLLECTIVE_ARGUMENT], + ) + def test_broadcast_object_list_subgroup(self, set_device: SetDeviceMethod): + world_size = 4 + if self.rank >= world_size: + return + subgroup = self._init_two_pg2_subgroups(world_size) + if set_device == SetDeviceMethod.TORCH_XPU_SET: + torch.xpu.set_device(self.rank) + device = None + else: + device = torch.device("xpu:%d" % self.rank) + if self.rank == 0 or self.rank == 2: + x = [{}] + c10d.broadcast_object_list( + x, src=self.rank + 1, group=subgroup, device=device + ) + expected = [{"rank": self.rank + 1}] + self.assertEqual(x, expected) + else: + x = [{"rank": self.rank}] + c10d.broadcast_object_list(x, src=self.rank, group=subgroup, device=device) + + @requires_xccl() + @skip_if_lt_x_gpu(4) + def test_scatter_subgroup(self): + world_size = 4 + if self.rank >= world_size: + return + subgroup = self._init_two_pg2_subgroups(world_size) + device = torch.device("xpu:%d" % self.rank) + x = torch.empty((10,), device=device) + expected = torch.ones((10,), device=device) * self.rank + if self.rank == 0 or self.rank == 2: + c10d.scatter(x, scatter_list=None, src=self.rank + 1, group=subgroup) + else: + scatter_list = [ + torch.ones((10,), device=device) * (self.rank - 1), + torch.ones((10,), device=device) * self.rank, + ] + c10d.scatter(x, scatter_list=scatter_list, src=self.rank, group=subgroup) + self.assertEqual(x, expected) + + @requires_xccl() + @skip_if_lt_x_gpu(4) + def test_scatter_object_list_subgroup(self): + world_size = 4 + if self.rank >= world_size: + return + subgroup = self._init_two_pg2_subgroups(world_size) + torch.xpu.set_device(self.rank) + scatter_object_output_list = [None] + expected = [{"rank": self.rank}] + if self.rank == 0 or self.rank == 2: + c10d.scatter_object_list( + scatter_object_output_list=scatter_object_output_list, + scatter_object_input_list=None, + src=self.rank + 1, + group=subgroup, + ) + + else: + scatter_object_input_list = [ + {"rank": self.rank - 1}, + {"rank": self.rank}, + ] + c10d.scatter_object_list( + scatter_object_output_list=scatter_object_output_list, + scatter_object_input_list=scatter_object_input_list, + src=self.rank, + group=subgroup, + ) + self.assertEqual(scatter_object_output_list, expected) + + +instantiate_parametrized_tests(LargeCommTest) + +if __name__ == "__main__": + run_tests() diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index b123023d2fd3c..5ac5f1358cd18 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -282,6 +282,9 @@ if(USE_DISTRIBUTED) if(USE_NCCL) list(APPEND TORCH_PYTHON_LINK_LIBRARIES __caffe2_nccl) endif() + if(USE_XCCL) + list(APPEND TORCH_PYTHON_LINK_LIBRARIES torch::xccl) + endif() # Same for MPI. if(USE_MPI) list(APPEND TORCH_PYTHON_LINK_LIBRARIES MPI::MPI_CXX) @@ -356,6 +359,10 @@ if(BUILD_LIBTORCHLESS) target_compile_definitions(torch_python PRIVATE USE_C10D_NCCL) endif() + if(USE_XPU AND USE_C10D_XCCL) + target_compile_definitions(torch_python PRIVATE USE_C10D_XCCL) + endif() + if(USE_DISTRIBUTED) target_compile_definitions(torch_python PRIVATE USE_DISTRIBUTED) endif() diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index c14e195cd4367..beb16f0c402d2 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -300,6 +300,7 @@ class ProcessGroup: UNDEFINED = ... GLOO = ... NCCL = ... + XCCL = ... UCC = ... MPI = ... CUSTOM = ... @@ -679,3 +680,11 @@ class _SymmetricMemory: def stream_write_value32( tensor: torch.Tensor, offset: int, val: int ) -> torch.Tensor: ... + +class ProcessGroupXCCL(Backend): + def __init__( + self, + store: Store, + rank: int, + size: int, + ): ... diff --git a/torch/csrc/distributed/c10d/Ops.cpp b/torch/csrc/distributed/c10d/Ops.cpp index 6251bfa1817dd..9d4cadf492334 100644 --- a/torch/csrc/distributed/c10d/Ops.cpp +++ b/torch/csrc/distributed/c10d/Ops.cpp @@ -79,6 +79,7 @@ namespace { } IMPL_SEND(CPU) +IMPL_SEND(XPU) IMPL_SEND(CUDA) IMPL_SEND(PrivateUse1) @@ -94,6 +95,7 @@ IMPL_SEND(PrivateUse1) } IMPL_RECV(CPU) +IMPL_RECV(XPU) IMPL_RECV(CUDA) IMPL_RECV(PrivateUse1) @@ -108,6 +110,7 @@ IMPL_RECV(PrivateUse1) } IMPL_RECV_ANY_SOURCE(CPU) +IMPL_RECV_ANY_SOURCE(XPU) IMPL_RECV_ANY_SOURCE(CUDA) IMPL_RECV_ANY_SOURCE(PrivateUse1) @@ -131,6 +134,7 @@ IMPL_RECV_ANY_SOURCE(PrivateUse1) } IMPL_REDUCE(CPU) +IMPL_REDUCE(XPU) IMPL_REDUCE(CUDA) IMPL_REDUCE(PrivateUse1) @@ -156,6 +160,7 @@ IMPL_REDUCE(PrivateUse1) } IMPL_BROADCAST(CPU) +IMPL_BROADCAST(XPU) IMPL_BROADCAST(CUDA) IMPL_BROADCAST(PrivateUse1) @@ -181,6 +186,7 @@ IMPL_BROADCAST(PrivateUse1) IMPL_ALLREDUCE(CPU) IMPL_ALLREDUCE(CUDA) +IMPL_ALLREDUCE(XPU) IMPL_ALLREDUCE(PrivateUse1) #define IMPL_ALLREDUCE_COALESCED(DEV) \ @@ -198,6 +204,7 @@ IMPL_ALLREDUCE(PrivateUse1) } IMPL_ALLREDUCE_COALESCED(CPU) +IMPL_ALLREDUCE_COALESCED(XPU) IMPL_ALLREDUCE_COALESCED(CUDA) IMPL_ALLREDUCE_COALESCED(PrivateUse1) @@ -222,6 +229,7 @@ IMPL_ALLREDUCE_COALESCED(PrivateUse1) // NOLINTBEGIN(cppcoreguidelines-pro-type-const-cast) IMPL_ALLGATHER(CPU) +IMPL_ALLGATHER(XPU) IMPL_ALLGATHER(CUDA) IMPL_ALLGATHER(PrivateUse1) @@ -242,6 +250,7 @@ IMPL_ALLGATHER(PrivateUse1) } IMPL__ALLGATHER_BASE(CPU) +IMPL__ALLGATHER_BASE(XPU) IMPL__ALLGATHER_BASE(CUDA) IMPL__ALLGATHER_BASE(PrivateUse1) @@ -258,6 +267,7 @@ IMPL__ALLGATHER_BASE(PrivateUse1) } IMPL_ALLGATHER_COALESCED(CPU) +IMPL_ALLGATHER_COALESCED(XPU) IMPL_ALLGATHER_COALESCED(CUDA) IMPL_ALLGATHER_COALESCED(PrivateUse1) @@ -273,6 +283,7 @@ IMPL_ALLGATHER_COALESCED(PrivateUse1) } IMPL_ALLGATHER_INTO_TENSOR_COALESCED(CPU) +IMPL_ALLGATHER_INTO_TENSOR_COALESCED(XPU) IMPL_ALLGATHER_INTO_TENSOR_COALESCED(CUDA) IMPL_ALLGATHER_INTO_TENSOR_COALESCED(PrivateUse1) @@ -296,6 +307,7 @@ IMPL_ALLGATHER_INTO_TENSOR_COALESCED(PrivateUse1) } IMPL_REDUCE_SCATTER(CPU) +IMPL_REDUCE_SCATTER(XPU) IMPL_REDUCE_SCATTER(CUDA) IMPL_REDUCE_SCATTER(PrivateUse1) @@ -320,6 +332,7 @@ IMPL_REDUCE_SCATTER(PrivateUse1) } IMPL__REDUCE_SCATTER_BASE(CPU) +IMPL__REDUCE_SCATTER_BASE(XPU) IMPL__REDUCE_SCATTER_BASE(CUDA) IMPL__REDUCE_SCATTER_BASE(PrivateUse1) @@ -341,6 +354,7 @@ IMPL__REDUCE_SCATTER_BASE(PrivateUse1) } IMPL_REDUCE_SCATTER_TENSOR_COALESCED(CPU) +IMPL_REDUCE_SCATTER_TENSOR_COALESCED(XPU) IMPL_REDUCE_SCATTER_TENSOR_COALESCED(CUDA) IMPL_REDUCE_SCATTER_TENSOR_COALESCED(PrivateUse1) @@ -360,6 +374,7 @@ IMPL_REDUCE_SCATTER_TENSOR_COALESCED(PrivateUse1) } IMPL_GATHER(CPU) +IMPL_GATHER(XPU) IMPL_GATHER(CUDA) IMPL_GATHER(PrivateUse1) @@ -382,6 +397,7 @@ IMPL_GATHER(PrivateUse1) } IMPL_SCATTER(CPU) +IMPL_SCATTER(XPU) IMPL_SCATTER(CUDA) IMPL_SCATTER(PrivateUse1) @@ -403,6 +419,7 @@ IMPL_SCATTER(PrivateUse1) } IMPL_ALLTOALL(CPU) +IMPL_ALLTOALL(XPU) IMPL_ALLTOALL(CUDA) IMPL_ALLTOALL(PrivateUse1) @@ -424,6 +441,7 @@ IMPL_ALLTOALL(PrivateUse1) } IMPL_ALLTOALL_BASE(CPU) +IMPL_ALLTOALL_BASE(XPU) IMPL_ALLTOALL_BASE(CUDA) IMPL_ALLTOALL_BASE(PrivateUse1) @@ -440,6 +458,7 @@ IMPL_ALLTOALL_BASE(PrivateUse1) } IMPL_BARRIER(CPU) +IMPL_BARRIER(XPU) IMPL_BARRIER(CUDA) IMPL_BARRIER(PrivateUse1) // NOLINTEND(performance-unnecessary-value-param) @@ -494,6 +513,7 @@ namespace { #define REGISTER_C10D_OP(FUNC) \ REGISTER_C10D_OP1(FUNC, CPU) \ REGISTER_C10D_OP1(FUNC, CUDA) \ + REGISTER_C10D_OP1(FUNC, XPU) \ REGISTER_C10D_OP1(FUNC, PrivateUse1) // Now we start to register ops with the three device keys diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index 5cba3a39629d4..7f4e929d23020 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -77,7 +77,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { NCCL = 2, UCC = 3, MPI = 4, - CUSTOM = 5, + XCCL = 5, + CUSTOM = 6, }; static std::string backendTypeToString(const BackendType& type) { @@ -86,6 +87,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { return "gloo"; case BackendType::NCCL: return "nccl"; + case BackendType::XCCL: + return "xccl"; case BackendType::UCC: return "ucc"; case BackendType::MPI: @@ -106,6 +109,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { return BackendType::GLOO; } else if (backend == "nccl") { return BackendType::NCCL; + } else if (backend == "xccl") { + return BackendType::XCCL; } else if (backend == "ucc") { return BackendType::UCC; } else if (backend == "mpi") { @@ -152,6 +157,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { return backendType_; } + inline bool backendSupportsSequenceNumbers(BackendType backendType) { + if (backendType == BackendType::GLOO || backendType == BackendType::NCCL || + backendType == BackendType::XCCL || backendType == BackendType::UCC) + return true; + return false; + } + virtual void startCoalescing(c10::DeviceType deviceType) { // only nccl has implemented startCoalescing so only execute for nccl // backends @@ -634,9 +646,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { virtual void setSequenceNumberForGroup() { auto backendType = getBackendType(); // TODO: HACK for backend name to get sequence number for that backend. - if (backendType == ProcessGroup::BackendType::GLOO || - backendType == ProcessGroup::BackendType::NCCL || - backendType == ProcessGroup::BackendType::UCC) { + if (backendSupportsSequenceNumbers(backendType)) { getDefaultBackend()->setSequenceNumberForGroup(); } else { TORCH_CHECK( @@ -655,9 +665,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { auto backendType = getBackendType(); // TODO: HACK for backend name to get sequence number for that backend. - if (backendType == ProcessGroup::BackendType::GLOO || - backendType == ProcessGroup::BackendType::NCCL || - backendType == ProcessGroup::BackendType::UCC) { + if (backendSupportsSequenceNumbers(backendType)) { return getDefaultBackend()->getSequenceNumberForGroup(); } else { TORCH_CHECK( @@ -752,6 +760,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { tensor = at::empty( {1}, at::TensorOptions().device(at::DeviceType::CUDA).dtype(at::kByte)); + } else if (backendType_ == c10d::ProcessGroup::BackendType::XCCL) { + // set xpu tensor for override cpu dispatch + tensor = at::empty( + {1}, + at::TensorOptions().device(at::DeviceType::XPU).dtype(at::kByte)); } else { // Default to using cpu implementation tensor = at::empty( diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp new file mode 100644 index 0000000000000..b2a900c92b8c0 --- /dev/null +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -0,0 +1,1794 @@ +#ifdef USE_C10D_XCCL + +#include +#include +#include + +namespace c10d { + +namespace { +const std::map xcclOps = { + {ReduceOp::MIN, ccl::reduction::min}, + {ReduceOp::MAX, ccl::reduction::max}, + {ReduceOp::SUM, ccl::reduction::sum}, + {ReduceOp::PRODUCT, ccl::reduction::prod}, +}; + +const std::map xcclDatatypes = { + {at::kByte, ccl::datatype::uint8}, + {at::kChar, ccl::datatype::int8}, + {at::kInt, ccl::datatype::int32}, + {at::kLong, ccl::datatype::int64}, + {at::kHalf, ccl::datatype::float16}, + {at::kFloat, ccl::datatype::float32}, + {at::kDouble, ccl::datatype::float64}, + {at::kBFloat16, ccl::datatype::bfloat16}, + {at::kBool, ccl::datatype::uint8}, + // use for allgather + {at::kFloat8_e5m2, ccl::datatype::uint8}, + {at::kFloat8_e4m3fn, ccl::datatype::uint8}, + {at::kFloat8_e4m3fnuz, ccl::datatype::uint8}, + {at::kFloat8_e5m2fnuz, ccl::datatype::uint8}, +}; + +bool computeLengthsAndCheckAndGetFlat( + const std::vector& tensors, + std::vector& lengths, + at::Tensor& flatTensor, + int64_t& flatLength) { + int64_t groupSize = tensors.size(); + auto firstTensor = tensors[0]; + int64_t totalSize = 0; + bool isFlat = true; + + auto storage = firstTensor.storage(); + int64_t firstStorageOffset = firstTensor.storage_offset(); + + for (int i = 0; i < groupSize; i++) { + auto& curTensor = tensors[i]; + int64_t length = curTensor.numel(); + lengths[i] = length; + totalSize += length; + + if (isFlat && + (!storage.is_alias_of(curTensor.storage()) || + curTensor.storage_offset() != + firstStorageOffset + totalSize - length)) { + isFlat = false; + } + } + + flatLength = totalSize; + + if (isFlat) { + flatTensor = firstTensor; + } else { + flatTensor = at::empty({totalSize}, firstTensor.options()); + } + + return isFlat; +} + +bool check_same_size(const std::vector& input_tensors) { + for (const auto& input_tensor : input_tensors) { + if (!input_tensors[0].is_same_size(input_tensor)) { + return false; + } + } + return true; +} + +void check_xpu_single_tensor( + const at::Tensor& tensor, + const bool p2p = false // whether operation is a P2P operation +) { + if (!tensor.is_xpu() || tensor.is_sparse() || tensor.is_complex()) { + C10_THROW_ERROR( + ValueError, "Tensors must be XPU and dense and non-complex"); + + // Skip the following requirements for P2P operations + if (!tensor.is_contiguous(tensor.suggest_memory_format())) { + if (p2p) { + TORCH_WARN_ONCE( + "Detected non-contiguous tensor in P2P operations. It is user " + "responsibility to guarantee that source and destination tensors have " + "the same contiguity format."); + } else { + C10_THROW_ERROR(ValueError, "Tensors must be contiguous"); + } + } + } +} + +int64_t check_xpu_tensors_same_device(const std::vector& tensors) { + TORCH_CHECK_WITH( + ValueError, tensors.size() != 0, "Tensor list must be nonempty"); + + const auto& first = tensors.front(); + + int64_t total_numel = 0; + for (const auto& t : tensors) { + if (!t.is_xpu() || t.is_sparse() || t.is_complex()) { + C10_THROW_ERROR( + ValueError, "Tensors must be XPU and dense and non-complex"); + } + if (t.scalar_type() != first.scalar_type()) { + C10_THROW_ERROR(TypeError, "Tensors must have identical type"); + } + TORCH_CHECK_WITH( + ValueError, + t.get_device() == tensors[0].get_device(), + "Expected list of tensors on the same device"); + total_numel += t.numel(); + } + + return total_numel; +} + +ccl::datatype getXcclDataType( + at::ScalarType type, + bool is_reduction_op = false) { + if (is_reduction_op) + TORCH_CHECK( + !isFloat8Type(type), + "Float8 dtypes are not currenlty supported for XCCL reductions"); + auto it = xcclDatatypes.find(type); + TORCH_CHECK_WITH( + TypeError, + it != xcclDatatypes.end(), + "Input tensor data type is not supported for XCCL process group: ", + type); + return it->second; +} + +ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) { + try { + if (input.scalar_type() == at::kBool && reduceOp == ReduceOp::SUM) { + // Map sum to max for bool tensors to avoid overflow issues with sum. + return ccl::reduction::max; + } + // WA due to oneCCL not support AVG + if (reduceOp == ReduceOp::AVG) { + return ccl::reduction::sum; + } + return xcclOps.at(reduceOp); + } catch (const std::out_of_range&) { + C10_THROW_ERROR( + ValueError, + "Cannot use ReduceOp." + reduceOpToString(reduceOp) + " with XCCL"); + } +} + +void syncStream( + at::Device& device, + at::xpu::XPUEvent& xcclEvent, + at::xpu::XPUStream& xcclStream) { + xcclEvent.record(at::xpu::getCurrentXPUStream(device.index())); + xcclEvent.block(xcclStream); +} + +} // namespace + +constexpr int64_t kSynchronizeBusyWaitMillis = 10; +thread_local uint64_t ProcessGroupXCCL::xcclActiveGroupCounter_ = 0; + +ProcessGroupXCCL::WorkXCCL::WorkXCCL( + at::Device& device, + int rank, + OpType opType, + uint64_t seq, + const char* profilingTitle, + const std::optional>& inputs) + : Work(rank, opType, profilingTitle, inputs), + device_(device), + workStartTime_(std::chrono::steady_clock::now()), + seq_(seq) { + xcclEndEvent_ = std::make_shared(); +} + +ProcessGroupXCCL::WorkXCCL::WorkXCCL(const WorkXCCL& w) + : Work(w.rank_, w.opType_), + device_(w.device_), + xcclEndEvent_(w.xcclEndEvent_), + blockingWait_(w.blockingWait_), + workStartTime_(w.workStartTime_), + seq_(w.seq_) {} + +ProcessGroupXCCL::WorkXCCL::~WorkXCCL() = default; + +bool ProcessGroupXCCL::WorkXCCL::isCompleted() { + if (xcclEndEvent_ && xcclEndEvent_->query()) { + return true; + } + return false; +} + +void ProcessGroupXCCL::WorkXCCL::synchronize() { + synchronizeInternal(kNoTimeout); +} + +void ProcessGroupXCCL::WorkXCCL::synchronizeInternal( + std::chrono::milliseconds timeout) { + auto currentStream = at::xpu::getCurrentXPUStream(device_.index()); + xcclEndEvent_->block(currentStream); + if (blockingWait_) { + while (!isCompleted()) { + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = std::chrono::duration_cast( + currentTimepoint - workStartTime_); + if (timeElapsed >= timeout) { + std::string exceptionMsg = c10::str( + "Work ran time out after ", timeElapsed.count(), " milliseconds."); + TORCH_CHECK(false, exceptionMsg) + } + std::this_thread::sleep_for( + std::chrono::milliseconds(kSynchronizeBusyWaitMillis)); + } + } + if (barrierTensor_.defined()) { + auto currentStream = at::xpu::getCurrentXPUStream(device_.index()); + currentStream.synchronize(); + } +} + +bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) { + synchronizeInternal(timeout); + return true; +} + +constexpr const char* MULTI_DEVICE_ERROR_MSG = + "Expecting one tensor only but got multiple"; + +ProcessGroupXCCL::ProcessGroupXCCL( + const c10::intrusive_ptr& store, + int rank, + int size) + : Backend(rank, size), store_(store), xcclCommCounter_(0) { + blockingWait_ = getCvarBool(TORCH_XCCL_BLOCKING_WAIT, false); + init(); +} + +ProcessGroupXCCL::~ProcessGroupXCCL() = default; + +void ProcessGroupXCCL::setSequenceNumberForGroup() {} + +uint64_t ProcessGroupXCCL::getSequenceNumberForGroup() { + return seqCollective_; +} + +c10::intrusive_ptr ProcessGroupXCCL::initWork( + at::Device& device, + int rank, + OpType opType, + const char* profilingTitle, + const std::vector& inputs, + const std::vector& outputs) { + auto r = c10::make_intrusive( + device, + rank, + opType, + seqCollective_, + profilingTitle, + std::optional>(inputs)); + return r; +} + +std::shared_ptr ProcessGroupXCCL::getXCCLComm( + const std::string& deviceKey, + at::Device& device, + OpType opType, + int p2pRank, + bool isSendRecvSelf) { + if (deviceKey.empty()) { + C10_THROW_ERROR( + DistBackendError, + "Not able to create/get the XCCL Communicator since " + "the devices are empty "); + } + + usedDeviceIdxs_.insert(device.index()); + + { + std::lock_guard lock(mutex_); + if (devXCCLCommMap_.find(deviceKey) != devXCCLCommMap_.end()) { + return devXCCLCommMap_[deviceKey]; + } + } + + std::shared_ptr XCCLComm; + + bool batchP2P = xcclActiveGroupCounter_ > 0; + bool singleP2POp = isP2POp(opType, batchP2P); + + at::xpu::OptionalXPUGuard gpuGuard(device); + + int numRanks, rank; + if (!singleP2POp) { + numRanks = getSize(); + rank = getRank(); + } else if (isSendRecvSelf) { + numRanks = 1; + rank = 0; + } else { + numRanks = 2; + rank = p2pRank; + } + + c10::impl::VirtualGuardImpl impl(device.type()); + c10::Stream stream = + impl.getStreamFromGlobalPool(device, /*isHighPriority=*/false); + sycl::queue& q = c10::xpu::XPUStream(stream).queue(); + + auto ctx = ccl::create_context(q.get_context()); + ccl::vector_class> devs_rank; + devs_rank.emplace_back(rank, ccl::create_device(q.get_device())); + + auto xccl_kvs = get_kvs(rank_, *store_, singleP2POp, deviceKey, p2pRank); + auto comms = ccl::create_communicators(numRanks, devs_rank, ctx, xccl_kvs); + XCCLComm = std::make_shared(std::move(comms[0])); + + RECORD_PARAM_COMMS( + 0, // seq + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + rank, // rank + "init", // collective name + 0, // inNelems + 0, // outNelems + at::kByte, // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + size_); // worldSize + + std::lock_guard lock(mutex_); + devXCCLCommMap_.emplace(deviceKey, XCCLComm); + xcclStreamsMap_.emplace(deviceKey, std::move(stream)); + xcclEventsMap_.emplace(deviceKey, at::xpu::XPUEvent()); + + return XCCLComm; +} + +void ProcessGroupXCCL::groupStart() { + ccl::group_start(); + ++xcclActiveGroupCounter_; +} + +void ProcessGroupXCCL::groupEnd() { + ccl::group_end(); + --xcclActiveGroupCounter_; +} + +// TODO: wait p2p enable +static constexpr int CoalActive = 0x01, CoalColl = 0x02, CoalP2P = 0x04; +void ProcessGroupXCCL::startCoalescing() { + if (coalescing_state_ & CoalP2P) { + seqP2P_++; + } else { + seqCollective_++; + } + coalescedDevice_.set_index(-1); + coalescedComm_ = nullptr; + coalescing_state_ |= CoalActive; + groupStart(); +} + +c10::intrusive_ptr ProcessGroupXCCL::endCoalescing(OpType optype) { + if (coalescedComm_ == nullptr) { + // There is no actual work being coalesced, return here + groupEnd(); + coalescing_state_ = 0; + return nullptr; + } + TORCH_CHECK( + coalescedDevice_.index() >= 0, + "Somthing went wrong. Did you call end_coalescing before start_coalescing?"); + + auto comm = coalescedComm_; + auto device = coalescedDevice_; + + const auto key = std::to_string(device.index()); + auto stream = xcclStreamsMap_.at(key); + + auto work = initWork(device, rank_, optype); + work->blockingWait_ = blockingWait_; + + groupEnd(); + + work->xcclEndEvent_->record(stream); + + coalescing_state_ = 0; + coalescedComm_ = nullptr; + return work; +} + +c10::intrusive_ptr ProcessGroupXCCL::endCoalescing() { + // Default OpType to COALESCED if not specified + return endCoalescing(OpType::COALESCED); +} + +template +c10::intrusive_ptr ProcessGroupXCCL::collective( + std::vector& inputs, + std::vector& outputs, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType, + const char* profilingTitle) { + seqCollective_++; + auto device = inputs[0].device(); + const auto key = std::to_string(device.index()); + auto comm = getXCCLComm(key, device, opType); + + if (coalescing_state_ & CoalActive) { + coalescing_state_ |= CoalColl; + if (coalescedDevice_.index() < 0) { + coalescedDevice_ = device; + } else { + TORCH_CHECK( + coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG); + } + if (coalescedComm_ == nullptr) { + coalescedComm_ = comm; + } else { + TORCH_CHECK(coalescedComm_ == comm, MULTI_DEVICE_ERROR_MSG); + } + } + + auto stream = xcclStreamsMap_.at(key); + syncStream(device, xcclEventsMap_[key], stream); + + c10::intrusive_ptr work; + work = initWork(device, rank_, opType); + + work->outputs_ = std::make_shared>(outputs); + + at::xpu::OptionalXPUGuard gpuGuard(device); + + pre(stream, work); + + for (const auto i : c10::irange(inputs.size())) { + c10::xpu::XPUCachingAllocator::recordStream( + inputs[i].storage().data_ptr(), stream); + fn(inputs[i], outputs[i], *comm, stream); + } + + post(stream, work); + + if (!coalescing_state_) { + work->xcclEndEvent_->record(stream); + } + + std::vector streams = {stream.unwrap()}; + c10::MultiStreamGuard streamGuard(streams); + std::vector devices{device}; + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get()), devices); + work->future_->markCompleted(at::IValue(*work->outputs_)); + work->blockingWait_ = blockingWait_; + + return work; +} + +template +c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( + at::Tensor& tensor, + Fn fn, + int peer, + OpType opType, + const char* profilingTitle) { + auto device = tensor.device(); + std::string key; + int p2pRank = 0, p2pTargetRank = 0; + bool isSendRecvSelf = false; + + bool batchP2P = xcclActiveGroupCounter_ > 0; + if (batchP2P) { + key = std::to_string(device.index()); + p2pRank = rank_; + p2pTargetRank = peer; + } else { + int lowRank = rank_ < peer ? rank_ : peer; + int highRank = rank_ < peer ? peer : rank_; + key = std::to_string(lowRank) + ":" + std::to_string(highRank); + p2pRank = rank_ <= peer ? 0 : 1; + isSendRecvSelf = rank_ == peer; + p2pTargetRank = isSendRecvSelf ? 0 : 1 - p2pRank; + if (!coalescing_state_) { + seqP2P_++; + } + } + + auto comm = getXCCLComm(key, device, opType, p2pRank, isSendRecvSelf); + + if (coalescing_state_ & CoalActive) { + coalescing_state_ |= CoalP2P; + if (coalescedDevice_.index() < 0) { + coalescedDevice_ = device; + } else { + TORCH_CHECK( + coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG); + } + if (coalescedComm_ == nullptr) { + coalescedComm_ = comm; + } else { + TORCH_CHECK(coalescedComm_ == comm, MULTI_DEVICE_ERROR_MSG); + } + } + + auto stream = xcclStreamsMap_.at(key); + syncStream(device, xcclEventsMap_[key], stream); + + if (!coalescing_state_) { + c10::intrusive_ptr work; + work = initWork(device, rank_, opType); + work->outputs_ = std::make_shared>(); + work->outputs_->push_back(tensor); + + at::xpu::OptionalXPUGuard gpuGuard(device); + + c10::xpu::XPUCachingAllocator::recordStream( + tensor.storage().data_ptr(), stream); + + fn(tensor, *comm, stream, p2pTargetRank); + + work->xcclEndEvent_->record(stream); + work->blockingWait_ = blockingWait_; + std::vector streams = {stream.unwrap()}; + c10::MultiStreamGuard streamGuard(streams); + std::vector devices{device}; + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get()), devices); + work->future_->markCompleted(at::IValue(*work->outputs_)); + return work; + } else { + at::xpu::OptionalXPUGuard gpuGuard(device); + + c10::xpu::XPUCachingAllocator::recordStream( + tensor.storage().data_ptr(), stream); + + fn(tensor, *comm, stream, p2pTargetRank); + + return nullptr; + } +} + +c10::intrusive_ptr ProcessGroupXCCL::send( + std::vector& tensors, + int dstRank, + int /* unused */) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + // @lint-ignore CLANGTIDY + auto tensor = tensors.back(); + check_xpu_single_tensor(tensor, true); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + dstRank, // dst rank + "send", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + auto ret = pointToPoint( + tensor, + [&](at::Tensor& input, + xcclComm_t& comm, + at::xpu::XPUStream& stream, + int dst) { + auto xcclDataType = getXcclDataType(input.scalar_type()); + ccl::send( + input.data_ptr(), + (size_t)input.numel(), + xcclDataType, + dst, + comm, + ccl::create_stream(stream.queue())); + return; + }, + dstRank, + OpType::SEND, + c10::str("xccl:send ", rank_, "->", dstRank).c_str()); + return ret; +} + +c10::intrusive_ptr ProcessGroupXCCL::recv( + std::vector& tensors, + int srcRank, + int /* unused */) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + // @lint-ignore CLANGTIDY + auto tensor = tensors.back(); + check_xpu_single_tensor(tensor, true); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + srcRank, // src rank + "recv", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + auto ret = pointToPoint( + tensor, + [&](at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream, + int src) { + auto xcclDataType = getXcclDataType(output.scalar_type()); + ccl::recv( + output.data_ptr(), + (size_t)output.numel(), + xcclDataType, + src, + comm, + ccl::create_stream(stream.queue())); + return; + }, + srcRank, + OpType::RECV, + c10::str("xccl:recv ", rank_, "<-", srcRank).c_str()); + return ret; +} + +c10::intrusive_ptr ProcessGroupXCCL::gather( + std::vector>& outputTensors, + std::vector& inputTensors, + const GatherOptions& opts) { + static auto invalidArgument = [](const std::string& msg) { + C10_THROW_ERROR(ValueError, "ProcessGroupXCCL::gather: " + msg); + }; + + assertRootRank(invalidArgument, opts.rootRank, size_); + + TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + // @lint-ignore CLANGTIDY + auto inputTensor = inputTensors.back(); + + std::vector outputs; + + if (getRank() == opts.rootRank) { + if (outputTensors.size() != 1) { + std::stringstream ss; + ss << "requires a single-element output list containing a list with " + << getSize() << " tensors."; + invalidArgument(ss.str()); + } else if (outputTensors[0].size() != static_cast(getSize())) { + std::stringstream ss; + ss << "Incorrect output list size " << outputTensors[0].size() + << ". Output list size should be " << getSize() + << ", same as size of the process group."; + invalidArgument(ss.str()); + } + + const auto& options = inputTensor.options(); + const auto& sizes = inputTensor.sizes(); + assertTypeAndSizesMatch(invalidArgument, outputTensors[0], options, sizes); + outputs = outputTensors[0]; + } else { + // if not in the root rank, initialize outputs as empty list + if (outputTensors.size() != 0) { + invalidArgument("requires empty output on non-root"); + } + outputs = {}; + // append a empty tensor to the list, we don't use it but the + // `collective` template function requires it to invoke its function + outputs.emplace_back(); + } + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + opts.rootRank, // root rank + "gather", // collective name + inputTensor.numel(), // inNelems + inputTensor.numel() * this->getSize(), // outNelems + inputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSize + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + auto inputs = std::vector{inputTensor}; + return collective( + inputs, + outputs, // just to fit the collective interface + [&](at::Tensor& /* unused */, + at::Tensor& /* unused */, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + const auto root = opts.rootRank; + if (getRank() == root) { + for (auto output : outputs) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + } + } + { + auto xcclDataType = getXcclDataType(inputTensor.scalar_type()); + if (rank_ == root) { + for (const auto r : c10::irange(size_)) { + if (r != root) { + // do receive + ccl::recv( + outputs[r].data_ptr(), + (size_t)inputTensor.numel(), + xcclDataType, + r, + comm, + ccl::create_stream(stream.queue())); + } else { + // on its own rank, simply copy from the input + outputs[r].copy_(inputTensor); + } + } + } else { + // do send + ccl::send( + inputTensor.data_ptr(), + (size_t)inputTensor.numel(), + xcclDataType, + root, + comm, + ccl::create_stream(stream.queue())); + } + return; + } + }, + OpType::GATHER); +} + +c10::intrusive_ptr ProcessGroupXCCL::scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ScatterOptions& opts) { + static auto invalidArgument = [](const std::string& msg) { + C10_THROW_ERROR(ValueError, "ProcessGroupXCCL::scatter: " + msg); + }; + + assertRootRank(invalidArgument, opts.rootRank, size_); + + TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto outputTensor = outputTensors.back(); + + std::vector inputs; + + if (getRank() == opts.rootRank) { + if (inputTensors.size() != 1) { + std::stringstream ss; + ss << "requires a single-element input list containing a list with " + << getSize() << " tensors."; + invalidArgument(ss.str()); + } else if (inputTensors[0].size() != static_cast(getSize())) { + std::stringstream ss; + ss << "Incorrect input list size " << inputTensors[0].size() + << ". Input list size should be " << getSize() + << ", same as size of the process group."; + invalidArgument(ss.str()); + } + + const auto& options = outputTensor.options(); + const auto& sizes = outputTensor.sizes(); + assertTypeAndSizesMatch(invalidArgument, inputTensors[0], options, sizes); + inputs = inputTensors[0]; + } else { + // if not in the root rank, initialize inputTensors as empty place holder + // with an empty list + if (inputTensors.size() != 0) { + invalidArgument("requires empty input on non-root"); + } + inputs = {}; + // append a empty tensor to the list, we don't use it but the + // `collective` template function requires it to invoke its function + inputs.emplace_back(); + } + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + opts.rootRank, // root rank + "scatter", // collective name + outputTensor.numel() * this->getSize(), // inNelems + outputTensor.numel(), // outNelems + outputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSize + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + const auto root = opts.rootRank; + + auto outputs = std::vector{outputTensor}; + return collective( + outputs, + inputs, // just to fit the collective interface + [&](at::Tensor& /* unused */, + at::Tensor& /* unused */, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + if (getRank() == root) { + for (auto input : inputs) { + c10::xpu::XPUCachingAllocator::recordStream( + input.storage().data_ptr(), stream); + } + } + { + if (rank_ == root) { + for (const auto r : c10::irange(size_)) { + if (r != root) { + // do send + size_t send_count = inputs[r].numel(); + auto send_type = getXcclDataType(inputs[r].scalar_type()); + ccl::send( + inputs[r].data_ptr(), + send_count, + send_type, + r, + comm, + ccl::create_stream(stream.queue())); + } else { + // on its own rank, simply copy from the input + outputTensor.copy_(inputs[r]); + } + } + } else { + // do receive + size_t recv_count = outputTensor.numel(); + auto recv_type = getXcclDataType(outputTensor.scalar_type()); + ccl::recv( + outputTensor.data_ptr(), + recv_count, + recv_type, + root, + comm, + ccl::create_stream(stream.queue())); + } + + return; + } + }, + OpType::SCATTER); +} + +c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( + at::Tensor& tensor, + const AllreduceOptions& opts) { + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type(), true); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + auto ccl_stream = ccl::create_stream(stream.queue()); + ccl::allreduce( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + comm, + ccl::create_stream(stream.queue())); + // WA due to oneCCL not support AVG + if (opts.reduceOp == ReduceOp::AVG) { + auto divisor = getSize(); + output.div_(divisor); + } + return; + }, + OpType::ALLREDUCE, + "xccl:all_reduce"); +} + +c10::intrusive_ptr ProcessGroupXCCL::allreduce( + std::vector& tensors, + const AllreduceOptions& opts) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto tensor = tensors.back(); + check_xpu_single_tensor(tensor); + + // @lint-ignore CLANGTIDY + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + rank_, // rank + "allreduce", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + size_); // worldSize + + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type(), true); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::allreduce( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + comm, + ccl::create_stream(stream.queue())); + // WA due to oneCCL not support AVG + if (opts.reduceOp == ReduceOp::AVG) { + auto divisor = getSize(); + output.div_(divisor); + } + return; + }, + OpType::ALLREDUCE, + "xccl:all_reduce"); +} + +c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( + std::vector& tensors, + const AllreduceCoalescedOptions& opts) { + auto total_numel = check_xpu_tensors_same_device(tensors); + + // @lint-ignore CLANGTIDY + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + rank_, // rank + "allreduce_coalesced", // collective name + total_numel, // inNelems + total_numel, // outNelems + tensors[0].scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + return collectiveCoalesced( + tensors, + tensors, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type(), true); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::allreduce( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + comm, + ccl::create_stream(stream.queue())); + // WA due to oneCCL not support AVG + if (opts.reduceOp == ReduceOp::AVG) { + auto divisor = getSize(); + output.div_(divisor); + } + return; + }, + OpType::COALESCED, + "xccl:allreduce_coalesced"); +} + +c10::intrusive_ptr ProcessGroupXCCL::broadcast( + std::vector& tensors, + const BroadcastOptions& opts) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto tensor = tensors.back(); + check_xpu_single_tensor(tensor); + + // @lint-ignore CLANGTIDY + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + opts.rootRank, // root rank + "broadcast", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + const auto root = opts.rootRank + opts.rootTensor; + + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type()); + ccl::broadcast( + input.data_ptr(), + (size_t)input.numel(), + xcclDataType, + root, + comm, + ccl::create_stream(stream.queue())); + return; + }, + OpType::BROADCAST, + "nccl:broadcast"); +} + +c10::intrusive_ptr ProcessGroupXCCL::_broadcast_oop( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const BroadcastOptions& opts) { + if (outputTensor.numel() != inputTensor.numel()) { + C10_THROW_ERROR( + ValueError, + "Tensor input and output of _broadcast_oop must have the same number of elements "); + } + const auto root = opts.rootRank + opts.rootTensor; + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type()); + ccl::broadcast( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + root, + comm, + ccl::create_stream(stream.queue())); + return; + }, + OpType::BROADCAST, + "xccl:_broadcast_oop"); +} + +c10::intrusive_ptr ProcessGroupXCCL::reduce( + std::vector& tensors, + const ReduceOptions& opts) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto tensor = tensors.back(); + check_xpu_single_tensor(tensor); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + opts.rootRank, // root rank + "reduce", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + const int root = opts.rootRank + opts.rootTensor; + const auto xcclDataType = getXcclDataType(input.scalar_type(), true); + const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::reduce( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + root, + comm, + ccl::create_stream(stream.queue())); + // WA due to oneCCL not support AVG + if (opts.reduceOp == ReduceOp::AVG && getRank() == root) { + auto divisor = getSize(); + output.div_(divisor); + } + return; + }, + OpType::REDUCE, + "xccl:reduce"); +} + +c10::intrusive_ptr ProcessGroupXCCL::_reduce_oop( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceOptions& opts) { + TORCH_CHECK_WITH( + ValueError, + outputTensor.numel() == inputTensor.numel(), + "Tensor input and output of _reduce_oop must have the same number of elements"); + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + const int root = opts.rootRank + opts.rootTensor; + const auto xcclDataType = getXcclDataType(input.scalar_type(), true); + const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::reduce( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + root, + comm, + ccl::create_stream(stream.queue())); + // WA due to oneCCL not support AVG + if (opts.reduceOp == ReduceOp::AVG && getRank() == root) { + auto divisor = getSize(); + output.div_(divisor); + } + return; + }, + OpType::REDUCE, + "xccl:_reduce_oop"); +} + +c10::intrusive_ptr ProcessGroupXCCL::allgather( + std::vector>& outputTensors, + std::vector& inputTensors, + const AllgatherOptions& opts) { + TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + // @lint-ignore CLANGTIDY + auto inputTensor = inputTensors.back(); + check_xpu_single_tensor(inputTensor); + // @lint-ignore CLANGTIDY + std::vector& outputTensors_ = outputTensors.back(); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + rank_, // rank + "all_gather", // collective name + inputTensor.numel(), // inNelems + inputTensor.numel() * // outNelems + this->getSize(), + inputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSize + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + bool same_size = check_same_size(outputTensors_); + if (same_size) { + // Flatten a vector of tensors into a single, stacked tensor. + at::Tensor outputFlattened = newLikeFlat(outputTensors_); + + return collective( + inputTensor, + outputFlattened, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + auto xcclDataType = getXcclDataType(input.scalar_type()); + ccl::allgather( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + comm, + ccl::create_stream(stream.queue())); + return; + }, + [](at::xpu::XPUStream&, + c10::intrusive_ptr& work) {}, + [&](at::xpu::XPUStream& Stream, + c10::intrusive_ptr& work) { + // Copy the flattened output tensors to the outputs. + c10::StreamGuard guard(Stream); + for (const auto j : c10::irange(outputTensors_.size())) { + c10::xpu::XPUCachingAllocator::recordStream( + outputTensors_[j].storage().data_ptr(), Stream); + outputTensors_[j].copy_(outputFlattened[j], true); + } + }, + OpType::ALLGATHER, + "xccl:all_gather"); + } else { + const auto num_reduces = outputTensors_.size(); + startCoalescing(); + for (const int i : c10::irange(num_reduces)) { + auto& output = outputTensors_[i]; + auto& input = (i == rank_) ? inputTensor : output; + auto broadcastOpts = BroadcastOptions{ + static_cast(i), static_cast(0), opts.timeout}; + _broadcast_oop(output, input, broadcastOpts); + } + auto work = endCoalescing(OpType::ALLGATHER); + return work; + } +} + +c10::intrusive_ptr ProcessGroupXCCL::_allgather_base( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const AllgatherOptions& opts) { + check_xpu_single_tensor(input_tensor); + check_xpu_single_tensor(output_tensor); + + TORCH_CHECK_WITH( + TypeError, + input_tensor.dtype() == output_tensor.dtype(), + "output tensor must have the same type as input tensor"); + TORCH_CHECK_WITH( + ValueError, + input_tensor.numel() * size_ == output_tensor.numel(), + "output tensor size must be equal to world_size times input tensor size"); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + input_tensor, // inputTensors + output_tensor, // outputTensors + rank_, // rank + "_allgather_base", // collective name + input_tensor.numel(), // inNelems + output_tensor.numel(), // outNelems + output_tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSize + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + return collective( + input_tensor, + output_tensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + auto xcclDataType = getXcclDataType(input.scalar_type()); + ccl::allgather( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + comm, + ccl::create_stream(stream.queue())); + return; + }, + OpType::_ALLGATHER_BASE, + "xccl:_all_gather_base"); +} + +c10::intrusive_ptr ProcessGroupXCCL::allgather_into_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const AllgatherOptions& opts) { + return collectiveCoalesced( + inputs, + outputs, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type()); + ccl::allgather( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + comm, + ccl::create_stream(stream.queue())); + return; + }, + OpType::COALESCED, + "xccl:all_gather_into_tensor_coalesced"); +} + +c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ReduceScatterOptions& opts) { + TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + // @lint-ignore CLANGTIDY + auto outputTensor = outputTensors.back(); + check_xpu_single_tensor(outputTensor); + // @lint-ignore CLANGTIDY + auto inputTensors_ = inputTensors.back(); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + rank_, // rank + "reduce_scatter", // collective name + outputTensor.numel() * this->getSize(), // inNelems + outputTensor.numel(), // outNelems + outputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + bool same_size = check_same_size(inputTensors_); + if (same_size) { + // Flatten a vector of tensors into a single, stacked tensor. + at::Tensor inputFlattened = newLikeFlat(inputTensors_); + return collective( + inputFlattened, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::reduce_scatter( + input.data_ptr(), + output.data_ptr(), + (size_t)output.numel(), + xcclDataType, + xcclReduceOp, + comm, + ccl::create_stream(stream.queue())); + // WA due to oneCCL not support AVG + if (opts.reduceOp == ReduceOp::AVG) { + auto divisor = getSize(); + output.div_(divisor); + } + return; + }, + [&](at::xpu::XPUStream& Stream, + c10::intrusive_ptr& work) { + // Copy the input tensors to the flattened inputs. + c10::StreamGuard guard(Stream); + for (const auto j : c10::irange(inputTensors_.size())) { + c10::xpu::XPUCachingAllocator::recordStream( + inputTensors_[j].storage().data_ptr(), Stream); + inputFlattened[j].copy_(inputTensors_[j], true); + } + }, + [&](at::xpu::XPUStream&, + c10::intrusive_ptr&) {}, + OpType::REDUCE_SCATTER, + "xccl:reduce_scatter"); + } else { + const auto num_reduces = inputTensors_.size(); + startCoalescing(); + for (const int i : c10::irange(num_reduces)) { + auto& input = inputTensors_[i]; + auto& output = (i == rank_) ? outputTensor : input; + auto reduceOpts = ReduceOptions{ + opts.reduceOp, + static_cast(i), + static_cast(0), + opts.timeout}; + _reduce_oop(output, input, reduceOpts); + } + auto work = endCoalescing(OpType::REDUCE_SCATTER); + return work; + } +} + +c10::intrusive_ptr ProcessGroupXCCL::_reduce_scatter_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts) { + TORCH_CHECK_WITH( + TypeError, + inputTensor.dtype() == outputTensor.dtype(), + "input tensor must be the same type as the output tensor."); + TORCH_CHECK_WITH( + ValueError, + inputTensor.numel() == outputTensor.numel() * size_, + "input tensor must be the same size as output size times world size"); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensor, // inputTensor + outputTensor, // outputTensor + rank_, // rank + "_reduce_scatter_base", // collective name + inputTensor.numel(), // inNelems + outputTensor.numel(), // outNelems + outputTensor.scalar_type(), // dtype + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::reduce_scatter( + input.data_ptr(), + output.data_ptr(), + (size_t)output.numel(), + xcclDataType, + xcclReduceOp, + comm, + ccl::create_stream(stream.queue())); + // WA due to oneCCL not support AVG + if (opts.reduceOp == ReduceOp::AVG) { + auto divisor = getSize(); + output.div_(divisor); + } + return; + }, + OpType::_REDUCE_SCATTER_BASE, + "xccl:_reduce_scatter_base"); +} + +c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const ReduceScatterOptions& opts) { + return collectiveCoalesced( + inputs, + outputs, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::reduce_scatter( + input.data_ptr(), + output.data_ptr(), + (size_t)output.numel(), + xcclDataType, + xcclReduceOp, + comm, + ccl::create_stream(stream.queue())); + // WA due to oneCCL not support AVG + if (opts.reduceOp == ReduceOp::AVG) { + auto divisor = getSize(); + output.div_(divisor); + } + return; + }, + OpType::COALESCED, + "xccl:reduce_scatter_tensor_coalesced"); +} + +c10::intrusive_ptr ProcessGroupXCCL::barrier(const BarrierOptions& opts) { + RECORD_PARAM_COMMS( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + rank_, // rank + "barrier", // collective name + 0, // inNelems + 0, // outNelems + at::kByte, // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + // Device to use for barrier + int barDevIdx = -1; + + // See nccl barrier comments + if (!opts.device_ids.empty()) { + barDevIdx = opts.device_ids[0]; + } else if (getBoundDeviceId()) { + barDevIdx = (*getBoundDeviceId()).index(); + } else if (!usedDeviceIdxs_.empty()) { + barDevIdx = *usedDeviceIdxs_.begin(); + } else { + barDevIdx = + static_cast(rank_ % at::detail::getXPUHooks().getNumGPUs()); + } + + // todo: use barrier instead of allreduce + TORCH_CHECK_WITH( + ValueError, + barDevIdx >= 0, + "Failed to infer a GPU device id to perform barrier. "); + auto barDevice = at::Device(at::DeviceType::XPU, barDevIdx); + + at::Tensor barrierTensor = + at::zeros({1}, at::TensorOptions().device(barDevice).dtype(at::kFloat)); + + auto work = allreduce_impl(barrierTensor); + + auto xcclWork = dynamic_cast(work.get()); + TORCH_CHECK(xcclWork); + xcclWork->barrierTensor_ = std::move(barrierTensor); + return work; +} + +c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + std::vector& outputSplitSizes, + std::vector& inputSplitSizes, + const AllToAllOptions& /* unused */) { + check_xpu_single_tensor(outputTensor, true); + check_xpu_single_tensor(inputTensor, true); + if (outputSplitSizes.size() == 0 && inputSplitSizes.size() == 0) { + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensor, // inputTensor + outputTensor, // outputTensor + rank_, // rank + "all_to_all", // collective name + inputTensor.numel(), // inNelems + outputTensor.numel(), // outNelems + inputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + TORCH_CHECK( + outputTensor.numel() == inputTensor.numel() && + outputTensor.scalar_type() == inputTensor.scalar_type(), + "xpu_alltoall_base: tensors are not equal in size or data type"); + TORCH_CHECK( + outputTensor.size(0) % size_ == 0, + "xpu_alltoall_base: tensor's dim 0 does not divide equally across group size"); + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + auto xcclDataType = getXcclDataType(output.scalar_type()); + ccl::alltoall( + input.data_ptr(), + output.data_ptr(), + (size_t)output.numel() / comm.size(), + xcclDataType, + comm, + ccl::create_stream(stream.queue())); + return; + }, + OpType::ALLTOALL_BASE, + "xccl:all_to_all"); + } else { + c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); + c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensor, // inputTensor + outputTensor, // outputTensor + rank_, // rank + "all_to_allv", // collective name + inputTensor.numel(), // inNelems + outputTensor.numel(), // outNelems + inputTensor.scalar_type(), // dType + inputSplitSizes, // inSplitSizes + outputSplitSizes, // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + std::vector sendCounts(size_); + std::vector recvCounts(size_); + bool inputSplitsEqual = inputSplitSizes.size() == 0; + bool outputSplitsEqual = outputSplitSizes.size() == 0; + + size_t inLen = input.numel(); + size_t outLen = output.numel(); + if (inLen) + inLen /= (inputSplitsEqual ? size_ : input.size(0)); + if (outLen) + outLen /= (outputSplitsEqual ? size_ : output.size(0)); + + for (int i = 0; i < size_; i++) { + sendCounts[i] = + (inputSplitsEqual ? inLen : inputSplitSizes[i] * inLen); + recvCounts[i] = + (outputSplitsEqual ? outLen : outputSplitSizes[i] * outLen); + } + auto xcclDataType = getXcclDataType(output.scalar_type()); + ccl::alltoallv( + input.data_ptr(), + sendCounts, + output.data_ptr(), + recvCounts, + xcclDataType, + comm, + ccl::create_stream(stream.queue())); + return; + }, + OpType::ALLTOALL_BASE, + "xccl:all_to_all"); + } +} + +c10::intrusive_ptr ProcessGroupXCCL::alltoall( + std::vector& outputTensors, + std::vector& inputTensors, + const AllToAllOptions& /* unused */) { + auto device = outputTensors[0].device(); + int64_t total_numel = 0; + for (const auto r : c10::irange(outputTensors.size())) { + check_xpu_single_tensor(outputTensors[r], true); + check_xpu_single_tensor(inputTensors[r], true); + TORCH_CHECK( + device == outputTensors[r].device() && + device == inputTensors[r].device(), + "Tensors must be on the same device") + total_numel += inputTensors[r].numel(); + } + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + rank_, // rank + "all_to_all", // collective name + total_numel, // inNelems + total_numel, // outNelems + inputTensors.front().scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + return collective( + inputTensors, + outputTensors, + [&](at::Tensor& /* unused */, + at::Tensor& /* unused */, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + c10::OptionalStreamGuard stream_guard(stream.unwrap()); + at::Tensor flatInput; + at::Tensor flatOutput; + + std::vector sendCounts(size_); + std::vector recvCounts(size_); + + int64_t flatSendCount; + int64_t flatRecvCount; + + bool isInputFlat = computeLengthsAndCheckAndGetFlat( + inputTensors, sendCounts, flatInput, flatSendCount); + bool isOutputFlat = computeLengthsAndCheckAndGetFlat( + outputTensors, recvCounts, flatOutput, flatRecvCount); + if (!isInputFlat) { + auto flatInputSplits = flatInput.split_with_sizes( + c10::IntArrayRef((int64_t*)sendCounts.data(), sendCounts.size()), + 0); + + for (int i = 0; i < size_; i++) { + flatInputSplits[i].copy_(inputTensors[i].view({-1})); + } + } + + auto xcclDataType = getXcclDataType(flatOutput.scalar_type()); + ccl::event ret_evt; + ret_evt = ccl::alltoallv( + flatInput.data_ptr(), + sendCounts, + flatOutput.data_ptr(), + recvCounts, + xcclDataType, + comm, + ccl::create_stream(stream.queue())); + + if (!isOutputFlat) { + ret_evt.wait(); + auto flatOutputSplits = flatOutput.split_with_sizes( + c10::IntArrayRef((int64_t*)recvCounts.data(), recvCounts.size()), + 0); + + for (int i = 0; i < size_; i++) { + outputTensors[i].view({-1}).copy_(flatOutputSplits[i]); + } + } + stream.synchronize(); + return; + }, + OpType::ALLTOALL, + "xccl:all_to_all"); +} + +} // namespace c10d + +#endif // USE_C10D_XCCL diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp new file mode 100644 index 0000000000000..c30ca603c7ba0 --- /dev/null +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -0,0 +1,371 @@ +#pragma once + +#ifdef USE_C10D_XCCL +// We will define those flags in XCCL backend file instead of passing to gcc +// compiler. +#define CCL_ENABLE_ZE +#define CCL_ENABLE_SYCL + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +namespace c10d { + +static std::vector TORCH_XCCL_BLOCKING_WAIT = { + "TORCH_XCCL_BLOCKING_WAIT", + "XCCL_BLOCKING_WAIT"}; + +using xcclComm_t = ccl::communicator; +using XCCL_KVS = ccl::shared_ptr_class; +constexpr const char* XCCL_BACKEND_NAME = "xccl"; + +class TORCH_API ProcessGroupXCCL : public Backend { + public: + class WorkXCCL : public Work { + public: + WorkXCCL( + at::Device& device, + int rank, + OpType opType, + uint64_t seq, + const char* profilingTitle = nullptr, + const std::optional>& inputs = std::nullopt); + WorkXCCL(const WorkXCCL& w); + ~WorkXCCL() override; + + bool isCompleted() override; + + void abort() override { + TORCH_CHECK(false, "ProcessGroupXCCL::WorkXCCL::abort not implemented"); + } + + void synchronize() override; + + bool wait(std::chrono::milliseconds timeout = kNoTimeout) override; + + c10::intrusive_ptr getFuture() override { + return future_; + } + + uint64_t getSequencenumber() const override { + return seq_; + } + + std::vector result() override { + return *outputs_; + } + + protected: + at::Device device_; + std::shared_ptr xcclEndEvent_; + at::Tensor barrierTensor_; + bool blockingWait_ = false; + std::chrono::time_point workStartTime_; + uint64_t seq_; + + private: + void synchronizeInternal(std::chrono::milliseconds timeout); + std::shared_ptr> outputs_; + c10::intrusive_ptr future_; + friend class ProcessGroupXCCL; + }; + + ProcessGroupXCCL(const c10::intrusive_ptr& store, int rank, int size); + + C10_DEPRECATED ProcessGroupXCCL( + const c10::intrusive_ptr& store, + int rank, + int size, + const std::string& groupName) + : ProcessGroupXCCL(store, rank, size) {} + + ~ProcessGroupXCCL() override; + + const std::string getBackendName() const override { + return std::string(XCCL_BACKEND_NAME); + } + + void startCoalescing() override; + + c10::intrusive_ptr endCoalescing() override; + + c10::intrusive_ptr endCoalescing(OpType optype); + + std::shared_ptr getXCCLComm( + const std::string& deviceKey, + at::Device& device, + OpType opType, + int p2pRank = 0, + bool isSendRecvSelf = false); + + virtual c10::intrusive_ptr initWork( + at::Device& device, + int rank, + OpType opType, + const char* profilingTitle = nullptr, + const std::vector& inputs = {}, + const std::vector& outputs = {}); + + template + c10::intrusive_ptr collective( + at::Tensor& input, + at::Tensor& output, + Fn fn, + OpType opType, + const char* profilingTitle = nullptr) { + return collective( + input, + output, + fn, + [](at::xpu::XPUStream&, + c10::intrusive_ptr&) {}, + [](at::xpu::XPUStream&, + c10::intrusive_ptr&) {}, + opType, + profilingTitle); + } + + template + c10::intrusive_ptr collective( + at::Tensor& input, + at::Tensor& output, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType, + const char* profilingTitle = nullptr) { + auto inputs = std::vector{input}; + auto outputs = std::vector{output}; + return collective(inputs, outputs, fn, pre, post, opType, profilingTitle); + } + + template + c10::intrusive_ptr collective( + std::vector& inputs, + std::vector& outputs, + Fn fn, + OpType opType, + const char* profilingTitle = nullptr) { + return collective( + inputs, + outputs, + fn, + [](at::xpu::XPUStream&, + c10::intrusive_ptr&) {}, + [](at::xpu::XPUStream&, + c10::intrusive_ptr&) {}, + opType, + profilingTitle); + } + + template + c10::intrusive_ptr collective( + std::vector& inputs, + std::vector& outputs, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType, + const char* profilingTitle = nullptr); + + template + c10::intrusive_ptr collectiveCoalesced( + std::vector& input, + std::vector& output, + Fn fn, + OpType opType, + const char* profilingTitle = nullptr) { + return collective( + input, + output, + fn, + [](at::xpu::XPUStream&, + c10::intrusive_ptr&) { + ccl::group_start(); + }, + [](at::xpu::XPUStream&, + c10::intrusive_ptr&) { + ccl::group_end(); + }, + opType, + profilingTitle); + } + + template + c10::intrusive_ptr pointToPoint( + at::Tensor& tensor, + Fn fn, + int peer, + OpType opType, + const char* profilingTitle = nullptr); + + c10::intrusive_ptr allreduce_impl( + at::Tensor& tensor, + const AllreduceOptions& opts = AllreduceOptions()); + + c10::intrusive_ptr allreduce( + std::vector& tensors, + const AllreduceOptions& opts = AllreduceOptions()) override; + + c10::intrusive_ptr allreduce_coalesced( + std::vector& tensors, + const AllreduceCoalescedOptions& opts = + AllreduceCoalescedOptions()) override; + + c10::intrusive_ptr reduce( + std::vector& tensors, + const ReduceOptions& opts = ReduceOptions()) override; + + c10::intrusive_ptr _reduce_oop( + at::Tensor& outputTensors, + at::Tensor& inputTensors, + const ReduceOptions& opts = ReduceOptions()); + + c10::intrusive_ptr broadcast( + std::vector& tensors, + const BroadcastOptions& opts = BroadcastOptions()) override; + + c10::intrusive_ptr _broadcast_oop( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const BroadcastOptions& opts); + + c10::intrusive_ptr allgather( + std::vector>& outputTensors, + std::vector& inputTensors, + const AllgatherOptions& opts = AllgatherOptions()) override; + + c10::intrusive_ptr _allgather_base( + at::Tensor& outputbuffer, + at::Tensor& inputbuffer, + const AllgatherOptions& opts = AllgatherOptions()) override; + + c10::intrusive_ptr allgather_into_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const AllgatherOptions& opts = AllgatherOptions()) override; + + c10::intrusive_ptr reduce_scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override; + + c10::intrusive_ptr _reduce_scatter_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override; + + c10::intrusive_ptr reduce_scatter_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override; + + c10::intrusive_ptr barrier( + const BarrierOptions& opts = BarrierOptions()) override; + + c10::intrusive_ptr alltoall_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + std::vector& outputSplitSizes, + std::vector& inputSplitSizes, + const AllToAllOptions& opts = AllToAllOptions()) override; + + c10::intrusive_ptr alltoall( + std::vector& outputTensors, + std::vector& inputTensors, + const AllToAllOptions& opts = AllToAllOptions()) override; + + c10::intrusive_ptr send( + std::vector& tensors, + int dstRank, + int tag) override; + + c10::intrusive_ptr recv( + std::vector& tensors, + int srcRank, + int tag) override; + + void groupStart(); + + void groupEnd(); + + c10::intrusive_ptr gather( + std::vector>& outputTensors, + std::vector& inputTensors, + const GatherOptions& opts = GatherOptions()) override; + + c10::intrusive_ptr scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ScatterOptions& opts = ScatterOptions()) override; + + void setSequenceNumberForGroup() override; + + uint64_t getSequenceNumberForGroup() override; + + protected: + std::unordered_map xcclStreamsMap_; + std::unordered_map xcclEventsMap_; + std::unordered_map> devXCCLCommMap_; + c10::intrusive_ptr store_; + uint64_t xcclCommCounter_{0}; + std::mutex mutex_; + std::set usedDeviceIdxs_; + int coalescing_state_ = 0; + at::Device coalescedDevice_ = at::Device("xpu"); + std::shared_ptr coalescedComm_ = nullptr; + bool blockingWait_ = false; + static thread_local uint64_t xcclActiveGroupCounter_; + uint64_t seqCollective_{0}; + uint64_t seqP2P_{0}; + + private: + std::mutex kvs_mutex; + + ccl::shared_ptr_class get_kvs( + int rank, + c10d::Store& store, + bool singleP2POp = false, + const std::string& p2pKey = "", + int p2pRank = 0) { + std::lock_guard lock(kvs_mutex); + ccl::shared_ptr_class kvs; + std::string storeKey; + if (!singleP2POp) { + storeKey = std::to_string(xcclCommCounter_++); + } else { + storeKey = p2pKey; + } + // Rank 0 broadcast the bootstrap network information to other ranks + if (rank == 0 || (singleP2POp && p2pRank == 0)) { + kvs = ccl::create_main_kvs(); + ccl::kvs::address_type main_addr = kvs->get_address(); + auto ccl_kvs_addr = + std::vector(main_addr.begin(), main_addr.end()); + store.set(storeKey, ccl_kvs_addr); + } else { + auto ccl_kvs_addr = store.get(storeKey); + if (ccl_kvs_addr.size() != ccl::kvs::address_max_size) { + throw std::runtime_error("Unexpected ccl kvs addr from the store\n"); + } + ccl::kvs::address_type main_addr; + std::copy_n( + ccl_kvs_addr.begin(), ccl::kvs::address_max_size, main_addr.begin()); + kvs = ccl::create_kvs(main_addr); + } + return kvs; + } +}; +} // namespace c10d + +#endif // USE_C10D_XCCL diff --git a/torch/csrc/distributed/c10d/Utils.hpp b/torch/csrc/distributed/c10d/Utils.hpp index ea4a4653bc35f..e27ec363ba1cc 100644 --- a/torch/csrc/distributed/c10d/Utils.hpp +++ b/torch/csrc/distributed/c10d/Utils.hpp @@ -557,6 +557,31 @@ size_t computeLengthsAndOffsets( return offset; } +inline std::string reduceOpToString(c10d::ReduceOp op) { + switch (op) { + case c10d::ReduceOp::SUM: + return "SUM"; + case c10d::ReduceOp::PRODUCT: + return "PRODUCT"; + case c10d::ReduceOp::MIN: + return "MIN"; + case c10d::ReduceOp::MAX: + return "MAX"; + case c10d::ReduceOp::BAND: + return "BAND"; + case c10d::ReduceOp::BOR: + return "BOR"; + case c10d::ReduceOp::BXOR: + return "BXOR"; + case c10d::ReduceOp::AVG: + return "AVG"; + case c10d::ReduceOp::PREMUL_SUM: + return "PREMUL_SUM"; + default: + return "UNKNOWN"; + } +} + using RankType = uint32_t; using SizeType = uint64_t; diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 01fc8cb45a333..c01f2b4f4e208 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -38,6 +38,10 @@ #include #endif +#ifdef USE_C10D_XCCL +#include +#endif + #include #include #include @@ -2311,6 +2315,7 @@ The hook must have the following signature: .value("UNDEFINED", ::c10d::ProcessGroup::BackendType::UNDEFINED) .value("GLOO", ::c10d::ProcessGroup::BackendType::GLOO) .value("NCCL", ::c10d::ProcessGroup::BackendType::NCCL) + .value("XCCL", ::c10d::ProcessGroup::BackendType::XCCL) .value("UCC", ::c10d::ProcessGroup::BackendType::UCC) .value("MPI", ::c10d::ProcessGroup::BackendType::MPI) .value("CUSTOM", ::c10d::ProcessGroup::BackendType::CUSTOM) @@ -2946,6 +2951,23 @@ Example:: py::call_guard()); #endif +#ifdef USE_C10D_XCCL + auto processGroupXCCL = + intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupXCCL>( + module, "ProcessGroupXCCL", backend) + .def( + py::init([](const c10::intrusive_ptr<::c10d::Store>& store, + int rank, + int size) { + return c10::make_intrusive<::c10d::ProcessGroupXCCL>( + store, rank, size); + }), + py::arg("store"), + py::arg("rank"), + py::arg("size"), + py::call_guard()); +#endif + py::enum_<::c10d::OpType>(module, "OpType") .value("BROADCAST", ::c10d::OpType::BROADCAST) .value("ALLREDUCE", ::c10d::OpType::ALLREDUCE) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 8c8d7f2a8d8ee..3736f616b3326 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -89,6 +89,7 @@ "is_nccl_available", "is_torchelastic_launched", "is_ucc_available", + "is_xccl_available", "isend", "monitored_barrier", "new_group", @@ -132,6 +133,7 @@ _NCCL_AVAILABLE = True _GLOO_AVAILABLE = True _UCC_AVAILABLE = True +_XCCL_AVAILABLE = True _pickler = pickle.Pickler _unpickler = pickle.Unpickler @@ -195,6 +197,14 @@ def _export_c_types() -> None: except ImportError: _UCC_AVAILABLE = False +try: + from torch._C._distributed_c10d import ProcessGroupXCCL + + ProcessGroupXCCL.__module__ = "torch.distributed.distributed_c10d" + __all__ += ["ProcessGroupXCCL"] +except ImportError: + _XCCL_AVAILABLE = False + logger = logging.getLogger(__name__) PG_WRAPPER_STORE_PREFIX = "pg_wrapper" @@ -224,7 +234,7 @@ class Backend(str): """ An enum-like class for backends. - Available backends: GLOO, NCCL, UCC, MPI, and other registered backends. + Available backends: GLOO, NCCL, UCC, MPI, XCCL, and other registered backends. The values of this class are lowercase strings, e.g., ``"gloo"``. They can be accessed as attributes, e.g., ``Backend.NCCL``. @@ -244,21 +254,24 @@ class Backend(str): NCCL = "nccl" UCC = "ucc" MPI = "mpi" + XCCL = "xccl" _BackendPlugin = namedtuple("_BackendPlugin", ["creator_fn", "extended_api"]) _plugins: Dict[str, _BackendPlugin] = {} - backend_list = [UNDEFINED, GLOO, NCCL, UCC, MPI] + backend_list = [UNDEFINED, GLOO, NCCL, XCCL, UCC, MPI] default_device_backend_map: Dict[str, str] = { "cpu": GLOO, "cuda": NCCL, + "xpu": XCCL, } backend_capability: Dict[str, List[str]] = { GLOO: ["cpu", "cuda"], NCCL: ["cuda"], + XCCL: ["xpu"], UCC: ["cpu", "cuda"], MPI: ["cpu", "cuda"], } @@ -267,6 +280,7 @@ class Backend(str): UNDEFINED: ProcessGroup.BackendType.UNDEFINED, GLOO: ProcessGroup.BackendType.GLOO, NCCL: ProcessGroup.BackendType.NCCL, + XCCL: ProcessGroup.BackendType.XCCL, UCC: ProcessGroup.BackendType.UCC, MPI: ProcessGroup.BackendType.MPI, } @@ -1185,6 +1199,11 @@ def is_ucc_available() -> bool: return _UCC_AVAILABLE +def is_xccl_available() -> bool: + """Check if the XCCL backend is available.""" + return _XCCL_AVAILABLE + + def is_backend_available(backend: str) -> bool: """ Check backend availability. @@ -1437,6 +1456,10 @@ def _set_pg_timeout(timeout: timedelta, group: Optional[ProcessGroup] = None) -> backends.add(backend) # type: ignore[arg-type] elif is_gloo_available() and isinstance(backend, ProcessGroupGloo): backends.add(backend) # type: ignore[arg-type] + if torch.device("xpu") in devices and is_xccl_available(): + backend = group._get_backend(torch.device("xpu")) + if isinstance(backend, ProcessGroupXCCL): + backends.add(backend) # type: ignore[arg-type] if len(backends) == 0: warnings.warn("Set timeout is now only supported for either nccl or gloo.") for backend in backends: @@ -1472,7 +1495,7 @@ def init_process_group( Args: backend (str or Backend, optional): The backend to use. Depending on - build-time configurations, valid values include ``mpi``, ``gloo``, + build-time configurations, valid values include ``mpi``, ``gloo``, ``xccl``, ``nccl``, and ``ucc``. If the backend is not provided, then both a ``gloo`` and ``nccl`` backend will be created, see notes below for how multiple backends are managed. This field can be given as a lowercase string @@ -1752,10 +1775,9 @@ def _new_process_group_helper( "created, please use a different group name" ) - if device_id is not None and (device_id.index is None or device_id.type != "cuda"): + if device_id is not None and device_id.index is None: raise ValueError( - "init_process_group device_id parameter must be a cuda device with an " - "id, e.g. cuda:0, not just cuda or cpu" + "init_process_group device_id parameter must be a device with an index" ) # Note: _new_process_group_helper is only called from init_process_group, which always provides a timeout value @@ -1885,6 +1907,17 @@ def _new_process_group_helper( backend_prefix_store, group_rank, group_size, timeout=timeout ) backend_type = ProcessGroup.BackendType.UCC + elif backend_str == Backend.XCCL: + if not is_xccl_available(): + raise RuntimeError("Distributed package doesn't have XCCL built in") + if backend_options is not None: + assert isinstance( + backend_options, ProcessGroupXCCL.Options + ), "Expected backend_options argument to be of type ProcessGroupXCCL.Options" + backend_class = ProcessGroupXCCL( + backend_prefix_store, group_rank, group_size + ) + backend_type = ProcessGroup.BackendType.XCCL else: assert ( backend_str.upper() in Backend._plugins diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 7a5053fe8b8b4..eb7c6f1e9aa04 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -94,8 +94,9 @@ class DistTestCases: # Sets showing that something is implemented backend_feature = {} - backend_feature["gpu"] = {"nccl", "gloo", "ucc"} + backend_feature["gpu"] = {"nccl", "gloo", "ucc", "xccl"} backend_feature["cuda"] = {"nccl", "gloo", "ucc"} + backend_feature["xpu"] = {"xccl"} backend_feature["ddp"] = {"nccl", "gloo", "ucc"} backend_feature["subgroup"] = {"nccl", "gloo", "ucc"} backend_feature["plugin"] = set() @@ -187,7 +188,8 @@ def skip_if_lt_x_gpu(x): def decorator(func): @wraps(func) def wrapper(*args, **kwargs): - if torch.cuda.is_available() and torch.cuda.device_count() >= x: + if (torch.cuda.is_available() and torch.cuda.device_count() >= x) or \ + (torch.xpu.is_available() and torch.xpu.device_count() >= x): return func(*args, **kwargs) sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) @@ -327,6 +329,12 @@ def requires_nccl(): "c10d was not compiled with the NCCL backend", ) +def requires_xccl(): + return skip_but_pass_in_sandcastle_if( + not c10d.is_xccl_available(), + "c10d was not compiled with the XCCL backend", + ) + def requires_ucc(): return skip_but_pass_in_sandcastle_if( not c10d.is_ucc_available(), @@ -478,6 +486,15 @@ def compute_sum(fn, world_size: int): ] ] +# Returns the number of GPUs, currently only for CUDA and XPU. +def get_device_count(backend: str): + assert c10d.is_backend_available(backend) + if backend in DistTestCases.backend_feature.get("cuda", set()): + return torch.cuda.device_count() + elif backend in DistTestCases.backend_feature.get("xpu", set()): + return torch.xpu.device_count() + else: + raise ValueError(f"Unsupported backend: {backend}") # HELPER FOR MULTIGPU TESTS def init_multigpu_helper(world_size: int, backend: str): @@ -486,7 +503,7 @@ def init_multigpu_helper(world_size: int, backend: str): On a single node, all visible GPUs are evenly divided to subsets, each process only uses a subset. """ - nGPUs = torch.cuda.device_count() + nGPUs = get_device_count(backend) visible_devices = range(nGPUs) # If rank is less than or equal to number of available GPU's