Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[npu]Add argsort op #34865

Merged
merged 7 commits into from
Aug 20, 2021
Merged

[npu]Add argsort op #34865

merged 7 commits into from
Aug 20, 2021

Conversation

lzzyzlbb
Copy link
Contributor

PR types

New features

PR changes

OPs

Describe

[NPU] Support npu kernel for argsort and argsort_grad
image
image
image

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

LielinJiang
LielinJiang previously approved these changes Aug 13, 2021
limitations under the License. */

#include <memory>
#include <string>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12-13行头文件删掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经删掉啦

indices->mutable_data<int32_t>(ctx.GetPlace());

int32_t axis = ctx.Attr<int>("axis");
bool descending = ctx.Attr<bool>("descending");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加一行 axis = (axis < 0) ? (in_dims.size() + axis) : axis; 使得axis永远为正数。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经添加啦

framework::NPUAttributeMap sort_attr_input = {{"axis", npu_axis},
{"descending", descending}};

if (axis != -1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改为 if (axis == -1 || axis + 1 == in_dims.size()) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经修改啦

.stream();
int32_t npu_axis = -1;
framework::NPUAttributeMap sort_attr_input = {{"axis", npu_axis},
{"descending", descending}};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

冗余定义 npu_axis,直接写成

framework::NPUAttributeMap sort_attr_input = {{"axis", static_cast<int32_t>(-1)},
                                                  {"descending", descending}};

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经修改啦



@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除所有34-35行的类似代码

@unittest.skipIf(not paddle.is_compiled_with_npu(),
                 "core is not compiled with NPU")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经全部删除

qili93
qili93 previously approved these changes Aug 18, 2021
Copy link
Contributor

@qili93 qili93 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@LielinJiang LielinJiang changed the title Add argsort npu [npu]Add argsort op Aug 18, 2021
LielinJiang
LielinJiang previously approved these changes Aug 18, 2021
}
};

template <typename T, typename Type>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the typename T used for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

framework::Tensor* output) {
output->ShareDataWith(*input);
output->Resize(framework::make_ddim(std::move(input_shapes)));
input = output;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of this statement?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

.AddInput(input_reshape_tensor)
.AddOutput(input_scatter_tensor);
runner_scatter.Run(stream);
ReshapeNPU<T, Type>(ctx, &input_scatter_tensor, trans_shapes, t_out);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw t_out is ShareDataWith and a t_out can be dX, which is the output of the ArgsortGradNPUKernel, this is not allowed in PaddlePaddle.

Reference: https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/07_new_op/op_notes_cn.html#sharedatawith

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已改成拷贝

@lzzyzlbb lzzyzlbb dismissed stale reviews from LielinJiang and qili93 via a068ac3 August 19, 2021 07:40
};

template <typename Type>
static void ReshapeNPU(const framework::ExecutionContext& ctx,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ctx 是冗余的参数

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

qili93
qili93 previously approved these changes Aug 19, 2021
Copy link
Contributor

@qili93 qili93 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

zhhsplendid
zhhsplendid previously approved these changes Aug 20, 2021
Copy link
Member

@zhhsplendid zhhsplendid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for ShareDataWith

@lzzyzlbb lzzyzlbb dismissed stale reviews from zhhsplendid and qili93 via 7606e4d August 20, 2021 03:50
Copy link
Member

@zhhsplendid zhhsplendid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@qili93 qili93 merged commit 99ffeff into PaddlePaddle:develop Aug 20, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants