-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[npu]Add argsort op #34865
[npu]Add argsort op #34865
Conversation
Thanks for your contribution! |
limitations under the License. */ | ||
|
||
#include <memory> | ||
#include <string> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
12-13行头文件删掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经删掉啦
indices->mutable_data<int32_t>(ctx.GetPlace()); | ||
|
||
int32_t axis = ctx.Attr<int>("axis"); | ||
bool descending = ctx.Attr<bool>("descending"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加一行 axis = (axis < 0) ? (in_dims.size() + axis) : axis; 使得axis永远为正数。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经添加啦
framework::NPUAttributeMap sort_attr_input = {{"axis", npu_axis}, | ||
{"descending", descending}}; | ||
|
||
if (axis != -1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改为 if (axis == -1 || axis + 1 == in_dims.size()) {
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经修改啦
.stream(); | ||
int32_t npu_axis = -1; | ||
framework::NPUAttributeMap sort_attr_input = {{"axis", npu_axis}, | ||
{"descending", descending}}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
冗余定义 npu_axis,直接写成
framework::NPUAttributeMap sort_attr_input = {{"axis", static_cast<int32_t>(-1)},
{"descending", descending}};
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经修改啦
|
||
|
||
@unittest.skipIf(not paddle.is_compiled_with_npu(), | ||
"core is not compiled with NPU") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删除所有34-35行的类似代码
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经全部删除
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
} | ||
}; | ||
|
||
template <typename T, typename Type> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the typename T used for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
framework::Tensor* output) { | ||
output->ShareDataWith(*input); | ||
output->Resize(framework::make_ddim(std::move(input_shapes))); | ||
input = output; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the purpose of this statement?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
.AddInput(input_reshape_tensor) | ||
.AddOutput(input_scatter_tensor); | ||
runner_scatter.Run(stream); | ||
ReshapeNPU<T, Type>(ctx, &input_scatter_tensor, trans_shapes, t_out); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw t_out
is ShareDataWith
and a t_out
can be dX
, which is the output of the ArgsortGradNPUKernel
, this is not allowed in PaddlePaddle.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已改成拷贝
}; | ||
|
||
template <typename Type> | ||
static void ReshapeNPU(const framework::ExecutionContext& ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ctx 是冗余的参数
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for ShareDataWith
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
OPs
Describe
[NPU] Support npu kernel for argsort and argsort_grad
![image](https://user-images.githubusercontent.com/17897185/129216064-7967b3a4-76ad-4843-bb30-5c358ffcefb5.png)
![image](https://user-images.githubusercontent.com/17897185/129216106-bb20f2bd-6966-4db5-8e33-185c96482a06.png)
![image](https://user-images.githubusercontent.com/17897185/129216139-e25a31f0-04bb-4221-8193-1da85f8c5076.png)