diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 404c55eebb061..7b13f997569c6 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -72,7 +72,7 @@ static void BinarySameInputDimsCheck(const MetaTensor& x, static DDim CheckAndGetOutputDim(const DDim& dim_x) { auto x_vec = phi::vectorize(dim_x); if (x_vec.size() == 2) { - return phi::make_ddim({1}); + return phi::make_ddim({}); } x_vec.erase(x_vec.end() - 2, x_vec.end()); return phi::make_ddim(x_vec); diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 31aa85245ebdd..6a0a3796c8854 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2344,7 +2344,7 @@ void MultiDotInferMeta(const std::vector& x, // If the last tensor is 1D of size n view it as a column vector (n, 1) if (last_dim.size() == 1) { last_dim = phi::make_ddim({static_cast(last_dim[0]), 1}); - out_dim = is_vector ? phi::make_ddim({1}) : phi::make_ddim({first_dim[0]}); + out_dim = is_vector ? phi::make_ddim({}) : phi::make_ddim({first_dim[0]}); } else { out_dim = is_vector ? phi::make_ddim({last_dim[1]}) : phi::make_ddim({first_dim[0], last_dim[1]}); diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index c1ee2b5d4ec11..8eb994f43965a 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -38,7 +38,7 @@ namespace detail { static DDim CheckAndGetOutputDim(const DDim& dim_x) { auto x_vec = phi::vectorize(dim_x); if (x_vec.size() == 2) { - return phi::make_ddim({1}); + return phi::make_ddim({}); } x_vec.erase(x_vec.end() - 2, x_vec.end()); return phi::make_ddim(x_vec); diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 3b909b7822572..6c8dc45c50f91 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -2087,6 +2087,23 @@ def test_linalg_slogdet(self): self.assertTrue(out1.shape, [2, 3]) self.assertTrue(x1.grad.shape, [3, 3, 3]) + def test_multi_dot(self): + a = paddle.randn([4]) + a.stop_gradient = False + b = paddle.randn([4, 5]) + b.stop_gradient = False + c = paddle.randn([5]) + c.stop_gradient = False + + out = paddle.linalg.multi_dot([a, b, c]) + out.retain_grads() + out.backward() + + self.assertEqual(out.shape, []) + self.assertEqual(a.grad.shape, [4]) + self.assertEqual(b.grad.shape, [4, 5]) + self.assertEqual(c.grad.shape, [5]) + class TestSundryAPIStatic(unittest.TestCase): def setUp(self): @@ -3654,6 +3671,26 @@ def test_linalg_slogdet(self): self.assertEqual(res[0].shape, (2, 3)) self.assertEqual(res[1].shape, (3, 3, 3)) + @prog_scope() + def test_multi_dot(self): + a = paddle.randn([4]) + a.stop_gradient = False + b = paddle.randn([4, 5]) + b.stop_gradient = False + c = paddle.randn([5]) + c.stop_gradient = False + + out = paddle.linalg.multi_dot([a, b, c]) + paddle.static.append_backward(out.sum()) + prog = paddle.static.default_main_program() + res = self.exe.run( + prog, fetch_list=[out, a.grad_name, b.grad_name, c.grad_name] + ) + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[1].shape, (4,)) + self.assertEqual(res[2].shape, (4, 5)) + self.assertEqual(res[3].shape, (5,)) + # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. class TestNoBackwardAPI(unittest.TestCase): @@ -3845,6 +3882,38 @@ def test_unique(self): self.assertEqual(inverse.shape, [1]) self.assertEqual(counts.shape, [1]) + def test_matrix_rank(self): + x = paddle.eye(10) + x.stop_gradient = False + out = paddle.linalg.matrix_rank(x) + + self.assertEqual(out.shape, []) + np.testing.assert_equal(out, np.array(10)) + + c = paddle.ones(shape=[3, 4, 5]) + c.stop_gradient = False + out_c = paddle.linalg.matrix_rank(c) + self.assertEqual(out_c.shape, [3]) + np.testing.assert_equal(out_c, np.array([1, 1, 1])) + + # 2D, tol->float : OUTPUT 0D + x_tol = paddle.eye(10) + x_tol.stop_gradient = False + out_tol = paddle.linalg.matrix_rank(x_tol, tol=0.1) + self.assertEqual(out_tol.shape, []) + + # 3D, tol->float : OUTPUT 1D + c_tol = paddle.ones(shape=[3, 4, 5]) + c_tol.stop_gradient = False + out_c_tol = paddle.linalg.matrix_rank(c_tol, tol=0.1) + self.assertEqual(out_c_tol.shape, [3]) + + tol_2 = paddle.randn([2]) + # 2D, tol->Tensor[1,2] : OUTPUT 1D + d = paddle.eye(10) + out_d = paddle.linalg.matrix_rank(d, tol=tol_2) + self.assertEqual(out_d.shape, [2]) + class TestNoBackwardAPIStatic(unittest.TestCase): def setUp(self): @@ -4079,6 +4148,51 @@ def test_unique(self): self.assertEqual(res[2].shape, (1,)) self.assertEqual(res[3].shape, (1,)) + @prog_scope() + def test_static_matrix_rank(self): + # 2D : OUTPUT 0D + x = paddle.eye(10) + x.stop_gradient = False + out = paddle.linalg.matrix_rank(x) + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out]) + self.assertEqual(res[0].shape, ()) + + # 3D : OUTPUT 1D + c = paddle.ones(shape=[3, 4, 5]) + c.stop_gradient = False + out_c = paddle.linalg.matrix_rank(c) + prog = paddle.static.default_main_program() + self.exe.run(paddle.static.default_startup_program()) + res = self.exe.run(prog, fetch_list=[out_c]) + self.assertEqual(res[0].shape, (3,)) + + # 2D, tol->float : OUTPUT 0D + x_tol = paddle.eye(10) + x_tol.stop_gradient = False + out_tol = paddle.linalg.matrix_rank(x_tol, tol=0.1) + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out_tol]) + self.assertEqual(res[0].shape, ()) + + # 3D, tol->float : OUTPUT 1D + c_tol = paddle.ones(shape=[3, 4, 5]) + c_tol.stop_gradient = False + out_c_tol = paddle.linalg.matrix_rank(c_tol, tol=0.1) + prog = paddle.static.default_main_program() + self.exe.run(paddle.static.default_startup_program()) + res = self.exe.run(prog, fetch_list=[out_c_tol]) + self.assertEqual(res[0].shape, (3,)) + + tol_2 = paddle.randn([2]) + # 2D, tol->Tensor[1,2] : OUTPUT 1D + d = paddle.eye(10) + out_d = paddle.linalg.matrix_rank(d, tol=tol_2) + prog = paddle.static.default_main_program() + self.exe.run(paddle.static.default_startup_program()) + res = self.exe.run(prog, fetch_list=[out_d]) + self.assertEqual(res[0].shape, (2,)) + unary_apis_with_complex_input = [ paddle.real,