Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Zero-Dim] support 0D output for matrix_rank/multi_dot #52861

Merged
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ static void BinarySameInputDimsCheck(const MetaTensor& x,
static DDim CheckAndGetOutputDim(const DDim& dim_x) {
auto x_vec = phi::vectorize(dim_x);
if (x_vec.size() == 2) {
return phi::make_ddim({1});
return phi::make_ddim({});
}
x_vec.erase(x_vec.end() - 2, x_vec.end());
return phi::make_ddim(x_vec);
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2344,7 +2344,7 @@ void MultiDotInferMeta(const std::vector<const MetaTensor*>& x,
// If the last tensor is 1D of size n view it as a column vector (n, 1)
if (last_dim.size() == 1) {
last_dim = phi::make_ddim({static_cast<int>(last_dim[0]), 1});
out_dim = is_vector ? phi::make_ddim({1}) : phi::make_ddim({first_dim[0]});
out_dim = is_vector ? phi::make_ddim({}) : phi::make_ddim({first_dim[0]});
} else {
out_dim = is_vector ? phi::make_ddim({last_dim[1]})
: phi::make_ddim({first_dim[0], last_dim[1]});
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ namespace detail {
static DDim CheckAndGetOutputDim(const DDim& dim_x) {
auto x_vec = phi::vectorize(dim_x);
if (x_vec.size() == 2) {
return phi::make_ddim({1});
return phi::make_ddim({});
}
x_vec.erase(x_vec.end() - 2, x_vec.end());
return phi::make_ddim(x_vec);
Expand Down
114 changes: 114 additions & 0 deletions python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2087,6 +2087,23 @@ def test_linalg_slogdet(self):
self.assertTrue(out1.shape, [2, 3])
self.assertTrue(x1.grad.shape, [3, 3, 3])

def test_multi_dot(self):
a = paddle.randn([4])
a.stop_gradient = False
b = paddle.randn([4, 5])
b.stop_gradient = False
c = paddle.randn([5])
c.stop_gradient = False

out = paddle.linalg.multi_dot([a, b, c])
out.retain_grads()
out.backward()

self.assertEqual(out.shape, [])
self.assertEqual(a.grad.shape, [4])
self.assertEqual(b.grad.shape, [4, 5])
self.assertEqual(c.grad.shape, [5])


class TestSundryAPIStatic(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -3654,6 +3671,26 @@ def test_linalg_slogdet(self):
self.assertEqual(res[0].shape, (2, 3))
self.assertEqual(res[1].shape, (3, 3, 3))

@prog_scope()
def test_multi_dot(self):
a = paddle.randn([4])
a.stop_gradient = False
b = paddle.randn([4, 5])
b.stop_gradient = False
c = paddle.randn([5])
c.stop_gradient = False

out = paddle.linalg.multi_dot([a, b, c])
paddle.static.append_backward(out.sum())
prog = paddle.static.default_main_program()
res = self.exe.run(
prog, fetch_list=[out, a.grad_name, b.grad_name, c.grad_name]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (4,))
self.assertEqual(res[2].shape, (4, 5))
self.assertEqual(res[3].shape, (5,))


# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
Expand Down Expand Up @@ -3845,6 +3882,38 @@ def test_unique(self):
self.assertEqual(inverse.shape, [1])
self.assertEqual(counts.shape, [1])

def test_matrix_rank(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以测下matrix_rank的第二种用法:

1. tol为float:输入tensor x要求是2维或3维,输出tensor 矩阵的秩为0维(输入2维) 或 1维(输入3维)
2. tol为Tensor:输入x的batch维与tol的形状做广播,得到广播后的shape

x = paddle.eye(10)
x.stop_gradient = False
out = paddle.linalg.matrix_rank(x)

self.assertEqual(out.shape, [])
np.testing.assert_equal(out, np.array(10))

c = paddle.ones(shape=[3, 4, 5])
c.stop_gradient = False
out_c = paddle.linalg.matrix_rank(c)
self.assertEqual(out_c.shape, [3])
np.testing.assert_equal(out_c, np.array([1, 1, 1]))

# 2D, tol->float : OUTPUT 0D
x_tol = paddle.eye(10)
x_tol.stop_gradient = False
out_tol = paddle.linalg.matrix_rank(x_tol, tol=0.1)
self.assertEqual(out_tol.shape, [])

# 3D, tol->float : OUTPUT 1D
c_tol = paddle.ones(shape=[3, 4, 5])
c_tol.stop_gradient = False
out_c_tol = paddle.linalg.matrix_rank(c_tol, tol=0.1)
self.assertEqual(out_c_tol.shape, [3])

tol_2 = paddle.randn([2])
# 2D, tol->Tensor[1,2] : OUTPUT 1D
d = paddle.eye(10)
out_d = paddle.linalg.matrix_rank(d, tol=tol_2)
self.assertEqual(out_d.shape, [2])


class TestNoBackwardAPIStatic(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -4079,6 +4148,51 @@ def test_unique(self):
self.assertEqual(res[2].shape, (1,))
self.assertEqual(res[3].shape, (1,))

@prog_scope()
def test_static_matrix_rank(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

和动态图一样可以加个case
命名可以直接test_matrix_rank,加装饰器@prog_scope()

# 2D : OUTPUT 0D
x = paddle.eye(10)
x.stop_gradient = False
out = paddle.linalg.matrix_rank(x)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out])
self.assertEqual(res[0].shape, ())

# 3D : OUTPUT 1D
c = paddle.ones(shape=[3, 4, 5])
c.stop_gradient = False
out_c = paddle.linalg.matrix_rank(c)
prog = paddle.static.default_main_program()
self.exe.run(paddle.static.default_startup_program())
res = self.exe.run(prog, fetch_list=[out_c])
self.assertEqual(res[0].shape, (3,))

# 2D, tol->float : OUTPUT 0D
x_tol = paddle.eye(10)
x_tol.stop_gradient = False
out_tol = paddle.linalg.matrix_rank(x_tol, tol=0.1)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out_tol])
self.assertEqual(res[0].shape, ())

# 3D, tol->float : OUTPUT 1D
c_tol = paddle.ones(shape=[3, 4, 5])
c_tol.stop_gradient = False
out_c_tol = paddle.linalg.matrix_rank(c_tol, tol=0.1)
prog = paddle.static.default_main_program()
self.exe.run(paddle.static.default_startup_program())
res = self.exe.run(prog, fetch_list=[out_c_tol])
self.assertEqual(res[0].shape, (3,))

tol_2 = paddle.randn([2])
# 2D, tol->Tensor[1,2] : OUTPUT 1D
d = paddle.eye(10)
out_d = paddle.linalg.matrix_rank(d, tol=tol_2)
prog = paddle.static.default_main_program()
self.exe.run(paddle.static.default_startup_program())
res = self.exe.run(prog, fetch_list=[out_d])
self.assertEqual(res[0].shape, (2,))


unary_apis_with_complex_input = [
paddle.real,
Expand Down