From 2af5eab91cc985385c260d719eba65c7e2389690 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Fri, 3 Sep 2021 08:30:16 +0000 Subject: [PATCH] merge develop --- .../elementwise/elementwise_op_broadcast.cu.h | 464 ++++++++---------- .../elementwise/elementwise_op_impl.cu.h | 300 ++++++----- 2 files changed, 333 insertions(+), 431 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h index 95dc6ed342ffc..e18141f7b62c8 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h @@ -15,10 +15,14 @@ #pragma once #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" - +#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" namespace paddle { namespace operators { +#define MAX_INPUT_NUM 3 // the max num of ET for BroadcacstConfig + +namespace kps = paddle::operators::kernel_primitives; + struct DimensionsTransform { using DimVector = std::vector; typedef void (*MergeFunctor)(bool &, std::vector &, DimVector &, @@ -161,201 +165,191 @@ struct DimensionsTransform { } }; -struct StridesCalculation { - std::vector> strides; - std::vector divmoders; - - private: - // To calculate the strides of each input_tensor. - __inline__ void CalculateStrides( - int N, int dim_size, const std::vector> &in_dims) { - for (int j = 0; j < N; ++j) { - for (int i = 0; i < dim_size; ++i) { - strides[j][i] = in_dims[j][i] == 1 ? 0 : strides[j][i]; - strides[j][i] = - (i != 0 && strides[j][i] != 0) - ? std::accumulate(in_dims[j].begin(), in_dims[j].begin() + i, 1, - std::multiplies()) - : strides[j][i]; - } - } - } - - public: - explicit StridesCalculation(const int64_t &dim_size, - const std::vector> &in_dims, - const std::vector &out_dims) { - const auto N = in_dims.size(); - divmoders.resize(dim_size); - strides.resize(N, std::vector(dim_size, 1)); - - for (int i = 0; i < dim_size; ++i) { - divmoders[i] = platform::FastDivMod(out_dims[i]); - } - CalculateStrides(N, dim_size, in_dims); - } -}; - -template -struct BroadcastArgsWrapper { - using InVecType = platform::CudaAlignedVector; - using OutVecType = platform::CudaAlignedVector; - - OutT *out_data; - OutVecType *vec_out_data; - const InT *__restrict__ in_data[ET]; - const InVecType *__restrict__ vec_in_data[ET]; - bool no_broadcast[ET]; - platform::FastDivMod divmoders[kDims]; - uint32_t strides[ET][framework::DDim::kMaxRank]; - uint32_t scalar_cal_offset; - Functor func; - - HOSTDEVICE BroadcastArgsWrapper( - const std::vector &ins, framework::Tensor *out, - int scalar_cal_offset, Functor func, - const StridesCalculation &offset_calculator) - : scalar_cal_offset(scalar_cal_offset), func(func) { - for (int j = 0; j < ET; ++j) { - in_data[j] = ins[j]->data(); - vec_in_data[j] = reinterpret_cast(in_data[j]); - no_broadcast[j] = ins[j]->dims() == out->dims() ? true : false; - memcpy(strides[j], offset_calculator.strides[j].data(), - kDims * sizeof(uint32_t)); - } - out_data = out->data(); - vec_out_data = reinterpret_cast(out_data); - memcpy(divmoders, offset_calculator.divmoders.data(), - kDims * sizeof(platform::FastDivMod)); - } - - __device__ __forceinline__ uint32_t GetOffsetByDivmod(int idx, int in_idx) { - uint32_t offset = 0; - -#pragma unroll(kDims) - for (int i = 0; i < kDims; ++i) { - auto fast_divmoder = divmoders[i].Divmod(idx); - idx = fast_divmoder.val[0]; - offset += fast_divmoder.val[1] * strides[in_idx][i]; - } - return offset; - } - - __device__ __forceinline__ void LoadVectorizedDataCommon( - InVecType *vector_args, int tid, int idx) { - *vector_args = vec_in_data[idx][tid]; - } - - __device__ __forceinline__ void LoadVectorizedDataByDivmod(InT *scalar_args, - int tid, int idx) { - int index = tid * VecSize; -#pragma unroll(VecSize) - for (int i = 0; i < VecSize; ++i) { - uint32_t offset = GetOffsetByDivmod(index + i, idx); - scalar_args[i] = in_data[idx][offset]; - } - } - - __device__ __forceinline__ void LoadScalarizedDataCommon(InT args[], int tid, - int idx) { - args[idx] = in_data[idx][tid + scalar_cal_offset]; - } - - __device__ __forceinline__ void LoadScalarizedDataByDivmod(InT args[], - int tid, int idx) { - auto offset = GetOffsetByDivmod(tid + scalar_cal_offset, idx); - args[idx] = in_data[idx][offset]; - } - - __device__ __forceinline__ void LoadVectorizedData(InT (*args)[VecSize], - int tid) { -#pragma unroll(ET) - for (int j = 0; j < ET; ++j) { - if (no_broadcast[j]) { - InVecType *vector_args = reinterpret_cast(args[j]); - LoadVectorizedDataCommon(vector_args, tid, j); - } else { - LoadVectorizedDataByDivmod(args[j], tid, j); - } - } +template +__device__ __forceinline__ void LoadData( + T *dst, const T *__restrict__ src, uint32_t block_offset, + const kps::details::BroadcastConfig &config, int numel, int num, + bool need_broadcast) { + // numel : whole num of output + // num: how many data will be deal with in this time + if (need_broadcast) { + kps::ReadDataBc( + dst, src, block_offset, config, numel, 1, 1); + } else { + kps::ReadData(dst, src + block_offset, num); } +} - __device__ __forceinline__ void LoadScalarizedData(InT args[], int tid) { -#pragma unroll(ET) - for (int j = 0; j < ET; ++j) { - if (no_broadcast[j]) { - LoadScalarizedDataCommon(args, tid, j); - } else { - LoadScalarizedDataByDivmod(args, tid, j); - } - } +template +__global__ void BroadcastKernelTernary( + const InT *__restrict__ in0, const InT *__restrict__ in1, + const InT *__restrict__ in2, OutT *out, + framework::Array use_broadcast, uint32_t numel, + framework::Array, MAX_INPUT_NUM> + configlists, + int main_tid, int tail_tid, Functor func) { + int block_offset = + blockIdx.x * blockDim.x * VecSize; // data offset of this block + int num = tail_tid; + InT arg[3][VecSize]; + OutT result[VecSize]; + const bool is_boundary = true; + if (blockIdx.x < main_tid) { + num = blockDim.x * VecSize; // blockIdx.x < main_tid + // load in0, in1, in2 + LoadData(arg[0], in0, block_offset, configlists[0], + numel, num, use_broadcast[0]); + LoadData(arg[1], in1, block_offset, configlists[1], + numel, num, use_broadcast[1]); + LoadData(arg[2], in2, block_offset, configlists[2], + numel, num, use_broadcast[2]); + kps::ElementwiseTernary( + result, arg[0], arg[1], arg[2], func); + kps::WriteData(out + block_offset, result, num); + } else { // blockIdx.x == main_tid + // This is the last block and tial_tid != 0, set is_boundary = true + // is_boundary = true, boundary judgment needs to be made when loading data + // to avoid access storage overflow + kps::Init(arg[0], static_cast(1.0f)); + kps::Init(arg[1], static_cast(1.0f)); + kps::Init(arg[2], static_cast(1.0f)); + LoadData(arg[0], in0, block_offset, + configlists[0], numel, num, + use_broadcast[0]); + LoadData(arg[1], in1, block_offset, + configlists[1], numel, num, + use_broadcast[1]); + LoadData(arg[2], in2, block_offset, + configlists[2], numel, num, + use_broadcast[2]); + kps::ElementwiseTernary( + result, arg[0], arg[1], arg[2], func); + kps::WriteData(out + block_offset, result, + num); } +} - __device__ __forceinline__ void StoreVectorizedData(OutVecType vec_args_out, - int tid) { - vec_out_data[tid] = vec_args_out; +template +__global__ void BroadcastKernelBinary( + const InT *__restrict__ in0, const InT *__restrict__ in1, OutT *out, + framework::Array use_broadcast, uint32_t numel, + framework::Array, MAX_INPUT_NUM> + configlists, + int main_tid, int tail_tid, Functor func) { + int block_offset = + blockIdx.x * blockDim.x * VecSize; // data offset of this block + int num = tail_tid; + InT arg[2][VecSize]; + OutT result[VecSize]; + const bool is_boundary = true; + if (blockIdx.x < main_tid) { + num = blockDim.x * VecSize; // blockIdx.x < main_tid + LoadData(arg[0], in0, block_offset, configlists[0], + numel, num, use_broadcast[0]); + LoadData(arg[1], in1, block_offset, configlists[1], + numel, num, use_broadcast[1]); + kps::ElementwiseBinary(result, arg[0], + arg[1], func); + kps::WriteData(out + block_offset, result, num); + } else { // reminder + // This is the last block and tial_tid != 0, set is_boundary = true + // is_boundary = true, boundary judgment needs to be made when loading data + // to avoid access storage overflow + kps::Init(arg[0], static_cast(1.0f)); + kps::Init(arg[1], static_cast(1.0f)); + LoadData(arg[0], in0, block_offset, + configlists[0], numel, num, + use_broadcast[0]); + LoadData(arg[1], in1, block_offset, + configlists[1], numel, num, + use_broadcast[1]); + kps::ElementwiseBinary(result, arg[0], + arg[1], func); + kps::WriteData(out + block_offset, result, + num); } +} - __device__ __forceinline__ void StoreScalarizedData(OutT args_out, int tid) { - out_data[scalar_cal_offset + tid] = args_out; +template +__global__ void BroadcastKernelUnary( + const InT *__restrict__ in, OutT *out, int numel, + kps::details::BroadcastConfig config, int main_tid, int tail_tid, + Functor func) { + int block_offset = + blockIdx.x * blockDim.x * VecSize; // data offset of this block + int num = tail_tid; + InT arg[VecSize]; + OutT result[VecSize]; + const bool is_boundary = true; + if (blockIdx.x < main_tid) { + num = blockDim.x * VecSize; // blockIdx.x < main_tid + kps::ReadDataBc(&arg[0], in, block_offset, + config, numel, 1, 1); + kps::ElementwiseUnary(&result[0], + &arg[0], func); + kps::WriteData(out + block_offset, &result[0], num); + } else { + // This is the last block and tial_tid != 0, set is_boundary = true + // is_boundary = true, boundary judgment needs to be made when loading data + // to avoid access storage overflow + kps::Init(&arg[0], static_cast(1.0f)); + kps::ReadDataBc( + &arg[0], in, block_offset, config, numel, 1, 1); + kps::ElementwiseUnary(&result[0], + &arg[0], func); + kps::WriteData(out + block_offset, + &result[0], num); } -}; - -template -__device__ inline void ScalarizedBroadcastKernelImpl( - BroadcastArgsWrapper broadcast_wrapper, int tid) { - InT args[ET]; - OutT args_out; - broadcast_wrapper.LoadScalarizedData(args, tid); - - // Calcualtion of the in_tensor data. - args_out = broadcast_wrapper.func(args); - - broadcast_wrapper.StoreScalarizedData(args_out, tid); } -template -__device__ inline void VectorizedBroadcastKernelImpl( - BroadcastArgsWrapper broadcast_wrapper, int tid) { - using OutVecType = platform::CudaAlignedVector; - OutVecType args_out; - InT ins[ET]; - InT args[ET][VecSize]; - broadcast_wrapper.LoadVectorizedData(args, tid); +template +void LaunchKernel(const platform::CUDADeviceContext &ctx, + const std::vector &ins, + framework::Tensor *out, Functor func, + DimensionsTransform merge_dims) { + int numel = out->numel(); + const int threads = 256; + int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads; -#pragma unroll(VecSize) - for (int i = 0; i < VecSize; ++i) { -#pragma unroll(ET) - for (int j = 0; j < ET; ++j) { - ins[j] = args[j][i]; + int main_tid = numel / (VecSize * threads); + int tail_tid = numel % (VecSize * threads); + auto stream = ctx.stream(); + OutT *out_data = out->data(); + + framework::Array, MAX_INPUT_NUM> + configlists; + framework::Array use_broadcast; + + for (int i = 0; i < ET; i++) { + use_broadcast[i] = (ins[i]->numel() != numel); + if (use_broadcast[i]) { + // get the broadcast config, + // if data shape is[m, n], then you should set data_dim = {n, m} + // eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3} + configlists[i] = kps::details::BroadcastConfig( + merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size); } - args_out.val[i] = broadcast_wrapper.func(ins); } - broadcast_wrapper.StoreVectorizedData(args_out, tid); -} -template -__global__ void ElementwiseBroadcastKernel( - BroadcastArgsWrapper broadcast_wrapper, int main_tid, int tail_tid) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - // Vectorized calculation of major data whose length is the max multipler of - // VecSize, - // eg: Calcualting the front 1024-length data in total 1027 data once VecSize - // is 4. - if (tid < main_tid) { - VectorizedBroadcastKernelImpl( - broadcast_wrapper, tid); - } - // Scalarzed calculation of rest data whose lenght cannot fulfill VecSize. - // eg: Calcualting the rest 3-length data in total 1027 data once VecSize is - // 4. - if (tid < tail_tid) { - ScalarizedBroadcastKernelImpl( - broadcast_wrapper, tid); + if (ET == kUnary) { // for unary eg: relu + BroadcastKernelUnary<<>>( + ins[0]->data(), out_data, numel, configlists[0], main_tid, + tail_tid, func); + } else if (ET == kBinary) { // for binary eg: add: a + b + BroadcastKernelBinary<<>>( + ins[0]->data(), ins[1]->data(), out_data, use_broadcast, + numel, configlists, main_tid, tail_tid, func); + } else { // for ternary eg:fma : a * b + c + BroadcastKernelTernary<<>>( + ins[0]->data(), ins[1]->data(), ins[2]->data(), out_data, + use_broadcast, numel, configlists, main_tid, tail_tid, func); } } @@ -365,98 +359,24 @@ void LaunchBroadcastKernelForDifferentDimSize( const platform::CUDADeviceContext &ctx, const std::vector &ins, framework::Tensor *out, int axis, Functor func) { - int numel = out->numel(); - int threads = GetThreadsConfig(ctx, numel, VecSize); - int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads; - int main_tid = numel / VecSize; - int tail_tid = numel % VecSize; - int vec_len = main_tid * VecSize; - auto stream = ctx.stream(); - const auto merge_dims = DimensionsTransform(ins, out->dims(), axis); - const auto offset_calculator = StridesCalculation( - merge_dims.dim_size, merge_dims.in_dims, merge_dims.out_dims); +#define DIM_SIZE(size) \ + case size: { \ + LaunchKernel(ctx, ins, out, func, \ + merge_dims); \ + } break; switch (merge_dims.dim_size) { - case 1: { - auto broadcast_wrapper = - BroadcastArgsWrapper( - ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel<<>>( - broadcast_wrapper, main_tid, tail_tid); - break; - } - case 2: { - auto broadcast_wrapper = - BroadcastArgsWrapper( - ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel<<>>( - broadcast_wrapper, main_tid, tail_tid); - break; - } - case 3: { - auto broadcast_wrapper = - BroadcastArgsWrapper( - ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel<<>>( - broadcast_wrapper, main_tid, tail_tid); - break; - } - case 4: { - auto broadcast_wrapper = - BroadcastArgsWrapper( - ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel<<>>( - broadcast_wrapper, main_tid, tail_tid); - break; - } - case 5: { - auto broadcast_wrapper = - BroadcastArgsWrapper( - ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel<<>>( - broadcast_wrapper, main_tid, tail_tid); - break; - } - case 6: { - auto broadcast_wrapper = - BroadcastArgsWrapper( - ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel<<>>( - broadcast_wrapper, main_tid, tail_tid); - break; - } - case 7: { - auto broadcast_wrapper = - BroadcastArgsWrapper( - ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel<<>>( - broadcast_wrapper, main_tid, tail_tid); - break; - } - case 8: { - auto broadcast_wrapper = - BroadcastArgsWrapper( - ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel<<>>( - broadcast_wrapper, main_tid, tail_tid); - break; - } - default: { - PADDLE_THROW(platform::errors::InvalidArgument( - "The maximum dimension of input tensor is expected to be less than " - "%d, but recieved %d.\n", - merge_dims.dim_size, framework::DDim::kMaxRank)); - } + DIM_SIZE(1); + DIM_SIZE(2); + DIM_SIZE(3); + DIM_SIZE(4); + DIM_SIZE(5); + DIM_SIZE(6); + DIM_SIZE(7); + DIM_SIZE(8); } +#undef DIM_SIZE } template @@ -528,5 +448,7 @@ void LaunchElementwiseCudaKernel( } } +#undef MAX_INPUT_NUM + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h index 101512e35fdcb..5cc2e3b1c2fba 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/fast_divmod.h" @@ -26,7 +27,8 @@ limitations under the License. */ namespace paddle { namespace operators { -enum ElementwiseType { kUnary = 1, kBinary = 2 }; +namespace kps = paddle::operators::kernel_primitives; +enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3 }; /* * According to NVIDIA, if number of threads per block is 64/128/256/512, @@ -52,165 +54,156 @@ inline int GetThreadsConfig(const platform::CUDADeviceContext &ctx, return std::max(64, threads); } -/* -* Only the address of input data is the multiplier of 1,2,4, vectorized load -* with corresponding multiplier-value is possible. Moreover, the maximum length -* of vectorized load is 128 bits once. Hence, valid length of vectorized load -* shall be determined under both former constraints. -*/ -template -int GetVectorizedSizeImpl(const T *pointer) { - constexpr int max_load_bits = 128; - int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T); - uint64_t address = reinterpret_cast(pointer); - constexpr int vec8 = - std::alignment_of>::value; // NOLINT - constexpr int vec4 = - std::alignment_of>::value; // NOLINT - constexpr int vec2 = - std::alignment_of>::value; // NOLINT - if (address % vec8 == 0) { - /* - * Currently, decide to deal with no more than 4 data once while adopting - * vectorization load/store, if performance test shows that dealing with - * 8 data once in vectorization load/store does get optimized, return code - * below can be changed into " return std::min(8, valid_vec_size); " . - */ - return std::min(4, valid_vec_size); - } else if (address % vec4 == 0) { - return std::min(4, valid_vec_size); - } else if (address % vec2 == 0) { - return std::min(2, valid_vec_size); - } else { - return 1; - } -} - template -int GetVectorizedSize(const std::vector &ins, - const std::vector &outs) { +int GetVectorizedSizeForIO(const std::vector &ins, + const std::vector &outs) { int vec_size = 4; for (auto iter = ins.begin(); iter != ins.end(); ++iter) { - vec_size = - std::min(vec_size, GetVectorizedSizeImpl((*iter)->data())); + vec_size = std::min(vec_size, + platform::GetVectorizedSize((*iter)->data())); } for (auto iter = outs.begin(); iter != outs.end(); ++iter) { - vec_size = - std::min(vec_size, GetVectorizedSizeImpl((*iter)->data())); + vec_size = std::min( + vec_size, platform::GetVectorizedSize((*iter)->data())); } return vec_size; } -template -struct ElementwiseDataWrapper { - OutT *out; - const InT *in0; - const InT *in1; - __device__ ElementwiseDataWrapper(OutT *out, const InT *in0, - const InT *in1 = nullptr) - : out(out), in0(in0), in1(in1) {} - - using InVecType = CudaAlignedVector; - using OutVecType = CudaAlignedVector; - - inline __device__ void load_vector(InVecType args[], int idx) { - const InVecType *x_vec = reinterpret_cast(in0); - args[0] = x_vec[idx]; - if (ET == ElementwiseType::kBinary) { - const InVecType *y_vec = reinterpret_cast(in1); - args[1] = y_vec[idx]; - } - } - - inline __device__ void load_scalar(InT args[], int idx) { - args[0] = in0[idx]; - if (ET == ElementwiseType::kBinary) { - args[1] = in1[idx]; - } - } - - inline __device__ void store_vector(OutVecType res, int idx) { - OutVecType *out_vec = reinterpret_cast(out); - out_vec[idx] = res; - } - - inline __device__ void store_scalar(OutT res, int idx) { out[idx] = res; } -}; - -template -__device__ inline void VectorizedKernelImpl( - ElementwiseDataWrapper data, Functor func, - int tid) { - using InVecType = CudaAlignedVector; - using OutVecType = CudaAlignedVector; - InVecType ins_vec[ET]; - OutVecType out_vec; - InT *ins_ptr[ET]; - InT ins[ET]; -#pragma unroll - for (int i = 0; i < ET; ++i) { - ins_ptr[i] = reinterpret_cast(&(ins_vec[i])); +template +__global__ void ElementVectorizedUnary(const InT *__restrict__ in0, OutT *out, + int size, Functor func) { + int data_offset = VecSize * blockIdx.x * blockDim.x; + // data offset of this block + int num = size - data_offset; + num = (VecSize * blockDim.x) > num ? num : VecSize * blockDim.x; + // the num this time have to deal with + InT args[VecSize]; + OutT result[VecSize]; + const bool is_reminder = true; + if (VecSize * blockDim.x > num) { // reminder segment + kps::Init(&args[0], static_cast(1.0f)); + kps::ReadData(args, in0 + data_offset, + num); + kps::ElementwiseUnary(result, args, + func); + kps::WriteData(out + data_offset, result, + num); + } else { // complete segment + kps::ReadData(args, in0 + data_offset, num); + kps::ElementwiseUnary(result, args, + func); + kps::WriteData(out + data_offset, result, num); } - // load - data.load_vector(ins_vec, tid); - -// compute -#pragma unroll - for (int i = 0; i < VecSize; ++i) { -#pragma unroll - for (int j = 0; j < ET; ++j) { - ins[j] = ins_ptr[j][i]; - } - out_vec.val[i] = func(ins); - } - // store - data.store_vector(out_vec, tid); } -template -__device__ inline void ScalarKernelImpl( - ElementwiseDataWrapper data, Functor func, - int start, int remain) { - InT ins[ET]; - OutT out; - - for (int i = 0; i < remain; ++i) { - int idx = start + i; - // load - data.load_scalar(ins, idx); - // compute - out = func(ins); - // store - data.store_scalar(out, idx); +template +__global__ void ElementVectorizedBinary(const InT *__restrict__ in0, + const InT *__restrict__ in1, OutT *out, + int size, Functor func) { + int data_offset = VecSize * blockIdx.x * blockDim.x; + // data offset of this block + int num = size - data_offset; + num = (VecSize * blockDim.x) > num ? num : VecSize * blockDim.x; + // the num this time have to deal with + InT arg0[VecSize]; + InT arg1[VecSize]; + OutT result[VecSize]; + + const bool is_reminder = true; + if (VecSize * blockDim.x > num) { // reminder segment + kps::Init(&arg0[0], static_cast(1.0f)); + kps::Init(&arg1[0], static_cast(1.0f)); + kps::ReadData(&arg0[0], in0 + data_offset, + num); + kps::ReadData(&arg1[0], in1 + data_offset, + num); + kps::ElementwiseBinary(result, &arg0[0], + &arg1[0], func); + kps::WriteData(out + data_offset, result, + num); + } else { // complete segment + kps::ReadData(&arg0[0], in0 + data_offset, num); + kps::ReadData(&arg1[0], in1 + data_offset, num); + kps::ElementwiseBinary(result, &arg0[0], + &arg1[0], func); + kps::WriteData(out + data_offset, result, num); } } -template -__global__ void VectorizedKernel(const InT *__restrict__ in0, - const InT *__restrict__ in1, OutT *out, - int size, Functor func) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int remain = size - VecSize * tid; - remain = remain > 0 ? remain : 0; - auto data = ElementwiseDataWrapper(out, in0, in1); - if (remain >= VecSize) { - VectorizedKernelImpl(data, func, tid); +template +__global__ void ElementVectorizedTernary(const InT *__restrict__ in0, + const InT *__restrict__ in1, + const InT *__restrict__ in2, OutT *out, + int size, Functor func) { + int data_offset = VecSize * blockIdx.x * blockDim.x; + // data offset of this block + int num = size - data_offset; + num = (VecSize * blockDim.x) > num ? num : VecSize * blockDim.x; + // the num this time have to deal with + InT args[3][VecSize]; + OutT result[VecSize]; + + const bool is_reminder = true; + if (VecSize * blockDim.x > num) { // reminder segment + kps::Init(args[0], static_cast(1.0f)); + kps::Init(args[1], static_cast(1.0f)); + kps::Init(args[2], static_cast(1.0f)); + kps::ReadData(args[0], in0 + data_offset, + num); + kps::ReadData(args[1], in1 + data_offset, + num); + kps::ReadData(args[2], in2 + data_offset, + num); + kps::ElementwiseTernary( + result, args[0], args[1], args[2], func); + kps::WriteData(out + data_offset, result, + num); } else { - ScalarKernelImpl(data, func, tid * VecSize, remain); + kps::ReadData(args[0], in0 + data_offset, num); + kps::ReadData(args[1], in1 + data_offset, num); + kps::ReadData(args[2], in2 + data_offset, num); + kps::ElementwiseTernary( + result, args[0], args[1], args[2], func); + kps::WriteData(out + data_offset, result, num); } } -template -__global__ void ScalarKernel(const InT *__restrict__ in0, - const InT *__restrict__ in1, OutT *out, int size, - Functor func) { - auto data = ElementwiseDataWrapper(out, in0, in1); - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int remain = tid < size ? 1 : 0; - ScalarKernelImpl(data, func, tid, remain); +template +void ElementwiseCudaKernel(const platform::CUDADeviceContext &ctx, + const std::vector &ins, + std::vector *outs, + Functor func) { + auto numel = ins[0]->numel(); + int block_size = GetThreadsConfig(ctx, numel, VecSize); + int grid_size = + ((numel + VecSize - 1) / VecSize + block_size - 1) / block_size; + const InT *in0 = ins[0]->data(); + OutT *out = (*outs)[0]->data(); + // cuda kernel + auto stream = ctx.stream(); + switch (ET) { + case ElementwiseType::kTernary: + ElementVectorizedTernary<<>>( + in0, ins[1]->data(), ins[2]->data(), out, numel, func); + break; + case ElementwiseType::kBinary: + ElementVectorizedBinary<<>>( + in0, ins[1]->data(), out, numel, func); + break; + case ElementwiseType::kUnary: + ElementVectorizedUnary<<>>( + in0, out, numel, func); + break; + default: { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported this ElementwiseType : %d !", ET)); + break; + } + } } template @@ -219,35 +212,22 @@ void LaunchSameDimsElementwiseCudaKernel( const std::vector &ins, std::vector *outs, Functor func) { // calculate the max vec_size for all ins and outs - auto size = ins[0]->numel(); - int vec_size = GetVectorizedSize(ins, *outs); - int block_size = GetThreadsConfig(ctx, size, vec_size); - int grid_size = - ((size + vec_size - 1) / vec_size + block_size - 1) / block_size; - const InT *in0 = ins[0]->data(); - const InT *in1 = - (ET == ElementwiseType::kBinary) ? ins[1]->data() : nullptr; - OutT *out = (*outs)[0]->data(); - // cuda kernel - auto stream = ctx.stream(); - + int vec_size = GetVectorizedSizeForIO(ins, *outs); switch (vec_size) { case 4: - VectorizedKernel<<>>( - in0, in1, out, size, func); + ElementwiseCudaKernel(ctx, ins, outs, func); break; case 2: - VectorizedKernel<<>>( - in0, in1, out, size, func); + ElementwiseCudaKernel(ctx, ins, outs, func); break; case 1: - ScalarKernel<<>>(in0, in1, out, - size, func); + ElementwiseCudaKernel(ctx, ins, outs, func); break; - default: + default: { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported vectorized size: %d !", vec_size)); break; + } } }