diff --git a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h index 17cf7c762def2..129c90a22be6b 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h @@ -15,10 +15,14 @@ #pragma once #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" - +#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" namespace paddle { namespace operators { +#define MAX_INPUT_NUM 3 // the max num of ET for BroadcacstConfig + +namespace kps = paddle::operators::kernel_primitives; + struct DimensionsTransform { using DimVector = std::vector; typedef void (*MergeFunctor)(bool &, std::vector &, DimVector &, @@ -161,202 +165,113 @@ struct DimensionsTransform { } }; -struct StridesCalculation { - std::vector> strides; - std::vector divmoders; - - private: - // To calculate the strides of each input_tensor. - __inline__ void CalculateStrides( - int N, int dim_size, const std::vector> &in_dims) { - for (int j = 0; j < N; ++j) { - for (int i = 0; i < dim_size; ++i) { - strides[j][i] = in_dims[j][i] == 1 ? 0 : strides[j][i]; - strides[j][i] = - (i != 0 && strides[j][i] != 0) - ? std::accumulate(in_dims[j].begin(), in_dims[j].begin() + i, 1, - std::multiplies()) - : strides[j][i]; - } - } - } - - public: - explicit StridesCalculation(const int64_t &dim_size, - const std::vector> &in_dims, - const std::vector &out_dims) { - const auto N = in_dims.size(); - divmoders.resize(dim_size); - strides.resize(N, std::vector(dim_size, 1)); - - for (int i = 0; i < dim_size; ++i) { - divmoders[i] = platform::FastDivMod(out_dims[i]); - } - CalculateStrides(N, dim_size, in_dims); - } -}; - -template -struct BroadcastArgsWrapper { - using InVecType = platform::AlignedVector; - using OutVecType = platform::AlignedVector; - - OutT *out_data; - OutVecType *vec_out_data; - const InT *__restrict__ in_data[ET]; - const InVecType *__restrict__ vec_in_data[ET]; - bool no_broadcast[ET]; - platform::FastDivMod divmoders[kDims]; - uint32_t strides[ET][framework::DDim::kMaxRank]; - uint32_t scalar_cal_offset; - Functor func; - - HOSTDEVICE BroadcastArgsWrapper( - const std::vector &ins, framework::Tensor *out, - int scalar_cal_offset, Functor func, - const StridesCalculation &offset_calculator) - : scalar_cal_offset(scalar_cal_offset), func(func) { - for (int j = 0; j < ET; ++j) { - in_data[j] = ins[j]->data(); - vec_in_data[j] = reinterpret_cast(in_data[j]); - no_broadcast[j] = ins[j]->dims() == out->dims() ? true : false; - memcpy(strides[j], offset_calculator.strides[j].data(), - kDims * sizeof(uint32_t)); - } - out_data = out->data(); - vec_out_data = reinterpret_cast(out_data); - memcpy(divmoders, offset_calculator.divmoders.data(), - kDims * sizeof(platform::FastDivMod)); - } - - __device__ __forceinline__ uint32_t GetOffsetByDivmod(int idx, int in_idx) { - uint32_t offset = 0; - -#pragma unroll(kDims) - for (int i = 0; i < kDims; ++i) { - auto fast_divmoder = divmoders[i].Divmod(idx); - idx = fast_divmoder.val[0]; - offset += fast_divmoder.val[1] * strides[in_idx][i]; - } - return offset; - } - - __device__ __forceinline__ void LoadVectorizedDataCommon( - InVecType *vector_args, int tid, int idx) { - *vector_args = vec_in_data[idx][tid]; - } - - __device__ __forceinline__ void LoadVectorizedDataByDivmod(InT *scalar_args, - int tid, int idx) { - int index = tid * VecSize; -#pragma unroll(VecSize) - for (int i = 0; i < VecSize; ++i) { - uint32_t offset = GetOffsetByDivmod(index + i, idx); - scalar_args[i] = in_data[idx][offset]; - } - } - - __device__ __forceinline__ void LoadScalarizedDataCommon(InT args[], int tid, - int idx) { - args[idx] = in_data[idx][tid + scalar_cal_offset]; - } - - __device__ __forceinline__ void LoadScalarizedDataByDivmod(InT args[], - int tid, int idx) { - auto offset = GetOffsetByDivmod(tid + scalar_cal_offset, idx); - args[idx] = in_data[idx][offset]; - } - - __device__ __forceinline__ void LoadVectorizedData(InT (*args)[VecSize], - int tid) { -#pragma unroll(ET) - for (int j = 0; j < ET; ++j) { - if (no_broadcast[j]) { - InVecType *vector_args = reinterpret_cast(args[j]); - LoadVectorizedDataCommon(vector_args, tid, j); - } else { - LoadVectorizedDataByDivmod(args[j], tid, j); - } - } +template +__device__ __forceinline__ void LoadData( + T *dst, const T *__restrict__ src, uint32_t block_offset, + const kps::details::BroadcastConfig &config, int numel, int num, + bool need_broadcast) { + // numel : whole num of output + // num: how many data will be deal with in this time + if (need_broadcast) { + kps::ReadDataBc( + dst, src, block_offset, config, numel, 1, 1); + } else { + kps::ReadData(dst, src + block_offset, num); } +} - __device__ __forceinline__ void LoadScalarizedData(InT args[], int tid) { -#pragma unroll(ET) - for (int j = 0; j < ET; ++j) { - if (no_broadcast[j]) { - LoadScalarizedDataCommon(args, tid, j); - } else { - LoadScalarizedDataByDivmod(args, tid, j); - } - } +template +__device__ void DealSegment( + const framework::Array &in, OutT *out, + const framework::Array &use_broadcast, uint32_t numel, + const framework::Array, + MAX_INPUT_NUM> &configlists, + int num, Functor func) { + InT args[ET][VecSize]; + OutT result[VecSize]; + int block_offset = blockIdx.x * blockDim.x * VecSize; +// load +#pragma unroll + for (int i = 0; i < ET; i++) { + kps::Init(args[i], static_cast(1.0f)); + LoadData(args[i], in[i], block_offset, + configlists[i], numel, num, + use_broadcast[i]); } - - __device__ __forceinline__ void StoreVectorizedData(OutVecType vec_args_out, - int tid) { - vec_out_data[tid] = vec_args_out; + // compute + if (ET == kUnary) { + kps::ElementwiseUnary(result, args[0], + func); + } else if (ET == kBinary) { + kps::ElementwiseBinary(result, args[0], + args[1], func); + } else { + kps::ElementwiseTernary( + result, args[0], args[1], args[2], func); } + // compute + kps::WriteData(out + block_offset, result, + num); +} - __device__ __forceinline__ void StoreScalarizedData(OutT args_out, int tid) { - out_data[scalar_cal_offset + tid] = args_out; +template +__global__ void BroadcastKernel( + framework::Array in, OutT *out, + framework::Array use_broadcast, uint32_t numel, + framework::Array, MAX_INPUT_NUM> + configlists, + int main_tid, int tail_tid, Functor func) { + int block_offset = blockIdx.x * blockDim.x * VecSize; + // data offset of this block + if (blockIdx.x < main_tid) { + int num = blockDim.x * VecSize; // blockIdx.x < main_tid + DealSegment( + in, out, use_broadcast, numel, configlists, num, func); + } else { // reminder + int num = tail_tid; + DealSegment( + in, out, use_broadcast, numel, configlists, num, func); } -}; - -template -__device__ inline void ScalarizedBroadcastKernelImpl( - BroadcastArgsWrapper broadcast_wrapper, int tid) { - InT args[ET]; - OutT args_out; - broadcast_wrapper.LoadScalarizedData(args, tid); - - // Calcualtion of the in_tensor data. - args_out = broadcast_wrapper.func(args); - - broadcast_wrapper.StoreScalarizedData(args_out, tid); } -template -__device__ inline void VectorizedBroadcastKernelImpl( - BroadcastArgsWrapper broadcast_wrapper, int tid) { - using OutVecType = platform::AlignedVector; - OutVecType args_out; - InT ins[ET]; - InT args[ET][VecSize]; - broadcast_wrapper.LoadVectorizedData(args, tid); +template +void LaunchKernel(const platform::CUDADeviceContext &ctx, + const std::vector &ins, + framework::Tensor *out, Functor func, + DimensionsTransform merge_dims) { + int numel = out->numel(); + const int threads = 256; + int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads; -#pragma unroll(VecSize) - for (int i = 0; i < VecSize; ++i) { -#pragma unroll(ET) - for (int j = 0; j < ET; ++j) { - ins[j] = args[j][i]; + int main_tid = numel / (VecSize * threads); + int tail_tid = numel % (VecSize * threads); + auto stream = ctx.stream(); + OutT *out_data = out->data(); + + framework::Array, MAX_INPUT_NUM> + configlists; + framework::Array use_broadcast; + framework::Array ins_data; + + for (int i = 0; i < ET; i++) { + use_broadcast[i] = (ins[i]->numel() != numel); + ins_data[i] = ins[i]->data(); + if (use_broadcast[i]) { + // get the broadcast config, + // if data shape is[m, n], then you should set data_dim = {n, m} + // eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3} + configlists[i] = kps::details::BroadcastConfig( + merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size); } - args_out.val[i] = broadcast_wrapper.func(ins); } - broadcast_wrapper.StoreVectorizedData(args_out, tid); -} -template -__global__ void ElementwiseBroadcastKernel( - BroadcastArgsWrapper broadcast_wrapper, int main_tid, int tail_tid) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - // Vectorized calculation of major data whose length is the max multipler of - // VecSize, - // eg: Calcualting the front 1024-length data in total 1027 data once VecSize - // is 4. - if (tid < main_tid) { - VectorizedBroadcastKernelImpl( - broadcast_wrapper, tid); - } - // Scalarzed calculation of rest data whose lenght cannot fulfill VecSize. - // eg: Calcualting the rest 3-length data in total 1027 data once VecSize is - // 4. - if (tid < tail_tid) { - ScalarizedBroadcastKernelImpl( - broadcast_wrapper, tid); - } + BroadcastKernel<<>>( + ins_data, out_data, use_broadcast, numel, configlists, main_tid, tail_tid, + func); } template &ins, framework::Tensor *out, int axis, Functor func) { - int numel = out->numel(); - int threads = GetThreadsConfig(ctx, numel, VecSize); - int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads; - int main_tid = numel / VecSize; - int tail_tid = numel % VecSize; - int vec_len = main_tid * VecSize; - auto stream = ctx.stream(); - const auto merge_dims = DimensionsTransform(ins, out->dims(), axis); - const auto offset_calculator = StridesCalculation( - merge_dims.dim_size, merge_dims.in_dims, merge_dims.out_dims); +#define DIM_SIZE(size) \ + case size: { \ + LaunchKernel(ctx, ins, out, func, \ + merge_dims); \ + } break; switch (merge_dims.dim_size) { - case 1: { - auto broadcast_wrapper = - BroadcastArgsWrapper( - ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel<<>>( - broadcast_wrapper, main_tid, tail_tid); - break; - } - case 2: { - auto broadcast_wrapper = - BroadcastArgsWrapper( - ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel<<>>( - broadcast_wrapper, main_tid, tail_tid); - break; - } - case 3: { - auto broadcast_wrapper = - BroadcastArgsWrapper( - ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel<<>>( - broadcast_wrapper, main_tid, tail_tid); - break; - } - case 4: { - auto broadcast_wrapper = - BroadcastArgsWrapper( - ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel<<>>( - broadcast_wrapper, main_tid, tail_tid); - break; - } - case 5: { - auto broadcast_wrapper = - BroadcastArgsWrapper( - ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel<<>>( - broadcast_wrapper, main_tid, tail_tid); - break; - } - case 6: { - auto broadcast_wrapper = - BroadcastArgsWrapper( - ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel<<>>( - broadcast_wrapper, main_tid, tail_tid); - break; - } - case 7: { - auto broadcast_wrapper = - BroadcastArgsWrapper( - ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel<<>>( - broadcast_wrapper, main_tid, tail_tid); - break; - } - case 8: { - auto broadcast_wrapper = - BroadcastArgsWrapper( - ins, out, vec_len, func, offset_calculator); - ElementwiseBroadcastKernel<<>>( - broadcast_wrapper, main_tid, tail_tid); - break; - } - default: { - PADDLE_THROW(platform::errors::InvalidArgument( - "The maximum dimension of input tensor is expected to be less than " - "%d, but recieved %d.\n", - merge_dims.dim_size, framework::DDim::kMaxRank)); - } + DIM_SIZE(1); + DIM_SIZE(2); + DIM_SIZE(3); + DIM_SIZE(4); + DIM_SIZE(5); + DIM_SIZE(6); + DIM_SIZE(7); + DIM_SIZE(8); } +#undef DIM_SIZE } template @@ -528,5 +369,7 @@ void LaunchElementwiseCudaKernel( } } +#undef MAX_INPUT_NUM + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h index 1b680cfc995a5..e591b145d2388 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/fast_divmod.h" @@ -26,6 +27,7 @@ limitations under the License. */ namespace paddle { namespace operators { +namespace kps = paddle::operators::kernel_primitives; enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3 }; /* @@ -67,121 +69,74 @@ int GetVectorizedSizeForIO(const std::vector &ins, return vec_size; } -template -struct ElementwiseDataWrapper { - using InVecType = platform::AlignedVector; - using OutVecType = platform::AlignedVector; - - const InT *__restrict__ in_data[ET]; - OutT *out_data; - uint32_t scalar_cal_offset; - - HOSTDEVICE ElementwiseDataWrapper( - const std::vector &ins, - std::vector *outs, uint32_t scalar_cal_offset) - : scalar_cal_offset(scalar_cal_offset) { -#pragma unroll - for (int i = 0; i < ET; ++i) { - in_data[i] = ins[i]->data(); - } - out_data = (*outs)[0]->data(); - } - - inline __device__ void LoadVectorizedData(InVecType vec_args[], int tid) { -#pragma unroll - for (int i = 0; i < ET; ++i) { - const InVecType *in_vec_data = - reinterpret_cast(in_data[i]); - vec_args[i] = in_vec_data[tid]; - } - } - - inline __device__ void LoadScalarizedData(InT args[], int tid) { +template +__device__ void DealSegment( + const framework::Array &in, OutT *out, int num, + Functor func) { + int data_offset = VecSize * blockIdx.x * blockDim.x; + InT args[ET][VecSize]; + OutT result[VecSize]; +// load data #pragma unroll - for (int i = 0; i < ET; ++i) { - args[i] = in_data[i][tid + scalar_cal_offset]; - } - } - - inline __device__ void StoreVectorizedData(OutVecType res, int tid) { - OutVecType *out_vec = reinterpret_cast(out_data); - out_vec[tid] = res; - } - - inline __device__ void StoreScalarizedData(OutT res, int tid) { - out_data[tid + scalar_cal_offset] = res; + for (int i = 0; i < ET; i++) { + kps::Init(args[i], static_cast(1.0f)); + kps::ReadData(args[i], in[i] + data_offset, + num); } -}; - -template -__device__ inline void VectorizedKernelImpl(ElementwiseWrapper data, - Functor func, int tid) { - using InVecType = platform::AlignedVector; - using OutVecType = platform::AlignedVector; - InVecType ins_vec[ET]; - OutVecType out_vec; - InT *ins_ptr[ET]; - InT ins[ET]; -#pragma unroll - for (int i = 0; i < ET; ++i) { - ins_ptr[i] = reinterpret_cast(&(ins_vec[i])); - } - // load - data.LoadVectorizedData(ins_vec, tid); -// compute -#pragma unroll - for (int i = 0; i < VecSize; ++i) { -#pragma unroll - for (int j = 0; j < ET; ++j) { - ins[j] = ins_ptr[j][i]; - } - out_vec.val[i] = func(ins); + // compute + if (ET == kUnary) { + kps::ElementwiseUnary(result, args[0], + func); + } else if (ET == kBinary) { + kps::ElementwiseBinary(result, args[0], + args[1], func); + } else { + kps::ElementwiseTernary( + result, args[0], args[1], args[2], func); } - // store - data.StoreVectorizedData(out_vec, tid); -} -template -__device__ inline void ScalarKernelImpl(ElementwiseWrapper data, Functor func, - int tid) { - InT ins[ET]; - OutT out; - - // load - data.LoadScalarizedData(ins, tid); - // compute - out = func(ins); // store - data.StoreScalarizedData(out, tid); + kps::WriteData(out + data_offset, result, + num); } -template -__global__ void VectorizedKernel(ElementwiseWrapper data, int main_tid, - int tail_tid, Functor func) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - if (tid < main_tid) { - VectorizedKernelImpl( - data, func, tid); - } - if (tid < tail_tid) { - ScalarKernelImpl(data, func, - tid); +template +__global__ void ElementVectorizeKernel( + framework::Array in, OutT *out, int size, + Functor func) { + int data_offset = VecSize * blockIdx.x * blockDim.x; + int num = size - data_offset; + // the num this time have to deal with + if (VecSize * blockDim.x > num) { // reminder segment + DealSegment(in, out, num, func); + } else { // complete segment + DealSegment(in, out, num, func); } } -template -__global__ void ScalarKernel(ElementwiseWrapper data, int numel, Functor func) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < numel) { - ScalarKernelImpl(data, func, - tid); +template +void ElementwiseCudaKernel(const platform::CUDADeviceContext &ctx, + const std::vector &ins, + std::vector *outs, + Functor func) { + auto numel = ins[0]->numel(); + int block_size = GetThreadsConfig(ctx, numel, VecSize); + int grid_size = + ((numel + VecSize - 1) / VecSize + block_size - 1) / block_size; + + auto stream = ctx.stream(); + OutT *out = (*outs)[0]->data(); + framework::Array in; + for (int i = 0; i < ET; i++) { + in[i] = ins[i]->data(); } + ElementVectorizeKernel<<>>( + in, out, numel, func); } template @@ -190,43 +145,17 @@ void LaunchSameDimsElementwiseCudaKernel( const std::vector &ins, std::vector *outs, Functor func) { // calculate the max vec_size for all ins and outs - auto numel = ins[0]->numel(); int vec_size = GetVectorizedSizeForIO(ins, *outs); - int block_size = GetThreadsConfig(ctx, numel, vec_size); - int grid_size = - ((numel + vec_size - 1) / vec_size + block_size - 1) / block_size; - int main_tid = numel / vec_size; - int tail_tid = numel % vec_size; - uint32_t vec_len = main_tid * vec_size; - - // cuda kernel - auto stream = ctx.stream(); - switch (vec_size) { - case 4: { - auto data_wrapper = - ElementwiseDataWrapper(ins, outs, vec_len); - VectorizedKernel<<>>( - data_wrapper, main_tid, tail_tid, func); + case 4: + ElementwiseCudaKernel(ctx, ins, outs, func); break; - } - case 2: { - auto data_wrapper = - ElementwiseDataWrapper(ins, outs, vec_len); - VectorizedKernel<<>>( - data_wrapper, main_tid, tail_tid, func); + case 2: + ElementwiseCudaKernel(ctx, ins, outs, func); break; - } - case 1: { - auto data_wrapper = - ElementwiseDataWrapper(ins, outs, 0); - ScalarKernel<<>>(data_wrapper, - numel, func); + case 1: + ElementwiseCudaKernel(ctx, ins, outs, func); break; - } default: { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported vectorized size: %d !", vec_size));