From f9815bfee7f74d08ebcd0e3c9e588a3261326121 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Sat, 3 Dec 2022 09:32:18 +0800 Subject: [PATCH] Scatter 0D index for gather, 0D index and 0D updates for scatter. (#48452) --- paddle/phi/infermeta/binary.cc | 78 ++++++++---- paddle/phi/infermeta/ternary.cc | 49 ++++---- paddle/phi/kernels/funcs/gather.cu.h | 9 +- paddle/phi/kernels/funcs/gather.h | 28 +++-- paddle/phi/kernels/funcs/scatter.cu.h | 26 ++-- paddle/phi/kernels/funcs/scatter.h | 91 ++++++++------ paddle/phi/kernels/xpu/gather_grad_kernel.cc | 10 +- paddle/phi/kernels/xpu/gather_kernel.cc | 23 ++-- paddle/phi/kernels/xpu/scatter_kernel.cc | 36 +++--- .../tests/unittests/test_zero_dim_tensor.py | 117 ++++++++++++++++++ .../unittests/xpu/test_zero_dim_tensor_xpu.py | 49 ++++++++ python/paddle/tensor/manipulation.py | 8 +- 12 files changed, 377 insertions(+), 147 deletions(-) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index c48388a03173d..532aed7f66d9e 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -1268,37 +1268,69 @@ void GatherInferMeta(const MetaTensor& x, index_dims[1])); } else { PADDLE_ENFORCE_EQ( - index_dims.size(), - 1, + index_dims.size() == 1 || index_dims.size() == 0, + true, phi::errors::InvalidArgument( - "The index should be 1D, when it is not 2D, but we get %d", + "The index should be 0D or 1D, when it is not 2D, but we get %d", index_dims.size())); } auto input_dim = x.dims(); auto axis_v = axis.to(); - if (axis.FromTensor() || axis_v == 0) { - // if axis.FromTensor(), we can not obtain correct shape of output - int batch_size = index_dims[0]; - phi::DDim output_dims(input_dim); - output_dims[0] = batch_size; - out->set_dims(output_dims); - out->set_dtype(x.dtype()); - out->share_lod(x); - } else { - int index_size = index_dims[0]; - std::vector out_dim_vec; - for (int i = 0; i < axis_v; i++) { - out_dim_vec.push_back(input_dim[i]); + if (index_dims.size() == 0) { + // 0D index will decrease the dimension + if (input_dim.size() == 1) { + // the index is a 0D tensor and the x is a 1D tensor + out->set_dims(phi::DDim(phi::Dim<0>())); + } else { + if (axis.FromTensor() || axis_v == 0) { + // decrease the output dimension + std::vector out_dim_vec; + for (int i = 1; i < input_dim.size(); ++i) { + out_dim_vec.emplace_back(input_dim[i]); + } + auto output_dims = phi::make_ddim(out_dim_vec); + out->set_dims(output_dims); + out->set_dtype(x.dtype()); + out->share_lod(x); + } else { + std::vector out_dim_vec; + for (int i = 0; i < axis_v; i++) { + out_dim_vec.push_back(input_dim[i]); + } + for (int i = axis_v + 1; i < input_dim.size(); i++) { + out_dim_vec.push_back(input_dim[i]); + } + auto output_dims = phi::make_ddim(out_dim_vec); + out->set_dims(output_dims); + out->set_dtype(x.dtype()); + out->share_lod(x); + } } - out_dim_vec.push_back(index_size); - for (int i = axis_v + 1; i < input_dim.size(); i++) { - out_dim_vec.push_back(input_dim[i]); + } else { + if (axis.FromTensor() || axis_v == 0) { + // if axis.FromTensor(), we can not obtain correct shape of output + int batch_size = index_dims[0]; + phi::DDim output_dims(input_dim); + output_dims[0] = batch_size; + out->set_dims(output_dims); + out->set_dtype(x.dtype()); + out->share_lod(x); + } else { + int index_size = index_dims[0]; + std::vector out_dim_vec; + for (int i = 0; i < axis_v; i++) { + out_dim_vec.push_back(input_dim[i]); + } + out_dim_vec.push_back(index_size); + for (int i = axis_v + 1; i < input_dim.size(); i++) { + out_dim_vec.push_back(input_dim[i]); + } + auto output_dims = phi::make_ddim(out_dim_vec); + out->set_dims(output_dims); + out->set_dtype(x.dtype()); + out->share_lod(x); } - auto output_dims = phi::make_ddim(out_dim_vec); - out->set_dims(output_dims); - out->set_dtype(x.dtype()); - out->share_lod(x); } } diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index 1b945c0254fb3..f7bae3690991f 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -995,31 +995,34 @@ void ScatterInferMeta(const MetaTensor& x, "index is a 2D tensor, but we get %d.", index_dims[1])); } else { + PADDLE_ENFORCE_EQ(index_dims.size() == 1 || index_dims.size() == 0, + true, + phi::errors::InvalidArgument( + "The index should be a 0D or 1D tensor when the " + "index is not a 2D tensor, but we get %d.", + index_dims.size())); + } + if (index_dims.size() != 0) { PADDLE_ENFORCE_EQ( - index_dims.size(), - 1, - phi::errors::InvalidArgument("The index should be a 1D tensor when the " - "index is not a 2D tensor, but we get %d.", - index_dims.size())); + (ref_dims.size() == updates_dims.size()), + true, + phi::errors::InvalidArgument( + "When the Input(Updates) is not a 0D tensor, the " + "Input(X) and Input(Updates) should have the same shape size, " + "but received the size of Input(x)'s shape is %d, the size of " + "Input(Updates)'s shape is %d.", + ref_dims.size(), + updates_dims.size())); + PADDLE_ENFORCE_EQ( + updates_dims[0], + index_dims[0], + phi::errors::InvalidArgument( + "Input(Updates) and Input(Ids) should have same batch-size, but" + " received Input(Updates)'s batch-size is %d, Input(Ids)'s " + "batch-size is %d.", + updates_dims[0], + index_dims[0])); } - PADDLE_ENFORCE_EQ( - ref_dims.size(), - updates_dims.size(), - phi::errors::InvalidArgument( - "Input(X) and Input(Updates) should have the same shape size, " - "but received the size of Input(x)'s shape is %d, the size of " - "Input(Updates)'s shape is %d.", - ref_dims.size(), - updates_dims.size())); - PADDLE_ENFORCE_EQ( - updates_dims[0], - index_dims[0], - phi::errors::InvalidArgument( - "Input(Updates) and Input(Ids) should have same batch-size, but" - " received Input(Updates)'s batch-size is %d, Input(Ids)'s " - "batch-size is %d.", - updates_dims[0], - index_dims[0])); out->set_dims(ref_dims); out->share_lod(x); out->set_dtype(x.dtype()); diff --git a/paddle/phi/kernels/funcs/gather.cu.h b/paddle/phi/kernels/funcs/gather.cu.h index ac8487db8f62e..2b1822ece2627 100644 --- a/paddle/phi/kernels/funcs/gather.cu.h +++ b/paddle/phi/kernels/funcs/gather.cu.h @@ -94,12 +94,9 @@ void GPUGather(const phi::GPUContext& ctx, } // index size - int64_t index_size = index.dims()[0]; - if (index_size == 0) return; + int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0]; auto src_dims = src.dims(); - phi::DDim output_dims(src_dims); - output_dims[0] = index_size; // slice size int64_t slice_size = 1; @@ -246,7 +243,9 @@ void GatherV2CUDAFunction(const DenseTensor* input, inner_dim_size *= input_dim[i]; out_dim_vec.push_back(input_dim[i]); } - out_dim_vec.push_back(index_size); + if (index->dims().size() != 0) { + out_dim_vec.push_back(index_size); + } for (int i = axis_index + 1; i < input_dim.size(); i++) { outer_dim_size *= input_dim[i]; out_dim_vec.push_back(input_dim[i]); diff --git a/paddle/phi/kernels/funcs/gather.h b/paddle/phi/kernels/funcs/gather.h index 094bc46cb6f45..f1ab1a16f1224 100644 --- a/paddle/phi/kernels/funcs/gather.h +++ b/paddle/phi/kernels/funcs/gather.h @@ -38,7 +38,6 @@ void CPUGather(const phi::CPUContext& ctx, const DenseTensor& src, const DenseTensor& index, DenseTensor* output) { - // check index of shape 1-D if (index.dims().size() == 2) { PADDLE_ENFORCE_EQ( index.dims()[1], @@ -48,14 +47,15 @@ void CPUGather(const phi::CPUContext& ctx, "in gather_op, but received value is [%d].", index.dims()[1])); } else { - PADDLE_ENFORCE_EQ(index.dims().size(), - 1, - phi::errors::InvalidArgument( - "index.dims().size() should be 1 or 2 in gather_op," - "but received shape's size is [%d].", - index.dims().size())); + PADDLE_ENFORCE_EQ( + index.dims().size() == 1 || index.dims().size() == 0, + true, + phi::errors::InvalidArgument( + "The index should be 0D or 1D, when it is not 2D, but we get %d", + index.dims().size())); } - int64_t index_size = index.dims()[0]; + + int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0]; auto src_dims = src.dims(); @@ -188,7 +188,9 @@ void GatherV2Function(const phi::CPUContext& ctx, inner_dim_size *= input_dim[i]; out_dim_vec.push_back(input_dim[i]); } - out_dim_vec.push_back(index_size); + if (index->dims().size() != 0) { + out_dim_vec.push_back(index_size); + } for (int i = axis_index + 1; i < input_dim.size(); i++) { outer_dim_size *= input_dim[i]; out_dim_vec.push_back(input_dim[i]); @@ -224,7 +226,13 @@ void GatherV2GradFunction(const phi::CPUContext& ctx, if (input->numel() == 0) return; int axis_index = axis; - int64_t input_index_dim_size = input_dim[axis_index]; + int64_t input_index_dim_size; + if (input_dim.size() == out->dims().size()) { + input_index_dim_size = input_dim[axis_index]; + } else { + // 0d index + input_index_dim_size = 1; + } int64_t inner_dim_size = 1; int64_t outer_dim_size = 1; diff --git a/paddle/phi/kernels/funcs/scatter.cu.h b/paddle/phi/kernels/funcs/scatter.cu.h index 6aeb09b232bd5..c03dcba1e2e7f 100644 --- a/paddle/phi/kernels/funcs/scatter.cu.h +++ b/paddle/phi/kernels/funcs/scatter.cu.h @@ -122,7 +122,6 @@ void GPUScatterAssign(const phi::GPUContext& ctx, const DenseTensor& index, DenseTensor* output, bool overwrite = true) { - // check index of shape 1-D if (index.dims().size() == 2) { PADDLE_ENFORCE_EQ( index.dims()[1], @@ -132,26 +131,33 @@ void GPUScatterAssign(const phi::GPUContext& ctx, "But received value is [%d]", index.dims()[1])); } else { - PADDLE_ENFORCE_EQ(index.dims().size(), - 1, - phi::errors::InvalidArgument( - "index.dims().size() should be 1 or 2 in scatter_op." - "But received value is [%d]", - index.dims().size())); + PADDLE_ENFORCE_EQ( + index.dims().size() == 1 || index.dims().size() == 0, + true, + phi::errors::InvalidArgument( + "index.dims().size() should be 0, 1 or 2 in scatter_op." + "But received value is [%d]", + index.dims().size())); } - int64_t index_size = index.dims()[0]; + + int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0]; auto src_dims = src.dims(); phi::DDim output_dims(src_dims); output_dims[0] = index_size; // slice size - int64_t slice_size = 1; - for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; + size_t slice_size = 1; + if (index.dims().size() != 0) { + for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; + } else { + for (int i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i]; + } const T* p_src = src.data(); const IndexT* p_index = index.data(); T* p_output = output->data(); + const size_t& slice_bytes = slice_size * sizeof(T); // set block and grid num diff --git a/paddle/phi/kernels/funcs/scatter.h b/paddle/phi/kernels/funcs/scatter.h index 0b381e5710651..9ee73a08b06c5 100644 --- a/paddle/phi/kernels/funcs/scatter.h +++ b/paddle/phi/kernels/funcs/scatter.h @@ -76,7 +76,6 @@ void ScatterAssign(const phi::CPUContext& ctx, const DenseTensor& src, const DenseTensor& index, DenseTensor* output) { - // check index of shape 1-D if (index.dims().size() == 2) { PADDLE_ENFORCE_EQ( index.dims()[1], @@ -86,14 +85,15 @@ void ScatterAssign(const phi::CPUContext& ctx, "But received value is [%d]", index.dims()[1])); } else { - PADDLE_ENFORCE_EQ(index.dims().size(), - 1, + PADDLE_ENFORCE_EQ(index.dims().size() == 1 || index.dims().size() == 0, + true, phi::errors::InvalidArgument( - "index.dims().size() should be 1 or 2 in scatter_op." - "But received value is [%d]", + "index.dims().size() should be 0, 1 or 2 in " + "scatter_op. But received value is [%d]", index.dims().size())); } - int64_t index_size = index.dims()[0]; + + int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0]; auto src_dims = src.dims(); auto dst_dims = output->dims(); @@ -102,23 +102,29 @@ void ScatterAssign(const phi::CPUContext& ctx, const IndexT* p_index = index.data(); T* p_output = output->data(); - // check src shape and dst shape should match - for (int i = 1; i < src_dims.size(); i++) - PADDLE_ENFORCE_EQ( - src_dims[i], - dst_dims[i], - phi::errors::InvalidArgument( - "The dimensions of the source tensor and target tensor should" - " match, but received source tensor's %d-th dimension is %d," - "target tensor's %d-th dimension is %d.", - i, - src_dims[i], - i, - dst_dims[i])); + if (index.dims().size() != 0) { + // check src shape and dst shape should match + for (int i = 1; i < src_dims.size(); i++) + PADDLE_ENFORCE_EQ( + src_dims[i], + dst_dims[i], + phi::errors::InvalidArgument( + "The dimensions of the source tensor and target tensor should" + " match, but received source tensor's %d-th dimension is %d," + "target tensor's %d-th dimension is %d.", + i, + src_dims[i], + i, + dst_dims[i])); + } // slice size size_t slice_size = 1; - for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; + if (index.dims().size() != 0) { + for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; + } else { + for (int i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i]; + } const size_t slice_bytes = slice_size * sizeof(T); @@ -143,43 +149,48 @@ void ScatterAssignAdd(const phi::CPUContext& ctx, const DenseTensor& src, const DenseTensor& index, DenseTensor* output) { - // check index of shape 1-D PADDLE_ENFORCE_EQ( - index.dims().size() == 1 || + index.dims().size() == 1 || index.dims().size() == 0 || (index.dims().size() == 2 && index.dims()[1] == 1), true, phi::errors::InvalidArgument( "index's shape is error, " - "expect index'dims shape is 1 or 2 and index.dims[1] is 1" - "but got index'dims shape is %d", + "expect index'dims shape is 0, 1, 2 (index.dims[1] should " + "be 1), but got index'dims shape is %d", index.dims().size())); - int64_t index_size = index.dims()[0]; + + int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0]; auto src_dims = src.dims(); auto dst_dims = output->dims(); const T* p_src = src.data(); const IndexT* p_index = index.data(); - T* p_output = output->data(); - // check src shape and dst shape should match - for (int i = 1; i < src_dims.size(); i++) - PADDLE_ENFORCE_EQ( - src_dims[i], - dst_dims[i], - phi::errors::InvalidArgument( - "The dimensions of the source tensor and target tensor should" - " match, but received source tensor's %d-th dimension is %d," - "target tensor's %d-th dimension is %d.", - i, - src_dims[i], - i, - dst_dims[i])); + if (index.dims().size() != 0) { + // check src shape and dst shape should match + for (int i = 1; i < src_dims.size(); i++) + PADDLE_ENFORCE_EQ( + src_dims[i], + dst_dims[i], + phi::errors::InvalidArgument( + "The dimensions of the source tensor and target tensor should" + " match, but received source tensor's %d-th dimension is %d," + "target tensor's %d-th dimension is %d.", + i, + src_dims[i], + i, + dst_dims[i])); + } // slice size size_t slice_size = 1; - for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; + if (index.dims().size() != 0) { + for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; + } else { + for (int i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i]; + } const size_t& slice_bytes = slice_size * sizeof(T); diff --git a/paddle/phi/kernels/xpu/gather_grad_kernel.cc b/paddle/phi/kernels/xpu/gather_grad_kernel.cc index 7be22a86d0019..86a6a39f87cf5 100644 --- a/paddle/phi/kernels/xpu/gather_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/gather_grad_kernel.cc @@ -44,10 +44,10 @@ void GatherGradKernel(const Context& dev_ctx, index_dims[1])); } else { PADDLE_ENFORCE_EQ( - index_dims.size(), - 1, + index_dims.size() == 1 || index_dims.size() == 0, + true, phi::errors::InvalidArgument( - "The index should be 1D, when it is not 2D, but we get %d", + "The index should be 0D or 1D, when it is not 2D, but we get %d", index_dims.size())); } std::vector xshape(x_grad->dims().size()); @@ -66,7 +66,7 @@ void GatherGradKernel(const Context& dev_ctx, index.data(), reinterpret_cast(x_grad->data()), xshape, - index.dims()[0], + index.dims().size() == 0 ? 1 : index.dims()[0], axis_v, overwrite); } else { @@ -84,7 +84,7 @@ void GatherGradKernel(const Context& dev_ctx, index_int_ptr_l3, reinterpret_cast(x_grad->data()), xshape, - index.dims()[0], + index.dims().size() == 0 ? 1 : index.dims()[0], axis_v, overwrite); } diff --git a/paddle/phi/kernels/xpu/gather_kernel.cc b/paddle/phi/kernels/xpu/gather_kernel.cc index c3520178d1804..76b2f04ee52ba 100644 --- a/paddle/phi/kernels/xpu/gather_kernel.cc +++ b/paddle/phi/kernels/xpu/gather_kernel.cc @@ -41,10 +41,10 @@ void GatherKernel(const Context& dev_ctx, index_dims[1])); } else { PADDLE_ENFORCE_EQ( - index_dims.size(), - 1, + index_dims.size() == 1 || index_dims.size() == 0, + true, phi::errors::InvalidArgument( - "The index should be 1D, when it is not 2D, but we get %d", + "The index should be 0D, 1D, when it is not 2D, but we get %d", index_dims.size())); } std::vector xshape(x.dims().size()); @@ -56,13 +56,14 @@ void GatherKernel(const Context& dev_ctx, int r = XPU_SUCCESS; if (index_type == DataType::INT32) { - r = xpu::gather(dev_ctx.x_context(), - reinterpret_cast(x.data()), - index.data(), - reinterpret_cast(out->data()), - xshape, - index.dims()[0], - axis_v); + r = xpu::gather( + dev_ctx.x_context(), + reinterpret_cast(x.data()), + index.data(), + reinterpret_cast(out->data()), + xshape, + index.dims().size() == 0 ? 1 : index.dims()[0], + axis_v); } else { r = xpu::gather( dev_ctx.x_context(), @@ -70,7 +71,7 @@ void GatherKernel(const Context& dev_ctx, index.data(), reinterpret_cast(out->data()), xshape, - index.dims()[0], + index.dims().size() == 0 ? 1 : index.dims()[0], axis_v); } PADDLE_ENFORCE_EQ( diff --git a/paddle/phi/kernels/xpu/scatter_kernel.cc b/paddle/phi/kernels/xpu/scatter_kernel.cc index a1db2669e619b..988b8a71568e9 100644 --- a/paddle/phi/kernels/xpu/scatter_kernel.cc +++ b/paddle/phi/kernels/xpu/scatter_kernel.cc @@ -43,30 +43,34 @@ void ScatterKernel(const Context &ctx, // check index of shape 1-D PADDLE_ENFORCE_EQ( - index.dims().size() == 1 || + index.dims().size() == 1 || index.dims().size() == 0 || (index.dims().size() == 2 && index.dims()[1] == 1), true, phi::errors::InvalidArgument( "index's shape is error, " - "expect index'dims shape is 1 or 2 and index.dims[1] is 1" - "but got index'dims shape is %d", + "expect index'dims shape is 0, 1, 2 (index.dims[1] should " + "be 1), 0 but got index'dims shape is %d", index.dims().size())); - int index_size = static_cast(index.dims()[0]); + int index_size = + static_cast(index.dims().size() == 0 ? 1 : index.dims()[0]); auto x_dims = x.dims(); auto update_dims = updates.dims(); - for (int i = 1; i < x_dims.size(); i++) - PADDLE_ENFORCE_EQ( - x_dims[i], - update_dims[i], - phi::errors::InvalidArgument( - "The dimensions of the source tensor and target tensor should" - " match, but received source tensor's %d-th dimension is %d," - "target tensor's %d-th dimension is %d.", - i, - x_dims[i], - i, - update_dims[i])); + if (index.dims().size() != 0) { + // only check when the updates tensor is not a 0D tensor + for (int i = 1; i < x_dims.size(); i++) + PADDLE_ENFORCE_EQ( + x_dims[i], + update_dims[i], + phi::errors::InvalidArgument( + "The dimensions of the source tensor and target tensor should" + " match, but received source tensor's %d-th dimension is %d," + "target tensor's %d-th dimension is %d.", + i, + x_dims[i], + i, + update_dims[i])); + } int dim0 = static_cast(x.dims()[0]); int dim1 = diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index e854b8489af14..e7381350624b9 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -598,6 +598,61 @@ def test_searchsorted(self): self.assertEqual(out.shape, []) self.assertEqual(out.numpy(), 0) + def test_gather_1D(self): + x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False) + index = paddle.full([], 2, 'int64') + out = paddle.gather(x, index) + out.backward() + + self.assertEqual(out.shape, []) + self.assertEqual(out.numpy(), 5) + self.assertEqual(out.grad.shape, []) + + def test_gather_xD_axis_0(self): + x = paddle.to_tensor( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False + ) + index = paddle.full([], 1, 'int64') + out = paddle.gather(x, index) + out.backward() + + self.assertEqual(out.shape, [3]) + for i in range(3): + self.assertEqual(out.numpy()[i], x.numpy()[1][i]) + self.assertEqual(out.grad.shape, [3]) + + def test_gather_xD_axis_1(self): + x = paddle.to_tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + index = paddle.full([], 1, 'int64') + out = paddle.gather(x, index, axis=1) + + self.assertEqual(out.shape, [2]) + for i in range(2): + self.assertEqual(out.numpy()[i], x.numpy()[i][1]) + + def test_scatter_1D(self): + x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False) + index = paddle.full([], 2, 'int64') + updates = paddle.full([], 4.0) + out = paddle.scatter(x, index, updates) + out.backward() + + self.assertEqual(out.grad.shape, [5]) + self.assertEqual(out.numpy()[2], 4) + + def test_scatter_XD(self): + x = paddle.to_tensor( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False + ) + index = paddle.full([], 1, 'int64') + updates = paddle.to_tensor([1.0, 2.0, 3.0]) + out = paddle.scatter(x, index, updates) + out.backward() + + for i in range(3): + self.assertEqual(out.numpy()[1][i], updates.numpy()[i]) + self.assertEqual(out.grad.shape, [2, 3]) + class TestSundryAPIStatic(unittest.TestCase): def setUp(self): @@ -679,6 +734,68 @@ def test_searchsorted(self): self.assertEqual(res[0].shape, ()) self.assertEqual(res[0], 0) + @prog_scope() + def test_gather_1D(self): + x = paddle.full([10], 1.0, 'float32') + index = paddle.full([], 2, 'int64') + out = paddle.gather(x, index) + paddle.static.append_backward(out) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out]) + self.assertEqual(res[0].shape, ()) + self.assertEqual(res[0], 1) + + @prog_scope() + def test_gather_XD_axis_0(self): + x = paddle.full([2, 3], 1.0, 'float32') + index = paddle.full([], 1, 'int64') + out = paddle.gather(x, index) + paddle.static.append_backward(out) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out]) + self.assertEqual(res[0].shape, (3,)) + for i in range(3): + self.assertEqual(res[0][i], 1) + + @prog_scope() + def test_gather_XD_axis_1(self): + x = paddle.full([2, 3], 1.0, 'float32') + index = paddle.full([], 1, 'int64') + out = paddle.gather(x, index, axis=1) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out]) + self.assertEqual(res[0].shape, (2,)) + for i in range(2): + self.assertEqual(res[0][i], 1) + + @prog_scope() + def test_scatter_1D(self): + x = paddle.full([10], 1.0, 'float32') + index = paddle.full([], 2, 'int64') + updates = paddle.full([], 4, 'float32') + out = paddle.scatter(x, index, updates) + paddle.static.append_backward(out) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out]) + self.assertEqual(res[0][2], 4) + + @prog_scope() + def test_scatter_XD(self): + x = paddle.full([2, 3], 1.0, 'float32') + index = paddle.full([], 1, 'int64') + updates = paddle.full([3], 4, 'float32') + out = paddle.scatter(x, index, updates) + paddle.static.append_backward(out) + + prog = paddle.static.default_main_program() + res = self.exe.run(prog, fetch_list=[out]) + for i in range(3): + self.assertEqual(res[0][1][i], 4) + # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. class TestNoBackwardAPI(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py index 6bde8ef947d7c..b07043689f7fe 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py @@ -426,6 +426,55 @@ def test_searchsorted(self): self.assertEqual(out.shape, []) self.assertEqual(out.numpy(), 0) + def test_gather_1D(self): + x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False) + index = paddle.full([], 2, 'int64') + out = paddle.gather(x, index) + out.backward() + + self.assertEqual(out.shape, []) + self.assertEqual(out.numpy(), 5) + self.assertEqual(out.grad.shape, []) + + def test_gather_xD_axis_0(self): + x = paddle.to_tensor( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False + ) + index = paddle.full([], 1, 'int64') + out = paddle.gather(x, index) + out.backward() + + self.assertEqual(out.shape, [3]) + for i in range(3): + self.assertEqual(out.numpy()[i], x.numpy()[1][i]) + self.assertEqual(out.grad.shape, [3]) + + def test_gather_xD_axis_1(self): + x = paddle.to_tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + index = paddle.full([], 1, 'int64') + out = paddle.gather(x, index, axis=1) + + self.assertEqual(out.shape, [2]) + for i in range(2): + self.assertEqual(out.numpy()[i], x.numpy()[i][1]) + + def test_scatter_1D(self): + x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0]) + index = paddle.full([], 2, 'int64') + updates = paddle.full([], 4.0) + out = paddle.scatter(x, index, updates) + + self.assertEqual(out.numpy()[2], 4) + + def test_scatter_XD(self): + x = paddle.to_tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + index = paddle.full([], 1, 'int64') + updates = paddle.to_tensor([1.0, 2.0, 3.0]) + out = paddle.scatter(x, index, updates) + + for i in range(3): + self.assertEqual(out.numpy()[1][i], updates.numpy()[i]) + # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. class TestNoBackwardAPI(unittest.TestCase): diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index fceae51e14564..8c47809d222a9 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -2728,13 +2728,13 @@ def gather(x, index, axis=None, name=None): x (Tensor): The source input tensor with rank>=1. Supported data type is int32, int64, float32, float64 and uint8 (only for CPU), float16 (only for GPU). - index (Tensor): The index input tensor with rank=1. Data type is int32 or int64. + index (Tensor): The index input tensor with rank=0 or rank=1. Data type is int32 or int64. axis (Tensor|int, optional): The axis of input to be gathered, it's can be int or a Tensor with data type is int32 or int64. The default value is None, if None, the ``axis`` is 0. name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` . Returns: - output (Tensor), The output is a tensor with the same rank as ``x``. + output (Tensor), If the index is a 1-D tensor, the output is a tensor with the same shape as ``x``. If the index is a 0-D tensor, the output will reduce the dimension where the axis pointing. Examples: @@ -2888,8 +2888,8 @@ def scatter(x, index, updates, overwrite=True, name=None): Args: x (Tensor): The input N-D Tensor with ndim>=1. Data type can be float32, float64. - index (Tensor): The index 1-D Tensor. Data type can be int32, int64. The length of index cannot exceed updates's length, and the value in index cannot exceed input's length. - updates (Tensor): update input with updates parameter based on index. shape should be the same as input, and dim value with dim > 1 should be the same as input. + index (Tensor): The index is a 1-D or 0-D Tensor. Data type can be int32, int64. The length of index cannot exceed updates's length, and the value in index cannot exceed input's length. + updates (Tensor): Update input with updates parameter based on index. When the index is a 1-D tensor, the updates shape should be the same as input, and dim value with dim > 1 should be the same as input. When the index is a 0-D tensor, the updates should be a (N-1)-D tensor, the ith dim of the updates should be queal with the (i+1)th dim of the input. overwrite (bool): The mode that updating the output when there are same indices. If True, use the overwrite mode to update the output of the same index,