-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NPU] add one_hot_op_npu and tests #34258
Conversation
Thanks for your contribution! |
|
||
|
||
@unittest.skipIf(not paddle.is_compiled_with_npu(), | ||
"core is not compiled with NPU") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
秋良在这个PR #34240 里面修改了,如果这个PR在秋良的PR之后合入的话,可以把这里的skipIf代码去掉。后面所有的class代码都一样可以去掉。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已更新
self.outputs = {'Out': (out, x_lod)} | ||
|
||
def test_check_output(self): | ||
self.check_output_with_place(paddle.NPUPlace(0), check_dygraph=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check_dygraph=False 删掉,后面的也是
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参考test_one_hot_op.py,删掉check_dygraph=False 会报错
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
self.outputs = {'Out': (out, x_lod)} | ||
|
||
def test_check_output(self): | ||
self.check_output_with_place(paddle.NPUPlace(0), check_dygraph=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
runner.SetType("OneHot") | ||
.AddInput(transformed_in) | ||
.AddInput(std::vector<int32_t>({static_cast<int32_t>(depth)})) | ||
.AddInput(on_value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make code clear, maybe on_value
and off_value
can write like depth
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
OPs
Describe
add one_hot_op_npu and tests
ascend的one-hot算子支持out of range,类似cuda kernel实现