From b17e2bcf060337e25ab8bd842d3bbd1011cd9b43 Mon Sep 17 00:00:00 2001 From: "haowen.han@mthreads.com" Date: Mon, 13 May 2024 13:29:31 +0800 Subject: [PATCH] Revert "repeat_interleave support bf16 dtype (#61854) (#61899)" This reverts commit 96c2aafdc4b8950d23fd54064833261d906b15ac. --- .../cpu/repeat_interleave_grad_kernel.cc | 6 ++---- .../kernels/cpu/repeat_interleave_kernel.cc | 6 ++---- .../gpu/repeat_interleave_grad_kernel.cu | 6 ++---- .../kernels/gpu/repeat_interleave_kernel.cu | 6 ++---- test/legacy_test/test_repeat_interleave_op.py | 19 ------------------- 5 files changed, 8 insertions(+), 35 deletions(-) diff --git a/paddle/phi/kernels/cpu/repeat_interleave_grad_kernel.cc b/paddle/phi/kernels/cpu/repeat_interleave_grad_kernel.cc index 66f3ef0cd790d1..b7b33d4290daec 100644 --- a/paddle/phi/kernels/cpu/repeat_interleave_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/repeat_interleave_grad_kernel.cc @@ -104,8 +104,7 @@ PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index_grad, float, double, int, - int64_t, - phi::dtype::bfloat16) {} + int64_t) {} PD_REGISTER_KERNEL(repeat_interleave_grad, CPU, @@ -114,5 +113,4 @@ PD_REGISTER_KERNEL(repeat_interleave_grad, float, double, int, - int64_t, - phi::dtype::bfloat16) {} + int64_t) {} diff --git a/paddle/phi/kernels/cpu/repeat_interleave_kernel.cc b/paddle/phi/kernels/cpu/repeat_interleave_kernel.cc index 8b00d7e38f304c..388e243eff42a0 100644 --- a/paddle/phi/kernels/cpu/repeat_interleave_kernel.cc +++ b/paddle/phi/kernels/cpu/repeat_interleave_kernel.cc @@ -25,8 +25,7 @@ PD_REGISTER_KERNEL(repeat_interleave, float, double, int, - int64_t, - phi::dtype::bfloat16) {} + int64_t) {} PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index, CPU, @@ -35,5 +34,4 @@ PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index, float, double, int, - int64_t, - phi::dtype::bfloat16) {} + int64_t) {} diff --git a/paddle/phi/kernels/gpu/repeat_interleave_grad_kernel.cu b/paddle/phi/kernels/gpu/repeat_interleave_grad_kernel.cu index 5ff1418b2732ad..52a0e313398e8b 100644 --- a/paddle/phi/kernels/gpu/repeat_interleave_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/repeat_interleave_grad_kernel.cu @@ -25,8 +25,7 @@ PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index_grad, float, double, int, - int64_t, - phi::dtype::bfloat16) {} + int64_t) {} PD_REGISTER_KERNEL(repeat_interleave_grad, GPU, ALL_LAYOUT, @@ -34,5 +33,4 @@ PD_REGISTER_KERNEL(repeat_interleave_grad, float, double, int, - int64_t, - phi::dtype::bfloat16) {} + int64_t) {} diff --git a/paddle/phi/kernels/gpu/repeat_interleave_kernel.cu b/paddle/phi/kernels/gpu/repeat_interleave_kernel.cu index 7b0675b3a752df..ed62278f067e5f 100644 --- a/paddle/phi/kernels/gpu/repeat_interleave_kernel.cu +++ b/paddle/phi/kernels/gpu/repeat_interleave_kernel.cu @@ -25,8 +25,7 @@ PD_REGISTER_KERNEL(repeat_interleave, float, double, int, - int64_t, - phi::dtype::bfloat16) {} + int64_t) {} PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index, GPU, @@ -35,5 +34,4 @@ PD_REGISTER_KERNEL(repeat_interleave_with_tensor_index, float, double, int, - int64_t, - phi::dtype::bfloat16) {} + int64_t) {} diff --git a/test/legacy_test/test_repeat_interleave_op.py b/test/legacy_test/test_repeat_interleave_op.py index 60d11a813263e5..b2d0a12c6e260d 100644 --- a/test/legacy_test/test_repeat_interleave_op.py +++ b/test/legacy_test/test_repeat_interleave_op.py @@ -252,25 +252,6 @@ def test_dygraph_api(self): expect_out = np.repeat(input_x, index, axis=None) np.testing.assert_allclose(expect_out, np_z, rtol=1e-05) - # case input dtype is bfloat16 - input_x = np.array([[1, 2, 1], [1, 2, 3]]).astype('uint16') - - with base.dygraph.guard(): - x = paddle.to_tensor(input_x) - index = paddle.to_tensor(index_x) - z = paddle.repeat_interleave(x, index, None) - np_z = z.numpy() - expect_out = np.repeat(input_x, index_x, axis=None) - np.testing.assert_allclose(expect_out, np_z, rtol=1e-05) - - with base.dygraph.guard(): - x = paddle.to_tensor(input_x) - index = 2 - z = paddle.repeat_interleave(x, index, None) - np_z = z.numpy() - expect_out = np.repeat(input_x, index, axis=None) - np.testing.assert_allclose(expect_out, np_z, rtol=1e-05) - # case 1: with base.dygraph.guard(): x = base.dygraph.to_variable(self.data_x)