From 35c48c75316cbefc0fcc3a582ea57deac70660ec Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Tue, 3 Jan 2023 07:22:06 +0000 Subject: [PATCH 1/4] support 0D for paddle.sort/argsort --- paddle/phi/infermeta/unary.cc | 32 +++++---- paddle/phi/kernels/cpu/argsort_grad_kernel.cc | 7 ++ paddle/phi/kernels/cpu/argsort_kernel.cc | 9 +++ paddle/phi/kernels/gpu/argsort_grad_kernel.cu | 8 +++ paddle/phi/kernels/gpu/argsort_kernel.cu | 8 +++ .../tests/unittests/test_zero_dim_tensor.py | 72 +++++++++++++++++++ .../unittests/xpu/test_zero_dim_tensor_xpu.py | 36 ++++++++++ 7 files changed, 160 insertions(+), 12 deletions(-) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index c3b96b813b8c3..895d50c7bbd2a 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -220,18 +220,26 @@ void ArgsortInferMeta(const MetaTensor& input, MetaTensor* indices) { auto in_dims = input.dims(); auto num_dims = in_dims.size(); - PADDLE_ENFORCE_GE( - axis, - -num_dims, - phi::errors::InvalidArgument("'axis'(%d) must be greater than or equal to" - " -num_dims(%d).", - axis, - -num_dims)); - PADDLE_ENFORCE_LT( - axis, - num_dims, - phi::errors::InvalidArgument( - "'axis'(%d) must be less than num_dims(%d).", axis, num_dims)); + if (num_dims > 0) { + PADDLE_ENFORCE_GE(axis, + -num_dims, + phi::errors::InvalidArgument( + "'axis'(%d) must be greater than or equal to" + " -num_dims(%d).", + axis, + -num_dims)); + PADDLE_ENFORCE_LT( + axis, + num_dims, + phi::errors::InvalidArgument( + "'axis'(%d) must be less than num_dims(%d).", axis, num_dims)); + } else { // 0-dim tensor + PADDLE_ENFORCE_EQ( + axis == 0 || axis == -1, + 1, + phi::errors::InvalidArgument( + "'axis'(%d) must be 0 or -1 if input tensor is 0-dim.", axis)); + } output->share_dims(input); output->set_dtype(input.dtype()); diff --git a/paddle/phi/kernels/cpu/argsort_grad_kernel.cc b/paddle/phi/kernels/cpu/argsort_grad_kernel.cc index 1e60847232c70..f866b62a2bd4d 100644 --- a/paddle/phi/kernels/cpu/argsort_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/argsort_grad_kernel.cc @@ -18,6 +18,7 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/transpose_kernel.h" namespace phi { @@ -58,6 +59,7 @@ void ArgsortGradKernel(const Context& dev_ctx, bool descending, DenseTensor* in_grad) { auto in_dims = indices.dims(); + auto rank = input.dims().size(); axis = (axis < 0) ? (in_dims.size() + axis) : axis; dev_ctx.template Alloc(in_grad); auto dxt = EigenVector::Flatten(*in_grad); @@ -65,6 +67,11 @@ void ArgsortGradKernel(const Context& dev_ctx, dxt.device(place) = dxt.constant(static_cast(0)); if (out_grad.numel() == 0) return; + if (rank == 0) { + phi::funcs::set_constant(dev_ctx, in_grad, 1.0); + return; + } + // Do full assign if (axis == -1 || axis + 1 == in_dims.size()) { const int64_t input_height = diff --git a/paddle/phi/kernels/cpu/argsort_kernel.cc b/paddle/phi/kernels/cpu/argsort_kernel.cc index 8621a717e1018..07b61fa3f3579 100644 --- a/paddle/phi/kernels/cpu/argsort_kernel.cc +++ b/paddle/phi/kernels/cpu/argsort_kernel.cc @@ -18,6 +18,7 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/transpose_kernel.h" namespace phi { @@ -75,9 +76,17 @@ void ArgsortKernel(const Context& dev_ctx, DenseTensor* output, DenseTensor* indices) { auto in_dims = input.dims(); + auto rank = in_dims.size(); axis = (axis < 0) ? (in_dims.size() + axis) : axis; T* out_data = dev_ctx.template Alloc(output); + // For 0D Tensor + if (rank == 0) { + phi::Copy(dev_ctx, input, dev_ctx.GetPlace(), false, output); + phi::funcs::set_constant(dev_ctx, indices, 0); + return; + } + // Do full sort if (axis == -1 || axis + 1 == in_dims.size()) { const int64_t input_height = diff --git a/paddle/phi/kernels/gpu/argsort_grad_kernel.cu b/paddle/phi/kernels/gpu/argsort_grad_kernel.cu index a2d149cb2e438..f28da8704cbff 100644 --- a/paddle/phi/kernels/gpu/argsort_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/argsort_grad_kernel.cu @@ -28,6 +28,7 @@ namespace cub = hipcub; #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/primitive/functor_primitives.h" #include "paddle/phi/kernels/transpose_kernel.h" @@ -141,11 +142,18 @@ void ArgsortGradKernel(const Context& dev_ctx, bool descending, DenseTensor* in_grad) { dev_ctx.template Alloc(in_grad); + phi::funcs::set_constant(dev_ctx, in_grad, 0.0); if (out_grad.numel() == 0) return; auto in_dims = in_grad->dims(); + auto rank = in_dims.size(); axis = (axis < 0) ? (in_dims.size() + axis) : axis; int64_t size = in_grad->numel(); + if (rank == 0) { + phi::funcs::set_constant(dev_ctx, in_grad, 1.0); + return; + } + // Parallel acceleration when the input size is equal to the length of the // ‘axis’ dimension. // Compared to 'special case for full sort' below, the gradient calculation diff --git a/paddle/phi/kernels/gpu/argsort_kernel.cu b/paddle/phi/kernels/gpu/argsort_kernel.cu index 1c3825b90e210..13455a7639cdb 100644 --- a/paddle/phi/kernels/gpu/argsort_kernel.cu +++ b/paddle/phi/kernels/gpu/argsort_kernel.cu @@ -30,6 +30,7 @@ namespace cub = hipcub; #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/primitive/functor_primitives.h" #include "paddle/phi/kernels/transpose_kernel.h" @@ -396,6 +397,7 @@ void ArgsortKernel(const Context &dev_ctx, DenseTensor *output, DenseTensor *indices) { auto in_dims = input.dims(); + auto rank = in_dims.size(); axis = (axis < 0) ? (in_dims.size() + axis) : axis; const T *in_data = input.data(); @@ -403,6 +405,12 @@ void ArgsortKernel(const Context &dev_ctx, T *out_data = dev_ctx.template Alloc(output); int64_t *ids_data = dev_ctx.template Alloc(indices); + if (rank == 0) { + phi::Copy(dev_ctx, input, dev_ctx.GetPlace(), false, output); + phi::funcs::set_constant(dev_ctx, indices, 0); + return; + } + // Use thrust for parallel acceleration when the input size is equal to the // length of the ‘axis’ dimension. // Compared to the following 'Special case for full sort', ascending sort is diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index b8a115104800a..c99aabbf9b5a1 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -747,6 +747,42 @@ def test_floor_divide(self): np.testing.assert_array_equal(out3_1.numpy(), out3_2.numpy()) np.testing.assert_array_equal(out3_2.numpy(), np.asarray(1)) + def test_sort(self): + x1 = paddle.rand([]) + x2 = paddle.rand([]) + x1.stop_gradient = False + x2.stop_gradient = False + out1 = paddle.sort(x1, axis=-1) + out2 = paddle.sort(x2, axis=0) + + out1.backward() + out2.backward() + + self.assertEqual(out1.shape, []) + self.assertEqual(out2.shape, []) + self.assertEqual(out1.grad.shape, []) + self.assertEqual(out2.grad.shape, []) + self.assertEqual(x1.grad.shape, []) + self.assertEqual(x2.grad.shape, []) + + def test_argsort(self): + x1 = paddle.rand([]) + x2 = paddle.rand([]) + x1.stop_gradient = False + x2.stop_gradient = False + out1 = paddle.argsort(x1, axis=-1) + out2 = paddle.argsort(x2, axis=0) + + out1.backward() + out2.backward() + + self.assertEqual(out1.shape, []) + self.assertEqual(out2.shape, []) + self.assertEqual(out1.grad.shape, []) + self.assertEqual(out2.grad.shape, []) + self.assertEqual(x1.grad.shape, []) + self.assertEqual(x2.grad.shape, []) + class TestSundryAPIStatic(unittest.TestCase): def setUp(self): @@ -990,6 +1026,42 @@ def test_floor_divide(self): np.testing.assert_array_equal(out3_1, out3_2) np.testing.assert_array_equal(out3_2, np.asarray(1)) + @prog_scope() + def test_sort(self): + x1 = paddle.rand([]) + x1.stop_gradient = False + out1 = paddle.sort(x1, axis=-1) + paddle.static.append_backward(out1) + + x2 = paddle.rand([]) + x2.stop_gradient = False + out2 = paddle.sort(x2, axis=0) + paddle.static.append_backward(out2) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out1, out2]) + + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[1].shape, ()) + + @prog_scope() + def test_argsort(self): + x1 = paddle.rand([]) + x1.stop_gradient = False + out1 = paddle.argsort(x1, axis=-1) + paddle.static.append_backward(out1) + + x2 = paddle.rand([]) + x2.stop_gradient = False + out2 = paddle.argsort(x2, axis=0) + paddle.static.append_backward(out2) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out1, out2]) + + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[1].shape, ()) + # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. class TestNoBackwardAPI(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py index b561b775f29d0..ebd4354593dba 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py @@ -556,6 +556,42 @@ def test_floor_divide(self): np.testing.assert_array_equal(out3_1.numpy(), out3_2.numpy()) np.testing.assert_array_equal(out3_2.numpy(), np.asarray(1)) + def test_sort(self): + x1 = paddle.rand([]) + x2 = paddle.rand([]) + x1.stop_gradient = False + x2.stop_gradient = False + out1 = paddle.sort(x1, axis=-1) + out2 = paddle.sort(x2, axis=0) + + out1.backward() + out2.backward() + + self.assertEqual(out1.shape, []) + self.assertEqual(out2.shape, []) + self.assertEqual(out1.grad.shape, []) + self.assertEqual(out2.grad.shape, []) + self.assertEqual(x1.grad.shape, []) + self.assertEqual(x2.grad.shape, []) + + def test_argsort(self): + x1 = paddle.rand([]) + x2 = paddle.rand([]) + x1.stop_gradient = False + x2.stop_gradient = False + out1 = paddle.argsort(x1, axis=-1) + out2 = paddle.argsort(x2, axis=0) + + out1.backward() + out2.backward() + + self.assertEqual(out1.shape, []) + self.assertEqual(out2.shape, []) + self.assertEqual(out1.grad.shape, []) + self.assertEqual(out2.grad.shape, []) + self.assertEqual(x1.grad.shape, []) + self.assertEqual(x2.grad.shape, []) + # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. class TestNoBackwardAPI(unittest.TestCase): From e6254e35563db38edad5bc3e51054257c35062ca Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Tue, 3 Jan 2023 11:25:43 +0000 Subject: [PATCH 2/4] support 0D tensor for paddle.sort/argsort in xpu --- paddle/phi/kernels/xpu/argsort_grad_kernel.cc | 7 +++++++ paddle/phi/kernels/xpu/argsort_kernel.cc | 8 ++++++++ 2 files changed, 15 insertions(+) diff --git a/paddle/phi/kernels/xpu/argsort_grad_kernel.cc b/paddle/phi/kernels/xpu/argsort_grad_kernel.cc index 371cc7d39c290..00c679f0ab999 100644 --- a/paddle/phi/kernels/xpu/argsort_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/argsort_grad_kernel.cc @@ -17,6 +17,7 @@ #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { @@ -29,6 +30,7 @@ void ArgsortGradKernel(const Context& dev_ctx, bool descending, DenseTensor* in_grad) { auto in_dims = indices.dims(); + auto rank = in_dims.size(); axis = (axis < 0) ? (in_dims.size() + axis) : axis; dev_ctx.template Alloc(in_grad); @@ -40,6 +42,11 @@ void ArgsortGradKernel(const Context& dev_ctx, if (out_grad.numel() == 0) return; + if (rank == 0) { + phi::funcs::set_constant(dev_ctx, in_grad, 1.0); + return; + } + bool is_need_transpose = true; if (axis == -1 || axis + 1 == in_dims.size()) { is_need_transpose = false; diff --git a/paddle/phi/kernels/xpu/argsort_kernel.cc b/paddle/phi/kernels/xpu/argsort_kernel.cc index 0a71ec71463d4..4fdb42f69fd87 100644 --- a/paddle/phi/kernels/xpu/argsort_kernel.cc +++ b/paddle/phi/kernels/xpu/argsort_kernel.cc @@ -17,6 +17,7 @@ #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { @@ -171,6 +172,7 @@ void ArgsortKernel(const Context& dev_ctx, DenseTensor* output, DenseTensor* indices) { auto in_dims = input.dims(); + auto rank = in_dims.size(); axis = (axis < 0) ? (in_dims.size() + axis) : axis; int n = in_dims[axis]; @@ -178,6 +180,12 @@ void ArgsortKernel(const Context& dev_ctx, auto output_data = dev_ctx.template Alloc(output); auto indices_data = dev_ctx.template Alloc(indices); + if (rank == 0) { + phi::Copy(dev_ctx, input, dev_ctx.GetPlace(), false, output); + phi::funcs::set_constant(dev_ctx, indices, 0); + return; + } + int len_before = phi::product(phi::slice_ddim(in_dims, 0, axis)); int len_after = phi::product(phi::slice_ddim(in_dims, axis + 1, in_dims.size())); From 7d315faaf511c13f2921558846e1f039d58a9c89 Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Tue, 3 Jan 2023 12:31:06 +0000 Subject: [PATCH 3/4] fix bug --- .../tests/unittests/test_zero_dim_tensor.py | 73 +++++++++---------- 1 file changed, 36 insertions(+), 37 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 19a4711fc5b68..5aec0c8010f3f 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -1147,42 +1147,6 @@ def test_floor_divide(self): np.testing.assert_array_equal(out3_2, np.asarray(1)) @prog_scope() -<<<<<<< HEAD - def test_sort(self): - x1 = paddle.rand([]) - x1.stop_gradient = False - out1 = paddle.sort(x1, axis=-1) - paddle.static.append_backward(out1) - - x2 = paddle.rand([]) - x2.stop_gradient = False - out2 = paddle.sort(x2, axis=0) - paddle.static.append_backward(out2) - - prog = paddle.static.default_main_program() - res = self.exe.run(prog, fetch_list=[out1, out2]) - - self.assertEqual(res[0].shape, ()) - self.assertEqual(res[1].shape, ()) - - @prog_scope() - def test_argsort(self): - x1 = paddle.rand([]) - x1.stop_gradient = False - out1 = paddle.argsort(x1, axis=-1) - paddle.static.append_backward(out1) - - x2 = paddle.rand([]) - x2.stop_gradient = False - out2 = paddle.argsort(x2, axis=0) - paddle.static.append_backward(out2) - - prog = paddle.static.default_main_program() - res = self.exe.run(prog, fetch_list=[out1, out2]) - - self.assertEqual(res[0].shape, ()) - self.assertEqual(res[1].shape, ()) -======= def test_reshape_list(self): x1 = paddle.rand([]) x2 = paddle.rand([]) @@ -1253,7 +1217,42 @@ def test_reverse(self): res1, res2 = self.exe.run(program, fetch_list=[x, out]) self.assertEqual(res1.shape, ()) self.assertEqual(res2.shape, ()) ->>>>>>> c123dd1e4032efdbfff0bf0c35a58155f2d6e1d9 + + @prog_scope() + def test_sort(self): + x1 = paddle.rand([]) + x1.stop_gradient = False + out1 = paddle.sort(x1, axis=-1) + paddle.static.append_backward(out1) + + x2 = paddle.rand([]) + x2.stop_gradient = False + out2 = paddle.sort(x2, axis=0) + paddle.static.append_backward(out2) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out1, out2]) + + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[1].shape, ()) + + @prog_scope() + def test_argsort(self): + x1 = paddle.rand([]) + x1.stop_gradient = False + out1 = paddle.argsort(x1, axis=-1) + paddle.static.append_backward(out1) + + x2 = paddle.rand([]) + x2.stop_gradient = False + out2 = paddle.argsort(x2, axis=0) + paddle.static.append_backward(out2) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out1, out2]) + + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[1].shape, ()) # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. From d486046fab5d4222fe3287a47df643af2c46afbd Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Wed, 4 Jan 2023 08:39:53 +0000 Subject: [PATCH 4/4] fix grad and add value assertion --- paddle/phi/kernels/cpu/argsort_grad_kernel.cc | 3 +-- paddle/phi/kernels/cpu/argsort_kernel.cc | 1 + paddle/phi/kernels/gpu/argsort_grad_kernel.cu | 2 +- paddle/phi/kernels/xpu/argsort_grad_kernel.cc | 3 +-- .../paddle/fluid/tests/unittests/test_zero_dim_tensor.py | 8 ++++++++ .../fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py | 8 ++++++++ 6 files changed, 20 insertions(+), 5 deletions(-) diff --git a/paddle/phi/kernels/cpu/argsort_grad_kernel.cc b/paddle/phi/kernels/cpu/argsort_grad_kernel.cc index f866b62a2bd4d..81616dafc0a8b 100644 --- a/paddle/phi/kernels/cpu/argsort_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/argsort_grad_kernel.cc @@ -18,7 +18,6 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" -#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/transpose_kernel.h" namespace phi { @@ -68,7 +67,7 @@ void ArgsortGradKernel(const Context& dev_ctx, if (out_grad.numel() == 0) return; if (rank == 0) { - phi::funcs::set_constant(dev_ctx, in_grad, 1.0); + phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, in_grad); return; } diff --git a/paddle/phi/kernels/cpu/argsort_kernel.cc b/paddle/phi/kernels/cpu/argsort_kernel.cc index 07b61fa3f3579..97f8fb67ed1d6 100644 --- a/paddle/phi/kernels/cpu/argsort_kernel.cc +++ b/paddle/phi/kernels/cpu/argsort_kernel.cc @@ -83,6 +83,7 @@ void ArgsortKernel(const Context& dev_ctx, // For 0D Tensor if (rank == 0) { phi::Copy(dev_ctx, input, dev_ctx.GetPlace(), false, output); + dev_ctx.template Alloc(indices); phi::funcs::set_constant(dev_ctx, indices, 0); return; } diff --git a/paddle/phi/kernels/gpu/argsort_grad_kernel.cu b/paddle/phi/kernels/gpu/argsort_grad_kernel.cu index f28da8704cbff..b8d9df64c23ef 100644 --- a/paddle/phi/kernels/gpu/argsort_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/argsort_grad_kernel.cu @@ -150,7 +150,7 @@ void ArgsortGradKernel(const Context& dev_ctx, int64_t size = in_grad->numel(); if (rank == 0) { - phi::funcs::set_constant(dev_ctx, in_grad, 1.0); + phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, in_grad); return; } diff --git a/paddle/phi/kernels/xpu/argsort_grad_kernel.cc b/paddle/phi/kernels/xpu/argsort_grad_kernel.cc index 00c679f0ab999..4ebab7b37fc30 100644 --- a/paddle/phi/kernels/xpu/argsort_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/argsort_grad_kernel.cc @@ -17,7 +17,6 @@ #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { @@ -43,7 +42,7 @@ void ArgsortGradKernel(const Context& dev_ctx, if (out_grad.numel() == 0) return; if (rank == 0) { - phi::funcs::set_constant(dev_ctx, in_grad, 1.0); + phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, in_grad); return; } diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 5aec0c8010f3f..546c0c48f9b58 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -868,10 +868,14 @@ def test_sort(self): self.assertEqual(out1.shape, []) self.assertEqual(out2.shape, []) + self.assertEqual(out1.numpy(), x1.numpy()) + self.assertEqual(out2.numpy(), x2.numpy()) self.assertEqual(out1.grad.shape, []) self.assertEqual(out2.grad.shape, []) self.assertEqual(x1.grad.shape, []) self.assertEqual(x2.grad.shape, []) + self.assertEqual(x1.grad.numpy(), 1) + self.assertEqual(x2.grad.numpy(), 1) def test_argsort(self): x1 = paddle.rand([]) @@ -886,10 +890,14 @@ def test_argsort(self): self.assertEqual(out1.shape, []) self.assertEqual(out2.shape, []) + self.assertEqual(out1.numpy(), 0) + self.assertEqual(out2.numpy(), 0) self.assertEqual(out1.grad.shape, []) self.assertEqual(out2.grad.shape, []) self.assertEqual(x1.grad.shape, []) self.assertEqual(x2.grad.shape, []) + self.assertEqual(x1.grad.numpy(), 0) + self.assertEqual(x2.grad.numpy(), 0) class TestSundryAPIStatic(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py index 8cb27ecf0992a..c0e0de0ac1335 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py @@ -659,10 +659,14 @@ def test_sort(self): self.assertEqual(out1.shape, []) self.assertEqual(out2.shape, []) + self.assertEqual(out1.numpy(), x1.numpy()) + self.assertEqual(out2.numpy(), x2.numpy()) self.assertEqual(out1.grad.shape, []) self.assertEqual(out2.grad.shape, []) self.assertEqual(x1.grad.shape, []) self.assertEqual(x2.grad.shape, []) + self.assertEqual(x1.grad.numpy(), 1) + self.assertEqual(x2.grad.numpy(), 1) def test_argsort(self): x1 = paddle.rand([]) @@ -677,10 +681,14 @@ def test_argsort(self): self.assertEqual(out1.shape, []) self.assertEqual(out2.shape, []) + self.assertEqual(out1.numpy(), 0) + self.assertEqual(out2.numpy(), 0) self.assertEqual(out1.grad.shape, []) self.assertEqual(out2.grad.shape, []) self.assertEqual(x1.grad.shape, []) self.assertEqual(x2.grad.shape, []) + self.assertEqual(x1.grad.numpy(), 0) + self.assertEqual(x2.grad.numpy(), 0) # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.