-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PaddlePaddle Hackathon 3 No.16】为 Paddle 新增 API paddle.take #44741
Changes from 1 commit
982d01e
69b0a3e
b07c062
09d2836
c8482f6
0665e50
c5a9e16
10b41c4
6852760
9649b87
ec1cfd7
6806a8f
27b6943
5d32c52
b35d831
cc2f4f4
c4161f2
aaee858
ca2604f
5979d5f
7b3fc1d
668964d
64b688a
cdd1080
eca0483
4ca5c41
9fb6896
7fd6c85
046ff44
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4701,15 +4701,22 @@ def frac(x, name=None): | |
type="trunc", inputs=inputs, attrs=attrs, outputs={"Out": y}) | ||
return _elementwise_op(LayerHelper(op_type, **locals())) | ||
|
||
def take(input, index, name=None): | ||
def take(x, index, mode='raise', name=None): | ||
""" | ||
Returns a new tensor with the elements of input at the given index. | ||
The input tensor is treated as if it were viewed as a 1-D tensor. | ||
Returns a new tensor with the elements of tnput tensor x at the given index. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是的,fixed,done |
||
The input tensor is treated as if it were viewed as a 1-D tensor. | ||
The result takes the same shape as the index. | ||
|
||
Args: | ||
input (Tensor): An N-D Tensor, its data type should be int32, int64, float32, float64. | ||
x (Tensor): An N-D Tensor, its data type should be int32, int64, float32, float64. | ||
index (Tensor): An N-D Tensor, its data type should be int32, int64. | ||
mode (str, optional): Specifies how out-of-bounds index will behave. | ||
the candicates are ``'raise'`` | ``'wrap'`` | ``'clip'``. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里用 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thx, done. |
||
If :attr:`mode` is ``'raise'``, raise an error (default); | ||
If :attr:`mode` is ``'wrap'``, wrap around; | ||
If :attr:`mode` is ``'clip'``, clip to the range. | ||
``'clip'`` mode means that all indices that are too large are replaced by the index that | ||
addresses the last element. Note that this disables indexing with negative numbers. | ||
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. | ||
|
||
Returns: | ||
|
@@ -4746,6 +4753,9 @@ def take(input, index, name=None): | |
# [[4, 5, 6], | ||
# [7, 8, 9]]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 示例可增加一个negative index和float类型的input There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
""" | ||
if mode not in ['raise', 'wrap', 'clip']: | ||
raise ValueError( | ||
"'mode' in 'take' should be 'raise', 'wrap', 'clip', but received {}.".format(mode)) | ||
|
||
if paddle.in_dynamic_mode(): | ||
if not isinstance(index, (paddle.Tensor, Variable)): | ||
|
@@ -4755,14 +4765,28 @@ def take(input, index, name=None): | |
raise TypeError( | ||
"The data type of 'index' must be one of ['int32', 'int64'], but got {}".format( | ||
index.dtype)) | ||
|
||
else: | ||
check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'take') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. index索引越界时需要报错 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
input_1d = input.flatten() | ||
input_1d = x.flatten() | ||
index_1d = index.flatten() | ||
max_index = input_1d.shape[-1] | ||
|
||
if mode == 'raise': | ||
# This processing enables 'take' to handle negative indexes within the correct range. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可以补充下注释,negative indexes可以enable,但越界的索引会在下面的index_select报错 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. THX,Done |
||
index_1d = paddle.where(index_1d < 0, index_1d % max_index, index_1d) | ||
pass | ||
elif mode == 'wrap': | ||
# The out of range indices are constrained by taking the remainder. | ||
index_1d = paddle.where(index_1d < 0, | ||
index_1d % max_index, index_1d) | ||
index_1d = paddle.where(index_1d >= max_index, | ||
index_1d % max_index, index_1d) | ||
elif mode == 'clip': | ||
# 'clip' mode disables indexing with negative numbers. | ||
index_1d = clip(index_1d, 0, max_index - 1) | ||
|
||
# This processing enables 'take' to handle negative indexes within the correct range | ||
index_1d = paddle.where(index_1d < 0, index_1d + input_1d.shape[0], index_1d) | ||
out = input_1d.index_select(index_1d).reshape(index.shape) | ||
|
||
return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the name of parameter needs to be consistent with rfc,
input
in rfc whilex
here, andmode
is not in rfc.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jeff41404 根据之前的修改意见 PaddlePaddle/community#186 (review) 更新过RFC:PaddlePaddle/community#217
参数的名字按照新的RFC内容进行修改的。
@S-HuaBomb 请先修改完RFC的评审意见吧。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rfc is still old now, should update and merge rfc first
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the modified RFC PaddlePaddle/community#217 with instructions added