From b72dfeade38de0ca498becc9264a66523a6e757d Mon Sep 17 00:00:00 2001 From: loneranger <836253168@qq.com> Date: Tue, 7 Mar 2023 21:32:20 +0800 Subject: [PATCH 01/13] add fp16 for embedding and bf16 for lerp --- paddle/phi/kernels/embedding_kernel.h | 1 + paddle/phi/kernels/gpu/lerp_kernel.cu | 11 ++++++- paddle/phi/kernels/lerp_kernel.h | 1 + .../fluid/tests/unittests/test_lerp_op.py | 33 ++++++++++++++++++- .../unittests/test_lookup_table_v2_op.py | 32 ++++++++++++++++++ 5 files changed, 76 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/embedding_kernel.h b/paddle/phi/kernels/embedding_kernel.h index cd7d675d6dc6cd..8e672cc8806b21 100644 --- a/paddle/phi/kernels/embedding_kernel.h +++ b/paddle/phi/kernels/embedding_kernel.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/lerp_kernel.cu b/paddle/phi/kernels/gpu/lerp_kernel.cu index 3f6862ff9795e2..3ac58c475a21af 100644 --- a/paddle/phi/kernels/gpu/lerp_kernel.cu +++ b/paddle/phi/kernels/gpu/lerp_kernel.cu @@ -15,7 +15,16 @@ #include "paddle/phi/kernels/lerp_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/lerp_kernel_impl.h" -PD_REGISTER_KERNEL(lerp, GPU, ALL_LAYOUT, phi::LerpKernel, float, double) {} +PD_REGISTER_KERNEL(lerp, + GPU, + ALL_LAYOUT, + phi::LerpKernel, + float, + double, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/lerp_kernel.h b/paddle/phi/kernels/lerp_kernel.h index 0f73c5bcf9ed62..72adbf519bd2c8 100644 --- a/paddle/phi/kernels/lerp_kernel.h +++ b/paddle/phi/kernels/lerp_kernel.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" namespace phi { diff --git a/python/paddle/fluid/tests/unittests/test_lerp_op.py b/python/paddle/fluid/tests/unittests/test_lerp_op.py index 625d5b1b13dfe7..5e9472f05ac652 100644 --- a/python/paddle/fluid/tests/unittests/test_lerp_op.py +++ b/python/paddle/fluid/tests/unittests/test_lerp_op.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from op_test import OpTest +from op_test import OpTest, convert_float_to_uint16 import paddle import paddle.fluid.core as core @@ -195,5 +195,36 @@ def test_x_y_broadcast_w(self): paddle.enable_static() +class TestLerpBF16(OpTest): + def setUp(self): + self.op_type = "lerp" + self.python_api = paddle.lerp + self.init_dtype() + self.init_shape() + x = np.arange(1.0, 101.0).astype(np.float32).reshape(self.shape) + y = np.full(100, 10.0).astype(np.float32).reshape(self.shape) + w = np.asarray([0.5]).astype(np.float32) + self.inputs = { + 'X': convert_float_to_uint16(x), + 'Y': convert_float_to_uint16(y), + 'Weight': convert_float_to_uint16(w), + } + self.outputs = {'Out': convert_float_to_uint16(x + w * (y - x))} + + def init_dtype(self): + self.dtype = np.uint16 + + def init_shape(self): + self.shape = [100] + + def test_check_output(self): + self.check_output(check_eager=True, atol=1e-2) + + def test_check_grad(self): + self.check_grad( + ['X', 'Y'], 'Out', check_eager=True, max_relative_error=1e-2 + ) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py index dc0c8f3174bb5d..51c017f8e953b0 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py @@ -294,6 +294,38 @@ def test_param_dtype(): ) +class TestEmbeddingFP16OP(OpTest): + def setUp(self): + self.op_type = "lookup_table_v2" + self.dtype = "float16" + self.python_api = paddle.nn.functional.embedding + table = np.random.random((17, 31)).astype(np.float32) + ids = np.random.randint(0, 17, 4).astype(self.id_dtype()) + self.inputs = {'W': table.astype(self.dtype), 'Ids': ids} + self.outputs = {'Out': table[ids].astype(self.dtype)} + + def id_dtype(self): + return "int64" + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=1e-3) + + def test_check_grad(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_grad_with_place( + place, + ['X'], + 'Out', + no_grad_set=set('Ids'), + max_relative_error=1e-2, + ) + + if __name__ == "__main__": paddle.enable_static() unittest.main() From 6d8725935bda66d23f1d53885ffedaa47c339406 Mon Sep 17 00:00:00 2001 From: loneranger <836253168@qq.com> Date: Thu, 9 Mar 2023 17:16:38 +0800 Subject: [PATCH 02/13] fix bug --- paddle/phi/kernels/gpu/lerp_grad_kernel.cu | 11 ++++++++--- paddle/phi/kernels/gpu/lerp_kernel.cu | 4 +--- paddle/phi/kernels/lerp_grad_kernel.h | 1 + python/paddle/fluid/tests/unittests/test_lerp_op.py | 12 ++++++++---- 4 files changed, 18 insertions(+), 10 deletions(-) diff --git a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu index f42f316aae9803..b2fb56d927c9e4 100644 --- a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu @@ -16,8 +16,8 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" - #include "paddle/phi/kernels/broadcast_tensors_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/common_shape.h" @@ -270,5 +270,10 @@ void LerpGradKernel(const Context& ctx, } // namespace phi -PD_REGISTER_KERNEL( - lerp_grad, GPU, ALL_LAYOUT, phi::LerpGradKernel, float, double) {} +PD_REGISTER_KERNEL(lerp_grad, + GPU, + ALL_LAYOUT, + phi::LerpGradKernel, + float, + double, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/lerp_kernel.cu b/paddle/phi/kernels/gpu/lerp_kernel.cu index 3ac58c475a21af..b9d369e05e8fab 100644 --- a/paddle/phi/kernels/gpu/lerp_kernel.cu +++ b/paddle/phi/kernels/gpu/lerp_kernel.cu @@ -15,9 +15,7 @@ #include "paddle/phi/kernels/lerp_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/common/bfloat16.h" -#include "paddle/phi/common/data_type.h" -#include "paddle/phi/common/float16.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/lerp_kernel_impl.h" diff --git a/paddle/phi/kernels/lerp_grad_kernel.h b/paddle/phi/kernels/lerp_grad_kernel.h index b44af08f03dbe9..df3b994e179d64 100644 --- a/paddle/phi/kernels/lerp_grad_kernel.h +++ b/paddle/phi/kernels/lerp_grad_kernel.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" namespace phi { diff --git a/python/paddle/fluid/tests/unittests/test_lerp_op.py b/python/paddle/fluid/tests/unittests/test_lerp_op.py index 5e9472f05ac652..88db4c27e36f7d 100644 --- a/python/paddle/fluid/tests/unittests/test_lerp_op.py +++ b/python/paddle/fluid/tests/unittests/test_lerp_op.py @@ -218,12 +218,16 @@ def init_shape(self): self.shape = [100] def test_check_output(self): - self.check_output(check_eager=True, atol=1e-2) + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + self.check_output(check_eager=True, atol=1e-2) def test_check_grad(self): - self.check_grad( - ['X', 'Y'], 'Out', check_eager=True, max_relative_error=1e-2 - ) + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + self.check_grad( + ['X', 'Y'], 'Out', check_eager=True, max_relative_error=1e-2 + ) if __name__ == "__main__": From 63b16bae04b0903d0f59e706b8a5cb416549b52a Mon Sep 17 00:00:00 2001 From: loneranger <836253168@qq.com> Date: Thu, 9 Mar 2023 20:57:52 +0800 Subject: [PATCH 03/13] fix bug --- .../phi/kernels/gpu/broadcast_tensors_grad_kernel.cu | 5 +++-- paddle/phi/kernels/gpu/broadcast_tensors_kernel.cu | 5 +++-- paddle/phi/kernels/gpu/lerp_grad_kernel.cu | 2 +- paddle/phi/kernels/gpu/lerp_kernel.cu | 2 +- python/paddle/fluid/tests/unittests/test_lerp_op.py | 10 +++++++--- 5 files changed, 15 insertions(+), 9 deletions(-) diff --git a/paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu b/paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu index 7acfd33e94a9a4..a607a8ea5266b0 100644 --- a/paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu @@ -16,7 +16,7 @@ #include -#include "paddle/phi/common/float16.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" @@ -108,4 +108,5 @@ PD_REGISTER_KERNEL(broadcast_tensors_grad, int64_t, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/broadcast_tensors_kernel.cu b/paddle/phi/kernels/gpu/broadcast_tensors_kernel.cu index 82dac4552a42fe..ccc71b267ef6ab 100644 --- a/paddle/phi/kernels/gpu/broadcast_tensors_kernel.cu +++ b/paddle/phi/kernels/gpu/broadcast_tensors_kernel.cu @@ -14,7 +14,7 @@ #include "paddle/phi/kernels/broadcast_tensors_kernel.h" -#include "paddle/phi/common/float16.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/broadcast_tensors_kernel_impl.h" @@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(broadcast_tensors, int64_t, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu index b2fb56d927c9e4..584743e069af21 100644 --- a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu @@ -16,7 +16,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" -#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/broadcast_tensors_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" diff --git a/paddle/phi/kernels/gpu/lerp_kernel.cu b/paddle/phi/kernels/gpu/lerp_kernel.cu index b9d369e05e8fab..79b01d5d331a8d 100644 --- a/paddle/phi/kernels/gpu/lerp_kernel.cu +++ b/paddle/phi/kernels/gpu/lerp_kernel.cu @@ -15,7 +15,7 @@ #include "paddle/phi/kernels/lerp_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/lerp_kernel_impl.h" diff --git a/python/paddle/fluid/tests/unittests/test_lerp_op.py b/python/paddle/fluid/tests/unittests/test_lerp_op.py index 88db4c27e36f7d..8a9fb84a1ef551 100644 --- a/python/paddle/fluid/tests/unittests/test_lerp_op.py +++ b/python/paddle/fluid/tests/unittests/test_lerp_op.py @@ -220,13 +220,17 @@ def init_shape(self): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output(check_eager=True, atol=1e-2) + self.check_output_with_place(place, check_eager=True, atol=1e-2) def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_grad( - ['X', 'Y'], 'Out', check_eager=True, max_relative_error=1e-2 + self.check_grad_with_place( + place, + ['X', 'Y'], + 'Out', + check_eager=True, + max_relative_error=1e-2, ) From e87a23899cb927630c70c1e15566d53d10bfaa9c Mon Sep 17 00:00:00 2001 From: loneranger <836253168@qq.com> Date: Thu, 9 Mar 2023 22:02:34 +0800 Subject: [PATCH 04/13] fix bug --- python/paddle/fluid/tests/unittests/test_lerp_op.py | 1 + python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py | 1 + 2 files changed, 2 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_lerp_op.py b/python/paddle/fluid/tests/unittests/test_lerp_op.py index 8a9fb84a1ef551..c2da786b9e510d 100644 --- a/python/paddle/fluid/tests/unittests/test_lerp_op.py +++ b/python/paddle/fluid/tests/unittests/test_lerp_op.py @@ -201,6 +201,7 @@ def setUp(self): self.python_api = paddle.lerp self.init_dtype() self.init_shape() + self.__class__.op_type = self.op_type x = np.arange(1.0, 101.0).astype(np.float32).reshape(self.shape) y = np.full(100, 10.0).astype(np.float32).reshape(self.shape) w = np.asarray([0.5]).astype(np.float32) diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py index 51c017f8e953b0..267c6eeb98e0dc 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py @@ -299,6 +299,7 @@ def setUp(self): self.op_type = "lookup_table_v2" self.dtype = "float16" self.python_api = paddle.nn.functional.embedding + self.__class__.op_type = self.op_type table = np.random.random((17, 31)).astype(np.float32) ids = np.random.randint(0, 17, 4).astype(self.id_dtype()) self.inputs = {'W': table.astype(self.dtype), 'Ids': ids} From 4e0dadb6320f7f6f4181f716f9c17298b88de6db Mon Sep 17 00:00:00 2001 From: loneranger <836253168@qq.com> Date: Fri, 10 Mar 2023 20:23:40 +0800 Subject: [PATCH 05/13] fix bug --- .../fluid/tests/unittests/test_lerp_op.py | 27 +++++++++-------- .../unittests/test_lookup_table_v2_op.py | 29 ++++++++++--------- 2 files changed, 30 insertions(+), 26 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_lerp_op.py b/python/paddle/fluid/tests/unittests/test_lerp_op.py index c2da786b9e510d..b0dfa962a7051d 100644 --- a/python/paddle/fluid/tests/unittests/test_lerp_op.py +++ b/python/paddle/fluid/tests/unittests/test_lerp_op.py @@ -195,6 +195,11 @@ def test_x_y_broadcast_w(self): paddle.enable_static() +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the bfloat16", +) class TestLerpBF16(OpTest): def setUp(self): self.op_type = "lerp" @@ -219,20 +224,18 @@ def init_shape(self): self.shape = [100] def test_check_output(self): - if core.is_compiled_with_cuda(): - place = core.CUDAPlace(0) - self.check_output_with_place(place, check_eager=True, atol=1e-2) + place = core.CUDAPlace(0) + self.check_output_with_place(place, check_eager=True, atol=1e-2) def test_check_grad(self): - if core.is_compiled_with_cuda(): - place = core.CUDAPlace(0) - self.check_grad_with_place( - place, - ['X', 'Y'], - 'Out', - check_eager=True, - max_relative_error=1e-2, - ) + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, + ['X', 'Y'], + 'Out', + check_eager=True, + max_relative_error=1e-2, + ) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py index 267c6eeb98e0dc..97de9247c23ade 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py @@ -294,6 +294,11 @@ def test_param_dtype(): ) +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_float16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the float16", +) class TestEmbeddingFP16OP(OpTest): def setUp(self): self.op_type = "lookup_table_v2" @@ -309,22 +314,18 @@ def id_dtype(self): return "int64" def test_check_output(self): - if core.is_compiled_with_cuda(): - place = core.CUDAPlace(0) - if core.is_float16_supported(place): - self.check_output_with_place(place, atol=1e-3) + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=1e-3) def test_check_grad(self): - if core.is_compiled_with_cuda(): - place = core.CUDAPlace(0) - if core.is_float16_supported(place): - self.check_grad_with_place( - place, - ['X'], - 'Out', - no_grad_set=set('Ids'), - max_relative_error=1e-2, - ) + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, + ['X'], + 'Out', + no_grad_set=set('Ids'), + max_relative_error=1e-2, + ) if __name__ == "__main__": From df2b387d995fcf30c459bb64a3170b68a7d91baf Mon Sep 17 00:00:00 2001 From: LoneRanger <836253168@qq.com> Date: Sat, 11 Mar 2023 17:47:05 +0800 Subject: [PATCH 06/13] Update test_lookup_table_v2_op.py fix bug --- .../paddle/fluid/tests/unittests/test_lookup_table_v2_op.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py index 97de9247c23ade..df1d12e8874956 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py @@ -302,13 +302,13 @@ def test_param_dtype(): class TestEmbeddingFP16OP(OpTest): def setUp(self): self.op_type = "lookup_table_v2" - self.dtype = "float16" + self.dtype = np.float16 self.python_api = paddle.nn.functional.embedding self.__class__.op_type = self.op_type table = np.random.random((17, 31)).astype(np.float32) ids = np.random.randint(0, 17, 4).astype(self.id_dtype()) self.inputs = {'W': table.astype(self.dtype), 'Ids': ids} - self.outputs = {'Out': table[ids].astype(self.dtype)} + self.outputs = {'Out': table[ids]} def id_dtype(self): return "int64" @@ -321,7 +321,7 @@ def test_check_grad(self): place = core.CUDAPlace(0) self.check_grad_with_place( place, - ['X'], + ['W'], 'Out', no_grad_set=set('Ids'), max_relative_error=1e-2, From 675dda23e5e7067d0888781d5b5e723945e983c2 Mon Sep 17 00:00:00 2001 From: longranger2 <836253168@qq.com> Date: Sun, 19 Mar 2023 21:07:31 +0800 Subject: [PATCH 07/13] fix bug --- paddle/phi/kernels/embedding_kernel.h | 1 - .../gpu/broadcast_tensors_grad_kernel.cu | 2 +- .../kernels/gpu/broadcast_tensors_kernel.cu | 2 +- paddle/phi/kernels/gpu/lerp_grad_kernel.cu | 2 +- paddle/phi/kernels/gpu/lerp_kernel.cu | 1 - paddle/phi/kernels/lerp_grad_kernel.h | 1 - paddle/phi/kernels/lerp_kernel.h | 1 - .../fluid/tests/unittests/test_lerp_op.py | 10 ++---- .../unittests/test_lookup_table_v2_op.py | 33 ++----------------- 9 files changed, 7 insertions(+), 46 deletions(-) diff --git a/paddle/phi/kernels/embedding_kernel.h b/paddle/phi/kernels/embedding_kernel.h index 8e672cc8806b21..cd7d675d6dc6cd 100644 --- a/paddle/phi/kernels/embedding_kernel.h +++ b/paddle/phi/kernels/embedding_kernel.h @@ -15,7 +15,6 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/device_context.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu b/paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu index a607a8ea5266b0..d885fb4c454539 100644 --- a/paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu @@ -16,7 +16,7 @@ #include -#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/float16.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" diff --git a/paddle/phi/kernels/gpu/broadcast_tensors_kernel.cu b/paddle/phi/kernels/gpu/broadcast_tensors_kernel.cu index ccc71b267ef6ab..ede6e8496d9ccf 100644 --- a/paddle/phi/kernels/gpu/broadcast_tensors_kernel.cu +++ b/paddle/phi/kernels/gpu/broadcast_tensors_kernel.cu @@ -14,7 +14,7 @@ #include "paddle/phi/kernels/broadcast_tensors_kernel.h" -#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/broadcast_tensors_kernel_impl.h" diff --git a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu index 584743e069af21..ff915716b45091 100644 --- a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu @@ -16,8 +16,8 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" -#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" + #include "paddle/phi/kernels/broadcast_tensors_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/common_shape.h" diff --git a/paddle/phi/kernels/gpu/lerp_kernel.cu b/paddle/phi/kernels/gpu/lerp_kernel.cu index 79b01d5d331a8d..99455e8490c8d7 100644 --- a/paddle/phi/kernels/gpu/lerp_kernel.cu +++ b/paddle/phi/kernels/gpu/lerp_kernel.cu @@ -15,7 +15,6 @@ #include "paddle/phi/kernels/lerp_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/lerp_kernel_impl.h" diff --git a/paddle/phi/kernels/lerp_grad_kernel.h b/paddle/phi/kernels/lerp_grad_kernel.h index df3b994e179d64..b44af08f03dbe9 100644 --- a/paddle/phi/kernels/lerp_grad_kernel.h +++ b/paddle/phi/kernels/lerp_grad_kernel.h @@ -15,7 +15,6 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/device_context.h" namespace phi { diff --git a/paddle/phi/kernels/lerp_kernel.h b/paddle/phi/kernels/lerp_kernel.h index 72adbf519bd2c8..0f73c5bcf9ed62 100644 --- a/paddle/phi/kernels/lerp_kernel.h +++ b/paddle/phi/kernels/lerp_kernel.h @@ -15,7 +15,6 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/device_context.h" namespace phi { diff --git a/python/paddle/fluid/tests/unittests/test_lerp_op.py b/python/paddle/fluid/tests/unittests/test_lerp_op.py index b0dfa962a7051d..96f7d2cafd1fed 100644 --- a/python/paddle/fluid/tests/unittests/test_lerp_op.py +++ b/python/paddle/fluid/tests/unittests/test_lerp_op.py @@ -225,17 +225,11 @@ def init_shape(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, check_eager=True, atol=1e-2) + self.check_output_with_place(place) def test_check_grad(self): place = core.CUDAPlace(0) - self.check_grad_with_place( - place, - ['X', 'Y'], - 'Out', - check_eager=True, - max_relative_error=1e-2, - ) + self.check_grad_with_place(place, ['X', 'Y'], 'Out') if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py index df1d12e8874956..0e77d416d0104c 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py @@ -294,38 +294,9 @@ def test_param_dtype(): ) -@unittest.skipIf( - not core.is_compiled_with_cuda() - or not core.is_float16_supported(core.CUDAPlace(0)), - "core is not complied with CUDA and not support the float16", -) -class TestEmbeddingFP16OP(OpTest): - def setUp(self): - self.op_type = "lookup_table_v2" - self.dtype = np.float16 - self.python_api = paddle.nn.functional.embedding - self.__class__.op_type = self.op_type - table = np.random.random((17, 31)).astype(np.float32) - ids = np.random.randint(0, 17, 4).astype(self.id_dtype()) - self.inputs = {'W': table.astype(self.dtype), 'Ids': ids} - self.outputs = {'Out': table[ids]} - +class TestEmbeddingFP16OP(TestLookupTableOp): def id_dtype(self): - return "int64" - - def test_check_output(self): - place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-3) - - def test_check_grad(self): - place = core.CUDAPlace(0) - self.check_grad_with_place( - place, - ['W'], - 'Out', - no_grad_set=set('Ids'), - max_relative_error=1e-2, - ) + return np.float16 if __name__ == "__main__": From 94c694feb2b6dde52e4371d3cec6a9eb2d31bb2f Mon Sep 17 00:00:00 2001 From: longranger2 <836253168@qq.com> Date: Sun, 19 Mar 2023 23:05:18 +0800 Subject: [PATCH 08/13] fix bug --- .../fluid/tests/unittests/test_lookup_table_v2_op.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py index 0e77d416d0104c..b72445ac2afb17 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py @@ -48,11 +48,17 @@ class TestLookupTableOp(OpTest): def setUp(self): self.op_type = "lookup_table_v2" self.python_api = paddle.nn.functional.embedding - table = np.random.random((17, 31)).astype("float64") + self.init_dtype() + + table = np.random.random((17, 31)).astype(self.dtype) ids = np.random.randint(0, 17, 4).astype(self.id_dtype()) + self.inputs = {'W': table, 'Ids': ids} self.outputs = {'Out': table[ids]} + def init_dtype(self): + self.dtype = "float64" + def id_dtype(self): return "int64" @@ -295,8 +301,8 @@ def test_param_dtype(): class TestEmbeddingFP16OP(TestLookupTableOp): - def id_dtype(self): - return np.float16 + def init_dtype(self): + self.dtype = np.float16 if __name__ == "__main__": From 20bab053be49633d011e352135ca549704e17a92 Mon Sep 17 00:00:00 2001 From: longranger2 <836253168@qq.com> Date: Sat, 15 Apr 2023 19:48:17 +0800 Subject: [PATCH 09/13] fix bug --- python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py index 0b28829297f6e4..29fcc539073c20 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py @@ -49,7 +49,7 @@ def setUp(self): self.python_api = paddle.nn.functional.embedding self.init_dtype() - table = np.random.random((17, 31)).astype(self.dtype) + table = np.random.random((17, 32)).astype(self.dtype) ids = np.random.randint(0, 17, 4).astype(self.id_dtype()) self.inputs = {'W': table, 'Ids': ids} From dc477092a63048d8284d34f9bf62fc5dbaedc5bb Mon Sep 17 00:00:00 2001 From: longranger2 <836253168@qq.com> Date: Mon, 17 Apr 2023 06:13:24 +0000 Subject: [PATCH 10/13] fix bug --- python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py index 29fcc539073c20..bf88c0e8246142 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py @@ -108,7 +108,7 @@ class TestLookupTableOpWithPadding(TestLookupTableOp): def test_check_output(self): ids = np.squeeze(self.inputs['Ids']) padding_idx = np.random.choice(ids, 1)[0] - self.outputs['Out'][ids == padding_idx] = np.zeros(31) + self.outputs['Out'][ids == padding_idx] = np.zeros(32) self.attrs = {'padding_idx': int(padding_idx)} self.check_output() From f4cdd47fc14cd3f903b50ff29ec846a503e1df62 Mon Sep 17 00:00:00 2001 From: longranger2 <836253168@qq.com> Date: Wed, 19 Apr 2023 07:32:43 +0000 Subject: [PATCH 11/13] remove the support for lerp --- .../gpu/broadcast_tensors_grad_kernel.cu | 3 +- .../kernels/gpu/broadcast_tensors_kernel.cu | 3 +- paddle/phi/kernels/gpu/lerp_grad_kernel.cu | 1 - paddle/phi/kernels/gpu/lerp_kernel.cu | 1 - .../fluid/tests/unittests/test_lerp_op.py | 39 +------------------ .../unittests/test_lookup_table_v2_op.py | 4 +- 6 files changed, 5 insertions(+), 46 deletions(-) diff --git a/paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu b/paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu index d885fb4c454539..7acfd33e94a9a4 100644 --- a/paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu @@ -108,5 +108,4 @@ PD_REGISTER_KERNEL(broadcast_tensors_grad, int64_t, float, double, - phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/broadcast_tensors_kernel.cu b/paddle/phi/kernels/gpu/broadcast_tensors_kernel.cu index ede6e8496d9ccf..82dac4552a42fe 100644 --- a/paddle/phi/kernels/gpu/broadcast_tensors_kernel.cu +++ b/paddle/phi/kernels/gpu/broadcast_tensors_kernel.cu @@ -27,5 +27,4 @@ PD_REGISTER_KERNEL(broadcast_tensors, int64_t, float, double, - phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu index 43cf0deab6dd9d..8a1f31a5cfe37b 100644 --- a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu @@ -279,6 +279,5 @@ PD_REGISTER_KERNEL(lerp_grad, ALL_LAYOUT, phi::LerpGradKernel, phi::dtype::float16, - phi::dtype::bfloat16, float, double) {} diff --git a/paddle/phi/kernels/gpu/lerp_kernel.cu b/paddle/phi/kernels/gpu/lerp_kernel.cu index 8deb69918e031a..25f37bb170476b 100644 --- a/paddle/phi/kernels/gpu/lerp_kernel.cu +++ b/paddle/phi/kernels/gpu/lerp_kernel.cu @@ -23,6 +23,5 @@ PD_REGISTER_KERNEL(lerp, ALL_LAYOUT, phi::LerpKernel, phi::dtype::float16, - phi::dtype::bfloat16, float, double) {} diff --git a/python/paddle/fluid/tests/unittests/test_lerp_op.py b/python/paddle/fluid/tests/unittests/test_lerp_op.py index 6b32a017910a21..7bf2cb73380407 100644 --- a/python/paddle/fluid/tests/unittests/test_lerp_op.py +++ b/python/paddle/fluid/tests/unittests/test_lerp_op.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from eager_op_test import OpTest, convert_float_to_uint16 +from eager_op_test import OpTest import paddle from paddle.fluid import core @@ -220,42 +220,5 @@ def test_x_y_broadcast_w(self): paddle.enable_static() -@unittest.skipIf( - not core.is_compiled_with_cuda() - or not core.is_bfloat16_supported(core.CUDAPlace(0)), - "core is not complied with CUDA and not support the bfloat16", -) -class TestLerpBF16(OpTest): - def setUp(self): - self.op_type = "lerp" - self.python_api = paddle.lerp - self.init_dtype() - self.init_shape() - self.__class__.op_type = self.op_type - x = np.arange(1.0, 101.0).astype(np.float32).reshape(self.shape) - y = np.full(100, 10.0).astype(np.float32).reshape(self.shape) - w = np.asarray([0.5]).astype(np.float32) - self.inputs = { - 'X': convert_float_to_uint16(x), - 'Y': convert_float_to_uint16(y), - 'Weight': convert_float_to_uint16(w), - } - self.outputs = {'Out': convert_float_to_uint16(x + w * (y - x))} - - def init_dtype(self): - self.dtype = np.uint16 - - def init_shape(self): - self.shape = [100] - - def test_check_output(self): - place = core.CUDAPlace(0) - self.check_output_with_place(place) - - def test_check_grad(self): - place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X', 'Y'], 'Out') - - if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py index bf88c0e8246142..0b28829297f6e4 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py @@ -49,7 +49,7 @@ def setUp(self): self.python_api = paddle.nn.functional.embedding self.init_dtype() - table = np.random.random((17, 32)).astype(self.dtype) + table = np.random.random((17, 31)).astype(self.dtype) ids = np.random.randint(0, 17, 4).astype(self.id_dtype()) self.inputs = {'W': table, 'Ids': ids} @@ -108,7 +108,7 @@ class TestLookupTableOpWithPadding(TestLookupTableOp): def test_check_output(self): ids = np.squeeze(self.inputs['Ids']) padding_idx = np.random.choice(ids, 1)[0] - self.outputs['Out'][ids == padding_idx] = np.zeros(32) + self.outputs['Out'][ids == padding_idx] = np.zeros(31) self.attrs = {'padding_idx': int(padding_idx)} self.check_output() From 74b33036957ce3ec0c1bbb4adaf2deba0e98f8b0 Mon Sep 17 00:00:00 2001 From: longranger2 <836253168@qq.com> Date: Thu, 3 Aug 2023 14:49:00 +0800 Subject: [PATCH 12/13] fix bug --- test/legacy_test/test_lookup_table_v2_op.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/legacy_test/test_lookup_table_v2_op.py b/test/legacy_test/test_lookup_table_v2_op.py index bc69aa619f9147..39480593559480 100644 --- a/test/legacy_test/test_lookup_table_v2_op.py +++ b/test/legacy_test/test_lookup_table_v2_op.py @@ -304,6 +304,17 @@ def test_param_dtype(): class TestEmbeddingFP16OP(TestLookupTableOp): + def setUp(self): + self.op_type = "lookup_table_v2" + self.python_api = paddle.nn.functional.embedding + self.init_dtype() + + table = np.random.random((18, 32)).astype(self.dtype) + ids = np.random.randint(0, 18, 4).astype(self.id_dtype()) + + self.inputs = {'W': table, 'Ids': ids} + self.outputs = {'Out': table[ids]} + def init_dtype(self): self.dtype = np.float16 From 5a755a9f8a148bab74b03d65a51be8557becbd71 Mon Sep 17 00:00:00 2001 From: longranger2 <836253168@qq.com> Date: Thu, 3 Aug 2023 23:10:12 +0800 Subject: [PATCH 13/13] add bfloat16 test for embedding --- test/legacy_test/test_lookup_table_v2_op.py | 33 ++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/test/legacy_test/test_lookup_table_v2_op.py b/test/legacy_test/test_lookup_table_v2_op.py index 39480593559480..586e347cf00b84 100644 --- a/test/legacy_test/test_lookup_table_v2_op.py +++ b/test/legacy_test/test_lookup_table_v2_op.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from eager_op_test import OpTest, skip_check_grad_ci +from eager_op_test import OpTest, convert_float_to_uint16, skip_check_grad_ci from op import Operator import paddle @@ -319,6 +319,37 @@ def init_dtype(self): self.dtype = np.float16 +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the bfloat16", +) +class TestEmbeddingBF16OP(OpTest): + def setUp(self): + self.op_type = "lookup_table_v2" + self.python_api = paddle.nn.functional.embedding + self.dtype = np.uint16 + + table = np.random.random((18, 32)).astype("float32") + ids = np.random.randint(0, 18, 4).astype(self.id_dtype()) + + self.inputs = {'W': convert_float_to_uint16(table), 'Ids': ids} + self.outputs = {'Out': convert_float_to_uint16(table[ids])} + + def id_dtype(self): + return "int64" + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place, check_cinn=True) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, ['W'], 'Out', no_grad_set=set('Ids'), check_cinn=True + ) + + if __name__ == "__main__": paddle.enable_static() unittest.main()