Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PaddlePaddle Hackathon 3 No.16】为 Paddle 新增 API paddle.take #44741

Merged
merged 29 commits into from
Aug 30, 2022

Conversation

S-HuaBomb
Copy link
Contributor

@S-HuaBomb S-HuaBomb commented Jul 29, 2022

PR types

New features

PR changes

APIs

Describe

完成飞桨黑客松第三期第16项目开发任务: #44073 (comment)

增加 API paddle.take,对于输入的 Tensor,将输入 Tensor 视为一维 Tensor,实现根据索引返回指定索引上的元素集合组成的新 Tensor。返回结果与索引的形状相同。

RFC 设计文档: https://github.com/PaddlePaddle/community/blob/master/rfcs/APIs/20220714_api_design_for_take.md
docs 中文文档:PaddlePaddle/docs#5099

@paddle-bot
Copy link

paddle-bot bot commented Jul 29, 2022

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added contributor External developers status: proposed labels Jul 29, 2022


class TestTakeType(TestTakeAPI):
"""Test take Error"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

缺少index索引越界的报错

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我直接通过 paddle.index_select 来报错。
Done

The result takes the same shape as the indices.

Args:
input (Tensor): An N-D Tensor, which data type should be int32, int64, float32, float64.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which data type-》its data type

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

"The data type of 'index' must be one of ['int32', 'int64'], but got {}".format(
index.dtype))
else:
check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'take')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index索引越界时需要报错

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

x.take(idx)
# Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True,
# [[4, 5, 6],
# [7, 8, 9]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

示例可增加一个negative index和float类型的input

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@luotao1 luotao1 self-requested a review August 5, 2022 02:27
Copy link
Contributor

@luotao1 luotao1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@luotao1 luotao1 requested a review from Ligoml August 5, 2022 03:46

n = np.arange(0, 12).reshape([3, 4])
x_int = paddle.to_tensor(n, dtype='int64')
x_float = paddle.to_tensor(n, dtype='float64')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以使用paddle API直接生成输入的情况下,尽量避免引入第三方库哈~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docs 的 pr 也需要改一下

Ligoml
Ligoml previously approved these changes Aug 5, 2022
Copy link
Contributor

@Ligoml Ligoml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for docs

@luotao1
Copy link
Contributor

luotao1 commented Aug 9, 2022

@S-HuaBomb@jeff41404 讨论后,需要根据 PaddlePaddle/community#186 (review) 重新修改下RFC和PR

"""
Returns a new tensor with the elements of input at the given index.
The input tensor is treated as if it were viewed as a 1-D tensor.
Returns a new tensor with the elements of tnput tensor x at the given index.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tnput?是个 typo 嘛?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,fixed,done

Copy link
Contributor

@luotao1 luotao1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 请 @Ligoml review下文档部分

name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

Returns:
Tensor: Tensor with the same shape as index, the data type is the same with input.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tensor 后使用 ,,以避免解析出 Return Type

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx, done.

x (Tensor): An N-D Tensor, its data type should be int32, int64, float32, float64.
index (Tensor): An N-D Tensor, its data type should be int32, int64.
mode (str, optional): Specifies how out-of-bounds index will behave.
the candicates are ``'raise'`` | ``'wrap'`` | ``'clip'``.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里用 , 分隔即可,下面的内容按照中文文档那边的意见统一改成列表吧~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx, done.

- ``'raise'``: raise an error (default);
- ``'wrap'``: wrap around;
- ``'clip'``: clip to the range. ``'clip'`` mode means that all indices that are too large are replaced by
the index that addresses the last element. Note that this disables indexing with negative numbers.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

啊哦,这里报了一个warning,需要在这一行和参数详解之间增加一个空行来解决

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@Ligoml
Copy link
Contributor

Ligoml commented Aug 25, 2022

另外解决一下冲突~

@S-HuaBomb
Copy link
Contributor Author

另外解决一下冲突~

Done.

luotao1
luotao1 previously approved these changes Aug 26, 2022
Ligoml
Ligoml previously approved these changes Aug 26, 2022
Copy link
Contributor

@Ligoml Ligoml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for docs

@luotao1 luotao1 requested a review from jeff41404 August 26, 2022 03:30
@@ -4776,3 +4775,107 @@ def sgn(x, name=None):
return paddle.as_complex(output)
else:
return paddle.sign(x)

def take(x, index, mode='raise', name=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the name of parameter needs to be consistent with rfc, input in rfc while x here, and mode is not in rfc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeff41404 根据之前的修改意见 PaddlePaddle/community#186 (review) 更新过RFC:PaddlePaddle/community#217
参数的名字按照新的RFC内容进行修改的。

@S-HuaBomb 请先修改完RFC的评审意见吧。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rfc is still old now, should update and merge rfc first

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the modified RFC PaddlePaddle/community#217 with instructions added

@S-HuaBomb S-HuaBomb dismissed stale reviews from Ligoml and luotao1 via 9fb6896 August 27, 2022 05:02
# Negative indexes can be enabled,
# but out-of-range indexes will report an error in the following paddle.index_select
index_1d = paddle.where(index_1d < 0, index_1d % max_index, index_1d)
index_1d = paddle.where(index_1d < 0, index_1d + max_index, index_1d)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@S-HuaBomb 这个修改是哪个case会出bug呢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里本来就应该是 + 号,这样确保负值索引是在合理范围内的,我只需要 + 号把它转成对应的负值索引。如果使用取余 %,那么超出范围的负值索引也会被约束到合理范围,那样是不对的。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那么超出范围的负值索引也会被约束到合理范围,那样是不对的。

Got it. 可以针对这个case补充一个报错的单测么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已加,done.

@luotao1
Copy link
Contributor

luotao1 commented Aug 29, 2022

LGTM
@jeff41404 PaddlePaddle/community#217 修改后的RFC文档已经合入,请再次审核~

Copy link
Contributor

@jeff41404 jeff41404 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@luotao1 luotao1 merged commit 5f1a8e4 into PaddlePaddle:develop Aug 30, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants