From 96c2aafdc4b8950d23fd54064833261d906b15ac Mon Sep 17 00:00:00 2001 From: tianhaodongbd <137985359+tianhaodongbd@users.noreply.github.com> Date: Wed, 21 Feb 2024 17:51:36 +0800 Subject: [PATCH] repeat_interleave support bf16 dtype (#61854) (#61899) * repeat_interleave support bf16 dtype * support bf16 on cpu --- .../cpu/repeat_interleave_grad_kernel.cc | 6 ++++-- .../kernels/cpu/repeat_interleave_kernel.cc | 6 ++++-- .../gpu/repeat_interleave_grad_kernel.cu | 6 ++++-- .../kernels/gpu/repeat_interleave_kernel.cu | 6 ++++-- test/legacy_test/test_repeat_interleave_op.py | 19 +++++++++++++++++++ 5 files changed, 35 insertions(+), 8 deletions(-) diff --git a/paddle/phi/kernels/cpu/repeat_interleave_grad_kernel.cc b/paddle/phi/kernels/cpu/repeat_interleave_grad_kernel.cc index b7b33d4290dae..66f3ef0cd790d 100644 --- a/paddle/phi/kernels/cpu/repeat_interleave_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/repeat_interleave_grad_kernel.cc @@ -104,7 +104,8 @@ PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index_grad, float, double, int, - int64_t) {} + int64_t, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(repeat_interleave_grad, CPU, @@ -113,4 +114,5 @@ PD_REGISTER_KERNEL(repeat_interleave_grad, float, double, int, - int64_t) {} + int64_t, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/cpu/repeat_interleave_kernel.cc b/paddle/phi/kernels/cpu/repeat_interleave_kernel.cc index 388e243eff42a..8b00d7e38f304 100644 --- a/paddle/phi/kernels/cpu/repeat_interleave_kernel.cc +++ b/paddle/phi/kernels/cpu/repeat_interleave_kernel.cc @@ -25,7 +25,8 @@ PD_REGISTER_KERNEL(repeat_interleave, float, double, int, - int64_t) {} + int64_t, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index, CPU, @@ -34,4 +35,5 @@ PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index, float, double, int, - int64_t) {} + int64_t, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/repeat_interleave_grad_kernel.cu b/paddle/phi/kernels/gpu/repeat_interleave_grad_kernel.cu index 52a0e313398e8..5ff1418b2732a 100644 --- a/paddle/phi/kernels/gpu/repeat_interleave_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/repeat_interleave_grad_kernel.cu @@ -25,7 +25,8 @@ PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index_grad, float, double, int, - int64_t) {} + int64_t, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(repeat_interleave_grad, GPU, ALL_LAYOUT, @@ -33,4 +34,5 @@ PD_REGISTER_KERNEL(repeat_interleave_grad, float, double, int, - int64_t) {} + int64_t, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/repeat_interleave_kernel.cu b/paddle/phi/kernels/gpu/repeat_interleave_kernel.cu index ed62278f067e5..7b0675b3a752d 100644 --- a/paddle/phi/kernels/gpu/repeat_interleave_kernel.cu +++ b/paddle/phi/kernels/gpu/repeat_interleave_kernel.cu @@ -25,7 +25,8 @@ PD_REGISTER_KERNEL(repeat_interleave, float, double, int, - int64_t) {} + int64_t, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index, GPU, @@ -34,4 +35,5 @@ PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index, float, double, int, - int64_t) {} + int64_t, + phi::dtype::bfloat16) {} diff --git a/test/legacy_test/test_repeat_interleave_op.py b/test/legacy_test/test_repeat_interleave_op.py index b2d0a12c6e260..60d11a813263e 100644 --- a/test/legacy_test/test_repeat_interleave_op.py +++ b/test/legacy_test/test_repeat_interleave_op.py @@ -252,6 +252,25 @@ def test_dygraph_api(self): expect_out = np.repeat(input_x, index, axis=None) np.testing.assert_allclose(expect_out, np_z, rtol=1e-05) + # case input dtype is bfloat16 + input_x = np.array([[1, 2, 1], [1, 2, 3]]).astype('uint16') + + with base.dygraph.guard(): + x = paddle.to_tensor(input_x) + index = paddle.to_tensor(index_x) + z = paddle.repeat_interleave(x, index, None) + np_z = z.numpy() + expect_out = np.repeat(input_x, index_x, axis=None) + np.testing.assert_allclose(expect_out, np_z, rtol=1e-05) + + with base.dygraph.guard(): + x = paddle.to_tensor(input_x) + index = 2 + z = paddle.repeat_interleave(x, index, None) + np_z = z.numpy() + expect_out = np.repeat(input_x, index, axis=None) + np.testing.assert_allclose(expect_out, np_z, rtol=1e-05) + # case 1: with base.dygraph.guard(): x = base.dygraph.to_variable(self.data_x)