Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Zero-Dim] Support output 0D for squeeze, unbind, unstack. #52843

Merged
merged 12 commits into from
Apr 28, 2023
3 changes: 3 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3761,6 +3761,9 @@ void SqueezeInferMeta(const MetaTensor& x,
if (!config.is_runtime && axes.FromTensor()) {
// compile time infershape, set all elements to -1.
int output_size = x.dims().size() - axes.GetData().size();
if (x.dims().size() == 0 && output_size == -1) {
output_size = 0;
}
std::vector<int64_t> vec_out_dims(output_size, -1);
out->set_dims(phi::make_ddim(vec_out_dims));
} else {
Expand Down
142 changes: 142 additions & 0 deletions python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2292,6 +2292,56 @@ def test_upsample(self):
self.assertEqual(out1.shape, [2, 3, 12, 12])
self.assertEqual(input_x.grad.shape, [2, 3, 6, 6])

def test_unstack(self):
x1 = paddle.full([1], 0)
x2 = paddle.full([2], 2)
x1.retain_grads()
x2.retain_grads()
x1.stop_gradient = False
x2.stop_gradient = False

[out1] = paddle.unstack(x1, 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个API有反向,需要测一下反向。x1.grad/x2.grad

out1.retain_grads()
out1.backward()
[out2_1, out2_2] = paddle.unstack(x2, 0)
out2 = paddle.add_n([out2_1, out2_2])
out2.retain_grads()
out2.backward()

self.assertEqual(out1.shape, [])
self.assertEqual(out1.numpy(), 0)

self.assertEqual(out2_1.shape, [])
self.assertEqual(out2_1.numpy(), 2)
self.assertEqual(out2_2.shape, [])
self.assertEqual(out2_2.numpy(), 2)
self.assertEqual(x2.grad.shape, [2])

def test_unbind(self):
x1 = paddle.full([1], 0)
x2 = paddle.full([2], 2)
x1.retain_grads()
x2.retain_grads()
x1.stop_gradient = False
x2.stop_gradient = False

[out1] = paddle.unbind(x1, 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个API有反向,需要测一下反向。x1.grad/x2.grad

out1.retain_grads()
out1.backward()
[out2_1, out2_2] = paddle.unbind(x2, 0)
out2 = paddle.add_n([out2_1, out2_2])
out2.retain_grads()
out2.backward()

self.assertEqual(out1.shape, [])
self.assertEqual(out1.numpy(), 0)

self.assertEqual(out2_1.shape, [])
self.assertEqual(out2_1.numpy(), 2)
self.assertEqual(out2_2.shape, [])
self.assertEqual(out2_2.numpy(), 2)
self.assertEqual(x2.grad.shape, [2])

def test_maseked_select(self):
x = paddle.rand([])
x.stop_gradient = False
Expand All @@ -2306,6 +2356,26 @@ def test_maseked_select(self):
self.assertEqual(x.grad.shape, [])
self.assertEqual(x.grad.numpy(), 1)

def test_squeeze(self):
x1 = paddle.full([], 2)
x1.stop_gradient = False
x1.retain_grads()
out1 = paddle.squeeze(x1, axis=0)
out1.retain_grads()
out1.backward()
self.assertEqual(out1.shape, [])
self.assertEqual(x1.grad.shape, [])

x2 = paddle.full([], 3)
x3 = paddle.full([1], 0, dtype='int32')
x2.stop_gradient = False
x2.retain_grads()
out2 = paddle.squeeze(x2, axis=x3)
out2.retain_grads()
out2.backward()
self.assertEqual(out2.shape, [])
self.assertEqual(x2.grad.shape, [])

def test_unsqueeze(self):
x1 = paddle.full([], 2)
x1.stop_gradient = False
Expand Down Expand Up @@ -4073,6 +4143,50 @@ def test_upsample(self):
self.assertEqual(res1[0].shape, (2, 3, 12, 12))
self.assertEqual(res1[1].shape, (2, 3, 6, 6))

@prog_scope()
def test_unstack(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

测试case同静态图

x1 = paddle.full([1], 0, 'float32')
x1.stop_gradient = False
out1 = paddle.unstack(x1, 0)
out1 = paddle.add_n(out1)
paddle.static.append_backward(out1)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, feed={}, fetch_list=[out1, x1.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (1,))

x2 = paddle.full([2], 2, 'float32')
x2.stop_gradient = False
out2 = paddle.unstack(x2, 0)
out2_sum = paddle.add_n(out2)
paddle.static.append_backward(out2_sum)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, feed={}, fetch_list=[out2_sum, x2.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2,))

@prog_scope()
def test_unbind(self):
x1 = paddle.full([1], 0, 'float32')
x1.stop_gradient = False
out1 = paddle.unbind(x1, 0)
out1 = paddle.add_n(out1)
paddle.static.append_backward(out1)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, feed={}, fetch_list=[out1, x1.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (1,))

x2 = paddle.full([2], 2, 'float32')
x2.stop_gradient = False
out2 = paddle.unbind(x2, 0)
out2_sum = paddle.add_n(out2)
paddle.static.append_backward(out2_sum)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, feed={}, fetch_list=[out2_sum, x2.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2,))

@prog_scope()
def test_maseked_select(self):
x = paddle.rand([])
Expand All @@ -4089,6 +4203,34 @@ def test_maseked_select(self):
self.assertEqual(res[3].shape, ())
self.assertEqual(res[3], 1)

@prog_scope()
def test_squeeze(self):
x1 = paddle.full([], 2)
x1.stop_gradient = False
out1 = paddle.squeeze(x1, axis=0)
paddle.static.append_backward(out1.sum())

x2 = paddle.full([], 3)
x3 = paddle.full([], 0, dtype='int32')
x2.stop_gradient = False
out2 = paddle.squeeze(x2, axis=x3)
paddle.static.append_backward(out2.sum())

prog = paddle.static.default_main_program()
res = self.exe.run(
prog,
fetch_list=[
out1,
out2,
x1.grad_name,
x2.grad_name,
],
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, ())
self.assertEqual(res[3].shape, ())

@prog_scope()
def test_unsqueeze(self):
x1 = paddle.full([], 2)
Expand Down