diff --git a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc index 3122f070fd57f..1088397fbaa44 100644 --- a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc @@ -129,12 +129,47 @@ Status AllGather::ComputeInternal(OpKernelContext* context) const { return Status::OK(); } -ONNX_OPERATOR_KERNEL_EX(AllReduce, kMSDomain, 1, kCudaExecutionProvider, - (*KernelDefBuilder::Create()) - .VariadicAlias(0, 0) // outputs and inputs are mapped one to one - .AllocateInputsContiguously() - .TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()), - AllReduce); +AllToAll::AllToAll(const OpKernelInfo& info) : NcclKernel(info) { + info.GetAttrOrDefault("group_size", &group_size_, static_cast(1)); +} + +Status AllToAll::ComputeInternal(OpKernelContext* context) const { + const ncclComm_t comm = nccl_->Comm(); + auto input_tensor = context->Input(0); + const char* input_data = static_cast(input_tensor->DataRaw()); + const auto in_shape = input_tensor->Shape(); + const int64_t input_count = in_shape.Size(); + auto out_shape = in_shape; + const int64_t element_size = input_tensor->DataType()->Size(); + const int64_t rank_stride = input_count / group_size_; + const ncclDataType_t dtype = GetNcclDataType(input_tensor->DataType()); + + char* output_data = static_cast(context->Output(0, out_shape)->MutableDataRaw()); + +#ifdef ORT_USE_NCCL + NCCL_RETURN_IF_ERROR(ncclGroupStart()); + for (int32_t r = 0; r < group_size_; r++) { + NCCL_RETURN_IF_ERROR(ncclSend(input_data, rank_stride, dtype, r, comm, Stream(context))); + NCCL_RETURN_IF_ERROR(ncclRecv(output_data, rank_stride, dtype, r, comm, Stream(context))); + input_data += (rank_stride * element_size); + output_data += (rank_stride * element_size); + } + NCCL_RETURN_IF_ERROR(ncclGroupEnd()); +#endif + + return Status::OK(); +} + +ONNX_OPERATOR_KERNEL_EX( + AllReduce, + kMSDomain, + 1, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .VariadicAlias(0, 0) // outputs and inputs are mapped one to one + .AllocateInputsContiguously() + .TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()), + AllReduce); ONNX_OPERATOR_KERNEL_EX( AllGather, @@ -146,6 +181,17 @@ ONNX_OPERATOR_KERNEL_EX( .TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()), AllGather); +ONNX_OPERATOR_KERNEL_EX( + AllToAll, + kMSDomain, + 1, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .VariadicAlias(0, 0) // outputs and inputs are mapped one to one + .AllocateInputsContiguously() + .TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()), + AllToAll); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h index 1576f674106e2..4f7093dd49363 100644 --- a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h +++ b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h @@ -66,7 +66,17 @@ class AllGather final : public NcclKernel { Status ComputeInternal(OpKernelContext* context) const override; private: - int64_t group_size_; + int64_t group_size_ = -1; +}; + +class AllToAll final : public NcclKernel { + public: + explicit AllToAll(const OpKernelInfo& info); + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int64_t group_size_ = -1; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 1cefd44844f39..8b8df46a15191 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -137,6 +137,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain #if defined(USE_MPI) && defined(ORT_USE_NCCL) class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllReduce); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll); #endif template <> @@ -278,6 +279,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { #if defined(USE_MPI) && defined(ORT_USE_NCCL) BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc index 252f943c43df3..d420feac15a88 100644 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -112,6 +112,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kPytorchAtenDomain #if defined(USE_MPI) && defined(ORT_USE_NCCL) class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllReduce); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllGather); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllToAll); #endif template <> @@ -234,6 +235,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { #if defined(USE_MPI) && defined(ORT_USE_NCCL) BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/core/graph/contrib_ops/collective_defs.cc b/onnxruntime/core/graph/contrib_ops/collective_defs.cc index 167b80238a3d6..c4815a75dbd53 100644 --- a/onnxruntime/core/graph/contrib_ops/collective_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/collective_defs.cc @@ -57,6 +57,23 @@ void RegisterCollectiveOps() { *output_type->mutable_tensor_type()->mutable_shape() = shape; } }); + + ONNX_CONTRIB_OPERATOR_SCHEMA(AllToAll) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("group_size", + "total size in the group that need to participate.", + AttributeProto::INT, + static_cast(1)) + .Input(0, "input", "tensors to be sent", "T", OpSchema::Variadic) + .Output(0, "output", "collected tensors", "T", OpSchema::Variadic) + .TypeConstraint( + "T", + {"tensor(float16)", "tensor(float)", "tensor(double)"}, + "Constrain to float, float16 and double tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateShapeAndTypeFromFirstInput(ctx); + }); } } // namespace contrib diff --git a/onnxruntime/test/python/onnxruntime_test_collective.py b/onnxruntime/test/python/onnxruntime_test_collective.py index c2031498bd55a..8329ab36337c2 100644 --- a/onnxruntime/test/python/onnxruntime_test_collective.py +++ b/onnxruntime/test/python/onnxruntime_test_collective.py @@ -41,6 +41,19 @@ def _create_allgather_ut_model(self, shape): ) return helper.make_model(graph_def, producer_name="ort-distributed-inference-unittest") + def _create_alltoall_ut_model(self, shape): + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, shape) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, shape) + _, size = self._get_rank_size() + node_def = helper.make_node("AllToAll", ["X"], ["Y"], domain="com.microsoft", group_size=size) + graph_def = helper.make_graph( + [node_def], + "", + [X], + [Y], + ) + return helper.make_model(graph_def, producer_name="ort-distributed-inference-unittest") + def test_all_reduce(self): model = self._create_allreduce_ut_model((128, 128)) rank, size = self._get_rank_size() @@ -72,6 +85,26 @@ def test_all_gather(self): assert np.allclose(outputs[0], expected_output) + def test_all_to_all(self): + model = self._create_alltoall_ut_model((128, 128)) + rank, size = self._get_rank_size() + ort_sess = ort.InferenceSession( + model.SerializeToString(), + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], + provider_options=[{"device_id": str(rank)}, {}], + ) + + input = np.ones((128, 128), dtype=np.float32) * rank + outputs = ort_sess.run(None, {"X": input}) + + expected_output = np.zeros((int(128 / size), 128), dtype=np.float32) + for _ in range(size - 1): + expected_output = np.concatenate( + (expected_output, np.ones((int(128 / size), 128), dtype=np.float32) * (_ + 1)) + ) + + assert np.allclose(outputs[0], expected_output) + if __name__ == "__main__": unittest.main(module=__name__, buffer=True)