Skip to content

Commit

Permalink
【0D output】support 0D output for matrix_rank/multi_dot (#52861)
Browse files Browse the repository at this point in the history
* support_0D_output_for_matrix_rank_multi_dot, test=allcase

* add 0D output test for matrox_rank and mutli_dot test=allcase

* fix assert error ,test=allcase

* fix test error, test=allcase

* fix other test error, test=allcase

* fix other test error, test=allcase

* fix test error, test=allcase

* fix matrix_rank and multi dot test err test=allcase

* fix test error test=allcase

* fix test zero dim test, test=allcase

* add static backward test for multi_dot, test=allcase

* add tol 2d broadcast test case, test=allcase
  • Loading branch information
GGBond8488 authored Apr 21, 2023
1 parent 07878a3 commit 47fa806
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 3 deletions.
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ static void BinarySameInputDimsCheck(const MetaTensor& x,
static DDim CheckAndGetOutputDim(const DDim& dim_x) {
auto x_vec = phi::vectorize(dim_x);
if (x_vec.size() == 2) {
return phi::make_ddim({1});
return phi::make_ddim({});
}
x_vec.erase(x_vec.end() - 2, x_vec.end());
return phi::make_ddim(x_vec);
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2345,7 +2345,7 @@ void MultiDotInferMeta(const std::vector<const MetaTensor*>& x,
// If the last tensor is 1D of size n view it as a column vector (n, 1)
if (last_dim.size() == 1) {
last_dim = phi::make_ddim({static_cast<int>(last_dim[0]), 1});
out_dim = is_vector ? phi::make_ddim({1}) : phi::make_ddim({first_dim[0]});
out_dim = is_vector ? phi::make_ddim({}) : phi::make_ddim({first_dim[0]});
} else {
out_dim = is_vector ? phi::make_ddim({last_dim[1]})
: phi::make_ddim({first_dim[0], last_dim[1]});
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ namespace detail {
static DDim CheckAndGetOutputDim(const DDim& dim_x) {
auto x_vec = phi::vectorize(dim_x);
if (x_vec.size() == 2) {
return phi::make_ddim({1});
return phi::make_ddim({});
}
x_vec.erase(x_vec.end() - 2, x_vec.end());
return phi::make_ddim(x_vec);
Expand Down
114 changes: 114 additions & 0 deletions python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2123,6 +2123,23 @@ def test_linalg_slogdet(self):
self.assertTrue(out1.shape, [2, 3])
self.assertTrue(x1.grad.shape, [3, 3, 3])

def test_multi_dot(self):
a = paddle.randn([4])
a.stop_gradient = False
b = paddle.randn([4, 5])
b.stop_gradient = False
c = paddle.randn([5])
c.stop_gradient = False

out = paddle.linalg.multi_dot([a, b, c])
out.retain_grads()
out.backward()

self.assertEqual(out.shape, [])
self.assertEqual(a.grad.shape, [4])
self.assertEqual(b.grad.shape, [4, 5])
self.assertEqual(c.grad.shape, [5])


class TestSundryAPIStatic(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -3710,6 +3727,26 @@ def test_linalg_slogdet(self):
self.assertEqual(res[0].shape, (2, 3))
self.assertEqual(res[1].shape, (3, 3, 3))

@prog_scope()
def test_multi_dot(self):
a = paddle.randn([4])
a.stop_gradient = False
b = paddle.randn([4, 5])
b.stop_gradient = False
c = paddle.randn([5])
c.stop_gradient = False

out = paddle.linalg.multi_dot([a, b, c])
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
res = self.exe.run(
prog, fetch_list=[out, a.grad_name, b.grad_name, c.grad_name]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (4,))
self.assertEqual(res[2].shape, (4, 5))
self.assertEqual(res[3].shape, (5,))


# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
Expand Down Expand Up @@ -3901,6 +3938,38 @@ def test_unique(self):
self.assertEqual(inverse.shape, [1])
self.assertEqual(counts.shape, [1])

def test_matrix_rank(self):
x = paddle.eye(10)
x.stop_gradient = False
out = paddle.linalg.matrix_rank(x)

self.assertEqual(out.shape, [])
np.testing.assert_equal(out, np.array(10))

c = paddle.ones(shape=[3, 4, 5])
c.stop_gradient = False
out_c = paddle.linalg.matrix_rank(c)
self.assertEqual(out_c.shape, [3])
np.testing.assert_equal(out_c, np.array([1, 1, 1]))

# 2D, tol->float : OUTPUT 0D
x_tol = paddle.eye(10)
x_tol.stop_gradient = False
out_tol = paddle.linalg.matrix_rank(x_tol, tol=0.1)
self.assertEqual(out_tol.shape, [])

# 3D, tol->float : OUTPUT 1D
c_tol = paddle.ones(shape=[3, 4, 5])
c_tol.stop_gradient = False
out_c_tol = paddle.linalg.matrix_rank(c_tol, tol=0.1)
self.assertEqual(out_c_tol.shape, [3])

tol_2 = paddle.randn([2])
# 2D, tol->Tensor[1,2] : OUTPUT 1D
d = paddle.eye(10)
out_d = paddle.linalg.matrix_rank(d, tol=tol_2)
self.assertEqual(out_d.shape, [2])


class TestNoBackwardAPIStatic(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -4135,6 +4204,51 @@ def test_unique(self):
self.assertEqual(res[2].shape, (1,))
self.assertEqual(res[3].shape, (1,))

@prog_scope()
def test_static_matrix_rank(self):
# 2D : OUTPUT 0D
x = paddle.eye(10)
x.stop_gradient = False
out = paddle.linalg.matrix_rank(x)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out])
self.assertEqual(res[0].shape, ())

# 3D : OUTPUT 1D
c = paddle.ones(shape=[3, 4, 5])
c.stop_gradient = False
out_c = paddle.linalg.matrix_rank(c)
prog = paddle.static.default_main_program()
self.exe.run(paddle.static.default_startup_program())
res = self.exe.run(prog, fetch_list=[out_c])
self.assertEqual(res[0].shape, (3,))

# 2D, tol->float : OUTPUT 0D
x_tol = paddle.eye(10)
x_tol.stop_gradient = False
out_tol = paddle.linalg.matrix_rank(x_tol, tol=0.1)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out_tol])
self.assertEqual(res[0].shape, ())

# 3D, tol->float : OUTPUT 1D
c_tol = paddle.ones(shape=[3, 4, 5])
c_tol.stop_gradient = False
out_c_tol = paddle.linalg.matrix_rank(c_tol, tol=0.1)
prog = paddle.static.default_main_program()
self.exe.run(paddle.static.default_startup_program())
res = self.exe.run(prog, fetch_list=[out_c_tol])
self.assertEqual(res[0].shape, (3,))

tol_2 = paddle.randn([2])
# 2D, tol->Tensor[1,2] : OUTPUT 1D
d = paddle.eye(10)
out_d = paddle.linalg.matrix_rank(d, tol=tol_2)
prog = paddle.static.default_main_program()
self.exe.run(paddle.static.default_startup_program())
res = self.exe.run(prog, fetch_list=[out_d])
self.assertEqual(res[0].shape, (2,))


unary_apis_with_complex_input = [
paddle.real,
Expand Down

0 comments on commit 47fa806

Please sign in to comment.