Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLU] add int64 support for allgather. #46830

Merged
merged 1 commit into from
Oct 11, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 44 additions & 6 deletions paddle/fluid/operators/collective/c_allgather_op_mlu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/c_allgather_op.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"

#if defined(PADDLE_WITH_CNCL)
#include "paddle/fluid/platform/collective_helper.h"
Expand All @@ -27,15 +28,14 @@ template <typename T>
class CAllGatherOpMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto place = ctx.GetPlace();
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
#if defined(PADDLE_WITH_CNCL)
auto x = ctx.Input<phi::DenseTensor>("X");
auto out = ctx.Output<phi::DenseTensor>("Out");
cnclDataType_t dtype =
platform::ToCNCLDataType(framework::TransToProtoVarType(x->dtype()));

int nranks = ctx.Attr<int>("nranks");
int rid = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace();
auto comm = platform::CNCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_EQ(
nranks,
Expand All @@ -48,19 +48,56 @@ class CAllGatherOpMLUKernel : public framework::OpKernel<T> {
out->mutable_data<T>(out_dims, place);

uint32_t send_numel = x->numel();
void* send_buff = reinterpret_cast<void*>(const_cast<T*>(x->data<T>()));
void* recv_buff = reinterpret_cast<void*>(out->data<T>());
void* send_buff;
void* recv_buff;
phi::DenseTensor in_tensor, out_tensor;
if (framework::TransToProtoVarType(x->dtype()) ==
framework::proto::VarType::INT64) {
// cast from int64 to int32 since cncl do not support int64
in_tensor.mutable_data<int32_t>(x->dims(), place);
out_tensor.mutable_data<int32_t>(out->dims(), place);
MLUCnnlTensorDesc x_int64_desc(*x);
MLUCnnlTensorDesc x_int32_desc(in_tensor);
cnnlCastDataType_t cast_type = GetCastDataType(VT::INT64, VT::INT32);
MLUCnnl::Cast(ctx,
cast_type,
x_int64_desc.get(),
GetBasePtr(x),
x_int32_desc.get(),
GetBasePtr(&in_tensor));
send_buff = reinterpret_cast<void*>(in_tensor.data<int32_t>());
recv_buff = reinterpret_cast<void*>(out_tensor.data<int32_t>());
} else {
in_tensor.ShareDataWith(*x);
out_tensor.ShareDataWith(*out);
send_buff = reinterpret_cast<void*>(in_tensor.data<T>());
recv_buff = reinterpret_cast<void*>(out_tensor.data<T>());
}

mluStream stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::MLUDeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}
cnclDataType_t dtype = platform::ToCNCLDataType(
framework::TransToProtoVarType(in_tensor.dtype()));

PADDLE_ENFORCE_MLU_SUCCESS(cnclAllGather(
send_buff, recv_buff, send_numel, dtype, comm->comm(), stream));
if (framework::TransToProtoVarType(x->dtype()) ==
framework::proto::VarType::INT64) {
// cast back from int64 out_tensor to out
MLUCnnlTensorDesc out_int64_desc(*out);
MLUCnnlTensorDesc out_int32_desc(out_tensor);
cnnlCastDataType_t cast_type = GetCastDataType(VT::INT32, VT::INT64);
MLUCnnl::Cast(ctx,
cast_type,
out_int32_desc.get(),
GetBasePtr(&out_tensor),
out_int64_desc.get(),
GetBasePtr(out));
}
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with MLU."));
Expand All @@ -80,4 +117,5 @@ REGISTER_OP_MLU_KERNEL(c_allgather,
ops::CAllGatherOpMLUKernel<int>,
ops::CAllGatherOpMLUKernel<int8_t>,
ops::CAllGatherOpMLUKernel<int16_t>,
ops::CAllGatherOpMLUKernel<int64_t>,
ops::CAllGatherOpMLUKernel<plat::float16>);