From a79c04f363163fea6624696b8ad4889d275fbc5f Mon Sep 17 00:00:00 2001 From: JYChen Date: Sun, 23 Apr 2023 14:03:30 +0800 Subject: [PATCH] Cherry pick getitem/setitem 0d (#53125) * support 0-D output and 0-D as indice in __getitem__ * fix tests * fix inference and UT * add unittest for setitem * fix xpu test * fix xpu 0-d --- paddle/fluid/framework/attribute_checker.h | 23 +- paddle/fluid/inference/tensorrt/engine.h | 4 +- paddle/fluid/pybind/eager_method.cc | 11 - paddle/fluid/pybind/imperative.cc | 12 - paddle/phi/infermeta/unary.cc | 3 - paddle/phi/kernels/funcs/slice_utils.h | 6 - .../phi/kernels/xpu/set_value_grad_kernel.cc | 5 + .../auto_parallel/operators/dist_slice.py | 6 +- python/paddle/fft.py | 4 +- .../unittests/test_imperative_numpy_bridge.py | 2 +- .../fluid/tests/unittests/test_kthvalue_op.py | 8 +- .../tests/unittests/test_set_value_op.py | 2 +- .../fluid/tests/unittests/test_slice_op.py | 4 +- .../fluid/tests/unittests/test_var_base.py | 11 +- .../fluid/tests/unittests/test_variable.py | 3 +- .../fluid/tests/unittests/test_while_op.py | 6 +- .../tests/unittests/test_zero_dim_tensor.py | 246 ++++++++++++++++++ python/paddle/fluid/variable_index.py | 18 +- .../incubate/optimizer/functional/lbfgs.py | 8 +- .../jit/dy2static/variable_trans_func.py | 6 +- python/paddle/nn/layer/rnn.py | 2 +- python/paddle/tensor/manipulation.py | 18 +- test/dygraph_to_static/test_list.py | 2 +- test/xpu/test_set_value_op_xpu.py | 2 +- test/xpu/test_slice_op_xpu.py | 4 +- test/xpu/test_zero_dim_tensor_xpu.py | 134 ++++++++++ 26 files changed, 440 insertions(+), 110 deletions(-) diff --git a/paddle/fluid/framework/attribute_checker.h b/paddle/fluid/framework/attribute_checker.h index 67eb69efdf3d0..2e5e7bf8939ae 100644 --- a/paddle/fluid/framework/attribute_checker.h +++ b/paddle/fluid/framework/attribute_checker.h @@ -73,10 +73,10 @@ class TypedAttrVarInfoChecker { platform::errors::InvalidArgument( "Required Attribute with Variable type shall not be nullptr.")); auto shape = var_desc->GetShape(); - PADDLE_ENFORCE_EQ(shape.size(), + PADDLE_ENFORCE_LE(shape.size(), 1U, platform::errors::InvalidArgument( - "Required shape rank of Attribute(%s) == 1, " + "Required shape rank of Attribute(%s) <= 1, " "but received rank == %s", var_desc->Name(), shape.size())); @@ -105,20 +105,21 @@ class TypedAttrVarInfoChecker { platform::errors::InvalidArgument( "Required Attribute with Variable type shall not be nullptr.")); auto shape = var_desc->GetShape(); - PADDLE_ENFORCE_EQ(shape.size(), + PADDLE_ENFORCE_LE(shape.size(), 1U, platform::errors::InvalidArgument( - "Required shape rank of Attribute(%s) == 1, " + "Required shape rank of Attribute(%s) <= 1, " "but received rank == %s", var_desc->Name(), shape.size())); - PADDLE_ENFORCE_EQ(shape[0] == 1U || shape[0] == -1, - true, - platform::errors::InvalidArgument( - "Required shape[0] of Attribute(%s) == 1 or -1, " - "but received shape[0] == %s", - var_desc->Name(), - shape[0])); + PADDLE_ENFORCE_EQ( + shape.size() == 0U || shape[0] == 1U || shape[0] == -1, + true, + platform::errors::InvalidArgument( + "Required shape is (), or shape[0] of Attribute(%s) == 1 or -1, " + "but received shape[0] == %s", + var_desc->Name(), + shape[0])); } } }; diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index 0d77991deaf91..05746ea5123c8 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -86,10 +86,10 @@ template nvinfer1::Dims Vec2TRT_Dims(const std::vector& shape, std::string input, bool with_dynamic_shape = false) { - PADDLE_ENFORCE_GT(shape.size(), + PADDLE_ENFORCE_GE(shape.size(), 0UL, platform::errors::InvalidArgument( - "TensorRT's tensor input requires at least 1 " + "TensorRT's tensor input requires at least 0 " "dimensions, but input %s has %d dims.", input, shape.size())); diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index 2455eed34fe36..11dca753092fd 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -923,17 +923,6 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self, } if (!none_axes.empty()) { - // Deal with cases when all axes are decreased. - // After slice, the shape of out is [1], which should have been - // [], but Paddle doesn't support scalar. - // In order to ensure the correctness of the final shape of out, - // one dimension of out needs to be decreased. - // For example: - // # x.shape: (2,3,4) - // out = x[0, 1, 1, None] # out.shape : (1) - if (static_cast(decrease_axis.size()) == tensor->dims().size()) { - none_axes.pop_back(); - } if (!none_axes.empty()) { paddle::Tensor new_out; { diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index e78a5bfd35da2..44d4d070eafb1 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -1068,18 +1068,6 @@ void BindImperative(py::module *m_ptr) { tracer->TraceOp(op_type, ins, outs, std::move(attrs)); } if (!none_axes.empty()) { - // Deal with cases when all axes are decreased. - // After slice, the shape of out is [1], which should have been - // [], but Paddle doesn't support scalar. - // In order to ensure the correctness of the final shape of out, - // one dimension of out needs to be decreased. - // For example: - // # x.shape: (2,3,4) - // out = x[0, 1, 1, None] # out.shape : (1) - if (static_cast(decrease_axis.size()) == - tensor->dims().size()) { - none_axes.pop_back(); - } if (!none_axes.empty()) { // Deal with cases that decrease_axes is not empty // For example: diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index c1ee2b5d4ec11..8220a6e09eb1d 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -3927,9 +3927,6 @@ void StridedSliceRawInferMeta(const MetaTensor& x, new_out_shape.push_back(out_dims[i]); } } - if (new_out_shape.size() == 0) { - new_out_shape.push_back(1); - } out_dims = phi::make_ddim(new_out_shape); } VLOG(4) << "out_dims: " << out_dims; diff --git a/paddle/phi/kernels/funcs/slice_utils.h b/paddle/phi/kernels/funcs/slice_utils.h index a56a5e16f6503..9bbb7681dd888 100644 --- a/paddle/phi/kernels/funcs/slice_utils.h +++ b/paddle/phi/kernels/funcs/slice_utils.h @@ -203,12 +203,6 @@ inline DDim GetDecreasedDims(const DDim slice_dims, } } - // NOTE(liym27): Paddle does not support that the rank of Tensor is 0, and - // uses [1] instead. - if (new_shape.size() == 0) { - new_shape.push_back(1); - } - decreased_dims = phi::make_ddim(new_shape); } return decreased_dims; diff --git a/paddle/phi/kernels/xpu/set_value_grad_kernel.cc b/paddle/phi/kernels/xpu/set_value_grad_kernel.cc index d7e1ed8114e00..d80a2a97da8cf 100644 --- a/paddle/phi/kernels/xpu/set_value_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/set_value_grad_kernel.cc @@ -266,6 +266,11 @@ void SetValueGradImpl(const Context& dev_ctx, {fake_value_grad_dims.Get(), fake_value_grad_dims.size()}, static_cast(0)); auto value_grad_dims_vec = phi::vectorize(value_grad_dims); + // for value is a 0-D Tensor + if (value_grad_dims.size() == 0) { + value_grad_dims_vec = + phi::vectorize(phi::make_ddim(std::vector({1}))); + } for (auto offset : offsets) { for (int i = 0; i < out_dims_size; i++) { slice_end[i] = offset[i] + fake_value_grad_dims[i]; diff --git a/python/paddle/distributed/auto_parallel/operators/dist_slice.py b/python/paddle/distributed/auto_parallel/operators/dist_slice.py index 17e68002fa42f..69ccd8d7bc868 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_slice.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_slice.py @@ -70,9 +70,7 @@ def is_output_compatible(self, dist_op): if i not in decrease_axis: ref_indices.append(i) if ref_indices == []: - assert len(out_dims_mapping) == 1 - if is_dim_shard(out_dims_mapping[0]): - return False + assert len(out_dims_mapping) == 0 else: for i in range(len(out_dims_mapping)): ref_index = ref_indices[i] @@ -142,9 +140,7 @@ def update_dims_mapping(self, dist_op): ref_indices.append(i) if ref_dims_mapping == []: - ref_dims_mapping = [-1] assert len(ref_dims_mapping) == len(out_dims_mapping) - assert ref_dims_mapping[0] == out_dims_mapping[0] changed = False else: assert len(ref_dims_mapping) == len(out_dims_mapping) diff --git a/python/paddle/fft.py b/python/paddle/fft.py index 1ce18f120c19e..48c20f7fdafaf 100644 --- a/python/paddle/fft.py +++ b/python/paddle/fft.py @@ -1371,7 +1371,7 @@ def fftshift(x, axes=None, name=None): elif isinstance(axes, int): shifts = shape[axes] // 2 else: - shifts = paddle.concat([shape[ax] // 2 for ax in axes]) + shifts = paddle.stack([shape[ax] // 2 for ax in axes]) return paddle.roll(x, shifts, axes, name=name) @@ -1416,7 +1416,7 @@ def ifftshift(x, axes=None, name=None): elif isinstance(axes, int): shifts = -shape[axes] // 2 else: - shifts = paddle.concat([-shape[ax] // 2 for ax in axes]) + shifts = paddle.stack([-shape[ax] // 2 for ax in axes]) return paddle.roll(x, shifts, axes, name=name) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_numpy_bridge.py b/python/paddle/fluid/tests/unittests/test_imperative_numpy_bridge.py index effcfece0f5ee..58059a295539d 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_numpy_bridge.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_numpy_bridge.py @@ -43,7 +43,7 @@ def test_tensor_from_numpy(self): np.testing.assert_array_equal(var2.numpy(), data_np) data_np[0][0] = -1 self.assertEqual(data_np[0][0], -1) - self.assertNotEqual(var2[0][0].numpy()[0], -1) + self.assertNotEqual(var2[0][0].numpy(), -1) self.assertFalse(np.array_equal(var2.numpy(), data_np)) diff --git a/python/paddle/fluid/tests/unittests/test_kthvalue_op.py b/python/paddle/fluid/tests/unittests/test_kthvalue_op.py index 66389a870e46f..0bf3d8e948097 100644 --- a/python/paddle/fluid/tests/unittests/test_kthvalue_op.py +++ b/python/paddle/fluid/tests/unittests/test_kthvalue_op.py @@ -140,16 +140,16 @@ def test_nan_in_cpu_kernel(): nan_position = 100 self.x[0, nan_position, 2] = float('nan') v, inds = self.x.kthvalue(k=200, axis=1) - self.assertTrue(np.isnan(v[0, 2].numpy()[0])) - self.assertEqual(inds[0, 2].numpy()[0], nan_position) + self.assertTrue(np.isnan(v[0, 2].numpy())) + self.assertEqual(inds[0, 2].numpy(), nan_position) def test_nan_in_gpu_kernel(): paddle.set_device('gpu') nan_position = 100 self.x[0, nan_position, 2] = float('nan') v, inds = self.x.kthvalue(k=200, axis=1) - self.assertTrue(np.isnan(v[0, 2].numpy()[0])) - self.assertEqual(inds[0, 2].numpy()[0], nan_position) + self.assertTrue(np.isnan(v[0, 2].numpy())) + self.assertEqual(inds[0, 2].numpy(), nan_position) test_nan_in_cpu_kernel() if fluid.core.is_compiled_with_cuda(): diff --git a/python/paddle/fluid/tests/unittests/test_set_value_op.py b/python/paddle/fluid/tests/unittests/test_set_value_op.py index 9c5a71df01835..66d4b8f05b75d 100644 --- a/python/paddle/fluid/tests/unittests/test_set_value_op.py +++ b/python/paddle/fluid/tests/unittests/test_set_value_op.py @@ -1590,7 +1590,7 @@ def test_inplace(self): a.stop_gradient = False b = a[:] c = b - b[paddle.to_tensor(0)] = 1.0 + b[paddle.zeros([], dtype='int32')] = 1.0 self.assertTrue(id(b) == id(c)) np.testing.assert_array_equal(b.numpy(), c.numpy()) diff --git a/python/paddle/fluid/tests/unittests/test_slice_op.py b/python/paddle/fluid/tests/unittests/test_slice_op.py index 0314a37170d0e..f43bd4b140d7d 100644 --- a/python/paddle/fluid/tests/unittests/test_slice_op.py +++ b/python/paddle/fluid/tests/unittests/test_slice_op.py @@ -541,8 +541,8 @@ class TestSliceAPI(unittest.TestCase): def test_1(self): with paddle_static_guard(): input = np.random.random([3, 4, 5, 6]).astype("float64") - minus_1 = paddle.tensor.fill_constant([1], "int32", -1) - minus_3 = paddle.tensor.fill_constant([1], "int64", -3) + minus_1 = paddle.tensor.fill_constant([], "int32", -1) + minus_3 = paddle.tensor.fill_constant([], "int64", -3) starts = paddle.static.data( name='starts', shape=[1, 3], dtype="float32" ) diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index 24920eb375ce1..c9607f89197a5 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -604,8 +604,7 @@ def _test_slice(self): nw = w[1, 1, 1] - self.assertEqual(len(nw.shape), 1) - self.assertEqual(nw.shape[0], 1) + self.assertEqual(len(nw.shape), 0) nw = w[:, :, :-1] self.assertEqual((784, 100, 99), tuple(nw.shape)) @@ -705,10 +704,10 @@ def _test_slice_for_tensor_attr(self): var = paddle.to_tensor(tensor_array) - one = paddle.ones(shape=[1], dtype="int32") - two = paddle.full(shape=[1], fill_value=2, dtype="int32") - negative_one = paddle.full(shape=[1], fill_value=-1, dtype="int32") - four = paddle.full(shape=[1], fill_value=4, dtype="int32") + one = paddle.ones(shape=[], dtype="int32") + two = paddle.full(shape=[], fill_value=2, dtype="int32") + negative_one = paddle.full(shape=[], fill_value=-1, dtype="int32") + four = paddle.full(shape=[], fill_value=4, dtype="int32") var = fluid.dygraph.to_variable(tensor_array) var1 = var[0, one, one] diff --git a/python/paddle/fluid/tests/unittests/test_variable.py b/python/paddle/fluid/tests/unittests/test_variable.py index b709510371edf..6d5bd96f9aca2 100644 --- a/python/paddle/fluid/tests/unittests/test_variable.py +++ b/python/paddle/fluid/tests/unittests/test_variable.py @@ -132,8 +132,7 @@ def _test_slice(self, place): nw = w[1, 1, 1] - self.assertEqual(len(nw.shape), 1) - self.assertEqual(nw.shape[0], 1) + self.assertEqual(len(nw.shape), 0) nw = w[:, :, :-1] self.assertEqual((784, 100, 99), nw.shape) diff --git a/python/paddle/fluid/tests/unittests/test_while_op.py b/python/paddle/fluid/tests/unittests/test_while_op.py index ea6d2d4945395..8ae9fa8c5c2bf 100644 --- a/python/paddle/fluid/tests/unittests/test_while_op.py +++ b/python/paddle/fluid/tests/unittests/test_while_op.py @@ -192,9 +192,9 @@ def test_outputs_exists_inputs(self): with fluid.program_guard(main_program, startup_program): def func(x): - s = paddle.zeros([1]) - i = paddle.ones([1]) - max_len = paddle.shape(x)[0] + s = paddle.zeros([]) + i = paddle.ones([]) + max_len = paddle.shape(x) def cond(i, s, x): return i < max_len diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 3b909b7822572..4db8a5eee8f85 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -580,6 +580,140 @@ def test_create_parameter_var(self): self.assertEqual(zero_dim_var.shape, []) self.assertEqual(zero_dim_var.item(), 0.5) + def test_getitem(self): + # case1: When all axis have a scalar indice, output should be a 0-d Tensor; + x = paddle.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5)) + x.stop_gradient = False + out = x[1, 2, 3, 4] + out.retain_grads() + out.backward() + self.assertEqual(out.shape, []) + np.testing.assert_allclose(out, np.array(119)) + self.assertEqual(out.grad.shape, []) + np.testing.assert_allclose(out.grad, 1.0) + self.assertEqual(x.grad.shape, [2, 3, 4, 5]) + x_grad_expected = np.zeros((2, 3, 4, 5)) + x_grad_expected[1, 2, 3, 4] = 1.0 + np.testing.assert_allclose(x.grad, x_grad_expected) + + # case2: When one axis has a 0-d Tensor indice, the output should be same as int indice. + x = paddle.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5)) + out1 = x[1, 2] + out2 = x[ + paddle.full([], 1, dtype='int32'), paddle.full([], 2, dtype='int32') + ] + np.testing.assert_allclose(out1, out2) + + # case3: When all axis have a scalar indice (i.e. case1) and has None indice, + # ndim of output should be same with numbers of None. + x = paddle.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5)) + out1 = x[1, 2, None, 3, 4] + self.assertEqual(out1.shape, [1]) + np.testing.assert_allclose(out1, np.array([119])) + out2 = x[1, None, 2, None, 3, 4] + self.assertEqual(out2.shape, [1, 1]) + np.testing.assert_allclose(out2, np.array([[119]])) + + # case4: 1-D Tensor will be treated as vector, no axis decrease will happen. + x = paddle.ones((2, 3, 4)) + indice = paddle.ones([1], dtype='int32') + out1 = x[indice] + self.assertEqual(out1.shape, [1, 3, 4]) + np.testing.assert_allclose(out1, np.ones((1, 3, 4))) + out2 = x[indice, indice] + self.assertEqual(out2.shape, [1, 4]) + np.testing.assert_allclose(out2, np.ones((1, 4))) + + def test_setitem(self): + # case1: all axis have a scalar indice + x = paddle.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5)) + x.stop_gradient = False + out = x * 2 + out[1, 2, 3, 4] = 10 + out.backward() + + self.assertEqual(out.shape, x.shape) + np.testing.assert_allclose(out[1, 2, 3, 4], np.array(10)) + self.assertEqual(x.grad.shape, [2, 3, 4, 5]) + x_grad_expected = np.ones((2, 3, 4, 5)) * 2 + x_grad_expected[1, 2, 3, 4] = 0 + np.testing.assert_allclose(x.grad, x_grad_expected) + + # case2: 0-D Tensor indice in some axis + # NOTE(zoooo0820): Now, int/slice with 0-D Tensor will still be + # treated as combined indexing, which is not support backward. + # There should have more test cases such as out[1, indice, :] = 0.5 when this + # problem is fixed. + x = paddle.randn((2, 3, 4, 5)) + x.stop_gradient = False + indice = paddle.full([], 1, dtype='int32') + out = x * 1 + out[indice, indice] = 0.5 + out.backward() + + self.assertEqual(out.shape, x.shape) + np.testing.assert_allclose(out[1, 1], np.ones((4, 5)) * 0.5) + x_grad_expected = np.ones((2, 3, 4, 5)) + x_grad_expected[1, 1] = 0 + np.testing.assert_allclose(x.grad, x_grad_expected) + + # case3:0-D Tensor indice in some axis, value is a Tensor + # and there is broadcast + x = paddle.randn((2, 3, 4, 5)) + x.stop_gradient = False + v = paddle.ones((4, 5), dtype='float32') * 5 + v.stop_gradient = False + indice = paddle.full([], 1, dtype='int32') + out = x * 1 + out[indice] = v + out.backward() + + self.assertEqual(out.shape, x.shape) + np.testing.assert_allclose(out[1], np.ones((3, 4, 5)) * 5) + x_grad_expected = np.ones((2, 3, 4, 5)) + x_grad_expected[1] = 0 + np.testing.assert_allclose(x.grad, x_grad_expected) + value_grad_expected = np.ones((4, 5)) * 3 + np.testing.assert_allclose(v.grad, value_grad_expected) + + # case4: value is a 0-D tensor and there is broadcast + x = paddle.randn((2, 3, 4, 5)) + x.stop_gradient = False + v = paddle.ones([], dtype='float32') * 5 + v.stop_gradient = False + out = x * 1 + indice = paddle.full([], 0, dtype='int32') + out[indice] = v + out.backward() + + self.assertEqual(out.shape, x.shape) + self.assertEqual(v.grad.shape, []) + np.testing.assert_allclose(out[0], np.ones((3, 4, 5)) * 5) + x_grad_expected = np.ones((2, 3, 4, 5)) + x_grad_expected[0] = 0 + np.testing.assert_allclose(x.grad, x_grad_expected) + value_grad_expected = np.ones(()) * 3 * 4 * 5 + np.testing.assert_allclose(v.grad, value_grad_expected) + + # case5: indice / value is 0-D Tensor, and there is no broadcast + x = paddle.randn((2, 3, 4, 5)) + x.stop_gradient = False + v = paddle.ones([], dtype='float32') * 2 + v.stop_gradient = False + out = x * 1 + indice = paddle.full([], 0, dtype='int32') + out[indice, indice, indice, indice] = v + out.backward() + + self.assertEqual(out.shape, x.shape) + self.assertEqual(v.grad.shape, []) + np.testing.assert_allclose(out[0, 0, 0, 0], np.ones(()) * 2) + x_grad_expected = np.ones((2, 3, 4, 5)) + x_grad_expected[0, 0, 0, 0] = 0 + np.testing.assert_allclose(x.grad, x_grad_expected) + value_grad_expected = np.ones(()) + np.testing.assert_allclose(v.grad, value_grad_expected) + def test_expand(self): # case1 x = paddle.full([], 1, 'float32') @@ -2110,6 +2244,118 @@ def test_create_parameter_var(self): self.assertEqual(res[0].shape, ()) self.assertEqual(res[0], 0.5) + @prog_scope() + def test_getitem(self): + # case1: When all axis have a scalar indice, output should be a 0-d Tensor; + x = paddle.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5)) + x.stop_gradient = False + out = x[1, 2, 3, 4] + paddle.static.append_backward(out.sum()) + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out, x.grad_name, out.grad_name]) + + self.assertEqual(res[0].shape, ()) + np.testing.assert_allclose(res[0], np.array(119)) + self.assertEqual(res[2].shape, ()) + np.testing.assert_allclose(res[2], 1.0) + self.assertEqual(res[1].shape, (2, 3, 4, 5)) + x_grad_expected = np.zeros((2, 3, 4, 5)) + x_grad_expected[1, 2, 3, 4] = 1.0 + np.testing.assert_allclose(res[1], x_grad_expected) + + # case2: When one axis has a 0-d Tensor indice, the output should be same as int indice. + x2 = paddle.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5)) + out1 = x2[1, 2] + out2 = x2[ + paddle.full([], 1, dtype='int32'), paddle.full([], 2, dtype='int32') + ] + res = self.exe.run(prog, fetch_list=[out1, out2]) + np.testing.assert_allclose(res[0], res[1]) + + # case3: When all axis have a scalar indice (i.e. case1) and has None indice, + # ndim of output should be same with numbers of None. + x3 = paddle.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5)) + out3 = x3[1, 2, None, 3, 4] + out4 = x3[1, None, 2, None, 3, 4] + res = self.exe.run(prog, fetch_list=[out3, out4]) + self.assertEqual(res[0].shape, (1,)) + np.testing.assert_allclose(res[0], np.array([119])) + self.assertEqual(res[1].shape, (1, 1)) + np.testing.assert_allclose(res[1], np.array([[119]])) + + # case4: 1-D Tensor will be treated as vector, no axis decrease will happen. + x4 = paddle.ones((2, 3, 4)) + indice = paddle.ones([1], dtype='int32') + out5 = x4[indice] + out6 = x4[indice, indice] + res = self.exe.run(prog, fetch_list=[out5, out6]) + + self.assertEqual(res[0].shape, (1, 3, 4)) + np.testing.assert_allclose(res[0], np.ones((1, 3, 4))) + self.assertEqual(res[1].shape, (1, 4)) + np.testing.assert_allclose(res[1], np.ones((1, 4))) + + @prog_scope() + def test_setitem(self): + # NOTE(zoooo0820): __setitem__ has gradient problem in static graph. + # To solve this, we may not support __setitem__ in static graph. + # These unit tests will delete soon. + + # case1: all axis have a scalar indice + x = paddle.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5)) + x.stop_gradient = False + out = x * 2 + out[1, 2, 3, 4] = 10 + paddle.static.append_backward(out.sum()) + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out, x.grad_name]) + + self.assertEqual(out.shape, x.shape) + np.testing.assert_allclose(res[0][1, 2, 3, 4], np.array(10)) + self.assertEqual(res[1].shape, (2, 3, 4, 5)) + x_grad_expected = np.ones((2, 3, 4, 5)) * 2 + x_grad_expected[1, 2, 3, 4] = 0 + np.testing.assert_allclose(res[1], x_grad_expected) + + # case2: 0-D Tensor indice in some axis + # NOTE(zoooo0820): Now, int/slice with 0-D Tensor will still be + # treated as combined indexing, which is not support backward. + # There should have more test cases such as out[1, indice, :] = 0.5 when this + # problem is fixed. + x = paddle.randn((2, 3, 4, 5)) + x.stop_gradient = False + indice = paddle.full([], 1, dtype='int32') + out = x * 1 + out[indice, indice] = 0.5 + paddle.static.append_backward(out.sum()) + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out, x.grad_name]) + + self.assertEqual(out.shape, x.shape) + np.testing.assert_allclose(res[0][1, 1], np.ones((4, 5)) * 0.5) + x_grad_expected = np.ones((2, 3, 4, 5)) + x_grad_expected[1, 1] = 0 + np.testing.assert_allclose(res[1], x_grad_expected) + + # case3:0-D Tensor indice in some axis, value is a Tensor + # and there is broadcast + x = paddle.randn((2, 3, 4, 5)) + x.stop_gradient = False + v = paddle.ones((4, 5), dtype='float32') * 5 + v.stop_gradient = False + indice = paddle.full([], 1, dtype='int32') + out = x * 1 + out[indice] = v + paddle.static.append_backward(out.sum()) + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out, x.grad_name, v.grad_name]) + + self.assertEqual(out.shape, x.shape) + np.testing.assert_allclose(res[0][1], np.ones((3, 4, 5)) * 5) + x_grad_expected = np.ones((2, 3, 4, 5)) + x_grad_expected[1] = 0 + np.testing.assert_allclose(res[1], x_grad_expected) + @prog_scope() def test_expand(self): x = paddle.full([], 1, 'float32') diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index 0d866860b314a..252322f4c4bcc 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -282,7 +282,7 @@ def is_integer_or_scalar_tensor(ele): if isinstance(ele, int): return True elif isinstance(ele, Variable): - if len(ele.shape) == 1 and ele.shape[0] == 1: + if len(ele.shape) == 0: return True return False @@ -573,15 +573,6 @@ def _getitem_impl_(var, item): out = reverse(out, axis=reverse_axes) - # Deal with cases when all axes are decreased. - # After slice, the shape of out is [1], which should have been [], but Paddle doesn't support scalar. - # In order to ensure the correctness of the final shape of out, one dimension of out needs to be decreased. - # For example: - # # x.shape: (2,3,4) - # out = x[0, 1, 1, None] # out.shape : (1) - if len(decrease_axes) == len(var.shape): - none_axes = none_axes[1:] - if len(none_axes) > 0: # Deal with cases that decrease_axes is not empty # For example: @@ -592,13 +583,6 @@ def _getitem_impl_(var, item): new_axis = axis - l none_axes[idx] = new_axis - # Deal with cases when all axes are decreased. - # After slice, the shape of out is [1], which should have been [], but Paddle doesn't support scalar. - # In order to ensure the correctness of the final shape of out, one dimension of out needs to be decreased. - # For example: - # # x.shape: (2,3,4) - # out = x[0, 1, 1, None] # out.shape : (1) - from ..tensor import unsqueeze out = unsqueeze(out, axis=none_axes) diff --git a/python/paddle/incubate/optimizer/functional/lbfgs.py b/python/paddle/incubate/optimizer/functional/lbfgs.py index a7221f0925e76..e3620c4ffc0ed 100644 --- a/python/paddle/incubate/optimizer/functional/lbfgs.py +++ b/python/paddle/incubate/optimizer/functional/lbfgs.py @@ -125,9 +125,7 @@ def func(x): is_converge = paddle.full(shape=[1], fill_value=False, dtype='bool') num_func_calls = paddle.full(shape=[1], fill_value=1, dtype='int64') - history_size = paddle.full( - shape=[1], fill_value=history_size, dtype='int64' - ) + history_size = paddle.full(shape=[], fill_value=history_size, dtype='int64') head = paddle.full(shape=[1], fill_value=1, dtype='int64') tail = paddle.full(shape=[1], fill_value=0, dtype='int64') @@ -177,7 +175,7 @@ def body( q = paddle.assign(g1) # In a array circle, the index may out of range, so must use mod. i = paddle.full( - shape=[1], fill_value=(head - 1).mod(history_size), dtype='int64' + shape=[], fill_value=(head - 1).mod(history_size), dtype='int64' ) def cond(i, q): @@ -193,7 +191,7 @@ def body(i, q): r = paddle.matmul(H0, q) - i = paddle.full(shape=[1], fill_value=tail + 1, dtype='int64') + i = paddle.full(shape=[], fill_value=tail + 1, dtype='int64') def cond(i, r): return i != head diff --git a/python/paddle/jit/dy2static/variable_trans_func.py b/python/paddle/jit/dy2static/variable_trans_func.py index 20f0fb6317e3b..80c4487dc29c6 100644 --- a/python/paddle/jit/dy2static/variable_trans_func.py +++ b/python/paddle/jit/dy2static/variable_trans_func.py @@ -51,11 +51,11 @@ def to_static_variable(x): Translate a Python Tensor to PaddlePaddle static graph Tensor ''' if isinstance(x, bool): - return paddle.full(shape=[1], dtype='bool', fill_value=x) + return paddle.full(shape=[], dtype='bool', fill_value=x) if isinstance(x, float): - return paddle.full(shape=[1], dtype='float64', fill_value=x) + return paddle.full(shape=[], dtype='float64', fill_value=x) if isinstance(x, int): - return paddle.full(shape=[1], dtype='int64', fill_value=x) + return paddle.full(shape=[], dtype='int64', fill_value=x) if isinstance(x, UndefinedVar) or x is None: """ for early return case, we need a variable to represent None, current we use data_layer_not_check. diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index ffd27a545b993..cc8ab648b8895 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -271,7 +271,7 @@ def _switch_grad(x, stop=False): mask = paddle.reverse(mask, axis=[0]) if sequence_length else None with paddle.fluid.framework.device_guard("cpu"): - start_i = paddle.zeros([1], dtype="int64") + start_i = paddle.zeros([], dtype="int64") end = max_seq_len end = paddle.cast(end, "int64") diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 32e05851c9521..da7305be3dc5c 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -3160,19 +3160,19 @@ def tile(x, repeat_times, name=None): ) if isinstance(repeat_times, Variable): assert ( - len(repeat_times.shape) == 1 - ), 'repeat_times must be an 1-D Tensor.' + repeat_times.numel() == 1 + ), 'repeat_times must be a Tensor with one element.' else: for elem in repeat_times: if isinstance(elem, Variable): assert ( - len(elem.shape) == 1 - ), 'Elements in repeat_times must be 1-D Tensors or integers.' + elem.numel() == 1 + ), 'Elements in repeat_times must be Tensor with one element or integers.' else: type_tuple = (int, np.int32, np.int64) assert isinstance( elem, type_tuple - ), 'Elements in repeat_times must be 1-D Tensors or integers.' + ), 'Elements in repeat_times must be Tensor with one element or integers.' check_variable_and_dtype( x, @@ -3416,18 +3416,18 @@ def expand(x, shape, name=None): return _C_ops.expand(x, shape) else: if isinstance(shape, Variable): - assert len(shape.shape) == 1, 'shape must be an 1-D Tensor.' + assert shape.numel() == 1, 'shape must be a Tensor with one element' else: for elem in shape: if isinstance(elem, Variable): assert ( - len(elem.shape) == 1 - ), 'Elements in shape must be 1-D Tensors or integers.' + elem.numel() == 1 + ), 'Elements in shape must be Tensor with one element or integers.' else: type_tuple = (int, np.int32, np.int64) assert isinstance( elem, type_tuple - ), 'Elements in shape must be 1-D Tensors or integers.' + ), 'Elements in shape must be Tensor with one element or integers.' check_variable_and_dtype( x, diff --git a/test/dygraph_to_static/test_list.py b/test/dygraph_to_static/test_list.py index 44e02950bc548..091d261ed7458 100644 --- a/test/dygraph_to_static/test_list.py +++ b/test/dygraph_to_static/test_list.py @@ -364,7 +364,7 @@ def test_to_static(self): x = paddle.to_tensor([2, 3, 4], dtype='float32') index = paddle.to_tensor([1]) res = net(x, index) - self.assertEqual(res[0], 48.0) + self.assertEqual(res, 48.0) if __name__ == '__main__': diff --git a/test/xpu/test_set_value_op_xpu.py b/test/xpu/test_set_value_op_xpu.py index e749eb8bc1b11..a373d6a0ba5f8 100644 --- a/test/xpu/test_set_value_op_xpu.py +++ b/test/xpu/test_set_value_op_xpu.py @@ -1432,7 +1432,7 @@ def test_inplace(self): a.stop_gradient = False b = a[:] c = b - b[paddle.to_tensor(0)] = 1.0 + b[paddle.zeros([], dtype='int32')] = 1.0 self.assertTrue(id(b) == id(c)) np.testing.assert_array_equal(b.numpy(), c.numpy()) diff --git a/test/xpu/test_slice_op_xpu.py b/test/xpu/test_slice_op_xpu.py index f19c3d37e283e..7cc0550740e1b 100644 --- a/test/xpu/test_slice_op_xpu.py +++ b/test/xpu/test_slice_op_xpu.py @@ -166,7 +166,7 @@ def config(self): self.starts = [0, 1, 2, 3] self.ends = [1, 2, 3, 4] self.axes = [0, 1, 2, 3] - self.decrease_axis = [0, 1, 2, 3] + self.decrease_axis = [0, 1, 2] self.infer_flags = [1, 1, 1] self.out = self.input[0, 1, 2, 3:4] @@ -188,7 +188,7 @@ def config(self): self.axes = [0, 1, 2, 3] self.decrease_axis = [0, 1, 2, 3] self.infer_flags = [1, 1, 1] - self.out = self.input[0, 1, 2, 3:4] + self.out = self.input[0, 1, 2, 3] support_types = get_xpu_op_support_types('slice') diff --git a/test/xpu/test_zero_dim_tensor_xpu.py b/test/xpu/test_zero_dim_tensor_xpu.py index 1a9f59040d534..9ecce0af8305d 100644 --- a/test/xpu/test_zero_dim_tensor_xpu.py +++ b/test/xpu/test_zero_dim_tensor_xpu.py @@ -344,6 +344,140 @@ def setUp(self): paddle.disable_static() self.x = paddle.rand([]) + def test_getitem(self): + # case1: When all axis have a scalar indice, output should be a 0-d Tensor; + x = paddle.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5)) + x.stop_gradient = False + out = x[1, 2, 3, 4] + out.retain_grads() + out.backward() + self.assertEqual(out.shape, []) + np.testing.assert_allclose(out, np.array(119)) + self.assertEqual(out.grad.shape, []) + np.testing.assert_allclose(out.grad, 1.0) + self.assertEqual(x.grad.shape, [2, 3, 4, 5]) + x_grad_expected = np.zeros((2, 3, 4, 5)) + x_grad_expected[1, 2, 3, 4] = 1.0 + np.testing.assert_allclose(x.grad, x_grad_expected) + + # case2: When one axis has a 0-d Tensor indice, the output should be same as int indice. + x = paddle.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5)) + out1 = x[1, 2] + out2 = x[ + paddle.full([], 1, dtype='int32'), paddle.full([], 2, dtype='int32') + ] + np.testing.assert_allclose(out1, out2) + + # case3: When all axis have a scalar indice (i.e. case1) and has None indice, + # ndim of output should be same with numbers of None. + x = paddle.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5)) + out1 = x[1, 2, None, 3, 4] + self.assertEqual(out1.shape, [1]) + np.testing.assert_allclose(out1, np.array([119])) + out2 = x[1, None, 2, None, 3, 4] + self.assertEqual(out2.shape, [1, 1]) + np.testing.assert_allclose(out2, np.array([[119]])) + + # case4: 1-D Tensor will be treated as vector, no axis decrease will happen. + x = paddle.ones((2, 3, 4)) + indice = paddle.ones([1], dtype='int32') + out1 = x[indice] + self.assertEqual(out1.shape, [1, 3, 4]) + np.testing.assert_allclose(out1, np.ones((1, 3, 4))) + out2 = x[indice, indice] + self.assertEqual(out2.shape, [1, 4]) + np.testing.assert_allclose(out2, np.ones((1, 4))) + + def test_setitem(self): + # case1: all axis have a scalar indice + x = paddle.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5)) + x.stop_gradient = False + out = x * 2 + out[1, 2, 3, 4] = 10 + out.backward() + + self.assertEqual(out.shape, x.shape) + np.testing.assert_allclose(out[1, 2, 3, 4], np.array(10)) + self.assertEqual(x.grad.shape, [2, 3, 4, 5]) + x_grad_expected = np.ones((2, 3, 4, 5)) * 2 + x_grad_expected[1, 2, 3, 4] = 0 + np.testing.assert_allclose(x.grad, x_grad_expected) + + # case2: 0-D Tensor indice in some axis + # NOTE(zoooo0820): Now, int/slice with 0-D Tensor will still be + # treated as combined indexing, which is not support backward. + # There should have more test cases such as out[1, indice, :] = 0.5 when this + # problem is fixed. + x = paddle.randn((2, 3, 4, 5)) + x.stop_gradient = False + indice = paddle.full([], 1, dtype='int32') + out = x * 1 + out[indice, indice] = 0.5 + out.backward() + + self.assertEqual(out.shape, x.shape) + np.testing.assert_allclose(out[1, 1], np.ones((4, 5)) * 0.5) + x_grad_expected = np.ones((2, 3, 4, 5)) + x_grad_expected[1, 1] = 0 + np.testing.assert_allclose(x.grad, x_grad_expected) + + # case3:0-D Tensor indice in some axis, value is a Tensor + # and there is broadcast + x = paddle.randn((2, 3, 4, 5)) + x.stop_gradient = False + v = paddle.ones((4, 5), dtype='float32') * 5 + v.stop_gradient = False + indice = paddle.full([], 1, dtype='int32') + out = x * 1 + out[indice] = v + out.backward() + + self.assertEqual(out.shape, x.shape) + np.testing.assert_allclose(out[1], np.ones((3, 4, 5)) * 5) + x_grad_expected = np.ones((2, 3, 4, 5)) + x_grad_expected[1] = 0 + np.testing.assert_allclose(x.grad, x_grad_expected) + value_grad_expected = np.ones((4, 5)) * 3 + np.testing.assert_allclose(v.grad, value_grad_expected) + + # case4: value is a 0-D tensor and there is broadcast + x = paddle.randn((2, 3, 4, 5)) + x.stop_gradient = False + v = paddle.ones([], dtype='float32') * 5 + v.stop_gradient = False + out = x * 1 + indice = paddle.full([], 0, dtype='int32') + out[indice] = v + out.backward() + + self.assertEqual(out.shape, x.shape) + self.assertEqual(v.grad.shape, []) + np.testing.assert_allclose(out[0], np.ones((3, 4, 5)) * 5) + x_grad_expected = np.ones((2, 3, 4, 5)) + x_grad_expected[0] = 0 + np.testing.assert_allclose(x.grad, x_grad_expected) + value_grad_expected = np.ones(()) * 3 * 4 * 5 + np.testing.assert_allclose(v.grad, value_grad_expected) + + # case5: indice / value is 0-D Tensor, and there is no broadcast + x = paddle.randn((2, 3, 4, 5)) + x.stop_gradient = False + v = paddle.ones([], dtype='float32') * 2 + v.stop_gradient = False + out = x * 1 + indice = paddle.full([], 0, dtype='int32') + out[indice, indice, indice, indice] = v + out.backward() + + self.assertEqual(out.shape, x.shape) + self.assertEqual(v.grad.shape, []) + np.testing.assert_allclose(out[0, 0, 0, 0], np.ones(()) * 2) + x_grad_expected = np.ones((2, 3, 4, 5)) + x_grad_expected[0, 0, 0, 0] = 0 + np.testing.assert_allclose(x.grad, x_grad_expected) + value_grad_expected = np.ones(()) + np.testing.assert_allclose(v.grad, value_grad_expected) + def test_expand(self): # case1 x = paddle.full([], 1, 'float32')