Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify the block/grid strategy and implementation of ReduceLastDim and ReduceAny. #34436

Merged
merged 1 commit into from
Aug 2, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 41 additions & 50 deletions paddle/fluid/operators/reduce_ops/reduce_op.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,19 +360,26 @@ struct ReduceConfig {
constexpr int max_num_threads = detail::kMaxThread;

// set block size.
// 1. if reduce_lastdim == true, block is 1-D, no need reduction in block y;
// 2. if reduce_lastdim == false, block is 2-D, if it is necessary,
// it should reduce in block y.
// 1. If reduce_lastdim == true, all the threads whose threadIdx.y are same
// will process the reduction for one output.
// The number of output for one block is blockDim.y;
// 2. If reduce_lastdim == false, different threadIdx.x will process
// different reduction and gets the output separately. If it is
// necessary, it should reduce in block y.
// The number of output for one block is blockDim.x;
int block_x, block_y;
int grid_num, reduce_num_per_thread;
if (reduce_lastdim) {
block_dim->x = detail::GetBlockDim(reduce_num);
block_dim->y = 1;
grid_num = left_num;
reduce_num_per_thread =
detail::AlignUp(reduce_num, block_dim->x * block_dim->y);
block_x = detail::GetBlockDim(reduce_num);
block_y = detail::GetBlockDim(left_num);
block_dim->x = block_x;
block_dim->y =
std::min(block_y, static_cast<int>(max_num_threads / block_dim->x));
grid_num = detail::AlignUp(left_num, block_dim->y);
reduce_num_per_thread = detail::AlignUp(reduce_num, block_dim->x);
} else {
int block_x = detail::GetBlockDim(left_num);
int block_y = detail::GetBlockDim(reduce_num);
block_x = detail::GetBlockDim(left_num);
block_y = detail::GetBlockDim(reduce_num);
block_dim->x = std::min(block_x, 32);
block_dim->y =
std::min(block_y, static_cast<int>(max_num_threads / block_dim->x));
Expand Down Expand Up @@ -467,7 +474,7 @@ struct ReduceConfig {
grid_dim.x = (left_num + block_dim.x - 1) / block_dim.x;
grid_dim.y = 1;
}
} else if (reduce_type == ReduceType::kReduceAny) {
} else {
SetBlockDimForReduceAny(&block_dim, &grid_dim);
}

Expand Down Expand Up @@ -524,18 +531,20 @@ static __device__ T WarpReduce(T val, ReduceOp reducer) {
template <typename T, typename ReduceOp>
static __device__ T BlockXReduce(T val, ReduceOp reducer) {
using detail::kWarpSize;
__shared__ T shared[kWarpSize];
__shared__ T shared[2 * kWarpSize];
Copy link
Contributor

@Xreki Xreki Aug 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么是2 * kWarpSize?后续可以考虑继续优化下这个函数,使得它可以作为一个基础函数,可以适用于如softmax、batch_norm等算子的使用场景。

int block_dim_x = blockDim.x;
if (blockDim.x > kWarpSize) {
block_dim_x = blockDim.x / kWarpSize;
int lane = threadIdx.x % kWarpSize;
int wid = threadIdx.x / kWarpSize;
int tid = threadIdx.y * blockDim.x + threadIdx.x;
int wid = tid / kWarpSize;
int bid = threadIdx.y;
val = WarpReduce(val, reducer);
if (lane == 0) {
shared[wid] = val;
}
__syncthreads();
val = shared[lane];
val = shared[bid * block_dim_x + lane];
}

unsigned mask = 0u;
Expand All @@ -562,29 +571,6 @@ static __device__ T BlockYReduce(T val, ReduceOp reducer) {
return val;
}

// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, this
// function will be used
// blockId.x -> left_num, threadId.x -> reduce_num
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
__device__ void ReduceLastDim(const Tx* x, Ty* y, ReduceOp reducer,
TransformOp transformer, Ty init,
int reduce_num) {
int idx_x = blockIdx.x * reduce_num;
int idx_y = threadIdx.x;
Ty reduce_var = init;
for (int idx_y = threadIdx.x; idx_y < reduce_num; idx_y += blockDim.x) {
reduce_var =
reducer(reduce_var, static_cast<Ty>(transformer(x[idx_x + idx_y])));
}
__syncthreads();

reduce_var = BlockXReduce(reduce_var, reducer);

if (threadIdx.x == 0) {
y[blockIdx.x] = reduce_var;
}
}

// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
// function will be used
// eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1
Expand Down Expand Up @@ -613,27 +599,29 @@ __device__ void ReduceHigherDim(const Tx* x, Ty* y, ReduceOp reducer,
}
}

// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or
// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
// function will be used
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp>
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
typename ReduceIndexCal, typename LeftIndexCal>
__device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
TransformOp transformer, Ty init, int reduce_num,
int left_num, bool reduce_lastdim,
const IndexCalculator& reduce_index_calculator,
const IndexCalculator& left_index_calculator) {
ReduceIndexCal reduce_index_calculator,
LeftIndexCal left_index_calculator) {
int input_idx, left_idx, stride;
// the last dim gets involved in reduction
if (reduce_lastdim) {
input_idx = blockIdx.y * blockDim.x + threadIdx.x;
left_idx = blockIdx.x;
left_idx = blockIdx.x * blockDim.y + threadIdx.y;
stride = gridDim.y * blockDim.x;
} else {
input_idx = blockIdx.y * blockDim.y + threadIdx.y;
left_idx = blockIdx.x * blockDim.x + threadIdx.x;
stride = gridDim.y * blockDim.y;
}
// calculate the offset, means the addr where each thread really start.
int input_offset = left_index_calculator.Get(left_idx);
int input_offset = left_index_calculator(left_idx);
const Tx* input = x + input_offset;
Ty reduce_var = init;

Expand All @@ -646,7 +634,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
#pragma unroll
for (int i = 0; i < REDUCE_VEC_SIZE; ++i) {
int reduce_idx = input_idx + i * stride;
int idx_x = reduce_index_calculator.Get(reduce_idx);
int idx_x = reduce_index_calculator(reduce_idx);
input_reg[i] = input[idx_x];
}
#pragma unroll
Expand All @@ -664,7 +652,7 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
break;
}
int reduce_idx = input_idx;
int idx_x = reduce_index_calculator.Get(reduce_idx);
int idx_x = reduce_index_calculator(reduce_idx);
input_reg[i] = input[idx_x];
input_idx += stride;
}
Expand All @@ -680,16 +668,16 @@ __device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer,
}

// 2. reduce in block y
if (blockDim.y > 1) {
if (!reduce_lastdim && blockDim.y > 1) {
reduce_var = BlockYReduce(reduce_var, reducer);
}
__syncthreads();

if (reduce_lastdim) {
// 3. reduce in block x
reduce_var = BlockXReduce(reduce_var, reducer);
if (threadIdx.x == 0) {
y[blockIdx.x + blockIdx.y * gridDim.x] = reduce_var;
if (left_idx < left_num && threadIdx.x == 0) {
y[blockIdx.y * left_num + left_idx] = reduce_var;
}
} else {
if (left_idx < left_num && threadIdx.y == 0) {
Expand All @@ -707,8 +695,10 @@ __device__ void ReduceModule(const Tx* x, Ty* y, ReduceOp reducer,
const IndexCalculator& reduce_index_calculator,
const IndexCalculator& left_index_calculator) {
if (reduce_type == ReduceType::kReduceLastDim) {
ReduceLastDim<Tx, Ty, ReduceOp, TransformOp>(x, y, reducer, transformer,
init, reduce_num);
ReduceAny<Tx, Ty, ReduceOp, TransformOp>(
x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim,
[&](int idx) { return idx; },
[&](int idx) { return idx * reduce_num; });

// reduce_rank == 1 && reduce_dim[0] != x_dim.size() - 1
} else if (reduce_type == ReduceType::kReduceHigherDim) {
Expand All @@ -719,7 +709,8 @@ __device__ void ReduceModule(const Tx* x, Ty* y, ReduceOp reducer,
} else {
ReduceAny<Tx, Ty, ReduceOp, TransformOp>(
x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim,
reduce_index_calculator, left_index_calculator);
[&](int idx) { return reduce_index_calculator.Get(idx); },
[&](int idx) { return left_index_calculator.Get(idx); });
}
}

Expand Down