diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h index 61efa409b90c3..fd329acaf5ff2 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h @@ -360,19 +360,26 @@ struct ReduceConfig { constexpr int max_num_threads = detail::kMaxThread; // set block size. - // 1. if reduce_lastdim == true, block is 1-D, no need reduction in block y; - // 2. if reduce_lastdim == false, block is 2-D, if it is necessary, - // it should reduce in block y. + // 1. If reduce_lastdim == true, all the threads whose threadIdx.y are same + // will process the reduction for one output. + // The number of output for one block is blockDim.y; + // 2. If reduce_lastdim == false, different threadIdx.x will process + // different reduction and gets the output separately. If it is + // necessary, it should reduce in block y. + // The number of output for one block is blockDim.x; + int block_x, block_y; int grid_num, reduce_num_per_thread; if (reduce_lastdim) { - block_dim->x = detail::GetBlockDim(reduce_num); - block_dim->y = 1; - grid_num = left_num; - reduce_num_per_thread = - detail::AlignUp(reduce_num, block_dim->x * block_dim->y); + block_x = detail::GetBlockDim(reduce_num); + block_y = detail::GetBlockDim(left_num); + block_dim->x = block_x; + block_dim->y = + std::min(block_y, static_cast(max_num_threads / block_dim->x)); + grid_num = detail::AlignUp(left_num, block_dim->y); + reduce_num_per_thread = detail::AlignUp(reduce_num, block_dim->x); } else { - int block_x = detail::GetBlockDim(left_num); - int block_y = detail::GetBlockDim(reduce_num); + block_x = detail::GetBlockDim(left_num); + block_y = detail::GetBlockDim(reduce_num); block_dim->x = std::min(block_x, 32); block_dim->y = std::min(block_y, static_cast(max_num_threads / block_dim->x)); @@ -467,7 +474,7 @@ struct ReduceConfig { grid_dim.x = (left_num + block_dim.x - 1) / block_dim.x; grid_dim.y = 1; } - } else if (reduce_type == ReduceType::kReduceAny) { + } else { SetBlockDimForReduceAny(&block_dim, &grid_dim); } @@ -524,18 +531,20 @@ static __device__ T WarpReduce(T val, ReduceOp reducer) { template static __device__ T BlockXReduce(T val, ReduceOp reducer) { using detail::kWarpSize; - __shared__ T shared[kWarpSize]; + __shared__ T shared[2 * kWarpSize]; int block_dim_x = blockDim.x; if (blockDim.x > kWarpSize) { block_dim_x = blockDim.x / kWarpSize; int lane = threadIdx.x % kWarpSize; - int wid = threadIdx.x / kWarpSize; + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int wid = tid / kWarpSize; + int bid = threadIdx.y; val = WarpReduce(val, reducer); if (lane == 0) { shared[wid] = val; } __syncthreads(); - val = shared[lane]; + val = shared[bid * block_dim_x + lane]; } unsigned mask = 0u; @@ -562,29 +571,6 @@ static __device__ T BlockYReduce(T val, ReduceOp reducer) { return val; } -// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, this -// function will be used -// blockId.x -> left_num, threadId.x -> reduce_num -template -__device__ void ReduceLastDim(const Tx* x, Ty* y, ReduceOp reducer, - TransformOp transformer, Ty init, - int reduce_num) { - int idx_x = blockIdx.x * reduce_num; - int idx_y = threadIdx.x; - Ty reduce_var = init; - for (int idx_y = threadIdx.x; idx_y < reduce_num; idx_y += blockDim.x) { - reduce_var = - reducer(reduce_var, static_cast(transformer(x[idx_x + idx_y]))); - } - __syncthreads(); - - reduce_var = BlockXReduce(reduce_var, reducer); - - if (threadIdx.x == 0) { - y[blockIdx.x] = reduce_var; - } -} - // when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this // function will be used // eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1 @@ -613,19 +599,21 @@ __device__ void ReduceHigherDim(const Tx* x, Ty* y, ReduceOp reducer, } } +// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or // when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this // function will be used -template +template __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer, TransformOp transformer, Ty init, int reduce_num, int left_num, bool reduce_lastdim, - const IndexCalculator& reduce_index_calculator, - const IndexCalculator& left_index_calculator) { + ReduceIndexCal reduce_index_calculator, + LeftIndexCal left_index_calculator) { int input_idx, left_idx, stride; // the last dim gets involved in reduction if (reduce_lastdim) { input_idx = blockIdx.y * blockDim.x + threadIdx.x; - left_idx = blockIdx.x; + left_idx = blockIdx.x * blockDim.y + threadIdx.y; stride = gridDim.y * blockDim.x; } else { input_idx = blockIdx.y * blockDim.y + threadIdx.y; @@ -633,7 +621,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer, stride = gridDim.y * blockDim.y; } // calculate the offset, means the addr where each thread really start. - int input_offset = left_index_calculator.Get(left_idx); + int input_offset = left_index_calculator(left_idx); const Tx* input = x + input_offset; Ty reduce_var = init; @@ -646,7 +634,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer, #pragma unroll for (int i = 0; i < REDUCE_VEC_SIZE; ++i) { int reduce_idx = input_idx + i * stride; - int idx_x = reduce_index_calculator.Get(reduce_idx); + int idx_x = reduce_index_calculator(reduce_idx); input_reg[i] = input[idx_x]; } #pragma unroll @@ -664,7 +652,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer, break; } int reduce_idx = input_idx; - int idx_x = reduce_index_calculator.Get(reduce_idx); + int idx_x = reduce_index_calculator(reduce_idx); input_reg[i] = input[idx_x]; input_idx += stride; } @@ -680,7 +668,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer, } // 2. reduce in block y - if (blockDim.y > 1) { + if (!reduce_lastdim && blockDim.y > 1) { reduce_var = BlockYReduce(reduce_var, reducer); } __syncthreads(); @@ -688,8 +676,8 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer, if (reduce_lastdim) { // 3. reduce in block x reduce_var = BlockXReduce(reduce_var, reducer); - if (threadIdx.x == 0) { - y[blockIdx.x + blockIdx.y * gridDim.x] = reduce_var; + if (left_idx < left_num && threadIdx.x == 0) { + y[blockIdx.y * left_num + left_idx] = reduce_var; } } else { if (left_idx < left_num && threadIdx.y == 0) { @@ -707,8 +695,10 @@ __device__ void ReduceModule(const Tx* x, Ty* y, ReduceOp reducer, const IndexCalculator& reduce_index_calculator, const IndexCalculator& left_index_calculator) { if (reduce_type == ReduceType::kReduceLastDim) { - ReduceLastDim(x, y, reducer, transformer, - init, reduce_num); + ReduceAny( + x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim, + [&](int idx) { return idx; }, + [&](int idx) { return idx * reduce_num; }); // reduce_rank == 1 && reduce_dim[0] != x_dim.size() - 1 } else if (reduce_type == ReduceType::kReduceHigherDim) { @@ -719,7 +709,8 @@ __device__ void ReduceModule(const Tx* x, Ty* y, ReduceOp reducer, } else { ReduceAny( x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim, - reduce_index_calculator, left_index_calculator); + [&](int idx) { return reduce_index_calculator.Get(idx); }, + [&](int idx) { return left_index_calculator.Get(idx); }); } }