-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Zero-Dim] Fix 0-dim tensor for arg_min_max op. #49570
Conversation
paddle/phi/infermeta/unary.cc
Outdated
if (config.is_runtime) { | ||
if (dtype == phi::TransToProtoVarType(DataType::INT32)) { | ||
int64_t all_element_num = 0; | ||
if (flatten) { | ||
if (x_rank == 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个也不用写分支,因为phi::product(x_dims); 里面已经支持了0D的product计算结果是1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
额,这里有个flatten的配置,怕这个 为 false的话,后面可能有问题。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
0D的时候axis只能为None,就是flatten的情况
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加下axis的检查吧,0D的axis只能为None
dev_ctx.template Alloc<int64_t>(out); | ||
if (x.dims().size() == 0) { | ||
phi::funcs::set_constant(dev_ctx, out, 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
XPU的设置为常数:
xpu::constant<T>(
dev_ctx.x_context(), dx_data, x->numel(), static_cast<T>(1.0));
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要再微调下单测写法:
单测在TestSundry和TestSundryStatic里分别实现,测试axis=0/-1/None三种情况,测试前向shape与前向值
* fix 0-d tensor for arg_min_max op. * fix xpu. * fix zero dims * fix * Update arg_min_max_kernel.cc * Update arg_min_max_kernel.cc * Update arg_min_max_kernel.cc * Update test_zero_dim_tensor.py * Update test_zero_dim_tensor_xpu.py * Update test_zero_dim_tensor.py * Update arg_min_max_kernel.cc * Update arg_min_max_kernel.cc * Update arg_min_max_kernel.cc
PR types
New features
PR changes
OPs
Describe
Fix 0-dim tensor for arg_min_max op.