Skip to content

Commit

Permalink
Implement AllToAll collective op
Browse files Browse the repository at this point in the history
  • Loading branch information
shaahji committed Mar 7, 2023
1 parent 150043f commit 66101c0
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 7 deletions.
58 changes: 52 additions & 6 deletions onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,47 @@ Status AllGather::ComputeInternal(OpKernelContext* context) const {
return Status::OK();
}

ONNX_OPERATOR_KERNEL_EX(AllReduce, kMSDomain, 1, kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.VariadicAlias(0, 0) // outputs and inputs are mapped one to one
.AllocateInputsContiguously()
.TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()),
AllReduce);
AllToAll::AllToAll(const OpKernelInfo& info) : NcclKernel(info) {
info.GetAttrOrDefault("group_size", &group_size_, static_cast<int64_t>(1));
}

Status AllToAll::ComputeInternal(OpKernelContext* context) const {
const ncclComm_t comm = nccl_->Comm();
auto input_tensor = context->Input<Tensor>(0);
const char* input_data = static_cast<const char *>(input_tensor->DataRaw());
const auto in_shape = input_tensor->Shape();
const int64_t input_count = in_shape.Size();
auto out_shape = in_shape;
const int64_t element_size = input_tensor->DataType()->Size();
const int64_t rank_stride = input_count / group_size_;
const ncclDataType_t dtype = GetNcclDataType(input_tensor->DataType());

char* output_data = static_cast<char *>(context->Output(0, out_shape)->MutableDataRaw());

#ifdef ORT_USE_NCCL
NCCL_RETURN_IF_ERROR(ncclGroupStart());
for (int32_t r = 0; r < group_size_; r++) {
NCCL_RETURN_IF_ERROR(ncclSend(input_data, rank_stride, dtype, r, comm, Stream(context)));
NCCL_RETURN_IF_ERROR(ncclRecv(output_data, rank_stride, dtype, r, comm, Stream(context)));
input_data += (rank_stride * element_size);
output_data += (rank_stride * element_size);
}
NCCL_RETURN_IF_ERROR(ncclGroupEnd());
#endif

return Status::OK();
}

ONNX_OPERATOR_KERNEL_EX(
AllReduce,
kMSDomain,
1,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.VariadicAlias(0, 0) // outputs and inputs are mapped one to one
.AllocateInputsContiguously()
.TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()),
AllReduce);

ONNX_OPERATOR_KERNEL_EX(
AllGather,
Expand All @@ -146,6 +181,17 @@ ONNX_OPERATOR_KERNEL_EX(
.TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()),
AllGather);

ONNX_OPERATOR_KERNEL_EX(
AllToAll,
kMSDomain,
1,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.VariadicAlias(0, 0) // outputs and inputs are mapped one to one
.AllocateInputsContiguously()
.TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()),
AllToAll);

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
12 changes: 11 additions & 1 deletion onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,17 @@ class AllGather final : public NcclKernel {
Status ComputeInternal(OpKernelContext* context) const override;

private:
int64_t group_size_;
int64_t group_size_ = -1;
};

class AllToAll final : public NcclKernel {
public:
explicit AllToAll(const OpKernelInfo& info);

Status ComputeInternal(OpKernelContext* context) const override;

private:
int64_t group_size_ = -1;
};

} // namespace cuda
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain
#if defined(USE_MPI) && defined(ORT_USE_NCCL)
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllReduce);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll);
#endif

template <>
Expand Down Expand Up @@ -278,6 +279,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
#if defined(USE_MPI) && defined(ORT_USE_NCCL)
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllReduce)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll)>,
#endif

};
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kPytorchAtenDomain
#if defined(USE_MPI) && defined(ORT_USE_NCCL)
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllReduce);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllGather);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllToAll);
#endif

template <>
Expand Down Expand Up @@ -234,6 +235,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
#if defined(USE_MPI) && defined(ORT_USE_NCCL)
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllReduce)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllGather)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllToAll)>,
#endif
};

Expand Down
17 changes: 17 additions & 0 deletions onnxruntime/core/graph/contrib_ops/collective_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,23 @@ void RegisterCollectiveOps() {
*output_type->mutable_tensor_type()->mutable_shape() = shape;
}
});

ONNX_CONTRIB_OPERATOR_SCHEMA(AllToAll)
.SetDomain(kMSDomain)
.SinceVersion(1)
.Attr("group_size",
"total size in the group that need to participate.",
AttributeProto::INT,
static_cast<int64_t>(1))
.Input(0, "input", "tensors to be sent", "T", OpSchema::Variadic)
.Output(0, "output", "collected tensors", "T", OpSchema::Variadic)
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain to float, float16 and double tensors.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateShapeAndTypeFromFirstInput(ctx);
});
}

} // namespace contrib
Expand Down
33 changes: 33 additions & 0 deletions onnxruntime/test/python/onnxruntime_test_collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,19 @@ def _create_allgather_ut_model(self, shape):
)
return helper.make_model(graph_def, producer_name="ort-distributed-inference-unittest")

def _create_alltoall_ut_model(self, shape):
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, shape)
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, shape)
_, size = self._get_rank_size()
node_def = helper.make_node("AllToAll", ["X"], ["Y"], domain="com.microsoft", group_size=size)
graph_def = helper.make_graph(
[node_def],
"",
[X],
[Y],
)
return helper.make_model(graph_def, producer_name="ort-distributed-inference-unittest")

def test_all_reduce(self):
model = self._create_allreduce_ut_model((128, 128))
rank, size = self._get_rank_size()
Expand Down Expand Up @@ -72,6 +85,26 @@ def test_all_gather(self):

assert np.allclose(outputs[0], expected_output)

def test_all_to_all(self):
model = self._create_alltoall_ut_model((128, 128))
rank, size = self._get_rank_size()
ort_sess = ort.InferenceSession(
model.SerializeToString(),
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
provider_options=[{"device_id": str(rank)}, {}],
)

input = np.ones((128, 128), dtype=np.float32) * rank
outputs = ort_sess.run(None, {"X": input})

expected_output = np.zeros((int(128 / size), 128), dtype=np.float32)
for _ in range(size - 1):
expected_output = np.concatenate(
(expected_output, np.ones((int(128 / size), 128), dtype=np.float32) * (_ + 1))
)

assert np.allclose(outputs[0], expected_output)


if __name__ == "__main__":
unittest.main(module=__name__, buffer=True)

0 comments on commit 66101c0

Please sign in to comment.