From 45af4f2aa7abe0bfe98ceec55f97064a2c298981 Mon Sep 17 00:00:00 2001 From: andyjpaddle <87074272+andyjpaddle@users.noreply.github.com> Date: Wed, 11 Aug 2021 14:19:09 +0800 Subject: [PATCH] [NPU] add elementwise_min_grad_op_npu,test=develop (#34731) --- .../elementwise/elementwise_min_op_npu.cc | 176 +++++++++++++++++- .../npu/test_elementwise_min_op_npu.py | 132 +++++++++---- 2 files changed, 265 insertions(+), 43 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op_npu.cc b/paddle/fluid/operators/elementwise/elementwise_min_op_npu.cc index 48ac3905f32bd..84ff28bb3a0e4 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op_npu.cc +++ b/paddle/fluid/operators/elementwise/elementwise_min_op_npu.cc @@ -15,7 +15,9 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/elementwise/elementwise_min_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_npu.h" #include "paddle/fluid/operators/npu_op_runner.h" namespace paddle { @@ -27,31 +29,199 @@ template class ElementwiseMinNPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = + ctx.template device_context(); auto* x = ctx.Input("X"); auto* y = ctx.Input("Y"); auto* out = ctx.Output("Out"); - auto place = ctx.GetPlace(); out->mutable_data(place); + int axis = ctx.Attr("axis"); + bool direct_compute = false; + auto x_dims = x->dims(); + auto y_dims = y->dims(); + axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); + if (x_dims.size() >= y_dims.size()) { + direct_compute = + y_dims == framework::slice_ddim(x_dims, axis, x_dims.size()); + } else { + direct_compute = + x_dims == framework::slice_ddim(y_dims, axis, y_dims.size()); + } + Tensor transformed_x, transformed_y; + if (direct_compute) { + transformed_x.ShareDataWith(*x); + transformed_y.ShareDataWith(*y); + } else { + NpuElementWiseOpBroadcast(dev_ctx, x, y, axis, &transformed_x, + &transformed_y); + } + const auto& runner = + NpuOpRunner("Minimum", {transformed_x, transformed_y}, {*out}, {}); auto stream = ctx.template device_context() .stream(); - - const auto& runner = NpuOpRunner("Minimum", {*x, *y}, {*out}, {}); runner.Run(stream); } }; +template +class ElementwiseMinGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = + ctx.template device_context(); + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + int axis = ctx.Attr("axis"); + axis = (axis == -1 ? std::abs(x->dims().size() - y->dims().size()) : axis); + auto stream = dev_ctx.stream(); + if (dx && dy) { + // dx + dx->mutable_data(ctx.GetPlace()); + Tensor tmp_x; + tmp_x.ShareDataWith(*dx); + if (dx->dims() != dout->dims()) { + std::vector dst_dims_vec_x; + std::vector reduce_axes_x; + auto src_dims_x = dx->dims(); + auto dout_dims = dout->dims(); + + int src_axis_x = (src_dims_x.size() < dout_dims.size() ? axis : 0); + for (int ax = 0; ax < dout_dims.size(); ++ax) { + if ((ax < src_axis_x || ax >= src_axis_x + src_dims_x.size()) || + (dout_dims[ax] > 1 && src_dims_x[ax - src_axis_x] == 1)) { + reduce_axes_x.push_back(ax); + } else { + dst_dims_vec_x.push_back(dout_dims[ax]); + } + } + if (!reduce_axes_x.empty()) { + tmp_x.Resize(framework::make_ddim(dst_dims_vec_x)); + } + } + // dy + dy->mutable_data(ctx.GetPlace()); + Tensor tmp_y; + tmp_y.ShareDataWith(*dy); + if (dy->dims() != dout->dims()) { + std::vector dst_dims_vec_y; + std::vector reduce_axes_y; + auto src_dims_y = dy->dims(); + auto dout_dims = dout->dims(); + + int src_axis_y = (src_dims_y.size() < dout_dims.size() ? axis : 0); + for (int ax = 0; ax < dout_dims.size(); ++ax) { + if ((ax < src_axis_y || ax >= src_axis_y + src_dims_y.size()) || + (dout_dims[ax] > 1 && src_dims_y[ax - src_axis_y] == 1)) { + reduce_axes_y.push_back(ax); + } else { + dst_dims_vec_y.push_back(dout_dims[ax]); + } + } + if (!reduce_axes_y.empty()) { + tmp_y.Resize(framework::make_ddim(dst_dims_vec_y)); + } + } + + const auto& runner = + NpuOpRunner("MinimumGrad", {*dout, *x, *y}, {tmp_x, tmp_y}, + {{"grad_x", true}, {"grad_y", true}}); + runner.Run(stream); + + } else if (dx) { + Tensor zero_tensor(dout->type()); + zero_tensor.mutable_data(y->dims(), ctx.GetPlace()); + FillNpuTensorWithConstant(&zero_tensor, static_cast(0)); + // dx + dx->mutable_data(ctx.GetPlace()); + Tensor tmp_x; + tmp_x.ShareDataWith(*dx); + if (dx->dims() != dout->dims()) { + std::vector dst_dims_vec_x; + std::vector reduce_axes_x; + auto src_dims_x = dx->dims(); + auto dout_dims = dout->dims(); + + int src_axis_x = (src_dims_x.size() < dout_dims.size() ? axis : 0); + for (int ax = 0; ax < dout_dims.size(); ++ax) { + if ((ax < src_axis_x || ax >= src_axis_x + src_dims_x.size()) || + (dout_dims[ax] > 1 && src_dims_x[ax - src_axis_x] == 1)) { + reduce_axes_x.push_back(ax); + } else { + dst_dims_vec_x.push_back(dout_dims[ax]); + } + } + if (!reduce_axes_x.empty()) { + tmp_x.Resize(framework::make_ddim(dst_dims_vec_x)); + } + } + + const auto& runner = + NpuOpRunner("MinimumGrad", {*dout, *x, *y}, {tmp_x, zero_tensor}, + {{"grad_x", true}, {"grad_y", true}}); + runner.Run(stream); + + } else if (dy) { + Tensor zero_tensor(dout->type()); + zero_tensor.mutable_data(x->dims(), ctx.GetPlace()); + FillNpuTensorWithConstant(&zero_tensor, static_cast(0)); + + // dy + dy->mutable_data(ctx.GetPlace()); + Tensor tmp_y; + tmp_y.ShareDataWith(*dy); + if (dy->dims() != dout->dims()) { + std::vector dst_dims_vec_y; + std::vector reduce_axes_y; + auto src_dims_y = dy->dims(); + auto dout_dims = dout->dims(); + + int src_axis_y = (src_dims_y.size() < dout_dims.size() ? axis : 0); + for (int ax = 0; ax < dout_dims.size(); ++ax) { + if ((ax < src_axis_y || ax >= src_axis_y + src_dims_y.size()) || + (dout_dims[ax] > 1 && src_dims_y[ax - src_axis_y] == 1)) { + reduce_axes_y.push_back(ax); + } else { + dst_dims_vec_y.push_back(dout_dims[ax]); + } + } + if (!reduce_axes_y.empty()) { + tmp_y.Resize(framework::make_ddim(dst_dims_vec_y)); + } + } + + const auto& runner = + NpuOpRunner("MinimumGrad", {*dout, *x, *y}, {zero_tensor, tmp_y}, + {{"grad_x", true}, {"grad_y", true}}); + runner.Run(stream); + + } else { + std::cout << "error" << std::endl; + } + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_NPU_KERNEL( elementwise_min, ops::ElementwiseMinNPUKernel, ops::ElementwiseMinNPUKernel); + +REGISTER_OP_NPU_KERNEL( + elementwise_min_grad, + ops::ElementwiseMinGradNPUKernel, + ops::ElementwiseMinGradNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_elementwise_min_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_elementwise_min_op_npu.py index 2034a12c5c0fe..51cf5cdaf6d1a 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_elementwise_min_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_elementwise_min_op_npu.py @@ -18,81 +18,133 @@ import unittest import sys sys.path.append("..") -from op_test import OpTest +from op_test import OpTest, skip_check_grad_ci import paddle import paddle.fluid as fluid +from paddle.fluid import Program, program_guard +import paddle.fluid.core as core paddle.enable_static() SEED = 2021 -class TestElementwiseMin(OpTest): +class TestElementwiseMinOp(OpTest): def setUp(self): self.set_npu() self.op_type = "elementwise_min" self.place = paddle.NPUPlace(0) - self.init_dtype() - np.random.seed(SEED) - x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) - y = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) - out = np.minimum(x, y) - + self.init_input_output() self.inputs = { - 'X': OpTest.np_dtype_to_fluid_dtype(x), - 'Y': OpTest.np_dtype_to_fluid_dtype(y) + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) } - self.attrs = {} - self.outputs = {'Out': out} + self.outputs = {'Out': self.out} + self.attrs = {'axis': self.axis} def set_npu(self): self.__class__.use_npu = True + def init_input_output(self): + # If x and y have the same value, the min() is not differentiable. + # So we generate test data by the following method + # to avoid them being too close to each other. + self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.sgn = np.random.choice([-1, 1], [13, 17]).astype(self.dtype) + self.y = self.x + self.sgn * np.random.uniform( + 0.1, 1, [13, 17]).astype(self.dtype) + self.out = np.minimum(self.x, self.y) + self.axis = -1 + def init_dtype(self): self.dtype = np.float32 def test_check_output(self): self.check_output_with_place(self.place) - # TODO(ascendrc): Min grad test - # def test_check_grad(self): - # if self.dtype == np.float16: - # return - # self.check_grad(['X'], 'Out') - # + def test_check_grad_normal(self): + if self.dtype == np.float16: + return + self.check_grad_with_place( + self.place, + ['X', 'Y'], + 'Out', ) -class TestElementwiseMinFp16(OpTest): - def setUp(self): - self.set_npu() - self.op_type = "elementwise_min" - self.place = paddle.NPUPlace(0) + def test_check_grad_ingore_x(self): + if self.dtype == np.float16: + return - self.init_dtype() - np.random.seed(SEED) - x = np.random.uniform(1, 2, [3, 4]).astype(self.dtype) - y = np.random.uniform(1, 2, [3, 4]).astype(self.dtype) - out = np.minimum(x, y) + self.check_grad_with_place( + self.place, + ['Y'], + 'Out', + no_grad_set=set("X"), ) - self.inputs = { - 'X': OpTest.np_dtype_to_fluid_dtype(x), - 'Y': OpTest.np_dtype_to_fluid_dtype(y) - } - self.attrs = {} - self.outputs = {'Out': out} + def test_check_grad_ingore_y(self): + if self.dtype == np.float16: + return + + self.check_grad_with_place( + self.place, + ['X'], + 'Out', + no_grad_set=set("Y"), ) - def set_npu(self): - self.__class__.use_npu = True - self.__class__.no_need_check_grad = True +class TestElementwiseMinOpFp16(TestElementwiseMinOp): def init_dtype(self): self.dtype = np.float16 - def test_check_output(self): - self.check_output_with_place(self.place, atol=1e-5) + +class TestElementwiseMinOp_Vector(TestElementwiseMinOp): + def init_input_output(self): + self.x = np.random.uniform(1, 2, (100, )).astype(self.dtype) + self.sgn = np.random.choice([-1, 1], (100, )).astype(self.dtype) + self.y = self.x + self.sgn * np.random.uniform(0.1, 1, ( + 100, )).astype(self.dtype) + self.out = np.minimum(self.x, self.y) + self.axis = -1 + + +class TestElementwiseMinOpFp16_Vector(TestElementwiseMinOp_Vector): + def init_dtype(self): + self.dtype = np.float16 + + +@skip_check_grad_ci( + reason="[skip shape check] Use y_shape(1) to test broadcast.") +class TestElementwiseMinOp_scalar(TestElementwiseMinOp): + def init_input_output(self): + self.x = np.random.random_integers(-5, 5, [10, 3, 4]).astype(self.dtype) + self.y = np.array([0.5]).astype(self.dtype) + self.out = np.minimum(self.x, self.y) + self.axis = -1 + + +@skip_check_grad_ci( + reason="[skip shape check] Use y_shape(1) to test broadcast.") +class TestElementwiseMinOpFp16_scalar(TestElementwiseMinOp_scalar): + def init_dtype(self): + self.dtype = np.float16 + + +class TestElementwiseMinOp_broadcast(TestElementwiseMinOp): + def init_input_output(self): + self.x = np.random.uniform(0.5, 1, (2, 3, 100)).astype(self.dtype) + self.sgn = np.random.choice([-1, 1], (100, )).astype(self.dtype) + self.y = self.x[0, 0, :] + self.sgn * \ + np.random.uniform(1, 2, (100, )).astype(self.dtype) + self.out = np.minimum(self.x, self.y.reshape(1, 1, 100)) + self.axis = -1 + + +class TestElementwiseMinOpFp16_broadcast(TestElementwiseMinOp_broadcast): + def init_dtype(self): + self.dtype = np.float16 -class TestElementwiseMinNet(unittest.TestCase): +class TestElementwiseMinOpNet(unittest.TestCase): def _test(self, run_npu=True): main_prog = paddle.static.Program() startup_prog = paddle.static.Program()