diff --git a/paddle/fluid/operators/collective/c_allgather_op_mlu.cc b/paddle/fluid/operators/collective/c_allgather_op_mlu.cc index 7bd30ecadc8c8..347349ac7a49b 100644 --- a/paddle/fluid/operators/collective/c_allgather_op_mlu.cc +++ b/paddle/fluid/operators/collective/c_allgather_op_mlu.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/collective/c_allgather_op.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" #if defined(PADDLE_WITH_CNCL) #include "paddle/fluid/platform/collective_helper.h" @@ -27,15 +28,14 @@ template class CAllGatherOpMLUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + auto place = ctx.GetPlace(); + auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); #if defined(PADDLE_WITH_CNCL) auto x = ctx.Input("X"); auto out = ctx.Output("Out"); - cnclDataType_t dtype = - platform::ToCNCLDataType(framework::TransToProtoVarType(x->dtype())); int nranks = ctx.Attr("nranks"); int rid = ctx.Attr("ring_id"); - auto place = ctx.GetPlace(); auto comm = platform::CNCLCommContext::Instance().Get(rid, place); PADDLE_ENFORCE_EQ( nranks, @@ -48,19 +48,56 @@ class CAllGatherOpMLUKernel : public framework::OpKernel { out->mutable_data(out_dims, place); uint32_t send_numel = x->numel(); - void* send_buff = reinterpret_cast(const_cast(x->data())); - void* recv_buff = reinterpret_cast(out->data()); + void* send_buff; + void* recv_buff; + phi::DenseTensor in_tensor, out_tensor; + if (framework::TransToProtoVarType(x->dtype()) == + framework::proto::VarType::INT64) { + // cast from int64 to int32 since cncl do not support int64 + in_tensor.mutable_data(x->dims(), place); + out_tensor.mutable_data(out->dims(), place); + MLUCnnlTensorDesc x_int64_desc(*x); + MLUCnnlTensorDesc x_int32_desc(in_tensor); + cnnlCastDataType_t cast_type = GetCastDataType(VT::INT64, VT::INT32); + MLUCnnl::Cast(ctx, + cast_type, + x_int64_desc.get(), + GetBasePtr(x), + x_int32_desc.get(), + GetBasePtr(&in_tensor)); + send_buff = reinterpret_cast(in_tensor.data()); + recv_buff = reinterpret_cast(out_tensor.data()); + } else { + in_tensor.ShareDataWith(*x); + out_tensor.ShareDataWith(*out); + send_buff = reinterpret_cast(in_tensor.data()); + recv_buff = reinterpret_cast(out_tensor.data()); + } mluStream stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); stream = static_cast(dev_ctx)->stream(); } else { stream = comm->stream(); } + cnclDataType_t dtype = platform::ToCNCLDataType( + framework::TransToProtoVarType(in_tensor.dtype())); PADDLE_ENFORCE_MLU_SUCCESS(cnclAllGather( send_buff, recv_buff, send_numel, dtype, comm->comm(), stream)); + if (framework::TransToProtoVarType(x->dtype()) == + framework::proto::VarType::INT64) { + // cast back from int64 out_tensor to out + MLUCnnlTensorDesc out_int64_desc(*out); + MLUCnnlTensorDesc out_int32_desc(out_tensor); + cnnlCastDataType_t cast_type = GetCastDataType(VT::INT32, VT::INT64); + MLUCnnl::Cast(ctx, + cast_type, + out_int32_desc.get(), + GetBasePtr(&out_tensor), + out_int64_desc.get(), + GetBasePtr(out)); + } #else PADDLE_THROW(platform::errors::PreconditionNotMet( "PaddlePaddle should compile with MLU.")); @@ -80,4 +117,5 @@ REGISTER_OP_MLU_KERNEL(c_allgather, ops::CAllGatherOpMLUKernel, ops::CAllGatherOpMLUKernel, ops::CAllGatherOpMLUKernel, + ops::CAllGatherOpMLUKernel, ops::CAllGatherOpMLUKernel);