Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a bug in ReadData, ReadDataBc and ReadDataReduce when NX != 1 #36373

Merged
merged 32 commits into from
Oct 21, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
7d58b91
Merge pull request #1 from PaddlePaddle/develop
AnnaTrainingG Mar 25, 2021
1021e08
Merge pull request #2 from PaddlePaddle/develop
AnnaTrainingG Mar 29, 2021
43f53fe
Merge pull request #3 from PaddlePaddle/develop
AnnaTrainingG Apr 19, 2021
d25ab26
Merge pull request #4 from PaddlePaddle/develop
AnnaTrainingG May 7, 2021
8c8717f
Merge pull request #5 from PaddlePaddle/develop
AnnaTrainingG May 25, 2021
9ddf5e8
Merge pull request #6 from PaddlePaddle/develop
AnnaTrainingG May 26, 2021
b0cbcca
Merge pull request #9 from PaddlePaddle/develop
AnnaTrainingG Jun 1, 2021
cdecaf0
Merge pull request #14 from PaddlePaddle/develop
AnnaTrainingG Jun 11, 2021
0da14c9
Merge pull request #16 from PaddlePaddle/develop
AnnaTrainingG Jun 15, 2021
ca95763
Merge pull request #17 from PaddlePaddle/develop
AnnaTrainingG Jun 22, 2021
25ba21c
Merge pull request #18 from PaddlePaddle/develop
AnnaTrainingG Jul 5, 2021
3ce9983
Merge pull request #19 from PaddlePaddle/develop
AnnaTrainingG Jul 6, 2021
61842ed
Merge pull request #20 from PaddlePaddle/develop
AnnaTrainingG Jul 12, 2021
0e2c73b
Merge pull request #21 from PaddlePaddle/develop
AnnaTrainingG Jul 28, 2021
c1e59cf
Merge pull request #22 from PaddlePaddle/develop
AnnaTrainingG Aug 2, 2021
3a54149
Merge pull request #23 from PaddlePaddle/develop
AnnaTrainingG Aug 4, 2021
7addd79
Merge pull request #24 from PaddlePaddle/develop
AnnaTrainingG Aug 11, 2021
1e843d1
Merge pull request #25 from PaddlePaddle/develop
AnnaTrainingG Aug 23, 2021
e1a92d6
Merge pull request #26 from PaddlePaddle/develop
AnnaTrainingG Sep 1, 2021
05da032
Merge pull request #27 from PaddlePaddle/develop
AnnaTrainingG Sep 3, 2021
e1fe6dc
Merge pull request #28 from PaddlePaddle/develop
AnnaTrainingG Sep 6, 2021
650013c
Update the implement of reduceAnyKernel according to kernel primitive…
AnnaTrainingG Oct 11, 2021
6f7ff7b
datamover_primitives.h
AnnaTrainingG Oct 13, 2021
081145a
add ReadDataBc for 1D data
AnnaTrainingG Oct 13, 2021
33b5c40
add writeData with stride
AnnaTrainingG Oct 14, 2021
1a61ef9
add attn_bias_add.cu.h
AnnaTrainingG Oct 14, 2021
b4f4293
fix a bug in readDataReduce
AnnaTrainingG Oct 18, 2021
85a036c
update notes
AnnaTrainingG Oct 18, 2021
d616252
update the notes of compute
AnnaTrainingG Oct 18, 2021
7bde693
update the notes of ReadData and WriteData
AnnaTrainingG Oct 18, 2021
b9d3555
Merge branch 'develop' of https://github.com/niuliling123/Paddle into…
AnnaTrainingG Oct 19, 2021
469b364
add &
AnnaTrainingG Oct 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 42 additions & 18 deletions paddle/fluid/operators/kernel_primitives/datamover_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,15 @@ struct BroadcastConfig {
* parameter will be used when IsBoundary = true.
* size_ny: The current block needs to load size_ny rows of data. This parameter
* will be used when IsBoundary = true.
* stride_nx: The stride of cols.
* stride_ny: The stride of rows.
* stride_nx: Each read one element stride stride_nx elements in the last dim.
* stride_ny: Each read one element stride stride_ny elements in the first dim.
*/
template <typename Tx, typename Ty, int NX, int NY, int BlockSize,
bool IsBoundary = false>
__device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src,
int size_nx, int size_ny,
int stride_nx, int stride_ny) {
int thread_offset = threadIdx.x * NX;
int thread_offset = threadIdx.x;
int left_size_nx = size_nx - thread_offset;

// Each branch is added for better performance
Expand All @@ -165,7 +165,7 @@ __device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src,
#pragma unroll
for (int idy = 0; idy < NY; ++idy) {
if (IsBoundary) {
if (idy >= size_ny) {
if (idy * stride_ny >= size_ny) {
break;
}
}
Expand All @@ -175,7 +175,7 @@ __device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src,
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if (IsBoundary) {
if (idx >= left_size_nx) {
if (idx * stride_nx >= left_size_nx) {
break;
}
}
Expand All @@ -185,14 +185,14 @@ __device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src,
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if (IsBoundary) {
if (idx >= left_size_nx) {
if (idx * stride_nx >= left_size_nx) {
break;
}
}
#pragma unroll
for (int idy = 0; idy < NY; ++idy) {
if (IsBoundary) {
if (idy >= size_ny) {
if (idy * stride_ny >= size_ny) {
break;
}
}
Expand Down Expand Up @@ -299,8 +299,8 @@ __device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src,
* coordinate mapping relationship between output data and input data. Please
* refer to the sample code for specific usage.
* total_num_output: Total number of original output.
* stride_nx: The stride of cols.
* stride_ny: The stride of rows.
* stride_nx: Each read one element stride stride_nx elements in the last dim.
* stride_ny: Each read one element stride stride_ny elements in the first dim.
*/
template <typename T, int NX, int NY, int BlockSize, int Rank,
bool IsBoundary = false>
Expand Down Expand Up @@ -363,8 +363,8 @@ __device__ __forceinline__ void ReadDataBc(
* parameter will be used when IsBoundary = true.
* size_ny: The current block needs to load size_ny rows of data. This parameter
* will be used when IsBoundary = true.
* stride_nx: The stride of cols.
* stride_ny: The stride of rows.
* stride_nx: Each read one element stride stride_nx elements in the last dim.
* stride_ny: Each read one element stride stride_ny elements in the first dim.
* reduce_last_dim: Used to indicate whether the dimension of reduce contains
* the lowest dimension.
*/
Expand All @@ -375,17 +375,21 @@ __device__ __forceinline__ void ReadDataReduce(
const IndexCal& index_cal, int size_nx, int size_ny, int stride_nx,
int stride_ny, bool reduce_last_dim) {
int thread_offset = 0;
int left_size_nx = size_nx;
int left_size_ny = size_ny;
if (reduce_last_dim) {
thread_offset = block_offset + threadIdx.x;
left_size_nx -= thread_offset;
} else {
thread_offset = block_offset + threadIdx.y;
left_size_ny -= thread_offset;
}

if (NX == 1) {
#pragma unroll
for (int ny = 0; ny < NY; ++ny) {
if (IsBoundary) {
if (thread_offset >= size_ny) {
if (ny * stride_ny >= left_size_ny) {
break;
}
}
Expand All @@ -396,15 +400,11 @@ __device__ __forceinline__ void ReadDataReduce(
} else {
#pragma unroll
for (int nx = 0; nx < NX; ++nx) {
if (IsBoundary) {
if (nx * stride_nx >= size_nx) {
break;
}
}
#pragma unroll
for (int ny = 0; ny < NY; ++ny) {
if (IsBoundary) {
if (nx * stride_nx >= size_nx) {
if ((ny * stride_ny >= left_size_ny) ||
(nx * stride_nx >= left_size_nx)) {
break;
}
}
Expand Down Expand Up @@ -467,6 +467,30 @@ __device__ __forceinline__ void WriteData(T* dst, T* __restrict__ src,
}
}

/**
* @brief Initialize register with init_data.
*
* @template paraments
* T: Data type of register.
* NX: Number of data to initialize.
*
* @param:
* dst: The register pointer of the thread, the size is NX.
* init_data: The register pointer of init data, the size is NX.
*/
template <typename T, int NX, bool IsBoundary = false>
__device__ __forceinline__ void Init(T* dst, T* init_data, int num) {
#pragma unroll
for (int i = 0; i < NX; i++) {
if (IsBoundary) {
if (i >= num) {
break;
}
}
dst[i] = init_data[i];
}
}

} // namespace kernel_primitives
} // namespace operators
} // namespace paddle
59 changes: 32 additions & 27 deletions paddle/fluid/operators/reduce_ops/reduce_op.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,31 @@ __device__ void HigherDimDealSegment(const Tx* x, Ty* y, ReduceOp reducer,
kps::WriteData<Ty, 1, 1, 1, IsBoundary>(y + store_offset, &temp_data, size);
}

template <typename Tx, typename MPType, typename ReduceOp, typename TransformOp,
typename Calculator, bool IsBoundary>
__device__ void ReduceAnyKernelImpl(const Tx* input, MPType* reduce_var,
ReduceOp reducer, TransformOp transformer,
MPType init, int reduce_num, int input_idx,
bool reduce_last_dim,
const Calculator reduce_index_calculator,
int stride, int num) {
Tx input_reg[REDUCE_VEC_SIZE];
MPType input_compute[REDUCE_VEC_SIZE];
MPType input_transform[REDUCE_VEC_SIZE];

kps::Init<MPType, REDUCE_VEC_SIZE>(&input_compute[0], init);
kps::ReadDataReduce<Tx, 1, REDUCE_VEC_SIZE, 1, 1, Calculator, IsBoundary>(
&input_reg[0], input, input_idx, reduce_index_calculator, 1, reduce_num,
1, stride, reduce_last_dim);
kps::ElementwiseUnary<Tx, MPType, REDUCE_VEC_SIZE, 1, 1, TransformOp>(
&input_transform[0], &input_reg[0], transformer);
kps::Init<MPType, REDUCE_VEC_SIZE, IsBoundary>(input_compute, input_transform,
num);
kps::Reduce<MPType, REDUCE_VEC_SIZE, 1, 1, ReduceOp,
kps::details::ReduceMode::kLocalMode>(
reduce_var, &input_compute[0], reducer, reduce_last_dim);
}

// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or
// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
// function will be used
Expand Down Expand Up @@ -570,37 +595,17 @@ __global__ void ReduceAnyKernel(const Tx* x, Ty* y, ReduceOp reducer,
// 1. reduce for each thread
if (left_idx < left_num) {
// load REDUCE_VEC_SIZE data once, and then compute
Tx input_reg[REDUCE_VEC_SIZE];
MPType input_compute[REDUCE_VEC_SIZE];
int bound = reduce_num - (REDUCE_VEC_SIZE - 1) * stride;
for (; input_idx + block_size < bound;
input_idx += REDUCE_VEC_SIZE * stride) {
kps::ReadDataReduce<Tx, 1, REDUCE_VEC_SIZE, 1, 1, Calculator>(
&input_reg[0], input, input_idx, reduce_index_calculator, 1,
reduce_num, 1, stride, reduce_last_dim);
kps::ElementwiseUnary<Tx, MPType, REDUCE_VEC_SIZE, 1, 1, TransformOp>(
&input_compute[0], &input_reg[0], transformer);
kps::Reduce<MPType, REDUCE_VEC_SIZE, 1, 1, ReduceOp,
kps::details::ReduceMode::kLocalMode>(
&reduce_var, &input_compute[0], reducer, reduce_last_dim);
}

kps::Init<MPType, REDUCE_VEC_SIZE>(&input_compute[0], init);
kps::ReadDataReduce<Tx, 1, REDUCE_VEC_SIZE, 1, 1, Calculator, true>(
&input_reg[0], input, input_idx, reduce_index_calculator, 1, reduce_num,
1, stride, reduce_last_dim);
input_idx += tid;
#pragma unroll
for (int i = 0; i < REDUCE_VEC_SIZE; ++i) {
if (input_idx >= reduce_num) {
break;
}
input_compute[i] = static_cast<MPType>(transformer(input_reg[i]));
input_idx += stride;
ReduceAnyKernelImpl<Tx, MPType, ReduceOp, TransformOp, Calculator, false>(
input, &reduce_var, reducer, transformer, init, reduce_num, input_idx,
reduce_last_dim, reduce_index_calculator, stride, reduce_num);
}
kps::Reduce<MPType, REDUCE_VEC_SIZE, 1, 1, ReduceOp,
kps::details::ReduceMode::kLocalMode>(
&reduce_var, &input_compute[0], reducer, reduce_last_dim);
int num = (reduce_num - input_idx - tid + stride - 1) / stride;
ReduceAnyKernelImpl<Tx, MPType, ReduceOp, TransformOp, Calculator, true>(
input, &reduce_var, reducer, transformer, init, reduce_num, input_idx,
reduce_last_dim, reduce_index_calculator, stride, num);
}

kps::Reduce<MPType, 1, 1, 1, ReduceOp, kps::details::kGlobalMode>(
Expand Down