From 09d70e0fab3f73ba6662c5bb53a476dae4fbd528 Mon Sep 17 00:00:00 2001 From: JamesLim-sy Date: Mon, 5 Jul 2021 18:18:49 +0000 Subject: [PATCH 1/4] Fisrt commit --- .../elementwise/elementwise_add_op.h | 1 - .../elementwise/elementwise_op_broadcast.cu.h | 27 +-- .../elementwise/elementwise_op_impl.cu.h | 224 ++++++++---------- paddle/fluid/platform/fast_divmod.h | 38 ++- 4 files changed, 152 insertions(+), 138 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.h b/paddle/fluid/operators/elementwise/elementwise_add_op.h index a469ebbaec2ed..ad9066540c23b 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.h @@ -17,7 +17,6 @@ limitations under the License. */ #include #include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/math_function.h" diff --git a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h index 541ff9aacfc46..4e9bb869be21a 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h @@ -163,7 +163,7 @@ struct DimensionsTransform { struct StridesCalculation { std::vector> strides; - std::vector divmoders; + std::vector divmoders; private: // To calculate the strides of each input_tensor. @@ -190,7 +190,7 @@ struct StridesCalculation { strides.resize(N, std::vector(dim_size, 1)); for (int i = 0; i < dim_size; ++i) { - divmoders[i] = FastDivMod(out_dims[i]); + divmoders[i] = platform::FastDivMod(out_dims[i]); } CalculateStrides(N, dim_size, in_dims); } @@ -199,15 +199,15 @@ struct StridesCalculation { template struct BroadcastArgsWarpper { - using InVecType = CudaAlignedVector; - using OutVecType = CudaAlignedVector; + using InVecType = platform::CudaAlignedVector; + using OutVecType = platform::CudaAlignedVector; OutT *out_data; OutVecType *vec_out_data; const InT *__restrict__ in_data[ET]; const InVecType *__restrict__ vec_in_data[ET]; bool no_broadcast[ET]; - FastDivMod divmoders[kDims]; + platform::FastDivMod divmoders[kDims]; uint32_t strides[ET][framework::DDim::kMaxRank]; uint32_t scalar_cal_offset; Functor func; @@ -227,7 +227,7 @@ struct BroadcastArgsWarpper { out_data = out->data(); vec_out_data = reinterpret_cast(out_data); memcpy(divmoders, offset_calculator.divmoders.data(), - kDims * sizeof(FastDivMod)); + kDims * sizeof(platform::FastDivMod)); } __device__ __forceinline__ uint32_t GetOffsetByDivmod(int idx, int in_idx) { @@ -310,10 +310,9 @@ __device__ inline void ScalarizedBroadcastKernelImpl( OutT args_out; broadcast_warpper.LoadScalarizedData(args, tid); -#pragma unroll(ET) - for (int j = 1; j < ET; ++j) { - args_out = broadcast_warpper.func(args); - } + // Calcualtion of the in_tensor data. + args_out = broadcast_warpper.func(args); + broadcast_warpper.StoreScalarizedData(args_out, tid); } @@ -321,7 +320,7 @@ template __device__ inline void VectorizedBroadcastKernelImpl( BroadcastArgsWarpper broadcast_warpper, int tid) { - using OutVecType = CudaAlignedVector; + using OutVecType = platform::CudaAlignedVector; OutVecType args_out; InT ins[ET]; InT args[ET][VecSize]; @@ -367,7 +366,7 @@ void LaunchBroadcastKernelForDifferentDimSize( const std::vector &ins, framework::Tensor *out, int axis, Functor func) { int numel = out->numel(); - const int threads = 256; + int threads = GetThreadsConfig(ctx, numel, VecSize); int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads; int main_tid = numel / VecSize; int tail_tid = numel % VecSize; @@ -473,11 +472,11 @@ void LaunchBroadcastElementwiseCudaKernel( int in_vec_size = 4; framework::Tensor *out = (*outs)[0]; for (auto *in : ins) { - auto temp_size = GetVectorizedSizeImpl(in->data()); + auto temp_size = platform::GetVectorizedSize(in->data()); in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size) : in_vec_size; } - int out_vec_size = GetVectorizedSizeImpl(out->data()); + int out_vec_size = platform::GetVectorizedSize(out->data()); int vec_size = std::min(out_vec_size, in_vec_size); switch (vec_size) { diff --git a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h index 101512e35fdcb..4a4de37310898 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h @@ -26,7 +26,7 @@ limitations under the License. */ namespace paddle { namespace operators { -enum ElementwiseType { kUnary = 1, kBinary = 2 }; +enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3 }; /* * According to NVIDIA, if number of threads per block is 64/128/256/512, @@ -52,98 +52,73 @@ inline int GetThreadsConfig(const platform::CUDADeviceContext &ctx, return std::max(64, threads); } -/* -* Only the address of input data is the multiplier of 1,2,4, vectorized load -* with corresponding multiplier-value is possible. Moreover, the maximum length -* of vectorized load is 128 bits once. Hence, valid length of vectorized load -* shall be determined under both former constraints. -*/ -template -int GetVectorizedSizeImpl(const T *pointer) { - constexpr int max_load_bits = 128; - int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T); - uint64_t address = reinterpret_cast(pointer); - constexpr int vec8 = - std::alignment_of>::value; // NOLINT - constexpr int vec4 = - std::alignment_of>::value; // NOLINT - constexpr int vec2 = - std::alignment_of>::value; // NOLINT - if (address % vec8 == 0) { - /* - * Currently, decide to deal with no more than 4 data once while adopting - * vectorization load/store, if performance test shows that dealing with - * 8 data once in vectorization load/store does get optimized, return code - * below can be changed into " return std::min(8, valid_vec_size); " . - */ - return std::min(4, valid_vec_size); - } else if (address % vec4 == 0) { - return std::min(4, valid_vec_size); - } else if (address % vec2 == 0) { - return std::min(2, valid_vec_size); - } else { - return 1; - } -} - template -int GetVectorizedSize(const std::vector &ins, - const std::vector &outs) { +int GetVectorizedSizeImpl(const std::vector &ins, + const std::vector &outs) { int vec_size = 4; for (auto iter = ins.begin(); iter != ins.end(); ++iter) { - vec_size = - std::min(vec_size, GetVectorizedSizeImpl((*iter)->data())); + vec_size = std::min(vec_size, + platform::GetVectorizedSize((*iter)->data())); } for (auto iter = outs.begin(); iter != outs.end(); ++iter) { - vec_size = - std::min(vec_size, GetVectorizedSizeImpl((*iter)->data())); + vec_size = std::min( + vec_size, platform::GetVectorizedSize((*iter)->data())); } return vec_size; } template struct ElementwiseDataWrapper { + using InVecType = platform::CudaAlignedVector; + using OutVecType = platform::CudaAlignedVector; + + const InT *__restrict__ in_data[ET]; OutT *out; - const InT *in0; - const InT *in1; - __device__ ElementwiseDataWrapper(OutT *out, const InT *in0, - const InT *in1 = nullptr) - : out(out), in0(in0), in1(in1) {} - - using InVecType = CudaAlignedVector; - using OutVecType = CudaAlignedVector; - - inline __device__ void load_vector(InVecType args[], int idx) { - const InVecType *x_vec = reinterpret_cast(in0); - args[0] = x_vec[idx]; - if (ET == ElementwiseType::kBinary) { - const InVecType *y_vec = reinterpret_cast(in1); - args[1] = y_vec[idx]; + uint32_t scalar_cal_offset; + + HOSTDEVICE ElementwiseDataWrapper( + const std::vector &ins, + std::vector *outs, uint32_t scalar_cal_offset) + : scalar_cal_offset(scalar_cal_offset) { +#pragma unroll + for (int i = 0; i < ET; ++i) { + in_data[i] = ins[i]->data(); } + out = (*outs)[0]->data(); } - inline __device__ void load_scalar(InT args[], int idx) { - args[0] = in0[idx]; - if (ET == ElementwiseType::kBinary) { - args[1] = in1[idx]; + inline __device__ void LoadVectorizedData(InVecType vec_args[], int tid) { +#pragma unroll + for (int i = 0; i < ET; ++i) { + const InVecType *in_vec_data = + reinterpret_cast(in_data[i]); + vec_args[i] = in_vec_data[tid]; } } - inline __device__ void store_vector(OutVecType res, int idx) { + inline __device__ void LoadScalarizedData(InT args[], int tid) { +#pragma unroll + for (int i = 0; i < ET; ++i) { + args[i] = in_data[i][tid + scalar_cal_offset]; + } + } + + inline __device__ void StoreVectorizedData(OutVecType res, int idx) { OutVecType *out_vec = reinterpret_cast(out); out_vec[idx] = res; } - inline __device__ void store_scalar(OutT res, int idx) { out[idx] = res; } + inline __device__ void StoreScalarizedData(OutT res, int idx) { + out[idx + scalar_cal_offset] = res; + } }; -template -__device__ inline void VectorizedKernelImpl( - ElementwiseDataWrapper data, Functor func, - int tid) { - using InVecType = CudaAlignedVector; - using OutVecType = CudaAlignedVector; +template +__device__ inline void VectorizedKernelImpl(DataWarpper data, Functor func, + int tid) { + using InVecType = platform::CudaAlignedVector; + using OutVecType = platform::CudaAlignedVector; InVecType ins_vec[ET]; OutVecType out_vec; InT *ins_ptr[ET]; @@ -153,7 +128,7 @@ __device__ inline void VectorizedKernelImpl( ins_ptr[i] = reinterpret_cast(&(ins_vec[i])); } // load - data.load_vector(ins_vec, tid); + data.LoadVectorizedData(ins_vec, tid); // compute #pragma unroll @@ -165,52 +140,46 @@ __device__ inline void VectorizedKernelImpl( out_vec.val[i] = func(ins); } // store - data.store_vector(out_vec, tid); + data.StoreVectorizedData(out_vec, tid); } -template -__device__ inline void ScalarKernelImpl( - ElementwiseDataWrapper data, Functor func, - int start, int remain) { +__device__ inline void ScalarKernelImpl(DataWarpper data, Functor func, + int tid) { InT ins[ET]; OutT out; - for (int i = 0; i < remain; ++i) { - int idx = start + i; - // load - data.load_scalar(ins, idx); - // compute - out = func(ins); - // store - data.store_scalar(out, idx); - } + // load + data.LoadScalarizedData(ins, tid); + // compute + out = func(ins); + // store + data.StoreScalarizedData(out, tid); } -template -__global__ void VectorizedKernel(const InT *__restrict__ in0, - const InT *__restrict__ in1, OutT *out, - int size, Functor func) { +template +__global__ void VectorizedKernel(DataWarpper data, int main_tid, int tail_tid, + Functor func) { int tid = blockIdx.x * blockDim.x + threadIdx.x; - int remain = size - VecSize * tid; - remain = remain > 0 ? remain : 0; - auto data = ElementwiseDataWrapper(out, in0, in1); - if (remain >= VecSize) { - VectorizedKernelImpl(data, func, tid); - } else { - ScalarKernelImpl(data, func, tid * VecSize, remain); + + if (tid < main_tid) { + VectorizedKernelImpl( + data, func, tid); + } + if (tid < tail_tid) { + ScalarKernelImpl(data, func, tid); } } -template -__global__ void ScalarKernel(const InT *__restrict__ in0, - const InT *__restrict__ in1, OutT *out, int size, - Functor func) { - auto data = ElementwiseDataWrapper(out, in0, in1); +template +__global__ void ScalarKernel(DataWarpper data, int numel, Functor func) { int tid = blockIdx.x * blockDim.x + threadIdx.x; - int remain = tid < size ? 1 : 0; - ScalarKernelImpl(data, func, tid, remain); + if (tid < numel) { + ScalarKernelImpl(data, func, tid); + } } template @@ -219,35 +188,48 @@ void LaunchSameDimsElementwiseCudaKernel( const std::vector &ins, std::vector *outs, Functor func) { // calculate the max vec_size for all ins and outs - auto size = ins[0]->numel(); - int vec_size = GetVectorizedSize(ins, *outs); - int block_size = GetThreadsConfig(ctx, size, vec_size); + auto numel = ins[0]->numel(); + int vec_size = GetVectorizedSizeImpl(ins, *outs); + int block_size = GetThreadsConfig(ctx, numel, vec_size); int grid_size = - ((size + vec_size - 1) / vec_size + block_size - 1) / block_size; - const InT *in0 = ins[0]->data(); - const InT *in1 = - (ET == ElementwiseType::kBinary) ? ins[1]->data() : nullptr; - OutT *out = (*outs)[0]->data(); + ((numel + vec_size - 1) / vec_size + block_size - 1) / block_size; + int main_tid = numel / vec_size; + int tail_tid = numel % vec_size; + uint32_t vec_len = main_tid * vec_size; + // cuda kernel auto stream = ctx.stream(); switch (vec_size) { - case 4: - VectorizedKernel<<>>( - in0, in1, out, size, func); + case 4: { + auto data_warpper = + ElementwiseDataWrapper(ins, outs, vec_len); + VectorizedKernel<<>>( + data_warpper, main_tid, tail_tid, func); break; - case 2: - VectorizedKernel<<>>( - in0, in1, out, size, func); + } + case 2: { + auto data_warpper = + ElementwiseDataWrapper(ins, outs, vec_len); + VectorizedKernel<<>>( + data_warpper, main_tid, tail_tid, func); break; - case 1: - ScalarKernel<<>>(in0, in1, out, - size, func); + } + case 1: { + auto data_warpper = + ElementwiseDataWrapper(ins, outs, 0); + ScalarKernel<<>>(data_warpper, + numel, func); break; - default: + } + default: { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported vectorized size: %d !", vec_size)); break; + } } } diff --git a/paddle/fluid/platform/fast_divmod.h b/paddle/fluid/platform/fast_divmod.h index 5c5903d62cd27..0aba3aa97519f 100644 --- a/paddle/fluid/platform/fast_divmod.h +++ b/paddle/fluid/platform/fast_divmod.h @@ -20,7 +20,7 @@ limitations under the License. */ #define INT_BITS 32 namespace paddle { -namespace operators { +namespace platform { template struct alignas(sizeof(T) * Size) CudaAlignedVector { @@ -65,5 +65,39 @@ struct FastDivMod { uint32_t multiplier; }; -} // namespace operators +/* +* Only the address of input data is the multiplier of 1,2,4, vectorized load +* with corresponding multiplier-value is possible. Moreover, the maximum length +* of vectorized load is 128 bits once. Hence, valid length of vectorized load +* shall be determined under both former constraints. +*/ +template +int GetVectorizedSize(const T *pointer) { + constexpr int max_load_bits = 128; + int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T); + uint64_t address = reinterpret_cast(pointer); + constexpr int vec8 = + std::alignment_of>::value; // NOLINT + constexpr int vec4 = + std::alignment_of>::value; // NOLINT + constexpr int vec2 = + std::alignment_of>::value; // NOLINT + if (address % vec8 == 0) { + /* + * Currently, decide to deal with no more than 4 data once while adopting + * vectorization load/store, if performance test shows that dealing with + * 8 data once in vectorization load/store does get optimized, return code + * below can be changed into " return std::min(8, valid_vec_size); " . + */ + return std::min(4, valid_vec_size); + } else if (address % vec4 == 0) { + return std::min(4, valid_vec_size); + } else if (address % vec2 == 0) { + return std::min(2, valid_vec_size); + } else { + return 1; + } +} + +} // namespace platform } // namespace paddle From a73c141b965aab50a9da5ff90cce513339717b7f Mon Sep 17 00:00:00 2001 From: JamesLim-sy Date: Thu, 8 Jul 2021 09:03:32 +0000 Subject: [PATCH 2/4] Change some varible names. --- .../operators/elementwise/elementwise_op_impl.cu.h | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h index 4a4de37310898..34b43213244b0 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h @@ -73,7 +73,7 @@ struct ElementwiseDataWrapper { using OutVecType = platform::CudaAlignedVector; const InT *__restrict__ in_data[ET]; - OutT *out; + OutT *out_data; uint32_t scalar_cal_offset; HOSTDEVICE ElementwiseDataWrapper( @@ -84,7 +84,7 @@ struct ElementwiseDataWrapper { for (int i = 0; i < ET; ++i) { in_data[i] = ins[i]->data(); } - out = (*outs)[0]->data(); + out_data = (*outs)[0]->data(); } inline __device__ void LoadVectorizedData(InVecType vec_args[], int tid) { @@ -103,13 +103,13 @@ struct ElementwiseDataWrapper { } } - inline __device__ void StoreVectorizedData(OutVecType res, int idx) { - OutVecType *out_vec = reinterpret_cast(out); - out_vec[idx] = res; + inline __device__ void StoreVectorizedData(OutVecType res, int tid) { + OutVecType *out_vec = reinterpret_cast(out_data); + out_vec[tid] = res; } - inline __device__ void StoreScalarizedData(OutT res, int idx) { - out[idx + scalar_cal_offset] = res; + inline __device__ void StoreScalarizedData(OutT res, int tid) { + out_data[tid + scalar_cal_offset] = res; } }; From 2046e9da58cb3c6d50e83a80600dd89ab44e16c6 Mon Sep 17 00:00:00 2001 From: JamesLim-sy Date: Wed, 14 Jul 2021 04:57:10 +0000 Subject: [PATCH 3/4] Fix varible and function names --- .../elementwise/elementwise_op_impl.cu.h | 42 ++++++++++--------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h index 34b43213244b0..ebc437dfe7c93 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h @@ -53,8 +53,8 @@ inline int GetThreadsConfig(const platform::CUDADeviceContext &ctx, } template -int GetVectorizedSizeImpl(const std::vector &ins, - const std::vector &outs) { +int GetVectorizedSizeForIO(const std::vector &ins, + const std::vector &outs) { int vec_size = 4; for (auto iter = ins.begin(); iter != ins.end(); ++iter) { vec_size = std::min(vec_size, @@ -113,10 +113,10 @@ struct ElementwiseDataWrapper { } }; -template -__device__ inline void VectorizedKernelImpl(DataWarpper data, Functor func, - int tid) { +template +__device__ inline void VectorizedKernelImpl(ElementwiseWarpper data, + Functor func, int tid) { using InVecType = platform::CudaAlignedVector; using OutVecType = platform::CudaAlignedVector; InVecType ins_vec[ET]; @@ -143,9 +143,9 @@ __device__ inline void VectorizedKernelImpl(DataWarpper data, Functor func, data.StoreVectorizedData(out_vec, tid); } -template -__device__ inline void ScalarKernelImpl(DataWarpper data, Functor func, +template +__device__ inline void ScalarKernelImpl(ElementwiseWarpper data, Functor func, int tid) { InT ins[ET]; OutT out; @@ -158,27 +158,29 @@ __device__ inline void ScalarKernelImpl(DataWarpper data, Functor func, data.StoreScalarizedData(out, tid); } -template -__global__ void VectorizedKernel(DataWarpper data, int main_tid, int tail_tid, - Functor func) { +template +__global__ void VectorizedKernel(ElementwiseWarpper data, int main_tid, + int tail_tid, Functor func) { int tid = blockIdx.x * blockDim.x + threadIdx.x; if (tid < main_tid) { - VectorizedKernelImpl( + VectorizedKernelImpl( data, func, tid); } if (tid < tail_tid) { - ScalarKernelImpl(data, func, tid); + ScalarKernelImpl(data, func, + tid); } } -template -__global__ void ScalarKernel(DataWarpper data, int numel, Functor func) { +template +__global__ void ScalarKernel(ElementwiseWarpper data, int numel, Functor func) { int tid = blockIdx.x * blockDim.x + threadIdx.x; if (tid < numel) { - ScalarKernelImpl(data, func, tid); + ScalarKernelImpl(data, func, + tid); } } @@ -189,7 +191,7 @@ void LaunchSameDimsElementwiseCudaKernel( std::vector *outs, Functor func) { // calculate the max vec_size for all ins and outs auto numel = ins[0]->numel(); - int vec_size = GetVectorizedSizeImpl(ins, *outs); + int vec_size = GetVectorizedSizeForIO(ins, *outs); int block_size = GetThreadsConfig(ctx, numel, vec_size); int grid_size = ((numel + vec_size - 1) / vec_size + block_size - 1) / block_size; From c5e8e2b06846afcb82433bc7936d88e0d10863ef Mon Sep 17 00:00:00 2001 From: JamesLim-sy Date: Wed, 28 Jul 2021 11:16:20 +0000 Subject: [PATCH 4/4] Fix name spelling error. --- .../elementwise/elementwise_op_broadcast.cu.h | 100 +++++++++--------- .../elementwise/elementwise_op_impl.cu.h | 40 +++---- 2 files changed, 70 insertions(+), 70 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h index 4e9bb869be21a..95dc6ed342ffc 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h @@ -198,7 +198,7 @@ struct StridesCalculation { template -struct BroadcastArgsWarpper { +struct BroadcastArgsWrapper { using InVecType = platform::CudaAlignedVector; using OutVecType = platform::CudaAlignedVector; @@ -212,7 +212,7 @@ struct BroadcastArgsWarpper { uint32_t scalar_cal_offset; Functor func; - HOSTDEVICE BroadcastArgsWarpper( + HOSTDEVICE BroadcastArgsWrapper( const std::vector &ins, framework::Tensor *out, int scalar_cal_offset, Functor func, const StridesCalculation &offset_calculator) @@ -302,29 +302,29 @@ struct BroadcastArgsWarpper { } }; -template __device__ inline void ScalarizedBroadcastKernelImpl( - BroadcastArgsWarpper broadcast_warpper, int tid) { + BroadcastArgsWrapper broadcast_wrapper, int tid) { InT args[ET]; OutT args_out; - broadcast_warpper.LoadScalarizedData(args, tid); + broadcast_wrapper.LoadScalarizedData(args, tid); // Calcualtion of the in_tensor data. - args_out = broadcast_warpper.func(args); + args_out = broadcast_wrapper.func(args); - broadcast_warpper.StoreScalarizedData(args_out, tid); + broadcast_wrapper.StoreScalarizedData(args_out, tid); } -template __device__ inline void VectorizedBroadcastKernelImpl( - BroadcastArgsWarpper broadcast_warpper, int tid) { + BroadcastArgsWrapper broadcast_wrapper, int tid) { using OutVecType = platform::CudaAlignedVector; OutVecType args_out; InT ins[ET]; InT args[ET][VecSize]; - broadcast_warpper.LoadVectorizedData(args, tid); + broadcast_wrapper.LoadVectorizedData(args, tid); #pragma unroll(VecSize) for (int i = 0; i < VecSize; ++i) { @@ -332,30 +332,30 @@ __device__ inline void VectorizedBroadcastKernelImpl( for (int j = 0; j < ET; ++j) { ins[j] = args[j][i]; } - args_out.val[i] = broadcast_warpper.func(ins); + args_out.val[i] = broadcast_wrapper.func(ins); } - broadcast_warpper.StoreVectorizedData(args_out, tid); + broadcast_wrapper.StoreVectorizedData(args_out, tid); } -template __global__ void ElementwiseBroadcastKernel( - BroadcastArgsWarpper broadcast_warpper, int main_tid, int tail_tid) { + BroadcastArgsWrapper broadcast_wrapper, int main_tid, int tail_tid) { int tid = threadIdx.x + blockIdx.x * blockDim.x; // Vectorized calculation of major data whose length is the max multipler of // VecSize, // eg: Calcualting the front 1024-length data in total 1027 data once VecSize // is 4. if (tid < main_tid) { - VectorizedBroadcastKernelImpl( - broadcast_warpper, tid); + VectorizedBroadcastKernelImpl( + broadcast_wrapper, tid); } // Scalarzed calculation of rest data whose lenght cannot fulfill VecSize. // eg: Calcualting the rest 3-length data in total 1027 data once VecSize is // 4. if (tid < tail_tid) { - ScalarizedBroadcastKernelImpl( - broadcast_warpper, tid); + ScalarizedBroadcastKernelImpl( + broadcast_wrapper, tid); } } @@ -379,75 +379,75 @@ void LaunchBroadcastKernelForDifferentDimSize( switch (merge_dims.dim_size) { case 1: { - auto broadcast_warpper = - BroadcastArgsWarpper( + auto broadcast_wrapper = + BroadcastArgsWrapper( ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel<<>>( - broadcast_warpper, main_tid, tail_tid); + broadcast_wrapper, main_tid, tail_tid); break; } case 2: { - auto broadcast_warpper = - BroadcastArgsWarpper( + auto broadcast_wrapper = + BroadcastArgsWrapper( ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel<<>>( - broadcast_warpper, main_tid, tail_tid); + broadcast_wrapper, main_tid, tail_tid); break; } case 3: { - auto broadcast_warpper = - BroadcastArgsWarpper( + auto broadcast_wrapper = + BroadcastArgsWrapper( ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel<<>>( - broadcast_warpper, main_tid, tail_tid); + broadcast_wrapper, main_tid, tail_tid); break; } case 4: { - auto broadcast_warpper = - BroadcastArgsWarpper( + auto broadcast_wrapper = + BroadcastArgsWrapper( ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel<<>>( - broadcast_warpper, main_tid, tail_tid); + broadcast_wrapper, main_tid, tail_tid); break; } case 5: { - auto broadcast_warpper = - BroadcastArgsWarpper( + auto broadcast_wrapper = + BroadcastArgsWrapper( ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel<<>>( - broadcast_warpper, main_tid, tail_tid); + broadcast_wrapper, main_tid, tail_tid); break; } case 6: { - auto broadcast_warpper = - BroadcastArgsWarpper( + auto broadcast_wrapper = + BroadcastArgsWrapper( ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel<<>>( - broadcast_warpper, main_tid, tail_tid); + broadcast_wrapper, main_tid, tail_tid); break; } case 7: { - auto broadcast_warpper = - BroadcastArgsWarpper( + auto broadcast_wrapper = + BroadcastArgsWrapper( ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel<<>>( - broadcast_warpper, main_tid, tail_tid); + broadcast_wrapper, main_tid, tail_tid); break; } case 8: { - auto broadcast_warpper = - BroadcastArgsWarpper( + auto broadcast_wrapper = + BroadcastArgsWrapper( ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel<<>>( - broadcast_warpper, main_tid, tail_tid); + broadcast_wrapper, main_tid, tail_tid); break; } default: { diff --git a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h index ebc437dfe7c93..3bd746ace0610 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h @@ -113,9 +113,9 @@ struct ElementwiseDataWrapper { } }; -template -__device__ inline void VectorizedKernelImpl(ElementwiseWarpper data, +__device__ inline void VectorizedKernelImpl(ElementwiseWrapper data, Functor func, int tid) { using InVecType = platform::CudaAlignedVector; using OutVecType = platform::CudaAlignedVector; @@ -143,9 +143,9 @@ __device__ inline void VectorizedKernelImpl(ElementwiseWarpper data, data.StoreVectorizedData(out_vec, tid); } -template -__device__ inline void ScalarKernelImpl(ElementwiseWarpper data, Functor func, +__device__ inline void ScalarKernelImpl(ElementwiseWrapper data, Functor func, int tid) { InT ins[ET]; OutT out; @@ -158,28 +158,28 @@ __device__ inline void ScalarKernelImpl(ElementwiseWarpper data, Functor func, data.StoreScalarizedData(out, tid); } -template -__global__ void VectorizedKernel(ElementwiseWarpper data, int main_tid, +__global__ void VectorizedKernel(ElementwiseWrapper data, int main_tid, int tail_tid, Functor func) { int tid = blockIdx.x * blockDim.x + threadIdx.x; if (tid < main_tid) { - VectorizedKernelImpl( + VectorizedKernelImpl( data, func, tid); } if (tid < tail_tid) { - ScalarKernelImpl(data, func, + ScalarKernelImpl(data, func, tid); } } -template -__global__ void ScalarKernel(ElementwiseWarpper data, int numel, Functor func) { +__global__ void ScalarKernel(ElementwiseWrapper data, int numel, Functor func) { int tid = blockIdx.x * blockDim.x + threadIdx.x; if (tid < numel) { - ScalarKernelImpl(data, func, + ScalarKernelImpl(data, func, tid); } } @@ -204,26 +204,26 @@ void LaunchSameDimsElementwiseCudaKernel( switch (vec_size) { case 4: { - auto data_warpper = + auto data_wrapper = ElementwiseDataWrapper(ins, outs, vec_len); - VectorizedKernel<<>>( - data_warpper, main_tid, tail_tid, func); + data_wrapper, main_tid, tail_tid, func); break; } case 2: { - auto data_warpper = + auto data_wrapper = ElementwiseDataWrapper(ins, outs, vec_len); - VectorizedKernel<<>>( - data_warpper, main_tid, tail_tid, func); + data_wrapper, main_tid, tail_tid, func); break; } case 1: { - auto data_warpper = + auto data_wrapper = ElementwiseDataWrapper(ins, outs, 0); - ScalarKernel<<>>(data_warpper, + ScalarKernel<<>>(data_wrapper, numel, func); break; }