From 6a05a9bd3a2daf8fe1b70f792b23820c95fbcf28 Mon Sep 17 00:00:00 2001 From: GGBond8488 <857631483@qq.com> Date: Sat, 22 Apr 2023 16:25:52 +0000 Subject: [PATCH 1/3] add 0D support for trace, test=allcase --- paddle/phi/infermeta/unary.cc | 1 - .../tests/unittests/test_zero_dim_tensor.py | 24 +++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index bfe744446a97b..ea27eba513051 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -4402,7 +4402,6 @@ void TraceInferMeta( auto sizes = vectorize(x_dims); if (x_dims.size() == 2) { sizes.clear(); - sizes.push_back(1); } else { sizes.erase(sizes.begin() + std::max(dim1_, dim2_)); sizes.erase(sizes.begin() + std::min(dim1_, dim2_)); diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index d18f94e78da39..a3d1cb12c446e 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -2423,6 +2423,16 @@ def test_multi_dot(self): self.assertEqual(b.grad.shape, [4, 5]) self.assertEqual(c.grad.shape, [5]) + def test_trace(self): + x = paddle.to_tensor([[3, 2], [1, 9]], dtype="float32") + x.stop_gradient = False + out = paddle.trace(x) + out.backward() + + self.assertEqual(out.shape, []) + np.testing.assert_allclose(out, np.array(12)) + self.assertEqual(x.grad.shape, [2, 2]) + class TestSundryAPIStatic(unittest.TestCase): def setUp(self): @@ -4399,6 +4409,20 @@ def test_multi_dot(self): self.assertEqual(res[2].shape, (4, 5)) self.assertEqual(res[3].shape, (5,)) + @prog_scope() + def test_trace(self): + x = paddle.to_tensor([[3, 2], [1, 9]], dtype="float32") + x.stop_gradient = False + out = paddle.trace(x) + paddle.static.append_backward(out) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out, x.grad_name]) + + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[1].shape, (2, 2)) + np.testing.assert_allclose(res[0], np.array(12)) + # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. class TestNoBackwardAPI(unittest.TestCase): From 1bb3c0df43062c15cac88f1e75b6862a73566030 Mon Sep 17 00:00:00 2001 From: GGBond8488 <857631483@qq.com> Date: Sun, 23 Apr 2023 05:35:38 +0000 Subject: [PATCH 2/3] fix trace gpu kernel 0d error, test=allcase --- paddle/phi/kernels/gpu/trace_kernel.cu | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/gpu/trace_kernel.cu b/paddle/phi/kernels/gpu/trace_kernel.cu index 671ca490e136a..304bf778094d3 100644 --- a/paddle/phi/kernels/gpu/trace_kernel.cu +++ b/paddle/phi/kernels/gpu/trace_kernel.cu @@ -32,7 +32,10 @@ void TraceKernel(const Context& ctx, auto diag = funcs::Diagonal(ctx, &x, offset, axis1, axis2); if (diag.numel() > 0) { std::vector reduce_dims; - reduce_dims.push_back(out->dims().size()); + // Adapt to 0D output + auto out_dim_size = out->dims().size(); + if (out_dim_size == 0) out_dim_size = 1; + reduce_dims.push_back(out_dim_size); funcs::ReduceKernel>( ctx, diag, out, kps::IdentityFunctor(), reduce_dims); } else { From 026adfe68ca92221f29b35957e592a39ba6d1a2b Mon Sep 17 00:00:00 2001 From: GGBond8488 <857631483@qq.com> Date: Sun, 23 Apr 2023 08:21:56 +0000 Subject: [PATCH 3/3] fix windows error, test=allcase --- paddle/phi/kernels/impl/trace_grad_kernel_impl.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/impl/trace_grad_kernel_impl.h b/paddle/phi/kernels/impl/trace_grad_kernel_impl.h index 90a2327ef3e20..1099f27f3622e 100644 --- a/paddle/phi/kernels/impl/trace_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/trace_grad_kernel_impl.h @@ -91,7 +91,8 @@ void TraceGradKernel(const Context& ctx, auto input_dims = in_grad->dims(); auto input_stride = phi::stride(input_dims); auto output_dims = out_grad.dims(); - auto output_stride = phi::stride(output_dims); + auto output_stride = output_dims.size() == 0 ? phi::DDim(output_dims) + : phi::stride(output_dims); auto* out_data = out_grad.data(); T* x_data = ctx.template Alloc(in_grad);