-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[XPU][PHI Kernels] add scatter_nd_add_grad kernel & bf16 support for slice OPs #58580
Conversation
int loop_time = static_cast<int>( | ||
index_dims_size == 0 ? 1 | ||
: phi::product(phi::slice_ddim( | ||
index.dims(), 0, index_dims_size - 1))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
index tensor最后一维长度为0时,需要按前面所有维度索引updates数组并累加到output中,而不是只累加第一维
self.index_np = np.array([[[], []], [[], []]]).astype("int32") | ||
self.updates_np = np.random.random((2, 2, 10, 10)).astype( | ||
self.dtype | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
针对scatter_nd_add前向bug的单测修改
bool is_out_range = true; | ||
if (std::isinf(value) || std::isnan(value)) { | ||
is_out_range = false; | ||
} | ||
if ((common_type_value >= | ||
static_cast<CommonType>(std::numeric_limits<T>::lowest())) && | ||
(common_type_value <= | ||
static_cast<CommonType>(std::numeric_limits<T>::max()))) { | ||
is_out_range = false; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里感觉要修改一下,common_type_value满足条件之后,就算value里有inf和nan,is_out_range也是false
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个正常行为,用户可以paddle.full_like一个值为Nan的Tensor,这里的逻辑是如果传入的值在数据类型能表示的范围或者是Nan/inf都是合理的。这个参考了GPU实现
bool is_out_range = true; |
} | ||
|
||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL( | ||
nonzero, XPU, ALL_LAYOUT, phi::NonZeroKernel, int, bool, float) { | ||
nonzero, XPU, ALL_LAYOUT, phi::NonZeroKernel, int, bool, float, int64_t) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里需要在op_list里注册数据类型吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nonzero在Op list里的名字是where_index,这个已经加了
@@ -50,4 +50,5 @@ void ProdKernel(const Context& dev_ctx, | |||
|
|||
} // namespace phi | |||
|
|||
PD_REGISTER_KERNEL(prod, XPU, ALL_LAYOUT, phi::ProdKernel, float) {} | |||
PD_REGISTER_KERNEL( | |||
prod, XPU, ALL_LAYOUT, phi::ProdKernel, float, int, int64_t) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prod在Op list里的名字是reduce_prod,这个已经加了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…slice OPs (PaddlePaddle#58580) * bevformer and bf16 support * refine format * refine format * refine format * refine format * fix bugs in compilation
…slice OPs (PaddlePaddle#58580) * bevformer and bf16 support * refine format * refine format * refine format * refine format * fix bugs in compilation
PR types
New features
PR changes
OPs
Description
index.numel() == 0