Skip to content

Commit

Permalink
fix bug when block < 32
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzSean committed Jul 2, 2021
1 parent 148be36 commit bc5d183
Showing 1 changed file with 52 additions and 15 deletions.
67 changes: 52 additions & 15 deletions paddle/fluid/operators/reduce_ops/reduce_op.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,13 @@ constexpr int kMaxThread = 256;
#else
constexpr int kMaxThread = 128;
#endif
constexpr int warp_size = 32;

// get blockDim for reduceLastDim and reduceAny
static inline int GetBlockDim(int block_dim) {
return block_dim >= kMaxThread ? kMaxThread : GetLastPow2(block_dim);
return block_dim >= kMaxThread
? kMaxThread
: (block_dim <= warp_size ? warp_size : GetLastPow2(block_dim));
}

// check reduce rand is valid
Expand Down Expand Up @@ -393,26 +396,62 @@ struct ReduceConfig {
dim3 grid;
};

// version 1
template <typename T, typename ReduceOp>
__device__ __forceinline__ T BlockReduce(T* shared, T val, ReduceOp reducer) {
__device__ __forceinline__ T WarpReduce(T val, ReduceOp reducer) {
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
for (int stride = warpSize / 2; stride > 0; stride >>= 1) {
T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
val = reducer(val, temp);
}
return val;
}

template <typename T, typename ReduceOp>
__device__ __forceinline__ T BlockReduce(T val, T init, ReduceOp reducer) {
__shared__ T shared[32];
int lane = threadIdx.x % warpSize;
int wid = threadIdx.x / warpSize;

val = WarpReduce(val, reducer);

if (lane == 0) {
shared[wid] = val;
}

__syncthreads();

val = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : init;

if (wid == 0) {
val = WarpReduce(val, reducer);
}
return val;
}

// version 2
template <typename T, typename ReduceOp>
__device__ __forceinline__ T BlockReduce(T val, ReduceOp reducer) {
__shared__ T shared[detail::kMaxThread];
constexpr int warp_size = 32;
if (blockDim.x > warp_size) {
shared[threadIdx.x] = val;
}
for (int offset = blockDim.x / 2; offset >= warp_size; offset >>= 1) {
__syncthreads();
if (threadIdx.x < offset && threadIdx.x + offset < blockDim.x) {
T temp = shared[threadIdx.x + offset];
val = reducer(val, temp);
shared[threadIdx.x] = val;
for (int stride = blockDim.x / 2; stride >= warp_size; stride >>= 1) {
__syncthreads();
if (threadIdx.x < stride) {
T temp = shared[threadIdx.x + stride];
val = reducer(val, temp);
shared[threadIdx.x] = val;
}
}
}
__syncthreads();

unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
for (int offset = warp_size / 2; offset > 0; offset >>= 1) {
T temp = paddle::platform::CudaShuffleDownSync(mask, val, offset);
for (int stride = warp_size / 2; stride > 0; stride >>= 1) {
T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
val = reducer(val, temp);
}
return val;
Expand All @@ -426,7 +465,6 @@ __device__ __forceinline__ void ReduceLastDim(const Tx* x, Ty* y,
ReduceOp reducer,
TransformOp transformer, Ty init,
int reduce_num) {
__shared__ Ty shared_memory[detail::kMaxThread];
int idx_x = blockIdx.x * reduce_num;
int idx_y = threadIdx.x;
Ty reduce_var = init;
Expand All @@ -436,7 +474,7 @@ __device__ __forceinline__ void ReduceLastDim(const Tx* x, Ty* y,
}
__syncthreads();

reduce_var = BlockReduce(shared_memory, reduce_var, reducer);
reduce_var = BlockReduce(reduce_var, reducer);

if (threadIdx.x == 0) {
y[blockIdx.x] = reduce_var;
Expand Down Expand Up @@ -485,7 +523,6 @@ __device__ __forceinline__ void ReduceAny(
paddle::framework::Array<int, ReduceRank> reduce_strides,
paddle::framework::Array<int, Rank - ReduceRank> left_dim,
paddle::framework::Array<int, Rank - ReduceRank> left_strides) {
__shared__ Ty shared_memory[detail::kMaxThread];
int sub_index[Rank];
int left_idx = blockIdx.x;
for (int i = 0; i < Rank - ReduceRank; ++i) {
Expand Down Expand Up @@ -523,7 +560,7 @@ __device__ __forceinline__ void ReduceAny(
}
__syncthreads();

reduce_var = BlockReduce(shared_memory, reduce_var, reducer);
reduce_var = BlockReduce(reduce_var, reducer);
if (threadIdx.x == 0) {
y[blockIdx.x] = reduce_var;
}
Expand Down

0 comments on commit bc5d183

Please sign in to comment.