Skip to content

Commit

Permalink
supplement the function of slice
Browse files Browse the repository at this point in the history
  • Loading branch information
hbwx24 committed Jul 19, 2021
1 parent bd559a2 commit f41ce1c
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np

import paddle
from paddle.static import InputSpec

SEED = 2020
np.random.seed(SEED)
Expand Down Expand Up @@ -176,5 +177,44 @@ def test_set_value_with_save(self):
output_spec=None)


class TestSliceSupplementCase(unittest.TestCase):
def test_static_slice_step(self):
paddle.enable_static()
array = np.arange(4**3).reshape((4, 4, 4))

x = paddle.static.data(name='x', shape=[4, 4, 4], dtype='int64')
z1 = x[::2]
z2 = x[::-2]

place = paddle.CPUPlace()
prog = paddle.static.default_main_program()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())

out = exe.run(prog, feed={'x': array}, fetch_list=[z1, z2])

self.assertTrue(np.array_equal(out[0], array[::2]))
self.assertTrue(np.array_equal(out[1], array[::-2]))

def test_static_slice_step_dygraph2static(self):
paddle.disable_static()

array = np.arange(4**2 * 5).reshape((5, 4, 4))
inps = paddle.to_tensor(array)

def func(inps):
return inps[::2], inps[::-2]

origin_result = func(inps)
sfunc = paddle.jit.to_static(
func, input_spec=[InputSpec(shape=[None, 4, 4])])
static_result = sfunc(inps)

self.assertTrue(
np.array_equal(origin_result[0].numpy(), static_result[0].numpy()))
self.assertTrue(
np.array_equal(origin_result[1].numpy(), static_result[1].numpy()))


if __name__ == '__main__':
unittest.main()
9 changes: 5 additions & 4 deletions python/paddle/fluid/variable_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,14 @@ def _getitem_impl_(var, item):

step = 1 if step is None else step

if start is None and end is None:
assert (step == -1)
if start is None and end is None and step == -1:
reverse_axes.append(dim)
continue

start = 0 if start is None else start
end = MAX_INTEGER if end is None else end
if start is None:
start = 0 if step > 0 else MAX_INTEGER
if end is None:
end = MAX_INTEGER if step > 0 else 0

elif isinstance(slice_item, list):
is_bool_list = False
Expand Down

0 comments on commit f41ce1c

Please sign in to comment.