Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XPU][PHI Kernels] add scatter_nd_add_grad kernel & bf16 support for slice OPs #58580

Merged
merged 6 commits into from
Nov 2, 2023

Conversation

lj970926
Copy link
Contributor

@lj970926 lj970926 commented Nov 1, 2023

PR types

New features

PR changes

OPs

Description

  1. add scatter_nd_add_grad kernel for xpu
  2. bf16 support for slice、slice_grad、strided_slice、strided_slice_grad
  3. fix bugs in scatter_nd_add when index.numel() == 0
  4. int64 support for reduce_prod and nonzero
  5. fix assertation in full_like kernel when input is nan/inf. The kernel needs to handle these special values and xpu api can deal with them properly.

int loop_time = static_cast<int>(
index_dims_size == 0 ? 1
: phi::product(phi::slice_ddim(
index.dims(), 0, index_dims_size - 1)));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index tensor最后一维长度为0时,需要按前面所有维度索引updates数组并累加到output中,而不是只累加第一维

self.index_np = np.array([[[], []], [[], []]]).astype("int32")
self.updates_np = np.random.random((2, 2, 10, 10)).astype(
self.dtype
)
Copy link
Contributor Author

@lj970926 lj970926 Nov 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

针对scatter_nd_add前向bug的单测修改

@paddle-bot paddle-bot bot added the contributor External developers label Nov 1, 2023
Comment on lines +66 to +75
bool is_out_range = true;
if (std::isinf(value) || std::isnan(value)) {
is_out_range = false;
}
if ((common_type_value >=
static_cast<CommonType>(std::numeric_limits<T>::lowest())) &&
(common_type_value <=
static_cast<CommonType>(std::numeric_limits<T>::max()))) {
is_out_range = false;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里感觉要修改一下,common_type_value满足条件之后,就算value里有inf和nan,is_out_range也是false

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个正常行为,用户可以paddle.full_like一个值为Nan的Tensor,这里的逻辑是如果传入的值在数据类型能表示的范围或者是Nan/inf都是合理的。这个参考了GPU实现

bool is_out_range = true;

}

} // namespace phi

PD_REGISTER_KERNEL(
nonzero, XPU, ALL_LAYOUT, phi::NonZeroKernel, int, bool, float) {
nonzero, XPU, ALL_LAYOUT, phi::NonZeroKernel, int, bool, float, int64_t) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要在op_list里注册数据类型吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nonzero在Op list里的名字是where_index,这个已经加了

@@ -50,4 +50,5 @@ void ProdKernel(const Context& dev_ctx,

} // namespace phi

PD_REGISTER_KERNEL(prod, XPU, ALL_LAYOUT, phi::ProdKernel, float) {}
PD_REGISTER_KERNEL(
prod, XPU, ALL_LAYOUT, phi::ProdKernel, float, int, int64_t) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prod在Op list里的名字是reduce_prod,这个已经加了

Copy link
Contributor

@RuohengMa RuohengMa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@QingshuChen QingshuChen merged commit 038d4b4 into PaddlePaddle:develop Nov 2, 2023
@paddle-bot paddle-bot bot removed the contributor External developers label Nov 3, 2023
zeroRains pushed a commit to zeroRains/Paddle that referenced this pull request Nov 8, 2023
…slice OPs (PaddlePaddle#58580)

* bevformer and bf16 support

* refine format

* refine format

* refine format

* refine format

* fix bugs in compilation
danleifeng pushed a commit to danleifeng/Paddle that referenced this pull request Nov 14, 2023
…slice OPs (PaddlePaddle#58580)

* bevformer and bf16 support

* refine format

* refine format

* refine format

* refine format

* fix bugs in compilation
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants