Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Zero-Dim] Fix 0-dim tensor for arg_min_max op. #49570

Merged
merged 15 commits into from
Feb 1, 2023

Conversation

ZHUI
Copy link
Collaborator

@ZHUI ZHUI commented Jan 5, 2023

PR types

New features

PR changes

OPs

Describe

Fix 0-dim tensor for arg_min_max op.

@ZHUI ZHUI requested a review from zhwesky2010 January 5, 2023 07:51
if (config.is_runtime) {
if (dtype == phi::TransToProtoVarType(DataType::INT32)) {
int64_t all_element_num = 0;
if (flatten) {
if (x_rank == 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个也不用写分支,因为phi::product(x_dims); 里面已经支持了0D的product计算结果是1

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

额,这里有个flatten的配置,怕这个 为 false的话,后面可能有问题。

Copy link
Contributor

@zhwesky2010 zhwesky2010 Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0D的时候axis只能为None,就是flatten的情况

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加下axis的检查吧,0D的axis只能为None

dev_ctx.template Alloc<int64_t>(out);
if (x.dims().size() == 0) {
phi::funcs::set_constant(dev_ctx, out, 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XPU的设置为常数:

xpu::constant<T>(
        dev_ctx.x_context(), dx_data, x->numel(), static_cast<T>(1.0));

@ZHUI ZHUI requested a review from zhwesky2010 February 1, 2023 08:47
Copy link
Contributor

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要再微调下单测写法:
单测在TestSundry和TestSundryStatic里分别实现,测试axis=0/-1/None三种情况,测试前向shape与前向值

@zhwesky2010 zhwesky2010 merged commit e4e94a8 into PaddlePaddle:develop Feb 1, 2023
@ZHUI ZHUI deleted the fix_0d_tensor branch February 1, 2023 09:03
pangengzheng pushed a commit to pangengzheng/Paddle that referenced this pull request Feb 2, 2023
* fix 0-d tensor for arg_min_max op.

* fix xpu.

* fix zero dims

* fix

* Update arg_min_max_kernel.cc

* Update arg_min_max_kernel.cc

* Update arg_min_max_kernel.cc

* Update test_zero_dim_tensor.py

* Update test_zero_dim_tensor_xpu.py

* Update test_zero_dim_tensor.py

* Update arg_min_max_kernel.cc

* Update arg_min_max_kernel.cc

* Update arg_min_max_kernel.cc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants