diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index e126e41efcf65..d03fd458c1525 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -433,6 +433,7 @@ inner, inverse, isfinite, + isin, isinf, isnan, isneginf, @@ -730,6 +731,7 @@ 'squeeze_', 'to_tensor', 'gather_nd', + 'isin', 'isinf', 'isneginf', 'isposinf', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 4de5e392a8493..13187fa9cd019 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -306,6 +306,7 @@ inner, inverse, isfinite, + isin, isinf, isnan, isneginf, @@ -587,6 +588,7 @@ 'kron', 'kthvalue', 'isfinite', + 'isin', 'isinf', 'isnan', 'isneginf', diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index d7d8669ff0c3b..694e47c2340f8 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -7969,3 +7969,187 @@ def sinc_(x, name=None): paddle.sin_(x) paddle.divide_(x, tmp) return paddle.where(~paddle.isnan(x), x, paddle.full_like(x, 1.0)) + + +def isin(x, test_x, assume_unique=False, invert=False, name=None): + r""" + Tests if each element of `x` is in `test_x`. + + Args: + x (Tensor): The input Tensor. Supported data type: 'bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64'. + test_x (Tensor): Tensor values against which to test for each input element. Supported data type: 'bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64'. + assume_unique (bool, optional): If True, indicates both `x` and `test_x` contain unique elements, which could make the calculation faster. Default: False. + invert (bool, optional): Indicate whether to invert the boolean return tensor. If True, invert the results. Default: False. + name (str, optional): Name for the operation (optional, default is None).For more information, please refer to :ref:`api_guide_Name`. + + Returns: + out (Tensor), The output Tensor with the same shape as `x`. + + Examples: + .. code-block:: python + + >>> import paddle + >>> paddle.set_device('cpu') + >>> x = paddle.to_tensor([-0., -2.1, 2.5, 1.0, -2.1], dtype='float32') + >>> test_x = paddle.to_tensor([-2.1, 2.5], dtype='float32') + >>> res = paddle.isin(x, test_x) + >>> print(res) + Tensor(shape=[5], dtype=bool, place=Place(cpu), stop_gradient=True, + [False, True, True, False, True]) + + >>> x = paddle.to_tensor([-0., -2.1, 2.5, 1.0, -2.1], dtype='float32') + >>> test_x = paddle.to_tensor([-2.1, 2.5], dtype='float32') + >>> res = paddle.isin(x, test_x, invert=True) + >>> print(res) + Tensor(shape=[5], dtype=bool, place=Place(cpu), stop_gradient=True, + [True, False, False, True, False]) + + >>> # Set `assume_unique` to True only when `x` and `test_x` contain unique values, otherwise the result may be incorrect. + >>> x = paddle.to_tensor([0., 1., 2.]*20).reshape([20, 3]) + >>> test_x = paddle.to_tensor([0., 1.]*20) + >>> correct_result = paddle.isin(x, test_x, assume_unique=False) + >>> print(correct_result) + Tensor(shape=[20, 3], dtype=bool, place=Place(cpu), stop_gradient=True, + [[True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False], + [True , True , False]]) + + >>> incorrect_result = paddle.isin(x, test_x, assume_unique=True) + >>> print(incorrect_result) + Tensor(shape=[20, 3], dtype=bool, place=Place(gpu:0), stop_gradient=True, + [[True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , True ], + [True , True , False]]) + + """ + if not isinstance(x, (paddle.Tensor, Variable, paddle.pir.Value)): + raise TypeError(f"x must be tensor type, but got {type(x)}") + if not isinstance(test_x, (paddle.Tensor, Variable, paddle.pir.Value)): + raise TypeError(f"x must be tensor type, but got {type(test_x)}") + + check_variable_and_dtype( + x, + "x", + [ + 'uint16', + 'float16', + 'float32', + 'float64', + 'int32', + 'int64', + ], + "isin", + ) + + check_variable_and_dtype( + test_x, + "test_x", + [ + 'uint16', + 'float16', + 'float32', + 'float64', + 'int32', + 'int64', + ], + "isin", + ) + + x_zero_dim = False + if len(x.shape) == 0: + x = x.reshape([1]) + x_zero_dim = True + + size_x = math.prod(x.shape) + size_t = math.prod(test_x.shape) + if size_t < math.pow(size_x, 0.145) * 10.0: + # use brute-force searching if the test_x size is small + if len(x.shape) == 0: + return paddle.zeros([], dtype='bool') + + tmp = x.reshape(tuple(x.shape) + ((1,) * test_x.ndim)) + cmp = tmp == test_x + dim = tuple(range(-1, -test_x.ndim - 1, -1)) + cmp = cmp.any(axis=dim) + if invert: + cmp = ~cmp + else: + x_flat = x.flatten() + test_x_flat = test_x.flatten() + if assume_unique: + # if x and test_x both contain unique elements, use stable argsort method which could be faster + all_elements = paddle.concat([x_flat, test_x_flat]) + sorted_index = paddle.argsort(all_elements, stable=True) + sorted_x = all_elements[sorted_index] + + duplicate_mask = paddle.full_like(sorted_index, False, dtype='bool') + if not in_dynamic_mode(): + duplicate_mask = paddle.static.setitem( + duplicate_mask, + paddle.arange(duplicate_mask.numel() - 1), + sorted_x[1:] == sorted_x[:-1], + ) + else: + duplicate_mask[:-1] = sorted_x[1:] == sorted_x[:-1] + + if invert: + duplicate_mask = duplicate_mask.logical_not() + + mask = paddle.empty_like(duplicate_mask) + if not in_dynamic_or_pir_mode(): + mask = paddle.static.setitem(mask, sorted_index, duplicate_mask) + else: + mask[sorted_index] = duplicate_mask + + cmp = mask[0 : x.numel()].reshape(x.shape) + else: + # otherwise use searchsorted method + sorted_test_x = paddle.sort(test_x_flat) + idx = paddle.searchsorted(sorted_test_x, x_flat) + test_idx = paddle.where( + idx < sorted_test_x.numel(), + idx, + paddle.zeros_like(idx, 'int64'), + ) + cmp = sorted_test_x[test_idx] == x_flat + cmp = cmp.logical_not() if invert else cmp + cmp = cmp.reshape(x.shape) + + if x_zero_dim: + return cmp.reshape([]) + else: + return cmp diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 736ae891f2fb8..9ec4cd1e2ec7f 100755 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -130,6 +130,7 @@ def argsort(x, axis=-1, descending=False, stable=False, name=None): x, 'x', [ + 'uint16', 'float16', 'float32', 'float64', diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index 8c4cfe9113ab3..6cd56d29e5791 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -753,6 +753,7 @@ if(WITH_DISTRIBUTE) endif() # setting timeout value as 15S +set_tests_properties(test_isin PROPERTIES TIMEOUT 30) set_tests_properties(test_binomial_op PROPERTIES TIMEOUT 30) set_tests_properties(test_run PROPERTIES TIMEOUT 120) set_tests_properties(test_sync_batch_norm_op PROPERTIES TIMEOUT 180) diff --git a/test/legacy_test/test_isin.py b/test/legacy_test/test_isin.py new file mode 100644 index 0000000000000..101d89b4de84f --- /dev/null +++ b/test/legacy_test/test_isin.py @@ -0,0 +1,327 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from op_test import convert_float_to_uint16 + +import paddle +from paddle import base +from paddle.base import core +from paddle.pir_utils import test_with_pir_api + +DATA_CASES = [ + {'x_data': np.array(1.0), 'test_x_data': np.array(-1.0)}, + { + 'x_data': np.random.randint(-10, 10, (4, 8)), + 'test_x_data': np.random.randint(0, 20, (2, 3)), + }, + { + 'x_data': np.random.randint(-50, 50, (8, 64)), + 'test_x_data': np.random.randint(-20, 0, (4, 256)), + }, +] + +DATA_CASES_UNIQUE = [ + { + 'x_data': np.arange(0, 1000).reshape([2, 5, 100]), + 'test_x_data': np.arange(200, 700), + }, + { + 'x_data': np.arange(-100, 100).reshape([2, 2, 5, 10]), + 'test_x_data': np.arange(50, 150).reshape([4, 5, 5]), + }, +] + +DATA_CASES_BF16 = [ + {'x_data': np.array(1.0), 'test_x_data': np.array(0.0)}, + { + 'x_data': np.random.randint(0, 10, (4, 8)), + 'test_x_data': np.random.randint(5, 15, (2, 3)), + }, + { + 'x_data': np.random.randint(0, 50, (8, 64)), + 'test_x_data': np.random.randint(0, 20, (4, 256)), + }, +] + + +DATA_CASES_UNIQUE_BF16 = [ + { + 'x_data': np.arange(0, 100).reshape([2, 5, 10]), + 'test_x_data': np.arange(50, 150), + }, +] + + +DATA_TYPE = ['float32', 'float64', 'int32', 'int64'] + + +def run_dygraph( + x_data, + test_x_data, + type, + assume_unique=False, + invert=False, + use_gpu=False, +): + place = paddle.CPUPlace() + if use_gpu and base.core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + paddle.disable_static(place) + x_data = x_data.astype(type) + test_x_data = test_x_data.astype(type) + x_e = paddle.to_tensor(x_data) + x_t = paddle.to_tensor(test_x_data) + return paddle.isin(x_e, x_t, assume_unique, invert) + + +def run_static( + x_data, + test_x_data, + type, + assume_unique=False, + invert=False, + use_gpu=False, +): + paddle.enable_static() + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + place = paddle.CPUPlace() + if use_gpu and base.core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = base.Executor(place) + with paddle.static.program_guard(main_program, startup_program): + x_data = x_data.astype(type) + test_x_data = test_x_data.astype(type) + x_e = paddle.static.data(name='x_e', shape=x_data.shape, dtype=type) + x_t = paddle.static.data( + name='x_t', shape=test_x_data.shape, dtype=type + ) + res = paddle.isin(x_e, x_t, assume_unique, invert) + static_result = exe.run( + feed={'x_e': x_data, 'x_t': test_x_data}, + fetch_list=[res], + ) + return static_result + + +def test( + data_cases, type_cases, assume_unique=False, invert=False, use_gpu=False +): + for type in type_cases: + for case in data_cases: + x_data = case['x_data'] + test_x_data = case['test_x_data'] + dygraph_result = run_dygraph( + x_data, + test_x_data, + type, + assume_unique, + invert, + use_gpu, + ).numpy() + np_result = np.isin( + x_data.astype(type), + test_x_data.astype(type), + assume_unique=assume_unique, + invert=invert, + ) + np.testing.assert_equal(dygraph_result, np_result) + + @test_with_pir_api + def test_static(): + (static_result,) = run_static( + x_data, + test_x_data, + type, + assume_unique, + invert, + use_gpu, + ) + np.testing.assert_equal(static_result, np_result) + + test_static() + + +def run_dygraph_bf16( + x_data, + test_x_data, + assume_unique=False, + invert=False, + use_gpu=False, +): + place = paddle.CPUPlace() + if use_gpu and base.core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + paddle.disable_static(place) + x_e = paddle.to_tensor(convert_float_to_uint16(x_data)) + x_t = paddle.to_tensor(convert_float_to_uint16(test_x_data)) + return paddle.isin(x_e, x_t, assume_unique, invert) + + +def run_static_bf16( + x_data, + test_x_data, + assume_unique=False, + invert=False, + use_gpu=False, +): + paddle.enable_static() + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + place = paddle.CPUPlace() + if use_gpu and base.core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = base.Executor(place) + with paddle.static.program_guard(main_program, startup_program): + x_data = convert_float_to_uint16(x_data) + test_x_data = convert_float_to_uint16(test_x_data) + x_e = paddle.static.data( + name='x_e', shape=x_data.shape, dtype=np.uint16 + ) + x_t = paddle.static.data( + name='x_t', shape=test_x_data.shape, dtype=np.uint16 + ) + res = paddle.isin(x_e, x_t, assume_unique, invert) + static_result = exe.run( + feed={'x_e': x_data, 'x_t': test_x_data}, + fetch_list=[res], + ) + return static_result + + +def test_bf16(data_cases, assume_unique=False, invert=False, use_gpu=False): + for case in data_cases: + x_data = case['x_data'].astype("float32") + test_x_data = case['test_x_data'].astype("float32") + dygraph_result = run_dygraph_bf16( + x_data, + test_x_data, + assume_unique, + invert, + use_gpu, + ).numpy() + np_result = np.isin( + x_data, + test_x_data, + assume_unique=assume_unique, + invert=invert, + ) + np.testing.assert_equal(dygraph_result, np_result) + + @test_with_pir_api + def test_static(): + (static_result,) = run_static_bf16( + x_data, + test_x_data, + assume_unique, + invert, + use_gpu, + ) + np.testing.assert_equal(static_result, np_result) + + test_static() + + +class TestIsInError(unittest.TestCase): + def test_for_exception(self): + with self.assertRaises(TypeError): + paddle.isin(np.array([1, 2]), np.array([1, 2])) + + +class TestIsIn(unittest.TestCase): + def test_without_gpu(self): + test(DATA_CASES, DATA_TYPE) + + def test_with_gpu(self): + test(DATA_CASES, DATA_TYPE, use_gpu=True) + + def test_invert_without_gpu(self): + test(DATA_CASES, DATA_TYPE, invert=True) + + def test_invert_with_gpu(self): + test(DATA_CASES, DATA_TYPE, invert=True, use_gpu=True) + + def test_unique_without_gpu(self): + test(DATA_CASES_UNIQUE, DATA_TYPE, assume_unique=True) + + def test_unique_with_gpu(self): + test(DATA_CASES_UNIQUE, DATA_TYPE, assume_unique=True, use_gpu=True) + + def test_unique_invert_without_gpu(self): + test(DATA_CASES_UNIQUE, DATA_TYPE, assume_unique=True, invert=True) + + def test_unique_invert_with_gpu(self): + test( + DATA_CASES_UNIQUE, + DATA_TYPE, + assume_unique=True, + invert=True, + use_gpu=True, + ) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_float16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and not support the float16", +) +class TestIsInFP16(unittest.TestCase): + def test_default(self): + test(DATA_CASES, ['float16'], use_gpu=True) + + def test_invert(self): + test(DATA_CASES, ['float16'], invert=True, use_gpu=True) + + def test_unique(self): + test(DATA_CASES_UNIQUE, ['float16'], assume_unique=True, use_gpu=True) + + def test_unique_invert(self): + test( + DATA_CASES_UNIQUE, + ['float16'], + assume_unique=True, + invert=True, + use_gpu=True, + ) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_float16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and not support the float16", +) +class TestIsInBF16(unittest.TestCase): + def test_default(self): + test_bf16(DATA_CASES_BF16, use_gpu=True) + + def test_invert(self): + test_bf16(DATA_CASES_BF16, invert=True, use_gpu=True) + + def test_unique(self): + test_bf16(DATA_CASES_UNIQUE_BF16, assume_unique=True, use_gpu=True) + + def test_unique_invert(self): + test_bf16( + DATA_CASES_UNIQUE_BF16, + assume_unique=True, + invert=True, + use_gpu=True, + ) + + +if __name__ == '__main__': + unittest.main()