Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Zero-Dim] reshape/reshape_/reverse 0D support #49357

Merged
merged 7 commits into from
Jan 3, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions paddle/fluid/operators/reshape_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,6 @@ class ReshapeOp : public framework::OperatorWithKernel {
return;
}

PADDLE_ENFORCE_EQ(!shape.empty(),
true,
platform::errors::InvalidArgument(
"The parameter 'shape' in ReshapeOp must be set. "
"But received 'shape' is empty."));
auto x_dims = ctx->GetInputDim("X");
auto out_dims = ValidateShape(shape, x_dims);
ctx->SetOutputDim("Out", out_dims);
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/gpu/flip_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ void FlipKernel(const Context& dev_ctx,
DenseTensor* out) {
const size_t total_dims = x.dims().size();
switch (total_dims) {
case 0:
LaunchFlipCudaKernel<T, Context, 0>(dev_ctx, x, axis, out);
break;
case 1:
LaunchFlipCudaKernel<T, Context, 1>(dev_ctx, x, axis, out);
break;
Expand Down
10 changes: 5 additions & 5 deletions python/paddle/fluid/tests/unittests/test_reshape_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,20 @@ def test_check_grad(self):
class TestReshapeOp_ZeroDim1(OpTest):
def init_data(self):
self.ori_shape = ()
self.new_shape = 1
self.infered_shape = 1
self.new_shape = (1,)
self.infered_shape = (1,)


class TestReshapeOp_ZeroDim2(OpTest):
def init_data(self):
self.ori_shape = ()
self.new_shape = -1
self.infered_shape = 1
self.new_shape = (-1,)
self.infered_shape = (1,)


class TestReshapeOp_ZeroDim3(OpTest):
def init_data(self):
self.ori_shape = 1
self.ori_shape = (1,)
self.new_shape = ()
self.infered_shape = ()

Expand Down
171 changes: 171 additions & 0 deletions python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,105 @@ def test_scatter_nd(self):
self.assertEqual(out.numpy()[3], 2)
self.assertEqual(out.grad.shape, [5])

def test_reshape_list(self):
x = paddle.rand([])
x.stop_gradient = False

out = paddle.reshape(x, [])
out.backward()
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])

out = paddle.reshape(x, [1])
out.backward()
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.shape, [1])
self.assertEqual(out.grad.shape, [1])

out = paddle.reshape(x, [-1])
out.backward()
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.shape, [1])
self.assertEqual(out.grad.shape, [1])

out = paddle.reshape(x, [-1, 1])
out.backward()
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.shape, [1, 1])
self.assertEqual(out.grad.shape, [1, 1])

def test_reshape_tensor(self):
x = paddle.rand([1, 1])
x.stop_gradient = False

out = paddle.reshape(x, [])
out.backward()
self.assertEqual(x.grad.shape, [1, 1])
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])

new_shape = paddle.full([], 1, "int32")
out = paddle.reshape(x, new_shape)
out.backward()
self.assertEqual(x.grad.shape, [1, 1])
self.assertEqual(out.shape, [1])
self.assertEqual(out.grad.shape, [1])

new_shape = paddle.full([], -1, "int32")
out = paddle.reshape(x, new_shape)
out.backward()
self.assertEqual(x.grad.shape, [1, 1])
self.assertEqual(out.shape, [1])
self.assertEqual(out.grad.shape, [1])

new_shape = [paddle.full([], -1, "int32"), paddle.full([], 1, "int32")]
out = paddle.reshape(x, new_shape)
out.backward()
self.assertEqual(x.grad.shape, [1, 1])
self.assertEqual(out.shape, [1, 1])
self.assertEqual(out.grad.shape, [1, 1])

def test_reshape__list(self):
x = paddle.rand([])
out = paddle.reshape_(x, [])
self.assertEqual(out.shape, [])

out = paddle.reshape_(x, [1])
self.assertEqual(out.shape, [1])

out = paddle.reshape_(x, [-1])
self.assertEqual(out.shape, [1])

out = paddle.reshape_(x, [-1, 1])
self.assertEqual(out.shape, [1, 1])

def test_reshape__tensor(self):
x = paddle.rand([1, 1])
out = paddle.reshape_(x, [])
self.assertEqual(out.shape, [])

new_shape = paddle.full([], 1, "int32")
zhaoyinglia marked this conversation as resolved.
Show resolved Hide resolved
out = paddle.reshape_(x, new_shape)
self.assertEqual(out.shape, [1])

new_shape = paddle.full([], -1, "int32")
out = paddle.reshape_(x, new_shape)
self.assertEqual(out.shape, [1])

new_shape = [paddle.full([], -1, "int32"), paddle.full([], 1, "int32")]
out = paddle.reshape_(x, new_shape)
self.assertEqual(out.shape, [1, 1])

def test_reverse(self):
x = paddle.rand([])
x.stop_gradient = False
out = paddle.reverse(x, axis=[])
out.backward()
self.assertEqual(x.shape, [])
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])


class TestSundryAPIStatic(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -914,6 +1013,78 @@ def test_scatter_nd(self):
self.assertEqual(res[0].shape, (5,))
self.assertEqual(res[0][3], 2)

@prog_scope()
def test_reshape_list(self):
x1 = paddle.rand([])
x2 = paddle.rand([])
x3 = paddle.rand([])
x4 = paddle.rand([])
x1.stop_gradient = False
x2.stop_gradient = False
x3.stop_gradient = False
x4.stop_gradient = False

out1 = paddle.reshape(x1, [])
paddle.static.append_backward(out1)

out2 = paddle.reshape(x2, [1])
paddle.static.append_backward(out2)

out3 = paddle.reshape(x3, [-1])
paddle.static.append_backward(out3)

out4 = paddle.reshape(x4, [-1, 1])
paddle.static.append_backward(out4)

program = paddle.static.default_main_program()
res1, res2, res3, res4 = self.exe.run(
program, fetch_list=[out1, out2, out3, out4]
)
self.assertEqual(res1.shape, ())
self.assertEqual(res2.shape, (1,))
self.assertEqual(res3.shape, (1,))
self.assertEqual(res4.shape, (1, 1))

@prog_scope()
def test_reshape_tensor(self):
x1 = paddle.rand([])
x2 = paddle.rand([])
x3 = paddle.rand([])
x1.stop_gradient = False
x2.stop_gradient = False
x3.stop_gradient = False

new_shape = paddle.full([], 1, "int32")
out1 = paddle.reshape(x1, new_shape)
paddle.static.append_backward(out1)

new_shape = paddle.full([], -1, "int32")
out2 = paddle.reshape(x2, new_shape)
paddle.static.append_backward(out2)

new_shape = [paddle.full([], -1, "int32"), paddle.full([], 1, "int32")]
out3 = paddle.reshape(x3, new_shape)
paddle.static.append_backward(out3)

program = paddle.static.default_main_program()
res1, res2, res3 = self.exe.run(program, fetch_list=[out1, out2, out3])
self.assertEqual(res1.shape, (1,))
self.assertEqual(res2.shape, (1,))
self.assertEqual(res3.shape, (1, 1))

@prog_scope()
def test_reverse(self):
x = paddle.rand([])
x.stop_gradient = False

out = paddle.reverse(x, axis=[])
paddle.static.append_backward(out)

program = paddle.static.default_main_program()
res1, res2 = self.exe.run(program, fetch_list=[x, out])
self.assertEqual(res1.shape, ())
self.assertEqual(res2.shape, ())


# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,105 @@ def test_scatter__XD(self):
for i in range(3):
self.assertEqual(out.numpy()[1][i], updates.numpy()[i])

def test_reshape_list(self):
x = paddle.rand([])
x.stop_gradient = False

out = paddle.reshape(x, [])
out.backward()
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])

out = paddle.reshape(x, [1])
out.backward()
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.shape, [1])
self.assertEqual(out.grad.shape, [1])

out = paddle.reshape(x, [-1])
out.backward()
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.shape, [1])
self.assertEqual(out.grad.shape, [1])

out = paddle.reshape(x, [-1, 1])
out.backward()
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.shape, [1, 1])
self.assertEqual(out.grad.shape, [1, 1])

def test_reshape_tensor(self):
x = paddle.rand([1, 1])
x.stop_gradient = False

out = paddle.reshape(x, [])
out.backward()
self.assertEqual(x.grad.shape, [1, 1])
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])

new_shape = paddle.full([], 1, "int32")
out = paddle.reshape(x, new_shape)
out.backward()
self.assertEqual(x.grad.shape, [1, 1])
self.assertEqual(out.shape, [1])
self.assertEqual(out.grad.shape, [1])

new_shape = paddle.full([], -1, "int32")
out = paddle.reshape(x, new_shape)
out.backward()
self.assertEqual(x.grad.shape, [1, 1])
self.assertEqual(out.shape, [1])
self.assertEqual(out.grad.shape, [1])

new_shape = [paddle.full([], -1, "int32"), paddle.full([], 1, "int32")]
out = paddle.reshape(x, new_shape)
out.backward()
self.assertEqual(x.grad.shape, [1, 1])
self.assertEqual(out.shape, [1, 1])
self.assertEqual(out.grad.shape, [1, 1])

def test_reshape__list(self):
x = paddle.rand([])
out = paddle.reshape_(x, [])
self.assertEqual(out.shape, [])

out = paddle.reshape_(x, [1])
self.assertEqual(out.shape, [1])

out = paddle.reshape_(x, [-1])
self.assertEqual(out.shape, [1])

out = paddle.reshape_(x, [-1, 1])
self.assertEqual(out.shape, [1, 1])

def test_reshape__tensor(self):
x = paddle.rand([1, 1])
out = paddle.reshape_(x, [])
self.assertEqual(out.shape, [])

new_shape = paddle.full([], 1, "int32")
out = paddle.reshape_(x, new_shape)
self.assertEqual(out.shape, [1])

new_shape = paddle.full([], -1, "int32")
out = paddle.reshape_(x, new_shape)
self.assertEqual(out.shape, [1])

new_shape = [paddle.full([], -1, "int32"), paddle.full([], 1, "int32")]
out = paddle.reshape_(x, new_shape)
self.assertEqual(out.shape, [1, 1])

def test_reverse(self):
x = paddle.rand([])
x.stop_gradient = False
out = paddle.reverse(x, axis=[])
out.backward()
self.assertEqual(x.shape, [])
self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, [])


# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
Expand Down
6 changes: 1 addition & 5 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3450,7 +3450,7 @@ def reshape(x, shape, name=None):
Args:
x (Tensor): An N-D Tensor. The data type is ``float32``, ``float64``, ``int32``, ``int64`` or ``bool``
shape (list|tuple|Tensor): Define the target shape. At most one dimension of the target shape can be -1.
The data type is ``int32`` . If ``shape`` is a list or tuple, the elements of it should be integers or Tensors with shape [1].
The data type is ``int32`` . If ``shape`` is a list or tuple, the elements of it should be integers or Tensors with shape [].
If ``shape`` is an Tensor, it should be an 1-D Tensor .
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

Expand Down Expand Up @@ -3574,10 +3574,6 @@ def get_attr_shape(list_shape):
shape.stop_gradient = True
inputs["Shape"] = shape
elif isinstance(shape, (list, tuple)):
assert len(shape) > 0, (
"The size of 'shape' in reshape can't be zero, "
"but received %s." % len(shape)
)
attrs["shape"] = get_attr_shape(shape)
if utils._contain_var(shape):
inputs['ShapeTensor'] = utils._convert_to_tensor_list(shape)
Expand Down