-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PaddlePaddle Hackathon 3 No.16】为 Paddle 新增 API paddle.take #44741
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
|
||
|
||
class TestTakeType(TestTakeAPI): | ||
"""Test take Error""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
缺少index索引越界的报错
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我直接通过 paddle.index_select
来报错。
Done
python/paddle/tensor/math.py
Outdated
The result takes the same shape as the indices. | ||
|
||
Args: | ||
input (Tensor): An N-D Tensor, which data type should be int32, int64, float32, float64. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
which data type-》its data type
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
"The data type of 'index' must be one of ['int32', 'int64'], but got {}".format( | ||
index.dtype)) | ||
else: | ||
check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'take') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
index索引越界时需要报错
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
x.take(idx) | ||
# Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True, | ||
# [[4, 5, 6], | ||
# [7, 8, 9]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
示例可增加一个negative index和float类型的input
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
python/paddle/tensor/math.py
Outdated
|
||
n = np.arange(0, 12).reshape([3, 4]) | ||
x_int = paddle.to_tensor(n, dtype='int64') | ||
x_float = paddle.to_tensor(n, dtype='float64') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以使用paddle API直接生成输入的情况下,尽量避免引入第三方库哈~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docs 的 pr 也需要改一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for docs
@S-HuaBomb 和 @jeff41404 讨论后,需要根据 PaddlePaddle/community#186 (review) 重新修改下RFC和PR |
python/paddle/tensor/math.py
Outdated
""" | ||
Returns a new tensor with the elements of input at the given index. | ||
The input tensor is treated as if it were viewed as a 1-D tensor. | ||
Returns a new tensor with the elements of tnput tensor x at the given index. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tnput
?是个 typo 嘛?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,fixed,done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 请 @Ligoml review下文档部分
python/paddle/tensor/math.py
Outdated
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. | ||
|
||
Returns: | ||
Tensor: Tensor with the same shape as index, the data type is the same with input. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tensor 后使用 ,
,以避免解析出 Return Type
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx, done.
python/paddle/tensor/math.py
Outdated
x (Tensor): An N-D Tensor, its data type should be int32, int64, float32, float64. | ||
index (Tensor): An N-D Tensor, its data type should be int32, int64. | ||
mode (str, optional): Specifies how out-of-bounds index will behave. | ||
the candicates are ``'raise'`` | ``'wrap'`` | ``'clip'``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里用 ,
分隔即可,下面的内容按照中文文档那边的意见统一改成列表吧~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx, done.
python/paddle/tensor/math.py
Outdated
- ``'raise'``: raise an error (default); | ||
- ``'wrap'``: wrap around; | ||
- ``'clip'``: clip to the range. ``'clip'`` mode means that all indices that are too large are replaced by | ||
the index that addresses the last element. Note that this disables indexing with negative numbers. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
另外解决一下冲突~ |
Done. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for docs
@@ -4776,3 +4775,107 @@ def sgn(x, name=None): | |||
return paddle.as_complex(output) | |||
else: | |||
return paddle.sign(x) | |||
|
|||
def take(x, index, mode='raise', name=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the name of parameter needs to be consistent with rfc, input
in rfc while x
here, and mode
is not in rfc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jeff41404 根据之前的修改意见 PaddlePaddle/community#186 (review) 更新过RFC:PaddlePaddle/community#217
参数的名字按照新的RFC内容进行修改的。
@S-HuaBomb 请先修改完RFC的评审意见吧。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rfc is still old now, should update and merge rfc first
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the modified RFC PaddlePaddle/community#217 with instructions added
# Negative indexes can be enabled, | ||
# but out-of-range indexes will report an error in the following paddle.index_select | ||
index_1d = paddle.where(index_1d < 0, index_1d % max_index, index_1d) | ||
index_1d = paddle.where(index_1d < 0, index_1d + max_index, index_1d) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@S-HuaBomb 这个修改是哪个case会出bug呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里本来就应该是 +
号,这样确保负值索引是在合理范围内的,我只需要 +
号把它转成对应的负值索引。如果使用取余 %
,那么超出范围的负值索引也会被约束到合理范围,那样是不对的。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
那么超出范围的负值索引也会被约束到合理范围,那样是不对的。
Got it. 可以针对这个case补充一个报错的单测么?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已加,done.
LGTM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
APIs
Describe
完成飞桨黑客松第三期第16项目开发任务: #44073 (comment)
增加 API
paddle.take
,对于输入的 Tensor,将输入 Tensor 视为一维 Tensor,实现根据索引返回指定索引上的元素集合组成的新 Tensor。返回结果与索引的形状相同。RFC 设计文档: https://github.com/PaddlePaddle/community/blob/master/rfcs/APIs/20220714_api_design_for_take.md
docs 中文文档:PaddlePaddle/docs#5099