Skip to content

Commit

Permalink
[NPU] add elementwise_min_grad_op_npu,test=develop (#34731)
Browse files Browse the repository at this point in the history
  • Loading branch information
andyjiang1116 authored Aug 11, 2021
1 parent addd5fc commit 45af4f2
Show file tree
Hide file tree
Showing 2 changed files with 265 additions and 43 deletions.
176 changes: 173 additions & 3 deletions paddle/fluid/operators/elementwise/elementwise_min_op_npu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ limitations under the License. */
#include <memory>
#include <string>

#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/elementwise/elementwise_min_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_npu.h"
#include "paddle/fluid/operators/npu_op_runner.h"

namespace paddle {
Expand All @@ -27,31 +29,199 @@ template <typename DeviceContext, typename T>
class ElementwiseMinNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx =
ctx.template device_context<paddle::platform::NPUDeviceContext>();
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");

auto* out = ctx.Output<Tensor>("Out");

auto place = ctx.GetPlace();

out->mutable_data<T>(place);

int axis = ctx.Attr<int>("axis");
bool direct_compute = false;
auto x_dims = x->dims();
auto y_dims = y->dims();
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
if (x_dims.size() >= y_dims.size()) {
direct_compute =
y_dims == framework::slice_ddim(x_dims, axis, x_dims.size());
} else {
direct_compute =
x_dims == framework::slice_ddim(y_dims, axis, y_dims.size());
}
Tensor transformed_x, transformed_y;
if (direct_compute) {
transformed_x.ShareDataWith(*x);
transformed_y.ShareDataWith(*y);
} else {
NpuElementWiseOpBroadcast<T>(dev_ctx, x, y, axis, &transformed_x,
&transformed_y);
}
const auto& runner =
NpuOpRunner("Minimum", {transformed_x, transformed_y}, {*out}, {});
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();

const auto& runner = NpuOpRunner("Minimum", {*x, *y}, {*out}, {});
runner.Run(stream);
}
};

template <typename DeviceContext, typename T>
class ElementwiseMinGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx =
ctx.template device_context<paddle::platform::NPUDeviceContext>();
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
axis = (axis == -1 ? std::abs(x->dims().size() - y->dims().size()) : axis);
auto stream = dev_ctx.stream();
if (dx && dy) {
// dx
dx->mutable_data<T>(ctx.GetPlace());
Tensor tmp_x;
tmp_x.ShareDataWith(*dx);
if (dx->dims() != dout->dims()) {
std::vector<int> dst_dims_vec_x;
std::vector<int> reduce_axes_x;
auto src_dims_x = dx->dims();
auto dout_dims = dout->dims();

int src_axis_x = (src_dims_x.size() < dout_dims.size() ? axis : 0);
for (int ax = 0; ax < dout_dims.size(); ++ax) {
if ((ax < src_axis_x || ax >= src_axis_x + src_dims_x.size()) ||
(dout_dims[ax] > 1 && src_dims_x[ax - src_axis_x] == 1)) {
reduce_axes_x.push_back(ax);
} else {
dst_dims_vec_x.push_back(dout_dims[ax]);
}
}
if (!reduce_axes_x.empty()) {
tmp_x.Resize(framework::make_ddim(dst_dims_vec_x));
}
}
// dy
dy->mutable_data<T>(ctx.GetPlace());
Tensor tmp_y;
tmp_y.ShareDataWith(*dy);
if (dy->dims() != dout->dims()) {
std::vector<int> dst_dims_vec_y;
std::vector<int> reduce_axes_y;
auto src_dims_y = dy->dims();
auto dout_dims = dout->dims();

int src_axis_y = (src_dims_y.size() < dout_dims.size() ? axis : 0);
for (int ax = 0; ax < dout_dims.size(); ++ax) {
if ((ax < src_axis_y || ax >= src_axis_y + src_dims_y.size()) ||
(dout_dims[ax] > 1 && src_dims_y[ax - src_axis_y] == 1)) {
reduce_axes_y.push_back(ax);
} else {
dst_dims_vec_y.push_back(dout_dims[ax]);
}
}
if (!reduce_axes_y.empty()) {
tmp_y.Resize(framework::make_ddim(dst_dims_vec_y));
}
}

const auto& runner =
NpuOpRunner("MinimumGrad", {*dout, *x, *y}, {tmp_x, tmp_y},
{{"grad_x", true}, {"grad_y", true}});
runner.Run(stream);

} else if (dx) {
Tensor zero_tensor(dout->type());
zero_tensor.mutable_data<T>(y->dims(), ctx.GetPlace());
FillNpuTensorWithConstant<T>(&zero_tensor, static_cast<T>(0));
// dx
dx->mutable_data<T>(ctx.GetPlace());
Tensor tmp_x;
tmp_x.ShareDataWith(*dx);
if (dx->dims() != dout->dims()) {
std::vector<int> dst_dims_vec_x;
std::vector<int> reduce_axes_x;
auto src_dims_x = dx->dims();
auto dout_dims = dout->dims();

int src_axis_x = (src_dims_x.size() < dout_dims.size() ? axis : 0);
for (int ax = 0; ax < dout_dims.size(); ++ax) {
if ((ax < src_axis_x || ax >= src_axis_x + src_dims_x.size()) ||
(dout_dims[ax] > 1 && src_dims_x[ax - src_axis_x] == 1)) {
reduce_axes_x.push_back(ax);
} else {
dst_dims_vec_x.push_back(dout_dims[ax]);
}
}
if (!reduce_axes_x.empty()) {
tmp_x.Resize(framework::make_ddim(dst_dims_vec_x));
}
}

const auto& runner =
NpuOpRunner("MinimumGrad", {*dout, *x, *y}, {tmp_x, zero_tensor},
{{"grad_x", true}, {"grad_y", true}});
runner.Run(stream);

} else if (dy) {
Tensor zero_tensor(dout->type());
zero_tensor.mutable_data<T>(x->dims(), ctx.GetPlace());
FillNpuTensorWithConstant<T>(&zero_tensor, static_cast<T>(0));

// dy
dy->mutable_data<T>(ctx.GetPlace());
Tensor tmp_y;
tmp_y.ShareDataWith(*dy);
if (dy->dims() != dout->dims()) {
std::vector<int> dst_dims_vec_y;
std::vector<int> reduce_axes_y;
auto src_dims_y = dy->dims();
auto dout_dims = dout->dims();

int src_axis_y = (src_dims_y.size() < dout_dims.size() ? axis : 0);
for (int ax = 0; ax < dout_dims.size(); ++ax) {
if ((ax < src_axis_y || ax >= src_axis_y + src_dims_y.size()) ||
(dout_dims[ax] > 1 && src_dims_y[ax - src_axis_y] == 1)) {
reduce_axes_y.push_back(ax);
} else {
dst_dims_vec_y.push_back(dout_dims[ax]);
}
}
if (!reduce_axes_y.empty()) {
tmp_y.Resize(framework::make_ddim(dst_dims_vec_y));
}
}

const auto& runner =
NpuOpRunner("MinimumGrad", {*dout, *x, *y}, {zero_tensor, tmp_y},
{{"grad_x", true}, {"grad_y", true}});
runner.Run(stream);

} else {
std::cout << "error" << std::endl;
}
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_NPU_KERNEL(
elementwise_min,
ops::ElementwiseMinNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::ElementwiseMinNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);

REGISTER_OP_NPU_KERNEL(
elementwise_min_grad,
ops::ElementwiseMinGradNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::ElementwiseMinGradNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
132 changes: 92 additions & 40 deletions python/paddle/fluid/tests/unittests/npu/test_elementwise_min_op_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,81 +18,133 @@
import unittest
import sys
sys.path.append("..")
from op_test import OpTest
from op_test import OpTest, skip_check_grad_ci
import paddle
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
import paddle.fluid.core as core

paddle.enable_static()
SEED = 2021


class TestElementwiseMin(OpTest):
class TestElementwiseMinOp(OpTest):
def setUp(self):
self.set_npu()
self.op_type = "elementwise_min"
self.place = paddle.NPUPlace(0)

self.init_dtype()
np.random.seed(SEED)
x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
y = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
out = np.minimum(x, y)

self.init_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(x),
'Y': OpTest.np_dtype_to_fluid_dtype(y)
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {}
self.outputs = {'Out': out}
self.outputs = {'Out': self.out}
self.attrs = {'axis': self.axis}

def set_npu(self):
self.__class__.use_npu = True

def init_input_output(self):
# If x and y have the same value, the min() is not differentiable.
# So we generate test data by the following method
# to avoid them being too close to each other.
self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype)
self.sgn = np.random.choice([-1, 1], [13, 17]).astype(self.dtype)
self.y = self.x + self.sgn * np.random.uniform(
0.1, 1, [13, 17]).astype(self.dtype)
self.out = np.minimum(self.x, self.y)
self.axis = -1

def init_dtype(self):
self.dtype = np.float32

def test_check_output(self):
self.check_output_with_place(self.place)

# TODO(ascendrc): Min grad test
# def test_check_grad(self):
# if self.dtype == np.float16:
# return
# self.check_grad(['X'], 'Out')
#
def test_check_grad_normal(self):
if self.dtype == np.float16:
return

self.check_grad_with_place(
self.place,
['X', 'Y'],
'Out', )

class TestElementwiseMinFp16(OpTest):
def setUp(self):
self.set_npu()
self.op_type = "elementwise_min"
self.place = paddle.NPUPlace(0)
def test_check_grad_ingore_x(self):
if self.dtype == np.float16:
return

self.init_dtype()
np.random.seed(SEED)
x = np.random.uniform(1, 2, [3, 4]).astype(self.dtype)
y = np.random.uniform(1, 2, [3, 4]).astype(self.dtype)
out = np.minimum(x, y)
self.check_grad_with_place(
self.place,
['Y'],
'Out',
no_grad_set=set("X"), )

self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(x),
'Y': OpTest.np_dtype_to_fluid_dtype(y)
}
self.attrs = {}
self.outputs = {'Out': out}
def test_check_grad_ingore_y(self):
if self.dtype == np.float16:
return

self.check_grad_with_place(
self.place,
['X'],
'Out',
no_grad_set=set("Y"), )

def set_npu(self):
self.__class__.use_npu = True
self.__class__.no_need_check_grad = True

class TestElementwiseMinOpFp16(TestElementwiseMinOp):
def init_dtype(self):
self.dtype = np.float16

def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-5)

class TestElementwiseMinOp_Vector(TestElementwiseMinOp):
def init_input_output(self):
self.x = np.random.uniform(1, 2, (100, )).astype(self.dtype)
self.sgn = np.random.choice([-1, 1], (100, )).astype(self.dtype)
self.y = self.x + self.sgn * np.random.uniform(0.1, 1, (
100, )).astype(self.dtype)
self.out = np.minimum(self.x, self.y)
self.axis = -1


class TestElementwiseMinOpFp16_Vector(TestElementwiseMinOp_Vector):
def init_dtype(self):
self.dtype = np.float16


@skip_check_grad_ci(
reason="[skip shape check] Use y_shape(1) to test broadcast.")
class TestElementwiseMinOp_scalar(TestElementwiseMinOp):
def init_input_output(self):
self.x = np.random.random_integers(-5, 5, [10, 3, 4]).astype(self.dtype)
self.y = np.array([0.5]).astype(self.dtype)
self.out = np.minimum(self.x, self.y)
self.axis = -1


@skip_check_grad_ci(
reason="[skip shape check] Use y_shape(1) to test broadcast.")
class TestElementwiseMinOpFp16_scalar(TestElementwiseMinOp_scalar):
def init_dtype(self):
self.dtype = np.float16


class TestElementwiseMinOp_broadcast(TestElementwiseMinOp):
def init_input_output(self):
self.x = np.random.uniform(0.5, 1, (2, 3, 100)).astype(self.dtype)
self.sgn = np.random.choice([-1, 1], (100, )).astype(self.dtype)
self.y = self.x[0, 0, :] + self.sgn * \
np.random.uniform(1, 2, (100, )).astype(self.dtype)
self.out = np.minimum(self.x, self.y.reshape(1, 1, 100))
self.axis = -1


class TestElementwiseMinOpFp16_broadcast(TestElementwiseMinOp_broadcast):
def init_dtype(self):
self.dtype = np.float16


class TestElementwiseMinNet(unittest.TestCase):
class TestElementwiseMinOpNet(unittest.TestCase):
def _test(self, run_npu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
Expand Down

0 comments on commit 45af4f2

Please sign in to comment.