Skip to content

Commit

Permalink
Fix gather, scatter op 0d tenor GPU error. (#50271)
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu authored Feb 7, 2023
1 parent 0dd41a2 commit 05c9c0a
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 10 deletions.
3 changes: 2 additions & 1 deletion paddle/phi/kernels/funcs/gather.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,8 @@ void GatherV2GradCUDAFunction(const DenseTensor* input,

if (input->numel() == 0) return;
int axis_index = axis;
int64_t input_index_dim_size = input_dim[axis_index];
int64_t input_index_dim_size =
index->dims().size() == 0 ? 1 : input_dim[axis_index];

int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
Expand Down
10 changes: 7 additions & 3 deletions paddle/phi/kernels/funcs/scatter.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,15 @@ template <typename T, typename IndexT = int>
void GPUScatterGradForX(const phi::GPUContext& ctx,
const DenseTensor& index,
DenseTensor* output) {
int64_t index_size = index.dims()[0];
int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0];
auto dst_dims = output->dims();
// slice size
int64_t slice_size = 1;
for (int i = 1; i < dst_dims.size(); ++i) slice_size *= dst_dims[i];
int64_t slice_size = 1; // slice size
if (index.dims().size() != 0) {
for (int i = 1; i < dst_dims.size(); ++i) slice_size *= dst_dims[i];
} else {
for (int i = 0; i < dst_dims.size(); ++i) slice_size *= dst_dims[i];
}
const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>();
const size_t& slice_bytes = slice_size * sizeof(T);
Expand Down
12 changes: 6 additions & 6 deletions python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,7 @@ def test_gather_xD_axis_0(self):
self.assertEqual(x.grad.shape, [2, 3])
self.assertEqual(out.grad.shape, [3])

def _test_gather_xD_axis_1(self):
def test_gather_xD_axis_1(self):
x = paddle.to_tensor(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False
)
Expand All @@ -901,7 +901,7 @@ def _test_gather_xD_axis_1(self):
self.assertEqual(x.grad.shape, [2, 3])
self.assertEqual(out.grad.shape, [2])

def _test_scatter_1D(self):
def test_scatter_1D(self):
x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False)
index = paddle.full([], 2, 'int64')
updates = paddle.full([], 4.0)
Expand All @@ -913,7 +913,7 @@ def _test_scatter_1D(self):
self.assertEqual(out.numpy()[2], 4)
self.assertEqual(out.grad.shape, [5])

def _test_scatter_XD(self):
def test_scatter_XD(self):
x = paddle.to_tensor(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False
)
Expand Down Expand Up @@ -1925,7 +1925,7 @@ def test_gather_XD_axis_0(self):
self.assertEqual(res[2].shape, (3,))

@prog_scope()
def _test_gather_XD_axis_1(self):
def test_gather_XD_axis_1(self):
x = paddle.full([2, 3], 1.0, 'float32')
x.stop_gradient = False
index = paddle.full([], 1, 'int64')
Expand All @@ -1940,7 +1940,7 @@ def _test_gather_XD_axis_1(self):
self.assertEqual(res[2].shape, (2,))

@prog_scope()
def _test_scatter_1D(self):
def test_scatter_1D(self):
x = paddle.full([10], 1.0, 'float32')
x.stop_gradient = False
index = paddle.full([], 2, 'int64')
Expand All @@ -1956,7 +1956,7 @@ def _test_scatter_1D(self):
self.assertEqual(res[2].shape, (10,))

@prog_scope()
def _test_scatter_XD(self):
def test_scatter_XD(self):
x = paddle.full([2, 3], 1.0, 'float32')
x.stop_gradient = False
index = paddle.full([], 1, 'int64')
Expand Down

0 comments on commit 05c9c0a

Please sign in to comment.