-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 6th No.10】Add isin API to Paddle -part #64001
Changes from all commits
6464b58
321a03a
0816d26
9dd666e
9130a58
5372eb4
77484fc
3f444ae
0d7eb59
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7969,3 +7969,187 @@ def sinc_(x, name=None): | |
paddle.sin_(x) | ||
paddle.divide_(x, tmp) | ||
return paddle.where(~paddle.isnan(x), x, paddle.full_like(x, 1.0)) | ||
|
||
|
||
def isin(x, test_x, assume_unique=False, invert=False, name=None): | ||
r""" | ||
Tests if each element of `x` is in `test_x`. | ||
|
||
Args: | ||
x (Tensor): The input Tensor. Supported data type: 'bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64'. | ||
test_x (Tensor): Tensor values against which to test for each input element. Supported data type: 'bfloat16', 'float16', 'float32', 'float64', 'int32', 'int64'. | ||
assume_unique (bool, optional): If True, indicates both `x` and `test_x` contain unique elements, which could make the calculation faster. Default: False. | ||
invert (bool, optional): Indicate whether to invert the boolean return tensor. If True, invert the results. Default: False. | ||
name (str, optional): Name for the operation (optional, default is None).For more information, please refer to :ref:`api_guide_Name`. | ||
|
||
Returns: | ||
out (Tensor), The output Tensor with the same shape as `x`. | ||
|
||
Examples: | ||
.. code-block:: python | ||
|
||
>>> import paddle | ||
>>> paddle.set_device('cpu') | ||
>>> x = paddle.to_tensor([-0., -2.1, 2.5, 1.0, -2.1], dtype='float32') | ||
>>> test_x = paddle.to_tensor([-2.1, 2.5], dtype='float32') | ||
>>> res = paddle.isin(x, test_x) | ||
>>> print(res) | ||
Tensor(shape=[5], dtype=bool, place=Place(cpu), stop_gradient=True, | ||
[False, True, True, False, True]) | ||
|
||
>>> x = paddle.to_tensor([-0., -2.1, 2.5, 1.0, -2.1], dtype='float32') | ||
>>> test_x = paddle.to_tensor([-2.1, 2.5], dtype='float32') | ||
>>> res = paddle.isin(x, test_x, invert=True) | ||
>>> print(res) | ||
Tensor(shape=[5], dtype=bool, place=Place(cpu), stop_gradient=True, | ||
[True, False, False, True, False]) | ||
|
||
>>> # Set `assume_unique` to True only when `x` and `test_x` contain unique values, otherwise the result may be incorrect. | ||
>>> x = paddle.to_tensor([0., 1., 2.]*20).reshape([20, 3]) | ||
>>> test_x = paddle.to_tensor([0., 1.]*20) | ||
>>> correct_result = paddle.isin(x, test_x, assume_unique=False) | ||
>>> print(correct_result) | ||
Tensor(shape=[20, 3], dtype=bool, place=Place(cpu), stop_gradient=True, | ||
[[True , True , False], | ||
[True , True , False], | ||
[True , True , False], | ||
[True , True , False], | ||
[True , True , False], | ||
[True , True , False], | ||
[True , True , False], | ||
[True , True , False], | ||
[True , True , False], | ||
[True , True , False], | ||
[True , True , False], | ||
[True , True , False], | ||
[True , True , False], | ||
[True , True , False], | ||
[True , True , False], | ||
[True , True , False], | ||
[True , True , False], | ||
[True , True , False], | ||
[True , True , False], | ||
[True , True , False]]) | ||
|
||
>>> incorrect_result = paddle.isin(x, test_x, assume_unique=True) | ||
>>> print(incorrect_result) | ||
Tensor(shape=[20, 3], dtype=bool, place=Place(gpu:0), stop_gradient=True, | ||
[[True , True , True ], | ||
[True , True , True ], | ||
[True , True , True ], | ||
[True , True , True ], | ||
[True , True , True ], | ||
[True , True , True ], | ||
[True , True , True ], | ||
[True , True , True ], | ||
[True , True , True ], | ||
[True , True , True ], | ||
[True , True , True ], | ||
[True , True , True ], | ||
[True , True , True ], | ||
[True , True , True ], | ||
[True , True , True ], | ||
[True , True , True ], | ||
[True , True , True ], | ||
[True , True , True ], | ||
[True , True , True ], | ||
[True , True , False]]) | ||
|
||
""" | ||
if not isinstance(x, (paddle.Tensor, Variable, paddle.pir.Value)): | ||
raise TypeError(f"x must be tensor type, but got {type(x)}") | ||
if not isinstance(test_x, (paddle.Tensor, Variable, paddle.pir.Value)): | ||
raise TypeError(f"x must be tensor type, but got {type(test_x)}") | ||
|
||
check_variable_and_dtype( | ||
x, | ||
"x", | ||
[ | ||
'uint16', | ||
'float16', | ||
'float32', | ||
'float64', | ||
'int32', | ||
'int64', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we support There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Temporarily not since |
||
], | ||
"isin", | ||
) | ||
|
||
check_variable_and_dtype( | ||
test_x, | ||
"test_x", | ||
[ | ||
'uint16', | ||
'float16', | ||
'float32', | ||
'float64', | ||
'int32', | ||
'int64', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same date type issue as above There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Temporarily not since |
||
], | ||
"isin", | ||
) | ||
|
||
x_zero_dim = False | ||
if len(x.shape) == 0: | ||
x = x.reshape([1]) | ||
x_zero_dim = True | ||
|
||
size_x = math.prod(x.shape) | ||
size_t = math.prod(test_x.shape) | ||
if size_t < math.pow(size_x, 0.145) * 10.0: | ||
# use brute-force searching if the test_x size is small | ||
if len(x.shape) == 0: | ||
return paddle.zeros([], dtype='bool') | ||
|
||
tmp = x.reshape(tuple(x.shape) + ((1,) * test_x.ndim)) | ||
cmp = tmp == test_x | ||
dim = tuple(range(-1, -test_x.ndim - 1, -1)) | ||
cmp = cmp.any(axis=dim) | ||
if invert: | ||
cmp = ~cmp | ||
else: | ||
x_flat = x.flatten() | ||
test_x_flat = test_x.flatten() | ||
if assume_unique: | ||
# if x and test_x both contain unique elements, use stable argsort method which could be faster | ||
all_elements = paddle.concat([x_flat, test_x_flat]) | ||
sorted_index = paddle.argsort(all_elements, stable=True) | ||
sorted_x = all_elements[sorted_index] | ||
|
||
duplicate_mask = paddle.full_like(sorted_index, False, dtype='bool') | ||
if not in_dynamic_mode(): | ||
duplicate_mask = paddle.static.setitem( | ||
duplicate_mask, | ||
paddle.arange(duplicate_mask.numel() - 1), | ||
sorted_x[1:] == sorted_x[:-1], | ||
) | ||
else: | ||
duplicate_mask[:-1] = sorted_x[1:] == sorted_x[:-1] | ||
|
||
if invert: | ||
duplicate_mask = duplicate_mask.logical_not() | ||
|
||
mask = paddle.empty_like(duplicate_mask) | ||
if not in_dynamic_or_pir_mode(): | ||
mask = paddle.static.setitem(mask, sorted_index, duplicate_mask) | ||
else: | ||
mask[sorted_index] = duplicate_mask | ||
|
||
cmp = mask[0 : x.numel()].reshape(x.shape) | ||
else: | ||
# otherwise use searchsorted method | ||
sorted_test_x = paddle.sort(test_x_flat) | ||
idx = paddle.searchsorted(sorted_test_x, x_flat) | ||
test_idx = paddle.where( | ||
idx < sorted_test_x.numel(), | ||
idx, | ||
paddle.zeros_like(idx, 'int64'), | ||
) | ||
cmp = sorted_test_x[test_idx] == x_flat | ||
cmp = cmp.logical_not() if invert else cmp | ||
cmp = cmp.reshape(x.shape) | ||
|
||
if x_zero_dim: | ||
return cmp.reshape([]) | ||
else: | ||
return cmp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
示例代码中,能否体现出 assume_unique=True/False 的区别?注释中是否要加入这段说明
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
增加了一个
assume_unique
设置错误导致结果出错的例子;在下面代码中加了不同分支对应不同做法的注释