From 074782b4c1bb70ba8f3274bd015282e44d5552b1 Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Sun, 26 Feb 2023 03:01:13 +0800 Subject: [PATCH 01/25] add fp16 testcase for dist op --- .../fluid/tests/unittests/test_dist_op.py | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_dist_op.py b/python/paddle/fluid/tests/unittests/test_dist_op.py index 96c0de915cff2..4f6eb345308ef 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_op.py +++ b/python/paddle/fluid/tests/unittests/test_dist_op.py @@ -158,6 +158,46 @@ def init_case(self): self.p = 1.5 +class TestDistOpWithFP16(TestDistOp): + def init_data_type(self): + self.data_type = 'float16' + + +class TestDistOpWithFP16Case1(TestDistOpWithFP16): + def init_case(self): + self.x_shape = (3, 5, 5, 6) + self.y_shape = (5, 5, 6) + self.p = 1.0 + + +class TestDistOpWithFP16Case2(TestDistOpWithFP16): + def init_case(self): + self.x_shape = (10, 10) + self.y_shape = (4, 10, 10) + self.p = 2.0 + + +class TestDistOpWithFP16Case3(TestDistOpWithFP16): + def init_case(self): + self.x_shape = (15, 10) + self.y_shape = (15, 10) + self.p = float("inf") + + +class TestDistOpWithFP16Case4(TestDistOpWithFP16): + def init_case(self): + self.x_shape = (2, 3, 4, 5, 8) + self.y_shape = (3, 1, 5, 8) + self.p = float("-inf") + + +class TestDistOpWithFP16Case5(TestDistOpWithFP16): + def init_case(self): + self.x_shape = (4, 1, 4, 8) + self.y_shape = (2, 2, 1, 4, 4, 8) + self.p = 1.5 + + class TestDistAPI(unittest.TestCase): def init_data_type(self): self.data_type = ( @@ -202,6 +242,11 @@ def test_grad_x(self): paddle.enable_static() +class TestDistAPIWithFP16(unittest.TestCase): + def init_data_type(self): + self.data_type = 'float16' + + if __name__ == '__main__': paddle.enable_static() unittest.main() From 16b2ac31d6f3147c2dffd476c8f73d60b37d8667 Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Sun, 26 Feb 2023 03:12:00 +0800 Subject: [PATCH 02/25] add desc for fp16, test=document_fix --- python/paddle/tensor/linalg.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 2235cf93cfb60..9ab68675b92ee 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -675,8 +675,8 @@ def dist(x, y, p=2, name=None): ||z||_{p}=(\sum_{i=1}^{m}|z_i|^p)^{\\frac{1}{p}} Args: - x (Tensor): 1-D to 6-D Tensor, its data type is float32 or float64. - y (Tensor): 1-D to 6-D Tensor, its data type is float32 or float64. + x (Tensor): 1-D to 6-D Tensor, its data type is float16, float32 or float64. + y (Tensor): 1-D to 6-D Tensor, its data type is flaot16, float32 or float64. p (float, optional): The norm to be computed, its data type is float32 or float64. Default: 2. name (str, optional): The default value is `None`. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. @@ -706,8 +706,12 @@ def dist(x, y, p=2, name=None): if in_dygraph_mode(): return _C_ops.dist(x, y, p) - check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'dist') - check_variable_and_dtype(y, 'dtype', ['float32', 'float64'], 'dist') + check_variable_and_dtype( + x, 'dtype', ['float16', 'float32', 'float64'], 'dist' + ) + check_variable_and_dtype( + y, 'dtype', ['float16', 'float32', 'float64'], 'dist' + ) check_type(p, 'p', (float, int), 'dist') helper = LayerHelper("dist", **locals()) out = helper.create_variable_for_type_inference(x.dtype) From cbcf6bcab69f24f6e0bb1ad6f430e8e7f3234a4a Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Fri, 10 Mar 2023 17:12:48 +0800 Subject: [PATCH 03/25] fix typo --- python/paddle/tensor/linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 9ab68675b92ee..57550c86368a2 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -676,7 +676,7 @@ def dist(x, y, p=2, name=None): Args: x (Tensor): 1-D to 6-D Tensor, its data type is float16, float32 or float64. - y (Tensor): 1-D to 6-D Tensor, its data type is flaot16, float32 or float64. + y (Tensor): 1-D to 6-D Tensor, its data type is float16, float32 or float64. p (float, optional): The norm to be computed, its data type is float32 or float64. Default: 2. name (str, optional): The default value is `None`. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. From 3c595ae5b08f79b05a8e234e26a767d0e6d97026 Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Sun, 19 Mar 2023 18:09:11 +0800 Subject: [PATCH 04/25] add fp16 support --- paddle/phi/kernels/dist_kernel.cc | 8 +++- paddle/phi/kernels/funcs/math_cuda_utils.h | 43 +++++++++++++++++++++- paddle/phi/kernels/gpu/dist_kernel.cu | 30 ++++++++++++--- 3 files changed, 72 insertions(+), 9 deletions(-) diff --git a/paddle/phi/kernels/dist_kernel.cc b/paddle/phi/kernels/dist_kernel.cc index 90f199531d2a4..af48fe279d74f 100644 --- a/paddle/phi/kernels/dist_kernel.cc +++ b/paddle/phi/kernels/dist_kernel.cc @@ -33,4 +33,10 @@ void DistKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(dist, CPU, ALL_LAYOUT, phi::DistKernel, float, double) {} +PD_REGISTER_KERNEL(dist, + CPU, + ALL_LAYOUT, + phi::DistKernel, + phi::dtype::float16, + float, + double) {} diff --git a/paddle/phi/kernels/funcs/math_cuda_utils.h b/paddle/phi/kernels/funcs/math_cuda_utils.h index b493e2ac41bfd..eb5c462df5052 100644 --- a/paddle/phi/kernels/funcs/math_cuda_utils.h +++ b/paddle/phi/kernels/funcs/math_cuda_utils.h @@ -23,6 +23,9 @@ limitations under the License. */ #include +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/float16.h" + namespace phi { namespace funcs { @@ -178,6 +181,18 @@ __inline__ __device__ T WarpReduceSum(T val, unsigned lane_mask) { return val; } +template <> +__inline__ __device__ phi::dtype::float16 WarpReduceSum(phi::dtype::float16 val, + unsigned lane_mask) { + for (int mask = HALF_WARP; mask > 0; mask >>= 1) +#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) + val += __shfl_xor_sync(lane_mask, val.to_half(), mask, warpSize); +#else + val += __shfl_xor(val.to_half(), mask, warpSize); +#endif + return phi::dtype::float16(val); +} + /* Calculate the sum of all elements in a block */ template __inline__ __device__ T BlockReduceSum(T val, unsigned mask) { @@ -251,6 +266,18 @@ __inline__ __device__ T WarpReduceMax(T val, unsigned lane_mask) { return val; } +template <> +__inline__ __device__ phi::dtype::float16 WarpReduceMax(phi::dtype::float16 val, + unsigned lane_mask) { + for (int mask = HALF_WARP; mask > 0; mask >>= 1) +#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) + val = max(val, __shfl_xor_sync(lane_mask, val.to_half(), mask, warpSize)); +#else + val = max(val, __shfl_xor(val.to_half(), mask, warpSize)); +#endif + return phi::dtype::float16(val); +} + template __inline__ __device__ T WarpReduceMaxV2(T *val) { #pragma unroll @@ -273,6 +300,18 @@ __inline__ __device__ T WarpReduceMin(T val, unsigned lane_mask) { return val; } +template <> +__inline__ __device__ phi::dtype::float16 WarpReduceMin(phi::dtype::float16 val, + unsigned lane_mask) { + for (int mask = HALF_WARP; mask > 0; mask >>= 1) +#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) + val = min(val, __shfl_xor_sync(lane_mask, val.to_half(), mask, warpSize)); +#else + val = min(val, __shfl_xor(val.to_half(), mask, warpSize)); +#endif + return phi::dtype::float16(val); +} + /* Calculate the minimum of all elements in a warp when actual quantity of * threads are less than warpSize.*/ template @@ -310,7 +349,7 @@ __inline__ __device__ T BlockReduceMax(T val, unsigned mask) { // align block_span to warpSize int block_span = (blockDim.x + warpSize - 1) >> 5; - val = (lane < block_span) ? shared[lane] : -1e10f; + val = (lane < block_span) ? shared[lane] : T(-1e10f); val = WarpReduceMax(val, mask); return val; @@ -358,7 +397,7 @@ __inline__ __device__ T BlockReduceMin(T val, unsigned mask) { // align block_span to warpSize int block_span = (blockDim.x + warpSize - 1) >> 5; - val = (lane < block_span) ? shared[lane] : 1e10f; + val = (lane < block_span) ? shared[lane] : T(1e10f); val = WarpReduceMin(val, mask); return val; diff --git a/paddle/phi/kernels/gpu/dist_kernel.cu b/paddle/phi/kernels/gpu/dist_kernel.cu index 5040be8eaaca7..689fcf8a5ec5e 100644 --- a/paddle/phi/kernels/gpu/dist_kernel.cu +++ b/paddle/phi/kernels/gpu/dist_kernel.cu @@ -14,8 +14,10 @@ #include "paddle/phi/kernels/dist_kernel.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/elementwise_subtract_kernel.h" +#include "paddle/phi/kernels/funcs/eigen/extensions.h" #include "paddle/phi/kernels/funcs/math_cuda_utils.h" #include "paddle/phi/kernels/gpu/reduce.h" #include "paddle/phi/kernels/p_norm_kernel.h" @@ -24,11 +26,21 @@ namespace phi { #define FULL_MASK 0xffffffff +template +__device__ T max(T a, T b) { + return a > b ? a : b; +} + +template +__device__ T min(T a, T b) { + return a < b ? a : b; +} + template struct ZeroOrderFunctor { public: __device__ T operator()(const T& x, const T& y) const { - return static_cast((x - y) != 0); + return static_cast((x - y) != T(0)); } }; @@ -55,7 +67,7 @@ struct PowFunctor { template __global__ void ReduceSumWithSubtract( const T* x, const T* y, T* out, int64_t N, Functor func) { - T sum_val = 0; + T sum_val = T(0); for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { sum_val += func(x[i], y[i]); @@ -73,7 +85,7 @@ __global__ void ReduceMaxWithSubtract(const T* x, const T* y, T* out, int64_t N) { - T max_val = -1e10f; + T max_val = T(-1e10f); for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { max_val = max(max_val, abs(x[i] - y[i])); @@ -91,7 +103,7 @@ __global__ void ReduceMinWithSubtract(const T* x, const T* y, T* out, int64_t N) { - T min_val = 1e10f; + T min_val = T(1e10f); for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { min_val = min(min_val, abs(x[i] - y[i])); @@ -160,7 +172,7 @@ void DistKernel(const Context& dev_ctx, const DenseTensor* tmp_norm = out; std::vector ins = {tmp_norm}; std::vector outs = {out}; - T p_order_ = static_cast(1. / p_order); + T p_order_ = static_cast(T(1.) / p_order); phi::funcs::ElementwiseKernel( dev_ctx, ins, &outs, PowFunctor(p_order_)); } @@ -173,4 +185,10 @@ void DistKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(dist, GPU, ALL_LAYOUT, phi::DistKernel, float, double) {} +PD_REGISTER_KERNEL(dist, + GPU, + ALL_LAYOUT, + phi::DistKernel, + phi::dtype::float16, + float, + double) {} From 358ab86d37985b270e825f6d25e2769f6a66468f Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Sun, 19 Mar 2023 18:26:47 +0800 Subject: [PATCH 05/25] remove fp16 for cpu --- paddle/phi/kernels/dist_kernel.cc | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/paddle/phi/kernels/dist_kernel.cc b/paddle/phi/kernels/dist_kernel.cc index af48fe279d74f..90f199531d2a4 100644 --- a/paddle/phi/kernels/dist_kernel.cc +++ b/paddle/phi/kernels/dist_kernel.cc @@ -33,10 +33,4 @@ void DistKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(dist, - CPU, - ALL_LAYOUT, - phi::DistKernel, - phi::dtype::float16, - float, - double) {} +PD_REGISTER_KERNEL(dist, CPU, ALL_LAYOUT, phi::DistKernel, float, double) {} From b636d36d9e2fc7eb1dc11e5c8de10ea3fae1a9c3 Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Sun, 19 Mar 2023 18:56:08 +0800 Subject: [PATCH 06/25] modify the ut --- .../fluid/tests/unittests/test_dist_op.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_dist_op.py b/python/paddle/fluid/tests/unittests/test_dist_op.py index 4f6eb345308ef..3654da4e00fae 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_op.py +++ b/python/paddle/fluid/tests/unittests/test_dist_op.py @@ -158,40 +158,40 @@ def init_case(self): self.p = 1.5 -class TestDistOpWithFP16(TestDistOp): +class TestDistFP16Op(OpTest): def init_data_type(self): self.data_type = 'float16' -class TestDistOpWithFP16Case1(TestDistOpWithFP16): +class TestDistFP16OpCase1(TestDistFP16Op): def init_case(self): self.x_shape = (3, 5, 5, 6) self.y_shape = (5, 5, 6) self.p = 1.0 -class TestDistOpWithFP16Case2(TestDistOpWithFP16): +class TestDistFP16OpCase2(TestDistFP16Op): def init_case(self): self.x_shape = (10, 10) self.y_shape = (4, 10, 10) self.p = 2.0 -class TestDistOpWithFP16Case3(TestDistOpWithFP16): +class TestDistFP16OpCase3(TestDistFP16Op): def init_case(self): self.x_shape = (15, 10) self.y_shape = (15, 10) self.p = float("inf") -class TestDistOpWithFP16Case4(TestDistOpWithFP16): +class TestDistFP16OpCase4(TestDistFP16Op): def init_case(self): self.x_shape = (2, 3, 4, 5, 8) self.y_shape = (3, 1, 5, 8) self.p = float("-inf") -class TestDistOpWithFP16Case5(TestDistOpWithFP16): +class TestDistFP16OpCase5(TestDistFP16Op): def init_case(self): self.x_shape = (4, 1, 4, 8) self.y_shape = (2, 2, 1, 4, 4, 8) @@ -242,11 +242,6 @@ def test_grad_x(self): paddle.enable_static() -class TestDistAPIWithFP16(unittest.TestCase): - def init_data_type(self): - self.data_type = 'float16' - - if __name__ == '__main__': paddle.enable_static() unittest.main() From 3f2feff3f4c7a8459401ff675dbe9f9f00be9348 Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Sun, 19 Mar 2023 20:17:51 +0800 Subject: [PATCH 07/25] fix erros --- paddle/phi/kernels/funcs/math_cuda_utils.h | 23 +++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/paddle/phi/kernels/funcs/math_cuda_utils.h b/paddle/phi/kernels/funcs/math_cuda_utils.h index eb5c462df5052..928360748dc40 100644 --- a/paddle/phi/kernels/funcs/math_cuda_utils.h +++ b/paddle/phi/kernels/funcs/math_cuda_utils.h @@ -184,13 +184,14 @@ __inline__ __device__ T WarpReduceSum(T val, unsigned lane_mask) { template <> __inline__ __device__ phi::dtype::float16 WarpReduceSum(phi::dtype::float16 val, unsigned lane_mask) { + __half val_half = val.to_half(); for (int mask = HALF_WARP; mask > 0; mask >>= 1) #if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) - val += __shfl_xor_sync(lane_mask, val.to_half(), mask, warpSize); + val_half += __shfl_xor_sync(lane_mask, val_half, mask, warpSize); #else - val += __shfl_xor(val.to_half(), mask, warpSize); + val_half += __shfl_xor(val_half, mask, warpSize); #endif - return phi::dtype::float16(val); + return phi::dtype::float16(val_half); } /* Calculate the sum of all elements in a block */ @@ -269,13 +270,15 @@ __inline__ __device__ T WarpReduceMax(T val, unsigned lane_mask) { template <> __inline__ __device__ phi::dtype::float16 WarpReduceMax(phi::dtype::float16 val, unsigned lane_mask) { + __half val_half = val.to_half(); for (int mask = HALF_WARP; mask > 0; mask >>= 1) #if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) - val = max(val, __shfl_xor_sync(lane_mask, val.to_half(), mask, warpSize)); + val_half = + max(val_half, __shfl_xor_sync(lane_mask, val_half, mask, warpSize)); #else - val = max(val, __shfl_xor(val.to_half(), mask, warpSize)); + val_half = max(val_half, __shfl_xor(val_half, mask, warpSize)); #endif - return phi::dtype::float16(val); + return phi::dtype::float16(val_half); } template @@ -303,13 +306,15 @@ __inline__ __device__ T WarpReduceMin(T val, unsigned lane_mask) { template <> __inline__ __device__ phi::dtype::float16 WarpReduceMin(phi::dtype::float16 val, unsigned lane_mask) { + __half val_half = val.to_half(); for (int mask = HALF_WARP; mask > 0; mask >>= 1) #if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) - val = min(val, __shfl_xor_sync(lane_mask, val.to_half(), mask, warpSize)); + val_half = + min(val_half, __shfl_xor_sync(lane_mask, val_half, mask, warpSize)); #else - val = min(val, __shfl_xor(val.to_half(), mask, warpSize)); + val_half = min(val_half, __shfl_xor(val_half, mask, warpSize)); #endif - return phi::dtype::float16(val); + return phi::dtype::float16(val_half); } /* Calculate the minimum of all elements in a warp when actual quantity of From a54bdb544d5a0e9f34102ddc271eb612120b53a0 Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Sun, 19 Mar 2023 20:54:19 +0800 Subject: [PATCH 08/25] fix errors --- paddle/phi/kernels/funcs/math_cuda_utils.h | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/paddle/phi/kernels/funcs/math_cuda_utils.h b/paddle/phi/kernels/funcs/math_cuda_utils.h index 928360748dc40..de7a2d7ae6e61 100644 --- a/paddle/phi/kernels/funcs/math_cuda_utils.h +++ b/paddle/phi/kernels/funcs/math_cuda_utils.h @@ -29,6 +29,16 @@ limitations under the License. */ namespace phi { namespace funcs { +template +__device__ T max(T a, T b) { + return a > b ? a : b; +} + +template +__device__ T min(T a, T b) { + return a < b ? a : b; +} + template __device__ __forceinline__ T FromFloat(float a); From 8dc9f982fc48a8bbb9968d754bde6fb6b85506d4 Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Sun, 19 Mar 2023 21:35:32 +0800 Subject: [PATCH 09/25] fix pow related funcs --- paddle/phi/kernels/gpu/dist_kernel.cu | 31 ++++++++++++++------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/paddle/phi/kernels/gpu/dist_kernel.cu b/paddle/phi/kernels/gpu/dist_kernel.cu index 689fcf8a5ec5e..6f2834c499606 100644 --- a/paddle/phi/kernels/gpu/dist_kernel.cu +++ b/paddle/phi/kernels/gpu/dist_kernel.cu @@ -17,7 +17,6 @@ #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/elementwise_subtract_kernel.h" -#include "paddle/phi/kernels/funcs/eigen/extensions.h" #include "paddle/phi/kernels/funcs/math_cuda_utils.h" #include "paddle/phi/kernels/gpu/reduce.h" #include "paddle/phi/kernels/p_norm_kernel.h" @@ -44,24 +43,25 @@ struct ZeroOrderFunctor { } }; -template +template struct OtherOrderFunctor { - explicit OtherOrderFunctor(const T& p_order) : p_order_(p_order) {} - __device__ T operator()(const T& x, const T& y) const { - return static_cast(pow(abs(x - y), p_order_)); + explicit OtherOrderFunctor(const Ty& p_order) : p_order_(p_order) {} + __device__ Tx operator()(const Tx& x, const Tx& y) const { + return static_cast( + pow(abs(static_cast(x) - static_cast(y)), p_order_)); } private: - T p_order_; + Ty p_order_; }; -template +template struct PowFunctor { - explicit PowFunctor(const T& p_order) : p_order_(p_order) {} - HOSTDEVICE inline T operator()(const T x) const { - return static_cast(pow(x, p_order_)); + explicit PowFunctor(const Ty& p_order) : p_order_(p_order) {} + HOSTDEVICE inline Tx operator()(const Tx x) const { + return static_cast(pow(static_cast(x), p_order_)); } - T p_order_; + Ty p_order_; }; template @@ -162,19 +162,20 @@ void DistKernel(const Context& dev_ctx, dev_ctx, intermediate, out, kps::IdentityFunctor(), reduce_axis); } else { - T p_order = static_cast(p); + using MT = typename phi::dtype::MPTypeTrait::Type; + MT p_order = static_cast(p); ReduceSumWithSubtract <<>>( - x_ptr, y_ptr, i_ptr, n, OtherOrderFunctor(p_order)); + x_ptr, y_ptr, i_ptr, n, OtherOrderFunctor(p_order)); phi::funcs::ReduceKernel>( dev_ctx, intermediate, out, kps::IdentityFunctor(), reduce_axis); const DenseTensor* tmp_norm = out; std::vector ins = {tmp_norm}; std::vector outs = {out}; - T p_order_ = static_cast(T(1.) / p_order); + MT p_order_ = static_cast(static_cast(1.) / p_order); phi::funcs::ElementwiseKernel( - dev_ctx, ins, &outs, PowFunctor(p_order_)); + dev_ctx, ins, &outs, PowFunctor(p_order_)); } } else { From 4a3afbec96c3b8670b79fdcd67cd28fd9f46f411 Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Sun, 19 Mar 2023 22:39:10 +0800 Subject: [PATCH 10/25] fix rocm --- paddle/phi/kernels/funcs/math_cuda_utils.h | 37 ++++++++++++---------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/paddle/phi/kernels/funcs/math_cuda_utils.h b/paddle/phi/kernels/funcs/math_cuda_utils.h index de7a2d7ae6e61..5e1cbd4bac69c 100644 --- a/paddle/phi/kernels/funcs/math_cuda_utils.h +++ b/paddle/phi/kernels/funcs/math_cuda_utils.h @@ -194,14 +194,15 @@ __inline__ __device__ T WarpReduceSum(T val, unsigned lane_mask) { template <> __inline__ __device__ phi::dtype::float16 WarpReduceSum(phi::dtype::float16 val, unsigned lane_mask) { - __half val_half = val.to_half(); - for (int mask = HALF_WARP; mask > 0; mask >>= 1) #if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) - val_half += __shfl_xor_sync(lane_mask, val_half, mask, warpSize); + __half val_ret = val.to_half(); + for (int mask = HALF_WARP; mask > 0; mask >>= 1) + val_ret += __shfl_xor_sync(lane_mask, val_ret, mask, warpSize); #else - val_half += __shfl_xor(val_half, mask, warpSize); + val_ret = static_cast(val); + val_ret += __shfl_xor(val_ret, mask, warpSize); #endif - return phi::dtype::float16(val_half); + return static_cast(val_ret); } /* Calculate the sum of all elements in a block */ @@ -280,15 +281,16 @@ __inline__ __device__ T WarpReduceMax(T val, unsigned lane_mask) { template <> __inline__ __device__ phi::dtype::float16 WarpReduceMax(phi::dtype::float16 val, unsigned lane_mask) { - __half val_half = val.to_half(); - for (int mask = HALF_WARP; mask > 0; mask >>= 1) #if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) - val_half = - max(val_half, __shfl_xor_sync(lane_mask, val_half, mask, warpSize)); + __half val_ret = val.to_half(); + for (int mask = HALF_WARP; mask > 0; mask >>= 1) + val_ret = max(val_ret, __shfl_xor_sync(lane_mask, val_ret, mask, warpSize)); #else - val_half = max(val_half, __shfl_xor(val_half, mask, warpSize)); + float val_ret = static_cast(val); + for (int mask = HALF_WARP; mask > 0; mask >>= 1) + val_ret = max(val_ret, __shfl_xor(val_ret, mask, warpSize)); #endif - return phi::dtype::float16(val_half); + return static_cast(val_ret); } template @@ -316,15 +318,16 @@ __inline__ __device__ T WarpReduceMin(T val, unsigned lane_mask) { template <> __inline__ __device__ phi::dtype::float16 WarpReduceMin(phi::dtype::float16 val, unsigned lane_mask) { - __half val_half = val.to_half(); - for (int mask = HALF_WARP; mask > 0; mask >>= 1) #if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) - val_half = - min(val_half, __shfl_xor_sync(lane_mask, val_half, mask, warpSize)); + __half val_ret = val.to_half(); + for (int mask = HALF_WARP; mask > 0; mask >>= 1) + val_ret = min(val_ret, __shfl_xor_sync(lane_mask, val_ret, mask, warpSize)); #else - val_half = min(val_half, __shfl_xor(val_half, mask, warpSize)); + float val_ret = static_cast(val); + for (int mask = HALF_WARP; mask > 0; mask >>= 1) + val_ret = min(val_ret, __shfl_xor(val_ret, mask, warpSize)); #endif - return phi::dtype::float16(val_half); + return static_cast(val_ret); } /* Calculate the minimum of all elements in a warp when actual quantity of From 471b91ef97c298b06c025dd05447f2aef1433d1f Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Sun, 19 Mar 2023 22:47:08 +0800 Subject: [PATCH 11/25] add dist grad register fp16 --- paddle/phi/kernels/dist_grad_kernel.cc | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/dist_grad_kernel.cc b/paddle/phi/kernels/dist_grad_kernel.cc index 17c24fa905b5c..3c7ffc01edf07 100644 --- a/paddle/phi/kernels/dist_grad_kernel.cc +++ b/paddle/phi/kernels/dist_grad_kernel.cc @@ -98,6 +98,11 @@ PD_REGISTER_KERNEL( dist_grad, CPU, ALL_LAYOUT, phi::DistGradKernel, float, double) {} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PD_REGISTER_KERNEL( - dist_grad, GPU, ALL_LAYOUT, phi::DistGradKernel, float, double) {} +PD_REGISTER_KERNEL(dist_grad, + GPU, + ALL_LAYOUT, + phi::DistGradKernel, + phi::dtype::float16, + float, + double) {} #endif From df42d267eebf173c4d093cc54981fa303fdf16a7 Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Sun, 19 Mar 2023 23:14:58 +0800 Subject: [PATCH 12/25] fix error --- paddle/phi/kernels/funcs/math_cuda_utils.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/funcs/math_cuda_utils.h b/paddle/phi/kernels/funcs/math_cuda_utils.h index 5e1cbd4bac69c..e9a9dc0fa06ff 100644 --- a/paddle/phi/kernels/funcs/math_cuda_utils.h +++ b/paddle/phi/kernels/funcs/math_cuda_utils.h @@ -200,7 +200,8 @@ __inline__ __device__ phi::dtype::float16 WarpReduceSum(phi::dtype::float16 val, val_ret += __shfl_xor_sync(lane_mask, val_ret, mask, warpSize); #else val_ret = static_cast(val); - val_ret += __shfl_xor(val_ret, mask, warpSize); + for (int mask = HALF_WARP; mask > 0; mask >>= 1) + val_ret += __shfl_xor(val_ret, mask, warpSize); #endif return static_cast(val_ret); } From 426c0a01f0bef34989ab78090166597a6a85818f Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Sun, 19 Mar 2023 23:41:37 +0800 Subject: [PATCH 13/25] fix error --- paddle/phi/kernels/funcs/math_cuda_utils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/kernels/funcs/math_cuda_utils.h b/paddle/phi/kernels/funcs/math_cuda_utils.h index e9a9dc0fa06ff..332376f3440c9 100644 --- a/paddle/phi/kernels/funcs/math_cuda_utils.h +++ b/paddle/phi/kernels/funcs/math_cuda_utils.h @@ -199,7 +199,7 @@ __inline__ __device__ phi::dtype::float16 WarpReduceSum(phi::dtype::float16 val, for (int mask = HALF_WARP; mask > 0; mask >>= 1) val_ret += __shfl_xor_sync(lane_mask, val_ret, mask, warpSize); #else - val_ret = static_cast(val); + float val_ret = static_cast(val); for (int mask = HALF_WARP; mask > 0; mask >>= 1) val_ret += __shfl_xor(val_ret, mask, warpSize); #endif From 6b9a6eee855a5b0fb31c74c2493aec46244cc7d4 Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Wed, 22 Mar 2023 21:42:56 +0800 Subject: [PATCH 14/25] call std::max/min --- paddle/phi/kernels/funcs/math_cuda_utils.h | 32 ++++++++-------------- paddle/phi/kernels/gpu/dist_kernel.cu | 22 +++++---------- 2 files changed, 19 insertions(+), 35 deletions(-) diff --git a/paddle/phi/kernels/funcs/math_cuda_utils.h b/paddle/phi/kernels/funcs/math_cuda_utils.h index 332376f3440c9..9fe5d65995cac 100644 --- a/paddle/phi/kernels/funcs/math_cuda_utils.h +++ b/paddle/phi/kernels/funcs/math_cuda_utils.h @@ -29,16 +29,6 @@ limitations under the License. */ namespace phi { namespace funcs { -template -__device__ T max(T a, T b) { - return a > b ? a : b; -} - -template -__device__ T min(T a, T b) { - return a < b ? a : b; -} - template __device__ __forceinline__ T FromFloat(float a); @@ -272,9 +262,9 @@ template __inline__ __device__ T WarpReduceMax(T val, unsigned lane_mask) { for (int mask = HALF_WARP; mask > 0; mask >>= 1) #if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) - val = max(val, __shfl_xor_sync(lane_mask, val, mask, warpSize)); + val = std::max(val, __shfl_xor_sync(lane_mask, val, mask, warpSize)); #else - val = max(val, __shfl_xor(val, mask, warpSize)); + val = std::max(val, __shfl_xor(val, mask, warpSize)); #endif return val; } @@ -285,11 +275,12 @@ __inline__ __device__ phi::dtype::float16 WarpReduceMax(phi::dtype::float16 val, #if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) __half val_ret = val.to_half(); for (int mask = HALF_WARP; mask > 0; mask >>= 1) - val_ret = max(val_ret, __shfl_xor_sync(lane_mask, val_ret, mask, warpSize)); + val_ret = + std::max(val_ret, __shfl_xor_sync(lane_mask, val_ret, mask, warpSize)); #else float val_ret = static_cast(val); for (int mask = HALF_WARP; mask > 0; mask >>= 1) - val_ret = max(val_ret, __shfl_xor(val_ret, mask, warpSize)); + val_ret = std::max(val_ret, __shfl_xor(val_ret, mask, warpSize)); #endif return static_cast(val_ret); } @@ -309,9 +300,9 @@ template __inline__ __device__ T WarpReduceMin(T val, unsigned lane_mask) { for (int mask = HALF_WARP; mask > 0; mask >>= 1) #if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) - val = min(val, __shfl_xor_sync(lane_mask, val, mask, warpSize)); + val = std::min(val, __shfl_xor_sync(lane_mask, val, mask, warpSize)); #else - val = min(val, __shfl_xor(val, mask, warpSize)); + val = std::min(val, __shfl_xor(val, mask, warpSize)); #endif return val; } @@ -322,11 +313,12 @@ __inline__ __device__ phi::dtype::float16 WarpReduceMin(phi::dtype::float16 val, #if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) __half val_ret = val.to_half(); for (int mask = HALF_WARP; mask > 0; mask >>= 1) - val_ret = min(val_ret, __shfl_xor_sync(lane_mask, val_ret, mask, warpSize)); + val_ret = + std::min(val_ret, __shfl_xor_sync(lane_mask, val_ret, mask, warpSize)); #else float val_ret = static_cast(val); for (int mask = HALF_WARP; mask > 0; mask >>= 1) - val_ret = min(val_ret, __shfl_xor(val_ret, mask, warpSize)); + val_ret = std::min(val_ret, __shfl_xor(val_ret, mask, warpSize)); #endif return static_cast(val_ret); } @@ -368,7 +360,7 @@ __inline__ __device__ T BlockReduceMax(T val, unsigned mask) { // align block_span to warpSize int block_span = (blockDim.x + warpSize - 1) >> 5; - val = (lane < block_span) ? shared[lane] : T(-1e10f); + val = (lane < block_span) ? shared[lane] : (T)-1e10f; val = WarpReduceMax(val, mask); return val; @@ -416,7 +408,7 @@ __inline__ __device__ T BlockReduceMin(T val, unsigned mask) { // align block_span to warpSize int block_span = (blockDim.x + warpSize - 1) >> 5; - val = (lane < block_span) ? shared[lane] : T(1e10f); + val = (lane < block_span) ? shared[lane] : (T)1e10f; val = WarpReduceMin(val, mask); return val; diff --git a/paddle/phi/kernels/gpu/dist_kernel.cu b/paddle/phi/kernels/gpu/dist_kernel.cu index 6f2834c499606..438e7bbab5e11 100644 --- a/paddle/phi/kernels/gpu/dist_kernel.cu +++ b/paddle/phi/kernels/gpu/dist_kernel.cu @@ -12,10 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/dist_kernel.h" +#include + #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/dist_kernel.h" #include "paddle/phi/kernels/elementwise_subtract_kernel.h" #include "paddle/phi/kernels/funcs/math_cuda_utils.h" #include "paddle/phi/kernels/gpu/reduce.h" @@ -25,16 +27,6 @@ namespace phi { #define FULL_MASK 0xffffffff -template -__device__ T max(T a, T b) { - return a > b ? a : b; -} - -template -__device__ T min(T a, T b) { - return a < b ? a : b; -} - template struct ZeroOrderFunctor { public: @@ -85,10 +77,10 @@ __global__ void ReduceMaxWithSubtract(const T* x, const T* y, T* out, int64_t N) { - T max_val = T(-1e10f); + T max_val = (T)-1e10f; for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { - max_val = max(max_val, abs(x[i] - y[i])); + max_val = std::max(max_val, abs(x[i] - y[i])); } __syncthreads(); @@ -103,10 +95,10 @@ __global__ void ReduceMinWithSubtract(const T* x, const T* y, T* out, int64_t N) { - T min_val = T(1e10f); + T min_val = (T)1e10f; for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { - min_val = min(min_val, abs(x[i] - y[i])); + min_val = std::min(min_val, abs(x[i] - y[i])); } __syncthreads(); From 6e5325a1c9cf480a407bfe7bac6417e114506ca9 Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Mon, 27 Mar 2023 15:20:33 +0800 Subject: [PATCH 15/25] refactor WarpReduce* to support fp16/bf16 --- paddle/phi/kernels/funcs/math_cuda_utils.h | 68 ++-------------------- 1 file changed, 4 insertions(+), 64 deletions(-) diff --git a/paddle/phi/kernels/funcs/math_cuda_utils.h b/paddle/phi/kernels/funcs/math_cuda_utils.h index 9fe5d65995cac..d3b23e635ea07 100644 --- a/paddle/phi/kernels/funcs/math_cuda_utils.h +++ b/paddle/phi/kernels/funcs/math_cuda_utils.h @@ -23,8 +23,7 @@ limitations under the License. */ #include -#include "paddle/phi/common/data_type.h" -#include "paddle/phi/common/float16.h" +#include "paddle/phi/backends/gpu/gpu_device_function.h" namespace phi { namespace funcs { @@ -173,29 +172,10 @@ struct KeyValuePair { template __inline__ __device__ T WarpReduceSum(T val, unsigned lane_mask) { for (int mask = HALF_WARP; mask > 0; mask >>= 1) -#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) - val += __shfl_xor_sync(lane_mask, val, mask, warpSize); -#else - val += __shfl_xor(val, mask, warpSize); -#endif + val += CudaShuffleXorSync(lane_mask, val, mask); return val; } -template <> -__inline__ __device__ phi::dtype::float16 WarpReduceSum(phi::dtype::float16 val, - unsigned lane_mask) { -#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) - __half val_ret = val.to_half(); - for (int mask = HALF_WARP; mask > 0; mask >>= 1) - val_ret += __shfl_xor_sync(lane_mask, val_ret, mask, warpSize); -#else - float val_ret = static_cast(val); - for (int mask = HALF_WARP; mask > 0; mask >>= 1) - val_ret += __shfl_xor(val_ret, mask, warpSize); -#endif - return static_cast(val_ret); -} - /* Calculate the sum of all elements in a block */ template __inline__ __device__ T BlockReduceSum(T val, unsigned mask) { @@ -261,30 +241,10 @@ __inline__ __device__ T BlockReduceSumV2(T *val) { template __inline__ __device__ T WarpReduceMax(T val, unsigned lane_mask) { for (int mask = HALF_WARP; mask > 0; mask >>= 1) -#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) - val = std::max(val, __shfl_xor_sync(lane_mask, val, mask, warpSize)); -#else - val = std::max(val, __shfl_xor(val, mask, warpSize)); -#endif + val = std::max(val, CudaShuffleXorSync(lane_mask, val, mask)); return val; } -template <> -__inline__ __device__ phi::dtype::float16 WarpReduceMax(phi::dtype::float16 val, - unsigned lane_mask) { -#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) - __half val_ret = val.to_half(); - for (int mask = HALF_WARP; mask > 0; mask >>= 1) - val_ret = - std::max(val_ret, __shfl_xor_sync(lane_mask, val_ret, mask, warpSize)); -#else - float val_ret = static_cast(val); - for (int mask = HALF_WARP; mask > 0; mask >>= 1) - val_ret = std::max(val_ret, __shfl_xor(val_ret, mask, warpSize)); -#endif - return static_cast(val_ret); -} - template __inline__ __device__ T WarpReduceMaxV2(T *val) { #pragma unroll @@ -299,30 +259,10 @@ __inline__ __device__ T WarpReduceMaxV2(T *val) { template __inline__ __device__ T WarpReduceMin(T val, unsigned lane_mask) { for (int mask = HALF_WARP; mask > 0; mask >>= 1) -#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) - val = std::min(val, __shfl_xor_sync(lane_mask, val, mask, warpSize)); -#else - val = std::min(val, __shfl_xor(val, mask, warpSize)); -#endif + val = std::min(val, CudaShuffleXorSync(lane_mask, val, mask, warpSize)); return val; } -template <> -__inline__ __device__ phi::dtype::float16 WarpReduceMin(phi::dtype::float16 val, - unsigned lane_mask) { -#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000) - __half val_ret = val.to_half(); - for (int mask = HALF_WARP; mask > 0; mask >>= 1) - val_ret = - std::min(val_ret, __shfl_xor_sync(lane_mask, val_ret, mask, warpSize)); -#else - float val_ret = static_cast(val); - for (int mask = HALF_WARP; mask > 0; mask >>= 1) - val_ret = std::min(val_ret, __shfl_xor(val_ret, mask, warpSize)); -#endif - return static_cast(val_ret); -} - /* Calculate the minimum of all elements in a warp when actual quantity of * threads are less than warpSize.*/ template From 78d62faee838c4e65844b920485c5cf09cf3e20c Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Mon, 27 Mar 2023 15:38:56 +0800 Subject: [PATCH 16/25] adopt fp32 sum/pow on p!=inf path for better acc --- paddle/phi/kernels/gpu/dist_kernel.cu | 47 +++++++++++++++++---------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/paddle/phi/kernels/gpu/dist_kernel.cu b/paddle/phi/kernels/gpu/dist_kernel.cu index 438e7bbab5e11..c0d64572185d3 100644 --- a/paddle/phi/kernels/gpu/dist_kernel.cu +++ b/paddle/phi/kernels/gpu/dist_kernel.cu @@ -30,21 +30,22 @@ namespace phi { template struct ZeroOrderFunctor { public: - __device__ T operator()(const T& x, const T& y) const { - return static_cast((x - y) != T(0)); + using MT = typename phi::dtype::MPTypeTrait::Type; + __device__ MT operator()(const T& x, const T& y) const { + return static_cast((static_cast(x) - static_cast(y)) != MT(0)); } }; -template +template struct OtherOrderFunctor { - explicit OtherOrderFunctor(const Ty& p_order) : p_order_(p_order) {} - __device__ Tx operator()(const Tx& x, const Tx& y) const { - return static_cast( - pow(abs(static_cast(x) - static_cast(y)), p_order_)); + using MT = typename phi::dtype::MPTypeTrait::Type; + explicit OtherOrderFunctor(const T& p_order) : p_order_(p_order) {} + __device__ T operator()(const T& x, const T& y) const { + return static_cast(pow(abs(x - y), p_order_)); } private: - Ty p_order_; + T p_order_; }; template @@ -114,6 +115,7 @@ void DistKernel(const Context& dev_ctx, const DenseTensor& y, float p, DenseTensor* out) { + using MT = typename phi::dtype::MPTypeTrait::Type; DenseTensor intermediate; const T* x_ptr = x.data(); const T* y_ptr = y.data(); @@ -125,20 +127,25 @@ void DistKernel(const Context& dev_ctx, auto n = x.numel(); auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n); intermediate.Resize(phi::make_ddim({config.block_per_grid.x})); - T* i_ptr = dev_ctx.template Alloc(&intermediate); std::vector axis_dims = {static_cast(-1)}; std::vector reduce_axis = funcs::details::GetReduceDim(axis_dims, xdim.size(), true); if (p == 0) { - ReduceSumWithSubtract + MT* i_ptr = dev_ctx.template Alloc(&intermediate); + ReduceSumWithSubtract <<>>( x_ptr, y_ptr, i_ptr, n, ZeroOrderFunctor()); - phi::funcs::ReduceKernel>( - dev_ctx, intermediate, out, kps::IdentityFunctor(), reduce_axis); + phi::funcs::ReduceKernel>( + dev_ctx, + intermediate, + out, + kps::IdentityFunctor(), + reduce_axis); } else if (p == INFINITY) { + T* i_ptr = dev_ctx.template Alloc(&intermediate); ReduceMaxWithSubtract <<>>( x_ptr, y_ptr, i_ptr, n); @@ -146,6 +153,7 @@ void DistKernel(const Context& dev_ctx, dev_ctx, intermediate, out, kps::IdentityFunctor(), reduce_axis); } else if (p == -INFINITY) { + T* i_ptr = dev_ctx.template Alloc(&intermediate); ReduceMinWithSubtract <<>>( x_ptr, y_ptr, i_ptr, n); @@ -154,13 +162,18 @@ void DistKernel(const Context& dev_ctx, dev_ctx, intermediate, out, kps::IdentityFunctor(), reduce_axis); } else { - using MT = typename phi::dtype::MPTypeTrait::Type; + MT* i_ptr = dev_ctx.template Alloc(&intermediate); MT p_order = static_cast(p); - ReduceSumWithSubtract + ReduceSumWithSubtract <<>>( - x_ptr, y_ptr, i_ptr, n, OtherOrderFunctor(p_order)); - phi::funcs::ReduceKernel>( - dev_ctx, intermediate, out, kps::IdentityFunctor(), reduce_axis); + x_ptr, y_ptr, i_ptr, n, OtherOrderFunctor(p_order)); + phi::funcs:: + ReduceKernel>( + dev_ctx, + intermediate, + out, + kps::IdentityFunctor(), + reduce_axis); const DenseTensor* tmp_norm = out; std::vector ins = {tmp_norm}; From 2b141957a9393934758f40a2d58595674b006a6d Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Mon, 27 Mar 2023 19:30:48 +0800 Subject: [PATCH 17/25] fix interface mistakes --- paddle/phi/kernels/funcs/math_cuda_utils.h | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/paddle/phi/kernels/funcs/math_cuda_utils.h b/paddle/phi/kernels/funcs/math_cuda_utils.h index d3b23e635ea07..35a5876356c42 100644 --- a/paddle/phi/kernels/funcs/math_cuda_utils.h +++ b/paddle/phi/kernels/funcs/math_cuda_utils.h @@ -172,7 +172,7 @@ struct KeyValuePair { template __inline__ __device__ T WarpReduceSum(T val, unsigned lane_mask) { for (int mask = HALF_WARP; mask > 0; mask >>= 1) - val += CudaShuffleXorSync(lane_mask, val, mask); + val += phi::backends::gpu::CudaShuffleXorSync(lane_mask, val, mask); return val; } @@ -241,7 +241,8 @@ __inline__ __device__ T BlockReduceSumV2(T *val) { template __inline__ __device__ T WarpReduceMax(T val, unsigned lane_mask) { for (int mask = HALF_WARP; mask > 0; mask >>= 1) - val = std::max(val, CudaShuffleXorSync(lane_mask, val, mask)); + val = std::max( + val, phi::backends::gpu::CudaShuffleXorSync(lane_mask, val, mask)); return val; } @@ -259,7 +260,8 @@ __inline__ __device__ T WarpReduceMaxV2(T *val) { template __inline__ __device__ T WarpReduceMin(T val, unsigned lane_mask) { for (int mask = HALF_WARP; mask > 0; mask >>= 1) - val = std::min(val, CudaShuffleXorSync(lane_mask, val, mask, warpSize)); + val = std::min( + val, phi::backends::gpu::CudaShuffleXorSync(lane_mask, val, mask)); return val; } From c134c118545ff7ed9d6f1d9bff05460efc5f6ca0 Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Mon, 27 Mar 2023 19:49:53 +0800 Subject: [PATCH 18/25] fix interface mistakes --- paddle/phi/kernels/gpu/dist_kernel.cu | 31 +++++++++++++-------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/paddle/phi/kernels/gpu/dist_kernel.cu b/paddle/phi/kernels/gpu/dist_kernel.cu index c0d64572185d3..ce87c0fd1eb8c 100644 --- a/paddle/phi/kernels/gpu/dist_kernel.cu +++ b/paddle/phi/kernels/gpu/dist_kernel.cu @@ -27,25 +27,24 @@ namespace phi { #define FULL_MASK 0xffffffff -template +template struct ZeroOrderFunctor { public: - using MT = typename phi::dtype::MPTypeTrait::Type; - __device__ MT operator()(const T& x, const T& y) const { - return static_cast((static_cast(x) - static_cast(y)) != MT(0)); + __device__ Ty operator()(const Tx& x, const Tx& y) const { + return static_cast((static_cast(x) - static_cast(y)) != Ty(0)); } }; -template +template struct OtherOrderFunctor { - using MT = typename phi::dtype::MPTypeTrait::Type; - explicit OtherOrderFunctor(const T& p_order) : p_order_(p_order) {} - __device__ T operator()(const T& x, const T& y) const { - return static_cast(pow(abs(x - y), p_order_)); + explicit OtherOrderFunctor(const Ty& p_order) : p_order_(p_order) {} + __device__ Ty operator()(const Tx& x, const Tx& y) const { + return static_cast( + pow(abs(static_cast(x) - static_cast(y)), p_order_)); } private: - T p_order_; + Ty p_order_; }; template @@ -57,17 +56,17 @@ struct PowFunctor { Ty p_order_; }; -template +template __global__ void ReduceSumWithSubtract( - const T* x, const T* y, T* out, int64_t N, Functor func) { - T sum_val = T(0); + const Tx* x, const Tx* y, Ty* out, int64_t N, Functor func) { + Ty sum_val = Ty(0); for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { sum_val += func(x[i], y[i]); } __syncthreads(); - sum_val = phi::funcs::BlockReduceSum(sum_val, FULL_MASK); + sum_val = phi::funcs::BlockReduceSum(sum_val, FULL_MASK); if (threadIdx.x == 0) { out[blockIdx.x] = sum_val; } @@ -136,7 +135,7 @@ void DistKernel(const Context& dev_ctx, MT* i_ptr = dev_ctx.template Alloc(&intermediate); ReduceSumWithSubtract <<>>( - x_ptr, y_ptr, i_ptr, n, ZeroOrderFunctor()); + x_ptr, y_ptr, i_ptr, n, ZeroOrderFunctor()); phi::funcs::ReduceKernel>( dev_ctx, intermediate, @@ -166,7 +165,7 @@ void DistKernel(const Context& dev_ctx, MT p_order = static_cast(p); ReduceSumWithSubtract <<>>( - x_ptr, y_ptr, i_ptr, n, OtherOrderFunctor(p_order)); + x_ptr, y_ptr, i_ptr, n, OtherOrderFunctor(p_order)); phi::funcs:: ReduceKernel>( dev_ctx, From 23fb3241fa988191fb617aed966d27ab0f435614 Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Tue, 28 Mar 2023 00:02:00 +0800 Subject: [PATCH 19/25] fix interface mistakes --- paddle/phi/kernels/gpu/dist_kernel.cu | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/paddle/phi/kernels/gpu/dist_kernel.cu b/paddle/phi/kernels/gpu/dist_kernel.cu index ce87c0fd1eb8c..00a7a9c4f6587 100644 --- a/paddle/phi/kernels/gpu/dist_kernel.cu +++ b/paddle/phi/kernels/gpu/dist_kernel.cu @@ -136,12 +136,13 @@ void DistKernel(const Context& dev_ctx, ReduceSumWithSubtract <<>>( x_ptr, y_ptr, i_ptr, n, ZeroOrderFunctor()); - phi::funcs::ReduceKernel>( - dev_ctx, - intermediate, - out, - kps::IdentityFunctor(), - reduce_axis); + phi::funcs:: + ReduceKernel>( + dev_ctx, + intermediate, + out, + kps::IdentityFunctor(), + reduce_axis); } else if (p == INFINITY) { T* i_ptr = dev_ctx.template Alloc(&intermediate); @@ -167,7 +168,7 @@ void DistKernel(const Context& dev_ctx, <<>>( x_ptr, y_ptr, i_ptr, n, OtherOrderFunctor(p_order)); phi::funcs:: - ReduceKernel>( + ReduceKernel>( dev_ctx, intermediate, out, From 367bdbade77b7203da64cc97690f548d07103c36 Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Tue, 28 Mar 2023 00:55:42 +0800 Subject: [PATCH 20/25] fix min/max val --- paddle/phi/kernels/funcs/math_cuda_utils.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/funcs/math_cuda_utils.h b/paddle/phi/kernels/funcs/math_cuda_utils.h index 35a5876356c42..26ead61e4e5f9 100644 --- a/paddle/phi/kernels/funcs/math_cuda_utils.h +++ b/paddle/phi/kernels/funcs/math_cuda_utils.h @@ -24,6 +24,7 @@ limitations under the License. */ #include #include "paddle/phi/backends/gpu/gpu_device_function.h" +#include "paddle/phi/common/data_type.h" namespace phi { namespace funcs { @@ -302,7 +303,7 @@ __inline__ __device__ T BlockReduceMax(T val, unsigned mask) { // align block_span to warpSize int block_span = (blockDim.x + warpSize - 1) >> 5; - val = (lane < block_span) ? shared[lane] : (T)-1e10f; + val = (lane < block_span) ? shared[lane] : std::numeric_limits::min(); val = WarpReduceMax(val, mask); return val; @@ -350,7 +351,7 @@ __inline__ __device__ T BlockReduceMin(T val, unsigned mask) { // align block_span to warpSize int block_span = (blockDim.x + warpSize - 1) >> 5; - val = (lane < block_span) ? shared[lane] : (T)1e10f; + val = (lane < block_span) ? shared[lane] : std::numeric_limits::max(); val = WarpReduceMin(val, mask); return val; From 81f268581a99dddec040e1580ac3347a77cd4051 Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Tue, 28 Mar 2023 01:38:51 +0800 Subject: [PATCH 21/25] fix acc --- paddle/phi/kernels/gpu/dist_kernel.cu | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/paddle/phi/kernels/gpu/dist_kernel.cu b/paddle/phi/kernels/gpu/dist_kernel.cu index 00a7a9c4f6587..9fd61202361a1 100644 --- a/paddle/phi/kernels/gpu/dist_kernel.cu +++ b/paddle/phi/kernels/gpu/dist_kernel.cu @@ -77,7 +77,7 @@ __global__ void ReduceMaxWithSubtract(const T* x, const T* y, T* out, int64_t N) { - T max_val = (T)-1e10f; + T max_val = std::numeric_limits::min(); for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { max_val = std::max(max_val, abs(x[i] - y[i])); @@ -95,7 +95,7 @@ __global__ void ReduceMinWithSubtract(const T* x, const T* y, T* out, int64_t N) { - T min_val = (T)1e10f; + T min_val = std::numeric_limits::max(); for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { min_val = std::min(min_val, abs(x[i] - y[i])); @@ -136,14 +136,20 @@ void DistKernel(const Context& dev_ctx, ReduceSumWithSubtract <<>>( x_ptr, y_ptr, i_ptr, n, ZeroOrderFunctor()); + DenseTensor intermediate2(intermediate); + dev_ctx.template Alloc(&intermediate2); phi::funcs:: - ReduceKernel>( + ReduceKernel>( dev_ctx, intermediate, - out, - kps::IdentityFunctor(), + &intermediate2, + kps::IdentityFunctor(), reduce_axis); + std::vector ins = {&intermediate2}; + std::vector outs = {out}; + phi::funcs::ElementwiseKernel( + dev_ctx, ins, &outs, kps::IdentityFunctor()); } else if (p == INFINITY) { T* i_ptr = dev_ctx.template Alloc(&intermediate); ReduceMaxWithSubtract @@ -167,16 +173,17 @@ void DistKernel(const Context& dev_ctx, ReduceSumWithSubtract <<>>( x_ptr, y_ptr, i_ptr, n, OtherOrderFunctor(p_order)); + DenseTensor intermediate2(intermediate); + dev_ctx.template Alloc(&intermediate2); phi::funcs:: ReduceKernel>( dev_ctx, intermediate, - out, + &intermediate2, kps::IdentityFunctor(), reduce_axis); - const DenseTensor* tmp_norm = out; - std::vector ins = {tmp_norm}; + std::vector ins = {&intermediate2}; std::vector outs = {out}; MT p_order_ = static_cast(static_cast(1.) / p_order); phi::funcs::ElementwiseKernel( From 7ad62a933405dbdf7cd2b15994fd486fea2d83e7 Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Tue, 25 Apr 2023 17:25:02 +0800 Subject: [PATCH 22/25] modify the p!=inf --- paddle/phi/kernels/gpu/dist_kernel.cu | 87 ++++++++++----------------- 1 file changed, 32 insertions(+), 55 deletions(-) diff --git a/paddle/phi/kernels/gpu/dist_kernel.cu b/paddle/phi/kernels/gpu/dist_kernel.cu index 9fd61202361a1..97c6caa915820 100644 --- a/paddle/phi/kernels/gpu/dist_kernel.cu +++ b/paddle/phi/kernels/gpu/dist_kernel.cu @@ -1,3 +1,4 @@ + // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -27,46 +28,43 @@ namespace phi { #define FULL_MASK 0xffffffff -template +template struct ZeroOrderFunctor { - public: - __device__ Ty operator()(const Tx& x, const Tx& y) const { - return static_cast((static_cast(x) - static_cast(y)) != Ty(0)); - } -}; + HOSTDEVICE explicit inline ZeroOrderFunctor() {} -template -struct OtherOrderFunctor { - explicit OtherOrderFunctor(const Ty& p_order) : p_order_(p_order) {} - __device__ Ty operator()(const Tx& x, const Tx& y) const { - return static_cast( - pow(abs(static_cast(x) - static_cast(y)), p_order_)); + HOSTDEVICE inline Ty operator()(const Tx x) const { + return static_cast(x != Tx(0.0)); } - private: - Ty p_order_; + HOSTDEVICE inline Ty operator()(const Tx x, const Tx y) const { + return static_cast(x != y); + } }; -template +template struct PowFunctor { - explicit PowFunctor(const Ty& p_order) : p_order_(p_order) {} - HOSTDEVICE inline Tx operator()(const Tx x) const { - return static_cast(pow(static_cast(x), p_order_)); + HOSTDEVICE explicit inline PowFunctor(const Ty& _p_order) + : p_order(_p_order) {} + + HOSTDEVICE inline Ty operator()(const Tx x) const { + return static_cast(pow(abs(static_cast(x)), p_order)); } - Ty p_order_; + + private: + Ty p_order; }; -template +template __global__ void ReduceSumWithSubtract( - const Tx* x, const Tx* y, Ty* out, int64_t N, Functor func) { - Ty sum_val = Ty(0); + const T* x, const T* y, T* out, int64_t N, Functor func) { + T sum_val = T(0.0); for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { sum_val += func(x[i], y[i]); } __syncthreads(); - sum_val = phi::funcs::BlockReduceSum(sum_val, FULL_MASK); + sum_val = phi::funcs::BlockReduceSum(sum_val, FULL_MASK); if (threadIdx.x == 0) { out[blockIdx.x] = sum_val; } @@ -132,24 +130,12 @@ void DistKernel(const Context& dev_ctx, funcs::details::GetReduceDim(axis_dims, xdim.size(), true); if (p == 0) { - MT* i_ptr = dev_ctx.template Alloc(&intermediate); - ReduceSumWithSubtract + T* i_ptr = dev_ctx.template Alloc(&intermediate); + ReduceSumWithSubtract <<>>( - x_ptr, y_ptr, i_ptr, n, ZeroOrderFunctor()); - DenseTensor intermediate2(intermediate); - dev_ctx.template Alloc(&intermediate2); - phi::funcs:: - ReduceKernel>( - dev_ctx, - intermediate, - &intermediate2, - kps::IdentityFunctor(), - reduce_axis); - - std::vector ins = {&intermediate2}; - std::vector outs = {out}; - phi::funcs::ElementwiseKernel( - dev_ctx, ins, &outs, kps::IdentityFunctor()); + x_ptr, y_ptr, i_ptr, n, ZeroOrderFunctor()); + phi::funcs::ReduceKernel>( + dev_ctx, intermediate, out, kps::IdentityFunctor(), reduce_axis); } else if (p == INFINITY) { T* i_ptr = dev_ctx.template Alloc(&intermediate); ReduceMaxWithSubtract @@ -168,22 +154,13 @@ void DistKernel(const Context& dev_ctx, dev_ctx, intermediate, out, kps::IdentityFunctor(), reduce_axis); } else { - MT* i_ptr = dev_ctx.template Alloc(&intermediate); + auto t = Subtract(dev_ctx, x, y); MT p_order = static_cast(p); - ReduceSumWithSubtract - <<>>( - x_ptr, y_ptr, i_ptr, n, OtherOrderFunctor(p_order)); - DenseTensor intermediate2(intermediate); - dev_ctx.template Alloc(&intermediate2); - phi::funcs:: - ReduceKernel>( - dev_ctx, - intermediate, - &intermediate2, - kps::IdentityFunctor(), - reduce_axis); - - std::vector ins = {&intermediate2}; + phi::funcs::ReduceKernel>( + dev_ctx, t, out, PowFunctor(p_order), reduce_axis); + + const DenseTensor* tmp_norm = out; + std::vector ins = {tmp_norm}; std::vector outs = {out}; MT p_order_ = static_cast(static_cast(1.) / p_order); phi::funcs::ElementwiseKernel( From f897737ba25f05160af05440db6d87c6ed37a6ce Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Wed, 26 Apr 2023 23:41:42 +0800 Subject: [PATCH 23/25] refactor p=0 p>0 by RuduceSumWithSubtract + ReduceKernel --- paddle/phi/kernels/gpu/dist_kernel.cu | 33 +++++++++++++++++++-------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/paddle/phi/kernels/gpu/dist_kernel.cu b/paddle/phi/kernels/gpu/dist_kernel.cu index 97c6caa915820..96d1f6564a8ac 100644 --- a/paddle/phi/kernels/gpu/dist_kernel.cu +++ b/paddle/phi/kernels/gpu/dist_kernel.cu @@ -32,11 +32,11 @@ template struct ZeroOrderFunctor { HOSTDEVICE explicit inline ZeroOrderFunctor() {} - HOSTDEVICE inline Ty operator()(const Tx x) const { + HOSTDEVICE inline Ty operator()(const Tx& x) const { return static_cast(x != Tx(0.0)); } - HOSTDEVICE inline Ty operator()(const Tx x, const Tx y) const { + HOSTDEVICE inline Ty operator()(const Tx& x, const Tx& y) const { return static_cast(x != y); } }; @@ -46,25 +46,30 @@ struct PowFunctor { HOSTDEVICE explicit inline PowFunctor(const Ty& _p_order) : p_order(_p_order) {} - HOSTDEVICE inline Ty operator()(const Tx x) const { + HOSTDEVICE inline Ty operator()(const Tx& x) const { return static_cast(pow(abs(static_cast(x)), p_order)); } + HOSTDEVICE inline Ty operator()(const Tx& x, const Tx& y) const { + return static_cast( + pow(abs(static_cast(x) - static_cast(y)), p_order)); + } + private: Ty p_order; }; -template +template __global__ void ReduceSumWithSubtract( - const T* x, const T* y, T* out, int64_t N, Functor func) { - T sum_val = T(0.0); + const Tx* x, const Tx* y, Ty* out, int64_t N, Functor func) { + Ty sum_val = Ty(0.0); for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { sum_val += func(x[i], y[i]); } __syncthreads(); - sum_val = phi::funcs::BlockReduceSum(sum_val, FULL_MASK); + sum_val = phi::funcs::BlockReduceSum(sum_val, FULL_MASK); if (threadIdx.x == 0) { out[blockIdx.x] = sum_val; } @@ -154,10 +159,18 @@ void DistKernel(const Context& dev_ctx, dev_ctx, intermediate, out, kps::IdentityFunctor(), reduce_axis); } else { - auto t = Subtract(dev_ctx, x, y); + T* i_ptr = dev_ctx.template Alloc(&intermediate); MT p_order = static_cast(p); - phi::funcs::ReduceKernel>( - dev_ctx, t, out, PowFunctor(p_order), reduce_axis); + ReduceSumWithSubtract + <<>>( + x_ptr, y_ptr, i_ptr, n, PowFunctor(p_order)); + phi::funcs:: + ReduceKernel>( + dev_ctx, + intermediate, + out, + kps::IdentityFunctor(), + reduce_axis); const DenseTensor* tmp_norm = out; std::vector ins = {tmp_norm}; From 88afbf0f7cc23b26afc8f8418b7c27d8f827c008 Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Thu, 27 Apr 2023 00:32:45 +0800 Subject: [PATCH 24/25] fix errs --- paddle/phi/kernels/gpu/dist_kernel.cu | 54 +++++++++++++-------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/paddle/phi/kernels/gpu/dist_kernel.cu b/paddle/phi/kernels/gpu/dist_kernel.cu index 96d1f6564a8ac..b5c1dd375d782 100644 --- a/paddle/phi/kernels/gpu/dist_kernel.cu +++ b/paddle/phi/kernels/gpu/dist_kernel.cu @@ -1,4 +1,3 @@ - // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -32,24 +31,16 @@ template struct ZeroOrderFunctor { HOSTDEVICE explicit inline ZeroOrderFunctor() {} - HOSTDEVICE inline Ty operator()(const Tx& x) const { - return static_cast(x != Tx(0.0)); - } - HOSTDEVICE inline Ty operator()(const Tx& x, const Tx& y) const { return static_cast(x != y); } }; template -struct PowFunctor { - HOSTDEVICE explicit inline PowFunctor(const Ty& _p_order) +struct OtherOrderFunctor { + HOSTDEVICE explicit inline OtherOrderFunctor(const Ty& _p_order) : p_order(_p_order) {} - HOSTDEVICE inline Ty operator()(const Tx& x) const { - return static_cast(pow(abs(static_cast(x)), p_order)); - } - HOSTDEVICE inline Ty operator()(const Tx& x, const Tx& y) const { return static_cast( pow(abs(static_cast(x) - static_cast(y)), p_order)); @@ -59,19 +50,33 @@ struct PowFunctor { Ty p_order; }; -template +template +struct PowFunctor { + HOSTDEVICE explicit inline PowFunctor(const Ty& _p_order) + : p_order(_p_order) {} + + HOSTDEVICE inline Tx operator()(const Tx x) const { + return static_cast(pow(static_cast(x), p_order)); + } + + private: + Ty p_order; +}; + +template __global__ void ReduceSumWithSubtract( - const Tx* x, const Tx* y, Ty* out, int64_t N, Functor func) { - Ty sum_val = Ty(0.0); + const T* x, const T* y, T* out, int64_t N, Functor func) { + using MT = typename phi::dtype::MPTypeTrait::Type; + MT sum_val(0.0); for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { - sum_val += func(x[i], y[i]); + sum_val += static_cast(func(x[i], y[i])); } __syncthreads(); - sum_val = phi::funcs::BlockReduceSum(sum_val, FULL_MASK); + sum_val = phi::funcs::BlockReduceSum(sum_val, FULL_MASK); if (threadIdx.x == 0) { - out[blockIdx.x] = sum_val; + out[blockIdx.x] = static_cast(sum_val); } } @@ -159,18 +164,13 @@ void DistKernel(const Context& dev_ctx, dev_ctx, intermediate, out, kps::IdentityFunctor(), reduce_axis); } else { - T* i_ptr = dev_ctx.template Alloc(&intermediate); + T* i_ptr = dev_ctx.template Alloc(&intermediate); MT p_order = static_cast(p); - ReduceSumWithSubtract + ReduceSumWithSubtract <<>>( - x_ptr, y_ptr, i_ptr, n, PowFunctor(p_order)); - phi::funcs:: - ReduceKernel>( - dev_ctx, - intermediate, - out, - kps::IdentityFunctor(), - reduce_axis); + x_ptr, y_ptr, i_ptr, n, OtherOrderFunctor(p_order)); + phi::funcs::ReduceKernel>( + dev_ctx, intermediate, out, kps::IdentityFunctor(), reduce_axis); const DenseTensor* tmp_norm = out; std::vector ins = {tmp_norm}; From 6cc08db2b946248ac8fad6d16618ba23394f85b4 Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Thu, 27 Apr 2023 10:00:14 +0800 Subject: [PATCH 25/25] remove redundancy --- paddle/phi/kernels/gpu/dist_kernel.cu | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/paddle/phi/kernels/gpu/dist_kernel.cu b/paddle/phi/kernels/gpu/dist_kernel.cu index b5c1dd375d782..0a197053e1016 100644 --- a/paddle/phi/kernels/gpu/dist_kernel.cu +++ b/paddle/phi/kernels/gpu/dist_kernel.cu @@ -134,20 +134,19 @@ void DistKernel(const Context& dev_ctx, auto n = x.numel(); auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n); intermediate.Resize(phi::make_ddim({config.block_per_grid.x})); + T* i_ptr = dev_ctx.template Alloc(&intermediate); std::vector axis_dims = {static_cast(-1)}; std::vector reduce_axis = funcs::details::GetReduceDim(axis_dims, xdim.size(), true); if (p == 0) { - T* i_ptr = dev_ctx.template Alloc(&intermediate); ReduceSumWithSubtract <<>>( x_ptr, y_ptr, i_ptr, n, ZeroOrderFunctor()); phi::funcs::ReduceKernel>( dev_ctx, intermediate, out, kps::IdentityFunctor(), reduce_axis); } else if (p == INFINITY) { - T* i_ptr = dev_ctx.template Alloc(&intermediate); ReduceMaxWithSubtract <<>>( x_ptr, y_ptr, i_ptr, n); @@ -155,7 +154,6 @@ void DistKernel(const Context& dev_ctx, dev_ctx, intermediate, out, kps::IdentityFunctor(), reduce_axis); } else if (p == -INFINITY) { - T* i_ptr = dev_ctx.template Alloc(&intermediate); ReduceMinWithSubtract <<>>( x_ptr, y_ptr, i_ptr, n); @@ -164,7 +162,6 @@ void DistKernel(const Context& dev_ctx, dev_ctx, intermediate, out, kps::IdentityFunctor(), reduce_axis); } else { - T* i_ptr = dev_ctx.template Alloc(&intermediate); MT p_order = static_cast(p); ReduceSumWithSubtract <<>>(