From 6464b58b7b7002bce8c265387d37c4f79fdfe21a Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Tue, 30 Apr 2024 12:50:22 +0800 Subject: [PATCH 1/9] add isin --- python/paddle/__init__.py | 2 + python/paddle/tensor/__init__.py | 2 + python/paddle/tensor/math.py | 116 +++++++++++++++++++++ test/legacy_test/test_isin.py | 166 +++++++++++++++++++++++++++++++ 4 files changed, 286 insertions(+) create mode 100644 test/legacy_test/test_isin.py diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index e126e41efcf65..d03fd458c1525 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -433,6 +433,7 @@ inner, inverse, isfinite, + isin, isinf, isnan, isneginf, @@ -730,6 +731,7 @@ 'squeeze_', 'to_tensor', 'gather_nd', + 'isin', 'isinf', 'isneginf', 'isposinf', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 4de5e392a8493..13187fa9cd019 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -306,6 +306,7 @@ inner, inverse, isfinite, + isin, isinf, isnan, isneginf, @@ -587,6 +588,7 @@ 'kron', 'kthvalue', 'isfinite', + 'isin', 'isinf', 'isnan', 'isneginf', diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index d7d8669ff0c3b..f9e4b0added27 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -7969,3 +7969,119 @@ def sinc_(x, name=None): paddle.sin_(x) paddle.divide_(x, tmp) return paddle.where(~paddle.isnan(x), x, paddle.full_like(x, 1.0)) + + +def isin(elements, test_elements, assume_unique=False, invert=False, name=None): + r""" + Tests if each element of `elements` is in `test_elements`. + + Args: + elements (Tensor): The input Tensor. Supported data type: 'float32', 'float64', 'int32', 'int64'. + test_elements (Tensor): Tensor values against which to test for each input element. Supported data type: 'float32', 'float64', 'int32', 'int64'. + assume_unique (bool, optional): If True, indicates both `elements` and `test_elements` contain unique elements. Default: False. + invert (bool, optional): Indicate whether to invert the boolean return tensor. If True, invert the results. Default: False. + name (str, optional): Name for the operation (optional, default is None).For more information, please refer to :ref:`api_guide_Name`. + + Returns: + out (Tensor), The output Tensor with the same shape as `elements`. + + Examples: + .. code-block:: python + >>> import paddle + >>> paddle.set_device('cpu') + >>> elements = paddle.to_tensor([-0., -2.1, 2.5, 1.0, -2.1], dtype='float32') + >>> test_elements = paddle.to_tensor([-2.1, 2.5], dtype='float32') + >>> res = paddle.isin(elements, test_elements) + >>> print(res) + Tensor(shape=[5], dtype=bool, place=Place(cpu), stop_gradient=True, + [False, True, True, False, True]) + >>> elements = paddle.to_tensor([-0., -2.1, 2.5, 1.0, -2.1], dtype='float32') + >>> test_elements = paddle.to_tensor([-2.1, 2.5], dtype='float32') + >>> res = paddle.isin(elements, test_elements, invert=True) + >>> print(res) + Tensor(shape=[5], dtype=bool, place=Place(cpu), stop_gradient=True, + [True, False, False, True, False]) + """ + if not isinstance(elements, (paddle.Tensor, Variable, paddle.pir.Value)): + raise TypeError(f"x must be tensor type, but got {type(elements)}") + if not isinstance( + test_elements, (paddle.Tensor, Variable, paddle.pir.Value) + ): + raise TypeError(f"x must be tensor type, but got {type(test_elements)}") + + check_variable_and_dtype( + elements, + "elements", + [ + 'float32', + 'float64', + 'int32', + 'int64', + ], + "isin", + ) + + check_variable_and_dtype( + test_elements, + "test_elements", + [ + 'float32', + 'float64', + 'int32', + 'int64', + ], + "isin", + ) + + elements_zero_dim = False + if len(elements.shape) == 0: + elements = elements.reshape([1]) + elements_zero_dim = True + + size_elements = paddle.cast(paddle.numel(elements), 'float32') + if test_elements.numel() < 10.0 * paddle.pow(size_elements, 0.145): + if len(elements.shape) == 0: + return paddle.zeros([], dtype='bool') + + x = elements.reshape( + tuple(elements.shape) + ((1,) * test_elements.ndim) + ) + cmp = x == test_elements + dim = tuple(range(-1, -test_elements.ndim - 1, -1)) + cmp = cmp.any(axis=dim) + if invert: + cmp = ~cmp + else: + elements_flat = elements.flatten() + test_elements_flat = test_elements.flatten() + if assume_unique: + all_elements = paddle.concat([elements_flat, test_elements_flat]) + sorted_index = paddle.argsort(all_elements, stable=True) + sorted_elements = all_elements[sorted_index] + + duplicate_mask = paddle.full_like(sorted_index, False, dtype='bool') + duplicate_mask[:-1] = sorted_elements[1:] == sorted_elements[:-1] + + if invert: + duplicate_mask = duplicate_mask.logical_not() + + mask = paddle.empty_like(duplicate_mask) + mask = sorted_index[duplicate_mask] + + cmp = mask[0 : elements.numel()] + else: + sorted_test_elements = paddle.sort(test_elements_flat) + idx = paddle.searchsorted(sorted_test_elements, elements_flat) + test_idx = paddle.where( + idx < sorted_test_elements.numel(), + idx, + paddle.zeros_like(idx, 'int64'), + ) + cmp = sorted_test_elements[test_idx] == elements_flat + cmp = cmp.logical_not() if invert else cmp + cmp = cmp.reshape(elements.shape) + + if elements_zero_dim: + return cmp.reshape([]) + else: + return cmp diff --git a/test/legacy_test/test_isin.py b/test/legacy_test/test_isin.py new file mode 100644 index 0000000000000..80db6e66f12df --- /dev/null +++ b/test/legacy_test/test_isin.py @@ -0,0 +1,166 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle import base, static + +DATA_CASES = [ + {'elements_data': np.array(1.0), 'test_elements_data': np.array(-1.0)}, + { + 'elements_data': np.random.randint(-10, 10, (4, 8)), + 'test_elements_data': np.random.randint(0, 20, (2, 3)), + }, + { + 'elements_data': np.random.randint(-50, 50, (8, 64)), + 'test_elements_data': np.random.randint(-20, 0, (4, 256)), + }, +] + +DATA_CASES_UNIQUE = [ + {'elements_data': np.array(-1.0), 'test_elements_data': np.array(1.0)}, + { + 'elements_data': np.arange(-100, 100).reshape([4, 5, 10]), + 'test_elements_data': np.arange(-10, 10), + }, + { + 'elements_data': np.arange(-100, 100).reshape([2, 2, 5, 10]), + 'test_elements_data': np.arange(50, 150).reshape([4, 5, 5]), + }, +] + +DATA_TYPE = ['float32', 'float64', 'int32', 'int64'] + + +def run_dygraph( + elements_data, + test_elements_data, + type, + assume_unique=False, + invert=False, + use_gpu=False, +): + place = paddle.CPUPlace() + if use_gpu and base.core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + paddle.disable_static(place) + elements_data = elements_data.astype(type) + test_elements_data = test_elements_data.astype(type) + x_e = paddle.to_tensor(elements_data) + x_t = paddle.to_tensor(test_elements_data) + return paddle.isin(x_e, x_t, assume_unique, invert) + + +def run_static( + elements_data, + test_elements_data, + type, + assume_unique=False, + invert=False, + use_gpu=False, +): + paddle.enable_static() + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + place = paddle.CPUPlace() + if use_gpu and base.core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = base.Executor(place) + with static.program_guard(main_program, startup_program): + elements_data = elements_data.astype(type) + test_elements_data = test_elements_data.astype(type) + x_e = paddle.static.data( + name='x_e', shape=elements_data.shape, dtype=type + ) + x_t = paddle.static.data( + name='x_t', shape=test_elements_data.shape, dtype=type + ) + res = paddle.isin(x_e, x_t, assume_unique, invert) + static_result = exe.run( + feed={'x_e': elements_data, 'x_t': test_elements_data}, + fetch_list=[res], + ) + return static_result + + +def test( + data_cases, type_cases, assume_unique=False, invert=False, use_gpu=False +): + for type in type_cases: + for case in data_cases: + elements_data = case['elements_data'] + test_elements_data = case['test_elements_data'] + dygraph_result = run_dygraph( + elements_data, test_elements_data, type, invert, use_gpu + ).numpy() + np_result = np.isin( + elements_data.astype(type), + test_elements_data.astype(type), + assume_unique=assume_unique, + invert=invert, + ) + np.testing.assert_equal(dygraph_result, np_result) + + def test_static(): + (static_result,) = run_static( + elements_data, test_elements_data, type, invert, use_gpu + ) + np.testing.assert_equal(static_result, np_result) + + test_static() + + +class TestIsInError(unittest.TestCase): + def test_for_exception(self): + with self.assertRaises(TypeError): + paddle.isin(np.array([1, 2]), np.array([1, 2])) + + +class TestIsIn(unittest.TestCase): + def test_without_gpu(self): + test(DATA_CASES, DATA_TYPE) + + def test_with_gpu(self): + test(DATA_CASES, DATA_TYPE, use_gpu=True) + + def test_invert_without_gpu(self): + test(DATA_CASES, DATA_TYPE, invert=True) + + def test_invert_with_gpu(self): + test(DATA_CASES, DATA_TYPE, invert=True, use_gpu=True) + + def test_unique_without_gpu(self): + test(DATA_CASES_UNIQUE, DATA_TYPE, assume_unique=True) + + def test_unique_with_gpu(self): + test(DATA_CASES_UNIQUE, DATA_TYPE, assume_unique=True, use_gpu=True) + + def test_unique_invert_without_gpu(self): + test(DATA_CASES_UNIQUE, DATA_TYPE, assume_unique=True, invert=True) + + def test_unique_invert_with_gpu(self): + test( + DATA_CASES_UNIQUE, + DATA_TYPE, + assume_unique=True, + invert=True, + use_gpu=True, + ) + + +if __name__ == '__main__': + unittest.main() From 321a03a16f339d754114fcbaf1c90e369ae9f861 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Tue, 30 Apr 2024 14:01:58 +0800 Subject: [PATCH 2/9] fix test --- test/legacy_test/test_isin.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/test/legacy_test/test_isin.py b/test/legacy_test/test_isin.py index 80db6e66f12df..9ec6a6e64fff4 100644 --- a/test/legacy_test/test_isin.py +++ b/test/legacy_test/test_isin.py @@ -105,7 +105,12 @@ def test( elements_data = case['elements_data'] test_elements_data = case['test_elements_data'] dygraph_result = run_dygraph( - elements_data, test_elements_data, type, invert, use_gpu + elements_data, + test_elements_data, + type, + assume_unique, + invert, + use_gpu, ).numpy() np_result = np.isin( elements_data.astype(type), @@ -117,7 +122,12 @@ def test( def test_static(): (static_result,) = run_static( - elements_data, test_elements_data, type, invert, use_gpu + elements_data, + test_elements_data, + type, + assume_unique, + invert, + use_gpu, ) np.testing.assert_equal(static_result, np_result) From 0816d26e7d74264f075d7f8589ad3e458b6db2d6 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Tue, 30 Apr 2024 14:17:51 +0800 Subject: [PATCH 3/9] update test --- test/legacy_test/test_isin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/legacy_test/test_isin.py b/test/legacy_test/test_isin.py index 9ec6a6e64fff4..7532e3e3fdadd 100644 --- a/test/legacy_test/test_isin.py +++ b/test/legacy_test/test_isin.py @@ -34,8 +34,8 @@ DATA_CASES_UNIQUE = [ {'elements_data': np.array(-1.0), 'test_elements_data': np.array(1.0)}, { - 'elements_data': np.arange(-100, 100).reshape([4, 5, 10]), - 'test_elements_data': np.arange(-10, 10), + 'elements_data': np.arange(0, 500).reshape([5, 10, 10]), + 'test_elements_data': np.arange(200, 400), }, { 'elements_data': np.arange(-100, 100).reshape([2, 2, 5, 10]), From 9dd666ee34ed4fbc73a59e3c3181f6b0b273d700 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Tue, 30 Apr 2024 16:31:34 +0800 Subject: [PATCH 4/9] fix unique --- python/paddle/tensor/math.py | 4 ++-- test/legacy_test/test_isin.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index f9e4b0added27..1b3d6bf28f303 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -8066,9 +8066,9 @@ def isin(elements, test_elements, assume_unique=False, invert=False, name=None): duplicate_mask = duplicate_mask.logical_not() mask = paddle.empty_like(duplicate_mask) - mask = sorted_index[duplicate_mask] + mask[sorted_index] = duplicate_mask - cmp = mask[0 : elements.numel()] + cmp = mask[0 : elements.numel()].reshape(elements.shape) else: sorted_test_elements = paddle.sort(test_elements_flat) idx = paddle.searchsorted(sorted_test_elements, elements_flat) diff --git a/test/legacy_test/test_isin.py b/test/legacy_test/test_isin.py index 7532e3e3fdadd..031f51fe200d8 100644 --- a/test/legacy_test/test_isin.py +++ b/test/legacy_test/test_isin.py @@ -34,8 +34,8 @@ DATA_CASES_UNIQUE = [ {'elements_data': np.array(-1.0), 'test_elements_data': np.array(1.0)}, { - 'elements_data': np.arange(0, 500).reshape([5, 10, 10]), - 'test_elements_data': np.arange(200, 400), + 'elements_data': np.arange(0, 1000).reshape([2, 5, 100]), + 'test_elements_data': np.arange(200, 700), }, { 'elements_data': np.arange(-100, 100).reshape([2, 2, 5, 10]), From 9130a58559d75ef6552856b94ff3d99a614940fe Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Sat, 4 May 2024 19:24:06 +0800 Subject: [PATCH 5/9] fix timeout and en docs --- python/paddle/tensor/math.py | 2 ++ test/legacy_test/CMakeLists.txt | 1 + 2 files changed, 3 insertions(+) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 1b3d6bf28f303..b6e70d09ef9a0 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -7987,6 +7987,7 @@ def isin(elements, test_elements, assume_unique=False, invert=False, name=None): Examples: .. code-block:: python + >>> import paddle >>> paddle.set_device('cpu') >>> elements = paddle.to_tensor([-0., -2.1, 2.5, 1.0, -2.1], dtype='float32') @@ -7995,6 +7996,7 @@ def isin(elements, test_elements, assume_unique=False, invert=False, name=None): >>> print(res) Tensor(shape=[5], dtype=bool, place=Place(cpu), stop_gradient=True, [False, True, True, False, True]) + >>> elements = paddle.to_tensor([-0., -2.1, 2.5, 1.0, -2.1], dtype='float32') >>> test_elements = paddle.to_tensor([-2.1, 2.5], dtype='float32') >>> res = paddle.isin(elements, test_elements, invert=True) diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index 8c4cfe9113ab3..6cd56d29e5791 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -753,6 +753,7 @@ if(WITH_DISTRIBUTE) endif() # setting timeout value as 15S +set_tests_properties(test_isin PROPERTIES TIMEOUT 30) set_tests_properties(test_binomial_op PROPERTIES TIMEOUT 30) set_tests_properties(test_run PROPERTIES TIMEOUT 120) set_tests_properties(test_sync_batch_norm_op PROPERTIES TIMEOUT 180) From 5372eb4734393061c90a68f3e939dac6c1a4f7c1 Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Sat, 11 May 2024 21:29:08 +0800 Subject: [PATCH 6/9] add code example for assume_unique --- python/paddle/tensor/math.py | 52 ++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index b6e70d09ef9a0..b7b2019fadc10 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -8003,6 +8003,58 @@ def isin(elements, test_elements, assume_unique=False, invert=False, name=None): >>> print(res) Tensor(shape=[5], dtype=bool, place=Place(cpu), stop_gradient=True, [True, False, False, True, False]) + + >>> # Set `assume_unique` to True only when `elements` and `test_elements` contain unique values, ortherwise the result may be incorrect. + >>> elements = paddle.to_tensor([0., 1., 2.]*20).reshape([20, 3]) + >>> test_elements = paddle.to_tensor([0., 1.]*20) + >>> correct_result = paddle.isin(elements, test_elements, assume_unique=False) + >>> print(correct_result) + Tensor(shape=[20, 3], dtype=bool, place=Place(cpu), stop_gradient=True, + [[True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False]]) + + >>> incorrect_result = paddle.isin(elements, test_elements, assume_unique=True) + >>> print(incorrect_result) + Tensor(shape=[20, 3], dtype=bool, place=Place(gpu:0), stop_gradient=True, + [[True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , False]]) + """ if not isinstance(elements, (paddle.Tensor, Variable, paddle.pir.Value)): raise TypeError(f"x must be tensor type, but got {type(elements)}") From 77484fca45f285a719049205f426445e233edc1a Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Sat, 11 May 2024 21:38:53 +0800 Subject: [PATCH 7/9] add notations --- python/paddle/tensor/math.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index b7b2019fadc10..bc8e65974c9ce 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -7978,7 +7978,7 @@ def isin(elements, test_elements, assume_unique=False, invert=False, name=None): Args: elements (Tensor): The input Tensor. Supported data type: 'float32', 'float64', 'int32', 'int64'. test_elements (Tensor): Tensor values against which to test for each input element. Supported data type: 'float32', 'float64', 'int32', 'int64'. - assume_unique (bool, optional): If True, indicates both `elements` and `test_elements` contain unique elements. Default: False. + assume_unique (bool, optional): If True, indicates both `elements` and `test_elements` contain unique elements, which could make the calculation faster. Default: False. invert (bool, optional): Indicate whether to invert the boolean return tensor. If True, invert the results. Default: False. name (str, optional): Name for the operation (optional, default is None).For more information, please refer to :ref:`api_guide_Name`. @@ -8004,7 +8004,7 @@ def isin(elements, test_elements, assume_unique=False, invert=False, name=None): Tensor(shape=[5], dtype=bool, place=Place(cpu), stop_gradient=True, [True, False, False, True, False]) - >>> # Set `assume_unique` to True only when `elements` and `test_elements` contain unique values, ortherwise the result may be incorrect. + >>> # Set `assume_unique` to True only when `elements` and `test_elements` contain unique values, otherwise the result may be incorrect. >>> elements = paddle.to_tensor([0., 1., 2.]*20).reshape([20, 3]) >>> test_elements = paddle.to_tensor([0., 1.]*20) >>> correct_result = paddle.isin(elements, test_elements, assume_unique=False) @@ -8094,6 +8094,7 @@ def isin(elements, test_elements, assume_unique=False, invert=False, name=None): size_elements = paddle.cast(paddle.numel(elements), 'float32') if test_elements.numel() < 10.0 * paddle.pow(size_elements, 0.145): + # use brute-force searching if the test_elements size is small if len(elements.shape) == 0: return paddle.zeros([], dtype='bool') @@ -8109,6 +8110,7 @@ def isin(elements, test_elements, assume_unique=False, invert=False, name=None): elements_flat = elements.flatten() test_elements_flat = test_elements.flatten() if assume_unique: + # if elements and test_elements both contain unique elements, use stable argsort method which could be faster all_elements = paddle.concat([elements_flat, test_elements_flat]) sorted_index = paddle.argsort(all_elements, stable=True) sorted_elements = all_elements[sorted_index] @@ -8124,6 +8126,7 @@ def isin(elements, test_elements, assume_unique=False, invert=False, name=None): cmp = mask[0 : elements.numel()].reshape(elements.shape) else: + # otherwise use searchsorted method sorted_test_elements = paddle.sort(test_elements_flat) idx = paddle.searchsorted(sorted_test_elements, elements_flat) test_idx = paddle.where( From 3f444aeb753682056fffef841a8d59fb413ab3db Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Mon, 13 May 2024 18:09:10 +0800 Subject: [PATCH 8/9] revise parameter name --- python/paddle/tensor/math.py | 102 ++++++++++++++++------------------ test/legacy_test/test_isin.py | 64 +++++++++++---------- 2 files changed, 80 insertions(+), 86 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index bc8e65974c9ce..fb66db2faec07 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -7971,43 +7971,43 @@ def sinc_(x, name=None): return paddle.where(~paddle.isnan(x), x, paddle.full_like(x, 1.0)) -def isin(elements, test_elements, assume_unique=False, invert=False, name=None): +def isin(x, test_x, assume_unique=False, invert=False, name=None): r""" - Tests if each element of `elements` is in `test_elements`. + Tests if each element of `x` is in `test_x`. Args: - elements (Tensor): The input Tensor. Supported data type: 'float32', 'float64', 'int32', 'int64'. - test_elements (Tensor): Tensor values against which to test for each input element. Supported data type: 'float32', 'float64', 'int32', 'int64'. - assume_unique (bool, optional): If True, indicates both `elements` and `test_elements` contain unique elements, which could make the calculation faster. Default: False. + x (Tensor): The input Tensor. Supported data type: 'float32', 'float64', 'int32', 'int64'. + test_x (Tensor): Tensor values against which to test for each input element. Supported data type: 'float32', 'float64', 'int32', 'int64'. + assume_unique (bool, optional): If True, indicates both `x` and `test_x` contain unique elements, which could make the calculation faster. Default: False. invert (bool, optional): Indicate whether to invert the boolean return tensor. If True, invert the results. Default: False. name (str, optional): Name for the operation (optional, default is None).For more information, please refer to :ref:`api_guide_Name`. Returns: - out (Tensor), The output Tensor with the same shape as `elements`. + out (Tensor), The output Tensor with the same shape as `x`. Examples: .. code-block:: python >>> import paddle >>> paddle.set_device('cpu') - >>> elements = paddle.to_tensor([-0., -2.1, 2.5, 1.0, -2.1], dtype='float32') - >>> test_elements = paddle.to_tensor([-2.1, 2.5], dtype='float32') - >>> res = paddle.isin(elements, test_elements) + >>> x = paddle.to_tensor([-0., -2.1, 2.5, 1.0, -2.1], dtype='float32') + >>> test_x = paddle.to_tensor([-2.1, 2.5], dtype='float32') + >>> res = paddle.isin(x, test_x) >>> print(res) Tensor(shape=[5], dtype=bool, place=Place(cpu), stop_gradient=True, [False, True, True, False, True]) - >>> elements = paddle.to_tensor([-0., -2.1, 2.5, 1.0, -2.1], dtype='float32') - >>> test_elements = paddle.to_tensor([-2.1, 2.5], dtype='float32') - >>> res = paddle.isin(elements, test_elements, invert=True) + >>> x = paddle.to_tensor([-0., -2.1, 2.5, 1.0, -2.1], dtype='float32') + >>> test_x = paddle.to_tensor([-2.1, 2.5], dtype='float32') + >>> res = paddle.isin(x, test_x, invert=True) >>> print(res) Tensor(shape=[5], dtype=bool, place=Place(cpu), stop_gradient=True, [True, False, False, True, False]) - >>> # Set `assume_unique` to True only when `elements` and `test_elements` contain unique values, otherwise the result may be incorrect. - >>> elements = paddle.to_tensor([0., 1., 2.]*20).reshape([20, 3]) - >>> test_elements = paddle.to_tensor([0., 1.]*20) - >>> correct_result = paddle.isin(elements, test_elements, assume_unique=False) + >>> # Set `assume_unique` to True only when `x` and `test_x` contain unique values, otherwise the result may be incorrect. + >>> x = paddle.to_tensor([0., 1., 2.]*20).reshape([20, 3]) + >>> test_x = paddle.to_tensor([0., 1.]*20) + >>> correct_result = paddle.isin(x, test_x, assume_unique=False) >>> print(correct_result) Tensor(shape=[20, 3], dtype=bool, place=Place(cpu), stop_gradient=True, [[True , True , False], @@ -8031,7 +8031,7 @@ def isin(elements, test_elements, assume_unique=False, invert=False, name=None): [True , True , False], [True , True , False]]) - >>> incorrect_result = paddle.isin(elements, test_elements, assume_unique=True) + >>> incorrect_result = paddle.isin(x, test_x, assume_unique=True) >>> print(incorrect_result) Tensor(shape=[20, 3], dtype=bool, place=Place(gpu:0), stop_gradient=True, [[True , True , True ], @@ -8056,16 +8056,14 @@ def isin(elements, test_elements, assume_unique=False, invert=False, name=None): [True , True , False]]) """ - if not isinstance(elements, (paddle.Tensor, Variable, paddle.pir.Value)): - raise TypeError(f"x must be tensor type, but got {type(elements)}") - if not isinstance( - test_elements, (paddle.Tensor, Variable, paddle.pir.Value) - ): - raise TypeError(f"x must be tensor type, but got {type(test_elements)}") + if not isinstance(x, (paddle.Tensor, Variable, paddle.pir.Value)): + raise TypeError(f"x must be tensor type, but got {type(x)}") + if not isinstance(test_x, (paddle.Tensor, Variable, paddle.pir.Value)): + raise TypeError(f"x must be tensor type, but got {type(test_x)}") check_variable_and_dtype( - elements, - "elements", + x, + "x", [ 'float32', 'float64', @@ -8076,8 +8074,8 @@ def isin(elements, test_elements, assume_unique=False, invert=False, name=None): ) check_variable_and_dtype( - test_elements, - "test_elements", + test_x, + "test_x", [ 'float32', 'float64', @@ -8087,36 +8085,34 @@ def isin(elements, test_elements, assume_unique=False, invert=False, name=None): "isin", ) - elements_zero_dim = False - if len(elements.shape) == 0: - elements = elements.reshape([1]) - elements_zero_dim = True + x_zero_dim = False + if len(x.shape) == 0: + x = x.reshape([1]) + x_zero_dim = True - size_elements = paddle.cast(paddle.numel(elements), 'float32') - if test_elements.numel() < 10.0 * paddle.pow(size_elements, 0.145): - # use brute-force searching if the test_elements size is small - if len(elements.shape) == 0: + size_x = paddle.cast(paddle.numel(x), 'float32') + if test_x.numel() < 10.0 * paddle.pow(size_x, 0.145): + # use brute-force searching if the test_x size is small + if len(x.shape) == 0: return paddle.zeros([], dtype='bool') - x = elements.reshape( - tuple(elements.shape) + ((1,) * test_elements.ndim) - ) - cmp = x == test_elements - dim = tuple(range(-1, -test_elements.ndim - 1, -1)) + tmp = x.reshape(tuple(x.shape) + ((1,) * test_x.ndim)) + cmp = tmp == test_x + dim = tuple(range(-1, -test_x.ndim - 1, -1)) cmp = cmp.any(axis=dim) if invert: cmp = ~cmp else: - elements_flat = elements.flatten() - test_elements_flat = test_elements.flatten() + x_flat = x.flatten() + test_x_flat = test_x.flatten() if assume_unique: - # if elements and test_elements both contain unique elements, use stable argsort method which could be faster - all_elements = paddle.concat([elements_flat, test_elements_flat]) + # if x and test_x both contain unique elements, use stable argsort method which could be faster + all_elements = paddle.concat([x_flat, test_x_flat]) sorted_index = paddle.argsort(all_elements, stable=True) - sorted_elements = all_elements[sorted_index] + sorted_x = all_elements[sorted_index] duplicate_mask = paddle.full_like(sorted_index, False, dtype='bool') - duplicate_mask[:-1] = sorted_elements[1:] == sorted_elements[:-1] + duplicate_mask[:-1] = sorted_x[1:] == sorted_x[:-1] if invert: duplicate_mask = duplicate_mask.logical_not() @@ -8124,21 +8120,21 @@ def isin(elements, test_elements, assume_unique=False, invert=False, name=None): mask = paddle.empty_like(duplicate_mask) mask[sorted_index] = duplicate_mask - cmp = mask[0 : elements.numel()].reshape(elements.shape) + cmp = mask[0 : x.numel()].reshape(x.shape) else: # otherwise use searchsorted method - sorted_test_elements = paddle.sort(test_elements_flat) - idx = paddle.searchsorted(sorted_test_elements, elements_flat) + sorted_test_x = paddle.sort(test_x_flat) + idx = paddle.searchsorted(sorted_test_x, x_flat) test_idx = paddle.where( - idx < sorted_test_elements.numel(), + idx < sorted_test_x.numel(), idx, paddle.zeros_like(idx, 'int64'), ) - cmp = sorted_test_elements[test_idx] == elements_flat + cmp = sorted_test_x[test_idx] == x_flat cmp = cmp.logical_not() if invert else cmp - cmp = cmp.reshape(elements.shape) + cmp = cmp.reshape(x.shape) - if elements_zero_dim: + if x_zero_dim: return cmp.reshape([]) else: return cmp diff --git a/test/legacy_test/test_isin.py b/test/legacy_test/test_isin.py index 031f51fe200d8..395d0f0fd87a1 100644 --- a/test/legacy_test/test_isin.py +++ b/test/legacy_test/test_isin.py @@ -20,26 +20,26 @@ from paddle import base, static DATA_CASES = [ - {'elements_data': np.array(1.0), 'test_elements_data': np.array(-1.0)}, + {'x_data': np.array(1.0), 'test_x_data': np.array(-1.0)}, { - 'elements_data': np.random.randint(-10, 10, (4, 8)), - 'test_elements_data': np.random.randint(0, 20, (2, 3)), + 'x_data': np.random.randint(-10, 10, (4, 8)), + 'test_x_data': np.random.randint(0, 20, (2, 3)), }, { - 'elements_data': np.random.randint(-50, 50, (8, 64)), - 'test_elements_data': np.random.randint(-20, 0, (4, 256)), + 'x_data': np.random.randint(-50, 50, (8, 64)), + 'test_x_data': np.random.randint(-20, 0, (4, 256)), }, ] DATA_CASES_UNIQUE = [ - {'elements_data': np.array(-1.0), 'test_elements_data': np.array(1.0)}, + {'x_data': np.array(-1.0), 'test_x_data': np.array(1.0)}, { - 'elements_data': np.arange(0, 1000).reshape([2, 5, 100]), - 'test_elements_data': np.arange(200, 700), + 'x_data': np.arange(0, 1000).reshape([2, 5, 100]), + 'test_x_data': np.arange(200, 700), }, { - 'elements_data': np.arange(-100, 100).reshape([2, 2, 5, 10]), - 'test_elements_data': np.arange(50, 150).reshape([4, 5, 5]), + 'x_data': np.arange(-100, 100).reshape([2, 2, 5, 10]), + 'test_x_data': np.arange(50, 150).reshape([4, 5, 5]), }, ] @@ -47,8 +47,8 @@ def run_dygraph( - elements_data, - test_elements_data, + x_data, + test_x_data, type, assume_unique=False, invert=False, @@ -58,16 +58,16 @@ def run_dygraph( if use_gpu and base.core.is_compiled_with_cuda(): place = paddle.CUDAPlace(0) paddle.disable_static(place) - elements_data = elements_data.astype(type) - test_elements_data = test_elements_data.astype(type) - x_e = paddle.to_tensor(elements_data) - x_t = paddle.to_tensor(test_elements_data) + x_data = x_data.astype(type) + test_x_data = test_x_data.astype(type) + x_e = paddle.to_tensor(x_data) + x_t = paddle.to_tensor(test_x_data) return paddle.isin(x_e, x_t, assume_unique, invert) def run_static( - elements_data, - test_elements_data, + x_data, + test_x_data, type, assume_unique=False, invert=False, @@ -81,17 +81,15 @@ def run_static( place = paddle.CUDAPlace(0) exe = base.Executor(place) with static.program_guard(main_program, startup_program): - elements_data = elements_data.astype(type) - test_elements_data = test_elements_data.astype(type) - x_e = paddle.static.data( - name='x_e', shape=elements_data.shape, dtype=type - ) + x_data = x_data.astype(type) + test_x_data = test_x_data.astype(type) + x_e = paddle.static.data(name='x_e', shape=x_data.shape, dtype=type) x_t = paddle.static.data( - name='x_t', shape=test_elements_data.shape, dtype=type + name='x_t', shape=test_x_data.shape, dtype=type ) res = paddle.isin(x_e, x_t, assume_unique, invert) static_result = exe.run( - feed={'x_e': elements_data, 'x_t': test_elements_data}, + feed={'x_e': x_data, 'x_t': test_x_data}, fetch_list=[res], ) return static_result @@ -102,19 +100,19 @@ def test( ): for type in type_cases: for case in data_cases: - elements_data = case['elements_data'] - test_elements_data = case['test_elements_data'] + x_data = case['x_data'] + test_x_data = case['test_x_data'] dygraph_result = run_dygraph( - elements_data, - test_elements_data, + x_data, + test_x_data, type, assume_unique, invert, use_gpu, ).numpy() np_result = np.isin( - elements_data.astype(type), - test_elements_data.astype(type), + x_data.astype(type), + test_x_data.astype(type), assume_unique=assume_unique, invert=invert, ) @@ -122,8 +120,8 @@ def test( def test_static(): (static_result,) = run_static( - elements_data, - test_elements_data, + x_data, + test_x_data, type, assume_unique, invert, From 0d7eb597bbc3acf87575b98fe21e43202f6e005a Mon Sep 17 00:00:00 2001 From: AyaseNana <13659110308@163.com> Date: Mon, 27 May 2024 15:09:47 +0800 Subject: [PATCH 9/9] update isin to support fp16 and bf16/ update argsort --- python/paddle/tensor/math.py | 27 ++++-- python/paddle/tensor/search.py | 1 + test/legacy_test/test_isin.py | 159 ++++++++++++++++++++++++++++++++- 3 files changed, 178 insertions(+), 9 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index fb66db2faec07..694e47c2340f8 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -7976,8 +7976,8 @@ def isin(x, test_x, assume_unique=False, invert=False, name=None): Tests if each element of `x` is in `test_x`. Args: - x (Tensor): The input Tensor. Supported data type: 'float32', 'float64', 'int32', 'int64'. - test_x (Tensor): Tensor values against which to test for each input element. Supported data type: 'float32', 'float64', 'int32', 'int64'. + x (Tensor): The input Tensor. Supported data type: 'bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64'. + test_x (Tensor): Tensor values against which to test for each input element. Supported data type: 'bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64'. assume_unique (bool, optional): If True, indicates both `x` and `test_x` contain unique elements, which could make the calculation faster. Default: False. invert (bool, optional): Indicate whether to invert the boolean return tensor. If True, invert the results. Default: False. name (str, optional): Name for the operation (optional, default is None).For more information, please refer to :ref:`api_guide_Name`. @@ -8065,6 +8065,8 @@ def isin(x, test_x, assume_unique=False, invert=False, name=None): x, "x", [ + 'uint16', + 'float16', 'float32', 'float64', 'int32', @@ -8077,6 +8079,8 @@ def isin(x, test_x, assume_unique=False, invert=False, name=None): test_x, "test_x", [ + 'uint16', + 'float16', 'float32', 'float64', 'int32', @@ -8090,8 +8094,9 @@ def isin(x, test_x, assume_unique=False, invert=False, name=None): x = x.reshape([1]) x_zero_dim = True - size_x = paddle.cast(paddle.numel(x), 'float32') - if test_x.numel() < 10.0 * paddle.pow(size_x, 0.145): + size_x = math.prod(x.shape) + size_t = math.prod(test_x.shape) + if size_t < math.pow(size_x, 0.145) * 10.0: # use brute-force searching if the test_x size is small if len(x.shape) == 0: return paddle.zeros([], dtype='bool') @@ -8112,13 +8117,23 @@ def isin(x, test_x, assume_unique=False, invert=False, name=None): sorted_x = all_elements[sorted_index] duplicate_mask = paddle.full_like(sorted_index, False, dtype='bool') - duplicate_mask[:-1] = sorted_x[1:] == sorted_x[:-1] + if not in_dynamic_mode(): + duplicate_mask = paddle.static.setitem( + duplicate_mask, + paddle.arange(duplicate_mask.numel() - 1), + sorted_x[1:] == sorted_x[:-1], + ) + else: + duplicate_mask[:-1] = sorted_x[1:] == sorted_x[:-1] if invert: duplicate_mask = duplicate_mask.logical_not() mask = paddle.empty_like(duplicate_mask) - mask[sorted_index] = duplicate_mask + if not in_dynamic_or_pir_mode(): + mask = paddle.static.setitem(mask, sorted_index, duplicate_mask) + else: + mask[sorted_index] = duplicate_mask cmp = mask[0 : x.numel()].reshape(x.shape) else: diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 736ae891f2fb8..9ec4cd1e2ec7f 100755 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -130,6 +130,7 @@ def argsort(x, axis=-1, descending=False, stable=False, name=None): x, 'x', [ + 'uint16', 'float16', 'float32', 'float64', diff --git a/test/legacy_test/test_isin.py b/test/legacy_test/test_isin.py index 395d0f0fd87a1..101d89b4de84f 100644 --- a/test/legacy_test/test_isin.py +++ b/test/legacy_test/test_isin.py @@ -15,9 +15,12 @@ import unittest import numpy as np +from op_test import convert_float_to_uint16 import paddle -from paddle import base, static +from paddle import base +from paddle.base import core +from paddle.pir_utils import test_with_pir_api DATA_CASES = [ {'x_data': np.array(1.0), 'test_x_data': np.array(-1.0)}, @@ -32,7 +35,6 @@ ] DATA_CASES_UNIQUE = [ - {'x_data': np.array(-1.0), 'test_x_data': np.array(1.0)}, { 'x_data': np.arange(0, 1000).reshape([2, 5, 100]), 'test_x_data': np.arange(200, 700), @@ -43,6 +45,27 @@ }, ] +DATA_CASES_BF16 = [ + {'x_data': np.array(1.0), 'test_x_data': np.array(0.0)}, + { + 'x_data': np.random.randint(0, 10, (4, 8)), + 'test_x_data': np.random.randint(5, 15, (2, 3)), + }, + { + 'x_data': np.random.randint(0, 50, (8, 64)), + 'test_x_data': np.random.randint(0, 20, (4, 256)), + }, +] + + +DATA_CASES_UNIQUE_BF16 = [ + { + 'x_data': np.arange(0, 100).reshape([2, 5, 10]), + 'test_x_data': np.arange(50, 150), + }, +] + + DATA_TYPE = ['float32', 'float64', 'int32', 'int64'] @@ -80,7 +103,7 @@ def run_static( if use_gpu and base.core.is_compiled_with_cuda(): place = paddle.CUDAPlace(0) exe = base.Executor(place) - with static.program_guard(main_program, startup_program): + with paddle.static.program_guard(main_program, startup_program): x_data = x_data.astype(type) test_x_data = test_x_data.astype(type) x_e = paddle.static.data(name='x_e', shape=x_data.shape, dtype=type) @@ -118,6 +141,7 @@ def test( ) np.testing.assert_equal(dygraph_result, np_result) + @test_with_pir_api def test_static(): (static_result,) = run_static( x_data, @@ -132,6 +156,86 @@ def test_static(): test_static() +def run_dygraph_bf16( + x_data, + test_x_data, + assume_unique=False, + invert=False, + use_gpu=False, +): + place = paddle.CPUPlace() + if use_gpu and base.core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + paddle.disable_static(place) + x_e = paddle.to_tensor(convert_float_to_uint16(x_data)) + x_t = paddle.to_tensor(convert_float_to_uint16(test_x_data)) + return paddle.isin(x_e, x_t, assume_unique, invert) + + +def run_static_bf16( + x_data, + test_x_data, + assume_unique=False, + invert=False, + use_gpu=False, +): + paddle.enable_static() + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + place = paddle.CPUPlace() + if use_gpu and base.core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = base.Executor(place) + with paddle.static.program_guard(main_program, startup_program): + x_data = convert_float_to_uint16(x_data) + test_x_data = convert_float_to_uint16(test_x_data) + x_e = paddle.static.data( + name='x_e', shape=x_data.shape, dtype=np.uint16 + ) + x_t = paddle.static.data( + name='x_t', shape=test_x_data.shape, dtype=np.uint16 + ) + res = paddle.isin(x_e, x_t, assume_unique, invert) + static_result = exe.run( + feed={'x_e': x_data, 'x_t': test_x_data}, + fetch_list=[res], + ) + return static_result + + +def test_bf16(data_cases, assume_unique=False, invert=False, use_gpu=False): + for case in data_cases: + x_data = case['x_data'].astype("float32") + test_x_data = case['test_x_data'].astype("float32") + dygraph_result = run_dygraph_bf16( + x_data, + test_x_data, + assume_unique, + invert, + use_gpu, + ).numpy() + np_result = np.isin( + x_data, + test_x_data, + assume_unique=assume_unique, + invert=invert, + ) + np.testing.assert_equal(dygraph_result, np_result) + + @test_with_pir_api + def test_static(): + (static_result,) = run_static_bf16( + x_data, + test_x_data, + assume_unique, + invert, + use_gpu, + ) + np.testing.assert_equal(static_result, np_result) + + test_static() + + class TestIsInError(unittest.TestCase): def test_for_exception(self): with self.assertRaises(TypeError): @@ -170,5 +274,54 @@ def test_unique_invert_with_gpu(self): ) +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_float16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and not support the float16", +) +class TestIsInFP16(unittest.TestCase): + def test_default(self): + test(DATA_CASES, ['float16'], use_gpu=True) + + def test_invert(self): + test(DATA_CASES, ['float16'], invert=True, use_gpu=True) + + def test_unique(self): + test(DATA_CASES_UNIQUE, ['float16'], assume_unique=True, use_gpu=True) + + def test_unique_invert(self): + test( + DATA_CASES_UNIQUE, + ['float16'], + assume_unique=True, + invert=True, + use_gpu=True, + ) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_float16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and not support the float16", +) +class TestIsInBF16(unittest.TestCase): + def test_default(self): + test_bf16(DATA_CASES_BF16, use_gpu=True) + + def test_invert(self): + test_bf16(DATA_CASES_BF16, invert=True, use_gpu=True) + + def test_unique(self): + test_bf16(DATA_CASES_UNIQUE_BF16, assume_unique=True, use_gpu=True) + + def test_unique_invert(self): + test_bf16( + DATA_CASES_UNIQUE_BF16, + assume_unique=True, + invert=True, + use_gpu=True, + ) + + if __name__ == '__main__': unittest.main()