Skip to content

Commit

Permalink
add range support in indexing (#56272)
Browse files Browse the repository at this point in the history
* add range support in indexing

* add getitem ut case
  • Loading branch information
zoooo0820 authored Aug 16, 2023
1 parent 84482da commit 02a7b3c
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/paddle/fluid/dygraph/tensor_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,7 @@ def contain_tensor_or_list(item):
item = (item,)

for slice_item in item:
if isinstance(slice_item, (list, np.ndarray, Variable)):
if isinstance(slice_item, (list, np.ndarray, Variable, range)):
return True
elif isinstance(slice_item, slice):
if (
Expand Down
8 changes: 5 additions & 3 deletions python/paddle/fluid/variable_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,11 +259,13 @@ def replace_ellipsis(var, item):
return item


def replace_ndarray(item):
def replace_ndarray_and_range(item):
new_item = []
for slice_item in item:
if isinstance(slice_item, np.ndarray):
new_item.append(paddle.assign(slice_item))
elif isinstance(slice_item, range):
new_item.append(list(slice_item))
else:
new_item.append(slice_item)
return new_item
Expand Down Expand Up @@ -416,7 +418,7 @@ def _setitem_impl_(var, item, value):
ends = []
steps = []

item = replace_ndarray(item)
item = replace_ndarray_and_range(item)
item = replace_ellipsis(var, item)
item, none_axes = replace_none(item)
slice_info = SliceInfo()
Expand Down Expand Up @@ -700,7 +702,7 @@ def parse_index(x, indices):
if not isinstance(indices, tuple):
indices = (indices,)

indices = replace_ndarray(indices)
indices = replace_ndarray_and_range(indices)
indices = replace_ellipsis(x, indices)
indices, none_axes = replace_none(indices)

Expand Down
22 changes: 22 additions & 0 deletions test/indexing/test_getitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,15 @@ def test_combined_index_11(self):

np.testing.assert_allclose(y.numpy(), np_res)

def test_index_has_range(self):
np_data = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6))
np_res = np_data[:, range(3), 4]

x = paddle.to_tensor(np_data)
y = x[:, range(3), 4]

np.testing.assert_allclose(y.numpy(), np_res)


class TestGetitemInStatic(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -312,6 +321,19 @@ def test_combined_index_11(self):

np.testing.assert_allclose(res[0], np_res)

def test_index_has_range(self):
# only one bool tensor with all False
np_data = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6))
np_res = np_data[:, range(3), 4]
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.to_tensor(np_data)
y = _getitem_static(x, (slice(None, None, None), range(3), 4))
res = self.exe.run(fetch_list=[y.name])

np.testing.assert_allclose(res[0], np_res)


class TestGetItemErrorCase(unittest.TestCase):
def setUp(self):
Expand Down
25 changes: 25 additions & 0 deletions test/indexing/test_setitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ def test_combined_index_3(self):

np.testing.assert_allclose(x.numpy(), np_data)

def test_index_has_range(self):
np_data = np.ones((3, 4, 5, 6), dtype='int32')
x = paddle.to_tensor(np_data)

np_data[:, range(3), [1, 2, 4]] = 10
x[:, range(3), [1, 2, 4]] = 10

np.testing.assert_allclose(x.numpy(), np_data)


class TestSetitemInStatic(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -137,3 +146,19 @@ def test_combined_index_5(self):
res = self.exe.run(fetch_list=[y.name])

np.testing.assert_allclose(res[0], np_data)

def test_index_has_range(self):
np_data = np.ones((3, 4, 5, 6), dtype='int32')
np_data[:, range(3), [1, 2, 4]] = 10
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.ones((3, 4, 5, 6), dtype='int32')
y = _setitem_static(
x,
(slice(None, None), range(3), [1, 2, 4]),
10,
)
res = self.exe.run(fetch_list=[y.name])

np.testing.assert_allclose(res[0], np_data)

0 comments on commit 02a7b3c

Please sign in to comment.