Skip to content

Commit

Permalink
[cherry-pick][MLU] support add callback to stream and profiler (#42115)
Browse files Browse the repository at this point in the history
* [MLU] add mlu new profiler (#41138)

* [MLU] add mlu new profiler

* fix format

* [MLU] support add callback to stream (#41831)

* [MLU] add gather mlu kernel (#41969)

* [MLU] add mlu activation kernels (#41751)
  • Loading branch information
fwenguang authored May 10, 2022
1 parent 6c935e1 commit 25124d7
Show file tree
Hide file tree
Showing 42 changed files with 1,721 additions and 112 deletions.
9 changes: 6 additions & 3 deletions cmake/neuware.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@ INCLUDE_DIRECTORIES(${NEUWARE_INCLUDE_DIR})
set(CNNL_LIB ${NEUWARE_LIB_DIR}/libcnnl.so)
set(CNRT_LIB ${NEUWARE_LIB_DIR}/libcnrt.so)
set(CNDRV_LIB ${NEUWARE_LIB_DIR}/libcndrv.so)
set(CNPAPI_LIB ${NEUWARE_LIB_DIR}/libcnpapi.so)

generate_dummy_static_lib(LIB_NAME "neuware_lib" GENERATOR "neuware.cmake")
set(NEUWARE_LIB_DEPS ${CNNL_LIB} ${CNRT_LIB} ${CNDRV_LIB} ${CNPAPI_LIB})

if(WITH_CNCL)
MESSAGE(STATUS "Compile with CNCL!")
ADD_DEFINITIONS(-DPADDLE_WITH_CNCL)
set(CNCL_LIB ${NEUWARE_LIB_DIR}/libcncl.so)
TARGET_LINK_LIBRARIES(neuware_lib ${CNCL_LIB} ${CNNL_LIB} ${CNRT_LIB} ${CNDRV_LIB})
else()
TARGET_LINK_LIBRARIES(neuware_lib ${CNNL_LIB} ${CNRT_LIB} ${CNDRV_LIB})
list(APPEND NEUWARE_LIB_DEPS ${CNCL_LIB})
endif()

TARGET_LINK_LIBRARIES(neuware_lib ${NEUWARE_LIB_DEPS})
8 changes: 0 additions & 8 deletions paddle/fluid/framework/data_device_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,6 @@ void TransDataDevice(const Tensor &in, const platform::Place &dst_place,
return;
}

// NOTE(hqp): Special case for CPU->MLU, avoid stream sync.
if (platform::is_cpu_place(in.place()) && platform::is_mlu_place(dst_place)) {
paddle::framework::TensorCopy(
in, dst_place, *platform::DeviceContextPool::Instance().Get(dst_place),
out);
return;
}

// NOTE(yy): TransDataDevice should wait for computation of input.
if (!platform::is_cuda_pinned_place(in.place())) {
platform::DeviceContextPool::Instance().Get(in.place())->Wait();
Expand Down
138 changes: 110 additions & 28 deletions paddle/fluid/operators/activation_op_mlu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,8 @@ limitations under the Licnse. */
#include <memory>
#include <string>

#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/fluid/platform/device/mlu/device_context.h"
#include "paddle/phi/core/ddim.h"

namespace paddle {
namespace operators {
Expand All @@ -38,20 +34,39 @@ class ActivationMLUKernel : public framework::OpKernel<T> {
output->mutable_data<T>(ctx.GetPlace());

MLUCnnlActivationDesc act_desc(act_mode, alpha);
MLUCnnlTensorDesc input_desc(*input, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(input->dtype()));
MLUCnnlTensorDesc output_desc(*output, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(output->dtype()));

MLUCnnl::Active(ctx, act_desc.get(), input_desc.get(),
reinterpret_cast<const void*>(input->data<T>()),
output_desc.get(),
reinterpret_cast<void*>(output->data<T>()));
MLUCnnlTensorDesc input_desc(*input);
MLUCnnlTensorDesc output_desc(*output);

MLUCnnl::Active(ctx, act_desc.get(), input_desc.get(), GetBasePtr(input),
output_desc.get(), GetBasePtr(output));
}
};

// For gelu, leaky_relu
template <cnnlActivationMode_t act_mode, typename T>
class ActivationGradMLUKernel : public framework::OpKernel<T> {
class ActivationGradMLUKernelV1 : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;

dx->mutable_data<T>(ctx.GetPlace());

MLUCnnlTensorDesc x_desc(*x);
MLUCnnlTensorDesc dout_desc(*dout);
MLUCnnlTensorDesc dx_desc(*dx);
MLUCnnlActivationDesc act_desc(act_mode, alpha);
MLUCnnl::ActiveGrad(ctx, act_desc.get(), nullptr, nullptr, nullptr, nullptr,
dout_desc.get(), GetBasePtr(dout), x_desc.get(),
GetBasePtr(x), dx_desc.get(), GetBasePtr(dx));
}
};

// For tanh, sigmoid
template <cnnlActivationMode_t act_mode, typename T>
class ActivationGradMLUKernelV2 : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Input<Tensor>("Out");
Expand All @@ -61,18 +76,35 @@ class ActivationGradMLUKernel : public framework::OpKernel<T> {

dx->mutable_data<T>(ctx.GetPlace());

MLUCnnlTensorDesc dout_desc(*dout, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(dout->dtype()));
MLUCnnlTensorDesc out_desc(*out, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(out->dtype()));
MLUCnnlTensorDesc dx_desc(*dx, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(dx->dtype()));
MLUCnnlTensorDesc out_desc(*out);
MLUCnnlTensorDesc dout_desc(*dout);
MLUCnnlTensorDesc dx_desc(*dx);
MLUCnnlActivationDesc act_desc(act_mode, alpha);
MLUCnnl::ActiveGrad(
ctx, act_desc.get(), nullptr, nullptr, nullptr, nullptr,
dout_desc.get(), reinterpret_cast<const void*>(dout->data<T>()),
out_desc.get(), reinterpret_cast<const void*>(out->data<T>()),
dx_desc.get(), reinterpret_cast<void*>(dx->data<T>()));
MLUCnnl::ActiveGrad(ctx, act_desc.get(), nullptr, nullptr, out_desc.get(),
GetBasePtr(out), dout_desc.get(), GetBasePtr(dout),
nullptr, nullptr, dx_desc.get(), GetBasePtr(dx));
}
};

// For relu, relu6
template <cnnlActivationMode_t act_mode, typename T>
class ActivationGradMLUKernelV3 : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;

dx->mutable_data<T>(ctx.GetPlace());

MLUCnnlTensorDesc out_desc(*out);
MLUCnnlTensorDesc dout_desc(*dout);
MLUCnnlTensorDesc dx_desc(*dx);
MLUCnnlActivationDesc act_desc(act_mode, alpha);
MLUCnnl::ActiveGrad(ctx, act_desc.get(), nullptr, nullptr, nullptr, nullptr,
dout_desc.get(), GetBasePtr(dout), out_desc.get(),
GetBasePtr(out), dx_desc.get(), GetBasePtr(dx));
}
};

Expand All @@ -81,10 +113,60 @@ class ActivationGradMLUKernel : public framework::OpKernel<T> {

namespace ops = paddle::operators;

// relu
REGISTER_OP_MLU_KERNEL(
relu, ops::ActivationMLUKernel<CNNL_ACTIVATION_RELU, float>,
ops::ActivationMLUKernel<CNNL_ACTIVATION_RELU, paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL(
relu_grad, ops::ActivationGradMLUKernel<CNNL_ACTIVATION_RELU, float>,
ops::ActivationGradMLUKernel<CNNL_ACTIVATION_RELU,
paddle::platform::float16>);
relu_grad, ops::ActivationGradMLUKernelV3<CNNL_ACTIVATION_RELU, float>,
ops::ActivationGradMLUKernelV3<CNNL_ACTIVATION_RELU,
paddle::platform::float16>);

// relu6
REGISTER_OP_MLU_KERNEL(
relu6, ops::ActivationMLUKernel<CNNL_ACTIVATION_RELU6, float>,
ops::ActivationMLUKernel<CNNL_ACTIVATION_RELU6, paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL(
relu6_grad, ops::ActivationGradMLUKernelV3<CNNL_ACTIVATION_RELU6, float>,
ops::ActivationGradMLUKernelV3<CNNL_ACTIVATION_RELU6,
paddle::platform::float16>);

// sigmoid
REGISTER_OP_MLU_KERNEL(sigmoid,
ops::ActivationMLUKernel<CNNL_ACTIVATION_SIGMOID, float>,
ops::ActivationMLUKernel<CNNL_ACTIVATION_SIGMOID,
paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL(
sigmoid_grad,
ops::ActivationGradMLUKernelV2<CNNL_ACTIVATION_SIGMOID, float>,
ops::ActivationGradMLUKernelV2<CNNL_ACTIVATION_SIGMOID,
paddle::platform::float16>);

// tanh
REGISTER_OP_MLU_KERNEL(
tanh, ops::ActivationMLUKernel<CNNL_ACTIVATION_TANH, float>,
ops::ActivationMLUKernel<CNNL_ACTIVATION_TANH, paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL(
tanh_grad, ops::ActivationGradMLUKernelV2<CNNL_ACTIVATION_TANH, float>,
ops::ActivationGradMLUKernelV2<CNNL_ACTIVATION_TANH,
paddle::platform::float16>);

// gelu
REGISTER_OP_MLU_KERNEL(
gelu, ops::ActivationMLUKernel<CNNL_ACTIVATION_GELU, float>,
ops::ActivationMLUKernel<CNNL_ACTIVATION_GELU, paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL(
gelu_grad, ops::ActivationGradMLUKernelV1<CNNL_ACTIVATION_GELU, float>,
ops::ActivationGradMLUKernelV1<CNNL_ACTIVATION_GELU,
paddle::platform::float16>);

// leaky_relu
REGISTER_OP_MLU_KERNEL(
leaky_relu, ops::ActivationMLUKernel<CNNL_ACTIVATION_LEAKYRELU, float>,
ops::ActivationMLUKernel<CNNL_ACTIVATION_LEAKYRELU,
paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL(
leaky_relu_grad,
ops::ActivationGradMLUKernelV1<CNNL_ACTIVATION_LEAKYRELU, float>,
ops::ActivationGradMLUKernelV1<CNNL_ACTIVATION_LEAKYRELU,
paddle::platform::float16>);
16 changes: 7 additions & 9 deletions paddle/fluid/operators/fill_constant_op_mlu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class FillConstantMLUKernel : public framework::OpKernel<T> {
}
}
}
const T *value_data = &value;
cnnlPointerMode_t pointer_mode = CNNL_POINTER_MODE_HOST;
if (ctx.HasInput("ValueTensor")) {
auto *value_tensor = ctx.Input<framework::Tensor>("ValueTensor");
PADDLE_ENFORCE_EQ(
Expand All @@ -59,22 +61,18 @@ class FillConstantMLUKernel : public framework::OpKernel<T> {
"When use Tensor as value to set Tensor value in fill_cosntant, "
"value input(ValueTensor) size must be 1, but get %d",
value_tensor->numel()));
const T *tensor_data = value_tensor->data<T>();
framework::Tensor mlu_tensor;
value_data = value_tensor->data<T>();
auto tmp_place = value_tensor->place();
if (platform::is_mlu_place(tmp_place)) {
framework::TensorCopySync(*value_tensor, platform::CPUPlace(),
&mlu_tensor);
tensor_data = mlu_tensor.data<T>();
pointer_mode = CNNL_POINTER_MODE_DEVICE;
}
value = tensor_data[0];
}

auto shape = GetShape(ctx);
out_var->mutable_data<T>(shape, ctx.GetPlace());
MLUCnnlTensorDesc output_desc(*out_var, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(out_var->dtype()));
MLUCnnl::Fill(ctx, value, output_desc.get(), GetBasePtr(out_var));
MLUCnnlTensorDesc output_desc(*out_var);
MLUCnnl::Fill(ctx, pointer_mode, value_data, output_desc.get(),
GetBasePtr(out_var));
}
};
} // namespace operators
Expand Down
75 changes: 75 additions & 0 deletions paddle/fluid/operators/gather_op_mlu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"

namespace paddle {
namespace operators {

template <typename T>
class GatherOpMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<Tensor>("X");
auto *index = ctx.Input<Tensor>("Index");
auto axis = ctx.Attr<int>("axis");

auto *out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());

MLUCnnlTensorDesc x_desc(*x);
MLUCnnlTensorDesc index_desc(*index);
MLUCnnlTensorDesc out_desc(*out);
MLUCnnl::GatherFunctor(ctx, axis, 0 /*batch_dims*/, x_desc.get(),
GetBasePtr(x), index_desc.get(), GetBasePtr(index),
out_desc.get(), GetBasePtr(out));
}
};

template <typename T>
class GatherGradOpMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *index = ctx.Input<Tensor>("Index");
auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(ctx.GetPlace());

MLUCnnlTensorDesc dx_desc(*dx);
auto value = static_cast<T>(0);
MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &value, dx_desc.get(),
GetBasePtr(dx));

MLUCnnlTensorDesc index_desc(*index);
MLUCnnlTensorDesc dout_desc(*dout);
const cnnlScatterRefMode_t mode = CNNL_SCATTERREF_UPDATE;
MLUCnnl::ScatterFunctor(ctx, dx_desc.get(), GetBasePtr(dx), dout_desc.get(),
GetBasePtr(dout), index_desc.get(),
GetBasePtr(index), mode);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_MLU_KERNEL(gather, ops::GatherOpMLUKernel<float>,
ops::GatherOpMLUKernel<paddle::platform::float16>,
ops::GatherOpMLUKernel<int>);

REGISTER_OP_MLU_KERNEL(gather_grad, ops::GatherGradOpMLUKernel<float>,
ops::GatherGradOpMLUKernel<paddle::platform::float16>,
ops::GatherGradOpMLUKernel<int>);
3 changes: 2 additions & 1 deletion paddle/fluid/operators/mean_op_mlu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ class MeanMLUGradKernel : public framework::OpKernel<T> {
MLUCnnlTensorDesc mean_var_desc(mean_var, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(mean_var.dtype()));
auto value = static_cast<T>(1.0 / static_cast<float>(input_grad->numel()));
MLUCnnl::Fill(context, value, mean_var_desc.get(), GetBasePtr(&mean_var));
MLUCnnl::Fill(context, CNNL_POINTER_MODE_HOST, &value, mean_var_desc.get(),
GetBasePtr(&mean_var));

// means mul output_grad
MLUCnnlTensorDesc in_desc(*output_grad, CNNL_LAYOUT_ARRAY,
Expand Down
8 changes: 5 additions & 3 deletions paddle/fluid/operators/metrics/accuracy_op_mlu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,17 @@ class AccuracyMLUKernel : public framework::OpKernel<T> {
// [total]
total->mutable_data<int>(ctx.GetPlace());
MLUCnnlTensorDesc total_desc(*total);
MLUCnnl::Fill(ctx, num_samples, total_desc.get(), GetBasePtr(total));
MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &num_samples, total_desc.get(),
GetBasePtr(total));

// use `total` of type `float32` for calculating accuracy
Tensor total_fp32(framework::TransToPhiDataType(VT::FP32));
total_fp32.Resize(total->dims());
total_fp32.mutable_data<float>(ctx.GetPlace());
MLUCnnlTensorDesc total_fp32_desc(total_fp32);
MLUCnnl::Fill(ctx, static_cast<float>(num_samples), total_fp32_desc.get(),
GetBasePtr(&total_fp32));
float num_samples_fp32 = static_cast<float>(num_samples);
MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &num_samples_fp32,
total_fp32_desc.get(), GetBasePtr(&total_fp32));

// [accuracy]
accuracy->mutable_data<float>(ctx.GetPlace());
Expand Down
Loading

0 comments on commit 25124d7

Please sign in to comment.