From 487f12a43e53817004c9a9db2baaf79a28f856c7 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Fri, 3 Sep 2021 11:44:14 +0000 Subject: [PATCH] merge ET --- .../elementwise/elementwise_op_broadcast.cu.h | 175 +++++------------- 1 file changed, 48 insertions(+), 127 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h index e18141f7b62c8..129c90a22be6b 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h @@ -180,128 +180,59 @@ __device__ __forceinline__ void LoadData( } } -template -__global__ void BroadcastKernelTernary( - const InT *__restrict__ in0, const InT *__restrict__ in1, - const InT *__restrict__ in2, OutT *out, - framework::Array use_broadcast, uint32_t numel, - framework::Array, MAX_INPUT_NUM> - configlists, - int main_tid, int tail_tid, Functor func) { - int block_offset = - blockIdx.x * blockDim.x * VecSize; // data offset of this block - int num = tail_tid; - InT arg[3][VecSize]; +template +__device__ void DealSegment( + const framework::Array &in, OutT *out, + const framework::Array &use_broadcast, uint32_t numel, + const framework::Array, + MAX_INPUT_NUM> &configlists, + int num, Functor func) { + InT args[ET][VecSize]; OutT result[VecSize]; - const bool is_boundary = true; - if (blockIdx.x < main_tid) { - num = blockDim.x * VecSize; // blockIdx.x < main_tid - // load in0, in1, in2 - LoadData(arg[0], in0, block_offset, configlists[0], - numel, num, use_broadcast[0]); - LoadData(arg[1], in1, block_offset, configlists[1], - numel, num, use_broadcast[1]); - LoadData(arg[2], in2, block_offset, configlists[2], - numel, num, use_broadcast[2]); - kps::ElementwiseTernary( - result, arg[0], arg[1], arg[2], func); - kps::WriteData(out + block_offset, result, num); - } else { // blockIdx.x == main_tid - // This is the last block and tial_tid != 0, set is_boundary = true - // is_boundary = true, boundary judgment needs to be made when loading data - // to avoid access storage overflow - kps::Init(arg[0], static_cast(1.0f)); - kps::Init(arg[1], static_cast(1.0f)); - kps::Init(arg[2], static_cast(1.0f)); - LoadData(arg[0], in0, block_offset, - configlists[0], numel, num, - use_broadcast[0]); - LoadData(arg[1], in1, block_offset, - configlists[1], numel, num, - use_broadcast[1]); - LoadData(arg[2], in2, block_offset, - configlists[2], numel, num, - use_broadcast[2]); + int block_offset = blockIdx.x * blockDim.x * VecSize; +// load +#pragma unroll + for (int i = 0; i < ET; i++) { + kps::Init(args[i], static_cast(1.0f)); + LoadData(args[i], in[i], block_offset, + configlists[i], numel, num, + use_broadcast[i]); + } + // compute + if (ET == kUnary) { + kps::ElementwiseUnary(result, args[0], + func); + } else if (ET == kBinary) { + kps::ElementwiseBinary(result, args[0], + args[1], func); + } else { kps::ElementwiseTernary( - result, arg[0], arg[1], arg[2], func); - kps::WriteData(out + block_offset, result, - num); + result, args[0], args[1], args[2], func); } + // compute + kps::WriteData(out + block_offset, result, + num); } -template -__global__ void BroadcastKernelBinary( - const InT *__restrict__ in0, const InT *__restrict__ in1, OutT *out, +template +__global__ void BroadcastKernel( + framework::Array in, OutT *out, framework::Array use_broadcast, uint32_t numel, framework::Array, MAX_INPUT_NUM> configlists, int main_tid, int tail_tid, Functor func) { - int block_offset = - blockIdx.x * blockDim.x * VecSize; // data offset of this block - int num = tail_tid; - InT arg[2][VecSize]; - OutT result[VecSize]; - const bool is_boundary = true; + int block_offset = blockIdx.x * blockDim.x * VecSize; + // data offset of this block if (blockIdx.x < main_tid) { - num = blockDim.x * VecSize; // blockIdx.x < main_tid - LoadData(arg[0], in0, block_offset, configlists[0], - numel, num, use_broadcast[0]); - LoadData(arg[1], in1, block_offset, configlists[1], - numel, num, use_broadcast[1]); - kps::ElementwiseBinary(result, arg[0], - arg[1], func); - kps::WriteData(out + block_offset, result, num); + int num = blockDim.x * VecSize; // blockIdx.x < main_tid + DealSegment( + in, out, use_broadcast, numel, configlists, num, func); } else { // reminder - // This is the last block and tial_tid != 0, set is_boundary = true - // is_boundary = true, boundary judgment needs to be made when loading data - // to avoid access storage overflow - kps::Init(arg[0], static_cast(1.0f)); - kps::Init(arg[1], static_cast(1.0f)); - LoadData(arg[0], in0, block_offset, - configlists[0], numel, num, - use_broadcast[0]); - LoadData(arg[1], in1, block_offset, - configlists[1], numel, num, - use_broadcast[1]); - kps::ElementwiseBinary(result, arg[0], - arg[1], func); - kps::WriteData(out + block_offset, result, - num); - } -} - -template -__global__ void BroadcastKernelUnary( - const InT *__restrict__ in, OutT *out, int numel, - kps::details::BroadcastConfig config, int main_tid, int tail_tid, - Functor func) { - int block_offset = - blockIdx.x * blockDim.x * VecSize; // data offset of this block - int num = tail_tid; - InT arg[VecSize]; - OutT result[VecSize]; - const bool is_boundary = true; - if (blockIdx.x < main_tid) { - num = blockDim.x * VecSize; // blockIdx.x < main_tid - kps::ReadDataBc(&arg[0], in, block_offset, - config, numel, 1, 1); - kps::ElementwiseUnary(&result[0], - &arg[0], func); - kps::WriteData(out + block_offset, &result[0], num); - } else { - // This is the last block and tial_tid != 0, set is_boundary = true - // is_boundary = true, boundary judgment needs to be made when loading data - // to avoid access storage overflow - kps::Init(&arg[0], static_cast(1.0f)); - kps::ReadDataBc( - &arg[0], in, block_offset, config, numel, 1, 1); - kps::ElementwiseUnary(&result[0], - &arg[0], func); - kps::WriteData(out + block_offset, - &result[0], num); + int num = tail_tid; + DealSegment( + in, out, use_broadcast, numel, configlists, num, func); } } @@ -323,9 +254,11 @@ void LaunchKernel(const platform::CUDADeviceContext &ctx, framework::Array, MAX_INPUT_NUM> configlists; framework::Array use_broadcast; + framework::Array ins_data; for (int i = 0; i < ET; i++) { use_broadcast[i] = (ins[i]->numel() != numel); + ins_data[i] = ins[i]->data(); if (use_broadcast[i]) { // get the broadcast config, // if data shape is[m, n], then you should set data_dim = {n, m} @@ -335,22 +268,10 @@ void LaunchKernel(const platform::CUDADeviceContext &ctx, } } - if (ET == kUnary) { // for unary eg: relu - BroadcastKernelUnary<<>>( - ins[0]->data(), out_data, numel, configlists[0], main_tid, - tail_tid, func); - } else if (ET == kBinary) { // for binary eg: add: a + b - BroadcastKernelBinary<<>>( - ins[0]->data(), ins[1]->data(), out_data, use_broadcast, - numel, configlists, main_tid, tail_tid, func); - } else { // for ternary eg:fma : a * b + c - BroadcastKernelTernary<<>>( - ins[0]->data(), ins[1]->data(), ins[2]->data(), out_data, - use_broadcast, numel, configlists, main_tid, tail_tid, func); - } + BroadcastKernel<<>>( + ins_data, out_data, use_broadcast, numel, configlists, main_tid, tail_tid, + func); } template