From 90d9355be68f8c756a5e9faa2bedbe41a33c9f0b Mon Sep 17 00:00:00 2001 From: tianhaodongbd <137985359+tianhaodongbd@users.noreply.github.com> Date: Wed, 21 Feb 2024 10:51:17 +0800 Subject: [PATCH] repeat_interleave support bf16 dtype (#61854) * repeat_interleave support bf16 dtype * support bf16 on cpu --- .../cpu/repeat_interleave_grad_kernel.cc | 6 ++++-- .../kernels/cpu/repeat_interleave_kernel.cc | 6 ++++-- .../gpu/repeat_interleave_grad_kernel.cu | 6 ++++-- .../kernels/gpu/repeat_interleave_kernel.cu | 6 ++++-- test/legacy_test/test_repeat_interleave_op.py | 19 +++++++++++++++++++ 5 files changed, 35 insertions(+), 8 deletions(-) diff --git a/paddle/phi/kernels/cpu/repeat_interleave_grad_kernel.cc b/paddle/phi/kernels/cpu/repeat_interleave_grad_kernel.cc index b7b33d4290daec..66f3ef0cd790d1 100644 --- a/paddle/phi/kernels/cpu/repeat_interleave_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/repeat_interleave_grad_kernel.cc @@ -104,7 +104,8 @@ PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index_grad, float, double, int, - int64_t) {} + int64_t, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(repeat_interleave_grad, CPU, @@ -113,4 +114,5 @@ PD_REGISTER_KERNEL(repeat_interleave_grad, float, double, int, - int64_t) {} + int64_t, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/cpu/repeat_interleave_kernel.cc b/paddle/phi/kernels/cpu/repeat_interleave_kernel.cc index 388e243eff42a0..8b00d7e38f304c 100644 --- a/paddle/phi/kernels/cpu/repeat_interleave_kernel.cc +++ b/paddle/phi/kernels/cpu/repeat_interleave_kernel.cc @@ -25,7 +25,8 @@ PD_REGISTER_KERNEL(repeat_interleave, float, double, int, - int64_t) {} + int64_t, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index, CPU, @@ -34,4 +35,5 @@ PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index, float, double, int, - int64_t) {} + int64_t, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/repeat_interleave_grad_kernel.cu b/paddle/phi/kernels/gpu/repeat_interleave_grad_kernel.cu index 52a0e313398e8b..5ff1418b2732ad 100644 --- a/paddle/phi/kernels/gpu/repeat_interleave_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/repeat_interleave_grad_kernel.cu @@ -25,7 +25,8 @@ PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index_grad, float, double, int, - int64_t) {} + int64_t, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(repeat_interleave_grad, GPU, ALL_LAYOUT, @@ -33,4 +34,5 @@ PD_REGISTER_KERNEL(repeat_interleave_grad, float, double, int, - int64_t) {} + int64_t, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/repeat_interleave_kernel.cu b/paddle/phi/kernels/gpu/repeat_interleave_kernel.cu index ed62278f067e5f..7b0675b3a752df 100644 --- a/paddle/phi/kernels/gpu/repeat_interleave_kernel.cu +++ b/paddle/phi/kernels/gpu/repeat_interleave_kernel.cu @@ -25,7 +25,8 @@ PD_REGISTER_KERNEL(repeat_interleave, float, double, int, - int64_t) {} + int64_t, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index, GPU, @@ -34,4 +35,5 @@ PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index, float, double, int, - int64_t) {} + int64_t, + phi::dtype::bfloat16) {} diff --git a/test/legacy_test/test_repeat_interleave_op.py b/test/legacy_test/test_repeat_interleave_op.py index b2d0a12c6e260d..60d11a813263e5 100644 --- a/test/legacy_test/test_repeat_interleave_op.py +++ b/test/legacy_test/test_repeat_interleave_op.py @@ -252,6 +252,25 @@ def test_dygraph_api(self): expect_out = np.repeat(input_x, index, axis=None) np.testing.assert_allclose(expect_out, np_z, rtol=1e-05) + # case input dtype is bfloat16 + input_x = np.array([[1, 2, 1], [1, 2, 3]]).astype('uint16') + + with base.dygraph.guard(): + x = paddle.to_tensor(input_x) + index = paddle.to_tensor(index_x) + z = paddle.repeat_interleave(x, index, None) + np_z = z.numpy() + expect_out = np.repeat(input_x, index_x, axis=None) + np.testing.assert_allclose(expect_out, np_z, rtol=1e-05) + + with base.dygraph.guard(): + x = paddle.to_tensor(input_x) + index = 2 + z = paddle.repeat_interleave(x, index, None) + np_z = z.numpy() + expect_out = np.repeat(input_x, index, axis=None) + np.testing.assert_allclose(expect_out, np_z, rtol=1e-05) + # case 1: with base.dygraph.guard(): x = base.dygraph.to_variable(self.data_x)