Skip to content

Commit

Permalink
add unittests for supporting get_item where the index is a bool scala…
Browse files Browse the repository at this point in the history
…r tensor
  • Loading branch information
FlyingQianMM committed Mar 24, 2022
1 parent 895486b commit 767317f
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 1 deletion.
92 changes: 92 additions & 0 deletions python/paddle/fluid/tests/unittests/test_var_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,15 @@ def _test_bool_index(self):
with self.assertRaises(IndexError):
var_tensor[paddle.to_tensor([[True, False, False, False]])]

def _test_scalar_bool_index(self):
shape = (1, 2, 5, 64)
np_value = np.random.random(shape).astype('float32')
var_tensor = paddle.to_tensor(np_value)
index = [True]
tensor_index = paddle.to_tensor(index)
var = [var_tensor[tensor_index].numpy(), ]
self.assertTrue(np.array_equal(var[0], np_value[index]))

def _test_for_var(self):
np_value = np.random.random((30, 100, 100)).astype('float32')
w = fluid.dygraph.to_variable(np_value)
Expand Down Expand Up @@ -882,6 +891,7 @@ def test_slice(self):
self._test_for_getitem_ellipsis_index()
self._test_none_index()
self._test_bool_index()
self._test_scalar_bool_index()
self._test_numpy_index()
self._test_list_index()

Expand Down Expand Up @@ -1193,6 +1203,88 @@ def set_dtype(self):
self.dtype = "float64"


class TestVarBaseSetitemBoolIndex(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.set_dtype()
self.set_input()

def set_input(self):
self.tensor_x = paddle.to_tensor(np.ones((4, 2, 3)).astype(self.dtype))
self.np_value = np.random.random((2, 3)).astype(self.dtype)
self.tensor_value = paddle.to_tensor(self.np_value)

def set_dtype(self):
self.dtype = "int32"

def _test(self, value):
paddle.disable_static()
self.assertEqual(self.tensor_x.inplace_version, 0)

id_origin = id(self.tensor_x)
index_1 = paddle.to_tensor(np.array([True, False, False, False]))
self.tensor_x[index_1] = value
self.assertEqual(self.tensor_x.inplace_version, 1)

if isinstance(value, (six.integer_types, float)):
result = np.zeros((2, 3)).astype(self.dtype) + value

else:
result = self.np_value

self.assertTrue(np.array_equal(self.tensor_x[0].numpy(), result))
self.assertEqual(id_origin, id(self.tensor_x))

index_2 = paddle.to_tensor(np.array([False, True, False, False]))
self.tensor_x[index_2] = value
self.assertEqual(self.tensor_x.inplace_version, 2)
self.assertTrue(np.array_equal(self.tensor_x[1].numpy(), result))
self.assertEqual(id_origin, id(self.tensor_x))

index_3 = paddle.to_tensor(np.array([True, True, True, True]))
self.tensor_x[index_3] = value
self.assertEqual(self.tensor_x.inplace_version, 3)
self.assertTrue(np.array_equal(self.tensor_x[3].numpy(), result))
self.assertEqual(id_origin, id(self.tensor_x))

def test_value_tensor(self):
paddle.disable_static()
self._test(self.tensor_value)

def test_value_numpy(self):
paddle.disable_static()
self._test(self.np_value)

def test_value_int(self):
paddle.disable_static()
self._test(10)


class TestVarBaseSetitemBoolScalarIndex(unittest.TestCase):
def set_input(self):
self.tensor_x = paddle.to_tensor(np.ones((1, 2, 3)).astype(self.dtype))
self.np_value = np.random.random((2, 3)).astype(self.dtype)
self.tensor_value = paddle.to_tensor(self.np_value)

def _test(self, value):
paddle.disable_static()
self.assertEqual(self.tensor_x.inplace_version, 0)

id_origin = id(self.tensor_x)
index = paddle.to_tensor(np.array([True]))
self.tensor_x[index] = value
self.assertEqual(self.tensor_x.inplace_version, 1)

if isinstance(value, (six.integer_types, float)):
result = np.zeros((2, 3)).astype(self.dtype) + value

else:
result = self.np_value

self.assertTrue(np.array_equal(self.tensor_x[0].numpy(), result))
self.assertEqual(id_origin, id(self.tensor_x))


class TestVarBaseInplaceVersion(unittest.TestCase):
def test_setitem(self):
paddle.disable_static()
Expand Down
56 changes: 55 additions & 1 deletion python/paddle/fluid/tests/unittests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,23 @@ def _test_slice_index_list_bool(self, place):
with paddle.static.program_guard(prog):
res = x[[False, False]]

def _test_slice_index_scalar_bool(self, place):
data = np.random.rand(1, 3, 4).astype("float32")
np_idx = np.array([True])
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
x = paddle.assign(data)
idx = paddle.assign(np_idx)

out = x[idx]

exe = paddle.static.Executor(place)
result = exe.run(prog, fetch_list=[out])

expected = [data[np_idx]]

self.assertTrue((result[0] == expected[0]).all())

def test_slice(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
Expand All @@ -347,6 +364,7 @@ def test_slice(self):
self._test_slice_index_list(place)
self._test_slice_index_ellipsis(place)
self._test_slice_index_list_bool(place)
self._test_slice_index_scalar_bool(place)

def _tostring(self):
b = default_main_program().current_block()
Expand Down Expand Up @@ -705,7 +723,7 @@ def run_setitem_list_index(self, array, index, value_np):
fetch_list=fetch_list)

self.assertTrue(
np.array_equal(array2, setitem_pp[0]),
np.allclose(array2, setitem_pp[0]),
msg='\n numpy:{},\n paddle:{}'.format(array2, setitem_pp[0]))

def test_static_graph_setitem_list_index(self):
Expand Down Expand Up @@ -769,6 +787,42 @@ def test_static_graph_setitem_list_index(self):
index_mod = (index % (min(array.shape))).tolist()
self.run_setitem_list_index(array, index_mod, value_np)

def test_static_graph_setitem_bool_index(self):
paddle.enable_static()

# case 1:
array = np.ones((4, 2, 3), dtype='float32')
value_np = np.random.random((2, 3)).astype('float32')
index = np.array([True, False, False, False])
program = paddle.static.Program()
with paddle.static.program_guard(program):
self.run_setitem_list_index(array, index, value_np)

# case 2:
array = np.ones((4, 2, 3), dtype='float32')
value_np = np.random.random((2, 3)).astype('float32')
index = np.array([False, True, False, False])
program = paddle.static.Program()
with paddle.static.program_guard(program):
self.run_setitem_list_index(array, index, value_np)

# case 3:
array = np.ones((4, 2, 3), dtype='float32')
value_np = np.random.random((2, 3)).astype('float32')
index = np.array([True, True, True, True])
program = paddle.static.Program()
with paddle.static.program_guard(program):
self.run_setitem_list_index(array, index, value_np)

def test_static_graph_setitem_bool_scalar_index(self):
paddle.enable_static()
array = np.ones((1, 2, 3), dtype='float32')
value_np = np.random.random((2, 3)).astype('float32')
index = np.array([True])
program = paddle.static.Program()
with paddle.static.program_guard(program):
self.run_setitem_list_index(array, index, value_np)

def test_static_graph_tensor_index_setitem_muti_dim(self):
paddle.enable_static()
inps_shape = [3, 4, 5, 4]
Expand Down

1 comment on commit 767317f

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.