diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py index 67d3778bcc387..085b4ec6f99e1 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py @@ -18,6 +18,7 @@ import numpy as np import paddle +from paddle.static import InputSpec SEED = 2020 np.random.seed(SEED) @@ -176,5 +177,44 @@ def test_set_value_with_save(self): output_spec=None) +class TestSliceSupplementCase(unittest.TestCase): + def test_static_slice_step(self): + paddle.enable_static() + array = np.arange(4**3).reshape((4, 4, 4)) + + x = paddle.static.data(name='x', shape=[4, 4, 4], dtype='int64') + z1 = x[::2] + z2 = x[::-2] + + place = paddle.CPUPlace() + prog = paddle.static.default_main_program() + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + + out = exe.run(prog, feed={'x': array}, fetch_list=[z1, z2]) + + self.assertTrue(np.array_equal(out[0], array[::2])) + self.assertTrue(np.array_equal(out[1], array[::-2])) + + def test_static_slice_step_dygraph2static(self): + paddle.disable_static() + + array = np.arange(4**2 * 5).reshape((5, 4, 4)) + inps = paddle.to_tensor(array) + + def func(inps): + return inps[::2], inps[::-2] + + origin_result = func(inps) + sfunc = paddle.jit.to_static( + func, input_spec=[InputSpec(shape=[None, 4, 4])]) + static_result = sfunc(inps) + + self.assertTrue( + np.array_equal(origin_result[0].numpy(), static_result[0].numpy())) + self.assertTrue( + np.array_equal(origin_result[1].numpy(), static_result[1].numpy())) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index c9363dff13d81..7d55e64aa9fa9 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -144,13 +144,14 @@ def _getitem_impl_(var, item): step = 1 if step is None else step - if start is None and end is None: - assert (step == -1) + if start is None and end is None and step == -1: reverse_axes.append(dim) continue - start = 0 if start is None else start - end = MAX_INTEGER if end is None else end + if start is None: + start = 0 if step > 0 else MAX_INTEGER + if end is None: + end = MAX_INTEGER if step > 0 else 0 elif isinstance(slice_item, list): is_bool_list = False