From 555f0fc7777980ad1040c59ee483d7af55c74ae7 Mon Sep 17 00:00:00 2001 From: liym27 Date: Wed, 12 May 2021 20:30:37 +0800 Subject: [PATCH] Support static Variable getitem for Ellipsis index Remove ellipsis --- .../fluid/tests/unittests/test_variable.py | 22 +++++++++++++++++++ python/paddle/fluid/variable_index.py | 1 + 2 files changed, 23 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_variable.py b/python/paddle/fluid/tests/unittests/test_variable.py index 6ffecd33f8f48..4162fa436798f 100644 --- a/python/paddle/fluid/tests/unittests/test_variable.py +++ b/python/paddle/fluid/tests/unittests/test_variable.py @@ -17,6 +17,7 @@ import unittest import paddle from paddle.fluid.framework import default_main_program, Program, convert_np_dtype_to_dtype_, in_dygraph_mode +import paddle import paddle.fluid as fluid import paddle.fluid.layers as layers import paddle.fluid.core as core @@ -218,6 +219,26 @@ def _test_slice_index_list(self, place): self.assertTrue((result[2] == expected[2]).all()) self.assertTrue((result[3] == expected[3]).all()) + def _test_slice_index_ellipsis(self, place): + data = np.random.rand(2, 3, 4).astype("float32") + prog = paddle.static.Program() + with paddle.static.program_guard(prog): + x = paddle.assign(data) + out1 = x[0:, ..., 1:] + out2 = x[0:, ...] + out3 = x[..., 1:] + out4 = x[...] + + exe = paddle.static.Executor(place) + result = exe.run(prog, fetch_list=[out1, out2, out3, out4]) + + expected = [data[0:, ..., 1:], data[0:, ...], data[..., 1:], data[...]] + + self.assertTrue((result[0] == expected[0]).all()) + self.assertTrue((result[1] == expected[1]).all()) + self.assertTrue((result[2] == expected[2]).all()) + self.assertTrue((result[3] == expected[3]).all()) + with self.assertRaises(IndexError): res = x[[1, 0], [0, 0]] @@ -233,6 +254,7 @@ def test_slice(self): self._test_slice(place) self._test_slice_index_tensor(place) self._test_slice_index_list(place) + self._test_slice_index_ellipsis(place) def _tostring(self): b = default_main_program().current_block() diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index aed8c82d43b4d..e289ae7f837d5 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -112,6 +112,7 @@ def _getitem_impl_(var, item): use_strided_slice = False item, none_axes = replace_none(item) + item = replace_ellipsis(var, item) for dim, slice_item in enumerate(item): if is_integer_or_scalar_tensor(slice_item):