From 665a24c9ac610d92d149334969f513e5dc7668e2 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Sat, 15 Jan 2022 16:23:59 +0000 Subject: [PATCH 1/2] speedup gelu using fast math --- paddle/fluid/operators/gelu_op.cu | 85 +++++++++++++++++++ paddle/fluid/platform/flags.cc | 3 + .../fluid/tests/unittests/test_gelu_op.py | 24 ++++++ 3 files changed, 112 insertions(+) diff --git a/paddle/fluid/operators/gelu_op.cu b/paddle/fluid/operators/gelu_op.cu index 8151d21fa676d2..966a6ecbe26603 100644 --- a/paddle/fluid/operators/gelu_op.cu +++ b/paddle/fluid/operators/gelu_op.cu @@ -16,9 +16,82 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/gelu_op.h" +DECLARE_bool(use_fast_math); + namespace paddle { namespace operators { +#ifdef __NVCC__ +template +static __device__ __forceinline__ float FP32Gelu(float x) { + auto tmp = 0.79788456f * x * (1.0f + 0.044715f * x * x); + float tanh_tmp; +#if __CUDA_ARCH__ >= 750 && !defined(_WIN32) + if (FastMode) { + asm("tanh.approx.f32 %0,%1; \n\t" : "=f"(tanh_tmp) : "f"(tmp)); + } else { + tanh_tmp = tanhf(tmp); + } +#else + tanh_tmp = tanhf(tmp); +#endif + return x * 0.5f * (1.0f + tanh_tmp); +} + +template +static __global__ void FP16FastGeluCUDAKernel(const __half* x, __half* y, + size_t n) { + size_t offset = + static_cast(threadIdx.x + blockIdx.x * blockDim.x) * VecSize; + size_t stride = static_cast(blockDim.x * gridDim.x) * VecSize; + for (; offset < n; offset += stride) { + using ArrT = platform::AlignedVector<__half, VecSize>; + ArrT in_arr = *reinterpret_cast(x + offset); +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + float tmp = __half2float(in_arr[i]); + in_arr[i] = __float2half(FP32Gelu(tmp)); + } + *reinterpret_cast(y + offset) = in_arr; + } +} + +static bool TryLaunchFP16FastGeluVectorizeCUDAKernel( + const platform::CUDADeviceContext& dev_ctx, const __half* x, __half* y, + size_t n) { + auto is_aligned = [](const void* p, size_t alignment) { + return reinterpret_cast(p) % alignment == 0; + }; + +#define PD_LAUNCH_FP16_FAST_GELU_KERNEL(__vec_size, __use_fast_math) \ + do { \ + constexpr auto kAlignment = \ + alignof(platform::AlignedVector<__half, __vec_size>); \ + if (n % __vec_size == 0 && is_aligned(x, kAlignment) && \ + is_aligned(y, kAlignment)) { \ + size_t thread = std::min(512, dev_ctx.GetMaxThreadsPerBlock()); \ + size_t block = (n / __vec_size + thread - 1) / thread; \ + block = std::min(block, dev_ctx.GetCUDAMaxGridDimSize().x); \ + VLOG(10) << "Use FP16 fast gelu kernel, block = " << block \ + << " , thread = " << thread; \ + FP16FastGeluCUDAKernel< \ + __vec_size, \ + __use_fast_math><<>>(x, y, n); \ + return true; \ + } \ + } while (0) + + if (FLAGS_use_fast_math) { + PD_LAUNCH_FP16_FAST_GELU_KERNEL(8, true); + } else { + PD_LAUNCH_FP16_FAST_GELU_KERNEL(8, false); + } + +#undef PD_LAUNCH_FP16_FAST_GELU_KERNEL + return false; +} +#endif + template struct GeluWithApproximateFunctor { using MPType = typename details::MPTypeTrait::Type; @@ -59,7 +132,19 @@ class GeluKernel std::vector outs = {out}; const auto& dev_ctx = context.template device_context(); + if (approximate) { +#ifdef __NVCC__ + size_t n = in->numel(); + if (std::is_same::value) { + const auto* in_ptr = reinterpret_cast(in->data()); + auto* out_ptr = reinterpret_cast<__half*>(out->data()); + if (TryLaunchFP16FastGeluVectorizeCUDAKernel(dev_ctx, in_ptr, out_ptr, + n)) { + return; + } + } +#endif LaunchElementwiseCudaKernel( dev_ctx, ins, &outs, 0, GeluWithApproximateFunctor()); } else { diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index 8b117a5a8292d5..44bd4eaa29b807 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -652,6 +652,9 @@ PADDLE_DEFINE_EXPORTED_bool( #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PADDLE_DEFINE_EXPORTED_bool(conv2d_disable_cudnn, false, "Disable cudnn in conv2d"); + +PADDLE_DEFINE_EXPORTED_bool(use_fast_math, false, + "Whether to use fast math GPU functions."); #endif /** diff --git a/python/paddle/fluid/tests/unittests/test_gelu_op.py b/python/paddle/fluid/tests/unittests/test_gelu_op.py index 13174edad8f174..947f8fc8c3fcb2 100644 --- a/python/paddle/fluid/tests/unittests/test_gelu_op.py +++ b/python/paddle/fluid/tests/unittests/test_gelu_op.py @@ -19,6 +19,8 @@ from scipy.special import erf import paddle.fluid as fluid import paddle.fluid.dygraph as dg +import paddle +import paddle.nn.functional as F def gelu(x, approximate): @@ -59,6 +61,28 @@ def test_cases(self): if fluid.is_compiled_with_cuda(): self._test_case1_gpu(approximate) + def test_fast_math(self): + if not paddle.is_compiled_with_cuda(): + return + + def use_fast_math(enabled): + paddle.set_flags({'FLAGS_use_fast_math': enabled}) + + x_np = np.random.uniform(-1, 1, size=(11, 17, 8)).astype(np.float16) + + def run_gelu_op(approximate): + with dg.guard(): + x = paddle.to_tensor(x_np) + y = F.gelu(x, approximate=approximate) + return y.numpy() + + use_fast_math(True) + y_fast_math = run_gelu_op(True) + use_fast_math(False) + + y_ref = run_gelu_op(False) + self.assertTrue(np.allclose(y_ref, y_fast_math, rtol=1e-5, atol=5e-4)) + if __name__ == '__main__': unittest.main() From 2842e568c9b028944eaf52a335d466782ed55e0e Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Mon, 17 Jan 2022 06:14:15 +0000 Subject: [PATCH 2/2] add bwd part --- paddle/fluid/operators/gelu_op.cu | 130 +++++++++++++++--- .../fluid/tests/unittests/test_gelu_op.py | 16 ++- 2 files changed, 120 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/operators/gelu_op.cu b/paddle/fluid/operators/gelu_op.cu index 966a6ecbe26603..2f72fbff2668b7 100644 --- a/paddle/fluid/operators/gelu_op.cu +++ b/paddle/fluid/operators/gelu_op.cu @@ -23,24 +23,37 @@ namespace operators { #ifdef __NVCC__ template -static __device__ __forceinline__ float FP32Gelu(float x) { - auto tmp = 0.79788456f * x * (1.0f + 0.044715f * x * x); - float tanh_tmp; +static __device__ __forceinline__ float FP32FastTanh(float x) { #if __CUDA_ARCH__ >= 750 && !defined(_WIN32) if (FastMode) { - asm("tanh.approx.f32 %0,%1; \n\t" : "=f"(tanh_tmp) : "f"(tmp)); - } else { - tanh_tmp = tanhf(tmp); + float y; + asm("tanh.approx.f32 %0,%1; \n\t" : "=f"(y) : "f"(x)); + return y; } -#else - tanh_tmp = tanhf(tmp); #endif - return x * 0.5f * (1.0f + tanh_tmp); + return tanhf(x); +} + +template +static __device__ __forceinline__ float FP32GeluFwd(float x) { + auto tanh_out = + FP32FastTanh(0.79788456f * x * (1.0f + 0.044715f * x * x)); + return x * 0.5f * (1.0f + tanh_out); +} + +template +static __device__ __forceinline__ float FP32GeluBwd(float x, float y_g) { + auto tanh_out = + FP32FastTanh(0.79788456f * x * (1.0f + 0.044715f * x * x)); + auto tmp = 0.5f * x * ((1.0f - tanh_out * tanh_out) * + (0.79788456f + 0.1070322243f * x * x)) + + 0.5f * (1.0f + tanh_out); + return tmp * y_g; } template -static __global__ void FP16FastGeluCUDAKernel(const __half* x, __half* y, - size_t n) { +static __global__ void FP16FastGeluFwdCUDAKernel(const __half* x, __half* y, + size_t n) { size_t offset = static_cast(threadIdx.x + blockIdx.x * blockDim.x) * VecSize; size_t stride = static_cast(blockDim.x * gridDim.x) * VecSize; @@ -50,20 +63,44 @@ static __global__ void FP16FastGeluCUDAKernel(const __half* x, __half* y, #pragma unroll for (int i = 0; i < VecSize; ++i) { float tmp = __half2float(in_arr[i]); - in_arr[i] = __float2half(FP32Gelu(tmp)); + in_arr[i] = __float2half(FP32GeluFwd(tmp)); } *reinterpret_cast(y + offset) = in_arr; } } -static bool TryLaunchFP16FastGeluVectorizeCUDAKernel( +template +static __global__ void FP16FastGeluBwdCUDAKernel(const __half* x, + const __half* y_g, __half* x_g, + size_t n) { + size_t offset = + static_cast(threadIdx.x + blockIdx.x * blockDim.x) * VecSize; + size_t stride = static_cast(blockDim.x * gridDim.x) * VecSize; + for (; offset < n; offset += stride) { + using ArrT = platform::AlignedVector<__half, VecSize>; + ArrT x_in_arr = *reinterpret_cast(x + offset); + ArrT y_g_in_arr = *reinterpret_cast(y_g + offset); +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + __half2 tmp_fp16_2; + tmp_fp16_2.x = x_in_arr[i]; + tmp_fp16_2.y = y_g_in_arr[i]; + float2 tmp_fp32_2 = __half22float2(tmp_fp16_2); + x_in_arr[i] = + __float2half(FP32GeluBwd(tmp_fp32_2.x, tmp_fp32_2.y)); + } + *reinterpret_cast(x_g + offset) = x_in_arr; + } +} + +static bool TryLaunchFP16FastGeluFwdVectorizeCUDAKernel( const platform::CUDADeviceContext& dev_ctx, const __half* x, __half* y, size_t n) { auto is_aligned = [](const void* p, size_t alignment) { return reinterpret_cast(p) % alignment == 0; }; -#define PD_LAUNCH_FP16_FAST_GELU_KERNEL(__vec_size, __use_fast_math) \ +#define PD_LAUNCH_FP16_FAST_GELU_FWD_KERNEL(__vec_size, __use_fast_math) \ do { \ constexpr auto kAlignment = \ alignof(platform::AlignedVector<__half, __vec_size>); \ @@ -72,9 +109,9 @@ static bool TryLaunchFP16FastGeluVectorizeCUDAKernel( size_t thread = std::min(512, dev_ctx.GetMaxThreadsPerBlock()); \ size_t block = (n / __vec_size + thread - 1) / thread; \ block = std::min(block, dev_ctx.GetCUDAMaxGridDimSize().x); \ - VLOG(10) << "Use FP16 fast gelu kernel, block = " << block \ + VLOG(10) << "Use FP16 fast gelu fwd kernel, block = " << block \ << " , thread = " << thread; \ - FP16FastGeluCUDAKernel< \ + FP16FastGeluFwdCUDAKernel< \ __vec_size, \ __use_fast_math><<>>(x, y, n); \ return true; \ @@ -82,12 +119,49 @@ static bool TryLaunchFP16FastGeluVectorizeCUDAKernel( } while (0) if (FLAGS_use_fast_math) { - PD_LAUNCH_FP16_FAST_GELU_KERNEL(8, true); + PD_LAUNCH_FP16_FAST_GELU_FWD_KERNEL(8, true); } else { - PD_LAUNCH_FP16_FAST_GELU_KERNEL(8, false); + PD_LAUNCH_FP16_FAST_GELU_FWD_KERNEL(8, false); } -#undef PD_LAUNCH_FP16_FAST_GELU_KERNEL +#undef PD_LAUNCH_FP16_FAST_GELU_FWD_KERNEL + return false; +} + +static bool TryLaunchFP16FastGeluBwdVectorizeCUDAKernel( + const platform::CUDADeviceContext& dev_ctx, const __half* x, + const __half* y_g, __half* x_g, size_t n) { + auto is_aligned = [](const void* p, size_t alignment) { + return reinterpret_cast(p) % alignment == 0; + }; + +#define PD_LAUNCH_FP16_FAST_GELU_BWD_KERNEL(__vec_size, __use_fast_math) \ + do { \ + constexpr auto kAlignment = \ + alignof(platform::AlignedVector<__half, __vec_size>); \ + if (n % __vec_size == 0 && is_aligned(x, kAlignment) && \ + is_aligned(x, kAlignment) && is_aligned(y_g, kAlignment) && \ + is_aligned(x_g, kAlignment)) { \ + size_t thread = std::min(512, dev_ctx.GetMaxThreadsPerBlock()); \ + size_t block = (n / __vec_size + thread - 1) / thread; \ + block = std::min(block, dev_ctx.GetCUDAMaxGridDimSize().x); \ + VLOG(10) << "Use FP16 fast gelu bwd kernel, block = " << block \ + << " , thread = " << thread; \ + FP16FastGeluBwdCUDAKernel< \ + __vec_size, \ + __use_fast_math><<>>(x, y_g, \ + x_g, n); \ + return true; \ + } \ + } while (0) + + if (FLAGS_use_fast_math) { + PD_LAUNCH_FP16_FAST_GELU_BWD_KERNEL(8, true); + } else { + PD_LAUNCH_FP16_FAST_GELU_BWD_KERNEL(8, false); + } + +#undef PD_LAUNCH_FP16_FAST_GELU_BWD_KERNEL return false; } #endif @@ -135,12 +209,12 @@ class GeluKernel if (approximate) { #ifdef __NVCC__ - size_t n = in->numel(); if (std::is_same::value) { + size_t n = in->numel(); const auto* in_ptr = reinterpret_cast(in->data()); auto* out_ptr = reinterpret_cast<__half*>(out->data()); - if (TryLaunchFP16FastGeluVectorizeCUDAKernel(dev_ctx, in_ptr, out_ptr, - n)) { + if (TryLaunchFP16FastGeluFwdVectorizeCUDAKernel(dev_ctx, in_ptr, + out_ptr, n)) { return; } } @@ -205,6 +279,18 @@ class GeluGradKernel const auto& dev_ctx = context.template device_context(); if (approximate) { +#ifdef __NVCC__ + if (std::is_same::value) { + size_t n = x->numel(); + const auto* x_ptr = reinterpret_cast(x->data()); + const auto* y_g_ptr = reinterpret_cast(dout->data()); + auto* x_g_ptr = reinterpret_cast<__half*>(dx->data()); + if (TryLaunchFP16FastGeluBwdVectorizeCUDAKernel(dev_ctx, x_ptr, y_g_ptr, + x_g_ptr, n)) { + return; + } + } +#endif LaunchElementwiseCudaKernel( dev_ctx, ins, &outs, 0, GeluWithApproximateGradFunctor()); } else { diff --git a/python/paddle/fluid/tests/unittests/test_gelu_op.py b/python/paddle/fluid/tests/unittests/test_gelu_op.py index 947f8fc8c3fcb2..de34b63c9398e7 100644 --- a/python/paddle/fluid/tests/unittests/test_gelu_op.py +++ b/python/paddle/fluid/tests/unittests/test_gelu_op.py @@ -68,21 +68,29 @@ def test_fast_math(self): def use_fast_math(enabled): paddle.set_flags({'FLAGS_use_fast_math': enabled}) - x_np = np.random.uniform(-1, 1, size=(11, 17, 8)).astype(np.float16) + shape = [11, 17, 8] + x_np = np.random.uniform(-1, 1, size=shape).astype(np.float16) + y_g_np = np.random.uniform(-1, 1, size=shape).astype(np.float16) def run_gelu_op(approximate): with dg.guard(): x = paddle.to_tensor(x_np) + x.stop_gradient = False y = F.gelu(x, approximate=approximate) - return y.numpy() + x_grad = paddle.grad([y], [x], [paddle.to_tensor(y_g_np)])[0] + return y.numpy(), x_grad.numpy() use_fast_math(True) - y_fast_math = run_gelu_op(True) + y_fast_math, x_g_fast_math = run_gelu_op(True) use_fast_math(False) - y_ref = run_gelu_op(False) + y_ref, x_g_ref = run_gelu_op(True) self.assertTrue(np.allclose(y_ref, y_fast_math, rtol=1e-5, atol=5e-4)) + self.assertTrue( + np.allclose( + x_g_ref, x_g_fast_math, rtol=1e-5, atol=5e-4)) + if __name__ == '__main__': unittest.main()