Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement AllToAll collective operator #14926

Merged
merged 1 commit into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 52 additions & 6 deletions onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,47 @@ Status AllGather::ComputeInternal(OpKernelContext* context) const {
return Status::OK();
}

ONNX_OPERATOR_KERNEL_EX(AllReduce, kMSDomain, 1, kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.VariadicAlias(0, 0) // outputs and inputs are mapped one to one
.AllocateInputsContiguously()
.TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()),
AllReduce);
AllToAll::AllToAll(const OpKernelInfo& info) : NcclKernel(info) {
info.GetAttrOrDefault("group_size", &group_size_, static_cast<int64_t>(1));
}

Status AllToAll::ComputeInternal(OpKernelContext* context) const {
const ncclComm_t comm = nccl_->Comm();
auto input_tensor = context->Input<Tensor>(0);
const char* input_data = static_cast<const char *>(input_tensor->DataRaw());
const auto in_shape = input_tensor->Shape();
const int64_t input_count = in_shape.Size();
auto out_shape = in_shape;
const int64_t element_size = input_tensor->DataType()->Size();
const int64_t rank_stride = input_count / group_size_;
const ncclDataType_t dtype = GetNcclDataType(input_tensor->DataType());

char* output_data = static_cast<char *>(context->Output(0, out_shape)->MutableDataRaw());

#ifdef ORT_USE_NCCL
NCCL_RETURN_IF_ERROR(ncclGroupStart());
for (int32_t r = 0; r < group_size_; r++) {
NCCL_RETURN_IF_ERROR(ncclSend(input_data, rank_stride, dtype, r, comm, Stream(context)));
NCCL_RETURN_IF_ERROR(ncclRecv(output_data, rank_stride, dtype, r, comm, Stream(context)));
input_data += (rank_stride * element_size);
output_data += (rank_stride * element_size);
}
NCCL_RETURN_IF_ERROR(ncclGroupEnd());
#endif

return Status::OK();
}

ONNX_OPERATOR_KERNEL_EX(
AllReduce,
kMSDomain,
1,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.VariadicAlias(0, 0) // outputs and inputs are mapped one to one
.AllocateInputsContiguously()
.TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()),
AllReduce);

ONNX_OPERATOR_KERNEL_EX(
AllGather,
Expand All @@ -146,6 +181,17 @@ ONNX_OPERATOR_KERNEL_EX(
.TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()),
AllGather);

ONNX_OPERATOR_KERNEL_EX(
AllToAll,
kMSDomain,
1,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.VariadicAlias(0, 0) // outputs and inputs are mapped one to one
.AllocateInputsContiguously()
.TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()),
AllToAll);

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
12 changes: 11 additions & 1 deletion onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,17 @@ class AllGather final : public NcclKernel {
Status ComputeInternal(OpKernelContext* context) const override;

private:
int64_t group_size_;
int64_t group_size_ = -1;
};

class AllToAll final : public NcclKernel {
public:
explicit AllToAll(const OpKernelInfo& info);

Status ComputeInternal(OpKernelContext* context) const override;

private:
int64_t group_size_ = -1;
};

} // namespace cuda
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain
#if defined(USE_MPI) && defined(ORT_USE_NCCL)
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllReduce);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll);
#endif

template <>
Expand Down Expand Up @@ -278,6 +279,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
#if defined(USE_MPI) && defined(ORT_USE_NCCL)
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllReduce)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll)>,
#endif

};
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kPytorchAtenDomain
#if defined(USE_MPI) && defined(ORT_USE_NCCL)
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllReduce);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllGather);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllToAll);
#endif

template <>
Expand Down Expand Up @@ -234,6 +235,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
#if defined(USE_MPI) && defined(ORT_USE_NCCL)
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllReduce)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllGather)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllToAll)>,
#endif
};

Expand Down
17 changes: 17 additions & 0 deletions onnxruntime/core/graph/contrib_ops/collective_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,23 @@ void RegisterCollectiveOps() {
*output_type->mutable_tensor_type()->mutable_shape() = shape;
}
});

ONNX_CONTRIB_OPERATOR_SCHEMA(AllToAll)
.SetDomain(kMSDomain)
.SinceVersion(1)
.Attr("group_size",
"total size in the group that need to participate.",
AttributeProto::INT,
static_cast<int64_t>(1))
.Input(0, "input", "tensors to be sent", "T", OpSchema::Variadic)
.Output(0, "output", "collected tensors", "T", OpSchema::Variadic)
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain to float, float16 and double tensors.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateShapeAndTypeFromFirstInput(ctx);
});
}

} // namespace contrib
Expand Down
33 changes: 33 additions & 0 deletions onnxruntime/test/python/onnxruntime_test_collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,19 @@ def _create_allgather_ut_model(self, shape):
)
return helper.make_model(graph_def, producer_name="ort-distributed-inference-unittest")

def _create_alltoall_ut_model(self, shape):
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, shape)
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, shape)
_, size = self._get_rank_size()
node_def = helper.make_node("AllToAll", ["X"], ["Y"], domain="com.microsoft", group_size=size)
graph_def = helper.make_graph(
[node_def],
"",
[X],
[Y],
)
return helper.make_model(graph_def, producer_name="ort-distributed-inference-unittest")

def test_all_reduce(self):
model = self._create_allreduce_ut_model((128, 128))
rank, size = self._get_rank_size()
Expand Down Expand Up @@ -72,6 +85,26 @@ def test_all_gather(self):

assert np.allclose(outputs[0], expected_output)

def test_all_to_all(self):
model = self._create_alltoall_ut_model((128, 128))
rank, size = self._get_rank_size()
ort_sess = ort.InferenceSession(
model.SerializeToString(),
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
provider_options=[{"device_id": str(rank)}, {}],
)

input = np.ones((128, 128), dtype=np.float32) * rank
outputs = ort_sess.run(None, {"X": input})

expected_output = np.zeros((int(128 / size), 128), dtype=np.float32)
for _ in range(size - 1):
expected_output = np.concatenate(
(expected_output, np.ones((int(128 / size), 128), dtype=np.float32) * (_ + 1))
)

assert np.allclose(outputs[0], expected_output)


if __name__ == "__main__":
unittest.main(module=__name__, buffer=True)